From f103ac9d1e617d177e282f7a3451ef21169625e0 Mon Sep 17 00:00:00 2001 From: binz <123@123.com> Date: Sun, 2 Jun 2024 00:25:36 +0800 Subject: [PATCH] =?UTF-8?q?Update:=20exclude=5Flist=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E4=B8=9A=E7=BB=A9=E8=A1=B0=E9=80=80recession=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_handler.py | 5 +++-- dataloader.py | 1 - trader.py | 14 ++++++++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/data_handler.py b/data_handler.py index e388923..8968ff4 100644 --- a/data_handler.py +++ b/data_handler.py @@ -9,7 +9,8 @@ if __name__ == '__main__': data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data' save_dir = '/home/lenovo/quant/data/backtest/basic_data' - for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list','abnormal']): + for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list', + 'abnormal', 'recession']): if f in ['margin_list']: tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0) else: @@ -32,7 +33,7 @@ if __name__ == '__main__': # 更新下一日的数据用于筛选 next_date = gft.days_after(df.index.max(), 1) next_list = [] - for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal']): + for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal','recession']): if f in ['margin_list']: next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f)) else: diff --git a/dataloader.py b/dataloader.py index e522121..0d115af 100644 --- a/dataloader.py +++ b/dataloader.py @@ -20,7 +20,6 @@ class DataLoader(): - column(str): 查询列 """ res = pd.Series(index=target) - column_map = dict(zip(self.data.index.values, self.data[column].values)) column_type = self.data[column].dtype stock_list = list(set(self.data.index.values) & set(target)) res.loc[stock_list] = self.data.loc[stock_list, column].values diff --git a/trader.py b/trader.py index 1bac6b2..b6c451f 100644 --- a/trader.py +++ b/trader.py @@ -35,7 +35,8 @@ class Trader(Account): tax (dict): 印花税 exclude_list (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓 - abnormal: 异常公告剔除,包含中止上市、立案调查、警示函等异常情况的剔除 - - report: 财报同比下降50%以上剔除 + - receesion: 财报同比或环比下降50%以上 + - qualified_opinion: 会计保留意见 account (Account): 账户设置,account.Account """ def __init__(self, @@ -147,6 +148,12 @@ class Trader(Account): # exclude if isinstance(exclude_list, list): self.exclude_list = exclude_list + optional_list = ['abnormal', 'recession'] + for item in exclude_list: + if item in optional_list: + pass + else: + raise ValueError(f"Unexpected keyword argument '{item}'") else: raise ValueError('exclude_list should be list.') # data_root @@ -163,7 +170,7 @@ class Trader(Account): continue else: if s not in data_root: - raise ValueError(f'data for signal {s} is not provided') + raise ValueError(f"data for signal {s} is not provided") self.data_root = data_root def load_data(self, @@ -253,6 +260,9 @@ class Trader(Account): if exclude == 'abnormal': stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index() exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list() + if exclude == 'recession': + stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index() + exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list() exclude_stock = list(set(exclude_stock)) force_exclude = copy.deepcopy(exclude_stock) exclude_stock += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list()