1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
| import os from pathlib import Path
import backtrader.feeds as btfeeds import backtrader as bt import tushare as ts import pandas as pd from datetime import datetime
class PandasDataExtend(bt.feeds.PandasData): lines = ('open', 'high', 'low', 'close', 'vol')
params = ( ('datetime', None), ('open', 'open'), ('high', 'high'), ('low', 'low'), ('close', 'close'), ('volume', 'vol'), )
def dataframe_to_datafeeds(df: pd.DataFrame, start_date: str = "20100101", end_date: str = "20250101"): return bt.feeds.PandasData(dataname=df, fromdate=pd.to_datetime(start_date), todate=pd.to_datetime("20240101"))
def load_data_from_csv(ts_code: str, start_date: str = "20100101", end_date: str = "20250101", adj: str = "hfq") -> pd.DataFrame: """ 获取指定股票在指定日期范围内的日线行情数据。
参数: - ts_code: 股票代码 - start_date: 开始日期,格式 YYYYMMDD - end_date: 结束日期,格式 YYYYMMDD
返回: - 包含日线行情数据的Pandas DataFrame """ project_root = Path(__file__).resolve().parents[1] data_path = project_root / 'data'
df = pd.read_csv(f"{data_path}/{ts_code}_{start_date}_{end_date}_{adj}.csv") if len(df) == 0: raise ValueError("csv数据找不到!!!") df['trade_date'] = pd.to_datetime(df['trade_date']) df.set_index('trade_date', inplace=True) df.sort_index(inplace=True) df.fillna(0.0, inplace=True)
return df
|