import warnings import pandas as pd from rich import print as rprint from typing import Union # 警告格式 def custom_warnings(message, category, filename, lineno, file=None, line=None): rprint("[bold red]Warning:[/bold red]", message) warnings.showwarning = custom_warnings class CustomWarning(UserWarning): pass warnings.simplefilter('always', category=CustomWarning) class Account(): """ 账户类:用于控制账户类型 Args: is_real (bool): 是否为真实交易 init_account (Union[int, float]): 初始账户金额 trade_limit (float): 是否限制每个股票的购买比例(只在真实账户中有效) position_ratio (Union[float, pd.Series]): 持仓比例 """ def __init__(self, is_real: bool = False, init_account: Union[int, float] = 1e6, trade_limit: float = 0.01, position_ratio: Union[float, pd.Series] = 1.) -> None: if is_real: self.a_type = 'amount' self.account = init_account else: self.a_type = 'percent' self.account = 1.0 self.trade_limit = trade_limit # 持仓比例 # 判断类型以及持仓比例范围为0-2之间 if isinstance(position_ratio, (int, float)): position_ratio = float(position_ratio) if position_ratio >= 0 and position_ratio <= 2: self.position_ratio = position_ratio else: raise ValueError('`position_ratio` should less than 2.') elif isinstance(position_ratio, pd.Series): if position_ratio.min() >= 0 and position_ratio.max() <= 2: self.position_ratio = position_ratio else: raise ValueError('`position_ratio` should be position and less than 2.') else: raise ValueError('`position_ratio` should be float or Series.') # 当前持仓 self.position = pd.DataFrame() # 账户收益更新 self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','position_ratio','pnl']) # 历史持仓 self.position_history = dict() # 当前日期 self.date = None # 获取收盘时持仓和仓位 @property def close_position(self): return self.position[self.position['trade_time'] == 'close']['stock_code'].values @property def close_weight(self): return self.position[self.position['trade_time'] == 'close'].weight.values