diff --git a/data_handler.py b/data_handler.py index cf7f052..e388923 100644 --- a/data_handler.py +++ b/data_handler.py @@ -9,7 +9,7 @@ if __name__ == '__main__': data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data' save_dir = '/home/lenovo/quant/data/backtest/basic_data' - for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list']): + for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list','abnormal']): if f in ['margin_list']: tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0) else: @@ -32,7 +32,7 @@ if __name__ == '__main__': # 更新下一日的数据用于筛选 next_date = gft.days_after(df.index.max(), 1) next_list = [] - for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list']): + for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal']): if f in ['margin_list']: next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f)) else: diff --git a/spread_backtest.py b/spread_backtest.py index 426e597..fd5c7b1 100644 --- a/spread_backtest.py +++ b/spread_backtest.py @@ -25,10 +25,13 @@ class Spread_Backtest(): else: bkt_start = max(start, self.trader.avaliable_date.index.min()) rec_end = min(end, self.trader.avaliable_date.index.max()) - # 如果传入的第一个因子时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益 + # 如果传入的第一个因子的有值时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益 first_signal = self.trader.signal[list(self.trader.signal.keys())[0]] if rec_end < first_signal.index.max(): - bkt_end = first_signal.loc[rec_end:].index.to_list()[1] + if rec_end in first_signal.index: + bkt_end = first_signal.loc[rec_end:].index.to_list()[1] + else: + bkt_end = first_signal.loc[rec_end:].index.to_list()[0] else: bkt_end = rec_end print(f'回测区间: {bkt_start} - {bkt_end}') @@ -92,5 +95,4 @@ def print_year_rtn(year_rtn: pd.DataFrame) -> None: table.add_row(*new_row) rprint(table) - \ No newline at end of file diff --git a/trader.py b/trader.py index 029f279..6c6a673 100644 --- a/trader.py +++ b/trader.py @@ -4,7 +4,10 @@ import sys, os sys.path.append("/home/lenovo/quant/tools/get_factor_tools/") from db_tushare import get_factor_tools gft = get_factor_tools() + from typing import Union, Iterable, Dict +from ordered_set import OrderedSet + from account import Account from dataloader import DataLoader @@ -18,7 +21,6 @@ class Trader(Account): interval (int, tuple, pd.Series): 交易间隔 num (int): 持仓数量 ascending (bool): 因子方向 - fee (tuple): 买入成本和卖出成本 with_st (bool): 是否包含st tick (bool): 是否开始tick模拟模式(开发中) weight (str): 权重分配 @@ -26,19 +28,29 @@ class Trader(Account): amt_filter (set): 20日均成交额筛选,第一个参数是筛选下限,第二个参数是筛选上限,可以只提供下限 data_root (dict): 对应各个目标因子的交易价格数据,必须包含stock_code和price列 ipo_days (int): 筛选上市时间 + slippage (tuple): 买入和卖出滑点 + commission (float): 佣金 + tax (dict): 印花税 """ def __init__(self, signal: Dict[str, pd.DataFrame]=None, interval: Dict[str, Union[int,tuple,pd.Series]]=1, num: int=100, ascending: bool=False, - fee :tuple=(0.001,0.002), with_st: bool=False, data_root:dict={}, tick: bool=False, weight: str='avg', amt_filter: set=(0,), ipo_days: int=20, + slippage :tuple=(0.001,0.001), + commission: float=0.0001, + tax: dict={ + '1990-01-01': (0.001,0.001), + '2008-04-24': (0.001,0.001), + '2008-09-19': (0, 0.001), + '2023-08-28': (0, 0.0005) + }, **kwargs) -> None: # 初始化账户 super().__init__(**kwargs.get('account', {})) @@ -50,6 +62,9 @@ class Trader(Account): self.signal[s] = gft.return_factor(self.signal[s], self.signal[s].index.min(), self.signal[s].index.max(), return_type='origin') else: raise ValueError('type of signal is invalid') + # ---------- + # 参数检验 + # ---------- # interval self.interval = [] for s in signal: @@ -79,11 +94,6 @@ class Trader(Account): self.ascending = ascending else: raise ValueError('invalid type for ascending') - # fee - if isinstance(fee, tuple) and len(fee) == 2: - self.fee = fee - else: - raise ValueError('invalid input for fee') # with_st if isinstance(with_st, bool): self.with_st = with_st @@ -104,6 +114,21 @@ class Trader(Account): self.ipo_days = ipo_days else: raise Exception('wrong type for ipo_days, `int` is required') + # slippage + if isinstance(slippage, tuple) and len(slippage) == 2: + self.slippage = slippage + else: + raise ValueError('slippage should be set.') + # commission + if isinstance(commission, float): + self.commission = commission + else: + raise ValueError('commission should be flaot.') + # tax + if isinstance(tax, dict): + self.tax = tax + else: + raise ValueError('tax should be dict.') # data_root # 至少包含basic data路径,open信号默认使用basic_data if len(data_root) <= 0: @@ -175,27 +200,38 @@ class Trader(Account): stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1 # 交易列表 if self.leverage <= 1.0: - # 获取当前时间目标列表和冻结(无法交易)列表 + # 获取当前时间目标列表和冻结(无法交易)列表: + # 1 保留冻结的昨日持仓 + # 2 目标持仓列表会对当日因子进行ST、成交量、上市时间滤后按因子排序方向选取 + for stock in last_position.index: + # 如果停牌或者跌停继续持有 + if stock_status.loc[stock] in [0,2]: + target_list.append(stock) for stock in factor.dropna().sort_values(ascending=self.ascending).index.values: + # 目标持仓数量达到持仓数量则中止 + if len(target_list) == self.num: + break + # 成交量、上市时间过滤 if (stock_amt_filter.loc[stock] != 1) or (stock_ipo_filter.loc[stock] != 1): continue + # ST过滤 if self.with_st: - if stock_status.loc[stock] in [0,2]: + if stock_status.loc[stock] in [0,2,5,7]: if stock in last_position.index: target_list.append(stock) else: target_list.append(stock) else: # 非ST - if stock_status.loc[stock] in [3,7]: + if stock_status.loc[stock] in [3,6]: target_list.append(stock) - else: - # 如果停牌或者跌停继续持有 - if stock_status.loc[stock] in [0,2,6]: - if stock in last_position.index: - target_list.append(stock) - if len(target_list) == self.num: - break + + # 如果停牌或者跌停继续持有 + if stock_status.loc[stock] in [0,2,7]: + if stock in last_position.index: + target_list.append(stock) + target_list = list(OrderedSet(target_list)) + # 如果没有杠杆 buy_list = [] sell_list = [] @@ -205,7 +241,7 @@ class Trader(Account): factor = factor.fillna(factor.max()+1) else: factor = factor.fillna(factor.min()-1) - for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]: + for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]: if stock in target_list: continue else: @@ -213,8 +249,9 @@ class Trader(Account): continue else: sell_list.append(stock) - if len(sell_list) == max_trade_num: + if len(sell_list) >= max_trade_num: break + # ----- 买入 ----- # 卖出后持仓列表 after_sell_list = set(last_position.index) - set(sell_list) @@ -244,25 +281,34 @@ class Trader(Account): margin_ratio = max(self.leverage-1, 0) margin_needed = round(self.num * margin_ratio) is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index() - # 获取当前时间目标列表和冻结(无法交易)列表 + # 获取当前时间目标列表和冻结(无法交易)列表: + # 1 保留冻结的昨日持仓 + # 2 目标持仓列表会对当日因子进行ST、成交量、上市时间滤后按因子排序方向选取 + # 3 根据融资比例将股票分为常规池子和融资池子 normal_list = [] margin_list = [] + + for stock in last_position.index: + # 如果停牌或者跌停继续持有 + if stock_status.loc[stock] in [0,2]: + normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock]) + for stock in factor.dropna().sort_values(ascending=self.ascending).index.values: if stock_amt_filter.loc[stock] != 1: continue if self.with_st: - if stock_status.loc[stock] in [0,2]: + if stock_status.loc[stock] in [0,2,5,7]: if stock in last_position.index: normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock]) else: normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock]) else: # 非ST - if stock_status.loc[stock] in [3,7]: + if stock_status.loc[stock] in [3,6]: normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock]) else: # 如果停牌或者跌停继续持有 - if stock_status.loc[stock] in [0,2,6]: + if stock_status.loc[stock] in [0,2,7]: if stock in last_position.index: normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock]) if len(normal_list + margin_list) == self.num: @@ -329,7 +375,6 @@ class Trader(Account): next_position['date'] = date # 融资融券数量 margin_num = len(next_margin_list) -# next_position['weight'] = self.leverage*((margin_needed-margin_num)/margin_needed) / len(next_position) next_position['weight'] = (1 + (margin_num / self.num)) / self.num next_position['margin_trade'] = 0 next_position = next_position.set_index(['stock_code']) @@ -361,8 +406,8 @@ class Trader(Account): sell_list = list(set(target) & set(sell_list)) buy_list = list(set(target) & set(buy_list)) target_price.loc[target] = stock_price.get(target, 'price').fillna(0) - target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.fee[1]) - target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.fee[0]) + target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1]) + target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0]) return target_price def check_update_status(self, @@ -473,8 +518,8 @@ class Trader(Account): return True def update_signal(self, - date:str, - update_type='rtn'): + date:str, + update_type='rtn'): """ 更新信号收益 @@ -492,6 +537,14 @@ class Trader(Account): self.position_history.pop(date) # 更新持仓信号 self.load_data(date, update_type) + # 更新费用 + fee = (self.commission + self.slippage[0], self.commission + self.slippage[1]) + current_tax = (0.001, 0.001) + for time,tax_rate in self.tax.items(): + if date > time: + current_tax = tax_rate + fee = (fee[0] + current_tax[0], fee[1] + current_tax[1]) + self.current_fee = fee # 如果当前持仓不空,添加隔夜收益,否则直接买入 if len(self.position) == 0: cur_pos = pd.DataFrame(columns=['stock_code','date','weight','open','close','margin_trade']) @@ -510,6 +563,7 @@ class Trader(Account): factor = self.signal[trade_time].loc[date] # 获取当前、持仓 sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor) + # 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算 if update_type == 'position': self.position_history[date] = next_position.copy() return True