Compare commits
No commits in common. "main" and "0.0.1" have entirely different histories.
10
README.md
10
README.md
|
@ -2,13 +2,3 @@
|
||||||
1. 基本实现交易流程
|
1. 基本实现交易流程
|
||||||
2. 新增强制筛选参数
|
2. 新增强制筛选参数
|
||||||
3. 新增自定义个股权重和仓位权重
|
3. 新增自定义个股权重和仓位权重
|
||||||
|
|
||||||
## Version 0.02
|
|
||||||
1. 优化非满仓情况下现金比例对收益的计算问题
|
|
||||||
2. 优化自定义权重下的判断逻辑
|
|
||||||
3. 修复没有满仓有现金比例时的收益计算问题
|
|
||||||
|
|
||||||
## Version 0.03
|
|
||||||
1. 新增分钟价格计算函数
|
|
||||||
2. 修复自定义价格时,买卖价格和非买卖价格的选取问题
|
|
||||||
3. 增加信号更新函数
|
|
|
@ -1,33 +0,0 @@
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def amount_specified(
|
|
||||||
min_data: pd.DataFrame,
|
|
||||||
post_adj_factor: pd.DataFrame,
|
|
||||||
min_amount: float=0.
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
计算指定金额下的平均后复权价格
|
|
||||||
|
|
||||||
Args:
|
|
||||||
min_data (pd.DataFrame): 指定日期分钟数据,需至少包含代码(stock_code)、分钟(Time)、价格(price)、成交量(vol)、成交金额(amount)列
|
|
||||||
post_adj_factor (pd.DataFrame): 后复权因子数据,需至少包含股票、后复权因子
|
|
||||||
min_amount (float): 指定最小金额下的平均价
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 按照指定最小量获取平均价格
|
|
||||||
stock_amt = min_data.pivot_table(index='Time', columns='stock_code', values='amount').cumsum()
|
|
||||||
stock_vol = min_data.pivot_table(index='Time', columns='stock_code', values='vol').cumsum()
|
|
||||||
amount_price = stock_amt / stock_vol
|
|
||||||
amount_price.iloc[1:] = amount_price.iloc[1:].where(stock_amt <= min_amount, np.nan)
|
|
||||||
amount_price = amount_price.unstack().reset_index()
|
|
||||||
amount_price.columns = ['stock_code', 'time', 'price']
|
|
||||||
amount_price = amount_price.dropna(subset=['price']).drop_duplicates(subset='stock_code', keep='last')
|
|
||||||
# 计算后复权价格
|
|
||||||
amount_price = amount_price.merge(post_adj_factor, on=['stock_code'], how='left').dropna(subset=['factor'])
|
|
||||||
amount_price['open_post'] = amount_price['price'] * amount_price['factor']
|
|
||||||
|
|
||||||
return amount_price[['stock_code','open_post']]
|
|
||||||
|
|
||||||
|
|
|
@ -2,5 +2,4 @@ from account import Account
|
||||||
from trader import Trader
|
from trader import Trader
|
||||||
from spread_backtest import Spread_Backtest
|
from spread_backtest import Spread_Backtest
|
||||||
|
|
||||||
|
__all__ = ['Account', 'Trader', 'Spread_Backtest']
|
||||||
__all__ = ['Account', 'Trader', 'Spread_Backtest', 'Specified_Price']
|
|
|
@ -1,68 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def check_buy_exclude(buy_exclude):
|
|
||||||
buy_exclude_dict = dict()
|
|
||||||
# 检查剔除条件
|
|
||||||
for cond in ['amt_20', 'list_days', 'mkt', 'price']:
|
|
||||||
if cond == 'amt_20':
|
|
||||||
# 20日成交量均值
|
|
||||||
if cond in buy_exclude:
|
|
||||||
amt_exclude = buy_exclude[cond]
|
|
||||||
else:
|
|
||||||
buy_exclude[cond] = (0, np.inf)
|
|
||||||
continue
|
|
||||||
if isinstance(buy_exclude[cond], (set,list,tuple)):
|
|
||||||
if len(amt_exclude) == 1:
|
|
||||||
buy_exclude_dict[cond] = (amt_exclude[0], np.inf)
|
|
||||||
else:
|
|
||||||
buy_exclude_dict[cond] = (amt_exclude[0], amt_exclude[1])
|
|
||||||
else:
|
|
||||||
raise Exception('wrong input type for buy exclude: amt_20, `set`, `list` or `tuple` is required')
|
|
||||||
|
|
||||||
if cond == 'list_days':
|
|
||||||
# 上市时间
|
|
||||||
if cond in buy_exclude:
|
|
||||||
list_exclude = buy_exclude[cond]
|
|
||||||
else:
|
|
||||||
buy_exclude[cond] = (0, np.inf)
|
|
||||||
continue
|
|
||||||
if isinstance(buy_exclude[cond], (set,list,tuple)):
|
|
||||||
if len(list_exclude) == 1:
|
|
||||||
buy_exclude_dict[cond] = (list_exclude[0], np.inf)
|
|
||||||
else:
|
|
||||||
buy_exclude_dict[cond] = (list_exclude[0], list_exclude[1])
|
|
||||||
else:
|
|
||||||
raise Exception('wrong input type for buy exclude: list_days, `set`, `list` or `tuple` is required')
|
|
||||||
|
|
||||||
if cond == 'mkt':
|
|
||||||
# 市值
|
|
||||||
if cond in buy_exclude:
|
|
||||||
mkt_exclude = buy_exclude[cond]
|
|
||||||
else:
|
|
||||||
buy_exclude[cond] = (0, np.inf)
|
|
||||||
continue
|
|
||||||
if isinstance(buy_exclude[cond], (set,list,tuple)):
|
|
||||||
if len(mkt_exclude) == 1:
|
|
||||||
buy_exclude_dict[cond] = (mkt_exclude[0], np.inf)
|
|
||||||
else:
|
|
||||||
buy_exclude_dict[cond] = (mkt_exclude[0], mkt_exclude[1])
|
|
||||||
else:
|
|
||||||
raise Exception('wrong input type for buy exclude: mkt, `set`, `list` or `tuple` is required')
|
|
||||||
|
|
||||||
if cond == 'price':
|
|
||||||
# 价格
|
|
||||||
if cond in buy_exclude:
|
|
||||||
price_exclude = buy_exclude[cond]
|
|
||||||
else:
|
|
||||||
buy_exclude[cond] = (0, np.inf)
|
|
||||||
continue
|
|
||||||
if isinstance(buy_exclude[cond], (set,list,tuple)):
|
|
||||||
if len(price_exclude) == 1:
|
|
||||||
buy_exclude_dict[cond] = (price_exclude[0], np.inf)
|
|
||||||
else:
|
|
||||||
buy_exclude_dict[cond] = (price_exclude[0], price_exclude[1])
|
|
||||||
else:
|
|
||||||
raise Exception('wrong input type for buy exclude: price, `set`, `list` or `tuple` is required')
|
|
||||||
|
|
||||||
return buy_exclude_dict
|
|
|
@ -10,8 +10,8 @@ if __name__ == '__main__':
|
||||||
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
|
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
|
||||||
save_dir = '/home/lenovo/quant/data/backtest/basic_data'
|
save_dir = '/home/lenovo/quant/data/backtest/basic_data'
|
||||||
|
|
||||||
for i,f in enumerate(['open_post','close_post','open_pre','close_pre','down_limit','up_limit','size','amount_20',
|
for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list',
|
||||||
'opening_info','ipo_days','margin_list','abnormal', 'recession']):
|
'abnormal', 'recession']):
|
||||||
if f in ['margin_list']:
|
if f in ['margin_list']:
|
||||||
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
|
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
|
||||||
else:
|
else:
|
||||||
|
@ -34,7 +34,7 @@ if __name__ == '__main__':
|
||||||
# 更新下一日的数据用于筛选
|
# 更新下一日的数据用于筛选
|
||||||
next_date = gft.days_after(df.index.max(), 1)
|
next_date = gft.days_after(df.index.max(), 1)
|
||||||
next_list = []
|
next_list = []
|
||||||
for i,f in enumerate(['close_pre','size','amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
|
for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
|
||||||
if f in ['margin_list']:
|
if f in ['margin_list']:
|
||||||
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
|
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from typing import Union, Iterable
|
from typing import Union, Iterable, Optional, Dict
|
||||||
|
|
||||||
class DataLoader():
|
class DataLoader():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -5,7 +5,6 @@ import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
from trader import Trader
|
from trader import Trader
|
||||||
from typing import Union, Dict
|
|
||||||
from rich import print as rprint
|
from rich import print as rprint
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
|
||||||
|
@ -61,26 +60,6 @@ class SpreadBacktest():
|
||||||
else:
|
else:
|
||||||
self.trader.update_signal(date, update_type='position')
|
self.trader.update_signal(date, update_type='position')
|
||||||
|
|
||||||
# 更新数据
|
|
||||||
def update_signal(self,
|
|
||||||
trade_time: str,
|
|
||||||
new_signal: pd.DataFrame):
|
|
||||||
"""
|
|
||||||
更新信号因子
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trade_time (str): 信号时间
|
|
||||||
new_signal (pd.DataFrame): 新的更新信号
|
|
||||||
"""
|
|
||||||
self.trader.signal[trade_time] = new_signal
|
|
||||||
|
|
||||||
def update_interval(self,
|
|
||||||
interval: Dict[str, Union[int,tuple,pd.Series]]={},
|
|
||||||
):
|
|
||||||
# 更新interval和weight
|
|
||||||
# 如果interval为固定比例则更新
|
|
||||||
self.trader.init_interval(interval)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def account_history(self):
|
def account_history(self):
|
||||||
return self.trader.account_history
|
return self.trader.account_history
|
||||||
|
|
312
trader.py
312
trader.py
|
@ -1,4 +1,5 @@
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
|
@ -7,7 +8,6 @@ from typing import Union, Iterable, Dict
|
||||||
|
|
||||||
from account import Account
|
from account import Account
|
||||||
from dataloader import DataLoader
|
from dataloader import DataLoader
|
||||||
from check_funcs import check_buy_exclude
|
|
||||||
|
|
||||||
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
|
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
|
||||||
from db_tushare import get_factor_tools
|
from db_tushare import get_factor_tools
|
||||||
|
@ -20,8 +20,7 @@ class Trader(Account):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行
|
signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行
|
||||||
interval (dict[str, (int, tuple, pd.Series)]):
|
interval (int, tuple, pd.Series): 交易间隔
|
||||||
交易间隔
|
|
||||||
num (int): 持仓数量
|
num (int): 持仓数量
|
||||||
ascending (bool): 因子方向
|
ascending (bool): 因子方向
|
||||||
with_st (bool): 是否包含st
|
with_st (bool): 是否包含st
|
||||||
|
@ -35,7 +34,7 @@ class Trader(Account):
|
||||||
slippage (tuple): 买入和卖出滑点
|
slippage (tuple): 买入和卖出滑点
|
||||||
commission (float): 佣金
|
commission (float): 佣金
|
||||||
tax (dict): 印花税
|
tax (dict): 印花税
|
||||||
force_exclude (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓
|
exclude_list (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓
|
||||||
- abnormal: 异常公告剔除,包含中止上市、立案调查、警示函等异常情况的剔除
|
- abnormal: 异常公告剔除,包含中止上市、立案调查、警示函等异常情况的剔除
|
||||||
- receesion: 财报同比或环比下降50%以上
|
- receesion: 财报同比或环比下降50%以上
|
||||||
- qualified_opinion: 会计保留意见
|
- qualified_opinion: 会计保留意见
|
||||||
|
@ -43,24 +42,25 @@ class Trader(Account):
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
signal: Dict[str, pd.DataFrame]=None,
|
signal: Dict[str, pd.DataFrame]=None,
|
||||||
interval: Dict[str, Union[int,tuple,pd.Series]]={},
|
interval: Dict[str, Union[int,tuple,pd.Series]]=1,
|
||||||
num: int=100,
|
num: int=100,
|
||||||
ascending: bool=False,
|
ascending: bool=False,
|
||||||
with_st: bool=False,
|
with_st: bool=False,
|
||||||
data_root:dict={},
|
data_root:dict={},
|
||||||
tick: bool=False,
|
tick: bool=False,
|
||||||
weight: str='avg',
|
weight: str='avg',
|
||||||
|
amt_filter: set=(0,),
|
||||||
|
ipo_days: int=20,
|
||||||
slippage :tuple=(0.001,0.001),
|
slippage :tuple=(0.001,0.001),
|
||||||
commission: float=0.0001,
|
commission: float=0.0001,
|
||||||
buy_exclude: Dict[str, Union[set,int,float]]={},
|
|
||||||
force_exclude: list=[],
|
|
||||||
account: dict={},
|
|
||||||
tax: dict={
|
tax: dict={
|
||||||
'1990-01-01': (0.001,0.001),
|
'1990-01-01': (0.001,0.001),
|
||||||
'2008-04-24': (0.001,0.001),
|
'2008-04-24': (0.001,0.001),
|
||||||
'2008-09-19': (0, 0.001),
|
'2008-09-19': (0, 0.001),
|
||||||
'2023-08-28': (0, 0.0005)
|
'2023-08-28': (0, 0.0005)
|
||||||
},
|
},
|
||||||
|
exclude_list: list=[],
|
||||||
|
account: dict={},
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
# 初始化账户
|
# 初始化账户
|
||||||
super().__init__(**account)
|
super().__init__(**account)
|
||||||
|
@ -78,7 +78,24 @@ class Trader(Account):
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
|
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
|
||||||
# interval
|
# interval
|
||||||
self.init_interval(interval)
|
self.interval = []
|
||||||
|
for s in signal:
|
||||||
|
if s in interval:
|
||||||
|
s_interval = interval[s]
|
||||||
|
if isinstance(s_interval, int):
|
||||||
|
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
|
||||||
|
df_interval[::s_interval] = 1
|
||||||
|
elif isinstance(s_interval, tuple):
|
||||||
|
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
|
||||||
|
df_interval[::s_interval[0]] = s_interval[1]
|
||||||
|
elif isinstance(s_interval, pd.Series):
|
||||||
|
df_interval = s_interval
|
||||||
|
else:
|
||||||
|
raise ValueError('invalid interval type')
|
||||||
|
self.interval.append(df_interval)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'not found interval for signal {s}')
|
||||||
|
self.interval = pd.concat(self.interval)
|
||||||
# num
|
# num
|
||||||
if isinstance(num, int):
|
if isinstance(num, int):
|
||||||
self.num = int(num)
|
self.num = int(num)
|
||||||
|
@ -99,6 +116,21 @@ class Trader(Account):
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
else:
|
else:
|
||||||
raise ValueError('invalid type for `weight`')
|
raise ValueError('invalid type for `weight`')
|
||||||
|
# amt_filter
|
||||||
|
if isinstance(amt_filter, (set,list,tuple)):
|
||||||
|
if len(amt_filter) == 1:
|
||||||
|
self.amt_filter_min = amt_filter[0]
|
||||||
|
self.amt_filter_max = np.inf
|
||||||
|
else:
|
||||||
|
self.amt_filter_min = amt_filter[0]
|
||||||
|
self.amt_filter_max = amt_filter[1]
|
||||||
|
else:
|
||||||
|
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
|
||||||
|
# ipo_days
|
||||||
|
if isinstance(ipo_days, int):
|
||||||
|
self.ipo_days = ipo_days
|
||||||
|
else:
|
||||||
|
raise Exception('wrong type for ipo_days, `int` is required')
|
||||||
# slippage
|
# slippage
|
||||||
if isinstance(slippage, tuple) and len(slippage) == 2:
|
if isinstance(slippage, tuple) and len(slippage) == 2:
|
||||||
self.slippage = slippage
|
self.slippage = slippage
|
||||||
|
@ -114,19 +146,17 @@ class Trader(Account):
|
||||||
self.tax = tax
|
self.tax = tax
|
||||||
else:
|
else:
|
||||||
raise ValueError('tax should be dict.')
|
raise ValueError('tax should be dict.')
|
||||||
# buy exclude
|
# exclude
|
||||||
self.buy_exclude = check_buy_exclude(buy_exclude)
|
if isinstance(exclude_list, list):
|
||||||
# force exclude
|
self.exclude_list = exclude_list
|
||||||
if isinstance(force_exclude, list):
|
|
||||||
self.force_exclude = force_exclude
|
|
||||||
optional_list = ['abnormal', 'recession']
|
optional_list = ['abnormal', 'recession']
|
||||||
for item in force_exclude:
|
for item in exclude_list:
|
||||||
if item in optional_list:
|
if item in optional_list:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected keyword argument '{item}'")
|
raise ValueError(f"Unexpected keyword argument '{item}'")
|
||||||
else:
|
else:
|
||||||
raise ValueError('force_exclude should be list.')
|
raise ValueError('exclude_list should be list.')
|
||||||
# data_root
|
# data_root
|
||||||
# 至少包含basic data路径,open信号默认使用basic_data
|
# 至少包含basic data路径,open信号默认使用basic_data
|
||||||
if len(data_root) <= 0:
|
if len(data_root) <= 0:
|
||||||
|
@ -144,29 +174,6 @@ class Trader(Account):
|
||||||
raise ValueError(f"data for signal {s} is not provided")
|
raise ValueError(f"data for signal {s} is not provided")
|
||||||
self.data_root = data_root
|
self.data_root = data_root
|
||||||
|
|
||||||
def init_interval(self, interval):
|
|
||||||
"""
|
|
||||||
初始化interval
|
|
||||||
"""
|
|
||||||
interval_list = []
|
|
||||||
for s in self.signal:
|
|
||||||
if s in interval:
|
|
||||||
s_interval = interval[s]
|
|
||||||
if isinstance(s_interval, int):
|
|
||||||
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
|
|
||||||
df_interval[::s_interval] = 1
|
|
||||||
elif isinstance(s_interval, tuple):
|
|
||||||
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
|
|
||||||
df_interval[::s_interval[0]] = s_interval[1]
|
|
||||||
elif isinstance(s_interval, pd.Series):
|
|
||||||
df_interval = s_interval
|
|
||||||
else:
|
|
||||||
raise ValueError('invalid interval type')
|
|
||||||
interval_list.append(df_interval)
|
|
||||||
else:
|
|
||||||
raise ValueError(f'not found interval for signal {s}')
|
|
||||||
self.interval = pd.concat(interval_list)
|
|
||||||
|
|
||||||
def load_data(self,
|
def load_data(self,
|
||||||
date: str,
|
date: str,
|
||||||
update_type: str='rtn'):
|
update_type: str='rtn'):
|
||||||
|
@ -180,19 +187,14 @@ class Trader(Account):
|
||||||
"""
|
"""
|
||||||
self.today_data = dict()
|
self.today_data = dict()
|
||||||
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
|
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
|
||||||
# 遍历从data_root中读取当日数据
|
|
||||||
for path in self.data_root:
|
|
||||||
if path == 'basic':
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
self.today_data[path] = DataLoader(os.path.join(self.data_root[path],f'{date}.csv'))
|
|
||||||
if update_type == 'position':
|
if update_type == 'position':
|
||||||
return True
|
return True
|
||||||
for s in self.signal:
|
for s in self.signal:
|
||||||
if s == 'open':
|
if s == 'open':
|
||||||
if s in self.data_root:
|
if s in self.data_root:
|
||||||
self.today_data[s+'_trade'] = DataLoader(self.today_data['open'].data[['open_post']].rename(columns={'open_post':'price'}))
|
continue
|
||||||
self.today_data['open'] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
|
else:
|
||||||
|
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
|
||||||
else:
|
else:
|
||||||
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
|
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
|
||||||
if 'close' in self.signal:
|
if 'close' in self.signal:
|
||||||
|
@ -207,12 +209,11 @@ class Trader(Account):
|
||||||
# 可执行日期
|
# 可执行日期
|
||||||
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index()
|
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index()
|
||||||
|
|
||||||
def get_weight(self, date, account_weight, untradable_list, next_position):
|
def get_weight(self, date, account_weight, next_position):
|
||||||
"""
|
"""
|
||||||
计算个股仓位
|
计算个股仓位
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
untradable_list (list): 无法交易列表
|
|
||||||
account_weight (float): 总权重,即当前持仓比例
|
account_weight (float): 总权重,即当前持仓比例
|
||||||
"""
|
"""
|
||||||
if isinstance(self.weight, str):
|
if isinstance(self.weight, str):
|
||||||
|
@ -220,26 +221,14 @@ class Trader(Account):
|
||||||
return account_weight / len(next_position)
|
return account_weight / len(next_position)
|
||||||
if isinstance(self.weight, pd.DataFrame):
|
if isinstance(self.weight, pd.DataFrame):
|
||||||
date_weight = self.weight.loc[date].dropna().sort_index()
|
date_weight = self.weight.loc[date].dropna().sort_index()
|
||||||
# untradable_list不要求指定权重用昨日权重填充
|
|
||||||
weight_list = pd.Series(index=next_position['stock_code'])
|
|
||||||
try:
|
try:
|
||||||
# 填充untradable_list权重
|
weight_list = date_weight.loc[next_position['stock_code'].to_list()].values
|
||||||
if len(untradable_list) > 0:
|
if weight_list.sum() > 1 + 1e5: # 防止数据精度的影响,给与一定的宽松
|
||||||
weight_list.loc[untradable_list] = self.position.set_index('stock_code').loc[untradable_list, 'weight']
|
|
||||||
except Exception:
|
|
||||||
raise ValueError('not found stock weight for untradable stocks in last position.')
|
|
||||||
try:
|
|
||||||
# 获取tradable_list权重,并对untradable_list占据的仓位进行调整
|
|
||||||
tradable_list = list(set(next_position['stock_code']) - set(untradable_list))
|
|
||||||
# 剔除untradable_list仓位后剩余持仓根据自定义权重分配
|
|
||||||
weight_list.loc[tradable_list] = date_weight.loc[tradable_list].values / date_weight.loc[tradable_list].sum() * (account_weight - weight_list.loc[untradable_list].sum())
|
|
||||||
weight_list = weight_list.values
|
|
||||||
if sum(weight_list) > 1 + 1e-5: # 防止数据精度的影响,给与一定的宽松
|
|
||||||
raise Exception(f"total weight of {date} is larger then 1.")
|
raise Exception(f"total weight of {date} is larger then 1.")
|
||||||
|
weight_list = account_weight * weight_list
|
||||||
return weight_list
|
return weight_list
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(e)
|
raise ValueError(f'not found stock weight in {date}')
|
||||||
raise ValueError(f'not found specified stock weight in {date}')
|
|
||||||
|
|
||||||
def get_next_position(self, date, factor):
|
def get_next_position(self, date, factor):
|
||||||
"""
|
"""
|
||||||
|
@ -252,94 +241,56 @@ class Trader(Account):
|
||||||
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
|
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
|
||||||
# 不足的数量通过买入列表自适应调整
|
# 不足的数量通过买入列表自适应调整
|
||||||
# 这样能实现在因子值不足时也正常换仓
|
# 这样能实现在因子值不足时也正常换仓
|
||||||
try:
|
|
||||||
max_sell_num = self.interval.loc[date]*len(last_position)
|
max_sell_num = self.interval.loc[date]*len(last_position)
|
||||||
except Exception:
|
|
||||||
raise ValueError(f'not found interval in {date}')
|
|
||||||
else:
|
else:
|
||||||
last_position = pd.Series()
|
last_position = pd.Series()
|
||||||
max_sell_num = self.num
|
max_sell_num = self.num
|
||||||
|
|
||||||
# 获取用于筛选的数据
|
# 获取用于筛选的数据
|
||||||
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
|
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
|
||||||
exclude_data = dict()
|
|
||||||
for cond in self.buy_exclude:
|
|
||||||
if cond == 'amt_20':
|
|
||||||
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
|
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
|
||||||
stock_amt_exclude = stock_amt20.copy()
|
stock_amt_filter = stock_amt20.copy()
|
||||||
stock_amt_exclude.loc[:] = 0
|
stock_amt_filter.loc[:] = 0
|
||||||
stock_amt_exclude.loc[(stock_amt20 > self.buy_exclude[cond][0]) & (stock_amt20 < self.buy_exclude[cond][1])] = 1
|
stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1
|
||||||
stock_amt_exclude = stock_amt_exclude.sort_index()
|
stock_amt_filter = stock_amt_filter.sort_index()
|
||||||
exclude_data[cond] = stock_amt_exclude
|
|
||||||
if cond == 'list_days':
|
|
||||||
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
|
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
|
||||||
stock_ipo_exclude = stock_ipo_days.copy()
|
stock_ipo_filter = stock_ipo_days.copy()
|
||||||
stock_ipo_exclude.loc[:] = 0
|
stock_ipo_filter.loc[:] = 0
|
||||||
stock_ipo_exclude.loc[(stock_ipo_days > self.buy_exclude[cond][0]) & (stock_ipo_days < self.buy_exclude[cond][1])] = 1
|
stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
|
||||||
stock_ipo_exclude = stock_ipo_exclude.sort_index()
|
|
||||||
exclude_data[cond] = stock_ipo_exclude
|
|
||||||
if cond == 'mkt':
|
|
||||||
stock_size = self.today_data['basic'].get(factor.index.values, 'size').sort_index()
|
|
||||||
stock_size_exclude = stock_size.copy()
|
|
||||||
stock_size_exclude.loc[:] = 0
|
|
||||||
stock_size_exclude.loc[(stock_size > self.buy_exclude[cond][0]) & (stock_size < self.buy_exclude[cond][1])] = 1
|
|
||||||
stock_size_exclude = stock_size_exclude.sort_index()
|
|
||||||
exclude_data[cond] = stock_size_exclude
|
|
||||||
if cond == 'price':
|
|
||||||
stock_price = self.today_data['basic'].get(factor.index.values, 'close_pre').sort_index()
|
|
||||||
stock_price_exclude = stock_price.copy()
|
|
||||||
stock_price_exclude.loc[:] = 0
|
|
||||||
stock_price_exclude.loc[(stock_price > self.buy_exclude[cond][0]) & (stock_price < self.buy_exclude[cond][1])] = 1
|
|
||||||
stock_price_exclude = stock_price_exclude.sort_index()
|
|
||||||
exclude_data[cond] = stock_price_exclude
|
|
||||||
|
|
||||||
# 剔除列表
|
# 剔除列表
|
||||||
# 包含强制剔除列表和买入剔除列表:
|
# 包含强制列表和普通列表:
|
||||||
# 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中
|
# 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中
|
||||||
# 买入剔除列表如果已经持有不会过滤只对新买入的过滤
|
# 普通列表如果已经持有不会过滤只对新买入的过滤
|
||||||
|
|
||||||
# 强制过滤列表
|
# 强制过滤列表
|
||||||
exclude_stock = []
|
exclude_stock = []
|
||||||
for cond in self.force_exclude:
|
for exclude in self.exclude_list:
|
||||||
if cond == 'abnormal':
|
if exclude == 'abnormal':
|
||||||
stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index()
|
stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index()
|
||||||
exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list()
|
exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list()
|
||||||
if cond == 'recession':
|
if exclude == 'recession':
|
||||||
stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index()
|
stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index()
|
||||||
exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list()
|
exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list()
|
||||||
exclude_stock = list(set(exclude_stock))
|
exclude_stock = list(set(exclude_stock))
|
||||||
force_exclude = copy.deepcopy(exclude_stock)
|
force_exclude = copy.deepcopy(exclude_stock)
|
||||||
# 买入过滤列表
|
# 普通过滤列表
|
||||||
buy_exclude = []
|
normal_exclude = []
|
||||||
for cond in self.buy_exclude:
|
normal_exclude += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list()
|
||||||
buy_exclude += exclude_data[cond].loc[exclude_data[cond] != 1].index.to_list()
|
normal_exclude += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list()
|
||||||
buy_exclude = list(set(buy_exclude))
|
normal_exclude = list(set(normal_exclude))
|
||||||
|
|
||||||
# 交易列表
|
# 交易列表
|
||||||
# 仓位判断给与计算误差冗余
|
if self.today_position_ratio <= 1.0:
|
||||||
if self.today_position_ratio <= 1.0 + 1e-5:
|
|
||||||
# 如果没有杠杆:
|
# 如果没有杠杆:
|
||||||
# 交易逻辑:
|
buy_list = []
|
||||||
# 1 判断卖出,如果当天跌停则减少实际卖出数量
|
|
||||||
# 2 判断买入:根据实际卖出数量和距离目标持仓数量判断买入数量,如果当天涨停则减少实际买入数量
|
|
||||||
untradable_list = []
|
|
||||||
|
|
||||||
# ----- 卖出 -----
|
|
||||||
sell_list = []
|
sell_list = []
|
||||||
limit_down_list = [] # 跌停股记录
|
untradable_list = []
|
||||||
|
target_list = []
|
||||||
# 遍历昨日持仓状态:
|
# ----- 卖出 -----
|
||||||
# 1 记录持仓状态
|
# 异常强制卖出
|
||||||
# 2 获取停牌股列表
|
|
||||||
# 3 获取异常强制卖出列表
|
|
||||||
last_position_status = pd.Series()
|
|
||||||
for stock in last_position.index:
|
for stock in last_position.index:
|
||||||
last_position_status.loc[stock] = stock_status.loc[stock]
|
if stock_status.loc[stock] in [0,2,5,7]:
|
||||||
if last_position_status.loc[stock] in [0,2]:
|
|
||||||
untradable_list.append(stock)
|
untradable_list.append(stock)
|
||||||
else:
|
|
||||||
if last_position_status.loc[stock] in [5,7]:
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
if stock in force_exclude:
|
if stock in force_exclude:
|
||||||
sell_list.append(stock)
|
sell_list.append(stock)
|
||||||
|
@ -354,21 +305,15 @@ class Trader(Account):
|
||||||
for stock in factor_filled.loc[list(set(last_position.index)-set(untradable_list)-set(sell_list))].sort_values(ascending=self.ascending).index.values[::-1]:
|
for stock in factor_filled.loc[list(set(last_position.index)-set(untradable_list)-set(sell_list))].sort_values(ascending=self.ascending).index.values[::-1]:
|
||||||
if len(sell_list) >= max_sell_num + force_sell_num:
|
if len(sell_list) >= max_sell_num + force_sell_num:
|
||||||
break
|
break
|
||||||
if last_position_status.loc[stock] in [0,2]:
|
if stock_status.loc[stock] in [0,2,5,7]:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if last_position_status.loc[stock] in [5,7]:
|
|
||||||
limit_down_list.append(stock)
|
|
||||||
sell_list.append(stock)
|
sell_list.append(stock)
|
||||||
sell_list = list(set(sell_list))
|
sell_list = list(set(sell_list))
|
||||||
# 实际卖出列表 = 卖出列表 - 跌停列表
|
|
||||||
sell_list = list(set(sell_list) - set(limit_down_list))
|
|
||||||
|
|
||||||
# ----- 买入 -----
|
# ----- 买入 -----
|
||||||
buy_list = []
|
|
||||||
|
|
||||||
# 剔除过滤条件后可买入列表
|
# 剔除过滤条件后可买入列表
|
||||||
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
|
after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
|
||||||
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
|
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
|
||||||
|
|
||||||
# 更新卖出后的持仓列表
|
# 更新卖出后的持仓列表
|
||||||
|
@ -376,8 +321,6 @@ class Trader(Account):
|
||||||
limit_up_list = [] # 涨停股记录
|
limit_up_list = [] # 涨停股记录
|
||||||
max_buy_num = max(0, self.num-len(last_position)+len(sell_list))
|
max_buy_num = max(0, self.num-len(last_position)+len(sell_list))
|
||||||
for stock in target_list:
|
for stock in target_list:
|
||||||
if len(buy_list) == max_buy_num:
|
|
||||||
break
|
|
||||||
if stock in after_sell_list:
|
if stock in after_sell_list:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
@ -389,23 +332,29 @@ class Trader(Account):
|
||||||
if stock_status.loc[stock] in [4,6]:
|
if stock_status.loc[stock] in [4,6]:
|
||||||
limit_up_list.append(stock)
|
limit_up_list.append(stock)
|
||||||
buy_list.append(stock)
|
buy_list.append(stock)
|
||||||
buy_list = list(set(buy_list))
|
if len(buy_list) == max_buy_num:
|
||||||
|
break
|
||||||
# 剔除同时在sell_list和buy_list的股票
|
# 剔除同时在sell_list和buy_list的股票
|
||||||
duplicate_stock = set(sell_list) & set(buy_list)
|
duplicate_stock = set(sell_list) & set(buy_list)
|
||||||
sell_list = list(set(sell_list) - duplicate_stock)
|
sell_list = list(set(sell_list) - duplicate_stock)
|
||||||
buy_list = list(set(buy_list) - duplicate_stock)
|
buy_list = list(set(buy_list) - duplicate_stock)
|
||||||
|
|
||||||
# 生成下一期持仓
|
# 生成下一期持仓
|
||||||
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
|
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
|
||||||
next_position['date'] = date
|
next_position['date'] = date
|
||||||
next_position['weight'] = self.get_weight(date, self.today_position_ratio, untradable_list+limit_down_list, next_position)
|
next_position['weight'] = self.get_weight(date, self.today_position_ratio, next_position)
|
||||||
|
# 剔除无法买入的涨停股,这部分仓位空出
|
||||||
# 剔除无法买入且不在昨日持仓中的涨停股,这部分仓位空出
|
next_position = next_position[~next_position['stock_code'].isin(limit_up_list)]
|
||||||
next_position = next_position[~next_position['stock_code'].isin(list(set(limit_up_list)-set(last_position.index)))]
|
|
||||||
next_position['margin_trade'] = 0
|
next_position['margin_trade'] = 0
|
||||||
else:
|
else:
|
||||||
# 如果有杠杆:
|
# 如果有杠杆:
|
||||||
|
def assign_stock(normal_list, margin_list, margin_needed, stock, status):
|
||||||
|
if status == 1:
|
||||||
|
if len(margin_list) < margin_needed:
|
||||||
|
margin_list.append(stock)
|
||||||
|
else:
|
||||||
|
if len(normal_list) < self.num - margin_needed:
|
||||||
|
normal_list.append(stock)
|
||||||
|
return normal_list, margin_list
|
||||||
# 计算需要融资融券标的数量
|
# 计算需要融资融券标的数量
|
||||||
margin_ratio = max(self.today_position_ratio-1, 0)
|
margin_ratio = max(self.today_position_ratio-1, 0)
|
||||||
margin_needed = round(self.num * margin_ratio)
|
margin_needed = round(self.num * margin_ratio)
|
||||||
|
@ -423,6 +372,7 @@ class Trader(Account):
|
||||||
last_normal_list = []
|
last_normal_list = []
|
||||||
|
|
||||||
# ----- 卖出 -----
|
# ----- 卖出 -----
|
||||||
|
buy_list = []
|
||||||
sell_list = []
|
sell_list = []
|
||||||
untradable_list = []
|
untradable_list = []
|
||||||
# 分别更新融资融券池的和非融资融券池
|
# 分别更新融资融券池的和非融资融券池
|
||||||
|
@ -474,10 +424,8 @@ class Trader(Account):
|
||||||
next_normal_list = list(set(last_normal_list) - set(sell_list))
|
next_normal_list = list(set(last_normal_list) - set(sell_list))
|
||||||
|
|
||||||
# ----- 买入 -----
|
# ----- 买入 -----
|
||||||
buy_list = []
|
|
||||||
|
|
||||||
# 剔除过滤条件后可买入列表
|
# 剔除过滤条件后可买入列表
|
||||||
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
|
after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
|
||||||
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
|
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
|
||||||
|
|
||||||
# 更新卖出后的持仓列表
|
# 更新卖出后的持仓列表
|
||||||
|
@ -486,8 +434,6 @@ class Trader(Account):
|
||||||
# 融资融券池的和非融资融券池的分开更新
|
# 融资融券池的和非融资融券池的分开更新
|
||||||
# 更新融资融券池
|
# 更新融资融券池
|
||||||
for stock in target_list:
|
for stock in target_list:
|
||||||
if len(next_margin_list) >= margin_needed:
|
|
||||||
break
|
|
||||||
if stock in after_sell_list:
|
if stock in after_sell_list:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
@ -500,11 +446,10 @@ class Trader(Account):
|
||||||
if stock_status.loc[stock] in [4,6]:
|
if stock_status.loc[stock] in [4,6]:
|
||||||
limit_up_list.append(stock)
|
limit_up_list.append(stock)
|
||||||
next_margin_list.append(stock)
|
next_margin_list.append(stock)
|
||||||
next_margin_list = list(set(next_margin_list))
|
if len(next_margin_list) >= margin_needed:
|
||||||
|
break
|
||||||
# 更新非融资融券池
|
# 更新非融资融券池
|
||||||
for stock in target_list:
|
for stock in target_list:
|
||||||
if len(next_normal_list) >= self.num - len(next_margin_list):
|
|
||||||
break
|
|
||||||
if stock in (set(after_sell_list) | set(next_margin_list)):
|
if stock in (set(after_sell_list) | set(next_margin_list)):
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
@ -516,21 +461,20 @@ class Trader(Account):
|
||||||
if stock_status.loc[stock] in [4,6]:
|
if stock_status.loc[stock] in [4,6]:
|
||||||
limit_up_list.append(stock)
|
limit_up_list.append(stock)
|
||||||
next_normal_list.append(stock)
|
next_normal_list.append(stock)
|
||||||
next_normal_list = list(set(next_normal_list))
|
if len(next_normal_list) >= self.num - len(next_margin_list):
|
||||||
|
break
|
||||||
|
|
||||||
next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list})
|
next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list})
|
||||||
next_position['date'] = date
|
next_position['date'] = date
|
||||||
|
|
||||||
# 融资融券数量
|
# 融资融券数量
|
||||||
margin_num = len(next_margin_list)
|
margin_num = len(next_margin_list)
|
||||||
next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), untradable_list, next_position)
|
next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), next_position)
|
||||||
next_position['margin_trade'] = 0
|
next_position['margin_trade'] = 0
|
||||||
next_position = next_position.set_index(['stock_code'])
|
next_position = next_position.set_index(['stock_code'])
|
||||||
next_position.loc[next_margin_list, 'margin_trade'] = 1
|
next_position.loc[next_margin_list, 'margin_trade'] = 1
|
||||||
next_position = next_position.reset_index()
|
next_position = next_position.reset_index()
|
||||||
# 剔除无法买入的涨停股,这部分仓位空出
|
# 剔除无法买入的涨停股,这部分仓位空出
|
||||||
next_position = next_position[~next_position['stock_code'].isin(limit_up_list)]
|
next_position = next_position[~next_position['stock_code'].isin(limit_up_list)]
|
||||||
|
|
||||||
# 检测当前持仓是否可以交易
|
# 检测当前持仓是否可以交易
|
||||||
frozen_list = []
|
frozen_list = []
|
||||||
if len(self.position) > 0:
|
if len(self.position) > 0:
|
||||||
|
@ -552,19 +496,13 @@ class Trader(Account):
|
||||||
buy_list (Iterable[str]): 买入目标
|
buy_list (Iterable[str]): 买入目标
|
||||||
sell_list (Iterable[str]): 卖出目标
|
sell_list (Iterable[str]): 卖出目标
|
||||||
"""
|
"""
|
||||||
if trade_time+'_trade' in self.today_data:
|
stock_price = self.today_data[trade_time]
|
||||||
trade_price = self.today_data[trade_time+'_trade']
|
|
||||||
basic_price = self.today_data[trade_time]
|
|
||||||
else:
|
|
||||||
basic_price = self.today_data[trade_time]
|
|
||||||
trade_price = basic_price
|
|
||||||
target_price = pd.Series(index=target)
|
target_price = pd.Series(index=target)
|
||||||
sell_list = list(set(target) & set(sell_list))
|
sell_list = list(set(target) & set(sell_list))
|
||||||
buy_list = list(set(target) & set(buy_list))
|
buy_list = list(set(target) & set(buy_list))
|
||||||
# 根据交易和非交易标的分别获取目标价格
|
target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
|
||||||
target_price.loc[target] = basic_price.get(target, 'price').fillna(0)
|
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1])
|
||||||
target_price.loc[sell_list] = trade_price.get(sell_list, 'price') * (1 - self.current_fee[1])
|
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0])
|
||||||
target_price.loc[buy_list] = trade_price.get(buy_list, 'price') * (1 + self.current_fee[0])
|
|
||||||
return target_price
|
return target_price
|
||||||
|
|
||||||
def check_update_status(self,
|
def check_update_status(self,
|
||||||
|
@ -663,8 +601,7 @@ class Trader(Account):
|
||||||
if cur_pos['weight'].sum() == 0:
|
if cur_pos['weight'].sum() == 0:
|
||||||
pnl = 0
|
pnl = 0
|
||||||
else:
|
else:
|
||||||
cash = max(0, 1 - cur_pos['weight'].sum())
|
pnl = (cur_pos['end_weight'].sum() - cur_pos['weight'].sum())
|
||||||
pnl = ((cur_pos['end_weight'].sum() + cash) / (cur_pos['weight'].sum() + cash)) - 1
|
|
||||||
self.account *= 1+pnl
|
self.account *= 1+pnl
|
||||||
self.account_history = self.account_history.append({
|
self.account_history = self.account_history.append({
|
||||||
'date': date,
|
'date': date,
|
||||||
|
@ -675,18 +612,6 @@ class Trader(Account):
|
||||||
}, ignore_index=True)
|
}, ignore_index=True)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def update_next_weight(position):
|
|
||||||
"""
|
|
||||||
根据收盘权重计算下一时刻新的个股权重
|
|
||||||
"""
|
|
||||||
if position['weight'].sum() <= 1 + 1e-5:
|
|
||||||
# 非融资情况
|
|
||||||
cash = max(0, 1 - position['weight'].sum())
|
|
||||||
return position['end_weight'] / (position['end_weight'].sum() + cash)
|
|
||||||
else:
|
|
||||||
return position['weight'].sum() * (position['end_weight'] / position['end_weight'].sum())
|
|
||||||
|
|
||||||
def update_signal(self,
|
def update_signal(self,
|
||||||
date:str,
|
date:str,
|
||||||
update_type='rtn'):
|
update_type='rtn'):
|
||||||
|
@ -705,9 +630,7 @@ class Trader(Account):
|
||||||
self.account_history = self.account_history.query(f'date != "{date}" ', engine='python')
|
self.account_history = self.account_history.query(f'date != "{date}" ', engine='python')
|
||||||
if date in self.position_history:
|
if date in self.position_history:
|
||||||
self.position_history.pop(date)
|
self.position_history.pop(date)
|
||||||
# ----- 更新当日回测数据 ------
|
# 更新持仓信号
|
||||||
# 更新当前日期和持仓信号
|
|
||||||
self.current_date = date
|
|
||||||
self.load_data(date, update_type)
|
self.load_data(date, update_type)
|
||||||
# 更新当日持仓比例
|
# 更新当日持仓比例
|
||||||
if isinstance(self.position_ratio, float):
|
if isinstance(self.position_ratio, float):
|
||||||
|
@ -723,14 +646,13 @@ class Trader(Account):
|
||||||
fee = (fee[0] + current_tax[0], fee[1] + current_tax[1])
|
fee = (fee[0] + current_tax[0], fee[1] + current_tax[1])
|
||||||
self.current_fee = fee
|
self.current_fee = fee
|
||||||
# 如果当前持仓不空,添加隔夜收益,否则直接买入
|
# 如果当前持仓不空,添加隔夜收益,否则直接买入
|
||||||
position_fields = ['stock_code','date','weight','margin_trade','open','close','end_weight']
|
|
||||||
if len(self.position) == 0:
|
if len(self.position) == 0:
|
||||||
cur_pos = pd.DataFrame(columns=position_fields)
|
cur_pos = pd.DataFrame(columns=['stock_code','date','weight','open','close','margin_trade'])
|
||||||
else:
|
else:
|
||||||
cur_pos = self.position.copy()
|
cur_pos = self.position.copy()
|
||||||
# 冻结列表
|
# 冻结列表
|
||||||
frozen_list = []
|
frozen_list = []
|
||||||
# ----- 遍历各个交易时间的信号 -----
|
# 遍历各个交易时间的信号
|
||||||
for _,trade_time in enumerate(self.signal):
|
for _,trade_time in enumerate(self.signal):
|
||||||
if self.check_update_status(date, trade_time):
|
if self.check_update_status(date, trade_time):
|
||||||
continue
|
continue
|
||||||
|
@ -741,7 +663,6 @@ class Trader(Account):
|
||||||
factor = self.signal[trade_time].loc[date]
|
factor = self.signal[trade_time].loc[date]
|
||||||
# 获取当前、持仓
|
# 获取当前、持仓
|
||||||
sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor)
|
sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor)
|
||||||
|
|
||||||
# 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算
|
# 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算
|
||||||
if update_type == 'position':
|
if update_type == 'position':
|
||||||
self.position_history[date] = next_position.copy()
|
self.position_history[date] = next_position.copy()
|
||||||
|
@ -753,11 +674,9 @@ class Trader(Account):
|
||||||
# 计算收益
|
# 计算收益
|
||||||
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
|
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
|
||||||
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
|
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
|
||||||
# 价格缺失用初始weight填充
|
|
||||||
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
|
|
||||||
self.update_account(date, trade_time, cur_pos, next_position)
|
self.update_account(date, trade_time, cur_pos, next_position)
|
||||||
# 更新仓位
|
# 更新仓位
|
||||||
cur_pos['weight'] = self.update_next_weight(cur_pos)
|
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
|
||||||
# 调整权重:买入、卖出、仓位再平衡
|
# 调整权重:买入、卖出、仓位再平衡
|
||||||
next_position = self.reblance_weight(trade_time, cur_pos, next_position)
|
next_position = self.reblance_weight(trade_time, cur_pos, next_position)
|
||||||
else:
|
else:
|
||||||
|
@ -772,19 +691,14 @@ class Trader(Account):
|
||||||
# 停牌价格不变
|
# 停牌价格不变
|
||||||
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
|
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
|
||||||
cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open']
|
cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open']
|
||||||
# 更新当日收益
|
|
||||||
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
|
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
|
||||||
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
|
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
|
||||||
# 价格缺失用初始weight填充
|
|
||||||
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
|
|
||||||
position_record = cur_pos.copy()
|
position_record = cur_pos.copy()
|
||||||
position_record['end_weight'] = self.update_next_weight(position_record)
|
position_record['end_weight'] = (position_record['end_weight'] / position_record['end_weight'].sum()) * position_record['weight'].sum()
|
||||||
self.position_history[date] = position_record.copy()[position_fields]
|
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
|
||||||
# 更新当期收盘后个股仓位作为下一期的开盘仓位
|
|
||||||
cur_pos['weight'] = self.update_next_weight(cur_pos)
|
|
||||||
next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']]
|
next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']]
|
||||||
next_position['open'] = cur_pos['close']
|
next_position['open'] = cur_pos['close']
|
||||||
self.update_account(date, trade_time, cur_pos, cur_pos)
|
self.update_account(date, trade_time, cur_pos, cur_pos)
|
||||||
# 记录当前时刻最终持仓和个股权重
|
|
||||||
self.position = next_position.copy()
|
self.position = next_position.copy()
|
||||||
|
self.position_history[date] = position_record.copy()
|
||||||
return True
|
return True
|
Loading…
Reference in New Issue