spread_backtest/spread_backtest.py

136 lines
5.4 KiB
Python

# -*- coding: UTF-8 -*-
import sys
import os
import pandas as pd
import numpy as np
import time
from trader import Trader
from typing import Union, Dict
from rich import print as rprint
from rich.table import Table
class SpreadBacktest():
def __init__(
self,
trader: Trader
):
self.trader = trader
def run(
self,
start: str,
end: str
):
# 更新可回测时间
self.trader.update_avaliable_date()
# 确定回测的开始和结束时间
if len(self.trader.account_history) > 0:
bkt_start = max(start, self.trader.avaliable_date.index.min(), self.trader.account_history['date'].max())
else:
bkt_start = max(start, self.trader.avaliable_date.index.min())
rec_end = min(end, self.trader.avaliable_date.index.max())
# 如果传入的第一个因子的有值时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益
first_signal = self.trader.signal[list(self.trader.signal.keys())[0]]
if rec_end < first_signal.index.max():
if rec_end in first_signal.index:
bkt_end = first_signal.loc[rec_end:].index.to_list()[1]
else:
bkt_end = first_signal.loc[rec_end:].index.to_list()[0]
else:
bkt_end = rec_end
print(f'回测区间: {bkt_start} - {bkt_end}')
start_time = time.time()
target_list = self.trader.signal[list(self.trader.signal.keys())[0]].loc[bkt_start:bkt_end].index
for idx,date in enumerate(sorted(target_list), start=1):
sys.stdout.flush()
if idx > 1:
used_time = time.time()-start_time
need_time = int((used_time / idx) * (len(target_list) - idx))
used_time = int(used_time)
print("\r", end="")
print("回测进度: {:>3d}% [{} -> {}] 用时: {:0>2d}: {:0>2d} / 需要: {:0>2d}: {:0>2d}".format(
idx*100//len(target_list), date, bkt_end,
used_time//60, used_time%60,
need_time//60, need_time%60
), end="")
# avaliable_date最后一天的数据只能用于记录持仓
if (date <= rec_end) and (date < self.trader.avaliable_date.index.max()):
self.trader.update_signal(date)
else:
self.trader.update_signal(date, update_type='position')
# 更新数据
def update_signal(self,
trade_time: str,
new_signal: pd.DataFrame):
"""
更新信号因子
Args:
trade_time (str): 信号时间
new_signal (pd.DataFrame): 新的更新信号
"""
self.trader.signal[trade_time] = new_signal
def update_interval(self,
interval: Dict[str, Union[int,tuple,pd.Series]]={},
):
# 更新interval和weight
# 如果interval为固定比例则更新
self.trader.init_interval(interval)
@property
def account_history(self):
return self.trader.account_history
@property
def position_history(self):
return self.trader.position_history
def analyze(self):
"""
分析统计
"""
rtn_stat = self.trader.account_history.copy()
rtn_stat['pnl'] += 1
rtn_stat = rtn_stat.groupby(['date'])['pnl'].prod().cumprod()
# 根据basic data路径确定交易日
trading_day = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.trader.data_root['basic'])]).sort_index()
# 按年统计
year_rtn = pd.DataFrame(columns=['收益'])
for year in sorted(pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.unique()):
start_date = max(rtn_stat.index.min(), trading_day.loc[:'{}-01-01'.format(year)].index.values[-1])
end_date = min(rtn_stat.index.max(), trading_day.loc[:'{}-12-31'.format(year)].index.values[-1])
if year == pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.min():
year_rtn.loc[year, '收益'] = rtn_stat.loc[end_date] - 1
else:
year_rtn.loc[year, '收益'] = (rtn_stat.loc[end_date] / rtn_stat.loc[start_date]) - 1
year_rtn.loc['Annualized'] = np.power(rtn_stat.values[-1], 1 / (len(rtn_stat) / 244)) - 1
year_rtn.loc['Avg.Turnover'] = self.trader.account_history.groupby('date')['turnover'].sum().mean()
year_rtn = year_rtn.applymap(lambda x: '{:.2%}'.format(x))
year_rtn = year_rtn.reset_index()
year_rtn.columns = ['Year', 'Rtn']
rprint("[bold black]1 收益统计[/bold black]")
print_year_rtn(year_rtn)
def print_year_rtn(year_rtn: pd.DataFrame) -> None:
table = Table(show_header=True, header_style='bold')
for col in year_rtn:
table.add_column(col, justify='center', width=10, no_wrap=True)
for _,row in year_rtn.iterrows():
new_row = []
for col in year_rtn.columns:
if col == '收益':
color = "green" if row['收益'][0] == '-' else "sred"
value = f"[{color}]{row[col]}[/{color}]"
new_row.append(str(value))
else:
new_row.append(str(row[col]))
table.add_row(*new_row)
rprint(table)