diff --git a/trader.py b/trader.py index 12f6787..d085f5b 100644 --- a/trader.py +++ b/trader.py @@ -647,12 +647,16 @@ class Trader(Account): return True @staticmethod - def update_by_end_weight(position): + def update_next_weight(position): """ - 根据收盘权重计算新的个股权重 + 根据收盘权重计算下一时刻新的个股权重 """ - cash = 1 - position['weight'].sum() - return position['end_weight'] / (cash + position['end_weight'].sum()) + if position['weight'].sum() <= 1 + 1e-5: + # 非融资情况 + cash = max(0, 1 - position['weight'].sum()) + return position['end_weight'] / (position['end_weight'].sum() + cash) + else: + return position['weight'].sum() * (position['end_weight'] / position['end_weight'].sum()) def update_signal(self, date:str, @@ -722,7 +726,7 @@ class Trader(Account): cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn'] self.update_account(date, trade_time, cur_pos, next_position) # 更新仓位 - cur_pos['weight'] = self.update_by_end_weight(cur_pos) + cur_pos['weight'] = self.update_next_weight(cur_pos) # 调整权重:买入、卖出、仓位再平衡 next_position = self.reblance_weight(trade_time, cur_pos, next_position) else: @@ -741,10 +745,10 @@ class Trader(Account): cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1 cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1) position_record = cur_pos.copy() - position_record['end_weight'] = self.update_by_end_weight(position_record) + position_record['end_weight'] = self.update_next_weight(position_record) self.position_history[date] = position_record.copy()[position_fields] # 更新当期收盘后个股仓位作为下一期的开盘仓位 - cur_pos['weight'] = self.update_by_end_weight(cur_pos) + cur_pos['weight'] = self.update_next_weight(cur_pos) next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']] next_position['open'] = cur_pos['close'] self.update_account(date, trade_time, cur_pos, cur_pos)