Update: 支持自定义每日仓位 (#18)
This commit is contained in:
parent
1bb0114e37
commit
740a64ef54
36
account.py
36
account.py
|
@ -18,17 +18,18 @@ warnings.simplefilter('always', category=CustomWarning)
|
|||
class Account():
|
||||
"""
|
||||
账户类:用于控制账户类型
|
||||
Arguments:
|
||||
- is_real(bool): 是否为真实交易
|
||||
- init_account(Union[int,float]): 初始账户金额
|
||||
- trade_limit(float): 是否限制每个股票的购买比例(只在真实账户中有效)
|
||||
- leverage(float): 杠杆比例
|
||||
|
||||
Args:
|
||||
is_real (bool): 是否为真实交易
|
||||
init_account (Union[int, float]): 初始账户金额
|
||||
trade_limit (float): 是否限制每个股票的购买比例(只在真实账户中有效)
|
||||
position_ratio (Union[float, pd.Series]): 持仓比例
|
||||
"""
|
||||
def __init__(self,
|
||||
is_real: bool = False,
|
||||
init_account: Union[int,float] = 1e6,
|
||||
init_account: Union[int, float] = 1e6,
|
||||
trade_limit: float = 0.01,
|
||||
leverage: Union[int,float] = 1.) -> None:
|
||||
position_ratio: Union[float, pd.Series] = 1.) -> None:
|
||||
if is_real:
|
||||
self.a_type = 'amount'
|
||||
self.account = init_account
|
||||
|
@ -36,16 +37,25 @@ class Account():
|
|||
self.a_type = 'percent'
|
||||
self.account = 1.0
|
||||
self.trade_limit = trade_limit
|
||||
if isinstance(leverage, (int,float)):
|
||||
if leverage > 1.5:
|
||||
raise ValueError('leverage should less than 1.5')
|
||||
# 持仓比例
|
||||
# 判断类型以及持仓比例范围为0-2之间
|
||||
if isinstance(position_ratio, (int, float)):
|
||||
position_ratio = float(position_ratio)
|
||||
if position_ratio >= 0 and position_ratio <= 2:
|
||||
self.position_ratio = position_ratio
|
||||
else:
|
||||
raise ValueError('`position_ratio` should less than 2.')
|
||||
elif isinstance(position_ratio, pd.Series):
|
||||
if position_ratio.min() >= 0 and position_ratio.max() <= 2:
|
||||
self.position_ratio = position_ratio
|
||||
else:
|
||||
raise ValueError('`position_ratio` should be position and less than 2.')
|
||||
else:
|
||||
raise ValueError('leverage should be int or float')
|
||||
self.leverage = leverage
|
||||
raise ValueError('`position_ratio` should be float or Series.')
|
||||
# 当前持仓
|
||||
self.position = pd.DataFrame()
|
||||
# 账户收益更新
|
||||
self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','leverage','pnl'])
|
||||
self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','position_ratio','pnl'])
|
||||
# 历史持仓
|
||||
self.position_history = dict()
|
||||
# 当前日期
|
||||
|
|
28
trader.py
28
trader.py
|
@ -209,18 +209,23 @@ class Trader(Account):
|
|||
# 可执行日期
|
||||
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index()
|
||||
|
||||
def get_weight(self, date, total_weight, next_position):
|
||||
def get_weight(self, date, account_weight, next_position):
|
||||
"""
|
||||
计算个股仓位
|
||||
|
||||
Args:
|
||||
account_weight (float): 总权重,即当前持仓比例
|
||||
"""
|
||||
if isinstance(self.weight, str):
|
||||
if self.weight == 'avg':
|
||||
return total_weight / len(next_position)
|
||||
return account_weight / len(next_position)
|
||||
if isinstance(self.weight, pd.DataFrame):
|
||||
date_weight = self.weight.loc[date].dropna().sort_index()
|
||||
try:
|
||||
weight_list = date_weight.loc[next_position['stock_code'].to_list()].values
|
||||
weight_list = total_weight * weight_list / sum(weight_list)
|
||||
if weight_list.sum() > 1 + 1e5: # 防止数据精度的影响,给与一定的宽松
|
||||
raise Exception(f"total weight of {date} is larger then 1.")
|
||||
weight_list = account_weight * weight_list
|
||||
return weight_list
|
||||
except Exception:
|
||||
raise ValueError(f'not found stock weight in {date}')
|
||||
|
@ -275,7 +280,7 @@ class Trader(Account):
|
|||
normal_exclude = list(set(normal_exclude))
|
||||
|
||||
# 交易列表
|
||||
if self.leverage <= 1.0:
|
||||
if self.today_position_ratio <= 1.0:
|
||||
# 如果没有杠杆:
|
||||
buy_list = []
|
||||
sell_list = []
|
||||
|
@ -336,7 +341,7 @@ class Trader(Account):
|
|||
# 生成下一期持仓
|
||||
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
|
||||
next_position['date'] = date
|
||||
next_position['weight'] = self.get_weight(date, self.leverage, next_position)
|
||||
next_position['weight'] = self.get_weight(date, self.today_position_ratio, next_position)
|
||||
# 剔除无法买入的涨停股,这部分仓位空出
|
||||
next_position = next_position[~next_position['stock_code'].isin(limit_up_list)]
|
||||
next_position['margin_trade'] = 0
|
||||
|
@ -351,7 +356,7 @@ class Trader(Account):
|
|||
normal_list.append(stock)
|
||||
return normal_list, margin_list
|
||||
# 计算需要融资融券标的数量
|
||||
margin_ratio = max(self.leverage-1, 0)
|
||||
margin_ratio = max(self.today_position_ratio-1, 0)
|
||||
margin_needed = round(self.num * margin_ratio)
|
||||
is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index()
|
||||
|
||||
|
@ -592,7 +597,7 @@ class Trader(Account):
|
|||
next_position.set_index(['stock_code'])['weight'].rename('next'),
|
||||
], axis=1)
|
||||
turnover = (turnover['next'] - turnover['cur'].fillna(0)).abs().sum()
|
||||
leverage = next_position['weight'].sum()
|
||||
position_ratio = next_position['weight'].sum()
|
||||
if cur_pos['weight'].sum() == 0:
|
||||
pnl = 0
|
||||
else:
|
||||
|
@ -602,7 +607,7 @@ class Trader(Account):
|
|||
'date': date,
|
||||
'trade_time': trade_time,
|
||||
'turnover': turnover,
|
||||
'leverage': leverage,
|
||||
'position_ratio': position_ratio,
|
||||
'pnl': pnl
|
||||
}, ignore_index=True)
|
||||
return True
|
||||
|
@ -627,6 +632,11 @@ class Trader(Account):
|
|||
self.position_history.pop(date)
|
||||
# 更新持仓信号
|
||||
self.load_data(date, update_type)
|
||||
# 更新当日持仓比例
|
||||
if isinstance(self.position_ratio, float):
|
||||
self.today_position_ratio = self.position_ratio
|
||||
if isinstance(self.position_ratio, pd.Series):
|
||||
self.today_position_ratio = self.position_ratio.loc[date]
|
||||
# 更新费用
|
||||
fee = (self.commission + self.slippage[0], self.commission + self.slippage[1])
|
||||
current_tax = (0.001, 0.001)
|
||||
|
@ -643,7 +653,7 @@ class Trader(Account):
|
|||
# 冻结列表
|
||||
frozen_list = []
|
||||
# 遍历各个交易时间的信号
|
||||
for idx,trade_time in enumerate(self.signal):
|
||||
for _,trade_time in enumerate(self.signal):
|
||||
if self.check_update_status(date, trade_time):
|
||||
continue
|
||||
if date in self.signal[trade_time].index:
|
||||
|
|
Loading…
Reference in New Issue