Update: 支持自定义每日仓位 (#18)

This commit is contained in:
binz 2024-06-06 00:11:51 +08:00
parent 1bb0114e37
commit 740a64ef54
2 changed files with 42 additions and 22 deletions

View File

@ -18,17 +18,18 @@ warnings.simplefilter('always', category=CustomWarning)
class Account():
"""
账户类用于控制账户类型
Arguments:
- is_real(bool): 是否为真实交易
- init_account(Union[int,float]): 初始账户金额
- trade_limit(float): 是否限制每个股票的购买比例(只在真实账户中有效)
- leverage(float): 杠杆比例
Args:
is_real (bool): 是否为真实交易
init_account (Union[int, float]): 初始账户金额
trade_limit (float): 是否限制每个股票的购买比例(只在真实账户中有效)
position_ratio (Union[float, pd.Series]): 持仓比例
"""
def __init__(self,
is_real: bool = False,
init_account: Union[int,float] = 1e6,
init_account: Union[int, float] = 1e6,
trade_limit: float = 0.01,
leverage: Union[int,float] = 1.) -> None:
position_ratio: Union[float, pd.Series] = 1.) -> None:
if is_real:
self.a_type = 'amount'
self.account = init_account
@ -36,16 +37,25 @@ class Account():
self.a_type = 'percent'
self.account = 1.0
self.trade_limit = trade_limit
if isinstance(leverage, (int,float)):
if leverage > 1.5:
raise ValueError('leverage should less than 1.5')
# 持仓比例
# 判断类型以及持仓比例范围为0-2之间
if isinstance(position_ratio, (int, float)):
position_ratio = float(position_ratio)
if position_ratio >= 0 and position_ratio <= 2:
self.position_ratio = position_ratio
else:
raise ValueError('`position_ratio` should less than 2.')
elif isinstance(position_ratio, pd.Series):
if position_ratio.min() >= 0 and position_ratio.max() <= 2:
self.position_ratio = position_ratio
else:
raise ValueError('`position_ratio` should be position and less than 2.')
else:
raise ValueError('leverage should be int or float')
self.leverage = leverage
raise ValueError('`position_ratio` should be float or Series.')
# 当前持仓
self.position = pd.DataFrame()
# 账户收益更新
self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','leverage','pnl'])
self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','position_ratio','pnl'])
# 历史持仓
self.position_history = dict()
# 当前日期

View File

@ -209,18 +209,23 @@ class Trader(Account):
# 可执行日期
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index()
def get_weight(self, date, total_weight, next_position):
def get_weight(self, date, account_weight, next_position):
"""
计算个股仓位
Args:
account_weight (float): 总权重即当前持仓比例
"""
if isinstance(self.weight, str):
if self.weight == 'avg':
return total_weight / len(next_position)
return account_weight / len(next_position)
if isinstance(self.weight, pd.DataFrame):
date_weight = self.weight.loc[date].dropna().sort_index()
try:
weight_list = date_weight.loc[next_position['stock_code'].to_list()].values
weight_list = total_weight * weight_list / sum(weight_list)
if weight_list.sum() > 1 + 1e5: # 防止数据精度的影响,给与一定的宽松
raise Exception(f"total weight of {date} is larger then 1.")
weight_list = account_weight * weight_list
return weight_list
except Exception:
raise ValueError(f'not found stock weight in {date}')
@ -275,7 +280,7 @@ class Trader(Account):
normal_exclude = list(set(normal_exclude))
# 交易列表
if self.leverage <= 1.0:
if self.today_position_ratio <= 1.0:
# 如果没有杠杆:
buy_list = []
sell_list = []
@ -336,7 +341,7 @@ class Trader(Account):
# 生成下一期持仓
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
next_position['date'] = date
next_position['weight'] = self.get_weight(date, self.leverage, next_position)
next_position['weight'] = self.get_weight(date, self.today_position_ratio, next_position)
# 剔除无法买入的涨停股,这部分仓位空出
next_position = next_position[~next_position['stock_code'].isin(limit_up_list)]
next_position['margin_trade'] = 0
@ -351,7 +356,7 @@ class Trader(Account):
normal_list.append(stock)
return normal_list, margin_list
# 计算需要融资融券标的数量
margin_ratio = max(self.leverage-1, 0)
margin_ratio = max(self.today_position_ratio-1, 0)
margin_needed = round(self.num * margin_ratio)
is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index()
@ -592,7 +597,7 @@ class Trader(Account):
next_position.set_index(['stock_code'])['weight'].rename('next'),
], axis=1)
turnover = (turnover['next'] - turnover['cur'].fillna(0)).abs().sum()
leverage = next_position['weight'].sum()
position_ratio = next_position['weight'].sum()
if cur_pos['weight'].sum() == 0:
pnl = 0
else:
@ -602,7 +607,7 @@ class Trader(Account):
'date': date,
'trade_time': trade_time,
'turnover': turnover,
'leverage': leverage,
'position_ratio': position_ratio,
'pnl': pnl
}, ignore_index=True)
return True
@ -627,6 +632,11 @@ class Trader(Account):
self.position_history.pop(date)
# 更新持仓信号
self.load_data(date, update_type)
# 更新当日持仓比例
if isinstance(self.position_ratio, float):
self.today_position_ratio = self.position_ratio
if isinstance(self.position_ratio, pd.Series):
self.today_position_ratio = self.position_ratio.loc[date]
# 更新费用
fee = (self.commission + self.slippage[0], self.commission + self.slippage[1])
current_tax = (0.001, 0.001)
@ -643,7 +653,7 @@ class Trader(Account):
# 冻结列表
frozen_list = []
# 遍历各个交易时间的信号
for idx,trade_time in enumerate(self.signal):
for _,trade_time in enumerate(self.signal):
if self.check_update_status(date, trade_time):
continue
if date in self.signal[trade_time].index: