2024-05-22 23:33:19 +08:00
|
|
|
|
import warnings
|
|
|
|
|
import pandas as pd
|
|
|
|
|
from rich import print as rprint
|
|
|
|
|
from typing import Union, Iterable, Optional, Dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 警告格式
|
|
|
|
|
def custom_warnings(message, category, filename, lineno, file=None, line=None):
|
|
|
|
|
rprint("[bold red]Warning:[/bold red]", message)
|
|
|
|
|
warnings.showwarning = custom_warnings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomWarning(UserWarning):
|
|
|
|
|
pass
|
|
|
|
|
warnings.simplefilter('always', category=CustomWarning)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Account():
|
|
|
|
|
"""
|
|
|
|
|
账户类:用于控制账户类型
|
|
|
|
|
Arguments:
|
|
|
|
|
- is_real(bool): 是否为真实交易
|
|
|
|
|
- init_account(Union[int,float]): 初始账户金额
|
|
|
|
|
- trade_limit(float): 是否限制每个股票的购买比例(只在真实账户中有效)
|
|
|
|
|
- leverage(float): 杠杆比例
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self,
|
|
|
|
|
is_real: bool = False,
|
|
|
|
|
init_account: Union[int,float] = 1e6,
|
|
|
|
|
trade_limit: float = 0.01,
|
|
|
|
|
leverage: Union[int,float] = 1.) -> None:
|
|
|
|
|
if is_real:
|
|
|
|
|
self.a_type = 'amount'
|
|
|
|
|
self.account = init_account
|
|
|
|
|
else:
|
|
|
|
|
self.a_type = 'percent'
|
|
|
|
|
self.account = 1.0
|
|
|
|
|
self.trade_limit = trade_limit
|
|
|
|
|
if isinstance(leverage, (int,float)):
|
|
|
|
|
if leverage > 1.5:
|
|
|
|
|
raise ValueError('leverage should less than 1.5')
|
2024-05-28 00:03:38 +08:00
|
|
|
|
# if leverage > 1:
|
|
|
|
|
# warnings.warn('leverage(杠杆率)大于1时,优先满足杠杆,会出现高于指定换手率的情况!', category=CustomWarning)
|
2024-05-22 23:33:19 +08:00
|
|
|
|
else:
|
|
|
|
|
raise ValueError('leverage should be int or float')
|
|
|
|
|
self.leverage = leverage
|
|
|
|
|
# 当前持仓
|
|
|
|
|
self.position = pd.DataFrame()
|
|
|
|
|
# 账户收益更新
|
|
|
|
|
self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','leverage','pnl'])
|
|
|
|
|
# 历史持仓
|
|
|
|
|
self.position_history = dict()
|
|
|
|
|
# 当前日期
|
|
|
|
|
self.date = None
|
|
|
|
|
|
|
|
|
|
# 获取收盘时持仓和仓位
|
|
|
|
|
@property
|
|
|
|
|
def close_position(self):
|
|
|
|
|
return self.position[self.position['trade_time'] == 'close']['stock_code'].values
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def close_weight(self):
|
|
|
|
|
return self.position[self.position['trade_time'] == 'close'].weight.values
|