# -*- coding: UTF-8 -*- import sys,os import pandas as pd import numpy as np import time from trader import Trader from rich import print as rprint from rich.table import Table class SpreadBacktest(): def __init__( self, trader: Trader ): self.trader = trader def run( self, start: str, end: str ): # 更新可回测时间 self.trader.update_avaliable_date() # 确定回测的开始和结束时间 if len(self.trader.account_history) > 0: bkt_start = max(start, self.trader.avaliable_date.index.min(), self.trader.account_history['date'].max()) else: bkt_start = max(start, self.trader.avaliable_date.index.min()) rec_end = min(end, self.trader.avaliable_date.index.max()) # 如果传入的第一个因子的有值时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益 first_signal = self.trader.signal[list(self.trader.signal.keys())[0]] if rec_end < first_signal.index.max(): if rec_end in first_signal.index: bkt_end = first_signal.loc[rec_end:].index.to_list()[1] else: bkt_end = first_signal.loc[rec_end:].index.to_list()[0] else: bkt_end = rec_end print(f'回测区间: {bkt_start} - {bkt_end}') start_time = time.time() target_list = self.trader.signal[list(self.trader.signal.keys())[0]].loc[bkt_start:bkt_end].index for idx,date in enumerate(sorted(target_list), start=1): sys.stdout.flush() if idx > 1: used_time = time.time()-start_time need_time = int((used_time / idx) * (len(target_list) - idx)) used_time = int(used_time) print("\r", end="") print("回测进度: {:>3d}% [{} -> {}] 用时: {:0>2d}: {:0>2d} / 需要: {:0>2d}: {:0>2d}".format( idx*100//len(target_list), date, bkt_end, used_time//60, used_time%60, need_time//60, need_time%60 ), end="") # avaliable_date最后一天的数据只能用于记录持仓 if (date <= rec_end) and (date < self.trader.avaliable_date.index.max()): self.trader.update_signal(date) else: self.trader.update_signal(date, update_type='position') @property def account_history(self): return self.trader.account_history @property def position_history(self): return self.trader.position_history def analyze(self): """ 分析统计 """ rtn_stat = self.trader.account_history.copy() rtn_stat['pnl'] += 1 rtn_stat = rtn_stat.groupby(['date'])['pnl'].prod().cumprod() # 根据basic data路径确定交易日 trading_day = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.trader.data_root['basic'])]).sort_index() # 按年统计 year_rtn = pd.DataFrame(columns=['收益']) for year in sorted(pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.unique()): start_date = max(rtn_stat.index.min(), trading_day.loc[:'{}-01-01'.format(year)].index.values[-1]) end_date = min(rtn_stat.index.max(), trading_day.loc[:'{}-12-31'.format(year)].index.values[-1]) if year == pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.min(): year_rtn.loc[year, '收益'] = rtn_stat.loc[end_date] - 1 else: year_rtn.loc[year, '收益'] = (rtn_stat.loc[end_date] / rtn_stat.loc[start_date]) - 1 year_rtn.loc['Annualized'] = np.power(rtn_stat.values[-1], 1 / (len(rtn_stat) / 244)) - 1 year_rtn.loc['Avg.Turnover'] = self.trader.account_history.groupby('date')['turnover'].sum().mean() year_rtn = year_rtn.applymap(lambda x: '{:.2%}'.format(x)) year_rtn = year_rtn.reset_index() year_rtn.columns = ['Year', 'Rtn'] rprint("[bold black]1 收益统计[/bold black]") print_year_rtn(year_rtn) def print_year_rtn(year_rtn: pd.DataFrame) -> None: table = Table(show_header=True, header_style='bold') for col in year_rtn: table.add_column(col, justify='center', width=10, no_wrap=True) for _,row in year_rtn.iterrows(): new_row = [] for col in year_rtn.columns: if col == '收益': color = "green" if row['收益'][0] == '-' else "sred" value = f"[{color}]{row[col]}[/{color}]" new_row.append(str(value)) else: new_row.append(str(row[col])) table.add_row(*new_row) rprint(table)