修正停牌时目标持仓计算问题

This commit is contained in:
binz 2024-05-25 23:48:53 +08:00
parent f5dbecd9ee
commit 20c1477b42
3 changed files with 89 additions and 33 deletions

View File

@ -9,7 +9,7 @@ if __name__ == '__main__':
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
save_dir = '/home/lenovo/quant/data/backtest/basic_data'
for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list']):
for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list','abnormal']):
if f in ['margin_list']:
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
else:
@ -32,7 +32,7 @@ if __name__ == '__main__':
# 更新下一日的数据用于筛选
next_date = gft.days_after(df.index.max(), 1)
next_list = []
for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list']):
for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal']):
if f in ['margin_list']:
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
else:

View File

@ -25,10 +25,13 @@ class Spread_Backtest():
else:
bkt_start = max(start, self.trader.avaliable_date.index.min())
rec_end = min(end, self.trader.avaliable_date.index.max())
# 如果传入的第一个因子时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益
# 如果传入的第一个因子的有值时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益
first_signal = self.trader.signal[list(self.trader.signal.keys())[0]]
if rec_end < first_signal.index.max():
bkt_end = first_signal.loc[rec_end:].index.to_list()[1]
if rec_end in first_signal.index:
bkt_end = first_signal.loc[rec_end:].index.to_list()[1]
else:
bkt_end = first_signal.loc[rec_end:].index.to_list()[0]
else:
bkt_end = rec_end
print(f'回测区间: {bkt_start} - {bkt_end}')
@ -92,5 +95,4 @@ def print_year_rtn(year_rtn: pd.DataFrame) -> None:
table.add_row(*new_row)
rprint(table)

110
trader.py
View File

