Compare commits

..

No commits in common. "main" and "0.0.1" have entirely different histories.
main ... 0.0.1

8 changed files with 139 additions and 358 deletions

View File

@ -2,13 +2,3 @@
1. 基本实现交易流程 1. 基本实现交易流程
2. 新增强制筛选参数 2. 新增强制筛选参数
3. 新增自定义个股权重和仓位权重 3. 新增自定义个股权重和仓位权重
## Version 0.02
1. 优化非满仓情况下现金比例对收益的计算问题
2. 优化自定义权重下的判断逻辑
3. 修复没有满仓有现金比例时的收益计算问题
## Version 0.03
1. 新增分钟价格计算函数
2. 修复自定义价格时,买卖价格和非买卖价格的选取问题
3. 增加信号更新函数

View File

@ -1,33 +0,0 @@
import pandas as pd
import numpy as np
def amount_specified(
min_data: pd.DataFrame,
post_adj_factor: pd.DataFrame,
min_amount: float=0.
):
"""
计算指定金额下的平均后复权价格
Args:
min_data (pd.DataFrame): 指定日期分钟数据需至少包含代码(stock_code)分钟(Time)价格(price)成交量(vol)成交金额(amount)
post_adj_factor (pd.DataFrame): 后复权因子数据需至少包含股票后复权因子
min_amount (float): 指定最小金额下的平均价
"""
# 按照指定最小量获取平均价格
stock_amt = min_data.pivot_table(index='Time', columns='stock_code', values='amount').cumsum()
stock_vol = min_data.pivot_table(index='Time', columns='stock_code', values='vol').cumsum()
amount_price = stock_amt / stock_vol
amount_price.iloc[1:] = amount_price.iloc[1:].where(stock_amt <= min_amount, np.nan)
amount_price = amount_price.unstack().reset_index()
amount_price.columns = ['stock_code', 'time', 'price']
amount_price = amount_price.dropna(subset=['price']).drop_duplicates(subset='stock_code', keep='last')
# 计算后复权价格
amount_price = amount_price.merge(post_adj_factor, on=['stock_code'], how='left').dropna(subset=['factor'])
amount_price['open_post'] = amount_price['price'] * amount_price['factor']
return amount_price[['stock_code','open_post']]

View File

@ -2,5 +2,4 @@ from account import Account
from trader import Trader from trader import Trader
from spread_backtest import Spread_Backtest from spread_backtest import Spread_Backtest
__all__ = ['Account', 'Trader', 'Spread_Backtest']
__all__ = ['Account', 'Trader', 'Spread_Backtest', 'Specified_Price']

View File

@ -1,68 +0,0 @@
import numpy as np
def check_buy_exclude(buy_exclude):
buy_exclude_dict = dict()
# 检查剔除条件
for cond in ['amt_20', 'list_days', 'mkt', 'price']:
if cond == 'amt_20':
# 20日成交量均值
if cond in buy_exclude:
amt_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(amt_exclude) == 1:
buy_exclude_dict[cond] = (amt_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (amt_exclude[0], amt_exclude[1])
else:
raise Exception('wrong input type for buy exclude: amt_20, `set`, `list` or `tuple` is required')
if cond == 'list_days':
# 上市时间
if cond in buy_exclude:
list_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(list_exclude) == 1:
buy_exclude_dict[cond] = (list_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (list_exclude[0], list_exclude[1])
else:
raise Exception('wrong input type for buy exclude: list_days, `set`, `list` or `tuple` is required')
if cond == 'mkt':
# 市值
if cond in buy_exclude:
mkt_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(mkt_exclude) == 1:
buy_exclude_dict[cond] = (mkt_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (mkt_exclude[0], mkt_exclude[1])
else:
raise Exception('wrong input type for buy exclude: mkt, `set`, `list` or `tuple` is required')
if cond == 'price':
# 价格
if cond in buy_exclude:
price_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(price_exclude) == 1:
buy_exclude_dict[cond] = (price_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (price_exclude[0], price_exclude[1])
else:
raise Exception('wrong input type for buy exclude: price, `set`, `list` or `tuple` is required')
return buy_exclude_dict

View File

@ -10,8 +10,8 @@ if __name__ == '__main__':
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data' data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
save_dir = '/home/lenovo/quant/data/backtest/basic_data' save_dir = '/home/lenovo/quant/data/backtest/basic_data'
for i,f in enumerate(['open_post','close_post','open_pre','close_pre','down_limit','up_limit','size','amount_20', for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list',
'opening_info','ipo_days','margin_list','abnormal', 'recession']): 'abnormal', 'recession']):
if f in ['margin_list']: if f in ['margin_list']:
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0) tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
else: else:
@ -34,7 +34,7 @@ if __name__ == '__main__':
# 更新下一日的数据用于筛选 # 更新下一日的数据用于筛选
next_date = gft.days_after(df.index.max(), 1) next_date = gft.days_after(df.index.max(), 1)
next_list = [] next_list = []
for i,f in enumerate(['close_pre','size','amount_20','opening_info','ipo_days','margin_list','abnormal','recession']): for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
if f in ['margin_list']: if f in ['margin_list']:
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f)) next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
else: else:

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from typing import Union, Iterable from typing import Union, Iterable, Optional, Dict
class DataLoader(): class DataLoader():
""" """

