import pandas as pd import numpy as np import sys, os import copy sys.path.append("/home/lenovo/quant/tools/get_factor_tools/") from db_tushare import get_factor_tools gft = get_factor_tools() from typing import Union, Iterable, Dict from ordered_set import OrderedSet from account import Account from dataloader import DataLoader class Trader(Account): """ 交易类: 用于控制每日交易情况 Args: signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行 interval (int, tuple, pd.Series): 交易间隔 num (int): 持仓数量 ascending (bool): 因子方向 with_st (bool): 是否包含st tick (bool): 是否开始tick模拟模式(开发中) weight ([str, pd.DataFrame]): 权重分配 - avg (str): 平均分配,每天早盘重新分配,日中交易不重新分配 - (pd.DataFrame): 自定义股票权重,包含每天个股指定的权重,会自动归一化 amt_filter (set): 20日均成交额筛选,第一个参数是筛选下限,第二个参数是筛选上限,可以只提供下限 data_root (dict): 对应各个目标因子的交易价格数据,必须包含stock_code和price列 ipo_days (int): 筛选上市时间 slippage (tuple): 买入和卖出滑点 commission (float): 佣金 tax (dict): 印花税 exclude_list (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓 - abnormal: 异常公告剔除,包含中止上市、立案调查、警示函等异常情况的剔除 - receesion: 财报同比或环比下降50%以上 - qualified_opinion: 会计保留意见 account (Account): 账户设置,account.Account """ def __init__(self, signal: Dict[str, pd.DataFrame]=None, interval: Dict[str, Union[int,tuple,pd.Series]]=1, num: int=100, ascending: bool=False, with_st: bool=False, data_root:dict={}, tick: bool=False, weight: str='avg', amt_filter: set=(0,), ipo_days: int=20, slippage :tuple=(0.001,0.001), commission: float=0.0001, tax: dict={ '1990-01-01': (0.001,0.001), '2008-04-24': (0.001,0.001), '2008-09-19': (0, 0.001), '2023-08-28': (0, 0.0005) }, exclude_list: list=[], account: dict={}, **kwargs) -> None: # 初始化账户 super().__init__(account) if isinstance(signal, dict): self.signal = signal if 'close' in signal: raise ValueError('signal key cannot be close') for s in self.signal: self.signal[s] = gft.return_factor(self.signal[s], self.signal[s].index.min(), self.signal[s].index.max(), return_type='origin') else: raise ValueError('type of signal is invalid') # -------------------- # 参数检验 # -------------------- if len(kwargs) > 0: raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'") # interval self.interval = [] for s in signal: if s in interval: s_interval = interval[s] if isinstance(s_interval, int): df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index)) df_interval[::s_interval] = 1 elif isinstance(s_interval, tuple): df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index)) df_interval[::s_interval[0]] = s_interval[1] elif isinstance(s_interval, pd.Series): df_interval = s_interval else: raise ValueError('invalid interval type') self.interval.append(df_interval) else: raise ValueError(f'not found interval for signal {s}') self.interval = pd.concat(self.interval) # num if isinstance(num, int): self.num = int(num) else: raise ValueError('num should be int') # ascending if isinstance(ascending, bool): self.ascending = ascending else: raise ValueError('invalid type for `ascending`') # with_st if isinstance(with_st, bool): self.with_st = with_st else: raise ValueError('invalid type for `with_st`') # weight if isinstance(weight, (str, pd.DataFrame)): self.weight = weight else: raise ValueError('invalid type for `weight`') # amt_filter if isinstance(amt_filter, (set,list,tuple)): if len(amt_filter) == 1: self.amt_filter_min = amt_filter[0] self.amt_filter_max = np.inf else: self.amt_filter_min = amt_filter[0] self.amt_filter_max = amt_filter[1] else: raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required') # ipo_days if isinstance(ipo_days, int): self.ipo_days = ipo_days else: raise Exception('wrong type for ipo_days, `int` is required') # slippage if isinstance(slippage, tuple) and len(slippage) == 2: self.slippage = slippage else: raise ValueError('slippage should be set.') # commission if isinstance(commission, float): self.commission = commission else: raise ValueError('commission should be flaot.') # tax if isinstance(tax, dict): self.tax = tax else: raise ValueError('tax should be dict.') # exclude if isinstance(exclude_list, list): self.exclude_list = exclude_list optional_list = ['abnormal', 'recession'] for item in exclude_list: if item in optional_list: pass else: raise ValueError(f"Unexpected keyword argument '{item}'") else: raise ValueError('exclude_list should be list.') # data_root # 至少包含basic data路径,open信号默认使用basic_data if len(data_root) <= 0: raise ValueError('num of data_root should be equal or greater than 1') if 'basic' in data_root: # 可执行日期 self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(data_root['basic'])]).sort_index() else: raise ValueError('data_root should contain basic data root') for s in self.signal: if s == 'open': continue else: if s not in data_root: raise ValueError(f"data for signal {s} is not provided") self.data_root = data_root def load_data(self, date: str, update_type: str='rtn'): """ 加载每日基础数据 Args: update_type (str): 更新模式 - rtn: 更新所有信号数据 - position: 只更新basic数据,用于持仓判断 """ self.today_data = dict() self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv')) if update_type == 'position': return True for s in self.signal: if s == 'open': if s in self.data_root: continue else: self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'})) else: self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv')) if 'close' in self.signal: pass else: self.today_data['close'] = DataLoader(self.today_data['basic'].data[['close_post']].rename(columns={'close_post':'price'})) def update_avaliable_date(self): """ 更新可执行日期 """ # 可执行日期 self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index() def get_weight(self, date, total_weight, next_position): """ 计算个股仓位 """ if isinstance(self.weight, str): if self.weight == 'avg': return total_weight / len(next_position) if isinstance(self.weight, pd.DataFrame): date_weight = self.weight.loc[date].dropna().sort_index() try: weight_list = date_weight.loc[next_position['stock_code'].to_list()].values weight_list = total_weight * weight_list / sum(weight_list) return weight_list except: raise ValueError(f'not found stock weight in {date}') def get_next_position(self, date, factor): """ 计算下一时刻持仓 """ # 计算持仓和最大可交易数量 if len(self.position) > 0: last_position = self.position['stock_code'].values last_position = factor.loc[last_position].sort_values(ascending=self.ascending) if len(self.position) <= self.num: # 如果昨日持仓本身就不足持仓数量则不用足额换仓 max_sell_num = min(int(self.interval.loc[date]*self.num), int(self.interval.loc[date]*self.num)+len(self.position)-self.num) else: # 如果昨日持仓本身就超额持仓数量则超额换仓 max_sell_num = max(int(self.interval.loc[date]*self.num), int(self.interval.loc[date]*self.num)+len(self.position)-self.num) else: last_position = pd.Series() max_sell_num = self.num target_list = [] # 获取用于筛选的数据 stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index() stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index() stock_amt_filter = stock_amt20.copy() stock_amt_filter.loc[:] = 0 stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1 stock_amt_filter = stock_amt_filter.sort_index() stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index() stock_ipo_filter = stock_ipo_days.copy() stock_ipo_filter.loc[:] = 0 stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1 # 剔除列表 = ipo筛选 + 成交量筛选 + 额外剔除 # 其中额外提出会强制执行,因此单独保存一份强制执行的列表 exclude_stock = [] for exclude in self.exclude_list: if exclude == 'abnormal': stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index() exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list() if exclude == 'recession': stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index() exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list() exclude_stock = list(set(exclude_stock)) force_exclude = copy.deepcopy(exclude_stock) exclude_stock += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list() exclude_stock += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list() exclude_stock = list(set(exclude_stock)) # 交易列表 if self.leverage <= 1.0: # 获取当前时间目标列表和冻结(无法交易)列表: # 1 保留冻结的昨日持仓 # 2 目标持仓列表会对当日因子进行ST、成交量、上市时间滤后按因子排序方向选取 for stock in last_position.index: # 如果停牌或者跌停继续持有 if stock_status.loc[stock] in [0,2]: target_list.append(stock) # 剔除过滤条件后 after_filter_list = list(set(factor.index) - set(exclude_stock)) for stock in factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.values: # 目标持仓数量达到持仓数量则中止 if len(target_list) == self.num: break # ST过滤 if self.with_st: if stock_status.loc[stock] in [0,2,5,7]: if stock in last_position.index: target_list.append(stock) else: target_list.append(stock) else: # 非ST if stock_status.loc[stock] in [3,6]: target_list.append(stock) # 如果停牌或者跌停继续持有 if stock_status.loc[stock] in [0,2,7]: if stock in last_position.index: target_list.append(stock) target_list = list(OrderedSet(target_list)) # 如果没有杠杆 buy_list = [] sell_list = [] # ----- 卖出 ----- # 异常强制卖出 for stock in last_position.index: if stock in force_exclude: sell_list.append(stock) force_sell_num = len(sell_list) # 按照反向排名逐个卖出 if self.ascending: factor = factor.fillna(factor.max()+1) else: factor = factor.fillna(factor.min()-1) for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]: if len(sell_list) >= max_sell_num + force_sell_num: break if stock in target_list: continue else: if stock_status.loc[stock] in [0,2,5,7]: continue else: sell_list.append(stock) sell_list = list(set(sell_list)) # ----- 买入 ----- # 卖出后持仓列表 after_sell_list = set(last_position.index) - set(sell_list) cant_buy_list = [] # 涨停股记录 max_buy_num = max(0, self.num-len(last_position)+len(sell_list)) for stock in target_list: if stock in after_sell_list: continue else: if stock_status.loc[stock] in [4,6]: cant_buy_list.append(stock) buy_list.append(stock) if len(buy_list) == max_buy_num: break next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))}) next_position['date'] = date next_position['weight'] = self.get_weight(date, self.leverage, next_position) # 剔除无法买入的涨停股,这部分仓位空出 next_position = next_position[~next_position['stock_code'].isin(cant_buy_list)] next_position['margin_trade'] = 0 else: # 如果有杠杆 def assign_stock(normal_list, margin_list, margin_needed, stock, status): if status == 1: if len(margin_list) < margin_needed: margin_list.append(stock) else: if len(normal_list) < self.num - margin_needed: normal_list.append(stock) return normal_list, margin_list # 计算需要融资融券标的数量 margin_ratio = max(self.leverage-1, 0) margin_needed = round(self.num * margin_ratio) is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index() # 获取当前时间目标列表和冻结(无法交易)列表: # 1 保留冻结的昨日持仓 # 2 目标持仓列表会对当日因子进行ST、成交量、上市时间滤后按因子排序方向选取 # 3 根据融资比例将股票分为常规池子和融资池子 normal_list = [] margin_list = [] # 获取历史融资融券池 if len(last_position) > 0: last_margin_list = self.position.loc[self.position['margin_trade'] == 1, 'stock_code'].to_list() else: last_margin_list = [] # 获取历史非融资融券标的 if len(last_position) > 0: last_normal_list = self.position.loc[self.position['margin_trade'] == 0, 'stock_code'].to_list() else: last_normal_list = [] for stock in last_margin_list: # 如果停牌或者跌停继续持有 if stock_status.loc[stock] in [0,2]: margin_list.append(stock) for stock in last_normal_list: # 如果停牌或者跌停继续持有 if stock_status.loc[stock] in [0,2]: normal_list.append(stock) # 剔除过滤条件后 after_filter_list = list(set(factor.index) - set(exclude_stock)) for stock in factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.values: if len(normal_list + margin_list) == self.num: break if self.with_st: if stock_status.loc[stock] in [0,2,5,7]: if stock in last_margin_list: margin_list.append(stock) else: normal_list.append(stock) else: normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock]) else: # 非ST if stock_status.loc[stock] in [3,6]: normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock]) else: # 如果停牌或者跌停继续持有 if stock_status.loc[stock] in [0,2,7]: if stock in last_margin_list: margin_list.append(stock) else: normal_list.append(stock) margin_list = list(OrderedSet(margin_list)) normal_list = list(OrderedSet(normal_list)) target_list = normal_list + margin_list # ----- 卖出 ----- buy_list = [] sell_list = [] # 融资融券池的和非融资融券池的分开更新 # 更新融资融券池 # 异常强制卖出 for stock in last_margin_list: if stock in force_exclude: sell_list.append(stock) force_sell_num = len(sell_list) for stock in factor.loc[last_margin_list].sort_values(ascending=self.ascending).index.values[::-1]: if len(sell_list) >= int(max_sell_num * margin_ratio) + force_sell_num + 1: break if stock in normal_list: continue else: if stock_status.loc[stock] in [0,2,5,7]: continue else: sell_list.append(stock) sell_list = list(set(sell_list)) next_margin_list = list(set(last_margin_list) - set(sell_list)) # 更新非融资融券池 # 异常强制卖出 for stock in last_normal_list: if stock in force_exclude: sell_list.append(stock) force_sell_num += 1 for stock in factor.loc[last_normal_list].sort_values(ascending=self.ascending).index.values[::-1]: if len(sell_list) >= max_sell_num + force_sell_num: break if stock in normal_list: continue else: if stock_status.loc[stock] in [0,2,5,7]: continue else: sell_list.append(stock) sell_list = list(set(sell_list)) next_normal_list = list(set(last_normal_list) - set(sell_list)) # ----- 买入 ----- # 卖出后持仓列表 after_sell_list = set(last_position.index) - set(sell_list) cant_buy_list = [] # 涨停股记录 # 融资融券池的和非融资融券池的分开更新 # 更新融资融券池 for stock in margin_list: if stock in after_sell_list: continue else: if stock_status.loc[stock] in [4,6]: cant_buy_list.append(stock) next_margin_list.append(stock) if len(next_margin_list) >= margin_needed: break # 更新非融资融券池 for stock in normal_list: if stock in after_sell_list: continue else: if stock_status.loc[stock] in [4,6]: cant_buy_list.append(stock) next_normal_list.append(stock) if len(next_normal_list) >= self.num - margin_needed: break next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list}) next_position['date'] = date # 融资融券数量 margin_num = len(next_margin_list) next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), next_position) next_position['margin_trade'] = 0 next_position = next_position.set_index(['stock_code']) next_position.loc[next_margin_list, 'margin_trade'] = 1 next_position = next_position.reset_index() # 剔除无法买入的涨停股,这部分仓位空出 next_position = next_position[~next_position['stock_code'].isin(cant_buy_list)] # 检测当前持仓是否可以交易 frozen_list = [] if len(self.position) > 0: for stock in next_position['stock_code']: if stock_status.loc[stock] in [0,2]: frozen_list.append(stock) return sell_list, buy_list, frozen_list, next_position def get_price(self, trade_time: str='open', target: Iterable[str]=[], buy_list: Iterable[str] = [], sell_list: Iterable[str] = []): """ 获取价格 Args: trade_time (str): 交易时间 target (Iterable[float]): 目标 buy_list (Iterable[str]): 买入目标 sell_list (Iterable[str]): 卖出目标 """ stock_price = self.today_data[trade_time] target_price = pd.Series(index=target) sell_list = list(set(target) & set(sell_list)) buy_list = list(set(target) & set(buy_list)) target_price.loc[target] = stock_price.get(target, 'price').fillna(0) target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1]) target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0]) return target_price def check_update_status(self, date: str, trade_time: str): # 判断当前更新状态 # 如果日期和交易时间已经存在则返回True if len(self.account_history) == 0: return False elif date < self.account_history['date'].max(): return True else: exist_list = self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values if f'{date}-{trade_time}' in exist_list: return True else: return False def reblance_weight(self, trade_time: str, cur_pos: pd.DataFrame, next_position: pd.DataFrame): """ 动态平衡权重 """ # 判断冻结列表 stock_status = self.today_data['basic'].get(cur_pos['stock_code'].values, 'opening_info') buy_frozen_list = [] for stock in cur_pos['stock_code']: if stock_status.loc[stock] in [0,2,4,6]: buy_frozen_list.append(stock) sell_frozen_list = [] for stock in cur_pos['stock_code']: if stock_status.loc[stock] in [0,2,5,7]: sell_frozen_list.append(stock) # 设定目标仓位 next_position['target_weight'] = next_position['weight'] next_position['current_weight'] = 0 cur_pos = cur_pos.set_index(['stock_code']) next_position = next_position.set_index(['stock_code']) current_list = list(set(cur_pos.index) & set(next_position.index)) next_position.loc[current_list, 'current_weight'] = cur_pos.loc[current_list, 'weight'] next_position['open'] = 0 next_position.loc[current_list, 'open'] = cur_pos.loc[current_list, 'close'] # 计算理想仓位变动 next_position['weight_chg'] = next_position['weight'] - next_position['current_weight'] # 根据冻结判断是否能够变动 next_position['final_weight'] = 0 buy_frozen_list = set(buy_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] > 0].index) sell_frozen_list = set(sell_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] < 0].index) next_position.loc[next_position.index, 'final_weight'] = next_position['weight'] next_position.loc[buy_frozen_list, 'final_weight'] = next_position.loc[buy_frozen_list, 'current_weight'] next_position.loc[sell_frozen_list, 'final_weight'] = next_position.loc[sell_frozen_list, 'current_weight'] # 动态平衡仓位 next_position['final_weight'] /= next_position['final_weight'].sum() next_position['final_weight'] *= next_position['weight'].sum() # 计算理想仓位变动 next_position['weight_chg'] = next_position['final_weight'] - next_position['current_weight'] next_position.loc[list(buy_frozen_list | sell_frozen_list), 'weight_chg'] = 0 # 动态平衡价格 next_position['adjust_price'] = 0 buy_adjust_list = next_position[next_position['weight_chg'] > 0].index.values sell_adjust_list = next_position[next_position['weight_chg'] < 0].index.values next_position.loc[buy_adjust_list, 'adjust_price'] = self.get_price(trade_time, buy_adjust_list, buy_adjust_list, []).values next_position.loc[sell_adjust_list, 'adjust_price'] = self.get_price(trade_time, sell_adjust_list, [], sell_adjust_list).values # 价格调整 next_position['adjust_open'] = (next_position['current_weight']*next_position['open'] + next_position['weight_chg']*next_position['adjust_price']) next_position['adjust_open'] = next_position['adjust_open'] / next_position['final_weight'] next_position.loc[list(buy_frozen_list | sell_frozen_list), 'adjust_open'] = next_position.loc[list(buy_frozen_list | sell_frozen_list), 'open'] # 当日买入不调整 next_position['open'] = next_position['adjust_open'] next_position['weight'] = next_position['final_weight'] next_position = next_position.reset_index() next_position = next_position[['stock_code','date','open','weight','margin_trade']] return next_position def update_account(self, date: str, trade_time: str, cur_pos: pd.DataFrame, next_position: pd.DataFrame): """ 更新账户 Args: date (str): 日期 trade_time (str): 交易时间 cur_pos (DataFrame): 当前持仓 next_position (Iterable[str]): 下一刻持仓 """ turnover = pd.concat([ cur_pos.set_index(['stock_code'])['weight'].fillna(0).rename('cur'), next_position.set_index(['stock_code'])['weight'].rename('next'), ], axis=1) turnover = (turnover['next'] - turnover['cur'].fillna(0)).abs().sum() leverage = next_position['weight'].sum() if cur_pos['weight'].sum() == 0: pnl = 0 else: pnl = (cur_pos['end_weight'].sum() - cur_pos['weight'].sum()) self.account *= 1+pnl self.account_history = self.account_history.append({ 'date': date, 'trade_time': trade_time, 'turnover': turnover, 'leverage': leverage, 'pnl': pnl }, ignore_index=True) return True def update_signal(self, date:str, update_type='rtn'): """ 更新信号收益 Args: update_type (str): 更新类型 - position: 只更新持仓不更新收益 - rtn: 更新收益和持仓 """ # 如果更新日期的close已经记录,则跳过,否则删除现有日期相关记录继续更新 if f'{date}-close' in self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values: return True else: self.account_history = self.account_history.query(f'date != "{date}" ', engine='python') if date in self.position_history: self.position_history.pop(date) # 更新持仓信号 self.load_data(date, update_type) # 更新费用 fee = (self.commission + self.slippage[0], self.commission + self.slippage[1]) current_tax = (0.001, 0.001) for time,tax_rate in self.tax.items(): if date > time: current_tax = tax_rate fee = (fee[0] + current_tax[0], fee[1] + current_tax[1]) self.current_fee = fee # 如果当前持仓不空,添加隔夜收益,否则直接买入 if len(self.position) == 0: cur_pos = pd.DataFrame(columns=['stock_code','date','weight','open','close','margin_trade']) else: cur_pos = self.position.copy() # 冻结列表 frozen_list = [] # 遍历各个交易时间的信号 for idx,trade_time in enumerate(self.signal): if self.check_update_status(date, trade_time): continue if date in self.signal[trade_time].index: factor = self.signal[trade_time].loc[date] else: continue factor = self.signal[trade_time].loc[date] # 获取当前、持仓 sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor) # 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算 if update_type == 'position': self.position_history[date] = next_position.copy() return True if len(cur_pos) > 0: cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], sell_list).values # 停牌股价格不变 cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open'] # 计算收益 cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn'] self.update_account(date, trade_time, cur_pos, next_position) # 更新仓位 cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum() # 调整权重:买入、卖出、仓位再平衡 next_position = self.reblance_weight(trade_time, cur_pos, next_position) else: next_position['open'] = self.get_price(trade_time, next_position['stock_code'].values, buy_list, []).values self.position = next_position.copy() # 收盘统计当日收益 trade_time = 'close' if self.check_update_status(date, trade_time): return True cur_pos = self.position.copy() cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], []).values # 停牌价格不变 cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open'] cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open'] cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1 cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1) position_record = cur_pos.copy() position_record['end_weight'] = (position_record['end_weight'] / position_record['end_weight'].sum()) * position_record['weight'].sum() cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum() next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']] next_position['open'] = cur_pos['close'] self.update_account(date, trade_time, cur_pos, cur_pos) self.position = next_position.copy() self.position_history[date] = position_record.copy() return True