diff --git a/spread_backtest.py b/spread_backtest.py index 6c0479f..fb5df4b 100644 --- a/spread_backtest.py +++ b/spread_backtest.py @@ -5,6 +5,7 @@ import pandas as pd import numpy as np import time from trader import Trader +from typing import Union, Dict from rich import print as rprint from rich.table import Table @@ -60,6 +61,26 @@ class SpreadBacktest(): else: self.trader.update_signal(date, update_type='position') + # 更新数据 + def update_signal(self, + trade_time: str, + new_signal: pd.DataFrame): + """ + 更新信号因子 + + Args: + trade_time (str): 信号时间 + new_signal (pd.DataFrame): 新的更新信号 + """ + self.trader.signal[trade_time] = new_signal + + def update_interval(self, + interval: Dict[str, Union[int,tuple,pd.Series]]={}, + ): + # 更新interval和weight + # 如果interval为固定比例则更新 + self.trader.init_interval(interval) + @property def account_history(self): return self.trader.account_history diff --git a/trader.py b/trader.py index 4f71eb0..dfd9a01 100644 --- a/trader.py +++ b/trader.py @@ -19,30 +19,31 @@ class Trader(Account): 交易类: 用于控制每日交易情况 Args: - signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行 - interval (int, tuple, pd.Series): 交易间隔 - num (int): 持仓数量 - ascending (bool): 因子方向 - with_st (bool): 是否包含st - tick (bool): 是否开始tick模拟模式(开发中) - weight ([str, pd.DataFrame]): 权重分配 + signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行 + interval (dict[str, (int, tuple, pd.Series)]): + 交易间隔 + num (int): 持仓数量 + ascending (bool): 因子方向 + with_st (bool): 是否包含st + tick (bool): 是否开始tick模拟模式(开发中) + weight ([str, pd.DataFrame]): 权重分配 - avg (str): 平均分配,每天早盘重新分配,日中交易不重新分配 - (pd.DataFrame): 自定义股票权重,包含每天个股指定的权重,会自动归一化 - amt_filter (set): 20日均成交额筛选,第一个参数是筛选下限,第二个参数是筛选上限,可以只提供下限 - data_root (dict): 对应各个目标因子的交易价格数据,必须包含stock_code和price列 - ipo_days (int): 筛选上市时间 - slippage (tuple): 买入和卖出滑点 - commission (float): 佣金 - tax (dict): 印花税 - exclude_list (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓 + amt_filter (set): 20日均成交额筛选,第一个参数是筛选下限,第二个参数是筛选上限,可以只提供下限 + data_root (dict): 对应各个目标因子的交易价格数据,必须包含stock_code和price列 + ipo_days (int): 筛选上市时间 + slippage (tuple): 买入和卖出滑点 + commission (float): 佣金 + tax (dict): 印花税 + exclude_list (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓 - abnormal: 异常公告剔除,包含中止上市、立案调查、警示函等异常情况的剔除 - receesion: 财报同比或环比下降50%以上 - qualified_opinion: 会计保留意见 - account (Account): 账户设置,account.Account + account (Account): 账户设置,account.Account """ def __init__(self, signal: Dict[str, pd.DataFrame]=None, - interval: Dict[str, Union[int,tuple,pd.Series]]=1, + interval: Dict[str, Union[int,tuple,pd.Series]]={}, num: int=100, ascending: bool=False, with_st: bool=False, @@ -78,24 +79,7 @@ class Trader(Account): if len(kwargs) > 0: raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'") # interval - self.interval = [] - for s in signal: - if s in interval: - s_interval = interval[s] - if isinstance(s_interval, int): - df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index)) - df_interval[::s_interval] = 1 - elif isinstance(s_interval, tuple): - df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index)) - df_interval[::s_interval[0]] = s_interval[1] - elif isinstance(s_interval, pd.Series): - df_interval = s_interval - else: - raise ValueError('invalid interval type') - self.interval.append(df_interval) - else: - raise ValueError(f'not found interval for signal {s}') - self.interval = pd.concat(self.interval) + self.init_interval(interval) # num if isinstance(num, int): self.num = int(num) @@ -173,7 +157,30 @@ class Trader(Account): if s not in data_root: raise ValueError(f"data for signal {s} is not provided") self.data_root = data_root - + + def init_interval(self, interval): + """ + 初始化interval + """ + interval_list = [] + for s in self.signal: + if s in interval: + s_interval = interval[s] + if isinstance(s_interval, int): + df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index)) + df_interval[::s_interval] = 1 + elif isinstance(s_interval, tuple): + df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index)) + df_interval[::s_interval[0]] = s_interval[1] + elif isinstance(s_interval, pd.Series): + df_interval = s_interval + else: + raise ValueError('invalid interval type') + interval_list.append(df_interval) + else: + raise ValueError(f'not found interval for signal {s}') + self.interval = pd.concat(interval_list) + def load_data(self, date: str, update_type: str='rtn'): @@ -259,7 +266,10 @@ class Trader(Account): # 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量 # 不足的数量通过买入列表自适应调整 # 这样能实现在因子值不足时也正常换仓 - max_sell_num = self.interval.loc[date]*len(last_position) + try: + max_sell_num = self.interval.loc[date]*len(last_position) + except Exception: + raise ValueError(f'not found interval in {date}') else: last_position = pd.Series() max_sell_num = self.num