diff --git a/trader.py b/trader.py index 277841a..12f6787 100644 --- a/trader.py +++ b/trader.py @@ -497,9 +497,10 @@ class Trader(Account): next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list}) next_position['date'] = date + # 融资融券数量 margin_num = len(next_margin_list) - next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), next_position) + next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), untradable_list, next_position) next_position['margin_trade'] = 0 next_position = next_position.set_index(['stock_code']) next_position.loc[next_margin_list, 'margin_trade'] = 1 @@ -633,7 +634,7 @@ class Trader(Account): if cur_pos['weight'].sum() == 0: pnl = 0 else: - cash = 1 - cur_pos['weight'].sum() + cash = max(0, 1 - cur_pos['weight'].sum()) pnl = ((cur_pos['end_weight'].sum() + cash) / (cur_pos['weight'].sum() + cash)) - 1 self.account *= 1+pnl self.account_history = self.account_history.append({