Update: 新增mkt和price的买入过滤;

Update: 调整过滤参数名称和指定方式;(#21)
This commit is contained in:
binz 2024-06-26 23:36:58 +08:00
parent d671f820d4
commit 0977e8c9fd
3 changed files with 127 additions and 51 deletions

68
check_funcs.py Normal file
View File

@ -0,0 +1,68 @@
import numpy as np
def check_buy_exclude(buy_exclude):
buy_exclude_dict = dict()
# 检查剔除条件
for cond in ['amt_20', 'list_days', 'mkt', 'price']:
if cond == 'amt_20':
# 20日成交量均值
if cond in buy_exclude:
amt_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(amt_exclude) == 1:
buy_exclude_dict[cond] = (amt_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (amt_exclude[0], amt_exclude[1])
else:
raise Exception('wrong input type for buy exclude: amt_20, `set`, `list` or `tuple` is required')
if cond == 'list_days':
# 上市时间
if cond in buy_exclude:
list_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(list_exclude) == 1:
buy_exclude_dict[cond] = (list_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (list_exclude[0], list_exclude[1])
else:
raise Exception('wrong input type for buy exclude: list_days, `set`, `list` or `tuple` is required')
if cond == 'mkt':
# 市值
if cond in buy_exclude:
mkt_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(mkt_exclude) == 1:
buy_exclude_dict[cond] = (mkt_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (mkt_exclude[0], mkt_exclude[1])
else:
raise Exception('wrong input type for buy exclude: mkt, `set`, `list` or `tuple` is required')
if cond == 'price':
# 价格
if cond in buy_exclude:
price_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(price_exclude) == 1:
buy_exclude_dict[cond] = (price_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (price_exclude[0], price_exclude[1])
else:
raise Exception('wrong input type for buy exclude: price, `set`, `list` or `tuple` is required')
return buy_exclude_dict

View File

@ -10,8 +10,8 @@ if __name__ == '__main__':
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
save_dir = '/home/lenovo/quant/data/backtest/basic_data'
for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list',
'abnormal', 'recession']):
for i,f in enumerate(['open_post','close_post','open_pre','close_pre','down_limit','up_limit','size','amount_20',
'opening_info','ipo_days','margin_list','abnormal', 'recession']):
if f in ['margin_list']:
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
else:
@ -34,7 +34,7 @@ if __name__ == '__main__':
# 更新下一日的数据用于筛选
next_date = gft.days_after(df.index.max(), 1)
next_list = []
for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
for i,f in enumerate(['close_pre','size','amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
if f in ['margin_list']:
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
else:

104
trader.py
View File

@ -1,5 +1,4 @@
import pandas as pd
import numpy as np
import sys
import os
import copy
@ -8,12 +7,13 @@ from typing import Union, Iterable, Dict
from account import Account
from dataloader import DataLoader
from check_funcs import check_buy_exclude
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools
gft = get_factor_tools()
class Trader(Account):
"""
交易类: 用于控制每日交易情况
@ -35,7 +35,7 @@ class Trader(Account):
slippage (tuple): 买入和卖出滑点
commission (float): 佣金
tax (dict): 印花税
exclude_list (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓
force_exclude (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓
- abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除
- receesion: 财报同比或环比下降50%以上
- qualified_opinion: 会计保留意见
@ -50,18 +50,17 @@ class Trader(Account):
data_root:dict={},
tick: bool=False,
weight: str='avg',
amt_filter: set=(0,),
ipo_days: int=20,
slippage :tuple=(0.001,0.001),
commission: float=0.0001,
buy_exclude: Dict[str, Union[set,int,float]]={},
force_exclude: list=[],
account: dict={},
tax: dict={
'1990-01-01': (0.001,0.001),
'2008-04-24': (0.001,0.001),
'2008-09-19': (0, 0.001),
'2023-08-28': (0, 0.0005)
},
exclude_list: list=[],
account: dict={},
**kwargs) -> None:
# 初始化账户
super().__init__(**account)
@ -100,21 +99,6 @@ class Trader(Account):
self.weight = weight
else:
raise ValueError('invalid type for `weight`')
# amt_filter
if isinstance(amt_filter, (set,list,tuple)):
if len(amt_filter) == 1:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = np.inf
else:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = amt_filter[1]
else:
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
# ipo_days
if isinstance(ipo_days, int):
self.ipo_days = ipo_days
else:
raise Exception('wrong type for ipo_days, `int` is required')
# slippage
if isinstance(slippage, tuple) and len(slippage) == 2:
self.slippage = slippage
@ -130,17 +114,19 @@ class Trader(Account):
self.tax = tax
else:
raise ValueError('tax should be dict.')
# exclude
if isinstance(exclude_list, list):
self.exclude_list = exclude_list
# buy exclude
self.buy_exclude = check_buy_exclude(buy_exclude)
# force exclude
if isinstance(force_exclude, list):
self.force_exclude = force_exclude
optional_list = ['abnormal', 'recession']
for item in exclude_list:
for item in force_exclude:
if item in optional_list:
pass
else:
raise ValueError(f"Unexpected keyword argument '{item}'")
else:
raise ValueError('exclude_list should be list.')
raise ValueError('force_exclude should be list.')
# data_root
# 至少包含basic data路径open信号默认使用basic_data
if len(data_root) <= 0:
@ -276,36 +262,58 @@ class Trader(Account):
# 获取用于筛选的数据
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_filter = stock_amt20.copy()
stock_amt_filter.loc[:] = 0
stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1
stock_amt_filter = stock_amt_filter.sort_index()
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_filter = stock_ipo_days.copy()
stock_ipo_filter.loc[:] = 0
stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
exclude_data = dict()
for cond in self.buy_exclude:
if cond == 'amt_20':
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_exclude = stock_amt20.copy()
stock_amt_exclude.loc[:] = 0
stock_amt_exclude.loc[(stock_amt20 > self.buy_exclude[cond][0]) & (stock_amt20 < self.buy_exclude[cond][1])] = 1
stock_amt_exclude = stock_amt_exclude.sort_index()
exclude_data[cond] = stock_amt_exclude
if cond == 'list_days':
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_exclude = stock_ipo_days.copy()
stock_ipo_exclude.loc[:] = 0
stock_ipo_exclude.loc[(stock_ipo_days > self.buy_exclude[cond][0]) & (stock_ipo_days < self.buy_exclude[cond][1])] = 1
stock_ipo_exclude = stock_ipo_exclude.sort_index()
exclude_data[cond] = stock_ipo_exclude
if cond == 'mkt':
stock_size = self.today_data['basic'].get(factor.index.values, 'size').sort_index()
stock_size_exclude = stock_size.copy()
stock_size_exclude.loc[:] = 0
stock_size_exclude.loc[(stock_size > self.buy_exclude[cond][0]) & (stock_size < self.buy_exclude[cond][1])] = 1
stock_size_exclude = stock_size_exclude.sort_index()
exclude_data[cond] = stock_size_exclude
if cond == 'price':
stock_price = self.today_data['basic'].get(factor.index.values, 'close_pre').sort_index()
stock_price_exclude = stock_price.copy()
stock_price_exclude.loc[:] = 0
stock_price_exclude.loc[(stock_price > self.buy_exclude[cond][0]) & (stock_price < self.buy_exclude[cond][1])] = 1
stock_price_exclude = stock_price_exclude.sort_index()
exclude_data[cond] = stock_price_exclude
# 剔除列表
# 包含强制列表和普通列表:
# 包含强制剔除列表和买入剔除列表:
# 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中
# 普通列表如果已经持有不会过滤只对新买入的过滤
# 买入剔除列表如果已经持有不会过滤只对新买入的过滤
# 强制过滤列表
exclude_stock = []
for exclude in self.exclude_list:
if exclude == 'abnormal':
for cond in self.force_exclude:
if cond == 'abnormal':
stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index()
exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list()
if exclude == 'recession':
if cond == 'recession':
stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index()
exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list()
exclude_stock = list(set(exclude_stock))
force_exclude = copy.deepcopy(exclude_stock)
# 普通过滤列表
normal_exclude = []
normal_exclude += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list()
normal_exclude += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list()
normal_exclude = list(set(normal_exclude))
# 买入过滤列表
buy_exclude = []
for cond in self.buy_exclude:
buy_exclude += exclude_data[cond].loc[exclude_data[cond] != 1].index.to_list()
buy_exclude = list(set(buy_exclude))
# 交易列表
# 仓位判断给与计算误差冗余
@ -360,7 +368,7 @@ class Trader(Account):
buy_list = []
# 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表
@ -469,7 +477,7 @@ class Trader(Account):
buy_list = []
# 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表