This commit is contained in:
binz 2024-05-22 23:33:19 +08:00
commit e0d025f71c
17 changed files with 790 additions and 0 deletions

0
.gitignore vendored Normal file
View File

0
README.md Normal file
View File

5
__init__.py Normal file
View File

@ -0,0 +1,5 @@
from account import Account
from trader import Trader
from spread_backtest import Spread_Backtest
__all__ = ['Account', 'Trader', 'Spread_Backtest']

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

63
account.py Normal file
View File

@ -0,0 +1,63 @@
import warnings
import pandas as pd
from rich import print as rprint
from typing import Union, Iterable, Optional, Dict
# 警告格式
def custom_warnings(message, category, filename, lineno, file=None, line=None):
rprint("[bold red]Warning:[/bold red]", message)
warnings.showwarning = custom_warnings
class CustomWarning(UserWarning):
pass
warnings.simplefilter('always', category=CustomWarning)
class Account():
"""
账户类用于控制账户类型
Arguments:
- is_real(bool): 是否为真实交易
- init_account(Union[int,float]): 初始账户金额
- trade_limit(float): 是否限制每个股票的购买比例(只在真实账户中有效)
- leverage(float): 杠杆比例
"""
def __init__(self,
is_real: bool = False,
init_account: Union[int,float] = 1e6,
trade_limit: float = 0.01,
leverage: Union[int,float] = 1.) -> None:
if is_real:
self.a_type = 'amount'
self.account = init_account
else:
self.a_type = 'percent'
self.account = 1.0
self.trade_limit = trade_limit
if isinstance(leverage, (int,float)):
if leverage > 1.5:
raise ValueError('leverage should less than 1.5')
if leverage > 1:
warnings.warn('leverage(杠杆率)大于1时优先满足杠杆会出现高于指定换手率的情况', category=CustomWarning)
else:
raise ValueError('leverage should be int or float')
self.leverage = leverage
# 当前持仓
self.position = pd.DataFrame()
# 账户收益更新
self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','leverage','pnl'])
# 历史持仓
self.position_history = dict()
# 当前日期
self.date = None
# 获取收盘时持仓和仓位
@property
def close_position(self):
return self.position[self.position['trade_time'] == 'close']['stock_code'].values
@property
def close_weight(self):
return self.position[self.position['trade_time'] == 'close'].weight.values

44
data_handler.py Normal file
View File

@ -0,0 +1,44 @@
import pandas as pd
import os, sys
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools
gft = get_factor_tools()
if __name__ == '__main__':
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
save_dir = '/home/lenovo/quant/data/backtest/basic_data'
for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list']):
if f in ['margin_list']:
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
else:
tmp = pd.read_csv(f'{data_dir}/{f}.csv', index_col=0)
tmp = tmp.unstack().reset_index()
tmp.columns = ['stock_code', 'date', f]
if i == 0:
df = tmp
else:
df = df.merge(tmp, on=['stock_code', 'date'], how="left")
df = df.set_index(['date']).sort_index()
existed = os.listdir(save_dir)
for d in sorted(df.index.unique()):
if (d+'.csv' in existed) and (d+'.csv' != max(existed)):
continue
else:
df.loc[d].sort_values(by=['stock_code']).to_csv(f'{save_dir}/{d}.csv', index=False)
# 更新下一日的数据用于筛选
next_date = gft.days_after(df.index.max(), 1)
next_list = []
for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list']):
if f in ['margin_list']:
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
else:
next_list.append(pd.Series(pd.read_csv(f'{data_dir}/{f}.csv', index_col=0).iloc[-1], name=f))
df = pd.concat(next_list, axis=1)
df.index.name = 'stock_code'
df = df.reset_index()
df.sort_values(by=['stock_code']).to_csv(f'{save_dir}/{next_date}.csv', index=False)

32
dataloader.py Normal file
View File

