spread_backtest/trader.py

704 lines
34 KiB
Python
Raw Normal View History

2024-05-22 23:33:19 +08:00
import pandas as pd
import numpy as np
import sys, os
2024-05-28 00:03:38 +08:00
import copy
2024-05-22 23:33:19 +08:00
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools
gft = get_factor_tools()
2024-05-22 23:33:19 +08:00
from typing import Union, Iterable, Dict
from ordered_set import OrderedSet
2024-05-22 23:33:19 +08:00
from account import Account
from dataloader import DataLoader
class Trader(Account):
"""
交易类: 用于控制每日交易情况
Args:
signal (dict[str, pd.DataFrame]): 目标因子按顺序执行
interval (int, tuple, pd.Series): 交易间隔
num (int): 持仓数量
ascending (bool): 因子方向
with_st (bool): 是否包含st
tick (bool): 是否开始tick模拟模式(开发中)
weight ([str, pd.DataFrame]): 权重分配
- avg (str): 平均分配每天早盘重新分配日中交易不重新分配
- (pd.DataFrame): 自定义股票权重包含每天个股指定的权重会自动归一化
2024-05-22 23:33:19 +08:00
amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限
data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列
ipo_days (int): 筛选上市时间
slippage (tuple): 买入和卖出滑点
commission (float): 佣金
tax (dict): 印花税
exclude_list (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓
- abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除
- report: 财报同比下降50%以上剔除
account (Account): 账户设置account.Account
2024-05-22 23:33:19 +08:00
"""
def __init__(self,
signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]=1,
num: int=100,
ascending: bool=False,
with_st: bool=False,
data_root:dict={},
tick: bool=False,
weight: str='avg',
amt_filter: set=(0,),
ipo_days: int=20,
slippage :tuple=(0.001,0.001),
commission: float=0.0001,
tax: dict={
'1990-01-01': (0.001,0.001),
'2008-04-24': (0.001,0.001),
'2008-09-19': (0, 0.001),
'2023-08-28': (0, 0.0005)
},
2024-05-28 00:03:38 +08:00
exclude_list: list=[],
account: dict={},
2024-05-22 23:33:19 +08:00
**kwargs) -> None:
# 初始化账户
super().__init__(account)
2024-05-22 23:33:19 +08:00
if isinstance(signal, dict):
self.signal = signal
if 'close' in signal:
raise ValueError('signal key cannot be close')
for s in self.signal:
self.signal[s] = gft.return_factor(self.signal[s], self.signal[s].index.min(), self.signal[s].index.max(), return_type='origin')
else:
raise ValueError('type of signal is invalid')
# --------------------
# 参数检验
# --------------------
if len(kwargs) > 0:
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
2024-05-22 23:33:19 +08:00
# interval
self.interval = []
for s in signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
self.interval.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(self.interval)
# num
if isinstance(num, int):
self.num = int(num)
else:
raise ValueError('num should be int')
# ascending
if isinstance(ascending, bool):
self.ascending = ascending
else:
raise ValueError('invalid type for `ascending`')
2024-05-22 23:33:19 +08:00
# with_st
if isinstance(with_st, bool):
self.with_st = with_st
else:
raise ValueError('invalid type for `with_st`')
# weight
if isinstance(weight, (str, pd.DataFrame)):
self.weight = weight
else:
raise ValueError('invalid type for `weight`')
2024-05-22 23:33:19 +08:00
# amt_filter
if isinstance(amt_filter, (set,list,tuple)):
if len(amt_filter) == 1:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = np.inf
else:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = amt_filter[1]
else:
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
# ipo_days
if isinstance(ipo_days, int):
self.ipo_days = ipo_days
else:
raise Exception('wrong type for ipo_days, `int` is required')
# slippage
if isinstance(slippage, tuple) and len(slippage) == 2:
self.slippage = slippage
else:
raise ValueError('slippage should be set.')
# commission
if isinstance(commission, float):
self.commission = commission
else:
raise ValueError('commission should be flaot.')
# tax
if isinstance(tax, dict):
self.tax = tax
else:
raise ValueError('tax should be dict.')
2024-05-28 00:03:38 +08:00
# exclude
if isinstance(exclude_list, list):
self.exclude_list = exclude_list
else:
raise ValueError('exclude_list should be list.')
2024-05-22 23:33:19 +08:00
# data_root
# 至少包含basic data路径open信号默认使用basic_data
if len(data_root) <= 0:
raise ValueError('num of data_root should be equal or greater than 1')
if 'basic' in data_root:
# 可执行日期
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(data_root['basic'])]).sort_index()
else:
raise ValueError('data_root should contain basic data root')
for s in self.signal:
if s == 'open':
continue
else:
if s not in data_root:
raise ValueError(f'data for signal {s} is not provided')
self.data_root = data_root
def load_data(self,
date: str,
update_type: str='rtn'):
"""
加载每日基础数据
Args:
update_type (str): 更新模式
- rtn: 更新所有信号数据
- position: 只更新basic数据用于持仓判断
"""
self.today_data = dict()
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
if update_type == 'position':
return True
for s in self.signal:
if s == 'open':
if s in self.data_root:
continue
else:
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
else:
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
if 'close' in self.signal:
pass
else:
self.today_data['close'] = DataLoader(self.today_data['basic'].data[['close_post']].rename(columns={'close_post':'price'}))
def update_avaliable_date(self):
"""
更新可执行日期
"""
# 可执行日期
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.data_root['basic'])]).sort_index()
def get_weight(self, date, total_weight, next_position):
"""
计算个股仓位
"""
if isinstance(self.weight, str):
if self.weight == 'avg':
return total_weight / len(next_position)
if isinstance(self.weight, pd.DataFrame):
date_weight = self.weight.loc[date].dropna().sort_index()
try:
weight_list = date_weight.loc[next_position['stock_code'].to_list()].values
weight_list = total_weight * weight_list / sum(weight_list)
return weight_list
except:
raise ValueError(f'not found stock weight in {date}')
2024-05-22 23:33:19 +08:00
def get_next_position(self, date, factor):
"""
计算下一时刻持仓
"""
# 计算持仓和最大可交易数量
if len(self.position) > 0:
last_position = self.position['stock_code'].values
last_position = factor.loc[last_position].sort_values(ascending=self.ascending)
if len(self.position) <= self.num:
# 如果昨日持仓本身就不足持仓数量则不用足额换仓
max_sell_num = min(int(self.interval.loc[date]*self.num), int(self.interval.loc[date]*self.num)+len(self.position)-self.num)
else:
# 如果昨日持仓本身就超额持仓数量则超额换仓
max_sell_num = max(int(self.interval.loc[date]*self.num), int(self.interval.loc[date]*self.num)+len(self.position)-self.num)
2024-05-22 23:33:19 +08:00
else:
last_position = pd.Series()
max_sell_num = self.num
2024-05-22 23:33:19 +08:00
target_list = []
# 获取用于筛选的数据
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_filter = stock_amt20.copy()
stock_amt_filter.loc[:] = 0
stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1
stock_amt_filter = stock_amt_filter.sort_index()
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_filter = stock_ipo_days.copy()
stock_ipo_filter.loc[:] = 0
stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
# 剔除列表 = ipo筛选 + 成交量筛选 + 额外剔除
# 其中额外提出会强制执行,因此单独保存一份强制执行的列表
2024-05-28 00:03:38 +08:00
exclude_stock = []
for exclude in self.exclude_list:
if exclude == 'abnormal':
stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index()
exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list()
exclude_stock = list(set(exclude_stock))
force_exclude = copy.deepcopy(exclude_stock)
exclude_stock += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list()
exclude_stock += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list()
exclude_stock = list(set(exclude_stock))
2024-05-22 23:33:19 +08:00
# 交易列表
if self.leverage <= 1.0:
# 获取当前时间目标列表和冻结(无法交易)列表:
# 1 保留冻结的昨日持仓
# 2 目标持仓列表会对当日因子进行ST、成交量、上市时间滤后按因子排序方向选取
for stock in last_position.index:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2]:
target_list.append(stock)
2024-05-28 00:03:38 +08:00
# 剔除过滤条件后
after_filter_list = list(set(factor.index) - set(exclude_stock))
for stock in factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.values:
# 目标持仓数量达到持仓数量则中止
if len(target_list) == self.num:
break
# ST过滤
2024-05-22 23:33:19 +08:00
if self.with_st:
if stock_status.loc[stock] in [0,2,5,7]:
2024-05-22 23:33:19 +08:00
if stock in last_position.index:
target_list.append(stock)
else:
target_list.append(stock)
else:
# 非ST
if stock_status.loc[stock] in [3,6]:
2024-05-22 23:33:19 +08:00
target_list.append(stock)
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2,7]:
if stock in last_position.index:
target_list.append(stock)
target_list = list(OrderedSet(target_list))
2024-05-22 23:33:19 +08:00
# 如果没有杠杆
buy_list = []
sell_list = []
# ----- 卖出 -----
2024-05-28 00:03:38 +08:00
# 异常强制卖出
for stock in last_position.index:
if stock in force_exclude:
sell_list.append(stock)
force_sell_num = len(sell_list)
2024-05-22 23:33:19 +08:00
# 按照反向排名逐个卖出
if self.ascending:
factor = factor.fillna(factor.max()+1)
else:
factor = factor.fillna(factor.min()-1)
2024-05-28 00:03:38 +08:00
for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]:
if len(sell_list) >= max_sell_num + force_sell_num:
break
2024-05-22 23:33:19 +08:00
if stock in target_list:
continue
else:
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
2024-05-28 00:03:38 +08:00
sell_list = list(set(sell_list))
2024-05-22 23:33:19 +08:00
# ----- 买入 -----
# 卖出后持仓列表
after_sell_list = set(last_position.index) - set(sell_list)
cant_buy_list = [] # 涨停股记录
max_buy_num = max(0, self.num-len(last_position)+len(sell_list))
2024-05-22 23:33:19 +08:00
for stock in target_list:
if stock in after_sell_list:
continue
else:
if stock_status.loc[stock] in [4,6]:
cant_buy_list.append(stock)
2024-05-22 23:33:19 +08:00
buy_list.append(stock)
if len(buy_list) == max_buy_num:
2024-05-22 23:33:19 +08:00
break
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
next_position['date'] = date
next_position['weight'] = self.get_weight(date, self.leverage, next_position)
# 剔除无法买入的涨停股,这部分仓位空出
next_position = next_position[~next_position['stock_code'].isin(cant_buy_list)]
2024-05-22 23:33:19 +08:00
next_position['margin_trade'] = 0
else:
# 如果有杠杆
def assign_stock(normal_list, margin_list, margin_needed, stock, status):
if status == 1:
if len(margin_list) < margin_needed:
margin_list.append(stock)
else:
if len(normal_list) < self.num - margin_needed:
normal_list.append(stock)
return normal_list, margin_list
# 计算需要融资融券标的数量
margin_ratio = max(self.leverage-1, 0)
margin_needed = round(self.num * margin_ratio)
is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index()
# 获取当前时间目标列表和冻结(无法交易)列表:
# 1 保留冻结的昨日持仓
# 2 目标持仓列表会对当日因子进行ST、成交量、上市时间滤后按因子排序方向选取
# 3 根据融资比例将股票分为常规池子和融资池子
2024-05-22 23:33:19 +08:00
normal_list = []
margin_list = []
# 获取历史融资融券池
if len(last_position) > 0:
last_margin_list = self.position.loc[self.position['margin_trade'] == 1, 'stock_code'].to_list()
else:
last_margin_list = []
# 获取历史非融资融券标的
if len(last_position) > 0:
last_normal_list = self.position.loc[self.position['margin_trade'] == 0, 'stock_code'].to_list()
else:
last_normal_list = []
for stock in last_margin_list:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2]:
margin_list.append(stock)
for stock in last_normal_list:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2]:
normal_list.append(stock)
2024-05-28 00:03:38 +08:00
# 剔除过滤条件后
after_filter_list = list(set(factor.index) - set(exclude_stock))
for stock in factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.values:
if len(normal_list + margin_list) == self.num:
break
2024-05-22 23:33:19 +08:00
if self.with_st:
if stock_status.loc[stock] in [0,2,5,7]:
if stock in last_margin_list:
margin_list.append(stock)
else:
normal_list.append(stock)
2024-05-22 23:33:19 +08:00
else:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
# 非ST
if stock_status.loc[stock] in [3,6]:
2024-05-22 23:33:19 +08:00
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2,7]:
if stock in last_margin_list:
margin_list.append(stock)
else:
normal_list.append(stock)
margin_list = list(OrderedSet(margin_list))
normal_list = list(OrderedSet(normal_list))
2024-05-22 23:33:19 +08:00
target_list = normal_list + margin_list
# ----- 卖出 -----
buy_list = []
sell_list = []
# 融资融券池的和非融资融券池的分开更新
# 更新融资融券池
2024-05-28 00:03:38 +08:00
# 异常强制卖出
for stock in last_margin_list:
if stock in force_exclude:
sell_list.append(stock)
2024-05-28 00:03:38 +08:00
force_sell_num = len(sell_list)
2024-05-22 23:33:19 +08:00
for stock in factor.loc[last_margin_list].sort_values(ascending=self.ascending).index.values[::-1]:
2024-05-28 00:03:38 +08:00
if len(sell_list) >= int(max_sell_num * margin_ratio) + force_sell_num + 1:
break
2024-05-22 23:33:19 +08:00
if stock in normal_list:
continue
else:
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
2024-05-28 00:03:38 +08:00
sell_list = list(set(sell_list))
2024-05-22 23:33:19 +08:00
next_margin_list = list(set(last_margin_list) - set(sell_list))
# 更新非融资融券池
2024-05-28 00:03:38 +08:00
# 异常强制卖出
for stock in last_normal_list:
if stock in force_exclude:
sell_list.append(stock)
force_sell_num += 1
2024-05-22 23:33:19 +08:00
for stock in factor.loc[last_normal_list].sort_values(ascending=self.ascending).index.values[::-1]:
2024-05-28 00:03:38 +08:00
if len(sell_list) >= max_sell_num + force_sell_num:
break
2024-05-22 23:33:19 +08:00
if stock in normal_list:
continue
else:
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
2024-05-28 00:03:38 +08:00
sell_list = list(set(sell_list))
2024-05-22 23:33:19 +08:00
next_normal_list = list(set(last_normal_list) - set(sell_list))
# ----- 买入 -----
# 卖出后持仓列表
after_sell_list = set(last_position.index) - set(sell_list)
cant_buy_list = [] # 涨停股记录
2024-05-22 23:33:19 +08:00
# 融资融券池的和非融资融券池的分开更新
# 更新融资融券池
for stock in margin_list:
if stock in after_sell_list:
continue
else:
if stock_status.loc[stock] in [4,6]:
cant_buy_list.append(stock)
2024-05-22 23:33:19 +08:00
next_margin_list.append(stock)
if len(next_margin_list) >= margin_needed:
2024-05-22 23:33:19 +08:00
break
# 更新非融资融券池
for stock in normal_list:
if stock in after_sell_list:
continue
else:
if stock_status.loc[stock] in [4,6]:
cant_buy_list.append(stock)
2024-05-22 23:33:19 +08:00
next_normal_list.append(stock)
if len(next_normal_list) >= self.num - margin_needed:
break
2024-05-28 00:03:38 +08:00
2024-05-22 23:33:19 +08:00
next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list})
next_position['date'] = date
# 融资融券数量
margin_num = len(next_margin_list)
next_position['weight'] = self.get_weight(date, 1 + (margin_num / self.num), next_position)
2024-05-22 23:33:19 +08:00
next_position['margin_trade'] = 0
next_position = next_position.set_index(['stock_code'])
next_position.loc[next_margin_list, 'margin_trade'] = 1
next_position = next_position.reset_index()
# 剔除无法买入的涨停股,这部分仓位空出
next_position = next_position[~next_position['stock_code'].isin(cant_buy_list)]
2024-05-22 23:33:19 +08:00
# 检测当前持仓是否可以交易
frozen_list = []
if len(self.position) > 0:
for stock in next_position['stock_code']:
if stock_status.loc[stock] in [0,2]:
frozen_list.append(stock)
return sell_list, buy_list, frozen_list, next_position
def get_price(self,
trade_time: str='open',
target: Iterable[str]=[],
buy_list: Iterable[str] = [],
sell_list: Iterable[str] = []):
"""
获取价格
Args:
trade_time (str): 交易时间
target (Iterable[float]): 目标
buy_list (Iterable[str]): 买入目标
sell_list (Iterable[str]): 卖出目标
"""
stock_price = self.today_data[trade_time]
target_price = pd.Series(index=target)
sell_list = list(set(target) & set(sell_list))
buy_list = list(set(target) & set(buy_list))
target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1])
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0])
2024-05-22 23:33:19 +08:00
return target_price
def check_update_status(self,
date: str,
trade_time: str):
# 判断当前更新状态
# 如果日期和交易时间已经存在则返回True
if len(self.account_history) == 0:
return False
elif date < self.account_history['date'].max():
return True
else:
exist_list = self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values
if f'{date}-{trade_time}' in exist_list:
return True
else:
return False
def reblance_weight(self,
trade_time: str,
cur_pos: pd.DataFrame,
next_position: pd.DataFrame):
"""
动态平衡权重
"""
# 判断冻结列表
stock_status = self.today_data['basic'].get(cur_pos['stock_code'].values, 'opening_info')
buy_frozen_list = []
for stock in cur_pos['stock_code']:
if stock_status.loc[stock] in [0,2,4,6]:
buy_frozen_list.append(stock)
sell_frozen_list = []
for stock in cur_pos['stock_code']:
if stock_status.loc[stock] in [0,2,5,7]:
sell_frozen_list.append(stock)
# 设定目标仓位
next_position['target_weight'] = next_position['weight']
next_position['current_weight'] = 0
cur_pos = cur_pos.set_index(['stock_code'])
next_position = next_position.set_index(['stock_code'])
current_list = list(set(cur_pos.index) & set(next_position.index))
next_position.loc[current_list, 'current_weight'] = cur_pos.loc[current_list, 'weight']
next_position['open'] = 0
next_position.loc[current_list, 'open'] = cur_pos.loc[current_list, 'close']
# 计算理想仓位变动
next_position['weight_chg'] = next_position['weight'] - next_position['current_weight']
# 根据冻结判断是否能够变动
next_position['final_weight'] = 0
buy_frozen_list = set(buy_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] > 0].index)
sell_frozen_list = set(sell_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] < 0].index)
next_position.loc[next_position.index, 'final_weight'] = next_position['weight']
next_position.loc[buy_frozen_list, 'final_weight'] = next_position.loc[buy_frozen_list, 'current_weight']
next_position.loc[sell_frozen_list, 'final_weight'] = next_position.loc[sell_frozen_list, 'current_weight']
# 动态平衡仓位
next_position['final_weight'] /= next_position['final_weight'].sum()
next_position['final_weight'] *= next_position['weight'].sum()
# 计算理想仓位变动
next_position['weight_chg'] = next_position['final_weight'] - next_position['current_weight']
next_position.loc[list(buy_frozen_list | sell_frozen_list), 'weight_chg'] = 0
# 动态平衡价格
next_position['adjust_price'] = 0
buy_adjust_list = next_position[next_position['weight_chg'] > 0].index.values
sell_adjust_list = next_position[next_position['weight_chg'] < 0].index.values
next_position.loc[buy_adjust_list, 'adjust_price'] = self.get_price(trade_time, buy_adjust_list, buy_adjust_list, []).values
next_position.loc[sell_adjust_list, 'adjust_price'] = self.get_price(trade_time, sell_adjust_list, [], sell_adjust_list).values
# 价格调整
next_position['adjust_open'] = (next_position['current_weight']*next_position['open'] + next_position['weight_chg']*next_position['adjust_price'])
next_position['adjust_open'] = next_position['adjust_open'] / next_position['final_weight']
next_position.loc[list(buy_frozen_list | sell_frozen_list), 'adjust_open'] = next_position.loc[list(buy_frozen_list | sell_frozen_list), 'open']
# 当日买入不调整
next_position['open'] = next_position['adjust_open']
next_position['weight'] = next_position['final_weight']
next_position = next_position.reset_index()
next_position = next_position[['stock_code','date','open','weight','margin_trade']]
return next_position
def update_account(self,
date: str,
trade_time: str,
cur_pos: pd.DataFrame,
next_position: pd.DataFrame):
"""
更新账户
Args:
date (str): 日期
trade_time (str): 交易时间
cur_pos (DataFrame): 当前持仓
next_position (Iterable[str]): 下一刻持仓
"""
turnover = pd.concat([
cur_pos.set_index(['stock_code'])['weight'].fillna(0).rename('cur'),
next_position.set_index(['stock_code'])['weight'].rename('next'),
], axis=1)
turnover = (turnover['next'] - turnover['cur'].fillna(0)).abs().sum()
leverage = next_position['weight'].sum()
if cur_pos['weight'].sum() == 0:
pnl = 0
else:
pnl = (cur_pos['end_weight'].sum() - cur_pos['weight'].sum())
self.account *= 1+pnl
self.account_history = self.account_history.append({
'date': date,
'trade_time': trade_time,
'turnover': turnover,
'leverage': leverage,
'pnl': pnl
}, ignore_index=True)
return True
def update_signal(self,
date:str,
update_type='rtn'):
2024-05-22 23:33:19 +08:00
"""
更新信号收益
Args:
update_type (str): 更新类型
- position: 只更新持仓不更新收益
- rtn: 更新收益和持仓
"""
# 如果更新日期的close已经记录则跳过否则删除现有日期相关记录继续更新
if f'{date}-close' in self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values:
return True
else:
self.account_history = self.account_history.query(f'date != "{date}" ', engine='python')
if date in self.position_history:
self.position_history.pop(date)
# 更新持仓信号
self.load_data(date, update_type)
# 更新费用
fee = (self.commission + self.slippage[0], self.commission + self.slippage[1])
current_tax = (0.001, 0.001)
for time,tax_rate in self.tax.items():
if date > time:
current_tax = tax_rate
fee = (fee[0] + current_tax[0], fee[1] + current_tax[1])
self.current_fee = fee
2024-05-22 23:33:19 +08:00
# 如果当前持仓不空,添加隔夜收益,否则直接买入
if len(self.position) == 0:
cur_pos = pd.DataFrame(columns=['stock_code','date','weight','open','close','margin_trade'])
else:
cur_pos = self.position.copy()
# 冻结列表
frozen_list = []
# 遍历各个交易时间的信号
for idx,trade_time in enumerate(self.signal):
if self.check_update_status(date, trade_time):
continue
if date in self.signal[trade_time].index:
factor = self.signal[trade_time].loc[date]
else:
continue
factor = self.signal[trade_time].loc[date]
# 获取当前、持仓
sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor)
# 区分回测模型和仓位模式:回撤模式会记录收益,仓位模式只记录下一日持仓并结束计算
2024-05-22 23:33:19 +08:00
if update_type == 'position':
self.position_history[date] = next_position.copy()
return True
if len(cur_pos) > 0:
cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], sell_list).values
# 停牌股价格不变
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
# 计算收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
self.update_account(date, trade_time, cur_pos, next_position)
# 更新仓位
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
# 调整权重:买入、卖出、仓位再平衡
next_position = self.reblance_weight(trade_time, cur_pos, next_position)
else:
next_position['open'] = self.get_price(trade_time, next_position['stock_code'].values, buy_list, []).values
self.position = next_position.copy()
# 收盘统计当日收益
trade_time = 'close'
if self.check_update_status(date, trade_time):
return True
cur_pos = self.position.copy()
cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], []).values
# 停牌价格不变
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open']
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
position_record = cur_pos.copy()
position_record['end_weight'] = (position_record['end_weight'] / position_record['end_weight'].sum()) * position_record['weight'].sum()
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']]
next_position['open'] = cur_pos['close']
self.update_account(date, trade_time, cur_pos, cur_pos)
self.position = next_position.copy()
self.position_history[date] = position_record.copy()
return True