Compare commits
7 Commits
Author | SHA1 | Date |
---|---|---|
binz | 0977e8c9fd | |
binz | d671f820d4 | |
binz | 721e2a42f5 | |
binz | 30a998b58a | |
binz | 411e1a3f78 | |
binz | d5a2d9cb9c | |
binz | 28e1bb27c6 |
|
@ -6,4 +6,9 @@
|
|||
## Version 0.02
|
||||
1. 优化非满仓情况下现金比例对收益的计算问题
|
||||
2. 优化自定义权重下的判断逻辑
|
||||
3. 修复没有满仓有现金比例时的收益计算问题
|
||||
3. 修复没有满仓有现金比例时的收益计算问题
|
||||
|
||||
## Version 0.03
|
||||
1. 新增分钟价格计算函数
|
||||
2. 修复自定义价格时,买卖价格和非买卖价格的选取问题
|
||||
3. 增加信号更新函数
|
|
@ -0,0 +1,33 @@
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
def amount_specified(
|
||||
min_data: pd.DataFrame,
|
||||
post_adj_factor: pd.DataFrame,
|
||||
min_amount: float=0.
|
||||
):
|
||||
"""
|
||||
计算指定金额下的平均后复权价格
|
||||
|
||||
Args:
|
||||
min_data (pd.DataFrame): 指定日期分钟数据,需至少包含代码(stock_code)、分钟(Time)、价格(price)、成交量(vol)、成交金额(amount)列
|
||||
post_adj_factor (pd.DataFrame): 后复权因子数据,需至少包含股票、后复权因子
|
||||
min_amount (float): 指定最小金额下的平均价
|
||||
"""
|
||||
|
||||
# 按照指定最小量获取平均价格
|
||||
stock_amt = min_data.pivot_table(index='Time', columns='stock_code', values='amount').cumsum()
|
||||
stock_vol = min_data.pivot_table(index='Time', columns='stock_code', values='vol').cumsum()
|
||||
amount_price = stock_amt / stock_vol
|
||||
amount_price.iloc[1:] = amount_price.iloc[1:].where(stock_amt <= min_amount, np.nan)
|
||||
amount_price = amount_price.unstack().reset_index()
|
||||
amount_price.columns = ['stock_code', 'time', 'price']
|
||||
amount_price = amount_price.dropna(subset=['price']).drop_duplicates(subset='stock_code', keep='last')
|
||||
# 计算后复权价格
|
||||
amount_price = amount_price.merge(post_adj_factor, on=['stock_code'], how='left').dropna(subset=['factor'])
|
||||
amount_price['open_post'] = amount_price['price'] * amount_price['factor']
|
||||
|
||||
return amount_price[['stock_code','open_post']]
|
||||
|
||||
|
|
@ -2,4 +2,5 @@ from account import Account
|
|||
from trader import Trader
|
||||
from spread_backtest import Spread_Backtest
|
||||
|
||||
__all__ = ['Account', 'Trader', 'Spread_Backtest']
|
||||
|
||||
__all__ = ['Account', 'Trader', 'Spread_Backtest', 'Specified_Price']
|
|
@ -0,0 +1,68 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def check_buy_exclude(buy_exclude):
|
||||
buy_exclude_dict = dict()
|
||||
# 检查剔除条件
|
||||
for cond in ['amt_20', 'list_days', 'mkt', 'price']:
|
||||
if cond == 'amt_20':
|
||||
# 20日成交量均值
|
||||
if cond in buy_exclude:
|
||||
amt_exclude = buy_exclude[cond]
|
||||
else:
|
||||
buy_exclude[cond] = (0, np.inf)
|
||||
continue
|
||||
if isinstance(buy_exclude[cond], (set,list,tuple)):
|
||||
if len(amt_exclude) == 1:
|
||||
buy_exclude_dict[cond] = (amt_exclude[0], np.inf)
|
||||
else:
|
||||
buy_exclude_dict[cond] = (amt_exclude[0], amt_exclude[1])
|
||||
else:
|
||||
raise Exception('wrong input type for buy exclude: amt_20, `set`, `list` or `tuple` is required')
|
||||
|
||||
if cond == 'list_days':
|
||||
# 上市时间
|
||||
if cond in buy_exclude:
|
||||
list_exclude = buy_exclude[cond]
|
||||
else:
|
||||
buy_exclude[cond] = (0, np.inf)
|
||||
continue
|
||||
if isinstance(buy_exclude[cond], (set,list,tuple)):
|
||||
if len(list_exclude) == 1:
|
||||
buy_exclude_dict[cond] = (list_exclude[0], np.inf)
|
||||
else:
|
||||
buy_exclude_dict[cond] = (list_exclude[0], list_exclude[1])
|
||||
else:
|
||||
raise Exception('wrong input type for buy exclude: list_days, `set`, `list` or `tuple` is required')
|
||||
|
||||
if cond == 'mkt':
|
||||
# 市值
|
||||
if cond in buy_exclude:
|
||||
mkt_exclude = buy_exclude[cond]
|
||||
else:
|
||||
buy_exclude[cond] = (0, np.inf)
|
||||
continue
|
||||
if isinstance(buy_exclude[cond], (set,list,tuple)):
|
||||
if len(mkt_exclude) == 1:
|
||||
buy_exclude_dict[cond] = (mkt_exclude[0], np.inf)
|
||||
else:
|
||||
buy_exclude_dict[cond] = (mkt_exclude[0], mkt_exclude[1])
|
||||
else:
|
||||
raise Exception('wrong input type for buy exclude: mkt, `set`, `list` or `tuple` is required')
|
||||
|
||||
if cond == 'price':
|
||||
# 价格
|
||||
if cond in buy_exclude:
|
||||
price_exclude = buy_exclude[cond]
|
||||
else:
|
||||
buy_exclude[cond] = (0, np.inf)
|
||||
continue
|
||||
if isinstance(buy_exclude[cond], (set,list,tuple)):
|
||||
if len(price_exclude) == 1:
|
||||
buy_exclude_dict[cond] = (price_exclude[0], np.inf)
|
||||
else:
|
||||
buy_exclude_dict[cond] = (price_exclude[0], price_exclude[1])
|
||||
else:
|
||||
raise Exception('wrong input type for buy exclude: price, `set`, `list` or `tuple` is required')
|
||||
|
||||
return buy_exclude_dict
|
|
@ -10,8 +10,8 @@ if __name__ == '__main__':
|
|||
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
|
||||
save_dir = '/home/lenovo/quant/data/backtest/basic_data'
|
||||
|
||||
for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list',
|
||||
'abnormal', 'recession']):
|
||||
for i,f in enumerate(['open_post','close_post','open_pre','close_pre','down_limit','up_limit','size','amount_20',
|
||||
'opening_info','ipo_days','margin_list','abnormal', 'recession']):
|
||||
if f in ['margin_list']:
|
||||
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
|
||||
else:
|
||||
|
@ -34,7 +34,7 @@ if __name__ == '__main__':
|
|||
# 更新下一日的数据用于筛选
|
||||
next_date = gft.days_after(df.index.max(), 1)
|
||||
next_list = []
|
||||
for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
|
||||
for i,f in enumerate(['close_pre','size','amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
|
||||
if f in ['margin_list']:
|
||||
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union, Iterable, Optional, Dict
|
||||
from typing import Union, Iterable
|
||||
|
||||
class DataLoader():
|
||||
"""
|
||||
|
|
|
@ -5,6 +5,7 @@ import pandas as pd
|
|||
import numpy as np
|
||||
import time
|
||||
from trader import Trader
|
||||
from typing import Union, Dict
|
||||
from rich import print as rprint
|
||||
from rich.table import Table
|
||||
|
||||
|
@ -60,6 +61,26 @@ class SpreadBacktest():
|
|||
else:
|
||||
self.trader.update_signal(date, update_type='position')
|
||||
|
||||
# 更新数据
|
||||
def update_signal(self,
|
||||
trade_time: str,
|
||||
new_signal: pd.DataFrame):
|
||||
"""
|
||||
更新信号因子
|
||||
|
||||
Args:
|
||||
trade_time (str): 信号时间
|
||||
new_signal (pd.DataFrame): 新的更新信号
|
||||
"""
|
||||
self.trader.signal[trade_time] = new_signal
|
||||
|
||||
def update_interval(self,
|
||||
interval: Dict[str, Union[int,tuple,pd.Series]]={},
|
||||
):
|
||||
# 更新interval和weight
|
||||
# 如果interval为固定比例则更新
|
||||
self.trader.init_interval(interval)
|
||||
|
||||
@property
|
||||
def account_history(self):
|
||||
return self.trader.account_history
|
||||
|
|
213
trader.py
213
trader.py
|
@ -1,5 +1,4 @@
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import copy
|
||||
|
@ -8,59 +7,60 @@ from typing import Union, Iterable, Dict
|
|||
|
||||
from account import Account
|
||||
from dataloader import DataLoader
|
||||
from check_funcs import check_buy_exclude
|
||||
|
||||
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
|
||||
from db_tushare import get_factor_tools
|
||||
gft = get_factor_tools()
|
||||
|
||||
|
||||
|
||||
class Trader(Account):
|
||||
"""
|
||||
交易类: 用于控制每日交易情况
|
||||
|
||||
Args:
|
||||
signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行
|
||||
interval (int, tuple, pd.Series): 交易间隔
|
||||
num (int): 持仓数量
|
||||
ascending (bool): 因子方向
|
||||
with_st (bool): 是否包含st
|
||||
tick (bool): 是否开始tick模拟模式(开发中)
|
||||
weight ([str, pd.DataFrame]): 权重分配
|
||||
signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行
|
||||
interval (dict[str, (int, tuple, pd.Series)]):
|
||||
交易间隔
|
||||
num (int): 持仓数量
|
||||
ascending (bool): 因子方向
|
||||
with_st (bool): 是否包含st
|
||||
tick (bool): 是否开始tick模拟模式(开发中)
|
||||
weight ([str, pd.DataFrame]): 权重分配
|
||||
- avg (str): 平均分配,每天早盘重新分配,日中交易不重新分配
|
||||
- (pd.DataFrame): 自定义股票权重,包含每天个股指定的权重,会自动归一化
|
||||
amt_filter (set): 20日均成交额筛选,第一个参数是筛选下限,第二个参数是筛选上限,可以只提供下限
|
||||
data_root (dict): 对应各个目标因子的交易价格数据,必须包含stock_code和price列
|
||||
ipo_days (int): 筛选上市时间
|
||||
slippage (tuple): 买入和卖出滑点
|
||||
commission (float): 佣金
|
||||
tax (dict): 印花税
|
||||
exclude_list (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓
|
||||
amt_filter (set): 20日均成交额筛选,第一个参数是筛选下限,第二个参数是筛选上限,可以只提供下限
|
||||
data_root (dict): 对应各个目标因子的交易价格数据,必须包含stock_code和price列
|
||||
ipo_days (int): 筛选上市时间
|
||||
slippage (tuple): 买入和卖出滑点
|
||||
commission (float): 佣金
|
||||
tax (dict): 印花税
|
||||
force_exclude (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓
|
||||
- abnormal: 异常公告剔除,包含中止上市、立案调查、警示函等异常情况的剔除
|
||||
- receesion: 财报同比或环比下降50%以上
|
||||
- qualified_opinion: 会计保留意见
|
||||
account (Account): 账户设置,account.Account
|
||||
account (Account): 账户设置,account.Account
|
||||
"""
|
||||
def __init__(self,
|
||||
signal: Dict[str, pd.DataFrame]=None,
|
||||
interval: Dict[str, Union[int,tuple,pd.Series]]=1,
|
||||
interval: Dict[str, Union[int,tuple,pd.Series]]={},
|
||||
num: int=100,
|
||||
ascending: bool=False,
|
||||
with_st: bool=False,
|
||||
data_root:dict={},
|
||||
tick: bool=False,
|
||||
weight: str='avg',
|
||||
amt_filter: set=(0,),
|
||||
ipo_days: int=20,
|
||||
slippage :tuple=(0.001,0.001),
|
||||
commission: float=0.0001,
|
||||
buy_exclude: Dict[str, Union[set,int,float]]={},
|
||||
force_exclude: list=[],
|
||||
account: dict={},
|
||||
tax: dict={
|
||||
'1990-01-01': (0.001,0.001),
|
||||
'2008-04-24': (0.001,0.001),
|
||||
'2008-09-19': (0, 0.001),
|
||||
'2023-08-28': (0, 0.0005)
|
||||
},
|
||||
exclude_list: list=[],
|
||||
account: dict={},
|
||||
**kwargs) -> None:
|
||||
# 初始化账户
|
||||
super().__init__(**account)
|
||||
|
@ -78,24 +78,7 @@ class Trader(Account):
|
|||
if len(kwargs) > 0:
|
||||
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
|
||||
# interval
|
||||
self.interval = []
|
||||
for s in signal:
|
||||
if s in interval:
|
||||
s_interval = interval[s]
|
||||
if isinstance(s_interval, int):
|
||||
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
|
||||
df_interval[::s_interval] = 1
|
||||
elif isinstance(s_interval, tuple):
|
||||
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
|
||||
df_interval[::s_interval[0]] = s_interval[1]
|
||||
elif isinstance(s_interval, pd.Series):
|
||||
df_interval = s_interval
|
||||
else:
|
||||
raise ValueError('invalid interval type')
|
||||
self.interval.append(df_interval)
|
||||
else:
|
||||
raise ValueError(f'not found interval for signal {s}')
|
||||
self.interval = pd.concat(self.interval)
|
||||
self.init_interval(interval)
|
||||
# num
|
||||
if isinstance(num, int):
|
||||
self.num = int(num)
|
||||
|
@ -116,21 +99,6 @@ class Trader(Account):
|
|||
self.weight = weight
|
||||
else:
|
||||
raise ValueError('invalid type for `weight`')
|
||||
# amt_filter
|
||||
if isinstance(amt_filter, (set,list,tuple)):
|
||||
if len(amt_filter) == 1:
|
||||
self.amt_filter_min = amt_filter[0]
|
||||
self.amt_filter_max = np.inf
|
||||
else:
|
||||
self.amt_filter_min = amt_filter[0]
|
||||
self.amt_filter_max = amt_filter[1]
|
||||
else:
|
||||
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
|
||||
# ipo_days
|
||||
if isinstance(ipo_days, int):
|
||||
self.ipo_days = ipo_days
|
||||
else:
|
||||
raise Exception('wrong type for ipo_days, `int` is required')
|
||||
# slippage
|
||||
if isinstance(slippage, tuple) and len(slippage) == 2:
|
||||
self.slippage = slippage
|
||||
|
@ -146,17 +114,19 @@ class Trader(Account):
|
|||
self.tax = tax
|
||||
else:
|
||||
raise ValueError('tax should be dict.')
|
||||
# exclude
|
||||
if isinstance(exclude_list, list):
|
||||
self.exclude_list = exclude_list
|
||||
# buy exclude
|
||||
self.buy_exclude = check_buy_exclude(buy_exclude)
|
||||
# force exclude
|
||||
if isinstance(force_exclude, list):
|
||||
self.force_exclude = force_exclude
|
||||
optional_list = ['abnormal', 'recession']
|
||||
for item in exclude_list:
|
||||
for item in force_exclude:
|
||||
if item in optional_list:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unexpected keyword argument '{item}'")
|
||||
else:
|
||||
raise ValueError('exclude_list should be list.')
|
||||
raise ValueError('force_exclude should be list.')
|
||||
# data_root
|
||||
# 至少包含basic data路径,open信号默认使用basic_data
|
||||
if len(data_root) <= 0:
|
||||
|
@ -173,7 +143,30 @@ class Trader(Account):
|
|||
if s not in data_root:
|
||||
raise ValueError(f"data for signal {s} is not provided")
|
||||
self.data_root = data_root
|
||||
|
||||
|
||||
def init_interval(self, interval):
|
||||
"""
|
||||
初始化interval
|
||||
"""
|
||||
interval_list = []
|
||||
for s in self.signal:
|
||||
if s in interval:
|
||||
s_interval = interval[s]
|
||||
if isinstance(s_interval, int):
|
||||
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
|
||||
df_interval[::s_interval] = 1
|
||||
elif isinstance(s_interval, tuple):
|
||||
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
|
||||
df_interval[::s_interval[0]] = s_interval[1]
|
||||
elif isinstance(s_interval, pd.Series):
|
||||
df_interval = s_interval
|
||||
else:
|
||||
raise ValueError('invalid interval type')
|
||||
interval_list.append(df_interval)
|
||||
else:
|
||||
raise ValueError(f'not found interval for signal {s}')
|
||||
self.interval = pd.concat(interval_list)
|
||||
|
||||
def load_data(self,
|
||||
date: str,
|
||||
update_type: str='rtn'):
|
||||
|
@ -187,14 +180,19 @@ class Trader(Account):
|
|||
"""
|
||||
self.today_data = dict()
|
||||
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
|
||||
# 遍历从data_root中读取当日数据
|
||||
for path in self.data_root:
|
||||
if path == 'basic':
|
||||
continue
|
||||
else:
|
||||
self.today_data[path] = DataLoader(os.path.join(self.data_root[path],f'{date}.csv'))
|
||||
if update_type == 'position':
|
||||
return True
|
||||
for s in self.signal:
|
||||
if s == 'open':
|
||||
if s in self.data_root:
|
||||
continue
|
||||
else:
|
||||
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
|
||||
self.today_data[s+'_trade'] = DataLoader(self.today_data['open'].data[['open_post']].rename(columns={'open_post':'price'}))
|
||||
self.today_data['open'] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
|
||||
else:
|
||||
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
|
||||
if 'close' in self.signal:
|
||||
|
@ -254,43 +252,68 @@ class Trader(Account):
|
|||
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
|
||||
# 不足的数量通过买入列表自适应调整
|
||||
# 这样能实现在因子值不足时也正常换仓
|
||||
max_sell_num = self.interval.loc[date]*len(last_position)
|
||||
try:
|
||||
max_sell_num = self.interval.loc[date]*len(last_position)
|
||||
except Exception:
|
||||
raise ValueError(f'not found interval in {date}')
|
||||
else:
|
||||
last_position = pd.Series()
|
||||
max_sell_num = self.num
|
||||
|
||||
# 获取用于筛选的数据
|
||||
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
|
||||
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
|
||||
stock_amt_filter = stock_amt20.copy()
|
||||
stock_amt_filter.loc[:] = 0
|
||||
stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1
|
||||
stock_amt_filter = stock_amt_filter.sort_index()
|
||||
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
|
||||
stock_ipo_filter = stock_ipo_days.copy()
|
||||
stock_ipo_filter.loc[:] = 0
|
||||
stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
|
||||
exclude_data = dict()
|
||||
for cond in self.buy_exclude:
|
||||
if cond == 'amt_20':
|
||||
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
|
||||
stock_amt_exclude = stock_amt20.copy()
|
||||
stock_amt_exclude.loc[:] = 0
|
||||
stock_amt_exclude.loc[(stock_amt20 > self.buy_exclude[cond][0]) & (stock_amt20 < self.buy_exclude[cond][1])] = 1
|
||||
stock_amt_exclude = stock_amt_exclude.sort_index()
|
||||
exclude_data[cond] = stock_amt_exclude
|
||||
if cond == 'list_days':
|
||||
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
|
||||
stock_ipo_exclude = stock_ipo_days.copy()
|
||||
stock_ipo_exclude.loc[:] = 0
|
||||
stock_ipo_exclude.loc[(stock_ipo_days > self.buy_exclude[cond][0]) & (stock_ipo_days < self.buy_exclude[cond][1])] = 1
|
||||
stock_ipo_exclude = stock_ipo_exclude.sort_index()
|
||||
exclude_data[cond] = stock_ipo_exclude
|
||||
if cond == 'mkt':
|
||||
stock_size = self.today_data['basic'].get(factor.index.values, 'size').sort_index()
|
||||
stock_size_exclude = stock_size.copy()
|
||||
stock_size_exclude.loc[:] = 0
|
||||
stock_size_exclude.loc[(stock_size > self.buy_exclude[cond][0]) & (stock_size < self.buy_exclude[cond][1])] = 1
|
||||
stock_size_exclude = stock_size_exclude.sort_index()
|
||||
exclude_data[cond] = stock_size_exclude
|
||||
if cond == 'price':
|
||||
stock_price = self.today_data['basic'].get(factor.index.values, 'close_pre').sort_index()
|
||||
stock_price_exclude = stock_price.copy()
|
||||
stock_price_exclude.loc[:] = 0
|
||||
stock_price_exclude.loc[(stock_price > self.buy_exclude[cond][0]) & (stock_price < self.buy_exclude[cond][1])] = 1
|
||||
stock_price_exclude = stock_price_exclude.sort_index()
|
||||
exclude_data[cond] = stock_price_exclude
|
||||
|
||||
# 剔除列表
|
||||
# 包含强制列表和普通列表:
|
||||
# 包含强制剔除列表和买入剔除列表:
|
||||
# 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中
|
||||
# 普通列表如果已经持有不会过滤只对新买入的过滤
|
||||
# 买入剔除列表如果已经持有不会过滤只对新买入的过滤
|
||||
|
||||
# 强制过滤列表
|
||||
exclude_stock = []
|
||||
for exclude in self.exclude_list:
|
||||
if exclude == 'abnormal':
|
||||
for cond in self.force_exclude:
|
||||
if cond == 'abnormal':
|
||||
stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index()
|
||||
exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list()
|
||||
if exclude == 'recession':
|
||||
if cond == 'recession':
|
||||
stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index()
|
||||
exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list()
|
||||
exclude_stock = list(set(exclude_stock))
|
||||
force_exclude = copy.deepcopy(exclude_stock)
|
||||
# 普通过滤列表
|
||||
normal_exclude = []
|
||||
normal_exclude += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list()
|
||||
normal_exclude += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list()
|
||||
normal_exclude = list(set(normal_exclude))
|
||||
# 买入过滤列表
|
||||
buy_exclude = []
|
||||
for cond in self.buy_exclude:
|
||||
buy_exclude += exclude_data[cond].loc[exclude_data[cond] != 1].index.to_list()
|
||||
buy_exclude = list(set(buy_exclude))
|
||||
|
||||
# 交易列表
|
||||
# 仓位判断给与计算误差冗余
|
||||
|
@ -345,7 +368,7 @@ class Trader(Account):
|
|||
buy_list = []
|
||||
|
||||
# 剔除过滤条件后可买入列表
|
||||
after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
|
||||
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
|
||||
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
|
||||
|
||||
# 更新卖出后的持仓列表
|
||||
|
@ -454,7 +477,7 @@ class Trader(Account):
|
|||
buy_list = []
|
||||
|
||||
# 剔除过滤条件后可买入列表
|
||||
after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
|
||||
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude))
|
||||
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
|
||||
|
||||
# 更新卖出后的持仓列表
|
||||
|
@ -529,13 +552,19 @@ class Trader(Account):
|
|||
buy_list (Iterable[str]): 买入目标
|
||||
sell_list (Iterable[str]): 卖出目标
|
||||
"""
|
||||
stock_price = self.today_data[trade_time]
|
||||
if trade_time+'_trade' in self.today_data:
|
||||
trade_price = self.today_data[trade_time+'_trade']
|
||||
basic_price = self.today_data[trade_time]
|
||||
else:
|
||||
basic_price = self.today_data[trade_time]
|
||||
trade_price = basic_price
|
||||
target_price = pd.Series(index=target)
|
||||
sell_list = list(set(target) & set(sell_list))
|
||||
buy_list = list(set(target) & set(buy_list))
|
||||
target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
|
||||
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1])
|
||||
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0])
|
||||
# 根据交易和非交易标的分别获取目标价格
|
||||
target_price.loc[target] = basic_price.get(target, 'price').fillna(0)
|
||||
target_price.loc[sell_list] = trade_price.get(sell_list, 'price') * (1 - self.current_fee[1])
|
||||
target_price.loc[buy_list] = trade_price.get(buy_list, 'price') * (1 + self.current_fee[0])
|
||||
return target_price
|
||||
|
||||
def check_update_status(self,
|
||||
|
@ -724,6 +753,8 @@ class Trader(Account):
|
|||
# 计算收益
|
||||
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
|
||||
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
|
||||
# 价格缺失用初始weight填充
|
||||
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
|
||||
self.update_account(date, trade_time, cur_pos, next_position)
|
||||
# 更新仓位
|
||||
cur_pos['weight'] = self.update_next_weight(cur_pos)
|
||||
|
@ -744,6 +775,8 @@ class Trader(Account):
|
|||
# 更新当日收益
|
||||
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
|
||||
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
|
||||
# 价格缺失用初始weight填充
|
||||
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
|
||||
position_record = cur_pos.copy()
|
||||
position_record['end_weight'] = self.update_next_weight(position_record)
|
||||
self.position_history[date] = position_record.copy()[position_fields]
|
||||
|
|
Loading…
Reference in New Issue