@ -0,0 +1,32 @@
import numpy as np
import pandas as pd
from typing import Union, Iterable, Optional, Dict
class DataLoader():
"""
数据类: 数据加载模块
"""
def __init__(self, path: Union[str, pd.DataFrame]=None):
if type(path) == str:
self.data = pd.read_csv(path).set_index(['stock_code']).sort_index()
if type(path) == pd.DataFrame:
self.data = path.copy()
def get(self,
target: Iterable[str]=[],
column: str=''):
"""
- target: 查询目标代码
- column(str): 查询列
"""
res = pd.Series(index=target)
column_map = dict(zip(self.data.index.values, self.data[column].values))
column_type = self.data[column].dtype
stock_list = list(set(self.data.index.values) & set(target))
res.loc[stock_list] = self.data.loc[stock_list, column].values
stock_list = list(set(target) - set(stock_list))
if column_type == str:
res.loc[stock_list] = None
else:
res.loc[stock_list] = np.nan
return res.sort_index()

96
spread_backtest.py Normal file
View File

@ -0,0 +1,96 @@
# -*- coding: UTF-8 -*-
import os
import pandas as pd
import numpy as np
from trader import Trader
from rich.progress import track
from rich import print as rprint
from rich import pretty, text
from rich.table import Column, Table
from rich.style import Style
class Spread_Backtest():
def __init__(self,
trader: Trader):
self.trader = trader
def run(self,
start: str,
end: str
):
# 确定回测的开始和结束时间
if len(self.trader.account_history) > 0:
bkt_start = max(start, self.trader.avaliable_date.index.min(), self.trader.account_history['date'].max())
else:
bkt_start = max(start, self.trader.avaliable_date.index.min())
rec_end = min(end, self.trader.avaliable_date.index.max())
# 如果传入的第一个因子时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益
first_signal = self.trader.signal[list(self.trader.signal.keys())[0]]
if rec_end < first_signal.index.max():
bkt_end = first_signal.loc[rec_end:].index.to_list()[1]
else:
bkt_end = rec_end
print(f'回测区间: {bkt_start} - {bkt_end}')
for d in track(self.trader.signal[list(self.trader.signal.keys())[0]].loc[bkt_start:bkt_end].index,
description='Backtesting...',
update_period=0.5):
# avaliable_date最后一天的数据只能用于记录持仓
if (d <= rec_end) and (d < self.trader.avaliable_date.index.max()):
self.trader.update_signal(d)
else:
self.trader.update_signal(d, update_type='position')
@property
def account_history(self):
return self.trader.account_history
@property
def position_history(self):
return self.trader.position_history
def analyze(self):
"""
分析统计
"""
rtn_stat = self.trader.account_history.copy()
rtn_stat['pnl'] += 1
rtn_stat = rtn_stat.groupby(['date'])['pnl'].prod().cumprod()
# 根据basic data路径确定交易日
trading_day = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.trader.data_root['basic'])]).sort_index()
# 按年统计
year_rtn = pd.DataFrame(columns=['收益'])
for year in sorted(pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.unique()):
start_date = max(rtn_stat.index.min(), trading_day.loc[:'{}-01-01'.format(year)].index.values[-1])
end_date = min(rtn_stat.index.max(), trading_day.loc[:'{}-12-31'.format(year)].index.values[-1])
if year == pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.min():
year_rtn.loc[year, '收益'] = rtn_stat.loc[end_date] - 1
else:
year_rtn.loc[year, '收益'] = (rtn_stat.loc[end_date] / rtn_stat.loc[start_date]) - 1
year_rtn.loc['Annualized'] = np.power(rtn_stat.values[-1], 1 / (len(rtn_stat) / 244)) - 1
year_rtn = year_rtn.applymap(lambda x: '{:.2%}'.format(x))
year_rtn = year_rtn.reset_index()
year_rtn.columns = ['Year', 'Rtn']
rprint("[bold black]1 收益统计[/bold black]")
print_year_rtn(year_rtn)
def print_year_rtn(year_rtn: pd.DataFrame) -> None:
table = Table(show_header=True, header_style='bold')
for col in year_rtn:
table.add_column(col, justify='center', width=10, no_wrap=True)
for _,row in year_rtn.iterrows():
new_row = []
for col in year_rtn.columns:
if col == '收益':
color = "green" if row['收益'][0] == '-' else "sred"
value = f"[{color}]{row[col]}[/{color}]"
new_row.append(str(value))
else:
new_row.append(str(row[col]))
table.add_row(*new_row)
rprint(table)

550
trader.py Normal file
View File

