diff --git a/trader.py b/trader.py index 6c6a673..784bba4 100644 --- a/trader.py +++ b/trader.py @@ -23,14 +23,18 @@ class Trader(Account): ascending (bool): 因子方向 with_st (bool): 是否包含st tick (bool): 是否开始tick模拟模式(开发中) - weight (str): 权重分配 - - avg: 平均分配,每天早盘重新分配,日中交易不重新分配 + weight ([str, pd.DataFrame]): 权重分配 + - avg (str): 平均分配,每天早盘重新分配,日中交易不重新分配 + - (pd.DataFrame): 自定义股票权重,包含每天个股指定的权重,会自动归一化 amt_filter (set): 20日均成交额筛选,第一个参数是筛选下限,第二个参数是筛选上限,可以只提供下限 data_root (dict): 对应各个目标因子的交易价格数据,必须包含stock_code和price列 ipo_days (int): 筛选上市时间 slippage (tuple): 买入和卖出滑点 commission (float): 佣金 tax (dict): 印花税 + exclude_list (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓 + - abnormal: 异常公告剔除,包含中止上市、立案调查、警示函等异常情况的剔除 + - report: 财报同比下降50%以上剔除 """ def __init__(self, signal: Dict[str, pd.DataFrame]=None, @@ -62,9 +66,9 @@ class Trader(Account): self.signal[s] = gft.return_factor(self.signal[s], self.signal[s].index.min(), self.signal[s].index.max(), return_type='origin') else: raise ValueError('type of signal is invalid') - # ---------- + # -------------------- # 参数检验 - # ---------- + # -------------------- # interval self.interval = [] for s in signal: @@ -93,12 +97,17 @@ class Trader(Account): if isinstance(ascending, bool): self.ascending = ascending else: - raise ValueError('invalid type for ascending') + raise ValueError('invalid type for `ascending`') # with_st if isinstance(with_st, bool): self.with_st = with_st else: - raise ValueError('invalid type for with_st') + raise ValueError('invalid type for `with_st`') + # weight + if isinstance(weight, (str, pd.DataFrame)): + self.weight = weight + else: + raise ValueError('invalid type for `weight`') # amt_filter if isinstance(amt_filter, (set,list,tuple)): if len(amt_filter) == 1: @@ -173,7 +182,23 @@ class Trader(Account): pass else: self.today_data['close'] = DataLoader(self.today_data['basic'].data[['close_post']].rename(columns={'close_post':'price'})) - + + def get_weight(self, date, total_weight, next_position): + """ + 计算个股仓位 + """ + if isinstance(self.weight, str): + if self.weight == 'avg': + return total_weight / len(next_position) + if isinstance(self.weight, pd.DataFrame): + date_weight = self.weight.loc[date].dropna().sort_index() + try: + weight_list = date_weight.loc[next_position['stock_code'].to_list()].values + weight_list = total_weight * weight_list / sum(weight_list) + return weight_list + except: + raise ValueError(f'not found stock weight in {date}') + def get_next_position(self, date, factor): """ 计算下一时刻持仓 @@ -182,10 +207,15 @@ class Trader(Account): if len(self.position) > 0: last_position = self.position['stock_code'].values last_position = factor.loc[last_position].sort_values(ascending=self.ascending) - max_trade_num = max(int(self.interval.loc[date]*self.num), self.num-len(self.position)) + if len(self.position) <= self.num: + # 如果昨日持仓本身就不足持仓数量则不用足额换仓 + max_sell_num = min(int(self.interval.loc[date]*self.num), int(self.interval.loc[date]*self.num)+len(self.position)-self.num) + else: + # 如果昨日持仓本身就超额持仓数量则超额换仓 + max_sell_num = max(int(self.interval.loc[date]*self.num), int(self.interval.loc[date]*self.num)+len(self.position)-self.num) else: last_position = pd.Series() - max_trade_num = self.num + max_sell_num = self.num target_list = [] # 获取用于筛选的数据 stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index() @@ -198,6 +228,8 @@ class Trader(Account): stock_ipo_filter = stock_ipo_days.copy() stock_ipo_filter.loc[:] = 0 stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1 + # 剔除列表 + # 交易列表 if self.leverage <= 1.0: # 获取当前时间目标列表和冻结(无法交易)列表: @@ -249,23 +281,28 @@ class Trader(Account): continue else: sell_list.append(stock) - if len(sell_list) >= max_trade_num: + if len(sell_list) >= max_sell_num: break # ----- 买入 ----- # 卖出后持仓列表 after_sell_list = set(last_position.index) - set(sell_list) - max_trade_num = min(max_trade_num, self.num-len(last_position)+len(sell_list)) + cant_buy_list = [] # 涨停股记录 + max_buy_num = max(0, self.num-len(last_position)+len(sell_list)) for stock in target_list: if stock in after_sell_list: continue else: + if stock_status.loc[stock] in [4,6]: + cant_buy_list.append(stock) buy_list.append(stock) - if len(buy_list) == max_trade_num: + if len(buy_list) == max_buy_num: break next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))}) next_position['date'] = date - next_position['weight'] = self.leverage / len(next_position) + next_position['weight'] = self.get_weight(date, self.leverage, next_position) + # 剔除无法买入的涨停股,这部分仓位空出 + next_position = next_position[~next_position['stock_code'].isin(cant_buy_list)] next_position['margin_trade'] = 0 else: # 如果有杠杆 @@ -375,7 +412,7 @@ class Trader(Account): next_position['date'] = date # 融资融券数量 margin_num = len(next_margin_list) - next_position['weight'] = (1 + (margin_num / self.num)) / self.num + next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), next_position) next_position['margin_trade'] = 0 next_position = next_position.set_index(['stock_code']) next_position.loc[next_margin_list, 'margin_trade'] = 1