diff --git a/check_funcs.py b/check_funcs.py new file mode 100644 index 0000000..20b8d7b --- /dev/null +++ b/check_funcs.py @@ -0,0 +1,68 @@ +import numpy as np + + +def check_buy_exclude(buy_exclude): + buy_exclude_dict = dict() + # 检查剔除条件 + for cond in ['amt_20', 'list_days', 'mkt', 'price']: + if cond == 'amt_20': + # 20日成交量均值 + if cond in buy_exclude: + amt_exclude = buy_exclude[cond] + else: + buy_exclude[cond] = (0, np.inf) + continue + if isinstance(buy_exclude[cond], (set,list,tuple)): + if len(amt_exclude) == 1: + buy_exclude_dict[cond] = (amt_exclude[0], np.inf) + else: + buy_exclude_dict[cond] = (amt_exclude[0], amt_exclude[1]) + else: + raise Exception('wrong input type for buy exclude: amt_20, `set`, `list` or `tuple` is required') + + if cond == 'list_days': + # 上市时间 + if cond in buy_exclude: + list_exclude = buy_exclude[cond] + else: + buy_exclude[cond] = (0, np.inf) + continue + if isinstance(buy_exclude[cond], (set,list,tuple)): + if len(list_exclude) == 1: + buy_exclude_dict[cond] = (list_exclude[0], np.inf) + else: + buy_exclude_dict[cond] = (list_exclude[0], list_exclude[1]) + else: + raise Exception('wrong input type for buy exclude: list_days, `set`, `list` or `tuple` is required') + + if cond == 'mkt': + # 市值 + if cond in buy_exclude: + mkt_exclude = buy_exclude[cond] + else: + buy_exclude[cond] = (0, np.inf) + continue + if isinstance(buy_exclude[cond], (set,list,tuple)): + if len(mkt_exclude) == 1: + buy_exclude_dict[cond] = (mkt_exclude[0], np.inf) + else: + buy_exclude_dict[cond] = (mkt_exclude[0], mkt_exclude[1]) + else: + raise Exception('wrong input type for buy exclude: mkt, `set`, `list` or `tuple` is required') + + if cond == 'price': + # 价格 + if cond in buy_exclude: + price_exclude = buy_exclude[cond] + else: + buy_exclude[cond] = (0, np.inf) + continue + if isinstance(buy_exclude[cond], (set,list,tuple)): + if len(price_exclude) == 1: + buy_exclude_dict[cond] = (price_exclude[0], np.inf) + else: + buy_exclude_dict[cond] = (price_exclude[0], price_exclude[1]) + else: + raise Exception('wrong input type for buy exclude: price, `set`, `list` or `tuple` is required') + + return buy_exclude_dict \ No newline at end of file diff --git a/data_handler.py b/data_handler.py index 00c1b02..6b0f03d 100644 --- a/data_handler.py +++ b/data_handler.py @@ -10,8 +10,8 @@ if __name__ == '__main__': data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data' save_dir = '/home/lenovo/quant/data/backtest/basic_data' - for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list', - 'abnormal', 'recession']): + for i,f in enumerate(['open_post','close_post','open_pre','close_pre','down_limit','up_limit','size','amount_20', + 'opening_info','ipo_days','margin_list','abnormal', 'recession']): if f in ['margin_list']: tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0) else: @@ -34,7 +34,7 @@ if __name__ == '__main__': # 更新下一日的数据用于筛选 next_date = gft.days_after(df.index.max(), 1) next_list = [] - for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal','recession']): + for i,f in enumerate(['close_pre','size','amount_20','opening_info','ipo_days','margin_list','abnormal','recession']): if f in ['margin_list']: next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f)) else: diff --git a/trader.py b/trader.py index dfd9a01..4465cc7 100644 --- a/trader.py +++ b/trader.py @@ -1,5 +1,4 @@ import pandas as pd -import numpy as np import sys import os import copy @@ -8,12 +7,13 @@ from typing import Union, Iterable, Dict from account import Account from dataloader import DataLoader +from check_funcs import check_buy_exclude sys.path.append("/home/lenovo/quant/tools/get_factor_tools/") from db_tushare import get_factor_tools gft = get_factor_tools() - + class Trader(Account): """ 交易类: 用于控制每日交易情况 @@ -35,7 +35,7 @@ class Trader(Account): slippage (tuple): 买入和卖出滑点 commission (float): 佣金 tax (dict): 印花税 - exclude_list (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓 + force_exclude (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓 - abnormal: 异常公告剔除,包含中止上市、立案调查、警示函等异常情况的剔除 - receesion: 财报同比或环比下降50%以上 - qualified_opinion: 会计保留意见 @@ -50,18 +50,17 @@ class Trader(Account): data_root:dict={}, tick: bool=False, weight: str='avg', - amt_filter: set=(0,), - ipo_days: int=20, slippage :tuple=(0.001,0.001), commission: float=0.0001, + buy_exclude: Dict[str, Union[set,int,float]]={}, + force_exclude: list=[], + account: dict={}, tax: dict={ '1990-01-01': (0.001,0.001), '2008-04-24': (0.001,0.001), '2008-09-19': (0, 0.001), '2023-08-28': (0, 0.0005) }, - exclude_list: list=[], - account: dict={}, **kwargs) -> None: # 初始化账户 super().__init__(**account) @@ -100,21 +99,6 @@ class Trader(Account): self.weight = weight else: raise ValueError('invalid type for `weight`') - # amt_filter - if isinstance(amt_filter, (set,list,tuple)): - if len(amt_filter) == 1: - self.amt_filter_min = amt_filter[0] - self.amt_filter_max = np.inf - else: - self.amt_filter_min = amt_filter[0] - self.amt_filter_max = amt_filter[1] - else: - raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required') - # ipo_days - if isinstance(ipo_days, int): - self.ipo_days = ipo_days - else: - raise Exception('wrong type for ipo_days, `int` is required') # slippage if isinstance(slippage, tuple) and len(slippage) == 2: self.slippage = slippage @@ -130,17 +114,19 @@ class Trader(Account): self.tax = tax else: raise ValueError('tax should be dict.') - # exclude - if isinstance(exclude_list, list): - self.exclude_list = exclude_list + # buy exclude + self.buy_exclude = check_buy_exclude(buy_exclude) + # force exclude + if isinstance(force_exclude, list): + self.force_exclude = force_exclude optional_list = ['abnormal', 'recession'] - for item in exclude_list: + for item in force_exclude: if item in optional_list: pass else: raise ValueError(f"Unexpected keyword argument '{item}'") else: - raise ValueError('exclude_list should be list.') + raise ValueError('force_exclude should be list.') # data_root # 至少包含basic data路径,open信号默认使用basic_data if len(data_root) <= 0: @@ -276,36 +262,58 @@ class Trader(Account): # 获取用于筛选的数据 stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index() - stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index() - stock_amt_filter = stock_amt20.copy() - stock_amt_filter.loc[:] = 0 - stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1 - stock_amt_filter = stock_amt_filter.sort_index() - stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index() - stock_ipo_filter = stock_ipo_days.copy() - stock_ipo_filter.loc[:] = 0 - stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1 + exclude_data = dict() + for cond in self.buy_exclude: + if cond == 'amt_20': + stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index() + stock_amt_exclude = stock_amt20.copy() + stock_amt_exclude.loc[:] = 0 + stock_amt_exclude.loc[(stock_amt20 > self.buy_exclude[cond][0]) & (stock_amt20 < self.buy_exclude[cond][1])] = 1 + stock_amt_exclude = stock_amt_exclude.sort_index() + exclude_data[cond] = stock_amt_exclude + if cond == 'list_days': + stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index() + stock_ipo_exclude = stock_ipo_days.copy() + stock_ipo_exclude.loc[:] = 0 + stock_ipo_exclude.loc[(stock_ipo_days > self.buy_exclude[cond][0]) & (stock_ipo_days < self.buy_exclude[cond][1])] = 1 + stock_ipo_exclude = stock_ipo_exclude.sort_index() + exclude_data[cond] = stock_ipo_exclude + if cond == 'mkt': + stock_size = self.today_data['basic'].get(factor.index.values, 'size').sort_index() + stock_size_exclude = stock_size.copy() + stock_size_exclude.loc[:] = 0 + stock_size_exclude.loc[(stock_size > self.buy_exclude[cond][0]) & (stock_size < self.buy_exclude[cond][1])] = 1 + stock_size_exclude = stock_size_exclude.sort_index() + exclude_data[cond] = stock_size_exclude + if cond == 'price': + stock_price = self.today_data['basic'].get(factor.index.values, 'close_pre').sort_index() + stock_price_exclude = stock_price.copy() + stock_price_exclude.loc[:] = 0 + stock_price_exclude.loc[(stock_price > self.buy_exclude[cond][0]) & (stock_price < self.buy_exclude[cond][1])] = 1 + stock_price_exclude = stock_price_exclude.sort_index() + exclude_data[cond] = stock_price_exclude # 剔除列表 - # 包含强制列表和普通列表: + # 包含强制剔除列表和买入剔除列表: # 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中 - # 普通列表如果已经持有不会过滤只对新买入的过滤 + # 买入剔除列表如果已经持有不会过滤只对新买入的过滤 + # 强制过滤列表 exclude_stock = [] - for exclude in self.exclude_list: - if exclude == 'abnormal': + for cond in self.force_exclude: + if cond == 'abnormal': stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index() exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list() - if exclude == 'recession': + if cond == 'recession': stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index() exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list() exclude_stock = list(set(exclude_stock)) force_exclude = copy.deepcopy(exclude_stock) - # 普通过滤列表 - normal_exclude = [] - normal_exclude += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list() - normal_exclude += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list() - normal_exclude = list(set(normal_exclude)) + # 买入过滤列表 + buy_exclude = [] + for cond in self.buy_exclude: + buy_exclude += exclude_data[cond].loc[exclude_data[cond] != 1].index.to_list() + buy_exclude = list(set(buy_exclude)) # 交易列表 # 仓位判断给与计算误差冗余 @@ -360,7 +368,7 @@ class Trader(Account): buy_list = [] # 剔除过滤条件后可买入列表 - after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude)) + after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude)) target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list() # 更新卖出后的持仓列表 @@ -469,7 +477,7 @@ class Trader(Account): buy_list = [] # 剔除过滤条件后可买入列表 - after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude)) + after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude)) target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list() # 更新卖出后的持仓列表