@ -0,0 +1,550 @@
import pandas as pd
import numpy as np
import sys, os
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools
gft = get_factor_tools()
from typing import Union, Iterable, Dict
from account import Account
from dataloader import DataLoader
class Trader(Account):
"""
交易类: 用于控制每日交易情况
Args:
signal (dict[str, pd.DataFrame]): 目标因子按顺序执行
interval (int, tuple, pd.Series): 交易间隔
num (int): 持仓数量
ascending (bool): 因子方向
fee (tuple): 买入成本和卖出成本
with_st (bool): 是否包含st
tick (bool): 是否开始tick模拟模式(开发中)
weight (str): 权重分配
- avg: 平均分配每天早盘重新分配日中交易不重新分配
amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限
data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列
ipo_days (int): 筛选上市时间
"""
def __init__(self,
signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]=1,
num: int=100,
ascending: bool=False,
fee :tuple=(0.001,0.002),
with_st: bool=False,
data_root:dict={},
tick: bool=False,
weight: str='avg',
amt_filter: set=(0,),
ipo_days: int=20,
**kwargs) -> None:
# 初始化账户
super().__init__(**kwargs.get('account', {}))
if isinstance(signal, dict):
self.signal = signal
if 'close' in signal:
raise ValueError('signal key cannot be close')
for s in self.signal:
self.signal[s] = gft.return_factor(self.signal[s], self.signal[s].index.min(), self.signal[s].index.max(), return_type='origin')
else:
raise ValueError('type of signal is invalid')
# interval
self.interval = []
for s in signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
self.interval.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(self.interval)
# num
if isinstance(num, int):
self.num = int(num)
else:
raise ValueError('num should be int')
# ascending
if isinstance(ascending, bool):
self.ascending = ascending
else:
raise ValueError('invalid type for ascending')
# fee
if isinstance(fee, tuple) and len(fee) == 2:
self.fee = fee
else:
raise ValueError('invalid input for fee')
# with_st
if isinstance(with_st, bool):
self.with_st = with_st
else:
raise ValueError('invalid type for with_st')
# amt_filter
if isinstance(amt_filter, (set,list,tuple)):
if len(amt_filter) == 1:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = np.inf
else:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = amt_filter[1]
else:
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
# ipo_days
if isinstance(ipo_days, int):
self.ipo_days = ipo_days
else:
raise Exception('wrong type for ipo_days, `int` is required')
# data_root
# 至少包含basic data路径open信号默认使用basic_data
if len(data_root) <= 0:
raise ValueError('num of data_root should be equal or greater than 1')
if 'basic' in data_root:
# 可执行日期
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(data_root['basic'])]).sort_index()
else:
raise ValueError('data_root should contain basic data root')
for s in self.signal:
if s == 'open':
continue
else:
if s not in data_root:
raise ValueError(f'data for signal {s} is not provided')
self.data_root = data_root
def load_data(self,
date: str,
update_type: str='rtn'):
"""
加载每日基础数据
Args:
update_type (str): 更新模式
- rtn: 更新所有信号数据
- position: 只更新basic数据用于持仓判断
"""
self.today_data = dict()
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
if update_type == 'position':
return True
for s in self.signal:
if s == 'open':
if s in self.data_root:
continue
else:
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
else:
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
if 'close' in self.signal:
pass
else:
self.today_data['close'] = DataLoader(self.today_data['basic'].data[['close_post']].rename(columns={'close_post':'price'}))
def get_next_position(self, date, factor):
"""
计算下一时刻持仓
"""
# 计算持仓和最大可交易数量
if len(self.position) > 0:
last_position = self.position['stock_code'].values
last_position = factor.loc[last_position].sort_values(ascending=self.ascending)
max_trade_num = max(int(self.interval.loc[date]*self.num), self.num-len(self.position))
else:
last_position = pd.Series()
max_trade_num = self.num
target_list = []
# 获取用于筛选的数据
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_filter = stock_amt20.copy()
stock_amt_filter.loc[:] = 0
stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1
stock_amt_filter = stock_amt_filter.sort_index()
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_filter = stock_ipo_days.copy()
stock_ipo_filter.loc[:] = 0
stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
# 交易列表
if self.leverage <= 1.0:
# 获取当前时间目标列表和冻结(无法交易)列表
for stock in factor.dropna().sort_values(ascending=self.ascending).index.values:
if (stock_amt_filter.loc[stock] != 1) or (stock_ipo_filter.loc[stock] != 1):
continue
if self.with_st:
if stock_status.loc[stock] in [0,2]:
if stock in last_position.index:
target_list.append(stock)
else:
target_list.append(stock)
else:
# 非ST
if stock_status.loc[stock] in [3,7]:
target_list.append(stock)
else:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2,6]:
if stock in last_position.index:
target_list.append(stock)
if len(target_list) == self.num:
break
# 如果没有杠杆
buy_list = []
sell_list = []
# ----- 卖出 -----
# 按照反向排名逐个卖出
if self.ascending:
factor = factor.fillna(factor.max()+1)
else:
factor = factor.fillna(factor.min()-1)
for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]:
if stock in target_list:
continue
else:
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
if len(sell_list) == max_trade_num:
break
# ----- 买入 -----
# 卖出后持仓列表
after_sell_list = set(last_position.index) - set(sell_list)
max_trade_num = min(max_trade_num, self.num-len(last_position)+len(sell_list))
for stock in target_list:
if stock in after_sell_list:
continue
else:
buy_list.append(stock)
if len(buy_list) == max_trade_num:
break
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
next_position['date'] = date
next_position['weight'] = self.leverage / len(next_position)
next_position['margin_trade'] = 0
else:
# 如果有杠杆
def assign_stock(normal_list, margin_list, margin_needed, stock, status):
if status == 1:
if len(margin_list) < margin_needed:
margin_list.append(stock)
else:
if len(normal_list) < self.num - margin_needed:
normal_list.append(stock)
return normal_list, margin_list
# 计算需要融资融券标的数量
margin_ratio = max(self.leverage-1, 0)
margin_needed = round(self.num * margin_ratio)
is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index()
# 获取当前时间目标列表和冻结(无法交易)列表
normal_list = []
margin_list = []
for stock in factor.dropna().sort_values(ascending=self.ascending).index.values:
if stock_amt_filter.loc[stock] != 1:
continue
if self.with_st:
if stock_status.loc[stock] in [0,2]:
if stock in last_position.index:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
# 非ST
if stock_status.loc[stock] in [3,7]:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
else:
# 如果停牌或者跌停继续持有
if stock_status.loc[stock] in [0,2,6]:
if stock in last_position.index:
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
if len(normal_list + margin_list) == self.num:
break
target_list = normal_list + margin_list
# ----- 卖出 -----
buy_list = []
sell_list = []
# 融资融券池的和非融资融券池的分开更新
# 更新融资融券池
if len(last_position) > 0:
last_margin_list = self.position.loc[self.position['margin_trade'] == 1, 'stock_code'].to_list()
else:
last_margin_list = []
for stock in factor.loc[last_margin_list].sort_values(ascending=self.ascending).index.values[::-1]:
if stock in normal_list:
continue
else:
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
if len(sell_list) >= int(max_trade_num * margin_ratio) + 1:
break
next_margin_list = list(set(last_margin_list) - set(sell_list))
# 更新非融资融券池
if len(last_position) > 0:
last_normal_list = self.position.loc[self.position['margin_trade'] == 0, 'stock_code'].to_list()
else:
last_normal_list = []
for stock in factor.loc[last_normal_list].sort_values(ascending=self.ascending).index.values[::-1]:
if stock in normal_list:
continue
else:
if stock_status.loc[stock] in [0,2,5,7]:
continue
else:
sell_list.append(stock)
if len(sell_list) >= max_trade_num:
break
next_normal_list = list(set(last_normal_list) - set(sell_list))
# ----- 买入 -----
# 卖出后持仓列表
after_sell_list = set(last_position.index) - set(sell_list)
max_trade_num = min(max_trade_num, self.num-len(last_position)+len(sell_list))
# 融资融券池的和非融资融券池的分开更新
# 更新融资融券池
for stock in margin_list:
if stock in after_sell_list:
continue
else:
next_margin_list.append(stock)
if len(next_margin_list) == margin_needed:
break
# 更新非融资融券池
for stock in normal_list:
if stock in after_sell_list:
continue
else:
next_normal_list.append(stock)
if len(next_normal_list) >= self.num - margin_needed:
break
next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list})
next_position['date'] = date
# 融资融券数量
margin_num = len(next_margin_list)
# next_position['weight'] = self.leverage*((margin_needed-margin_num)/margin_needed) / len(next_position)
next_position['weight'] = (1 + (margin_num / self.num)) / self.num
next_position['margin_trade'] = 0
next_position = next_position.set_index(['stock_code'])
next_position.loc[next_margin_list, 'margin_trade'] = 1
next_position = next_position.reset_index()
# 检测当前持仓是否可以交易
frozen_list = []
if len(self.position) > 0:
for stock in next_position['stock_code']:
if stock_status.loc[stock] in [0,2]:
frozen_list.append(stock)
return sell_list, buy_list, frozen_list, next_position
def get_price(self,
trade_time: str='open',
target: Iterable[str]=[],
buy_list: Iterable[str] = [],
sell_list: Iterable[str] = []):
"""
获取价格
Args:
trade_time (str): 交易时间
target (Iterable[float]): 目标
buy_list (Iterable[str]): 买入目标
sell_list (Iterable[str]): 卖出目标
"""
stock_price = self.today_data[trade_time]
target_price = pd.Series(index=target)
sell_list = list(set(target) & set(sell_list))
buy_list = list(set(target) & set(buy_list))
target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.fee[1])
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.fee[0])
return target_price
def check_update_status(self,
date: str,
trade_time: str):
# 判断当前更新状态
# 如果日期和交易时间已经存在则返回True
if len(self.account_history) == 0:
return False
elif date < self.account_history['date'].max():
return True
else:
exist_list = self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values
if f'{date}-{trade_time}' in exist_list:
return True
else:
return False
def reblance_weight(self,
trade_time: str,
cur_pos: pd.DataFrame,
next_position: pd.DataFrame):
"""
动态平衡权重
"""
# 判断冻结列表
stock_status = self.today_data['basic'].get(cur_pos['stock_code'].values, 'opening_info')
buy_frozen_list = []
for stock in cur_pos['stock_code']:
if stock_status.loc[stock] in [0,2,4,6]:
buy_frozen_list.append(stock)
sell_frozen_list = []
for stock in cur_pos['stock_code']:
if stock_status.loc[stock] in [0,2,5,7]:
sell_frozen_list.append(stock)
# 设定目标仓位
next_position['target_weight'] = next_position['weight']
next_position['current_weight'] = 0
cur_pos = cur_pos.set_index(['stock_code'])
next_position = next_position.set_index(['stock_code'])
current_list = list(set(cur_pos.index) & set(next_position.index))
next_position.loc[current_list, 'current_weight'] = cur_pos.loc[current_list, 'weight']
next_position['open'] = 0
next_position.loc[current_list, 'open'] = cur_pos.loc[current_list, 'close']
# 计算理想仓位变动
next_position['weight_chg'] = next_position['weight'] - next_position['current_weight']
# 根据冻结判断是否能够变动
next_position['final_weight'] = 0
buy_frozen_list = set(buy_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] > 0].index)
sell_frozen_list = set(sell_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] < 0].index)
next_position.loc[next_position.index, 'final_weight'] = next_position['weight']
next_position.loc[buy_frozen_list, 'final_weight'] = next_position.loc[buy_frozen_list, 'current_weight']
next_position.loc[sell_frozen_list, 'final_weight'] = next_position.loc[sell_frozen_list, 'current_weight']
# 动态平衡仓位
next_position['final_weight'] /= next_position['final_weight'].sum()
next_position['final_weight'] *= next_position['weight'].sum()
# 计算理想仓位变动
next_position['weight_chg'] = next_position['final_weight'] - next_position['current_weight']
next_position.loc[list(buy_frozen_list | sell_frozen_list), 'weight_chg'] = 0
# 动态平衡价格
next_position['adjust_price'] = 0
buy_adjust_list = next_position[next_position['weight_chg'] > 0].index.values
sell_adjust_list = next_position[next_position['weight_chg'] < 0].index.values
next_position.loc[buy_adjust_list, 'adjust_price'] = self.get_price(trade_time, buy_adjust_list, buy_adjust_list, []).values
next_position.loc[sell_adjust_list, 'adjust_price'] = self.get_price(trade_time, sell_adjust_list, [], sell_adjust_list).values
# 价格调整
next_position['adjust_open'] = (next_position['current_weight']*next_position['open'] + next_position['weight_chg']*next_position['adjust_price'])
next_position['adjust_open'] = next_position['adjust_open'] / next_position['final_weight']
next_position.loc[list(buy_frozen_list | sell_frozen_list), 'adjust_open'] = next_position.loc[list(buy_frozen_list | sell_frozen_list), 'open']
# 当日买入不调整
next_position['open'] = next_position['adjust_open']
next_position['weight'] = next_position['final_weight']
next_position = next_position.reset_index()
next_position = next_position[['stock_code','date','open','weight','margin_trade']]
return next_position
def update_account(self,
date: str,
trade_time: str,
cur_pos: pd.DataFrame,
next_position: pd.DataFrame):
"""
更新账户
Args:
date (str): 日期
trade_time (str): 交易时间
cur_pos (DataFrame): 当前持仓
next_position (Iterable[str]): 下一刻持仓
"""
turnover = pd.concat([
cur_pos.set_index(['stock_code'])['weight'].fillna(0).rename('cur'),
next_position.set_index(['stock_code'])['weight'].rename('next'),
], axis=1)
turnover = (turnover['next'] - turnover['cur'].fillna(0)).abs().sum()
leverage = next_position['weight'].sum()
if cur_pos['weight'].sum() == 0:
pnl = 0
else:
pnl = (cur_pos['end_weight'].sum() - cur_pos['weight'].sum())
self.account *= 1+pnl
self.account_history = self.account_history.append({
'date': date,
'trade_time': trade_time,
'turnover': turnover,
'leverage': leverage,
'pnl': pnl
}, ignore_index=True)
return True
def update_signal(self,
date:str,
update_type='rtn'):
"""
更新信号收益
Args:
update_type (str): 更新类型
- position: 只更新持仓不更新收益
- rtn: 更新收益和持仓
"""
# 如果更新日期的close已经记录则跳过否则删除现有日期相关记录继续更新
if f'{date}-close' in self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values:
return True
else:
self.account_history = self.account_history.query(f'date != "{date}" ', engine='python')
if date in self.position_history:
self.position_history.pop(date)
# 更新持仓信号
self.load_data(date, update_type)
# 如果当前持仓不空,添加隔夜收益,否则直接买入
if len(self.position) == 0:
cur_pos = pd.DataFrame(columns=['stock_code','date','weight','open','close','margin_trade'])
else:
cur_pos = self.position.copy()
# 冻结列表
frozen_list = []
# 遍历各个交易时间的信号
for idx,trade_time in enumerate(self.signal):
if self.check_update_status(date, trade_time):
continue
if date in self.signal[trade_time].index:
factor = self.signal[trade_time].loc[date]
else:
continue
factor = self.signal[trade_time].loc[date]
# 获取当前、持仓
sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor)
if update_type == 'position':
self.position_history[date] = next_position.copy()
return True
if len(cur_pos) > 0:
cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], sell_list).values
# 停牌股价格不变
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
# 计算收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
self.update_account(date, trade_time, cur_pos, next_position)
# 更新仓位
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
# 调整权重:买入、卖出、仓位再平衡
next_position = self.reblance_weight(trade_time, cur_pos, next_position)
else:
next_position['open'] = self.get_price(trade_time, next_position['stock_code'].values, buy_list, []).values
self.position = next_position.copy()
# 收盘统计当日收益
trade_time = 'close'
if self.check_update_status(date, trade_time):
return True
cur_pos = self.position.copy()
cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], []).values
# 停牌价格不变
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open']
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
position_record = cur_pos.copy()
position_record['end_weight'] = (position_record['end_weight'] / position_record['end_weight'].sum()) * position_record['weight'].sum()
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']]
next_position['open'] = cur_pos['close']
self.update_account(date, trade_time, cur_pos, cur_pos)
self.position = next_position.copy()
self.position_history[date] = position_record.copy()
return True