Compare commits

..

No commits in common. "main" and "0.0.2" have entirely different histories.
main ... 0.0.2

8 changed files with 96 additions and 257 deletions

View File

@ -6,9 +6,4 @@
## Version 0.02
1. 优化非满仓情况下现金比例对收益的计算问题
2. 优化自定义权重下的判断逻辑
3. 修复没有满仓有现金比例时的收益计算问题
## Version 0.03
1. 新增分钟价格计算函数
2. 修复自定义价格时,买卖价格和非买卖价格的选取问题
3. 增加信号更新函数
3. 修复没有满仓有现金比例时的收益计算问题

View File

@ -1,33 +0,0 @@
import pandas as pd
import numpy as np
def amount_specified(
min_data: pd.DataFrame,
post_adj_factor: pd.DataFrame,
min_amount: float=0.
):
"""
计算指定金额下的平均后复权价格
Args:
min_data (pd.DataFrame): 指定日期分钟数据需至少包含代码(stock_code)分钟(Time)价格(price)成交量(vol)成交金额(amount)
post_adj_factor (pd.DataFrame): 后复权因子数据需至少包含股票后复权因子
min_amount (float): 指定最小金额下的平均价
"""
# 按照指定最小量获取平均价格
stock_amt = min_data.pivot_table(index='Time', columns='stock_code', values='amount').cumsum()
stock_vol = min_data.pivot_table(index='Time', columns='stock_code', values='vol').cumsum()
amount_price = stock_amt / stock_vol
amount_price.iloc[1:] = amount_price.iloc[1:].where(stock_amt <= min_amount, np.nan)
amount_price = amount_price.unstack().reset_index()
amount_price.columns = ['stock_code', 'time', 'price']
amount_price = amount_price.dropna(subset=['price']).drop_duplicates(subset='stock_code', keep='last')
# 计算后复权价格
amount_price = amount_price.merge(post_adj_factor, on=['stock_code'], how='left').dropna(subset=['factor'])
amount_price['open_post'] = amount_price['price'] * amount_price['factor']
return amount_price[['stock_code','open_post']]

View File

@ -2,5 +2,4 @@ from account import Account
from trader import Trader
from spread_backtest import Spread_Backtest
__all__ = ['Account', 'Trader', 'Spread_Backtest', 'Specified_Price']
__all__ = ['Account', 'Trader', 'Spread_Backtest']

View File

