# -*- coding: UTF-8 -*- import sys import os import pandas as pd import numpy as np import time from trader import Trader from typing import Union, Dict from rich import print as rprint from rich.table import Table class SpreadBacktest(): def __init__( self, trader: Trader ): self.trader = trader def run( self, start: str, end: str ): # 更新可回测时间 self.trader.update_avaliable_date() # 确定回测的开始和结束时间 if len(self.trader.account_history) > 0: bkt_start = max(start, self.trader.avaliable_date.index.min(), self.trader.account_history['date'].max()) else: bkt_start = max(start, self.trader.avaliable_date.index.min()) rec_end = min(end, self.trader.avaliable_date.index.max()) # 如果传入的第一个因子的有值时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益 first_signal = self.trader.signal[list(self.trader.signal.keys())[0]] if rec_end < first_signal.index.max(): if rec_end in first_signal.index: bkt_end = first_signal.loc[rec_end:].index.to_list()[1] else: bkt_end = first_signal.loc[rec_end:].index.to_list()[0] else: bkt_end = rec_end print(f'回测区间: {bkt_start} - {bkt_end}') start_time = time.time() target_list = self.trader.signal[list(self.trader.signal.keys())[0]].loc[bkt_start:bkt_end].index for idx,date in enumerate(sorted(target_list), start=1): sys.stdout.flush() if idx > 1: used_time = time.time()-start_time need_time = int((used_time / idx) * (len(target_list) - idx)) used_time = int(used_time) print("\r", end="") print("回测进度: {:>3d}% [{} -> {}] 用时: {:0>2d}: {:0>2d} / 需要: {:0>2d}: {:0>2d}".format( idx*100//len(target_list), date, bkt_end, used_time//60, used_time%60, need_time//60, need_time%60 ), end="") # avaliable_date最后一天的数据只能用于记录持仓 if (date <= rec_end) and (date < self.trader.avaliable_date.index.max()): self.trader.update_signal(date) else: self.trader.update_signal(date, update_type='position') # 更新数据 def update_signal(self, trade_time: str, new_signal: pd.DataFrame): """ 更新信号因子 Args: trade_time (str): 信号时间 new_signal (pd.DataFrame): 新的更新信号 """ self.trader.signal[trade_time] = new_signal def update_interval(self, interval: Dict[str, Union[int,tuple,pd.Series]]={}, ): # 更新interval和weight # 如果interval为固定比例则更新 self.trader.init_interval(interval) @property def account_history(self): return self.trader.account_history @property def position_history(self): return self.trader.position_history def analyze(self): """ 分析统计 """ rtn_stat = self.trader.account_history.copy() rtn_stat['pnl'] += 1 rtn_stat = rtn_stat.groupby(['date'])['pnl'].prod().cumprod() # 根据basic data路径确定交易日 trading_day = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.trader.data_root['basic'])]).sort_index() # 按年统计 year_rtn = pd.DataFrame(columns=['收益']) for year in sorted(pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.unique()): start_date = max(rtn_stat.index.min(), trading_day.loc[:'{}-01-01'.format(year)].index.values[-1]) end_date = min(rtn_stat.index.max(), trading_day.loc[:'{}-12-31'.format(year)].index.values[-1]) if year == pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.min(): year_rtn.loc[year, '收益'] = rtn_stat.loc[end_date] - 1 else: year_rtn.loc[year, '收益'] = (rtn_stat.loc[end_date] / rtn_stat.loc[start_date]) - 1 year_rtn.loc['Annualized'] = np.power(rtn_stat.values[-1], 1 / (len(rtn_stat) / 244)) - 1 year_rtn.loc['Avg.Turnover'] = self.trader.account_history.groupby('date')['turnover'].sum().mean() year_rtn = year_rtn.applymap(lambda x: '{:.2%}'.format(x)) year_rtn = year_rtn.reset_index() year_rtn.columns = ['Year', 'Rtn'] rprint("[bold black]1 收益统计[/bold black]") print_year_rtn(year_rtn) def print_year_rtn(year_rtn: pd.DataFrame) -> None: table = Table(show_header=True, header_style='bold') for col in year_rtn: table.add_column(col, justify='center', width=10, no_wrap=True) for _,row in year_rtn.iterrows(): new_row = [] for col in year_rtn.columns: if col == '收益': color = "green" if row['收益'][0] == '-' else "sred" value = f"[{color}]{row[col]}[/{color}]" new_row.append(str(value)) else: new_row.append(str(row[col])) table.add_row(*new_row) rprint(table)