diff --git a/account.py b/account.py index 2f4bb27..3f12650 100644 --- a/account.py +++ b/account.py @@ -18,17 +18,18 @@ warnings.simplefilter('always', category=CustomWarning) class Account(): """ 账户类:用于控制账户类型 - Arguments: - - is_real(bool): 是否为真实交易 - - init_account(Union[int,float]): 初始账户金额 - - trade_limit(float): 是否限制每个股票的购买比例(只在真实账户中有效) - - leverage(float): 杠杆比例 + + Args: + is_real (bool): 是否为真实交易 + init_account (Union[int, float]): 初始账户金额 + trade_limit (float): 是否限制每个股票的购买比例(只在真实账户中有效) + position_ratio (Union[float, pd.Series]): 持仓比例 """ def __init__(self, is_real: bool = False, - init_account: Union[int,float] = 1e6, + init_account: Union[int, float] = 1e6, trade_limit: float = 0.01, - leverage: Union[int,float] = 1.) -> None: + position_ratio: Union[float, pd.Series] = 1.) -> None: if is_real: self.a_type = 'amount' self.account = init_account @@ -36,16 +37,25 @@ class Account(): self.a_type = 'percent' self.account = 1.0 self.trade_limit = trade_limit - if isinstance(leverage, (int,float)): - if leverage > 1.5: - raise ValueError('leverage should less than 1.5') + # 持仓比例 + # 判断类型以及持仓比例范围为0-2之间 + if isinstance(position_ratio, (int, float)): + position_ratio = float(position_ratio) + if position_ratio >= 0 and position_ratio <= 2: + self.position_ratio = position_ratio + else: + raise ValueError('`position_ratio` should less than 2.') + elif isinstance(position_ratio, pd.Series): + if position_ratio.min() >= 0 and position_ratio.max() <= 2: + self.position_ratio = position_ratio + else: + raise ValueError('`position_ratio` should be position and less than 2.') else: - raise ValueError('leverage should be int or float') - self.leverage = leverage + raise ValueError('`position_ratio` should be float or Series.') # 当前持仓 self.position = pd.DataFrame() # 账户收益更新 - self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','leverage','pnl']) + self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','position_ratio','pnl']) # 历史持仓 self.position_history = dict() # 当前日期 diff --git a/trader.py b/trader.py index ea56214..7cf45ed 100644 --- a/trader.py +++ b/trader.py @@ -209,18 +209,23 @@ class Trader(Account): # 可执行日期 self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index() - def get_weight(self, date, total_weight, next_position): + def get_weight(self, date, account_weight, next_position): """ 计算个股仓位 + + Args: + account_weight (float): 总权重,即当前持仓比例 """ if isinstance(self.weight, str): if self.weight == 'avg': - return total_weight / len(next_position) + return account_weight / len(next_position) if isinstance(self.weight, pd.DataFrame): date_weight = self.weight.loc[date].dropna().sort_index() try: weight_list = date_weight.loc[next_position['stock_code'].to_list()].values - weight_list = total_weight * weight_list / sum(weight_list) + if weight_list.sum() > 1 + 1e5: # 防止数据精度的影响,给与一定的宽松 + raise Exception(f"total weight of {date} is larger then 1.") + weight_list = account_weight * weight_list return weight_list except Exception: raise ValueError(f'not found stock weight in {date}') @@ -275,7 +280,7 @@ class Trader(Account): normal_exclude = list(set(normal_exclude)) # 交易列表 - if self.leverage <= 1.0: + if self.today_position_ratio <= 1.0: # 如果没有杠杆: buy_list = [] sell_list = [] @@ -336,7 +341,7 @@ class Trader(Account): # 生成下一期持仓 next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))}) next_position['date'] = date - next_position['weight'] = self.get_weight(date, self.leverage, next_position) + next_position['weight'] = self.get_weight(date, self.today_position_ratio, next_position) # 剔除无法买入的涨停股,这部分仓位空出 next_position = next_position[~next_position['stock_code'].isin(limit_up_list)] next_position['margin_trade'] = 0 @@ -351,7 +356,7 @@ class Trader(Account): normal_list.append(stock) return normal_list, margin_list # 计算需要融资融券标的数量 - margin_ratio = max(self.leverage-1, 0) + margin_ratio = max(self.today_position_ratio-1, 0) margin_needed = round(self.num * margin_ratio) is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index() @@ -592,7 +597,7 @@ class Trader(Account): next_position.set_index(['stock_code'])['weight'].rename('next'), ], axis=1) turnover = (turnover['next'] - turnover['cur'].fillna(0)).abs().sum() - leverage = next_position['weight'].sum() + position_ratio = next_position['weight'].sum() if cur_pos['weight'].sum() == 0: pnl = 0 else: @@ -602,7 +607,7 @@ class Trader(Account): 'date': date, 'trade_time': trade_time, 'turnover': turnover, - 'leverage': leverage, + 'position_ratio': position_ratio, 'pnl': pnl }, ignore_index=True) return True @@ -627,6 +632,11 @@ class Trader(Account): self.position_history.pop(date) # 更新持仓信号 self.load_data(date, update_type) + # 更新当日持仓比例 + if isinstance(self.position_ratio, float): + self.today_position_ratio = self.position_ratio + if isinstance(self.position_ratio, pd.Series): + self.today_position_ratio = self.position_ratio.loc[date] # 更新费用 fee = (self.commission + self.slippage[0], self.commission + self.slippage[1]) current_tax = (0.001, 0.001) @@ -643,7 +653,7 @@ class Trader(Account): # 冻结列表 frozen_list = [] # 遍历各个交易时间的信号 - for idx,trade_time in enumerate(self.signal): + for _,trade_time in enumerate(self.signal): if self.check_update_status(date, trade_time): continue if date in self.signal[trade_time].index: