Update: 增加信号更新函数 (#26);

新增interval初始化函数;
This commit is contained in:
binz 2024-06-24 21:59:34 +08:00
parent 30a998b58a
commit 721e2a42f5
2 changed files with 67 additions and 36 deletions

View File

@ -5,6 +5,7 @@ import pandas as pd
import numpy as np import numpy as np
import time import time
from trader import Trader from trader import Trader
from typing import Union, Dict
from rich import print as rprint from rich import print as rprint
from rich.table import Table from rich.table import Table
@ -60,6 +61,26 @@ class SpreadBacktest():
else: else:
self.trader.update_signal(date, update_type='position') self.trader.update_signal(date, update_type='position')
# 更新数据
def update_signal(self,
trade_time: str,
new_signal: pd.DataFrame):
"""
更新信号因子
Args:
trade_time (str): 信号时间
new_signal (pd.DataFrame): 新的更新信号
"""
self.trader.signal[trade_time] = new_signal
def update_interval(self,
interval: Dict[str, Union[int,tuple,pd.Series]]={},
):
# 更新interval和weight
# 如果interval为固定比例则更新
self.trader.init_interval(interval)
@property @property
def account_history(self): def account_history(self):
return self.trader.account_history return self.trader.account_history

View File

@ -19,30 +19,31 @@ class Trader(Account):
交易类: 用于控制每日交易情况 交易类: 用于控制每日交易情况
Args: Args:
signal (dict[str, pd.DataFrame]): 目标因子按顺序执行 signal (dict[str, pd.DataFrame]): 目标因子按顺序执行
interval (int, tuple, pd.Series): 交易间隔 interval (dict[str, (int, tuple, pd.Series)]):
num (int): 持仓数量 交易间隔
ascending (bool): 因子方向 num (int): 持仓数量
with_st (bool): 是否包含st ascending (bool): 因子方向
tick (bool): 是否开始tick模拟模式(开发中) with_st (bool): 是否包含st
weight ([str, pd.DataFrame]): 权重分配 tick (bool): 是否开始tick模拟模式(开发中)
weight ([str, pd.DataFrame]): 权重分配
- avg (str): 平均分配每天早盘重新分配日中交易不重新分配 - avg (str): 平均分配每天早盘重新分配日中交易不重新分配
- (pd.DataFrame): 自定义股票权重包含每天个股指定的权重会自动归一化 - (pd.DataFrame): 自定义股票权重包含每天个股指定的权重会自动归一化
amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限 amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限
data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列 data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列
ipo_days (int): 筛选上市时间 ipo_days (int): 筛选上市时间
slippage (tuple): 买入和卖出滑点 slippage (tuple): 买入和卖出滑点
commission (float): 佣金 commission (float): 佣金
tax (dict): 印花税 tax (dict): 印花税
exclude_list (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓 exclude_list (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓
- abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除 - abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除
- receesion: 财报同比或环比下降50%以上 - receesion: 财报同比或环比下降50%以上
- qualified_opinion: 会计保留意见 - qualified_opinion: 会计保留意见
account (Account): 账户设置account.Account account (Account): 账户设置account.Account
""" """
def __init__(self, def __init__(self,
signal: Dict[str, pd.DataFrame]=None, signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]=1, interval: Dict[str, Union[int,tuple,pd.Series]]={},
num: int=100, num: int=100,
ascending: bool=False, ascending: bool=False,
with_st: bool=False, with_st: bool=False,
@ -78,24 +79,7 @@ class Trader(Account):
if len(kwargs) > 0: if len(kwargs) > 0:
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'") raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
# interval # interval
self.interval = [] self.init_interval(interval)
for s in signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
self.interval.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(self.interval)
# num # num
if isinstance(num, int): if isinstance(num, int):
self.num = int(num) self.num = int(num)
@ -174,6 +158,29 @@ class Trader(Account):
raise ValueError(f"data for signal {s} is not provided") raise ValueError(f"data for signal {s} is not provided")
self.data_root = data_root self.data_root = data_root
def init_interval(self, interval):
"""
初始化interval
"""
interval_list = []
for s in self.signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
interval_list.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(interval_list)
def load_data(self, def load_data(self,
date: str, date: str,
update_type: str='rtn'): update_type: str='rtn'):
@ -259,7 +266,10 @@ class Trader(Account):
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量 # 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
# 不足的数量通过买入列表自适应调整 # 不足的数量通过买入列表自适应调整
# 这样能实现在因子值不足时也正常换仓 # 这样能实现在因子值不足时也正常换仓
max_sell_num = self.interval.loc[date]*len(last_position) try:
max_sell_num = self.interval.loc[date]*len(last_position)
except Exception:
raise ValueError(f'not found interval in {date}')
else: else:
last_position = pd.Series() last_position = pd.Series()
max_sell_num = self.num max_sell_num = self.num