63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
import warnings
|
||
import pandas as pd
|
||
from rich import print as rprint
|
||
from typing import Union, Iterable, Optional, Dict
|
||
|
||
|
||
# 警告格式
|
||
def custom_warnings(message, category, filename, lineno, file=None, line=None):
|
||
rprint("[bold red]Warning:[/bold red]", message)
|
||
warnings.showwarning = custom_warnings
|
||
|
||
|
||
class CustomWarning(UserWarning):
|
||
pass
|
||
warnings.simplefilter('always', category=CustomWarning)
|
||
|
||
|
||
class Account():
|
||
"""
|
||
账户类:用于控制账户类型
|
||
Arguments:
|
||
- is_real(bool): 是否为真实交易
|
||
- init_account(Union[int,float]): 初始账户金额
|
||
- trade_limit(float): 是否限制每个股票的购买比例(只在真实账户中有效)
|
||
- leverage(float): 杠杆比例
|
||
"""
|
||
def __init__(self,
|
||
is_real: bool = False,
|
||
init_account: Union[int,float] = 1e6,
|
||
trade_limit: float = 0.01,
|
||
leverage: Union[int,float] = 1.) -> None:
|
||
if is_real:
|
||
self.a_type = 'amount'
|
||
self.account = init_account
|
||
else:
|
||
self.a_type = 'percent'
|
||
self.account = 1.0
|
||
self.trade_limit = trade_limit
|
||
if isinstance(leverage, (int,float)):
|
||
if leverage > 1.5:
|
||
raise ValueError('leverage should less than 1.5')
|
||
# if leverage > 1:
|
||
# warnings.warn('leverage(杠杆率)大于1时,优先满足杠杆,会出现高于指定换手率的情况!', category=CustomWarning)
|
||
else:
|
||
raise ValueError('leverage should be int or float')
|
||
self.leverage = leverage
|
||
# 当前持仓
|
||
self.position = pd.DataFrame()
|
||
# 账户收益更新
|
||
self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','leverage','pnl'])
|
||
# 历史持仓
|
||
self.position_history = dict()
|
||
# 当前日期
|
||
self.date = None
|
||
|
||
# 获取收盘时持仓和仓位
|
||
@property
|
||
def close_position(self):
|
||
return self.position[self.position['trade_time'] == 'close']['stock_code'].values
|
||
|
||
@property
|
||
def close_weight(self):
|
||
return self.position[self.position['trade_time'] == 'close'].weight.values |