2024-05-22 23:33:19 +08:00
|
|
|
# -*- coding: UTF-8 -*-
|
2024-06-04 22:29:56 +08:00
|
|
|
import sys
|
|
|
|
import os
|
2024-05-22 23:33:19 +08:00
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
2024-05-28 20:03:20 +08:00
|
|
|
import time
|
2024-05-22 23:33:19 +08:00
|
|
|
from trader import Trader
|
2024-06-24 21:59:34 +08:00
|
|
|
from typing import Union, Dict
|
2024-05-22 23:33:19 +08:00
|
|
|
from rich import print as rprint
|
2024-05-28 20:03:20 +08:00
|
|
|
from rich.table import Table
|
2024-05-22 23:33:19 +08:00
|
|
|
|
|
|
|
|
2024-06-01 16:32:51 +08:00
|
|
|
class SpreadBacktest():
|
2024-05-28 20:03:20 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
trader: Trader
|
|
|
|
):
|
2024-05-22 23:33:19 +08:00
|
|
|
self.trader = trader
|
|
|
|
|
2024-05-28 20:03:20 +08:00
|
|
|
def run(
|
|
|
|
self,
|
2024-05-22 23:33:19 +08:00
|
|
|
start: str,
|
|
|
|
end: str
|
|
|
|
):
|
2024-05-29 21:37:41 +08:00
|
|
|
# 更新可回测时间
|
|
|
|
self.trader.update_avaliable_date()
|
2024-05-22 23:33:19 +08:00
|
|
|
# 确定回测的开始和结束时间
|
|
|
|
if len(self.trader.account_history) > 0:
|
|
|
|
bkt_start = max(start, self.trader.avaliable_date.index.min(), self.trader.account_history['date'].max())
|
|
|
|
else:
|
|
|
|
bkt_start = max(start, self.trader.avaliable_date.index.min())
|
|
|
|
rec_end = min(end, self.trader.avaliable_date.index.max())
|
2024-05-25 23:48:53 +08:00
|
|
|
# 如果传入的第一个因子的有值时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益
|
2024-05-22 23:33:19 +08:00
|
|
|
first_signal = self.trader.signal[list(self.trader.signal.keys())[0]]
|
|
|
|
if rec_end < first_signal.index.max():
|
2024-05-25 23:48:53 +08:00
|
|
|
if rec_end in first_signal.index:
|
|
|
|
bkt_end = first_signal.loc[rec_end:].index.to_list()[1]
|
|
|
|
else:
|
|
|
|
bkt_end = first_signal.loc[rec_end:].index.to_list()[0]
|
2024-05-22 23:33:19 +08:00
|
|
|
else:
|
|
|
|
bkt_end = rec_end
|
|
|
|
print(f'回测区间: {bkt_start} - {bkt_end}')
|
2024-05-28 22:00:04 +08:00
|
|
|
|
2024-05-28 20:03:20 +08:00
|
|
|
start_time = time.time()
|
|
|
|
target_list = self.trader.signal[list(self.trader.signal.keys())[0]].loc[bkt_start:bkt_end].index
|
|
|
|
for idx,date in enumerate(sorted(target_list), start=1):
|
|
|
|
sys.stdout.flush()
|
|
|
|
if idx > 1:
|
|
|
|
used_time = time.time()-start_time
|
|
|
|
need_time = int((used_time / idx) * (len(target_list) - idx))
|
|
|
|
used_time = int(used_time)
|
|
|
|
print("\r", end="")
|
|
|
|
print("回测进度: {:>3d}% [{} -> {}] 用时: {:0>2d}: {:0>2d} / 需要: {:0>2d}: {:0>2d}".format(
|
|
|
|
idx*100//len(target_list), date, bkt_end,
|
|
|
|
used_time//60, used_time%60,
|
|
|
|
need_time//60, need_time%60
|
|
|
|
), end="")
|
2024-05-22 23:33:19 +08:00
|
|
|
# avaliable_date最后一天的数据只能用于记录持仓
|
2024-05-28 20:03:20 +08:00
|
|
|
if (date <= rec_end) and (date < self.trader.avaliable_date.index.max()):
|
|
|
|
self.trader.update_signal(date)
|
2024-05-22 23:33:19 +08:00
|
|
|
else:
|
2024-05-28 20:03:20 +08:00
|
|
|
self.trader.update_signal(date, update_type='position')
|
2024-05-22 23:33:19 +08:00
|
|
|
|
2024-06-24 21:59:34 +08:00
|
|
|
# 更新数据
|
|
|
|
def update_signal(self,
|
|
|
|
trade_time: str,
|
|
|
|
new_signal: pd.DataFrame):
|
|
|
|
"""
|
|
|
|
更新信号因子
|
|
|
|
|
|
|
|
Args:
|
|
|
|
trade_time (str): 信号时间
|
|
|
|
new_signal (pd.DataFrame): 新的更新信号
|
|
|
|
"""
|
|
|
|
self.trader.signal[trade_time] = new_signal
|
|
|
|
|
|
|
|
def update_interval(self,
|
|
|
|
interval: Dict[str, Union[int,tuple,pd.Series]]={},
|
|
|
|
):
|
|
|
|
# 更新interval和weight
|
|
|
|
# 如果interval为固定比例则更新
|
|
|
|
self.trader.init_interval(interval)
|
|
|
|
|
2024-05-22 23:33:19 +08:00
|
|
|
@property
|
|
|
|
def account_history(self):
|
|
|
|
return self.trader.account_history
|
|
|
|
|
|
|
|
@property
|
|
|
|
def position_history(self):
|
|
|
|
return self.trader.position_history
|
|
|
|
|
|
|
|
def analyze(self):
|
|
|
|
"""
|
|
|
|
分析统计
|
|
|
|
"""
|
|
|
|
rtn_stat = self.trader.account_history.copy()
|
|
|
|
rtn_stat['pnl'] += 1
|
|
|
|
rtn_stat = rtn_stat.groupby(['date'])['pnl'].prod().cumprod()
|
|
|
|
|
|
|
|
# 根据basic data路径确定交易日
|
|
|
|
trading_day = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.trader.data_root['basic'])]).sort_index()
|
|
|
|
# 按年统计
|
|
|
|
year_rtn = pd.DataFrame(columns=['收益'])
|
|
|
|
for year in sorted(pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.unique()):
|
|
|
|
start_date = max(rtn_stat.index.min(), trading_day.loc[:'{}-01-01'.format(year)].index.values[-1])
|
|
|
|
end_date = min(rtn_stat.index.max(), trading_day.loc[:'{}-12-31'.format(year)].index.values[-1])
|
|
|
|
if year == pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.min():
|
|
|
|
year_rtn.loc[year, '收益'] = rtn_stat.loc[end_date] - 1
|
|
|
|
else:
|
|
|
|
year_rtn.loc[year, '收益'] = (rtn_stat.loc[end_date] / rtn_stat.loc[start_date]) - 1
|
|
|
|
year_rtn.loc['Annualized'] = np.power(rtn_stat.values[-1], 1 / (len(rtn_stat) / 244)) - 1
|
2024-05-28 00:03:38 +08:00
|
|
|
year_rtn.loc['Avg.Turnover'] = self.trader.account_history.groupby('date')['turnover'].sum().mean()
|
2024-05-22 23:33:19 +08:00
|
|
|
year_rtn = year_rtn.applymap(lambda x: '{:.2%}'.format(x))
|
|
|
|
year_rtn = year_rtn.reset_index()
|
|
|
|
year_rtn.columns = ['Year', 'Rtn']
|
|
|
|
rprint("[bold black]1 收益统计[/bold black]")
|
|
|
|
print_year_rtn(year_rtn)
|
|
|
|
|
|
|
|
|
|
|
|
def print_year_rtn(year_rtn: pd.DataFrame) -> None:
|
|
|
|
table = Table(show_header=True, header_style='bold')
|
|
|
|
for col in year_rtn:
|
|
|
|
table.add_column(col, justify='center', width=10, no_wrap=True)
|
|
|
|
for _,row in year_rtn.iterrows():
|
|
|
|
new_row = []
|
|
|
|
for col in year_rtn.columns:
|
|
|
|
if col == '收益':
|
|
|
|
color = "green" if row['收益'][0] == '-' else "sred"
|
|
|
|
value = f"[{color}]{row[col]}[/{color}]"
|
|
|
|
new_row.append(str(value))
|
|
|
|
else:
|
|
|
|
new_row.append(str(row[col]))
|
|
|
|
table.add_row(*new_row)
|
|
|
|
|
|
|
|
rprint(table)
|
|
|
|
|