spread_backtest/trader.py

550 lines
26 KiB
Python
Raw Normal View History

2024-05-22 23:33:19 +08:00
import pandas as pd
import numpy as np
import sys, os
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools
gft = get_factor_tools()
from typing import Union, Iterable, Dict
from account import Account
from dataloader import DataLoader
class Trader(Account):
"""
交易类: 用于控制每日交易情况
Args:
signal (dict[str, pd.DataFrame]): 目标因子按顺序执行
interval (int, tuple, pd.Series): 交易间隔
num (int): 持仓数量
ascending (bool): 因子方向
fee (tuple): 买入成本和卖出成本
with_st (bool): 是否包含st
tick (bool): 是否开始tick模拟模式(开发中)
weight (str): 权重分配
- avg: 平均分配每天早盘重新分配日中交易不重新分配
amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限
data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列
ipo_days (int): 筛选上市时间
"""
def __init__(self,
signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]=1,
num: int=100,
ascending: bool=False,
fee :tuple=(0.001,0.002),
with_st: bool=False,
data_root:dict={},
tick: bool=False,
weight: str='avg',
amt_filter: set=(0,),
ipo_days: int=20,
**kwargs) -> None:
# 初始化账户
super().__init__(**kwargs.get('account', {}))
if isinstance(signal, dict):
self.signal = signal
if 'close' in signal:
raise ValueError('signal key cannot be close')
for s in self.signal:
self.signal[s] = gft.return_factor(self.signal[s], self.signal[s].index.min(), self.signal[s].index.max(), return_type='origin')
else:
raise ValueError('type of signal is invalid')
# interval
self.interval = []
for s in signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
self.interval.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(self.interval)
# num
if isinstance(num, int):
self.num = int(num)
else:
raise ValueError('num should be int')
# ascending
if isinstance(ascending, bool):
self.ascending = ascending
else:
raise ValueError('invalid type for ascending')
# fee
if isinstance(fee, tuple) and len(fee) == 2:
self.fee = fee
else:
raise ValueError('invalid input for fee')
# with_st
if isinstance(with_st, bool):
self.with_st = with_st
else:
raise ValueError('invalid type for with_st')
# amt_filter
if isinstance(amt_filter, (set,list,tuple)):
if len(amt_filter) == 1:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = np.inf
else:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = amt_filter[1]
else:
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
# ipo_days
if isinstance(ipo_days, int):
self.ipo_days = ipo_days
else:
raise Exception('wrong type for ipo_days, `int` is required')
# data_root
# 至少包含basic data路径open信号默认使用basic_data
if len(data_root) <= 0:
raise ValueError('num of data_root should be equal or greater than 1')
if 'basic' in data_root:
# 可执行日期
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(data_root['basic'])]).sort_index()
else:
raise ValueError('data_root should contain basic data root')
for s in self.signal:
if s == 'open':
continue
else:
if s not in data_root:
raise ValueError(f'data for signal {s} is not provided')
self.data_root = data_root
def load_data(self,
date: str,
update_type: str='rtn'):
"""
加载每日基础数据
Args:
update_type (str): 更新模式
- rtn: 更新所有信号数据
- position: 只更新basic数据用于持仓判断
"""
self.today_data = dict()
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
if update_type == 'position':
return True
for s in self.signal:
if s == 'open':
if s in self.data_root:
continue
else:
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
else:
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
if 'close' in self.signal:
pass
else:
self.today_data['close'] = DataLoader(self.today_data['basic'].data[['close_post']].rename(columns={'close_post':'price'}))
def get_next_position(self, date, factor):
"""
计算下一时刻持仓
"""
# 计算持仓和最大可交易数量
if len(self.position) > 0:
last_position = self.position['stock_code'].values
last_position = factor.loc[last_position].sort_values(ascending=self.ascending)
max_trade_num = max(int(self.interval.loc[date]*self.num), self.num-len(self.position))
else:
last_position = pd.Series()
max_trade_num = self.num
target_list = []
# 获取用于筛选的数据
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_filter = stock_amt20.copy()
stock_amt_filter.loc[:] = 0
stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1
stock_amt_filter = stock_amt_filter.sort_index()
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_filter = stock_ipo_days.copy()
stock_ipo_filter.loc[:] = 0
stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
# 交易列表
if self.leverage <= 1.0:
# 获取当前时间目标列表和冻结(无法交易)列表
for stock in factor.dropna().sort_values(ascending=self.ascending).index.values:
if (stock_amt_filter.loc[stock] != 1) or (stock_ipo_filter.loc[stock] != 1):
continue
if self.with_st:
if stock_status.loc[stock] in [0,2]:
if stock in last_position.index:
target_list.append(stock)
else:
target_list.append(stock)
else:
# 非ST
if stock_status.loc[stock] in [3,7]:
target_list.append(stock)
else:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2,6]:
if stock in last_position.index:
target_list.append(stock)
if len(target_list) == self.num:
break
# 如果没有杠杆
buy_list = []
sell_list = []
# ----- 卖出 -----
# 按照反向排名逐个卖出
if self.ascending:
factor = factor.fillna(factor.max()+1)
else:
factor = factor.fillna(factor.min()-1)
for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]:
if stock in target_list:
continue
else:
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
if len(sell_list) == max_trade_num:
break
# ----- 买入 -----
# 卖出后持仓列表
after_sell_list = set(last_position.index) - set(sell_list)
max_trade_num = min(max_trade_num, self.num-len(last_position)+len(sell_list))
for stock in target_list:
if stock in after_sell_list:
continue
else:
buy_list.append(stock)
if len(buy_list) == max_trade_num:
break
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
next_position['date'] = date
next_position['weight'] = self.leverage / len(next_position)
next_position['margin_trade'] = 0
else:
# 如果有杠杆
def assign_stock(normal_list, margin_list, margin_needed, stock, status):
if status == 1:
if len(margin_list) < margin_needed:
margin_list.append(stock)
else:
if len(normal_list) < self.num - margin_needed:
normal_list.append(stock)
return normal_list, margin_list
# 计算需要融资融券标的数量
margin_ratio = max(self.leverage-1, 0)
margin_needed = round(self.num * margin_ratio)
is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index()
# 获取当前时间目标列表和冻结(无法交易)列表
normal_list = []
margin_list = []
for stock in factor.dropna().sort_values(ascending=self.ascending).index.values:
if stock_amt_filter.loc[stock] != 1:
continue
if self.with_st:
if stock_status.loc[stock] in [0,2]:
if stock in last_position.index:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
# 非ST
if stock_status.loc[stock] in [3,7]:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2,6]:
if stock in last_position.index:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
if len(normal_list + margin_list) == self.num:
break
target_list = normal_list + margin_list
# ----- 卖出 -----
buy_list = []
sell_list = []
# 融资融券池的和非融资融券池的分开更新
# 更新融资融券池
if len(last_position) > 0:
last_margin_list = self.position.loc[self.position['margin_trade'] == 1, 'stock_code'].to_list()
else:
last_margin_list = []
for stock in factor.loc[last_margin_list].sort_values(ascending=self.ascending).index.values[::-1]:
if stock in normal_list:
continue
else:
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
if len(sell_list) >= int(max_trade_num * margin_ratio) + 1:
break
next_margin_list = list(set(last_margin_list) - set(sell_list))
# 更新非融资融券池
if len(last_position) > 0:
last_normal_list = self.position.loc[self.position['margin_trade'] == 0, 'stock_code'].to_list()
else:
last_normal_list = []
for stock in factor.loc[last_normal_list].sort_values(ascending=self.ascending).index.values[::-1]:
if stock in normal_list:
continue
else:
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
if len(sell_list) >= max_trade_num:
break
next_normal_list = list(set(last_normal_list) - set(sell_list))
# ----- 买入 -----
# 卖出后持仓列表
after_sell_list = set(last_position.index) - set(sell_list)
max_trade_num = min(max_trade_num, self.num-len(last_position)+len(sell_list))
# 融资融券池的和非融资融券池的分开更新
# 更新融资融券池
for stock in margin_list:
if stock in after_sell_list:
continue
else:
next_margin_list.append(stock)
if len(next_margin_list) == margin_needed:
break
# 更新非融资融券池
for stock in normal_list:
if stock in after_sell_list:
continue
else:
next_normal_list.append(stock)
if len(next_normal_list) >= self.num - margin_needed:
break
next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list})
next_position['date'] = date
# 融资融券数量
margin_num = len(next_margin_list)
# next_position['weight'] = self.leverage*((margin_needed-margin_num)/margin_needed) / len(next_position)
next_position['weight'] = (1 + (margin_num / self.num)) / self.num
next_position['margin_trade'] = 0
next_position = next_position.set_index(['stock_code'])
next_position.loc[next_margin_list, 'margin_trade'] = 1
next_position = next_position.reset_index()
# 检测当前持仓是否可以交易
frozen_list = []
if len(self.position) > 0:
for stock in next_position['stock_code']:
if stock_status.loc[stock] in [0,2]:
frozen_list.append(stock)
return sell_list, buy_list, frozen_list, next_position
def get_price(self,
trade_time: str='open',
target: Iterable[str]=[],
buy_list: Iterable[str] = [],
sell_list: Iterable[str] = []):
"""
获取价格
Args:
trade_time (str): 交易时间
target (Iterable[float]): 目标
buy_list (Iterable[str]): 买入目标
sell_list (Iterable[str]): 卖出目标
"""
stock_price = self.today_data[trade_time]
target_price = pd.Series(index=target)
sell_list = list(set(target) & set(sell_list))
buy_list = list(set(target) & set(buy_list))
target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.fee[1])
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.fee[0])
return target_price
def check_update_status(self,
date: str,
trade_time: str):
# 判断当前更新状态
# 如果日期和交易时间已经存在则返回True
if len(self.account_history) == 0:
return False
elif date < self.account_history['date'].max():
return True
else:
exist_list = self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values
if f'{date}-{trade_time}' in exist_list:
return True
else:
return False
def reblance_weight(self,
trade_time: str,
cur_pos: pd.DataFrame,
next_position: pd.DataFrame):
"""
动态平衡权重
"""
# 判断冻结列表
stock_status = self.today_data['basic'].get(cur_pos['stock_code'].values, 'opening_info')
buy_frozen_list = []
for stock in cur_pos['stock_code']:
if stock_status.loc[stock] in [0,2,4,6]:
buy_frozen_list.append(stock)
sell_frozen_list = []
for stock in cur_pos['stock_code']:
if stock_status.loc[stock] in [0,2,5,7]:
sell_frozen_list.append(stock)
# 设定目标仓位
next_position['target_weight'] = next_position['weight']
next_position['current_weight'] = 0
cur_pos = cur_pos.set_index(['stock_code'])
next_position = next_position.set_index(['stock_code'])
current_list = list(set(cur_pos.index) & set(next_position.index))
next_position.loc[current_list, 'current_weight'] = cur_pos.loc[current_list, 'weight']
next_position['open'] = 0
next_position.loc[current_list, 'open'] = cur_pos.loc[current_list, 'close']
# 计算理想仓位变动
next_position['weight_chg'] = next_position['weight'] - next_position['current_weight']
# 根据冻结判断是否能够变动
next_position['final_weight'] = 0
buy_frozen_list = set(buy_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] > 0].index)
sell_frozen_list = set(sell_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] < 0].index)
next_position.loc[next_position.index, 'final_weight'] = next_position['weight']
next_position.loc[buy_frozen_list, 'final_weight'] = next_position.loc[buy_frozen_list, 'current_weight']
next_position.loc[sell_frozen_list, 'final_weight'] = next_position.loc[sell_frozen_list, 'current_weight']
# 动态平衡仓位
next_position['final_weight'] /= next_position['final_weight'].sum()
next_position['final_weight'] *= next_position['weight'].sum()
# 计算理想仓位变动
next_position['weight_chg'] = next_position['final_weight'] - next_position['current_weight']
next_position.loc[list(buy_frozen_list | sell_frozen_list), 'weight_chg'] = 0
# 动态平衡价格
next_position['adjust_price'] = 0
buy_adjust_list = next_position[next_position['weight_chg'] > 0].index.values
sell_adjust_list = next_position[next_position['weight_chg'] < 0].index.values
next_position.loc[buy_adjust_list, 'adjust_price'] = self.get_price(trade_time, buy_adjust_list, buy_adjust_list, []).values
next_position.loc[sell_adjust_list, 'adjust_price'] = self.get_price(trade_time, sell_adjust_list, [], sell_adjust_list).values
# 价格调整
next_position['adjust_open'] = (next_position['current_weight']*next_position['open'] + next_position['weight_chg']*next_position['adjust_price'])
next_position['adjust_open'] = next_position['adjust_open'] / next_position['final_weight']
next_position.loc[list(buy_frozen_list | sell_frozen_list), 'adjust_open'] = next_position.loc[list(buy_frozen_list | sell_frozen_list), 'open']
# 当日买入不调整
next_position['open'] = next_position['adjust_open']
next_position['weight'] = next_position['final_weight']
next_position = next_position.reset_index()
next_position = next_position[['stock_code','date','open','weight','margin_trade']]
return next_position
def update_account(self,
date: str,
trade_time: str,
cur_pos: pd.DataFrame,
next_position: pd.DataFrame):
"""
更新账户
Args:
date (str): 日期
trade_time (str): 交易时间
cur_pos (DataFrame): 当前持仓
next_position (Iterable[str]): 下一刻持仓
"""
turnover = pd.concat([
cur_pos.set_index(['stock_code'])['weight'].fillna(0).rename('cur'),
next_position.set_index(['stock_code'])['weight'].rename('next'),
], axis=1)
turnover = (turnover['next'] - turnover['cur'].fillna(0)).abs().sum()
leverage = next_position['weight'].sum()
if cur_pos['weight'].sum() == 0:
pnl = 0
else:
pnl = (cur_pos['end_weight'].sum() - cur_pos['weight'].sum())
self.account *= 1+pnl
self.account_history = self.account_history.append({
'date': date,
'trade_time': trade_time,
'turnover': turnover,
'leverage': leverage,
'pnl': pnl
}, ignore_index=True)
return True
def update_signal(self,
date:str,
update_type='rtn'):
"""
更新信号收益
Args:
update_type (str): 更新类型
- position: 只更新持仓不更新收益
- rtn: 更新收益和持仓
"""
# 如果更新日期的close已经记录则跳过否则删除现有日期相关记录继续更新
if f'{date}-close' in self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values:
return True
else:
self.account_history = self.account_history.query(f'date != "{date}" ', engine='python')
if date in self.position_history:
self.position_history.pop(date)
# 更新持仓信号
self.load_data(date, update_type)
# 如果当前持仓不空,添加隔夜收益,否则直接买入
if len(self.position) == 0:
cur_pos = pd.DataFrame(columns=['stock_code','date','weight','open','close','margin_trade'])
else:
cur_pos = self.position.copy()
# 冻结列表
frozen_list = []
# 遍历各个交易时间的信号
for idx,trade_time in enumerate(self.signal):
if self.check_update_status(date, trade_time):
continue
if date in self.signal[trade_time].index:
factor = self.signal[trade_time].loc[date]
else:
continue
factor = self.signal[trade_time].loc[date]
# 获取当前、持仓
sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor)
if update_type == 'position':
self.position_history[date] = next_position.copy()
return True
if len(cur_pos) > 0:
cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], sell_list).values
# 停牌股价格不变
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
# 计算收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
self.update_account(date, trade_time, cur_pos, next_position)
# 更新仓位
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
# 调整权重:买入、卖出、仓位再平衡
next_position = self.reblance_weight(trade_time, cur_pos, next_position)
else:
next_position['open'] = self.get_price(trade_time, next_position['stock_code'].values, buy_list, []).values
self.position = next_position.copy()
# 收盘统计当日收益
trade_time = 'close'
if self.check_update_status(date, trade_time):
return True
cur_pos = self.position.copy()
cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], []).values
# 停牌价格不变
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open']
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
position_record = cur_pos.copy()
position_record['end_weight'] = (position_record['end_weight'] / position_record['end_weight'].sum()) * position_record['weight'].sum()
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']]
next_position['open'] = cur_pos['close']
self.update_account(date, trade_time, cur_pos, cur_pos)
self.position = next_position.copy()
self.position_history[date] = position_record.copy()
return True