# -*- coding: UTF-8 -*- import os import pandas as pd import numpy as np from trader import Trader from rich.progress import track from rich import print as rprint from rich import pretty, text from rich.table import Column, Table from rich.style import Style class Spread_Backtest(): def __init__(self, trader: Trader): self.trader = trader def run(self, start: str, end: str ): # 确定回测的开始和结束时间 if len(self.trader.account_history) > 0: bkt_start = max(start, self.trader.avaliable_date.index.min(), self.trader.account_history['date'].max()) else: bkt_start = max(start, self.trader.avaliable_date.index.min()) rec_end = min(end, self.trader.avaliable_date.index.max()) # 如果传入的第一个因子时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益 first_signal = self.trader.signal[list(self.trader.signal.keys())[0]] if rec_end < first_signal.index.max(): bkt_end = first_signal.loc[rec_end:].index.to_list()[1] else: bkt_end = rec_end print(f'回测区间: {bkt_start} - {bkt_end}') for d in track(self.trader.signal[list(self.trader.signal.keys())[0]].loc[bkt_start:bkt_end].index, description='Backtesting...', update_period=0.5): # avaliable_date最后一天的数据只能用于记录持仓 if (d <= rec_end) and (d < self.trader.avaliable_date.index.max()): self.trader.update_signal(d) else: self.trader.update_signal(d, update_type='position') @property def account_history(self): return self.trader.account_history @property def position_history(self): return self.trader.position_history def analyze(self): """ 分析统计 """ rtn_stat = self.trader.account_history.copy() rtn_stat['pnl'] += 1 rtn_stat = rtn_stat.groupby(['date'])['pnl'].prod().cumprod() # 根据basic data路径确定交易日 trading_day = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.trader.data_root['basic'])]).sort_index() # 按年统计 year_rtn = pd.DataFrame(columns=['收益']) for year in sorted(pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.unique()): start_date = max(rtn_stat.index.min(), trading_day.loc[:'{}-01-01'.format(year)].index.values[-1]) end_date = min(rtn_stat.index.max(), trading_day.loc[:'{}-12-31'.format(year)].index.values[-1]) if year == pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.min(): year_rtn.loc[year, '收益'] = rtn_stat.loc[end_date] - 1 else: year_rtn.loc[year, '收益'] = (rtn_stat.loc[end_date] / rtn_stat.loc[start_date]) - 1 year_rtn.loc['Annualized'] = np.power(rtn_stat.values[-1], 1 / (len(rtn_stat) / 244)) - 1 year_rtn = year_rtn.applymap(lambda x: '{:.2%}'.format(x)) year_rtn = year_rtn.reset_index() year_rtn.columns = ['Year', 'Rtn'] rprint("[bold black]1 收益统计[/bold black]") print_year_rtn(year_rtn) def print_year_rtn(year_rtn: pd.DataFrame) -> None: table = Table(show_header=True, header_style='bold') for col in year_rtn: table.add_column(col, justify='center', width=10, no_wrap=True) for _,row in year_rtn.iterrows(): new_row = [] for col in year_rtn.columns: if col == '收益': color = "green" if row['收益'][0] == '-' else "sred" value = f"[{color}]{row[col]}[/{color}]" new_row.append(str(value)) else: new_row.append(str(row[col])) table.add_row(*new_row) rprint(table)