Compare commits

..

11 Commits
0.0.1 ... main

Author SHA1 Message Date
binz 0977e8c9fd Update: 新增mkt和price的买入过滤;
Update: 调整过滤参数名称和指定方式;(#21)
2024-06-26 23:36:58 +08:00
binz d671f820d4 Update: readme Ver 0.0.3 2024-06-25 21:01:59 +08:00
binz 721e2a42f5 Update: 增加信号更新函数 (#26);
新增interval初始化函数;
2024-06-24 21:59:34 +08:00
binz 30a998b58a Fix: 修复自定义价格时,非买卖标的的价格选取 (#24) 2024-06-17 20:35:25 +08:00
binz 411e1a3f78 Update: 新增指定量分钟价格计算 2024-06-16 23:31:39 +08:00
binz d5a2d9cb9c test tag 2024-06-14 21:18:40 +08:00
binz 28e1bb27c6 tag test 2024-06-14 21:09:41 +08:00
binz 43763db2dd Update: readme Ver 0.0.2 2024-06-13 19:16:47 +08:00
binz 11becd1d2a Bug: 修复position_ratio大于1时权重重新分配时的计算问题 (#23) 2024-06-12 14:39:20 +08:00
binz 158316aebc Bug: 修复position_ratio > 1时仓位计算函数缺失untradable_list导致计算错误的问题 2024-06-12 01:34:59 +08:00
binz 10e926bf28 Update:自定义权重支持不满仓的情况;Update:增加自定义权重总和判断;
Bug:当换仓数量大于持仓数量时全卖出和全买入导致的持仓计算问题‘;
2024-06-08 19:12:23 +08:00
8 changed files with 358 additions and 139 deletions

View File

@ -1,4 +1,14 @@
## Version 0.0.1 ## Version 0.0.1
1. 基本实现交易流程 1. 基本实现交易流程
2. 新增强制筛选参数 2. 新增强制筛选参数
3. 新增自定义个股权重和仓位权重 3. 新增自定义个股权重和仓位权重
## Version 0.02
1. 优化非满仓情况下现金比例对收益的计算问题
2. 优化自定义权重下的判断逻辑
3. 修复没有满仓有现金比例时的收益计算问题
## Version 0.03
1. 新增分钟价格计算函数
2. 修复自定义价格时,买卖价格和非买卖价格的选取问题
3. 增加信号更新函数

33
Specified_Price.py Normal file
View File

@ -0,0 +1,33 @@
import pandas as pd
import numpy as np
def amount_specified(
min_data: pd.DataFrame,
post_adj_factor: pd.DataFrame,
min_amount: float=0.
):
"""
计算指定金额下的平均后复权价格
Args:
min_data (pd.DataFrame): 指定日期分钟数据需至少包含代码(stock_code)分钟(Time)价格(price)成交量(vol)成交金额(amount)
post_adj_factor (pd.DataFrame): 后复权因子数据需至少包含股票后复权因子
min_amount (float): 指定最小金额下的平均价
"""
# 按照指定最小量获取平均价格
stock_amt = min_data.pivot_table(index='Time', columns='stock_code', values='amount').cumsum()
stock_vol = min_data.pivot_table(index='Time', columns='stock_code', values='vol').cumsum()
amount_price = stock_amt / stock_vol
amount_price.iloc[1:] = amount_price.iloc[1:].where(stock_amt <= min_amount, np.nan)
amount_price = amount_price.unstack().reset_index()
amount_price.columns = ['stock_code', 'time', 'price']
amount_price = amount_price.dropna(subset=['price']).drop_duplicates(subset='stock_code', keep='last')
# 计算后复权价格
amount_price = amount_price.merge(post_adj_factor, on=['stock_code'], how='left').dropna(subset=['factor'])
amount_price['open_post'] = amount_price['price'] * amount_price['factor']
return amount_price[['stock_code','open_post']]

View File

@ -2,4 +2,5 @@ from account import Account
from trader import Trader from trader import Trader
from spread_backtest import Spread_Backtest from spread_backtest import Spread_Backtest
__all__ = ['Account', 'Trader', 'Spread_Backtest']
__all__ = ['Account', 'Trader', 'Spread_Backtest', 'Specified_Price']

68
check_funcs.py Normal file
View File

@ -0,0 +1,68 @@
import numpy as np
def check_buy_exclude(buy_exclude):
buy_exclude_dict = dict()
# 检查剔除条件
for cond in ['amt_20', 'list_days', 'mkt', 'price']:
if cond == 'amt_20':
# 20日成交量均值
if cond in buy_exclude:
amt_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(amt_exclude) == 1:
buy_exclude_dict[cond] = (amt_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (amt_exclude[0], amt_exclude[1])
else:
raise Exception('wrong input type for buy exclude: amt_20, `set`, `list` or `tuple` is required')
if cond == 'list_days':
# 上市时间
if cond in buy_exclude:
list_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(list_exclude) == 1:
buy_exclude_dict[cond] = (list_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (list_exclude[0], list_exclude[1])
else:
raise Exception('wrong input type for buy exclude: list_days, `set`, `list` or `tuple` is required')
if cond == 'mkt':
# 市值
if cond in buy_exclude:
mkt_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(mkt_exclude) == 1:
buy_exclude_dict[cond] = (mkt_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (mkt_exclude[0], mkt_exclude[1])
else:
raise Exception('wrong input type for buy exclude: mkt, `set`, `list` or `tuple` is required')
if cond == 'price':
# 价格
if cond in buy_exclude:
price_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(price_exclude) == 1:
buy_exclude_dict[cond] = (price_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (price_exclude[0], price_exclude[1])
else:
raise Exception('wrong input type for buy exclude: price, `set`, `list` or `tuple` is required')
return buy_exclude_dict

View File

@ -10,8 +10,8 @@ if __name__ == '__main__':
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data' data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
save_dir = '/home/lenovo/quant/data/backtest/basic_data' save_dir = '/home/lenovo/quant/data/backtest/basic_data'
for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list', for i,f in enumerate(['open_post','close_post','open_pre','close_pre','down_limit','up_limit','size','amount_20',
'abnormal', 'recession']): 'opening_info','ipo_days','margin_list','abnormal', 'recession']):
if f in ['margin_list']: if f in ['margin_list']:
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0) tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
else: else:
@ -34,7 +34,7 @@ if __name__ == '__main__':
# 更新下一日的数据用于筛选 # 更新下一日的数据用于筛选
next_date = gft.days_after(df.index.max(), 1) next_date = gft.days_after(df.index.max(), 1)
next_list = [] next_list = []
for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal','recession']): for i,f in enumerate(['close_pre','size','amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
if f in ['margin_list']: if f in ['margin_list']:
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f)) next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
else: else:

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from typing import Union, Iterable, Optional, Dict from typing import Union, Iterable
class DataLoader(): class DataLoader():
""" """

View File

@ -5,6 +5,7 @@ import pandas as pd
import numpy as np import numpy as np
import time import time
from trader import Trader from trader import Trader
from typing import Union, Dict
from rich import print as rprint from rich import print as rprint
from rich.table import Table from rich.table import Table
@ -60,6 +61,26 @@ class SpreadBacktest():
else: else:
self.trader.update_signal(date, update_type='position') self.trader.update_signal(date, update_type='position')
# 更新数据
def update_signal(self,
trade_time: str,
new_signal: pd.DataFrame):
"""
更新信号因子
Args:
trade_time (str): 信号时间
new_signal (pd.DataFrame): 新的更新信号
"""
self.trader.signal[trade_time] = new_signal
def update_interval(self,
interval: Dict[str, Union[int,tuple,pd.Series]]={},
):
# 更新interval和weight
# 如果interval为固定比例则更新
self.trader.init_interval(interval)
@property @property
def account_history(self): def account_history(self):
return self.trader.account_history return self.trader.account_history

352
trader.py
View File

@ -1,5 +1,4 @@
import pandas as pd import pandas as pd
import numpy as np
import sys import sys
import os import os
import copy import copy
@ -8,59 +7,60 @@ from typing import Union, Iterable, Dict
from account import Account from account import Account
from dataloader import DataLoader from dataloader import DataLoader
from check_funcs import check_buy_exclude
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/") sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools from db_tushare import get_factor_tools
gft = get_factor_tools() gft = get_factor_tools()
class Trader(Account): class Trader(Account):
""" """
交易类: 用于控制每日交易情况 交易类: 用于控制每日交易情况
Args: Args:
signal (dict[str, pd.DataFrame]): 目标因子按顺序执行 signal (dict[str, pd.DataFrame]): 目标因子按顺序执行
interval (int, tuple, pd.Series): 交易间隔 interval (dict[str, (int, tuple, pd.Series)]):
num (int): 持仓数量 交易间隔
ascending (bool): 因子方向 num (int): 持仓数量
with_st (bool): 是否包含st ascending (bool): 因子方向
tick (bool): 是否开始tick模拟模式(开发中) with_st (bool): 是否包含st
weight ([str, pd.DataFrame]): 权重分配 tick (bool): 是否开始tick模拟模式(开发中)
weight ([str, pd.DataFrame]): 权重分配
- avg (str): 平均分配每天早盘重新分配日中交易不重新分配 - avg (str): 平均分配每天早盘重新分配日中交易不重新分配
- (pd.DataFrame): 自定义股票权重包含每天个股指定的权重会自动归一化 - (pd.DataFrame): 自定义股票权重包含每天个股指定的权重会自动归一化
amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限 amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限
data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列 data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列
ipo_days (int): 筛选上市时间 ipo_days (int): 筛选上市时间
slippage (tuple): 买入和卖出滑点 slippage (tuple): 买入和卖出滑点
commission (float): 佣金 commission (float): 佣金
tax (dict): 印花税 tax (dict): 印花税
exclude_list (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓 force_exclude (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓
- abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除 - abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除
- receesion: 财报同比或环比下降50%以上 - receesion: 财报同比或环比下降50%以上
- qualified_opinion: 会计保留意见 - qualified_opinion: 会计保留意见
account (Account): 账户设置account.Account account (Account): 账户设置account.Account
""" """
def __init__(self, def __init__(self,
signal: Dict[str, pd.DataFrame]=None, signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]=1, interval: Dict[str, Union[int,tuple,pd.Series]]={},
num: int=100, num: int=100,
ascending: bool=False, ascending: bool=False,
with_st: bool=False, with_st: bool=False,
data_root:dict={}, data_root:dict={},
tick: bool=False, tick: bool=False,
weight: str='avg', weight: str='avg',
amt_filter: set=(0,),
ipo_days: int=20,
slippage :tuple=(0.001,0.001), slippage :tuple=(0.001,0.001),
commission: float=0.0001, commission: float=0.0001,
buy_exclude: Dict[str, Union[set,int,float]]={},
force_exclude: list=[],
account: dict={},
tax: dict={ tax: dict={
'1990-01-01': (0.001,0.001), '1990-01-01': (0.001,0.001),
'2008-04-24': (0.001,0.001), '2008-04-24': (0.001,0.001),
'2008-09-19': (0, 0.001), '2008-09-19': (0, 0.001),
'2023-08-28': (0, 0.0005) '2023-08-28': (0, 0.0005)
}, },
exclude_list: list=[],
account: dict={},
**kwargs) -> None: **kwargs) -> None:
# 初始化账户 # 初始化账户
super().__init__(**account) super().__init__(**account)
@ -78,24 +78,7 @@ class Trader(Account):
if len(kwargs) > 0: if len(kwargs) > 0:
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'") raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
# interval # interval
self.interval = [] self.init_interval(interval)
for s in signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
self.interval.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(self.interval)
# num # num
if isinstance(num, int): if isinstance(num, int):
self.num = int(num) self.num = int(num)
@ -116,21 +99,6 @@ class Trader(Account):
self.weight = weight self.weight = weight
else: else:
raise ValueError('invalid type for `weight`') raise ValueError('invalid type for `weight`')
# amt_filter
if isinstance(amt_filter, (set,list,tuple)):
if len(amt_filter) == 1:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = np.inf
else:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = amt_filter[1]
else:
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
# ipo_days
if isinstance(ipo_days, int):
self.ipo_days = ipo_days
else:
raise Exception('wrong type for ipo_days, `int` is required')
# slippage # slippage
if isinstance(slippage, tuple) and len(slippage) == 2: if isinstance(slippage, tuple) and len(slippage) == 2:
self.slippage = slippage self.slippage = slippage
@ -146,17 +114,19 @@ class Trader(Account):
self.tax = tax self.tax = tax
else: else:
raise ValueError('tax should be dict.') raise ValueError('tax should be dict.')
# exclude # buy exclude
if isinstance(exclude_list, list): self.buy_exclude = check_buy_exclude(buy_exclude)
self.exclude_list = exclude_list # force exclude
if isinstance(force_exclude, list):
self.force_exclude = force_exclude
optional_list = ['abnormal', 'recession'] optional_list = ['abnormal', 'recession']
for item in exclude_list: for item in force_exclude:
if item in optional_list: if item in optional_list:
pass pass
else: else:
raise ValueError(f"Unexpected keyword argument '{item}'") raise ValueError(f"Unexpected keyword argument '{item}'")
else: else:
raise ValueError('exclude_list should be list.') raise ValueError('force_exclude should be list.')
# data_root # data_root
# 至少包含basic data路径open信号默认使用basic_data # 至少包含basic data路径open信号默认使用basic_data
if len(data_root) <= 0: if len(data_root) <= 0:
@ -173,7 +143,30 @@ class Trader(Account):
if s not in data_root: if s not in data_root:
raise ValueError(f"data for signal {s} is not provided") raise ValueError(f"data for signal {s} is not provided")
self.data_root = data_root self.data_root = data_root
def init_interval(self, interval):
"""
初始化interval
"""
interval_list = []
for s in self.signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
interval_list.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(interval_list)
def load_data(self, def load_data(self,
date: str, date: str,
update_type: str='rtn'): update_type: str='rtn'):
@ -187,14 +180,19 @@ class Trader(Account):
""" """
self.today_data = dict() self.today_data = dict()
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv')) self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
# 遍历从data_root中读取当日数据
for path in self.data_root:
if path == 'basic':
continue
else:
self.today_data[path] = DataLoader(os.path.join(self.data_root[path],f'{date}.csv'))
if update_type == 'position': if update_type == 'position':
return True return True
for s in self.signal: for s in self.signal:
if s == 'open': if s == 'open':
if s in self.data_root: if s in self.data_root:
continue self.today_data[s+'_trade'] = DataLoader(self.today_data['open'].data[['open_post']].rename(columns={'open_post':'price'}))
else: self.today_data['open'] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
else: else:
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv')) self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
if 'close' in self.signal: if 'close' in self.signal:
@ -209,11 +207,12 @@ class Trader(Account):
# 可执行日期 # 可执行日期
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index() self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index()
def get_weight(self, date, account_weight, next_position): def get_weight(self, date, account_weight, untradable_list, next_position):
""" """
计算个股仓位 计算个股仓位
Args: Args:
untradable_list (list): 无法交易列表
account_weight (float): 总权重即当前持仓比例 account_weight (float): 总权重即当前持仓比例
""" """
if isinstance(self.weight, str): if isinstance(self.weight, str):
@ -221,14 +220,26 @@ class Trader(Account):
return account_weight / len(next_position) return account_weight / len(next_position)
if isinstance(self.weight, pd.DataFrame): if isinstance(self.weight, pd.DataFrame):
date_weight = self.weight.loc[date].dropna().sort_index() date_weight = self.weight.loc[date].dropna().sort_index()
# untradable_list不要求指定权重用昨日权重填充
weight_list = pd.Series(index=next_position['stock_code'])
try: try:
weight_list = date_weight.loc[next_position['stock_code'].to_list()].values # 填充untradable_list权重
if weight_list.sum() > 1 + 1e5: # 防止数据精度的影响,给与一定的宽松 if len(untradable_list) > 0:
raise Exception(f"total weight of {date} is larger then 1.") weight_list.loc[untradable_list] = self.position.set_index('stock_code').loc[untradable_list, 'weight']
weight_list = account_weight * weight_list
return weight_list
except Exception: except Exception:
raise ValueError(f'not found stock weight in {date}') raise ValueError('not found stock weight for untradable stocks in last position.')
try:
# 获取tradable_list权重并对untradable_list占据的仓位进行调整
tradable_list = list(set(next_position['stock_code']) - set(untradable_list))
# 剔除untradable_list仓位后剩余持仓根据自定义权重分配
weight_list.loc[tradable_list] = date_weight.loc[tradable_list].values / date_weight.loc[tradable_list].sum() * (account_weight - weight_list.loc[untradable_list].sum())
weight_list = weight_list.values
if sum(weight_list) > 1 + 1e-5: # 防止数据精度的影响,给与一定的宽松
raise Exception(f"total weight of {date} is larger then 1.")
return weight_list
except Exception as e:
print(e)
raise ValueError(f'not found specified stock weight in {date}')
def get_next_position(self, date, factor): def get_next_position(self, date, factor):
""" """
@ -241,59 +252,97 @@ class Trader(Account):
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量 # 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
# 不足的数量通过买入列表自适应调整 # 不足的数量通过买入列表自适应调整
# 这样能实现在因子值不足时也正常换仓 # 这样能实现在因子值不足时也正常换仓
max_sell_num = self.interval.loc[date]*len(last_position) try:
max_sell_num = self.interval.loc[date]*len(last_position)
except Exception:
raise ValueError(f'not found interval in {date}')
else: else:
last_position = pd.Series() last_position = pd.Series()
max_sell_num = self.num max_sell_num = self.num
# 获取用于筛选的数据 # 获取用于筛选的数据
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index() stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index() exclude_data = dict()
stock_amt_filter = stock_amt20.copy() for cond in self.buy_exclude:
stock_amt_filter.loc[:] = 0 if cond == 'amt_20':
stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1 stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_filter = stock_amt_filter.sort_index() stock_amt_exclude = stock_amt20.copy()
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index() stock_amt_exclude.loc[:] = 0
stock_ipo_filter = stock_ipo_days.copy() stock_amt_exclude.loc[(stock_amt20 > self.buy_exclude[cond][0]) & (stock_amt20 < self.buy_exclude[cond][1])] = 1
stock_ipo_filter.loc[:] = 0 stock_amt_exclude = stock_amt_exclude.sort_index()
stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1 exclude_data[cond] = stock_amt_exclude
if cond == 'list_days':
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_exclude = stock_ipo_days.copy()
stock_ipo_exclude.loc[:] = 0
stock_ipo_exclude.loc[(stock_ipo_days > self.buy_exclude[cond][0]) & (stock_ipo_days < self.buy_exclude[cond][1])] = 1
stock_ipo_exclude = stock_ipo_exclude.sort_index()
exclude_data[cond] = stock_ipo_exclude
if cond == 'mkt':
stock_size = self.today_data['basic'].get(factor.index.values, 'size').sort_index()
stock_size_exclude = stock_size.copy()
stock_size_exclude.loc[:] = 0
stock_size_exclude.loc[(stock_size > self.buy_exclude[cond][0]) & (stock_size < self.buy_exclude[cond][1])] = 1
stock_size_exclude = stock_size_exclude.sort_index()
exclude_data[cond] = stock_size_exclude
if cond == 'price':
stock_price = self.today_data['basic'].get(factor.index.values, 'close_pre').sort_index()
stock_price_exclude = stock_price.copy()
stock_price_exclude.loc[:] = 0
stock_price_exclude.loc[(stock_price > self.buy_exclude[cond][0]) & (stock_price < self.buy_exclude[cond][1])] = 1
stock_price_exclude = stock_price_exclude.sort_index()
exclude_data[cond] = stock_price_exclude
# 剔除列表 # 剔除列表
# 包含强制列表和普通列表: # 包含强制剔除列表和买入剔除列表:
# 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中 # 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中
# 普通列表如果已经持有不会过滤只对新买入的过滤 # 买入剔除列表如果已经持有不会过滤只对新买入的过滤
# 强制过滤列表 # 强制过滤列表
exclude_stock = [] exclude_stock = []
for exclude in self.exclude_list: for cond in self.force_exclude:
if exclude == 'abnormal': if cond == 'abnormal':
stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index() stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index()
exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list() exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list()
if exclude == 'recession': if cond == 'recession':
stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index() stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index()
exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list() exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list()
exclude_stock = list(set(exclude_stock)) exclude_stock = list(set(exclude_stock))
force_exclude = copy.deepcopy(exclude_stock) force_exclude = copy.deepcopy(exclude_stock)
# 普通过滤列表 # 买入过滤列表
normal_exclude = [] buy_exclude = []
normal_exclude += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list() for cond in self.buy_exclude:
normal_exclude += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list() buy_exclude += exclude_data[cond].loc[exclude_data[cond] != 1].index.to_list()
normal_exclude = list(set(normal_exclude)) buy_exclude = list(set(buy_exclude))
# 交易列表 # 交易列表
if self.today_position_ratio <= 1.0: # 仓位判断给与计算误差冗余
if self.today_position_ratio <= 1.0 + 1e-5:
# 如果没有杠杆: # 如果没有杠杆:
buy_list = [] # 交易逻辑:
sell_list = [] # 1 判断卖出,如果当天跌停则减少实际卖出数量
# 2 判断买入:根据实际卖出数量和距离目标持仓数量判断买入数量,如果当天涨停则减少实际买入数量
untradable_list = [] untradable_list = []
target_list = []
# ----- 卖出 ----- # ----- 卖出 -----
# 异常强制卖出 sell_list = []
limit_down_list = [] # 跌停股记录
# 遍历昨日持仓状态:
# 1 记录持仓状态
# 2 获取停牌股列表
# 3 获取异常强制卖出列表
last_position_status = pd.Series()
for stock in last_position.index: for stock in last_position.index:
if stock_status.loc[stock] in [0,2,5,7]: last_position_status.loc[stock] = stock_status.loc[stock]
if last_position_status.loc[stock] in [0,2]:
untradable_list.append(stock) untradable_list.append(stock)
else: else:
if stock in force_exclude: if last_position_status.loc[stock] in [5,7]:
sell_list.append(stock) continue
else:
if stock in force_exclude:
sell_list.append(stock)
force_sell_num = len(sell_list) force_sell_num = len(sell_list)
# 剔除无法交易列表后,按照当日因子反向排名逐个卖出 # 剔除无法交易列表后,按照当日因子反向排名逐个卖出
@ -305,15 +354,21 @@ class Trader(Account):
for stock in factor_filled.loc[list(set(last_position.index)-set(untradable_list)-set(sell_list))].sort_values(ascending=self.ascending).index.values[::-1]: for stock in factor_filled.loc[list(set(last_position.index)-set(untradable_list)-set(sell_list))].sort_values(ascending=self.ascending).index.values[::-1]:
if len(sell_list) >= max_sell_num + force_sell_num: if len(sell_list) >= max_sell_num + force_sell_num:
break break
if stock_status.loc[stock] in [0,2,5,7]: if last_position_status.loc[stock] in [0,2]:
continue continue
else: else:
if last_position_status.loc[stock] in [5,7]:
limit_down_list.append(stock)
sell_list.append(stock) sell_list.append(stock)
sell_list = list(set(sell_list)) sell_list = list(set(sell_list))
# 实际卖出列表 = 卖出列表 - 跌停列表
sell_list = list(set(sell_list) - set(limit_down_list))
# ----- 买入 ----- # ----- 买入 -----
buy_list = []
# 剔除过滤条件后可买入列表 # 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude)) after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list() target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表 # 更新卖出后的持仓列表
@ -321,6 +376,8 @@ class Trader(Account):
limit_up_list = [] # 涨停股记录 limit_up_list = [] # 涨停股记录
max_buy_num = max(0, self.num-len(last_position)+len(sell_list)) max_buy_num = max(0, self.num-len(last_position)+len(sell_list))
for stock in target_list: for stock in target_list:
if len(buy_list) == max_buy_num:
break
if stock in after_sell_list: if stock in after_sell_list:
continue continue
else: else:
@ -332,29 +389,23 @@ class Trader(Account):
if stock_status.loc[stock] in [4,6]: if stock_status.loc[stock] in [4,6]:
limit_up_list.append(stock) limit_up_list.append(stock)
buy_list.append(stock) buy_list.append(stock)
if len(buy_list) == max_buy_num: buy_list = list(set(buy_list))
break
# 剔除同时在sell_list和buy_list的股票 # 剔除同时在sell_list和buy_list的股票
duplicate_stock = set(sell_list) & set(buy_list) duplicate_stock = set(sell_list) & set(buy_list)
sell_list = list(set(sell_list) - duplicate_stock) sell_list = list(set(sell_list) - duplicate_stock)
buy_list = list(set(buy_list) - duplicate_stock) buy_list = list(set(buy_list) - duplicate_stock)
# 生成下一期持仓 # 生成下一期持仓
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))}) next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
next_position['date'] = date next_position['date'] = date
next_position['weight'] = self.get_weight(date, self.today_position_ratio, next_position) next_position['weight'] = self.get_weight(date, self.today_position_ratio, untradable_list+limit_down_list, next_position)
# 剔除无法买入的涨停股,这部分仓位空出
next_position = next_position[~next_position['stock_code'].isin(limit_up_list)] # 剔除无法买入且不在昨日持仓中的涨停股,这部分仓位空出
next_position = next_position[~next_position['stock_code'].isin(list(set(limit_up_list)-set(last_position.index)))]
next_position['margin_trade'] = 0 next_position['margin_trade'] = 0
else: else:
# 如果有杠杆: # 如果有杠杆:
def assign_stock(normal_list, margin_list, margin_needed, stock, status):
if status == 1:
if len(margin_list) < margin_needed:
margin_list.append(stock)
else:
if len(normal_list) < self.num - margin_needed:
normal_list.append(stock)
return normal_list, margin_list
# 计算需要融资融券标的数量 # 计算需要融资融券标的数量
margin_ratio = max(self.today_position_ratio-1, 0) margin_ratio = max(self.today_position_ratio-1, 0)
margin_needed = round(self.num * margin_ratio) margin_needed = round(self.num * margin_ratio)
@ -372,7 +423,6 @@ class Trader(Account):
last_normal_list = [] last_normal_list = []
# ----- 卖出 ----- # ----- 卖出 -----
buy_list = []
sell_list = [] sell_list = []
untradable_list = [] untradable_list = []
# 分别更新融资融券池的和非融资融券池 # 分别更新融资融券池的和非融资融券池
@ -424,8 +474,10 @@ class Trader(Account):
next_normal_list = list(set(last_normal_list) - set(sell_list)) next_normal_list = list(set(last_normal_list) - set(sell_list))
# ----- 买入 ----- # ----- 买入 -----
buy_list = []
# 剔除过滤条件后可买入列表 # 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude)) after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list() target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表 # 更新卖出后的持仓列表
@ -434,6 +486,8 @@ class Trader(Account):
# 融资融券池的和非融资融券池的分开更新 # 融资融券池的和非融资融券池的分开更新
# 更新融资融券池 # 更新融资融券池
for stock in target_list: for stock in target_list:
if len(next_margin_list) >= margin_needed:
break
if stock in after_sell_list: if stock in after_sell_list:
continue continue
else: else:
@ -446,10 +500,11 @@ class Trader(Account):
if stock_status.loc[stock] in [4,6]: if stock_status.loc[stock] in [4,6]:
limit_up_list.append(stock) limit_up_list.append(stock)
next_margin_list.append(stock) next_margin_list.append(stock)
if len(next_margin_list) >= margin_needed: next_margin_list = list(set(next_margin_list))
break
# 更新非融资融券池 # 更新非融资融券池
for stock in target_list: for stock in target_list:
if len(next_normal_list) >= self.num - len(next_margin_list):
break
if stock in (set(after_sell_list) | set(next_margin_list)): if stock in (set(after_sell_list) | set(next_margin_list)):
continue continue
else: else:
@ -461,20 +516,21 @@ class Trader(Account):
if stock_status.loc[stock] in [4,6]: if stock_status.loc[stock] in [4,6]:
limit_up_list.append(stock) limit_up_list.append(stock)
next_normal_list.append(stock) next_normal_list.append(stock)
if len(next_normal_list) >= self.num - len(next_margin_list): next_normal_list = list(set(next_normal_list))
break
next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list}) next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list})
next_position['date'] = date next_position['date'] = date
# 融资融券数量 # 融资融券数量
margin_num = len(next_margin_list) margin_num = len(next_margin_list)
next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), next_position) next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), untradable_list, next_position)
next_position['margin_trade'] = 0 next_position['margin_trade'] = 0
next_position = next_position.set_index(['stock_code']) next_position = next_position.set_index(['stock_code'])
next_position.loc[next_margin_list, 'margin_trade'] = 1 next_position.loc[next_margin_list, 'margin_trade'] = 1
next_position = next_position.reset_index() next_position = next_position.reset_index()
# 剔除无法买入的涨停股,这部分仓位空出 # 剔除无法买入的涨停股,这部分仓位空出
next_position = next_position[~next_position['stock_code'].isin(limit_up_list)] next_position = next_position[~next_position['stock_code'].isin(limit_up_list)]
# 检测当前持仓是否可以交易 # 检测当前持仓是否可以交易
frozen_list = [] frozen_list = []
if len(self.position) > 0: if len(self.position) > 0:
@ -496,13 +552,19 @@ class Trader(Account):
buy_list (Iterable[str]): 买入目标 buy_list (Iterable[str]): 买入目标
sell_list (Iterable[str]): 卖出目标 sell_list (Iterable[str]): 卖出目标
""" """
stock_price = self.today_data[trade_time] if trade_time+'_trade' in self.today_data:
trade_price = self.today_data[trade_time+'_trade']
basic_price = self.today_data[trade_time]
else:
basic_price = self.today_data[trade_time]
trade_price = basic_price
target_price = pd.Series(index=target) target_price = pd.Series(index=target)
sell_list = list(set(target) & set(sell_list)) sell_list = list(set(target) & set(sell_list))
buy_list = list(set(target) & set(buy_list)) buy_list = list(set(target) & set(buy_list))
target_price.loc[target] = stock_price.get(target, 'price').fillna(0) # 根据交易和非交易标的分别获取目标价格
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1]) target_price.loc[target] = basic_price.get(target, 'price').fillna(0)
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0]) target_price.loc[sell_list] = trade_price.get(sell_list, 'price') * (1 - self.current_fee[1])
target_price.loc[buy_list] = trade_price.get(buy_list, 'price') * (1 + self.current_fee[0])
return target_price return target_price
def check_update_status(self, def check_update_status(self,
@ -601,7 +663,8 @@ class Trader(Account):
if cur_pos['weight'].sum() == 0: if cur_pos['weight'].sum() == 0:
pnl = 0 pnl = 0
else: else:
pnl = (cur_pos['end_weight'].sum() - cur_pos['weight'].sum()) cash = max(0, 1 - cur_pos['weight'].sum())
pnl = ((cur_pos['end_weight'].sum() + cash) / (cur_pos['weight'].sum() + cash)) - 1
self.account *= 1+pnl self.account *= 1+pnl
self.account_history = self.account_history.append({ self.account_history = self.account_history.append({
'date': date, 'date': date,
@ -612,6 +675,18 @@ class Trader(Account):
}, ignore_index=True) }, ignore_index=True)
return True return True
@staticmethod
def update_next_weight(position):
"""
根据收盘权重计算下一时刻新的个股权重
"""
if position['weight'].sum() <= 1 + 1e-5:
# 非融资情况
cash = max(0, 1 - position['weight'].sum())
return position['end_weight'] / (position['end_weight'].sum() + cash)
else:
return position['weight'].sum() * (position['end_weight'] / position['end_weight'].sum())
def update_signal(self, def update_signal(self,
date:str, date:str,
update_type='rtn'): update_type='rtn'):
@ -630,7 +705,9 @@ class Trader(Account):
self.account_history = self.account_history.query(f'date != "{date}" ', engine='python') self.account_history = self.account_history.query(f'date != "{date}" ', engine='python')
if date in self.position_history: if date in self.position_history:
self.position_history.pop(date) self.position_history.pop(date)
# 更新持仓信号 # ----- 更新当日回测数据 ------
# 更新当前日期和持仓信号
self.current_date = date
self.load_data(date, update_type) self.load_data(date, update_type)
# 更新当日持仓比例 # 更新当日持仓比例
if isinstance(self.position_ratio, float): if isinstance(self.position_ratio, float):
@ -646,13 +723,14 @@ class Trader(Account):
fee = (fee[0] + current_tax[0], fee[1] + current_tax[1]) fee = (fee[0] + current_tax[0], fee[1] + current_tax[1])
self.current_fee = fee self.current_fee = fee
# 如果当前持仓不空,添加隔夜收益,否则直接买入 # 如果当前持仓不空,添加隔夜收益,否则直接买入
position_fields = ['stock_code','date','weight','margin_trade','open','close','end_weight']
if len(self.position) == 0: if len(self.position) == 0:
cur_pos = pd.DataFrame(columns=['stock_code','date','weight','open','close','margin_trade']) cur_pos = pd.DataFrame(columns=position_fields)
else: else:
cur_pos = self.position.copy() cur_pos = self.position.copy()
# 冻结列表 # 冻结列表
frozen_list = [] frozen_list = []
# 遍历各个交易时间的信号 # ----- 遍历各个交易时间的信号 -----
for _,trade_time in enumerate(self.signal): for _,trade_time in enumerate(self.signal):
if self.check_update_status(date, trade_time): if self.check_update_status(date, trade_time):
continue continue
@ -663,6 +741,7 @@ class Trader(Account):
factor = self.signal[trade_time].loc[date] factor = self.signal[trade_time].loc[date]
# 获取当前、持仓 # 获取当前、持仓
sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor) sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor)
# 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算 # 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算
if update_type == 'position': if update_type == 'position':
self.position_history[date] = next_position.copy() self.position_history[date] = next_position.copy()
@ -674,9 +753,11 @@ class Trader(Account):
# 计算收益 # 计算收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn'] cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
self.update_account(date, trade_time, cur_pos, next_position) self.update_account(date, trade_time, cur_pos, next_position)
# 更新仓位 # 更新仓位
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum() cur_pos['weight'] = self.update_next_weight(cur_pos)
# 调整权重:买入、卖出、仓位再平衡 # 调整权重:买入、卖出、仓位再平衡
next_position = self.reblance_weight(trade_time, cur_pos, next_position) next_position = self.reblance_weight(trade_time, cur_pos, next_position)
else: else:
@ -691,14 +772,19 @@ class Trader(Account):
# 停牌价格不变 # 停牌价格不变
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open'] cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open'] cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open']
# 更新当日收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1 cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1) cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
position_record = cur_pos.copy() position_record = cur_pos.copy()
position_record['end_weight'] = (position_record['end_weight'] / position_record['end_weight'].sum()) * position_record['weight'].sum() position_record['end_weight'] = self.update_next_weight(position_record)
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum() self.position_history[date] = position_record.copy()[position_fields]
# 更新当期收盘后个股仓位作为下一期的开盘仓位
cur_pos['weight'] = self.update_next_weight(cur_pos)
next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']] next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']]
next_position['open'] = cur_pos['close'] next_position['open'] = cur_pos['close']
self.update_account(date, trade_time, cur_pos, cur_pos) self.update_account(date, trade_time, cur_pos, cur_pos)
# 记录当前时刻最终持仓和个股权重
self.position = next_position.copy() self.position = next_position.copy()
self.position_history[date] = position_record.copy()
return True return True