Update: 增加信号更新函数 (#26);

新增interval初始化函数;
This commit is contained in:
binz 2024-06-24 21:59:34 +08:00
parent 30a998b58a
commit 721e2a42f5
2 changed files with 67 additions and 36 deletions

View File

@ -5,6 +5,7 @@ import pandas as pd
import numpy as np
import time
from trader import Trader
from typing import Union, Dict
from rich import print as rprint
from rich.table import Table
@ -60,6 +61,26 @@ class SpreadBacktest():
else:
self.trader.update_signal(date, update_type='position')
# 更新数据
def update_signal(self,
trade_time: str,
new_signal: pd.DataFrame):
"""
更新信号因子
Args:
trade_time (str): 信号时间
new_signal (pd.DataFrame): 新的更新信号
"""
self.trader.signal[trade_time] = new_signal
def update_interval(self,
interval: Dict[str, Union[int,tuple,pd.Series]]={},
):
# 更新interval和weight
# 如果interval为固定比例则更新
self.trader.init_interval(interval)
@property
def account_history(self):
return self.trader.account_history

View File

@ -20,7 +20,8 @@ class Trader(Account):
Args:
signal (dict[str, pd.DataFrame]): 目标因子按顺序执行
interval (int, tuple, pd.Series): 交易间隔
interval (dict[str, (int, tuple, pd.Series)]):
交易间隔
num (int): 持仓数量
ascending (bool): 因子方向
with_st (bool): 是否包含st
@ -42,7 +43,7 @@ class Trader(Account):
"""
def __init__(self,
signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]=1,
interval: Dict[str, Union[int,tuple,pd.Series]]={},
num: int=100,
ascending: bool=False,
with_st: bool=False,
@ -78,24 +79,7 @@ class Trader(Account):
if len(kwargs) > 0:
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
# interval
self.interval = []
for s in signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
self.interval.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(self.interval)
self.init_interval(interval)
# num
if isinstance(num, int):
self.num = int(num)
@ -174,6 +158,29 @@ class Trader(Account):
raise ValueError(f"data for signal {s} is not provided")
self.data_root = data_root
def init_interval(self, interval):
"""
初始化interval
"""
interval_list = []
for s in self.signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
interval_list.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(interval_list)
def load_data(self,
date: str,
update_type: str='rtn'):
@ -259,7 +266,10 @@ class Trader(Account):
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
# 不足的数量通过买入列表自适应调整
# 这样能实现在因子值不足时也正常换仓
try:
max_sell_num = self.interval.loc[date]*len(last_position)
except Exception:
raise ValueError(f'not found interval in {date}')
else:
last_position = pd.Series()
max_sell_num = self.num