spread_backtest/account.py

63 lines
2.2 KiB
Python
Raw Normal View History

2024-05-22 23:33:19 +08:00
import warnings
import pandas as pd
from rich import print as rprint
from typing import Union, Iterable, Optional, Dict
# 警告格式
def custom_warnings(message, category, filename, lineno, file=None, line=None):
rprint("[bold red]Warning:[/bold red]", message)
warnings.showwarning = custom_warnings
class CustomWarning(UserWarning):
pass
warnings.simplefilter('always', category=CustomWarning)
class Account():
"""
账户类用于控制账户类型
Arguments:
- is_real(bool): 是否为真实交易
- init_account(Union[int,float]): 初始账户金额
- trade_limit(float): 是否限制每个股票的购买比例(只在真实账户中有效)
- leverage(float): 杠杆比例
"""
def __init__(self,
is_real: bool = False,
init_account: Union[int,float] = 1e6,
trade_limit: float = 0.01,
leverage: Union[int,float] = 1.) -> None:
if is_real:
self.a_type = 'amount'
self.account = init_account
else:
self.a_type = 'percent'
self.account = 1.0
self.trade_limit = trade_limit
if isinstance(leverage, (int,float)):
if leverage > 1.5:
raise ValueError('leverage should less than 1.5')
2024-05-28 00:03:38 +08:00
# if leverage > 1:
# warnings.warn('leverage(杠杆率)大于1时优先满足杠杆会出现高于指定换手率的情况', category=CustomWarning)
2024-05-22 23:33:19 +08:00
else:
raise ValueError('leverage should be int or float')
self.leverage = leverage
# 当前持仓
self.position = pd.DataFrame()
# 账户收益更新
self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','leverage','pnl'])
# 历史持仓
self.position_history = dict()
# 当前日期
self.date = None
# 获取收盘时持仓和仓位
@property
def close_position(self):
return self.position[self.position['trade_time'] == 'close']['stock_code'].values
@property
def close_weight(self):
return self.position[self.position['trade_time'] == 'close'].weight.values