@ -4,7 +4,10 @@ import sys, os
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools
gft = get_factor_tools()
from typing import Union, Iterable, Dict
from ordered_set import OrderedSet
from account import Account
from dataloader import DataLoader
@ -18,7 +21,6 @@ class Trader(Account):
interval (int, tuple, pd.Series): 交易间隔
num (int): 持仓数量
ascending (bool): 因子方向
fee (tuple): 买入成本和卖出成本
with_st (bool): 是否包含st
tick (bool): 是否开始tick模拟模式(开发中)
weight (str): 权重分配
@ -26,19 +28,29 @@ class Trader(Account):
amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限
data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列
ipo_days (int): 筛选上市时间
slippage (tuple): 买入和卖出滑点
commission (float): 佣金
tax (dict): 印花税
"""
def __init__(self,
signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]=1,
num: int=100,
ascending: bool=False,
fee :tuple=(0.001,0.002),
with_st: bool=False,
data_root:dict={},
tick: bool=False,
weight: str='avg',
amt_filter: set=(0,),
ipo_days: int=20,
slippage :tuple=(0.001,0.001),
commission: float=0.0001,
tax: dict={
'1990-01-01': (0.001,0.001),
'2008-04-24': (0.001,0.001),
'2008-09-19': (0, 0.001),
'2023-08-28': (0, 0.0005)
},
**kwargs) -> None:
# 初始化账户
super().__init__(**kwargs.get('account', {}))
@ -50,6 +62,9 @@ class Trader(Account):
self.signal[s] = gft.return_factor(self.signal[s], self.signal[s].index.min(), self.signal[s].index.max(), return_type='origin')
else:
raise ValueError('type of signal is invalid')
# ----------
# 参数检验
# ----------
# interval
self.interval = []
for s in signal:
@ -79,11 +94,6 @@ class Trader(Account):
self.ascending = ascending
else:
raise ValueError('invalid type for ascending')
# fee
if isinstance(fee, tuple) and len(fee) == 2:
self.fee = fee
else:
raise ValueError('invalid input for fee')
# with_st
if isinstance(with_st, bool):
self.with_st = with_st
@ -104,6 +114,21 @@ class Trader(Account):
self.ipo_days = ipo_days
else:
raise Exception('wrong type for ipo_days, `int` is required')
# slippage
if isinstance(slippage, tuple) and len(slippage) == 2:
self.slippage = slippage
else:
raise ValueError('slippage should be set.')
# commission
if isinstance(commission, float):
self.commission = commission
else:
raise ValueError('commission should be flaot.')
# tax
if isinstance(tax, dict):
self.tax = tax
else:
raise ValueError('tax should be dict.')
# data_root
# 至少包含basic data路径open信号默认使用basic_data
if len(data_root) <= 0:
@ -175,27 +200,38 @@ class Trader(Account):
stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
# 交易列表
if self.leverage <= 1.0:
# 获取当前时间目标列表和冻结(无法交易)列表
# 获取当前时间目标列表和冻结(无法交易)列表:
# 1 保留冻结的昨日持仓
# 2 目标持仓列表会对当日因子进行ST、成交量、上市时间滤后按因子排序方向选取
for stock in last_position.index:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2]:
target_list.append(stock)
for stock in factor.dropna().sort_values(ascending=self.ascending).index.values:
# 目标持仓数量达到持仓数量则中止
if len(target_list) == self.num:
break
# 成交量、上市时间过滤
if (stock_amt_filter.loc[stock] != 1) or (stock_ipo_filter.loc[stock] != 1):
continue
# ST过滤
if self.with_st:
if stock_status.loc[stock] in [0,2]:
if stock_status.loc[stock] in [0,2,5,7]:
if stock in last_position.index:
target_list.append(stock)
else:
target_list.append(stock)
else:
# 非ST
if stock_status.loc[stock] in [3,7]:
if stock_status.loc[stock] in [3,6]:
target_list.append(stock)
else:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2,6]:
if stock in last_position.index:
target_list.append(stock)
if len(target_list) == self.num:
break
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2,7]:
if stock in last_position.index:
target_list.append(stock)
target_list = list(OrderedSet(target_list))
# 如果没有杠杆
buy_list = []
sell_list = []
@ -205,7 +241,7 @@ class Trader(Account):
factor = factor.fillna(factor.max()+1)
else:
factor = factor.fillna(factor.min()-1)
for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]:
for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]:
if stock in target_list:
continue
else:
@ -213,8 +249,9 @@ class Trader(Account):
continue
else:
sell_list.append(stock)
if len(sell_list) == max_trade_num:
if len(sell_list) >= max_trade_num:
break
# ----- 买入 -----
# 卖出后持仓列表
after_sell_list = set(last_position.index) - set(sell_list)
@ -244,25 +281,34 @@ class Trader(Account):
margin_ratio = max(self.leverage-1, 0)
margin_needed = round(self.num * margin_ratio)
is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index()
# 获取当前时间目标列表和冻结(无法交易)列表
# 获取当前时间目标列表和冻结(无法交易)列表:
# 1 保留冻结的昨日持仓
# 2 目标持仓列表会对当日因子进行ST、成交量、上市时间滤后按因子排序方向选取
# 3 根据融资比例将股票分为常规池子和融资池子
normal_list = []
margin_list = []
for stock in last_position.index:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2]:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
for stock in factor.dropna().sort_values(ascending=self.ascending).index.values:
if stock_amt_filter.loc[stock] != 1:
continue
if self.with_st:
if stock_status.loc[stock] in [0,2]:
if stock_status.loc[stock] in [0,2,5,7]:
if stock in last_position.index:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
# 非ST
if stock_status.loc[stock] in [3,7]:
if stock_status.loc[stock] in [3,6]:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2,6]:
if stock_status.loc[stock] in [0,2,7]:
if stock in last_position.index:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
if len(normal_list + margin_list) == self.num:
@ -329,7 +375,6 @@ class Trader(Account):
next_position['date'] = date
# 融资融券数量
margin_num = len(next_margin_list)
# next_position['weight'] = self.leverage*((margin_needed-margin_num)/margin_needed) / len(next_position)
next_position['weight'] = (1 + (margin_num / self.num)) / self.num
next_position['margin_trade'] = 0
next_position = next_position.set_index(['stock_code'])
@ -361,8 +406,8 @@ class Trader(Account):
sell_list = list(set(target) & set(sell_list))
buy_list = list(set(target) & set(buy_list))
target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.fee[1])
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.fee[0])
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1])
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0])
return target_price
def check_update_status(self,
@ -473,8 +518,8 @@ class Trader(Account):
return True
def update_signal(self,
date:str,
update_type='rtn'):
date:str,
update_type='rtn'):
"""
更新信号收益
@ -492,6 +537,14 @@ class Trader(Account):
self.position_history.pop(date)
# 更新持仓信号
self.load_data(date, update_type)
# 更新费用
fee = (self.commission + self.slippage[0], self.commission + self.slippage[1])
current_tax = (0.001, 0.001)
for time,tax_rate in self.tax.items():
if date > time:
current_tax = tax_rate
fee = (fee[0] + current_tax[0], fee[1] + current_tax[1])
self.current_fee = fee
# 如果当前持仓不空,添加隔夜收益,否则直接买入
if len(self.position) == 0:
cur_pos = pd.DataFrame(columns=['stock_code','date','weight','open','close','margin_trade'])
@ -510,6 +563,7 @@ class Trader(Account):
factor = self.signal[trade_time].loc[date]
# 获取当前、持仓
sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor)
# 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算
if update_type == 'position':
self.position_history[date] = next_position.copy()
return True