@ -1,68 +0,0 @@
import numpy as np
def check_buy_exclude(buy_exclude):
buy_exclude_dict = dict()
# 检查剔除条件
for cond in ['amt_20', 'list_days', 'mkt', 'price']:
if cond == 'amt_20':
# 20日成交量均值
if cond in buy_exclude:
amt_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(amt_exclude) == 1:
buy_exclude_dict[cond] = (amt_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (amt_exclude[0], amt_exclude[1])
else:
raise Exception('wrong input type for buy exclude: amt_20, `set`, `list` or `tuple` is required')
if cond == 'list_days':
# 上市时间
if cond in buy_exclude:
list_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(list_exclude) == 1:
buy_exclude_dict[cond] = (list_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (list_exclude[0], list_exclude[1])
else:
raise Exception('wrong input type for buy exclude: list_days, `set`, `list` or `tuple` is required')
if cond == 'mkt':
# 市值
if cond in buy_exclude:
mkt_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(mkt_exclude) == 1:
buy_exclude_dict[cond] = (mkt_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (mkt_exclude[0], mkt_exclude[1])
else:
raise Exception('wrong input type for buy exclude: mkt, `set`, `list` or `tuple` is required')
if cond == 'price':
# 价格
if cond in buy_exclude:
price_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(price_exclude) == 1:
buy_exclude_dict[cond] = (price_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (price_exclude[0], price_exclude[1])
else:
raise Exception('wrong input type for buy exclude: price, `set`, `list` or `tuple` is required')
return buy_exclude_dict

View File

@ -10,8 +10,8 @@ if __name__ == '__main__':
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
save_dir = '/home/lenovo/quant/data/backtest/basic_data'
for i,f in enumerate(['open_post','close_post','open_pre','close_pre','down_limit','up_limit','size','amount_20',
'opening_info','ipo_days','margin_list','abnormal', 'recession']):
for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list',
'abnormal', 'recession']):
if f in ['margin_list']:
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
else:
@ -34,7 +34,7 @@ if __name__ == '__main__':
# 更新下一日的数据用于筛选
next_date = gft.days_after(df.index.max(), 1)
next_list = []
for i,f in enumerate(['close_pre','size','amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
if f in ['margin_list']:
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
else:

View File

@ -1,6 +1,6 @@
import numpy as np
import pandas as pd
from typing import Union, Iterable
from typing import Union, Iterable, Optional, Dict
class DataLoader():
"""

View File

@ -5,7 +5,6 @@ import pandas as pd
import numpy as np
import time
from trader import Trader
from typing import Union, Dict
from rich import print as rprint
from rich.table import Table
@ -61,26 +60,6 @@ class SpreadBacktest():
else:
self.trader.update_signal(date, update_type='position')
# 更新数据
def update_signal(self,
trade_time: str,
new_signal: pd.DataFrame):
"""
更新信号因子
Args:
trade_time (str): 信号时间
new_signal (pd.DataFrame): 新的更新信号
"""
self.trader.signal[trade_time] = new_signal
def update_interval(self,
interval: Dict[str, Union[int,tuple,pd.Series]]={},
):
# 更新interval和weight
# 如果interval为固定比例则更新
self.trader.init_interval(interval)
@property
def account_history(self):
return self.trader.account_history

213
trader.py
View File

@ -1,4 +1,5 @@
import pandas as pd
import numpy as np
import sys
import os
import copy
@ -7,60 +8,59 @@ from typing import Union, Iterable, Dict
from account import Account
from dataloader import DataLoader
from check_funcs import check_buy_exclude
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools
gft = get_factor_tools()
class Trader(Account):
"""
交易类: 用于控制每日交易情况
Args:
signal (dict[str, pd.DataFrame]): 目标因子按顺序执行
interval (dict[str, (int, tuple, pd.Series)]):
交易间隔
num (int): 持仓数量
ascending (bool): 因子方向
with_st (bool): 是否包含st
tick (bool): 是否开始tick模拟模式(开发中)
weight ([str, pd.DataFrame]): 权重分配
signal (dict[str, pd.DataFrame]): 目标因子按顺序执行
interval (int, tuple, pd.Series): 交易间隔
num (int): 持仓数量
ascending (bool): 因子方向
with_st (bool): 是否包含st
tick (bool): 是否开始tick模拟模式(开发中)
weight ([str, pd.DataFrame]): 权重分配
- avg (str): 平均分配每天早盘重新分配日中交易不重新分配
- (pd.DataFrame): 自定义股票权重包含每天个股指定的权重会自动归一化
amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限
data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列
ipo_days (int): 筛选上市时间
slippage (tuple): 买入和卖出滑点
commission (float): 佣金
tax (dict): 印花税
force_exclude (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓
amt_filter (set): 20日均成交额筛选第一个参数是筛选下限第二个参数是筛选上限可以只提供下限
data_root (dict): 对应各个目标因子的交易价格数据必须包含stock_code和price列
ipo_days (int): 筛选上市时间
slippage (tuple): 买入和卖出滑点
commission (float): 佣金
tax (dict): 印花税
exclude_list (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓
- abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除
- receesion: 财报同比或环比下降50%以上
- qualified_opinion: 会计保留意见
account (Account): 账户设置account.Account
account (Account): 账户设置account.Account
"""
def __init__(self,
signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]={},
interval: Dict[str, Union[int,tuple,pd.Series]]=1,
num: int=100,
ascending: bool=False,
with_st: bool=False,
data_root:dict={},
tick: bool=False,
weight: str='avg',
amt_filter: set=(0,),
ipo_days: int=20,
slippage :tuple=(0.001,0.001),
commission: float=0.0001,
buy_exclude: Dict[str, Union[set,int,float]]={},
force_exclude: list=[],
account: dict={},
tax: dict={
'1990-01-01': (0.001,0.001),
'2008-04-24': (0.001,0.001),
'2008-09-19': (0, 0.001),
'2023-08-28': (0, 0.0005)
},
exclude_list: list=[],
account: dict={},
**kwargs) -> None:
# 初始化账户
super().__init__(**account)
@ -78,7 +78,24 @@ class Trader(Account):
if len(kwargs) > 0:
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
# interval
self.init_interval(interval)
self.interval = []
for s in signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
self.interval.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(self.interval)
# num
if isinstance(num, int):
self.num = int(num)
@ -99,6 +116,21 @@ class Trader(Account):
self.weight = weight
else:
raise ValueError('invalid type for `weight`')
# amt_filter
if isinstance(amt_filter, (set,list,tuple)):
if len(amt_filter) == 1:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = np.inf
else:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = amt_filter[1]
else:
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
# ipo_days
if isinstance(ipo_days, int):
self.ipo_days = ipo_days
else:
raise Exception('wrong type for ipo_days, `int` is required')
# slippage
if isinstance(slippage, tuple) and len(slippage) == 2:
self.slippage = slippage
@ -114,19 +146,17 @@ class Trader(Account):
self.tax = tax
else:
raise ValueError('tax should be dict.')
# buy exclude
self.buy_exclude = check_buy_exclude(buy_exclude)
# force exclude
if isinstance(force_exclude, list):
self.force_exclude = force_exclude
# exclude
if isinstance(exclude_list, list):
self.exclude_list = exclude_list
optional_list = ['abnormal', 'recession']
for item in force_exclude:
for item in exclude_list:
if item in optional_list:
pass
else:
raise ValueError(f"Unexpected keyword argument '{item}'")
else:
raise ValueError('force_exclude should be list.')
raise ValueError('exclude_list should be list.')
# data_root
# 至少包含basic data路径open信号默认使用basic_data
if len(data_root) <= 0:
@ -143,30 +173,7 @@ class Trader(Account):
if s not in data_root:
raise ValueError(f"data for signal {s} is not provided")
self.data_root = data_root
def init_interval(self, interval):
"""
初始化interval
"""
interval_list = []
for s in self.signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
interval_list.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(interval_list)
def load_data(self,
date: str,
update_type: str='rtn'):
@ -180,19 +187,14 @@ class Trader(Account):
"""
self.today_data = dict()
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
# 遍历从data_root中读取当日数据
for path in self.data_root:
if path == 'basic':
continue
else:
self.today_data[path] = DataLoader(os.path.join(self.data_root[path],f'{date}.csv'))
if update_type == 'position':
return True
for s in self.signal:
if s == 'open':
if s in self.data_root:
self.today_data[s+'_trade'] = DataLoader(self.today_data['open'].data[['open_post']].rename(columns={'open_post':'price'}))
self.today_data['open'] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
continue
else:
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
else:
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
if 'close' in self.signal:
@ -252,68 +254,43 @@ class Trader(Account):
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
# 不足的数量通过买入列表自适应调整
# 这样能实现在因子值不足时也正常换仓
try:
max_sell_num = self.interval.loc[date]*len(last_position)
except Exception:
raise ValueError(f'not found interval in {date}')
max_sell_num = self.interval.loc[date]*len(last_position)
else:
last_position = pd.Series()
max_sell_num = self.num
# 获取用于筛选的数据
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
exclude_data = dict()
for cond in self.buy_exclude:
if cond == 'amt_20':
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_exclude = stock_amt20.copy()
stock_amt_exclude.loc[:] = 0
stock_amt_exclude.loc[(stock_amt20 > self.buy_exclude[cond][0]) & (stock_amt20 < self.buy_exclude[cond][1])] = 1
stock_amt_exclude = stock_amt_exclude.sort_index()
exclude_data[cond] = stock_amt_exclude
if cond == 'list_days':
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_exclude = stock_ipo_days.copy()
stock_ipo_exclude.loc[:] = 0
stock_ipo_exclude.loc[(stock_ipo_days > self.buy_exclude[cond][0]) & (stock_ipo_days < self.buy_exclude[cond][1])] = 1
stock_ipo_exclude = stock_ipo_exclude.sort_index()
exclude_data[cond] = stock_ipo_exclude
if cond == 'mkt':
stock_size = self.today_data['basic'].get(factor.index.values, 'size').sort_index()
stock_size_exclude = stock_size.copy()
stock_size_exclude.loc[:] = 0
stock_size_exclude.loc[(stock_size > self.buy_exclude[cond][0]) & (stock_size < self.buy_exclude[cond][1])] = 1
stock_size_exclude = stock_size_exclude.sort_index()
exclude_data[cond] = stock_size_exclude
if cond == 'price':
stock_price = self.today_data['basic'].get(factor.index.values, 'close_pre').sort_index()
stock_price_exclude = stock_price.copy()
stock_price_exclude.loc[:] = 0
stock_price_exclude.loc[(stock_price > self.buy_exclude[cond][0]) & (stock_price < self.buy_exclude[cond][1])] = 1
stock_price_exclude = stock_price_exclude.sort_index()
exclude_data[cond] = stock_price_exclude
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_filter = stock_amt20.copy()
stock_amt_filter.loc[:] = 0
stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1
stock_amt_filter = stock_amt_filter.sort_index()
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_filter = stock_ipo_days.copy()
stock_ipo_filter.loc[:] = 0
stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
# 剔除列表
# 包含强制剔除列表和买入剔除列表:
# 包含强制列表和普通列表:
# 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中
# 买入剔除列表如果已经持有不会过滤只对新买入的过滤
# 普通列表如果已经持有不会过滤只对新买入的过滤
# 强制过滤列表
exclude_stock = []
for cond in self.force_exclude:
if cond == 'abnormal':
for exclude in self.exclude_list:
if exclude == 'abnormal':
stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index()
exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list()
if cond == 'recession':
if exclude == 'recession':
stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index()
exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list()
exclude_stock = list(set(exclude_stock))
force_exclude = copy.deepcopy(exclude_stock)
# 买入过滤列表
buy_exclude = []
for cond in self.buy_exclude:
buy_exclude += exclude_data[cond].loc[exclude_data[cond] != 1].index.to_list()
buy_exclude = list(set(buy_exclude))
# 普通过滤列表
normal_exclude = []
normal_exclude += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list()
normal_exclude += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list()
normal_exclude = list(set(normal_exclude))
# 交易列表
# 仓位判断给与计算误差冗余
@ -368,7 +345,7 @@ class Trader(Account):
buy_list = []
# 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表
@ -477,7 +454,7 @@ class Trader(Account):
buy_list = []
# 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表
@ -552,19 +529,13 @@ class Trader(Account):
buy_list (Iterable[str]): 买入目标
sell_list (Iterable[str]): 卖出目标
"""
if trade_time+'_trade' in self.today_data:
trade_price = self.today_data[trade_time+'_trade']
basic_price = self.today_data[trade_time]
else:
basic_price = self.today_data[trade_time]
trade_price = basic_price
stock_price = self.today_data[trade_time]
target_price = pd.Series(index=target)
sell_list = list(set(target) & set(sell_list))
buy_list = list(set(target) & set(buy_list))
# 根据交易和非交易标的分别获取目标价格
target_price.loc[target] = basic_price.get(target, 'price').fillna(0)
target_price.loc[sell_list] = trade_price.get(sell_list, 'price') * (1 - self.current_fee[1])
target_price.loc[buy_list] = trade_price.get(buy_list, 'price') * (1 + self.current_fee[0])
target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1])
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0])
return target_price
def check_update_status(self,
@ -753,8 +724,6 @@ class Trader(Account):
# 计算收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
self.update_account(date, trade_time, cur_pos, next_position)
# 更新仓位
cur_pos['weight'] = self.update_next_weight(cur_pos)
@ -775,8 +744,6 @@ class Trader(Account):
# 更新当日收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
position_record = cur_pos.copy()
position_record['end_weight'] = self.update_next_weight(position_record)
self.position_history[date] = position_record.copy()[position_fields]