diff --git a/account.py b/account.py index 880c099..d23ca97 100644 --- a/account.py +++ b/account.py @@ -39,8 +39,8 @@ class Account(): if isinstance(leverage, (int,float)): if leverage > 1.5: raise ValueError('leverage should less than 1.5') - if leverage > 1: - warnings.warn('leverage(杠杆率)大于1时,优先满足杠杆,会出现高于指定换手率的情况!', category=CustomWarning) + # if leverage > 1: + # warnings.warn('leverage(杠杆率)大于1时,优先满足杠杆,会出现高于指定换手率的情况!', category=CustomWarning) else: raise ValueError('leverage should be int or float') self.leverage = leverage diff --git a/spread_backtest.py b/spread_backtest.py index fd5c7b1..1b44473 100644 --- a/spread_backtest.py +++ b/spread_backtest.py @@ -72,6 +72,7 @@ class Spread_Backtest(): else: year_rtn.loc[year, '收益'] = (rtn_stat.loc[end_date] / rtn_stat.loc[start_date]) - 1 year_rtn.loc['Annualized'] = np.power(rtn_stat.values[-1], 1 / (len(rtn_stat) / 244)) - 1 + year_rtn.loc['Avg.Turnover'] = self.trader.account_history.groupby('date')['turnover'].sum().mean() year_rtn = year_rtn.applymap(lambda x: '{:.2%}'.format(x)) year_rtn = year_rtn.reset_index() year_rtn.columns = ['Year', 'Rtn'] diff --git a/trader.py b/trader.py index f2cfcd6..58eef0c 100644 --- a/trader.py +++ b/trader.py @@ -1,6 +1,7 @@ import pandas as pd import numpy as np import sys, os +import copy sys.path.append("/home/lenovo/quant/tools/get_factor_tools/") from db_tushare import get_factor_tools gft = get_factor_tools() @@ -55,6 +56,7 @@ class Trader(Account): '2008-09-19': (0, 0.001), '2023-08-28': (0, 0.0005) }, + exclude_list: list=[], **kwargs) -> None: # 初始化账户 super().__init__(**kwargs.get('account', {})) @@ -138,6 +140,11 @@ class Trader(Account): self.tax = tax else: raise ValueError('tax should be dict.') + # exclude + if isinstance(exclude_list, list): + self.exclude_list = exclude_list + else: + raise ValueError('exclude_list should be list.') # data_root # 至少包含basic data路径,open信号默认使用basic_data if len(data_root) <= 0: @@ -229,7 +236,16 @@ class Trader(Account): stock_ipo_filter.loc[:] = 0 stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1 # 剔除列表 - + exclude_stock = [] + for exclude in self.exclude_list: + if exclude == 'abnormal': + stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index() + exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list() + exclude_stock = list(set(exclude_stock)) + force_exclude = copy.deepcopy(exclude_stock) + exclude_stock += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list() + exclude_stock += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list() + exclude_stock = list(set(exclude_stock)) # 交易列表 if self.leverage <= 1.0: # 获取当前时间目标列表和冻结(无法交易)列表: @@ -239,13 +255,12 @@ class Trader(Account): # 如果停牌或者跌停继续持有 if stock_status.loc[stock] in [0,2]: target_list.append(stock) - for stock in factor.dropna().sort_values(ascending=self.ascending).index.values: + # 剔除过滤条件后 + after_filter_list = list(set(factor.index) - set(exclude_stock)) + for stock in factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.values: # 目标持仓数量达到持仓数量则中止 if len(target_list) == self.num: break - # 成交量、上市时间过滤 - if (stock_amt_filter.loc[stock] != 1) or (stock_ipo_filter.loc[stock] != 1): - continue # ST过滤 if self.with_st: if stock_status.loc[stock] in [0,2,5,7]: @@ -257,7 +272,6 @@ class Trader(Account): # 非ST if stock_status.loc[stock] in [3,6]: target_list.append(stock) - # 如果停牌或者跌停继续持有 if stock_status.loc[stock] in [0,2,7]: if stock in last_position.index: @@ -268,12 +282,19 @@ class Trader(Account): buy_list = [] sell_list = [] # ----- 卖出 ----- + # 异常强制卖出 + for stock in last_position.index: + if stock in force_exclude: + sell_list.append(stock) + force_sell_num = len(sell_list) # 按照反向排名逐个卖出 if self.ascending: factor = factor.fillna(factor.max()+1) else: factor = factor.fillna(factor.min()-1) - for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]: + for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]: + if len(sell_list) >= max_sell_num + force_sell_num: + break if stock in target_list: continue else: @@ -281,9 +302,7 @@ class Trader(Account): continue else: sell_list.append(stock) - if len(sell_list) >= max_sell_num: - break - + sell_list = list(set(sell_list)) # ----- 买入 ----- # 卖出后持仓列表 after_sell_list = set(last_position.index) - set(sell_list) @@ -329,8 +348,9 @@ class Trader(Account): # 如果停牌或者跌停继续持有 if stock_status.loc[stock] in [0,2]: normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock]) - - for stock in factor.dropna().sort_values(ascending=self.ascending).index.values: + # 剔除过滤条件后 + after_filter_list = list(set(factor.index) - set(exclude_stock)) + for stock in factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.values: if stock_amt_filter.loc[stock] != 1: continue if self.with_st: @@ -360,7 +380,14 @@ class Trader(Account): last_margin_list = self.position.loc[self.position['margin_trade'] == 1, 'stock_code'].to_list() else: last_margin_list = [] + # 异常强制卖出 + for stock in last_margin_list: + if stock in force_exclude: + sell_list.append(stock) + force_sell_num = len(sell_list) for stock in factor.loc[last_margin_list].sort_values(ascending=self.ascending).index.values[::-1]: + if len(sell_list) >= int(max_sell_num * margin_ratio) + force_sell_num + 1: + break if stock in normal_list: continue else: @@ -368,15 +395,21 @@ class Trader(Account): continue else: sell_list.append(stock) - if len(sell_list) >= int(max_sell_num * margin_ratio) + 1: - break + sell_list = list(set(sell_list)) next_margin_list = list(set(last_margin_list) - set(sell_list)) # 更新非融资融券池 if len(last_position) > 0: last_normal_list = self.position.loc[self.position['margin_trade'] == 0, 'stock_code'].to_list() else: last_normal_list = [] + # 异常强制卖出 + for stock in last_normal_list: + if stock in force_exclude: + sell_list.append(stock) + force_sell_num += 1 for stock in factor.loc[last_normal_list].sort_values(ascending=self.ascending).index.values[::-1]: + if len(sell_list) >= max_sell_num + force_sell_num: + break if stock in normal_list: continue else: @@ -384,8 +417,7 @@ class Trader(Account): continue else: sell_list.append(stock) - if len(sell_list) >= max_sell_num: - break + sell_list = list(set(sell_list)) next_normal_list = list(set(last_normal_list) - set(sell_list)) # ----- 买入 ----- # 卖出后持仓列表 @@ -412,6 +444,7 @@ class Trader(Account): next_normal_list.append(stock) if len(next_normal_list) >= self.num - margin_needed: break + next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list}) next_position['date'] = date # 融资融券数量