spread_backtest/trader.py

790 lines
38 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import sys
import os
import copy
from typing import Union, Iterable, Dict
from account import Account
from dataloader import DataLoader
from check_funcs import check_buy_exclude
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools
gft = get_factor_tools()
class Trader(Account):
"""
交易类: 用于控制每日交易情况
Args:
signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行
interval (dict[str, (int, tuple, pd.Series)]):
交易间隔
num (int): 持仓数量
ascending (bool): 因子方向
with_st (bool): 是否包含st
tick (bool): 是否开始tick模拟模式(开发中)
weight ([str, pd.DataFrame]): 权重分配
- avg (str): 平均分配,每天早盘重新分配,日中交易不重新分配
- (pd.DataFrame): 自定义股票权重,包含每天个股指定的权重,会自动归一化
amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限
data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列
ipo_days (int): 筛选上市时间
slippage (tuple): 买入和卖出滑点
commission (float): 佣金
tax (dict): 印花税
force_exclude (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓
- abnormal: 异常公告剔除,包含中止上市、立案调查、警示函等异常情况的剔除
- receesion: 财报同比或环比下降50%以上
- qualified_opinion: 会计保留意见
account (Account): 账户设置account.Account
"""
def __init__(self,
signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]={},
num: int=100,
ascending: bool=False,
with_st: bool=False,
data_root:dict={},
tick: bool=False,
weight: str='avg',
slippage :tuple=(0.001,0.001),
commission: float=0.0001,
buy_exclude: Dict[str, Union[set,int,float]]={},
force_exclude: list=[],
account: dict={},
tax: dict={
'1990-01-01': (0.001,0.001),
'2008-04-24': (0.001,0.001),
'2008-09-19': (0, 0.001),
'2023-08-28': (0, 0.0005)
},
**kwargs) -> None:
# 初始化账户
super().__init__(**account)
if isinstance(signal, dict):
self.signal = signal
if 'close' in signal:
raise ValueError('signal key cannot be close')
for s in self.signal:
self.signal[s] = gft.return_factor(self.signal[s], self.signal[s].index.min(), self.signal[s].index.max(), return_type='origin')
else:
raise ValueError('type of signal is invalid')
# --------------------
# 参数检验
# --------------------
if len(kwargs) > 0:
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
# interval
self.init_interval(interval)
# num
if isinstance(num, int):
self.num = int(num)
else:
raise ValueError('num should be int')
# ascending
if isinstance(ascending, bool):
self.ascending = ascending
else:
raise ValueError('invalid type for `ascending`')
# with_st
if isinstance(with_st, bool):
self.with_st = with_st
else:
raise ValueError('invalid type for `with_st`')
# weight
if isinstance(weight, (str, pd.DataFrame)):
self.weight = weight
else:
raise ValueError('invalid type for `weight`')
# slippage
if isinstance(slippage, tuple) and len(slippage) == 2:
self.slippage = slippage
else:
raise ValueError('slippage should be set.')
# commission
if isinstance(commission, float):
self.commission = commission
else:
raise ValueError('commission should be flaot.')
# tax
if isinstance(tax, dict):
self.tax = tax
else:
raise ValueError('tax should be dict.')
# buy exclude
self.buy_exclude = check_buy_exclude(buy_exclude)
# force exclude
if isinstance(force_exclude, list):
self.force_exclude = force_exclude
optional_list = ['abnormal', 'recession']
for item in force_exclude:
if item in optional_list:
pass
else:
raise ValueError(f"Unexpected keyword argument '{item}'")
else:
raise ValueError('force_exclude should be list.')
# data_root
# 至少包含basic data路径open信号默认使用basic_data
if len(data_root) <= 0:
raise ValueError('num of data_root should be equal or greater than 1')
if 'basic' in data_root:
# 可执行日期
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(data_root['basic'])]).sort_index()
else:
raise ValueError('data_root should contain basic data root')
for s in self.signal:
if s == 'open':
continue
else:
if s not in data_root:
raise ValueError(f"data for signal {s} is not provided")
self.data_root = data_root
def init_interval(self, interval):
"""
初始化interval
"""
interval_list = []
for s in self.signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
interval_list.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(interval_list)
def load_data(self,
date: str,
update_type: str='rtn'):
"""
加载每日基础数据
Args:
update_type (str): 更新模式
- rtn: 更新所有信号数据
- position: 只更新basic数据用于持仓判断
"""
self.today_data = dict()
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
# 遍历从data_root中读取当日数据
for path in self.data_root:
if path == 'basic':
continue
else:
self.today_data[path] = DataLoader(os.path.join(self.data_root[path],f'{date}.csv'))
if update_type == 'position':
return True
for s in self.signal:
if s == 'open':
if s in self.data_root:
self.today_data[s+'_trade'] = DataLoader(self.today_data['open'].data[['open_post']].rename(columns={'open_post':'price'}))
self.today_data['open'] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
else:
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
if 'close' in self.signal:
pass
else:
self.today_data['close'] = DataLoader(self.today_data['basic'].data[['close_post']].rename(columns={'close_post':'price'}))
def update_avaliable_date(self):
"""
更新可执行日期
"""
# 可执行日期
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index()
def get_weight(self, date, account_weight, untradable_list, next_position):
"""
计算个股仓位
Args:
untradable_list (list): 无法交易列表
account_weight (float): 总权重,即当前持仓比例
"""
if isinstance(self.weight, str):
if self.weight == 'avg':
return account_weight / len(next_position)
if isinstance(self.weight, pd.DataFrame):
date_weight = self.weight.loc[date].dropna().sort_index()
# untradable_list不要求指定权重用昨日权重填充
weight_list = pd.Series(index=next_position['stock_code'])
try:
# 填充untradable_list权重
if len(untradable_list) > 0:
weight_list.loc[untradable_list] = self.position.set_index('stock_code').loc[untradable_list, 'weight']
except Exception:
raise ValueError('not found stock weight for untradable stocks in last position.')
try:
# 获取tradable_list权重并对untradable_list占据的仓位进行调整
tradable_list = list(set(next_position['stock_code']) - set(untradable_list))
# 剔除untradable_list仓位后剩余持仓根据自定义权重分配
weight_list.loc[tradable_list] = date_weight.loc[tradable_list].values / date_weight.loc[tradable_list].sum() * (account_weight - weight_list.loc[untradable_list].sum())
weight_list = weight_list.values
if sum(weight_list) > 1 + 1e-5: # 防止数据精度的影响,给与一定的宽松
raise Exception(f"total weight of {date} is larger then 1.")
return weight_list
except Exception as e:
print(e)
raise ValueError(f'not found specified stock weight in {date}')
def get_next_position(self, date, factor):
"""
计算下一时刻持仓
"""
# 计算持仓和最大可交易数量
if len(self.position) > 0:
last_position = pd.Series(index=self.position['stock_code'])
if len(self.position) <= self.num:
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
# 不足的数量通过买入列表自适应调整
# 这样能实现在因子值不足时也正常换仓
try:
max_sell_num = self.interval.loc[date]*len(last_position)
except Exception:
raise ValueError(f'not found interval in {date}')
else:
last_position = pd.Series()
max_sell_num = self.num
# 获取用于筛选的数据
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
exclude_data = dict()
for cond in self.buy_exclude:
if cond == 'amt_20':
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_exclude = stock_amt20.copy()
stock_amt_exclude.loc[:] = 0
stock_amt_exclude.loc[(stock_amt20 > self.buy_exclude[cond][0]) & (stock_amt20 < self.buy_exclude[cond][1])] = 1
stock_amt_exclude = stock_amt_exclude.sort_index()
exclude_data[cond] = stock_amt_exclude
if cond == 'list_days':
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_exclude = stock_ipo_days.copy()
stock_ipo_exclude.loc[:] = 0
stock_ipo_exclude.loc[(stock_ipo_days > self.buy_exclude[cond][0]) & (stock_ipo_days < self.buy_exclude[cond][1])] = 1
stock_ipo_exclude = stock_ipo_exclude.sort_index()
exclude_data[cond] = stock_ipo_exclude
if cond == 'mkt':
stock_size = self.today_data['basic'].get(factor.index.values, 'size').sort_index()
stock_size_exclude = stock_size.copy()
stock_size_exclude.loc[:] = 0
stock_size_exclude.loc[(stock_size > self.buy_exclude[cond][0]) & (stock_size < self.buy_exclude[cond][1])] = 1
stock_size_exclude = stock_size_exclude.sort_index()
exclude_data[cond] = stock_size_exclude
if cond == 'price':
stock_price = self.today_data['basic'].get(factor.index.values, 'close_pre').sort_index()
stock_price_exclude = stock_price.copy()
stock_price_exclude.loc[:] = 0
stock_price_exclude.loc[(stock_price > self.buy_exclude[cond][0]) & (stock_price < self.buy_exclude[cond][1])] = 1
stock_price_exclude = stock_price_exclude.sort_index()
exclude_data[cond] = stock_price_exclude
# 剔除列表
# 包含强制剔除列表和买入剔除列表:
# 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中
# 买入剔除列表如果已经持有不会过滤只对新买入的过滤
# 强制过滤列表
exclude_stock = []
for cond in self.force_exclude:
if cond == 'abnormal':
stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index()
exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list()
if cond == 'recession':
stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index()
exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list()
exclude_stock = list(set(exclude_stock))
force_exclude = copy.deepcopy(exclude_stock)
# 买入过滤列表
buy_exclude = []
for cond in self.buy_exclude:
buy_exclude += exclude_data[cond].loc[exclude_data[cond] != 1].index.to_list()
buy_exclude = list(set(buy_exclude))
# 交易列表
# 仓位判断给与计算误差冗余
if self.today_position_ratio <= 1.0 + 1e-5:
# 如果没有杠杆:
# 交易逻辑:
# 1 判断卖出,如果当天跌停则减少实际卖出数量
# 2 判断买入:根据实际卖出数量和距离目标持仓数量判断买入数量,如果当天涨停则减少实际买入数量
untradable_list = []
# ----- 卖出 -----
sell_list = []
limit_down_list = [] # 跌停股记录
# 遍历昨日持仓状态:
# 1 记录持仓状态
# 2 获取停牌股列表
# 3 获取异常强制卖出列表
last_position_status = pd.Series()
for stock in last_position.index:
last_position_status.loc[stock] = stock_status.loc[stock]
if last_position_status.loc[stock] in [0,2]:
untradable_list.append(stock)
else:
if last_position_status.loc[stock] in [5,7]:
continue
else:
if stock in force_exclude:
sell_list.append(stock)
force_sell_num = len(sell_list)
# 剔除无法交易列表后,按照当日因子反向排名逐个卖出
# 对无因子的异常股票进行最末位填充用于判断卖出,在判断买入列表时还是使用原始因子
if self.ascending:
factor_filled = factor.fillna(factor.max()+1)
else:
factor_filled = factor.fillna(factor.min()-1)
for stock in factor_filled.loc[list(set(last_position.index)-set(untradable_list)-set(sell_list))].sort_values(ascending=self.ascending).index.values[::-1]:
if len(sell_list) >= max_sell_num + force_sell_num:
break
if last_position_status.loc[stock] in [0,2]:
continue
else:
if last_position_status.loc[stock] in [5,7]:
limit_down_list.append(stock)
sell_list.append(stock)
sell_list = list(set(sell_list))
# 实际卖出列表 = 卖出列表 - 跌停列表
sell_list = list(set(sell_list) - set(limit_down_list))
# ----- 买入 -----
buy_list = []
# 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表
after_sell_list = set(last_position.index) - set(sell_list)
limit_up_list = [] # 涨停股记录
max_buy_num = max(0, self.num-len(last_position)+len(sell_list))
for stock in target_list:
if len(buy_list) == max_buy_num:
break
if stock in after_sell_list:
continue
else:
if self.with_st:
pass
else:
if stock_status.loc[stock] in [1,2]:
continue
if stock_status.loc[stock] in [4,6]:
limit_up_list.append(stock)
buy_list.append(stock)
buy_list = list(set(buy_list))
# 剔除同时在sell_list和buy_list的股票
duplicate_stock = set(sell_list) & set(buy_list)
sell_list = list(set(sell_list) - duplicate_stock)
buy_list = list(set(buy_list) - duplicate_stock)
# 生成下一期持仓
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
next_position['date'] = date
next_position['weight'] = self.get_weight(date, self.today_position_ratio, untradable_list+limit_down_list, next_position)
# 剔除无法买入且不在昨日持仓中的涨停股,这部分仓位空出
next_position = next_position[~next_position['stock_code'].isin(list(set(limit_up_list)-set(last_position.index)))]
next_position['margin_trade'] = 0
else:
# 如果有杠杆:
# 计算需要融资融券标的数量
margin_ratio = max(self.today_position_ratio-1, 0)
margin_needed = round(self.num * margin_ratio)
is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index()
# 获取历史融资融券池
if len(last_position) > 0:
last_margin_list = self.position.loc[self.position['margin_trade'] == 1, 'stock_code'].to_list()
else:
last_margin_list = []
# 获取历史非融资融券标的
if len(last_position) > 0:
last_normal_list = self.position.loc[self.position['margin_trade'] == 0, 'stock_code'].to_list()
else:
last_normal_list = []
# ----- 卖出 -----
sell_list = []
untradable_list = []
# 分别更新融资融券池的和非融资融券池
# 更新融资融券池
# 异常强制卖出
for stock in last_margin_list:
if stock_status.loc[stock] in [0,2,5,7]:
untradable_list.append(stock)
continue
else:
if stock in force_exclude:
sell_list.append(stock)
force_sell_num = len(sell_list)
# 对无因子的异常股票进行最末位填充用于判断卖出,在判断买入列表时还是使用原始因子
if self.ascending:
factor_filled = factor.fillna(factor.max()+1)
else:
factor_filled = factor.fillna(factor.min()-1)
for stock in factor_filled.loc[list(set(last_margin_list)-set(untradable_list)-set(sell_list))].sort_values(ascending=self.ascending).index.values[::-1]:
if len(sell_list) >= int(max_sell_num * margin_ratio) + force_sell_num + 1:
break
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
sell_list = list(set(sell_list))
next_margin_list = list(set(last_margin_list) - set(sell_list))
# 更新非融资融券池
# 异常强制卖出
untradable_list = []
for stock in last_normal_list:
if stock_status.loc[stock] in [0,2,5,7]:
untradable_list.append(stock)
continue
else:
if stock in force_exclude:
sell_list.append(stock)
force_sell_num += 1
for stock in factor_filled.loc[list(set(last_normal_list)-set(untradable_list)-set(sell_list))].sort_values(ascending=self.ascending).index.values[::-1]:
if len(sell_list) >= max_sell_num + force_sell_num:
break
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
sell_list = list(set(sell_list))
next_normal_list = list(set(last_normal_list) - set(sell_list))
# ----- 买入 -----
buy_list = []
# 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表
after_sell_list = set(last_position.index) - set(sell_list)
limit_up_list = [] # 涨停股记录
# 融资融券池的和非融资融券池的分开更新
# 更新融资融券池
for stock in target_list:
if len(next_margin_list) >= margin_needed:
break
if stock in after_sell_list:
continue
else:
if self.with_st:
pass
else:
if stock_status.loc[stock] in [1,2]:
continue
if is_margin.loc[stock]:
if stock_status.loc[stock] in [4,6]:
limit_up_list.append(stock)
next_margin_list.append(stock)
next_margin_list = list(set(next_margin_list))
# 更新非融资融券池
for stock in target_list:
if len(next_normal_list) >= self.num - len(next_margin_list):
break
if stock in (set(after_sell_list) | set(next_margin_list)):
continue
else:
if self.with_st:
pass
else:
if stock_status.loc[stock] in [1,2]:
continue
if stock_status.loc[stock] in [4,6]:
limit_up_list.append(stock)
next_normal_list.append(stock)
next_normal_list = list(set(next_normal_list))
next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list})
next_position['date'] = date
# 融资融券数量
margin_num = len(next_margin_list)
next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), untradable_list, next_position)
next_position['margin_trade'] = 0
next_position = next_position.set_index(['stock_code'])
next_position.loc[next_margin_list, 'margin_trade'] = 1
next_position = next_position.reset_index()
# 剔除无法买入的涨停股,这部分仓位空出
next_position = next_position[~next_position['stock_code'].isin(limit_up_list)]
# 检测当前持仓是否可以交易
frozen_list = []
if len(self.position) > 0:
for stock in next_position['stock_code']:
if stock_status.loc[stock] in [0,2]:
frozen_list.append(stock)
return sell_list, buy_list, frozen_list, next_position
def get_price(self,
trade_time: str='open',
target: Iterable[str]=[],
buy_list: Iterable[str] = [],
sell_list: Iterable[str] = []):
"""
获取价格
Args:
trade_time (str): 交易时间
target (Iterable[float]): 目标
buy_list (Iterable[str]): 买入目标
sell_list (Iterable[str]): 卖出目标
"""
if trade_time+'_trade' in self.today_data:
trade_price = self.today_data[trade_time+'_trade']
basic_price = self.today_data[trade_time]
else:
basic_price = self.today_data[trade_time]
trade_price = basic_price
target_price = pd.Series(index=target)
sell_list = list(set(target) & set(sell_list))
buy_list = list(set(target) & set(buy_list))
# 根据交易和非交易标的分别获取目标价格
target_price.loc[target] = basic_price.get(target, 'price').fillna(0)
target_price.loc[sell_list] = trade_price.get(sell_list, 'price') * (1 - self.current_fee[1])
target_price.loc[buy_list] = trade_price.get(buy_list, 'price') * (1 + self.current_fee[0])
return target_price
def check_update_status(self,
date: str,
trade_time: str):
# 判断当前更新状态
# 如果日期和交易时间已经存在则返回True
if len(self.account_history) == 0:
return False
elif date < self.account_history['date'].max():
return True
else:
exist_list = self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values
if f'{date}-{trade_time}' in exist_list:
return True
else:
return False
def reblance_weight(self,
trade_time: str,
cur_pos: pd.DataFrame,
next_position: pd.DataFrame):
"""
动态平衡权重
"""
# 判断冻结列表
stock_status = self.today_data['basic'].get(cur_pos['stock_code'].values, 'opening_info')
buy_frozen_list = []
for stock in cur_pos['stock_code']:
if stock_status.loc[stock] in [0,2,4,6]:
buy_frozen_list.append(stock)
sell_frozen_list = []
for stock in cur_pos['stock_code']:
if stock_status.loc[stock] in [0,2,5,7]:
sell_frozen_list.append(stock)
# 设定目标仓位
next_position['target_weight'] = next_position['weight']
next_position['current_weight'] = 0
cur_pos = cur_pos.set_index(['stock_code'])
next_position = next_position.set_index(['stock_code'])
current_list = list(set(cur_pos.index) & set(next_position.index))
next_position.loc[current_list, 'current_weight'] = cur_pos.loc[current_list, 'weight']
next_position['open'] = 0
next_position.loc[current_list, 'open'] = cur_pos.loc[current_list, 'close']
# 计算理想仓位变动
next_position['weight_chg'] = next_position['weight'] - next_position['current_weight']
# 根据冻结判断是否能够变动
next_position['final_weight'] = 0
buy_frozen_list = set(buy_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] > 0].index)
sell_frozen_list = set(sell_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] < 0].index)
next_position.loc[next_position.index, 'final_weight'] = next_position['weight']
next_position.loc[buy_frozen_list, 'final_weight'] = next_position.loc[buy_frozen_list, 'current_weight']
next_position.loc[sell_frozen_list, 'final_weight'] = next_position.loc[sell_frozen_list, 'current_weight']
# 动态平衡仓位
next_position['final_weight'] /= next_position['final_weight'].sum()
next_position['final_weight'] *= next_position['weight'].sum()
# 计算理想仓位变动
next_position['weight_chg'] = next_position['final_weight'] - next_position['current_weight']
next_position.loc[list(buy_frozen_list | sell_frozen_list), 'weight_chg'] = 0
# 动态平衡价格
next_position['adjust_price'] = 0
buy_adjust_list = next_position[next_position['weight_chg'] > 0].index.values
sell_adjust_list = next_position[next_position['weight_chg'] < 0].index.values
next_position.loc[buy_adjust_list, 'adjust_price'] = self.get_price(trade_time, buy_adjust_list, buy_adjust_list, []).values
next_position.loc[sell_adjust_list, 'adjust_price'] = self.get_price(trade_time, sell_adjust_list, [], sell_adjust_list).values
# 价格调整
next_position['adjust_open'] = (next_position['current_weight']*next_position['open'] + next_position['weight_chg']*next_position['adjust_price'])
next_position['adjust_open'] = next_position['adjust_open'] / next_position['final_weight']
next_position.loc[list(buy_frozen_list | sell_frozen_list), 'adjust_open'] = next_position.loc[list(buy_frozen_list | sell_frozen_list), 'open']
# 当日买入不调整
next_position['open'] = next_position['adjust_open']
next_position['weight'] = next_position['final_weight']
next_position = next_position.reset_index()
next_position = next_position[['stock_code','date','open','weight','margin_trade']]
return next_position
def update_account(self,
date: str,
trade_time: str,
cur_pos: pd.DataFrame,
next_position: pd.DataFrame):
"""
更新账户
Args:
date (str): 日期
trade_time (str): 交易时间
cur_pos (DataFrame): 当前持仓
next_position (Iterable[str]): 下一刻持仓
"""
turnover = pd.concat([
cur_pos.set_index(['stock_code'])['weight'].fillna(0).rename('cur'),
next_position.set_index(['stock_code'])['weight'].rename('next'),
], axis=1)
turnover = (turnover['next'] - turnover['cur'].fillna(0)).abs().sum()
position_ratio = next_position['weight'].sum()
if cur_pos['weight'].sum() == 0:
pnl = 0
else:
cash = max(0, 1 - cur_pos['weight'].sum())
pnl = ((cur_pos['end_weight'].sum() + cash) / (cur_pos['weight'].sum() + cash)) - 1
self.account *= 1+pnl
self.account_history = self.account_history.append({
'date': date,
'trade_time': trade_time,
'turnover': turnover,
'position_ratio': position_ratio,
'pnl': pnl
}, ignore_index=True)
return True
@staticmethod
def update_next_weight(position):
"""
根据收盘权重计算下一时刻新的个股权重
"""
if position['weight'].sum() <= 1 + 1e-5:
# 非融资情况
cash = max(0, 1 - position['weight'].sum())
return position['end_weight'] / (position['end_weight'].sum() + cash)
else:
return position['weight'].sum() * (position['end_weight'] / position['end_weight'].sum())
def update_signal(self,
date:str,
update_type='rtn'):
"""
更新信号收益
Args:
update_type (str): 更新类型
- position: 只更新持仓不更新收益
- rtn: 更新收益和持仓
"""
# 如果更新日期的close已经记录则跳过否则删除现有日期相关记录继续更新
if f'{date}-close' in self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values:
return True
else:
self.account_history = self.account_history.query(f'date != "{date}" ', engine='python')
if date in self.position_history:
self.position_history.pop(date)
# ----- 更新当日回测数据 ------
# 更新当前日期和持仓信号
self.current_date = date
self.load_data(date, update_type)
# 更新当日持仓比例
if isinstance(self.position_ratio, float):
self.today_position_ratio = self.position_ratio
if isinstance(self.position_ratio, pd.Series):
self.today_position_ratio = self.position_ratio.loc[date]
# 更新费用
fee = (self.commission + self.slippage[0], self.commission + self.slippage[1])
current_tax = (0.001, 0.001)
for time,tax_rate in self.tax.items():
if date > time:
current_tax = tax_rate
fee = (fee[0] + current_tax[0], fee[1] + current_tax[1])
self.current_fee = fee
# 如果当前持仓不空,添加隔夜收益,否则直接买入
position_fields = ['stock_code','date','weight','margin_trade','open','close','end_weight']
if len(self.position) == 0:
cur_pos = pd.DataFrame(columns=position_fields)
else:
cur_pos = self.position.copy()
# 冻结列表
frozen_list = []
# ----- 遍历各个交易时间的信号 -----
for _,trade_time in enumerate(self.signal):
if self.check_update_status(date, trade_time):
continue
if date in self.signal[trade_time].index:
factor = self.signal[trade_time].loc[date]
else:
continue
factor = self.signal[trade_time].loc[date]
# 获取当前、持仓
sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor)
# 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算
if update_type == 'position':
self.position_history[date] = next_position.copy()
return True
if len(cur_pos) > 0:
cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], sell_list).values
# 停牌股价格不变
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
# 计算收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
self.update_account(date, trade_time, cur_pos, next_position)
# 更新仓位
cur_pos['weight'] = self.update_next_weight(cur_pos)
# 调整权重:买入、卖出、仓位再平衡
next_position = self.reblance_weight(trade_time, cur_pos, next_position)
else:
next_position['open'] = self.get_price(trade_time, next_position['stock_code'].values, buy_list, []).values
self.position = next_position.copy()
# 收盘统计当日收益
trade_time = 'close'
if self.check_update_status(date, trade_time):
return True
cur_pos = self.position.copy()
cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], []).values
# 停牌价格不变
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open']
# 更新当日收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
position_record = cur_pos.copy()
position_record['end_weight'] = self.update_next_weight(position_record)
self.position_history[date] = position_record.copy()[position_fields]
# 更新当期收盘后个股仓位作为下一期的开盘仓位
cur_pos['weight'] = self.update_next_weight(cur_pos)
next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']]
next_position['open'] = cur_pos['close']
self.update_account(date, trade_time, cur_pos, cur_pos)
# 记录当前时刻最终持仓和个股权重
self.position = next_position.copy()
return True