View File

@ -5,7 +5,6 @@ import pandas as pd
import numpy as np import numpy as np
import time import time
from trader import Trader from trader import Trader
from typing import Union, Dict
from rich import print as rprint from rich import print as rprint
from rich.table import Table from rich.table import Table
@ -61,26 +60,6 @@ class SpreadBacktest():
else: else:
self.trader.update_signal(date, update_type='position') self.trader.update_signal(date, update_type='position')
# 更新数据
def update_signal(self,
trade_time: str,
new_signal: pd.DataFrame):
"""
更新信号因子
Args:
trade_time (str): 信号时间
new_signal (pd.DataFrame): 新的更新信号
"""
self.trader.signal[trade_time] = new_signal
def update_interval(self,
interval: Dict[str, Union[int,tuple,pd.Series]]={},
):
# 更新interval和weight
# 如果interval为固定比例则更新
self.trader.init_interval(interval)
@property @property
def account_history(self): def account_history(self):
return self.trader.account_history return self.trader.account_history

312
trader.py
View File

@ -1,4 +1,5 @@
import pandas as pd import pandas as pd
import numpy as np
import sys import sys
import os import os
import copy import copy
@ -7,7 +8,6 @@ from typing import Union, Iterable, Dict
from account import Account from account import Account
from dataloader import DataLoader from dataloader import DataLoader
from check_funcs import check_buy_exclude
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/") sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools from db_tushare import get_factor_tools
@ -20,8 +20,7 @@ class Trader(Account):
Args: Args:
signal (dict[str, pd.DataFrame]): 目标因子按顺序执行 signal (dict[str, pd.DataFrame]): 目标因子按顺序执行
interval (dict[str, (int, tuple, pd.Series)]): interval (int, tuple, pd.Series): 交易间隔
交易间隔
num (int): 持仓数量 num (int): 持仓数量
ascending (bool): 因子方向 ascending (bool): 因子方向
with_st (bool): 是否包含st with_st (bool): 是否包含st
@ -35,7 +34,7 @@ class Trader(Account):
slippage (tuple): 买入和卖出滑点 slippage (tuple): 买入和卖出滑点
commission (float): 佣金 commission (float): 佣金
tax (dict): 印花税 tax (dict): 印花税
force_exclude (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓 exclude_list (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓
- abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除 - abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除
- receesion: 财报同比或环比下降50%以上 - receesion: 财报同比或环比下降50%以上
- qualified_opinion: 会计保留意见 - qualified_opinion: 会计保留意见
@ -43,24 +42,25 @@ class Trader(Account):
""" """
def __init__(self, def __init__(self,
signal: Dict[str, pd.DataFrame]=None, signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]={}, interval: Dict[str, Union[int,tuple,pd.Series]]=1,
num: int=100, num: int=100,
ascending: bool=False, ascending: bool=False,
with_st: bool=False, with_st: bool=False,
data_root:dict={}, data_root:dict={},
tick: bool=False, tick: bool=False,
weight: str='avg', weight: str='avg',
amt_filter: set=(0,),
ipo_days: int=20,
slippage :tuple=(0.001,0.001), slippage :tuple=(0.001,0.001),
commission: float=0.0001, commission: float=0.0001,
buy_exclude: Dict[str, Union[set,int,float]]={},
force_exclude: list=[],
account: dict={},
tax: dict={ tax: dict={
'1990-01-01': (0.001,0.001), '1990-01-01': (0.001,0.001),
'2008-04-24': (0.001,0.001), '2008-04-24': (0.001,0.001),
'2008-09-19': (0, 0.001), '2008-09-19': (0, 0.001),
'2023-08-28': (0, 0.0005) '2023-08-28': (0, 0.0005)
}, },
exclude_list: list=[],
account: dict={},
**kwargs) -> None: **kwargs) -> None:
# 初始化账户 # 初始化账户
super().__init__(**account) super().__init__(**account)
@ -78,7 +78,24 @@ class Trader(Account):
if len(kwargs) > 0: if len(kwargs) > 0:
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'") raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
# interval # interval
self.init_interval(interval) self.interval = []
for s in signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
self.interval.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(self.interval)
# num # num
if isinstance(num, int): if isinstance(num, int):
self.num = int(num) self.num = int(num)
@ -99,6 +116,21 @@ class Trader(Account):
self.weight = weight self.weight = weight
else: else:
raise ValueError('invalid type for `weight`') raise ValueError('invalid type for `weight`')
# amt_filter
if isinstance(amt_filter, (set,list,tuple)):
if len(amt_filter) == 1:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = np.inf
else:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = amt_filter[1]
else:
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
# ipo_days
if isinstance(ipo_days, int):
self.ipo_days = ipo_days
else:
raise Exception('wrong type for ipo_days, `int` is required')
# slippage # slippage
if isinstance(slippage, tuple) and len(slippage) == 2: if isinstance(slippage, tuple) and len(slippage) == 2:
self.slippage = slippage self.slippage = slippage
@ -114,19 +146,17 @@ class Trader(Account):
self.tax = tax self.tax = tax
else: else:
raise ValueError('tax should be dict.') raise ValueError('tax should be dict.')
# buy exclude # exclude
self.buy_exclude = check_buy_exclude(buy_exclude) if isinstance(exclude_list, list):
# force exclude self.exclude_list = exclude_list
if isinstance(force_exclude, list):
self.force_exclude = force_exclude
optional_list = ['abnormal', 'recession'] optional_list = ['abnormal', 'recession']
for item in force_exclude: for item in exclude_list:
if item in optional_list: if item in optional_list:
pass pass
else: else:
raise ValueError(f"Unexpected keyword argument '{item}'") raise ValueError(f"Unexpected keyword argument '{item}'")
else: else:
raise ValueError('force_exclude should be list.') raise ValueError('exclude_list should be list.')
# data_root # data_root
# 至少包含basic data路径open信号默认使用basic_data # 至少包含basic data路径open信号默认使用basic_data
if len(data_root) <= 0: if len(data_root) <= 0:
@ -144,29 +174,6 @@ class Trader(Account):
raise ValueError(f"data for signal {s} is not provided") raise ValueError(f"data for signal {s} is not provided")
self.data_root = data_root self.data_root = data_root
def init_interval(self, interval):
"""
初始化interval
"""
interval_list = []
for s in self.signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
interval_list.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(interval_list)
def load_data(self, def load_data(self,
date: str, date: str,
update_type: str='rtn'): update_type: str='rtn'):
@ -180,19 +187,14 @@ class Trader(Account):
""" """
self.today_data = dict() self.today_data = dict()
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv')) self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
# 遍历从data_root中读取当日数据
for path in self.data_root:
if path == 'basic':
continue
else:
self.today_data[path] = DataLoader(os.path.join(self.data_root[path],f'{date}.csv'))
if update_type == 'position': if update_type == 'position':
return True return True
for s in self.signal: for s in self.signal:
if s == 'open': if s == 'open':
if s in self.data_root: if s in self.data_root:
self.today_data[s+'_trade'] = DataLoader(self.today_data['open'].data[['open_post']].rename(columns={'open_post':'price'})) continue
self.today_data['open'] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'})) else:
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
else: else:
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv')) self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
if 'close' in self.signal: if 'close' in self.signal:
@ -207,12 +209,11 @@ class Trader(Account):
# 可执行日期 # 可执行日期
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index() self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index()
def get_weight(self, date, account_weight, untradable_list, next_position): def get_weight(self, date, account_weight, next_position):
""" """
计算个股仓位 计算个股仓位
Args: Args:
untradable_list (list): 无法交易列表
account_weight (float): 总权重即当前持仓比例 account_weight (float): 总权重即当前持仓比例
""" """
if isinstance(self.weight, str): if isinstance(self.weight, str):
@ -220,26 +221,14 @@ class Trader(Account):
return account_weight / len(next_position) return account_weight / len(next_position)
if isinstance(self.weight, pd.DataFrame): if isinstance(self.weight, pd.DataFrame):
date_weight = self.weight.loc[date].dropna().sort_index() date_weight = self.weight.loc[date].dropna().sort_index()
# untradable_list不要求指定权重用昨日权重填充
weight_list = pd.Series(index=next_position['stock_code'])
try: try:
# 填充untradable_list权重 weight_list = date_weight.loc[next_position['stock_code'].to_list()].values
if len(untradable_list) > 0: if weight_list.sum() > 1 + 1e5: # 防止数据精度的影响,给与一定的宽松
weight_list.loc[untradable_list] = self.position.set_index('stock_code').loc[untradable_list, 'weight']
except Exception:
raise ValueError('not found stock weight for untradable stocks in last position.')
try:
# 获取tradable_list权重并对untradable_list占据的仓位进行调整
tradable_list = list(set(next_position['stock_code']) - set(untradable_list))
# 剔除untradable_list仓位后剩余持仓根据自定义权重分配
weight_list.loc[tradable_list] = date_weight.loc[tradable_list].values / date_weight.loc[tradable_list].sum() * (account_weight - weight_list.loc[untradable_list].sum())
weight_list = weight_list.values
if sum(weight_list) > 1 + 1e-5: # 防止数据精度的影响,给与一定的宽松
raise Exception(f"total weight of {date} is larger then 1.") raise Exception(f"total weight of {date} is larger then 1.")
weight_list = account_weight * weight_list
return weight_list return weight_list
except Exception as e: except Exception:
print(e) raise ValueError(f'not found stock weight in {date}')
raise ValueError(f'not found specified stock weight in {date}')
def get_next_position(self, date, factor): def get_next_position(self, date, factor):
""" """
@ -252,94 +241,56 @@ class Trader(Account):
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量 # 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
# 不足的数量通过买入列表自适应调整 # 不足的数量通过买入列表自适应调整
# 这样能实现在因子值不足时也正常换仓 # 这样能实现在因子值不足时也正常换仓
try:
max_sell_num = self.interval.loc[date]*len(last_position) max_sell_num = self.interval.loc[date]*len(last_position)
except Exception:
raise ValueError(f'not found interval in {date}')
else: else:
last_position = pd.Series() last_position = pd.Series()
max_sell_num = self.num max_sell_num = self.num
# 获取用于筛选的数据 # 获取用于筛选的数据
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index() stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
exclude_data = dict()
for cond in self.buy_exclude:
if cond == 'amt_20':
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index() stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_exclude = stock_amt20.copy() stock_amt_filter = stock_amt20.copy()
stock_amt_exclude.loc[:] = 0 stock_amt_filter.loc[:] = 0
stock_amt_exclude.loc[(stock_amt20 > self.buy_exclude[cond][0]) & (stock_amt20 < self.buy_exclude[cond][1])] = 1 stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1
stock_amt_exclude = stock_amt_exclude.sort_index() stock_amt_filter = stock_amt_filter.sort_index()
exclude_data[cond] = stock_amt_exclude
if cond == 'list_days':
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index() stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_exclude = stock_ipo_days.copy() stock_ipo_filter = stock_ipo_days.copy()
stock_ipo_exclude.loc[:] = 0 stock_ipo_filter.loc[:] = 0
stock_ipo_exclude.loc[(stock_ipo_days > self.buy_exclude[cond][0]) & (stock_ipo_days < self.buy_exclude[cond][1])] = 1 stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
stock_ipo_exclude = stock_ipo_exclude.sort_index()
exclude_data[cond] = stock_ipo_exclude
if cond == 'mkt':
stock_size = self.today_data['basic'].get(factor.index.values, 'size').sort_index()
stock_size_exclude = stock_size.copy()
stock_size_exclude.loc[:] = 0
stock_size_exclude.loc[(stock_size > self.buy_exclude[cond][0]) & (stock_size < self.buy_exclude[cond][1])] = 1
stock_size_exclude = stock_size_exclude.sort_index()
exclude_data[cond] = stock_size_exclude
if cond == 'price':
stock_price = self.today_data['basic'].get(factor.index.values, 'close_pre').sort_index()
stock_price_exclude = stock_price.copy()
stock_price_exclude.loc[:] = 0
stock_price_exclude.loc[(stock_price > self.buy_exclude[cond][0]) & (stock_price < self.buy_exclude[cond][1])] = 1
stock_price_exclude = stock_price_exclude.sort_index()
exclude_data[cond] = stock_price_exclude
# 剔除列表 # 剔除列表
# 包含强制剔除列表和买入剔除列表: # 包含强制列表和普通列表:
# 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中 # 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中
# 买入剔除列表如果已经持有不会过滤只对新买入的过滤 # 普通列表如果已经持有不会过滤只对新买入的过滤
# 强制过滤列表 # 强制过滤列表
exclude_stock = [] exclude_stock = []
for cond in self.force_exclude: for exclude in self.exclude_list:
if cond == 'abnormal': if exclude == 'abnormal':
stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index() stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index()
exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list() exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list()
if cond == 'recession': if exclude == 'recession':
stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index() stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index()
exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list() exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list()
exclude_stock = list(set(exclude_stock)) exclude_stock = list(set(exclude_stock))
force_exclude = copy.deepcopy(exclude_stock) force_exclude = copy.deepcopy(exclude_stock)
# 买入过滤列表 # 普通过滤列表
buy_exclude = [] normal_exclude = []
for cond in self.buy_exclude: normal_exclude += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list()
buy_exclude += exclude_data[cond].loc[exclude_data[cond] != 1].index.to_list() normal_exclude += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list()
buy_exclude = list(set(buy_exclude)) normal_exclude = list(set(normal_exclude))
# 交易列表 # 交易列表
# 仓位判断给与计算误差冗余 if self.today_position_ratio <= 1.0:
if self.today_position_ratio <= 1.0 + 1e-5:
# 如果没有杠杆: # 如果没有杠杆:
# 交易逻辑: buy_list = []
# 1 判断卖出,如果当天跌停则减少实际卖出数量
# 2 判断买入:根据实际卖出数量和距离目标持仓数量判断买入数量,如果当天涨停则减少实际买入数量
untradable_list = []
# ----- 卖出 -----
sell_list = [] sell_list = []
limit_down_list = [] # 跌停股记录 untradable_list = []
target_list = []
# 遍历昨日持仓状态: # ----- 卖出 -----
# 1 记录持仓状态 # 异常强制卖出
# 2 获取停牌股列表
# 3 获取异常强制卖出列表
last_position_status = pd.Series()
for stock in last_position.index: for stock in last_position.index:
last_position_status.loc[stock] = stock_status.loc[stock] if stock_status.loc[stock] in [0,2,5,7]:
if last_position_status.loc[stock] in [0,2]:
untradable_list.append(stock) untradable_list.append(stock)
else:
if last_position_status.loc[stock] in [5,7]:
continue
else: else:
if stock in force_exclude: if stock in force_exclude:
sell_list.append(stock) sell_list.append(stock)
@ -354,21 +305,15 @@ class Trader(Account):
for stock in factor_filled.loc[list(set(last_position.index)-set(untradable_list)-set(sell_list))].sort_values(ascending=self.ascending).index.values[::-1]: for stock in factor_filled.loc[list(set(last_position.index)-set(untradable_list)-set(sell_list))].sort_values(ascending=self.ascending).index.values[::-1]:
if len(sell_list) >= max_sell_num + force_sell_num: if len(sell_list) >= max_sell_num + force_sell_num:
break break
if last_position_status.loc[stock] in [0,2]: if stock_status.loc[stock] in [0,2,5,7]:
continue continue
else: else:
if last_position_status.loc[stock] in [5,7]:
limit_down_list.append(stock)
sell_list.append(stock) sell_list.append(stock)
sell_list = list(set(sell_list)) sell_list = list(set(sell_list))
# 实际卖出列表 = 卖出列表 - 跌停列表
sell_list = list(set(sell_list) - set(limit_down_list))
# ----- 买入 ----- # ----- 买入 -----
buy_list = []
# 剔除过滤条件后可买入列表 # 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude)) after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list() target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表 # 更新卖出后的持仓列表
@ -376,8 +321,6 @@ class Trader(Account):
limit_up_list = [] # 涨停股记录 limit_up_list = [] # 涨停股记录
max_buy_num = max(0, self.num-len(last_position)+len(sell_list)) max_buy_num = max(0, self.num-len(last_position)+len(sell_list))
for stock in target_list: for stock in target_list:
if len(buy_list) == max_buy_num:
break
if stock in after_sell_list: if stock in after_sell_list:
continue continue
else: else:
@ -389,23 +332,29 @@ class Trader(Account):
if stock_status.loc[stock] in [4,6]: if stock_status.loc[stock] in [4,6]:
limit_up_list.append(stock) limit_up_list.append(stock)
buy_list.append(stock) buy_list.append(stock)
buy_list = list(set(buy_list)) if len(buy_list) == max_buy_num:
break
# 剔除同时在sell_list和buy_list的股票 # 剔除同时在sell_list和buy_list的股票
duplicate_stock = set(sell_list) & set(buy_list) duplicate_stock = set(sell_list) & set(buy_list)
sell_list = list(set(sell_list) - duplicate_stock) sell_list = list(set(sell_list) - duplicate_stock)
buy_list = list(set(buy_list) - duplicate_stock) buy_list = list(set(buy_list) - duplicate_stock)
# 生成下一期持仓 # 生成下一期持仓
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))}) next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
next_position['date'] = date next_position['date'] = date
next_position['weight'] = self.get_weight(date, self.today_position_ratio, untradable_list+limit_down_list, next_position) next_position['weight'] = self.get_weight(date, self.today_position_ratio, next_position)
# 剔除无法买入的涨停股,这部分仓位空出
# 剔除无法买入且不在昨日持仓中的涨停股,这部分仓位空出 next_position = next_position[~next_position['stock_code'].isin(limit_up_list)]
next_position = next_position[~next_position['stock_code'].isin(list(set(limit_up_list)-set(last_position.index)))]
next_position['margin_trade'] = 0 next_position['margin_trade'] = 0
else: else:
# 如果有杠杆: # 如果有杠杆:
def assign_stock(normal_list, margin_list, margin_needed, stock, status):
if status == 1:
if len(margin_list) < margin_needed:
margin_list.append(stock)
else:
if len(normal_list) < self.num - margin_needed:
normal_list.append(stock)
return normal_list, margin_list
# 计算需要融资融券标的数量 # 计算需要融资融券标的数量
margin_ratio = max(self.today_position_ratio-1, 0) margin_ratio = max(self.today_position_ratio-1, 0)
margin_needed = round(self.num * margin_ratio) margin_needed = round(self.num * margin_ratio)
@ -423,6 +372,7 @@ class Trader(Account):
last_normal_list = [] last_normal_list = []
# ----- 卖出 ----- # ----- 卖出 -----
buy_list = []
sell_list = [] sell_list = []
untradable_list = [] untradable_list = []
# 分别更新融资融券池的和非融资融券池 # 分别更新融资融券池的和非融资融券池
@ -474,10 +424,8 @@ class Trader(Account):
next_normal_list = list(set(last_normal_list) - set(sell_list)) next_normal_list = list(set(last_normal_list) - set(sell_list))
# ----- 买入 ----- # ----- 买入 -----
buy_list = []
# 剔除过滤条件后可买入列表 # 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude)) after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list() target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表 # 更新卖出后的持仓列表
@ -486,8 +434,6 @@ class Trader(Account):
# 融资融券池的和非融资融券池的分开更新 # 融资融券池的和非融资融券池的分开更新
# 更新融资融券池 # 更新融资融券池
for stock in target_list: for stock in target_list:
if len(next_margin_list) >= margin_needed:
break
if stock in after_sell_list: if stock in after_sell_list:
continue continue
else: else:
@ -500,11 +446,10 @@ class Trader(Account):
if stock_status.loc[stock] in [4,6]: if stock_status.loc[stock] in [4,6]:
limit_up_list.append(stock) limit_up_list.append(stock)
next_margin_list.append(stock) next_margin_list.append(stock)
next_margin_list = list(set(next_margin_list)) if len(next_margin_list) >= margin_needed:
break
# 更新非融资融券池 # 更新非融资融券池
for stock in target_list: for stock in target_list:
if len(next_normal_list) >= self.num - len(next_margin_list):
break
if stock in (set(after_sell_list) | set(next_margin_list)): if stock in (set(after_sell_list) | set(next_margin_list)):
continue continue
else: else:
@ -516,21 +461,20 @@ class Trader(Account):
if stock_status.loc[stock] in [4,6]: if stock_status.loc[stock] in [4,6]:
limit_up_list.append(stock) limit_up_list.append(stock)
next_normal_list.append(stock) next_normal_list.append(stock)
next_normal_list = list(set(next_normal_list)) if len(next_normal_list) >= self.num - len(next_margin_list):
break
next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list}) next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list})
next_position['date'] = date next_position['date'] = date
# 融资融券数量 # 融资融券数量
margin_num = len(next_margin_list) margin_num = len(next_margin_list)
next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), untradable_list, next_position) next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), next_position)
next_position['margin_trade'] = 0 next_position['margin_trade'] = 0
next_position = next_position.set_index(['stock_code']) next_position = next_position.set_index(['stock_code'])
next_position.loc[next_margin_list, 'margin_trade'] = 1 next_position.loc[next_margin_list, 'margin_trade'] = 1
next_position = next_position.reset_index() next_position = next_position.reset_index()
# 剔除无法买入的涨停股,这部分仓位空出 # 剔除无法买入的涨停股,这部分仓位空出
next_position = next_position[~next_position['stock_code'].isin(limit_up_list)] next_position = next_position[~next_position['stock_code'].isin(limit_up_list)]
# 检测当前持仓是否可以交易 # 检测当前持仓是否可以交易
frozen_list = [] frozen_list = []
if len(self.position) > 0: if len(self.position) > 0:
@ -552,19 +496,13 @@ class Trader(Account):
buy_list (Iterable[str]): 买入目标 buy_list (Iterable[str]): 买入目标
sell_list (Iterable[str]): 卖出目标 sell_list (Iterable[str]): 卖出目标
""" """
if trade_time+'_trade' in self.today_data: stock_price = self.today_data[trade_time]
trade_price = self.today_data[trade_time+'_trade']
basic_price = self.today_data[trade_time]
else:
basic_price = self.today_data[trade_time]
trade_price = basic_price
target_price = pd.Series(index=target) target_price = pd.Series(index=target)
sell_list = list(set(target) & set(sell_list)) sell_list = list(set(target) & set(sell_list))
buy_list = list(set(target) & set(buy_list)) buy_list = list(set(target) & set(buy_list))
# 根据交易和非交易标的分别获取目标价格 target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
target_price.loc[target] = basic_price.get(target, 'price').fillna(0) target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1])
target_price.loc[sell_list] = trade_price.get(sell_list, 'price') * (1 - self.current_fee[1]) target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0])
target_price.loc[buy_list] = trade_price.get(buy_list, 'price') * (1 + self.current_fee[0])
return target_price return target_price
def check_update_status(self, def check_update_status(self,
@ -663,8 +601,7 @@ class Trader(Account):
if cur_pos['weight'].sum() == 0: if cur_pos['weight'].sum() == 0:
pnl = 0 pnl = 0
else: else:
cash = max(0, 1 - cur_pos['weight'].sum()) pnl = (cur_pos['end_weight'].sum() - cur_pos['weight'].sum())
pnl = ((cur_pos['end_weight'].sum() + cash) / (cur_pos['weight'].sum() + cash)) - 1
self.account *= 1+pnl self.account *= 1+pnl
self.account_history = self.account_history.append({ self.account_history = self.account_history.append({
'date': date, 'date': date,
@ -675,18 +612,6 @@ class Trader(Account):
}, ignore_index=True) }, ignore_index=True)
return True return True
@staticmethod
def update_next_weight(position):
"""
根据收盘权重计算下一时刻新的个股权重
"""
if position['weight'].sum() <= 1 + 1e-5:
# 非融资情况
cash = max(0, 1 - position['weight'].sum())
return position['end_weight'] / (position['end_weight'].sum() + cash)
else:
return position['weight'].sum() * (position['end_weight'] / position['end_weight'].sum())
def update_signal(self, def update_signal(self,
date:str, date:str,
update_type='rtn'): update_type='rtn'):
@ -705,9 +630,7 @@ class Trader(Account):
self.account_history = self.account_history.query(f'date != "{date}" ', engine='python') self.account_history = self.account_history.query(f'date != "{date}" ', engine='python')
if date in self.position_history: if date in self.position_history:
self.position_history.pop(date) self.position_history.pop(date)
# ----- 更新当日回测数据 ------ # 更新持仓信号
# 更新当前日期和持仓信号
self.current_date = date
self.load_data(date, update_type) self.load_data(date, update_type)
# 更新当日持仓比例 # 更新当日持仓比例
if isinstance(self.position_ratio, float): if isinstance(self.position_ratio, float):
@ -723,14 +646,13 @@ class Trader(Account):
fee = (fee[0] + current_tax[0], fee[1] + current_tax[1]) fee = (fee[0] + current_tax[0], fee[1] + current_tax[1])
self.current_fee = fee self.current_fee = fee
# 如果当前持仓不空,添加隔夜收益,否则直接买入 # 如果当前持仓不空,添加隔夜收益,否则直接买入
position_fields = ['stock_code','date','weight','margin_trade','open','close','end_weight']
if len(self.position) == 0: if len(self.position) == 0:
cur_pos = pd.DataFrame(columns=position_fields) cur_pos = pd.DataFrame(columns=['stock_code','date','weight','open','close','margin_trade'])
else: else:
cur_pos = self.position.copy() cur_pos = self.position.copy()
# 冻结列表 # 冻结列表
frozen_list = [] frozen_list = []
# ----- 遍历各个交易时间的信号 ----- # 遍历各个交易时间的信号
for _,trade_time in enumerate(self.signal): for _,trade_time in enumerate(self.signal):
if self.check_update_status(date, trade_time): if self.check_update_status(date, trade_time):
continue continue
@ -741,7 +663,6 @@ class Trader(Account):
factor = self.signal[trade_time].loc[date] factor = self.signal[trade_time].loc[date]
# 获取当前、持仓 # 获取当前、持仓
sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor) sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor)
# 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算 # 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算
if update_type == 'position': if update_type == 'position':
self.position_history[date] = next_position.copy() self.position_history[date] = next_position.copy()
@ -753,11 +674,9 @@ class Trader(Account):
# 计算收益 # 计算收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn'] cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
self.update_account(date, trade_time, cur_pos, next_position) self.update_account(date, trade_time, cur_pos, next_position)
# 更新仓位 # 更新仓位
cur_pos['weight'] = self.update_next_weight(cur_pos) cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
# 调整权重:买入、卖出、仓位再平衡 # 调整权重:买入、卖出、仓位再平衡
next_position = self.reblance_weight(trade_time, cur_pos, next_position) next_position = self.reblance_weight(trade_time, cur_pos, next_position)
else: else:
@ -772,19 +691,14 @@ class Trader(Account):
# 停牌价格不变 # 停牌价格不变
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open'] cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open'] cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open']
# 更新当日收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1 cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1) cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
position_record = cur_pos.copy() position_record = cur_pos.copy()
position_record['end_weight'] = self.update_next_weight(position_record) position_record['end_weight'] = (position_record['end_weight'] / position_record['end_weight'].sum()) * position_record['weight'].sum()
self.position_history[date] = position_record.copy()[position_fields] cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
# 更新当期收盘后个股仓位作为下一期的开盘仓位
cur_pos['weight'] = self.update_next_weight(cur_pos)
next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']] next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']]
next_position['open'] = cur_pos['close'] next_position['open'] = cur_pos['close']
self.update_account(date, trade_time, cur_pos, cur_pos) self.update_account(date, trade_time, cur_pos, cur_pos)
# 记录当前时刻最终持仓和个股权重
self.position = next_position.copy() self.position = next_position.copy()
self.position_history[date] = position_record.copy()
return True return True