parent
30a998b58a
commit
721e2a42f5
|
@ -5,6 +5,7 @@ import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
from trader import Trader
|
from trader import Trader
|
||||||
|
from typing import Union, Dict
|
||||||
from rich import print as rprint
|
from rich import print as rprint
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
|
||||||
|
@ -60,6 +61,26 @@ class SpreadBacktest():
|
||||||
else:
|
else:
|
||||||
self.trader.update_signal(date, update_type='position')
|
self.trader.update_signal(date, update_type='position')
|
||||||
|
|
||||||
|
# 更新数据
|
||||||
|
def update_signal(self,
|
||||||
|
trade_time: str,
|
||||||
|
new_signal: pd.DataFrame):
|
||||||
|
"""
|
||||||
|
更新信号因子
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trade_time (str): 信号时间
|
||||||
|
new_signal (pd.DataFrame): 新的更新信号
|
||||||
|
"""
|
||||||
|
self.trader.signal[trade_time] = new_signal
|
||||||
|
|
||||||
|
def update_interval(self,
|
||||||
|
interval: Dict[str, Union[int,tuple,pd.Series]]={},
|
||||||
|
):
|
||||||
|
# 更新interval和weight
|
||||||
|
# 如果interval为固定比例则更新
|
||||||
|
self.trader.init_interval(interval)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def account_history(self):
|
def account_history(self):
|
||||||
return self.trader.account_history
|
return self.trader.account_history
|
||||||
|
|
50
trader.py
50
trader.py
|
@ -20,7 +20,8 @@ class Trader(Account):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行
|
signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行
|
||||||
interval (int, tuple, pd.Series): 交易间隔
|
interval (dict[str, (int, tuple, pd.Series)]):
|
||||||
|
交易间隔
|
||||||
num (int): 持仓数量
|
num (int): 持仓数量
|
||||||
ascending (bool): 因子方向
|
ascending (bool): 因子方向
|
||||||
with_st (bool): 是否包含st
|
with_st (bool): 是否包含st
|
||||||
|
@ -42,7 +43,7 @@ class Trader(Account):
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
signal: Dict[str, pd.DataFrame]=None,
|
signal: Dict[str, pd.DataFrame]=None,
|
||||||
interval: Dict[str, Union[int,tuple,pd.Series]]=1,
|
interval: Dict[str, Union[int,tuple,pd.Series]]={},
|
||||||
num: int=100,
|
num: int=100,
|
||||||
ascending: bool=False,
|
ascending: bool=False,
|
||||||
with_st: bool=False,
|
with_st: bool=False,
|
||||||
|
@ -78,24 +79,7 @@ class Trader(Account):
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
|
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
|
||||||
# interval
|
# interval
|
||||||
self.interval = []
|
self.init_interval(interval)
|
||||||
for s in signal:
|
|
||||||
if s in interval:
|
|
||||||
s_interval = interval[s]
|
|
||||||
if isinstance(s_interval, int):
|
|
||||||
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
|
|
||||||
df_interval[::s_interval] = 1
|
|
||||||
elif isinstance(s_interval, tuple):
|
|
||||||
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
|
|
||||||
df_interval[::s_interval[0]] = s_interval[1]
|
|
||||||
elif isinstance(s_interval, pd.Series):
|
|
||||||
df_interval = s_interval
|
|
||||||
else:
|
|
||||||
raise ValueError('invalid interval type')
|
|
||||||
self.interval.append(df_interval)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'not found interval for signal {s}')
|
|
||||||
self.interval = pd.concat(self.interval)
|
|
||||||
# num
|
# num
|
||||||
if isinstance(num, int):
|
if isinstance(num, int):
|
||||||
self.num = int(num)
|
self.num = int(num)
|
||||||
|
@ -174,6 +158,29 @@ class Trader(Account):
|
||||||
raise ValueError(f"data for signal {s} is not provided")
|
raise ValueError(f"data for signal {s} is not provided")
|
||||||
self.data_root = data_root
|
self.data_root = data_root
|
||||||
|
|
||||||
|
def init_interval(self, interval):
|
||||||
|
"""
|
||||||
|
初始化interval
|
||||||
|
"""
|
||||||
|
interval_list = []
|
||||||
|
for s in self.signal:
|
||||||
|
if s in interval:
|
||||||
|
s_interval = interval[s]
|
||||||
|
if isinstance(s_interval, int):
|
||||||
|
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
|
||||||
|
df_interval[::s_interval] = 1
|
||||||
|
elif isinstance(s_interval, tuple):
|
||||||
|
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
|
||||||
|
df_interval[::s_interval[0]] = s_interval[1]
|
||||||
|
elif isinstance(s_interval, pd.Series):
|
||||||
|
df_interval = s_interval
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid interval type')
|
||||||
|
interval_list.append(df_interval)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'not found interval for signal {s}')
|
||||||
|
self.interval = pd.concat(interval_list)
|
||||||
|
|
||||||
def load_data(self,
|
def load_data(self,
|
||||||
date: str,
|
date: str,
|
||||||
update_type: str='rtn'):
|
update_type: str='rtn'):
|
||||||
|
@ -259,7 +266,10 @@ class Trader(Account):
|
||||||
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
|
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
|
||||||
# 不足的数量通过买入列表自适应调整
|
# 不足的数量通过买入列表自适应调整
|
||||||
# 这样能实现在因子值不足时也正常换仓
|
# 这样能实现在因子值不足时也正常换仓
|
||||||
|
try:
|
||||||
max_sell_num = self.interval.loc[date]*len(last_position)
|
max_sell_num = self.interval.loc[date]*len(last_position)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(f'not found interval in {date}')
|
||||||
else:
|
else:
|
||||||
last_position = pd.Series()
|
last_position = pd.Series()
|
||||||
max_sell_num = self.num
|
max_sell_num = self.num
|
||||||
|
|
Loading…
Reference in New Issue