spread_backtest/spread_backtest.py

98 lines
4.0 KiB
Python
Raw Normal View History

2024-05-22 23:33:19 +08:00
# -*- coding: UTF-8 -*-
import os
import pandas as pd
import numpy as np
from trader import Trader
from rich.progress import track
from rich import print as rprint
from rich import pretty, text
from rich.table import Column, Table
from rich.style import Style
class Spread_Backtest():
def __init__(self,
trader: Trader):
self.trader = trader
def run(self,
start: str,
end: str
):
# 确定回测的开始和结束时间
if len(self.trader.account_history) > 0:
bkt_start = max(start, self.trader.avaliable_date.index.min(), self.trader.account_history['date'].max())
else:
bkt_start = max(start, self.trader.avaliable_date.index.min())
rec_end = min(end, self.trader.avaliable_date.index.max())
# 如果传入的第一个因子的有值时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益
2024-05-22 23:33:19 +08:00
first_signal = self.trader.signal[list(self.trader.signal.keys())[0]]
if rec_end < first_signal.index.max():
if rec_end in first_signal.index:
bkt_end = first_signal.loc[rec_end:].index.to_list()[1]
else:
bkt_end = first_signal.loc[rec_end:].index.to_list()[0]
2024-05-22 23:33:19 +08:00
else:
bkt_end = rec_end
print(f'回测区间: {bkt_start} - {bkt_end}')
for d in track(self.trader.signal[list(self.trader.signal.keys())[0]].loc[bkt_start:bkt_end].index,
description='Backtesting...',
update_period=0.5):
# avaliable_date最后一天的数据只能用于记录持仓
if (d <= rec_end) and (d < self.trader.avaliable_date.index.max()):
self.trader.update_signal(d)
else:
self.trader.update_signal(d, update_type='position')
@property
def account_history(self):
return self.trader.account_history
@property
def position_history(self):
return self.trader.position_history
def analyze(self):
"""
分析统计
"""
rtn_stat = self.trader.account_history.copy()
rtn_stat['pnl'] += 1
rtn_stat = rtn_stat.groupby(['date'])['pnl'].prod().cumprod()
# 根据basic data路径确定交易日
trading_day = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.trader.data_root['basic'])]).sort_index()
# 按年统计
year_rtn = pd.DataFrame(columns=['收益'])
for year in sorted(pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.unique()):
start_date = max(rtn_stat.index.min(), trading_day.loc[:'{}-01-01'.format(year)].index.values[-1])
end_date = min(rtn_stat.index.max(), trading_day.loc[:'{}-12-31'.format(year)].index.values[-1])
if year == pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.min():
year_rtn.loc[year, '收益'] = rtn_stat.loc[end_date] - 1
else:
year_rtn.loc[year, '收益'] = (rtn_stat.loc[end_date] / rtn_stat.loc[start_date]) - 1
year_rtn.loc['Annualized'] = np.power(rtn_stat.values[-1], 1 / (len(rtn_stat) / 244)) - 1
year_rtn = year_rtn.applymap(lambda x: '{:.2%}'.format(x))
year_rtn = year_rtn.reset_index()
year_rtn.columns = ['Year', 'Rtn']
rprint("[bold black]1 收益统计[/bold black]")
print_year_rtn(year_rtn)
def print_year_rtn(year_rtn: pd.DataFrame) -> None:
table = Table(show_header=True, header_style='bold')
for col in year_rtn:
table.add_column(col, justify='center', width=10, no_wrap=True)
for _,row in year_rtn.iterrows():
new_row = []
for col in year_rtn.columns:
if col == '收益':
color = "green" if row['收益'][0] == '-' else "sred"
value = f"[{color}]{row[col]}[/{color}]"
new_row.append(str(value))
else:
new_row.append(str(row[col]))
table.add_row(*new_row)
rprint(table)