spread_backtest/account.py

71 lines
2.6 KiB
Python
Raw Permalink Normal View History

2024-05-22 23:33:19 +08:00
import warnings
import pandas as pd
from rich import print as rprint
from typing import Union
2024-05-22 23:33:19 +08:00
# 警告格式
def custom_warnings(message, category, filename, lineno, file=None, line=None):
rprint("[bold red]Warning:[/bold red]", message)
warnings.showwarning = custom_warnings
class CustomWarning(UserWarning):
pass
warnings.simplefilter('always', category=CustomWarning)
class Account():
"""
账户类用于控制账户类型
Args:
is_real (bool): 是否为真实交易
init_account (Union[int, float]): 初始账户金额
trade_limit (float): 是否限制每个股票的购买比例(只在真实账户中有效)
position_ratio (Union[float, pd.Series]): 持仓比例
2024-05-22 23:33:19 +08:00
"""
def __init__(self,
is_real: bool = False,
init_account: Union[int, float] = 1e6,
2024-05-22 23:33:19 +08:00
trade_limit: float = 0.01,
position_ratio: Union[float, pd.Series] = 1.) -> None:
2024-05-22 23:33:19 +08:00
if is_real:
self.a_type = 'amount'
self.account = init_account
else:
self.a_type = 'percent'
self.account = 1.0
self.trade_limit = trade_limit
# 持仓比例
# 判断类型以及持仓比例范围为0-2之间
if isinstance(position_ratio, (int, float)):
position_ratio = float(position_ratio)
if position_ratio >= 0 and position_ratio <= 2:
self.position_ratio = position_ratio
else:
raise ValueError('`position_ratio` should less than 2.')
elif isinstance(position_ratio, pd.Series):
if position_ratio.min() >= 0 and position_ratio.max() <= 2:
self.position_ratio = position_ratio
else:
raise ValueError('`position_ratio` should be position and less than 2.')
2024-05-22 23:33:19 +08:00
else:
raise ValueError('`position_ratio` should be float or Series.')
2024-05-22 23:33:19 +08:00
# 当前持仓
self.position = pd.DataFrame()
# 账户收益更新
self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','position_ratio','pnl'])
2024-05-22 23:33:19 +08:00
# 历史持仓
self.position_history = dict()
# 当前日期
self.date = None
# 获取收盘时持仓和仓位
@property
def close_position(self):
return self.position[self.position['trade_time'] == 'close']['stock_code'].values
@property
def close_weight(self):
return self.position[self.position['trade_time'] == 'close'].weight.values