spread_backtest/spread_backtest.py

115 lines
4.7 KiB
Python
Raw Normal View History

2024-05-22 23:33:19 +08:00
# -*- coding: UTF-8 -*-
2024-06-04 22:29:56 +08:00
import sys
import os
2024-05-22 23:33:19 +08:00
import pandas as pd
import numpy as np
2024-05-28 20:03:20 +08:00
import time
2024-05-22 23:33:19 +08:00
from trader import Trader
from rich import print as rprint
2024-05-28 20:03:20 +08:00
from rich.table import Table
2024-05-22 23:33:19 +08:00
class SpreadBacktest():
2024-05-28 20:03:20 +08:00
def __init__(
self,
trader: Trader
):
2024-05-22 23:33:19 +08:00
self.trader = trader
2024-05-28 20:03:20 +08:00
def run(
self,
2024-05-22 23:33:19 +08:00
start: str,
end: str
):
# 更新可回测时间
self.trader.update_avaliable_date()
2024-05-22 23:33:19 +08:00
# 确定回测的开始和结束时间
if len(self.trader.account_history) > 0:
bkt_start = max(start, self.trader.avaliable_date.index.min(), self.trader.account_history['date'].max())
else:
bkt_start = max(start, self.trader.avaliable_date.index.min())
rec_end = min(end, self.trader.avaliable_date.index.max())
# 如果传入的第一个因子的有值时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益
2024-05-22 23:33:19 +08:00
first_signal = self.trader.signal[list(self.trader.signal.keys())[0]]
if rec_end < first_signal.index.max():
if rec_end in first_signal.index:
bkt_end = first_signal.loc[rec_end:].index.to_list()[1]
else:
bkt_end = first_signal.loc[rec_end:].index.to_list()[0]
2024-05-22 23:33:19 +08:00
else:
bkt_end = rec_end
print(f'回测区间: {bkt_start} - {bkt_end}')
2024-05-28 20:03:20 +08:00
start_time = time.time()
target_list = self.trader.signal[list(self.trader.signal.keys())[0]].loc[bkt_start:bkt_end].index
for idx,date in enumerate(sorted(target_list), start=1):
sys.stdout.flush()
if idx > 1:
used_time = time.time()-start_time
need_time = int((used_time / idx) * (len(target_list) - idx))
used_time = int(used_time)
print("\r", end="")
print("回测进度: {:>3d}% [{} -> {}] 用时: {:0>2d}: {:0>2d} / 需要: {:0>2d}: {:0>2d}".format(
idx*100//len(target_list), date, bkt_end,
used_time//60, used_time%60,
need_time//60, need_time%60
), end="")
2024-05-22 23:33:19 +08:00
# avaliable_date最后一天的数据只能用于记录持仓
2024-05-28 20:03:20 +08:00
if (date <= rec_end) and (date < self.trader.avaliable_date.index.max()):
self.trader.update_signal(date)
2024-05-22 23:33:19 +08:00
else:
2024-05-28 20:03:20 +08:00
self.trader.update_signal(date, update_type='position')
2024-05-22 23:33:19 +08:00
@property
def account_history(self):
return self.trader.account_history
@property
def position_history(self):
return self.trader.position_history
def analyze(self):
"""
分析统计
"""
rtn_stat = self.trader.account_history.copy()
rtn_stat['pnl'] += 1
rtn_stat = rtn_stat.groupby(['date'])['pnl'].prod().cumprod()
# 根据basic data路径确定交易日
trading_day = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.trader.data_root['basic'])]).sort_index()
# 按年统计
year_rtn = pd.DataFrame(columns=['收益'])
for year in sorted(pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.unique()):
start_date = max(rtn_stat.index.min(), trading_day.loc[:'{}-01-01'.format(year)].index.values[-1])
end_date = min(rtn_stat.index.max(), trading_day.loc[:'{}-12-31'.format(year)].index.values[-1])
if year == pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.min():
year_rtn.loc[year, '收益'] = rtn_stat.loc[end_date] - 1
else:
year_rtn.loc[year, '收益'] = (rtn_stat.loc[end_date] / rtn_stat.loc[start_date]) - 1
year_rtn.loc['Annualized'] = np.power(rtn_stat.values[-1], 1 / (len(rtn_stat) / 244)) - 1
2024-05-28 00:03:38 +08:00
year_rtn.loc['Avg.Turnover'] = self.trader.account_history.groupby('date')['turnover'].sum().mean()
2024-05-22 23:33:19 +08:00
year_rtn = year_rtn.applymap(lambda x: '{:.2%}'.format(x))
year_rtn = year_rtn.reset_index()
year_rtn.columns = ['Year', 'Rtn']
rprint("[bold black]1 收益统计[/bold black]")
print_year_rtn(year_rtn)
def print_year_rtn(year_rtn: pd.DataFrame) -> None:
table = Table(show_header=True, header_style='bold')
for col in year_rtn:
table.add_column(col, justify='center', width=10, no_wrap=True)
for _,row in year_rtn.iterrows():
new_row = []
for col in year_rtn.columns:
if col == '收益':
color = "green" if row['收益'][0] == '-' else "sred"
value = f"[{color}]{row[col]}[/{color}]"
new_row.append(str(value))
else:
new_row.append(str(row[col]))
table.add_row(*new_row)
rprint(table)