Compare commits

..

No commits in common. "main" and "0.0.2" have entirely different histories.
main ... 0.0.2

8 changed files with 96 additions and 257 deletions

View File

@ -7,8 +7,3 @@
1. 优化非满仓情况下现金比例对收益的计算问题 1. 优化非满仓情况下现金比例对收益的计算问题
2. 优化自定义权重下的判断逻辑 2. 优化自定义权重下的判断逻辑
3. 修复没有满仓有现金比例时的收益计算问题 3. 修复没有满仓有现金比例时的收益计算问题
## Version 0.03
1. 新增分钟价格计算函数
2. 修复自定义价格时,买卖价格和非买卖价格的选取问题
3. 增加信号更新函数

View File

@ -1,33 +0,0 @@
import pandas as pd
import numpy as np
def amount_specified(
min_data: pd.DataFrame,
post_adj_factor: pd.DataFrame,
min_amount: float=0.
):
"""
计算指定金额下的平均后复权价格
Args:
min_data (pd.DataFrame): 指定日期分钟数据需至少包含代码(stock_code)分钟(Time)价格(price)成交量(vol)成交金额(amount)
post_adj_factor (pd.DataFrame): 后复权因子数据需至少包含股票后复权因子
min_amount (float): 指定最小金额下的平均价
"""
# 按照指定最小量获取平均价格
stock_amt = min_data.pivot_table(index='Time', columns='stock_code', values='amount').cumsum()
stock_vol = min_data.pivot_table(index='Time', columns='stock_code', values='vol').cumsum()
amount_price = stock_amt / stock_vol
amount_price.iloc[1:] = amount_price.iloc[1:].where(stock_amt <= min_amount, np.nan)
amount_price = amount_price.unstack().reset_index()
amount_price.columns = ['stock_code', 'time', 'price']
amount_price = amount_price.dropna(subset=['price']).drop_duplicates(subset='stock_code', keep='last')
# 计算后复权价格
amount_price = amount_price.merge(post_adj_factor, on=['stock_code'], how='left').dropna(subset=['factor'])
amount_price['open_post'] = amount_price['price'] * amount_price['factor']
return amount_price[['stock_code','open_post']]

View File

@ -2,5 +2,4 @@ from account import Account
from trader import Trader from trader import Trader
from spread_backtest import Spread_Backtest from spread_backtest import Spread_Backtest
__all__ = ['Account', 'Trader', 'Spread_Backtest']
__all__ = ['Account', 'Trader', 'Spread_Backtest', 'Specified_Price']

View File

@ -1,68 +0,0 @@
import numpy as np
def check_buy_exclude(buy_exclude):
buy_exclude_dict = dict()
# 检查剔除条件
for cond in ['amt_20', 'list_days', 'mkt', 'price']:
if cond == 'amt_20':
# 20日成交量均值
if cond in buy_exclude:
amt_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(amt_exclude) == 1:
buy_exclude_dict[cond] = (amt_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (amt_exclude[0], amt_exclude[1])
else:
raise Exception('wrong input type for buy exclude: amt_20, `set`, `list` or `tuple` is required')
if cond == 'list_days':
# 上市时间
if cond in buy_exclude:
list_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(list_exclude) == 1:
buy_exclude_dict[cond] = (list_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (list_exclude[0], list_exclude[1])
else:
raise Exception('wrong input type for buy exclude: list_days, `set`, `list` or `tuple` is required')
if cond == 'mkt':
# 市值
if cond in buy_exclude:
mkt_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(mkt_exclude) == 1:
buy_exclude_dict[cond] = (mkt_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (mkt_exclude[0], mkt_exclude[1])
else:
raise Exception('wrong input type for buy exclude: mkt, `set`, `list` or `tuple` is required')
if cond == 'price':
# 价格
if cond in buy_exclude:
price_exclude = buy_exclude[cond]
else:
buy_exclude[cond] = (0, np.inf)
continue
if isinstance(buy_exclude[cond], (set,list,tuple)):
if len(price_exclude) == 1:
buy_exclude_dict[cond] = (price_exclude[0], np.inf)
else:
buy_exclude_dict[cond] = (price_exclude[0], price_exclude[1])
else:
raise Exception('wrong input type for buy exclude: price, `set`, `list` or `tuple` is required')
return buy_exclude_dict

View File

@ -10,8 +10,8 @@ if __name__ == '__main__':
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data' data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
save_dir = '/home/lenovo/quant/data/backtest/basic_data' save_dir = '/home/lenovo/quant/data/backtest/basic_data'
for i,f in enumerate(['open_post','close_post','open_pre','close_pre','down_limit','up_limit','size','amount_20', for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list',
'opening_info','ipo_days','margin_list','abnormal', 'recession']): 'abnormal', 'recession']):
if f in ['margin_list']: if f in ['margin_list']:
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0) tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
else: else:
@ -34,7 +34,7 @@ if __name__ == '__main__':
# 更新下一日的数据用于筛选 # 更新下一日的数据用于筛选
next_date = gft.days_after(df.index.max(), 1) next_date = gft.days_after(df.index.max(), 1)
next_list = [] next_list = []
for i,f in enumerate(['close_pre','size','amount_20','opening_info','ipo_days','margin_list','abnormal','recession']): for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list','abnormal','recession']):
if f in ['margin_list']: if f in ['margin_list']:
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f)) next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
else: else:

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from typing import Union, Iterable from typing import Union, Iterable, Optional, Dict
class DataLoader(): class DataLoader():
""" """

View File

@ -5,7 +5,6 @@ import pandas as pd
import numpy as np import numpy as np
import time import time
from trader import Trader from trader import Trader
from typing import Union, Dict
from rich import print as rprint from rich import print as rprint
from rich.table import Table from rich.table import Table
@ -61,26 +60,6 @@ class SpreadBacktest():
else: else:
self.trader.update_signal(date, update_type='position') self.trader.update_signal(date, update_type='position')
# 更新数据
def update_signal(self,
trade_time: str,
new_signal: pd.DataFrame):
"""
更新信号因子
Args:
trade_time (str): 信号时间
new_signal (pd.DataFrame): 新的更新信号
"""
self.trader.signal[trade_time] = new_signal
def update_interval(self,
interval: Dict[str, Union[int,tuple,pd.Series]]={},
):
# 更新interval和weight
# 如果interval为固定比例则更新
self.trader.init_interval(interval)
@property @property
def account_history(self): def account_history(self):
return self.trader.account_history return self.trader.account_history

177
trader.py
View File

@ -1,4 +1,5 @@
import pandas as pd import pandas as pd
import numpy as np
import sys import sys
import os import os
import copy import copy
@ -7,7 +8,6 @@ from typing import Union, Iterable, Dict
from account import Account from account import Account
from dataloader import DataLoader from dataloader import DataLoader
from check_funcs import check_buy_exclude
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/") sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
from db_tushare import get_factor_tools from db_tushare import get_factor_tools
@ -20,8 +20,7 @@ class Trader(Account):
Args: Args:
signal (dict[str, pd.DataFrame]): 目标因子按顺序执行 signal (dict[str, pd.DataFrame]): 目标因子按顺序执行
interval (dict[str, (int, tuple, pd.Series)]): interval (int, tuple, pd.Series): 交易间隔
交易间隔
num (int): 持仓数量 num (int): 持仓数量
ascending (bool): 因子方向 ascending (bool): 因子方向
with_st (bool): 是否包含st with_st (bool): 是否包含st
@ -35,7 +34,7 @@ class Trader(Account):
slippage (tuple): 买入和卖出滑点 slippage (tuple): 买入和卖出滑点
commission (float): 佣金 commission (float): 佣金
tax (dict): 印花税 tax (dict): 印花税
force_exclude (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓 exclude_list (list): 额外的剔除列表会优先满足该剔除列表中的条件之后再进行正常的调仓
- abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除 - abnormal: 异常公告剔除包含中止上市立案调查警示函等异常情况的剔除
- receesion: 财报同比或环比下降50%以上 - receesion: 财报同比或环比下降50%以上
- qualified_opinion: 会计保留意见 - qualified_opinion: 会计保留意见
@ -43,24 +42,25 @@ class Trader(Account):
""" """
def __init__(self, def __init__(self,
signal: Dict[str, pd.DataFrame]=None, signal: Dict[str, pd.DataFrame]=None,
interval: Dict[str, Union[int,tuple,pd.Series]]={}, interval: Dict[str, Union[int,tuple,pd.Series]]=1,
num: int=100, num: int=100,
ascending: bool=False, ascending: bool=False,
with_st: bool=False, with_st: bool=False,
data_root:dict={}, data_root:dict={},
tick: bool=False, tick: bool=False,
weight: str='avg', weight: str='avg',
amt_filter: set=(0,),
ipo_days: int=20,
slippage :tuple=(0.001,0.001), slippage :tuple=(0.001,0.001),
commission: float=0.0001, commission: float=0.0001,
buy_exclude: Dict[str, Union[set,int,float]]={},
force_exclude: list=[],
account: dict={},
tax: dict={ tax: dict={
'1990-01-01': (0.001,0.001), '1990-01-01': (0.001,0.001),
'2008-04-24': (0.001,0.001), '2008-04-24': (0.001,0.001),
'2008-09-19': (0, 0.001), '2008-09-19': (0, 0.001),
'2023-08-28': (0, 0.0005) '2023-08-28': (0, 0.0005)
}, },
exclude_list: list=[],
account: dict={},
**kwargs) -> None: **kwargs) -> None:
# 初始化账户 # 初始化账户
super().__init__(**account) super().__init__(**account)
@ -78,7 +78,24 @@ class Trader(Account):
if len(kwargs) > 0: if len(kwargs) > 0:
raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'") raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'")
# interval # interval
self.init_interval(interval) self.interval = []
for s in signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
self.interval.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(self.interval)
# num # num
if isinstance(num, int): if isinstance(num, int):
self.num = int(num) self.num = int(num)
@ -99,6 +116,21 @@ class Trader(Account):
self.weight = weight self.weight = weight
else: else:
raise ValueError('invalid type for `weight`') raise ValueError('invalid type for `weight`')
# amt_filter
if isinstance(amt_filter, (set,list,tuple)):
if len(amt_filter) == 1:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = np.inf
else:
self.amt_filter_min = amt_filter[0]
self.amt_filter_max = amt_filter[1]
else:
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
# ipo_days
if isinstance(ipo_days, int):
self.ipo_days = ipo_days
else:
raise Exception('wrong type for ipo_days, `int` is required')
# slippage # slippage
if isinstance(slippage, tuple) and len(slippage) == 2: if isinstance(slippage, tuple) and len(slippage) == 2:
self.slippage = slippage self.slippage = slippage
@ -114,19 +146,17 @@ class Trader(Account):
self.tax = tax self.tax = tax
else: else:
raise ValueError('tax should be dict.') raise ValueError('tax should be dict.')
# buy exclude # exclude
self.buy_exclude = check_buy_exclude(buy_exclude) if isinstance(exclude_list, list):
# force exclude self.exclude_list = exclude_list
if isinstance(force_exclude, list):
self.force_exclude = force_exclude
optional_list = ['abnormal', 'recession'] optional_list = ['abnormal', 'recession']
for item in force_exclude: for item in exclude_list:
if item in optional_list: if item in optional_list:
pass pass
else: else:
raise ValueError(f"Unexpected keyword argument '{item}'") raise ValueError(f"Unexpected keyword argument '{item}'")
else: else:
raise ValueError('force_exclude should be list.') raise ValueError('exclude_list should be list.')
# data_root # data_root
# 至少包含basic data路径open信号默认使用basic_data # 至少包含basic data路径open信号默认使用basic_data
if len(data_root) <= 0: if len(data_root) <= 0:
@ -144,29 +174,6 @@ class Trader(Account):
raise ValueError(f"data for signal {s} is not provided") raise ValueError(f"data for signal {s} is not provided")
self.data_root = data_root self.data_root = data_root
def init_interval(self, interval):
"""
初始化interval
"""
interval_list = []
for s in self.signal:
if s in interval:
s_interval = interval[s]
if isinstance(s_interval, int):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval] = 1
elif isinstance(s_interval, tuple):
df_interval = pd.Series(index=self.signal[s].index, data=[0]*len(self.signal[s].index))
df_interval[::s_interval[0]] = s_interval[1]
elif isinstance(s_interval, pd.Series):
df_interval = s_interval
else:
raise ValueError('invalid interval type')
interval_list.append(df_interval)
else:
raise ValueError(f'not found interval for signal {s}')
self.interval = pd.concat(interval_list)
def load_data(self, def load_data(self,
date: str, date: str,
update_type: str='rtn'): update_type: str='rtn'):
@ -180,19 +187,14 @@ class Trader(Account):
""" """
self.today_data = dict() self.today_data = dict()
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv')) self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
# 遍历从data_root中读取当日数据
for path in self.data_root:
if path == 'basic':
continue
else:
self.today_data[path] = DataLoader(os.path.join(self.data_root[path],f'{date}.csv'))
if update_type == 'position': if update_type == 'position':
return True return True
for s in self.signal: for s in self.signal:
if s == 'open': if s == 'open':
if s in self.data_root: if s in self.data_root:
self.today_data[s+'_trade'] = DataLoader(self.today_data['open'].data[['open_post']].rename(columns={'open_post':'price'})) continue
self.today_data['open'] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'})) else:
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
else: else:
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv')) self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
if 'close' in self.signal: if 'close' in self.signal:
@ -252,68 +254,43 @@ class Trader(Account):
# 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量 # 如果昨日持仓本身就不足持仓数量则按照昨日持仓数量作为基准计算换仓数量
# 不足的数量通过买入列表自适应调整 # 不足的数量通过买入列表自适应调整
# 这样能实现在因子值不足时也正常换仓 # 这样能实现在因子值不足时也正常换仓
try:
max_sell_num = self.interval.loc[date]*len(last_position) max_sell_num = self.interval.loc[date]*len(last_position)
except Exception:
raise ValueError(f'not found interval in {date}')
else: else:
last_position = pd.Series() last_position = pd.Series()
max_sell_num = self.num max_sell_num = self.num
# 获取用于筛选的数据 # 获取用于筛选的数据
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index() stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
exclude_data = dict()
for cond in self.buy_exclude:
if cond == 'amt_20':
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index() stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
stock_amt_exclude = stock_amt20.copy() stock_amt_filter = stock_amt20.copy()
stock_amt_exclude.loc[:] = 0 stock_amt_filter.loc[:] = 0
stock_amt_exclude.loc[(stock_amt20 > self.buy_exclude[cond][0]) & (stock_amt20 < self.buy_exclude[cond][1])] = 1 stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1
stock_amt_exclude = stock_amt_exclude.sort_index() stock_amt_filter = stock_amt_filter.sort_index()
exclude_data[cond] = stock_amt_exclude
if cond == 'list_days':
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index() stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
stock_ipo_exclude = stock_ipo_days.copy() stock_ipo_filter = stock_ipo_days.copy()
stock_ipo_exclude.loc[:] = 0 stock_ipo_filter.loc[:] = 0
stock_ipo_exclude.loc[(stock_ipo_days > self.buy_exclude[cond][0]) & (stock_ipo_days < self.buy_exclude[cond][1])] = 1 stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
stock_ipo_exclude = stock_ipo_exclude.sort_index()
exclude_data[cond] = stock_ipo_exclude
if cond == 'mkt':
stock_size = self.today_data['basic'].get(factor.index.values, 'size').sort_index()
stock_size_exclude = stock_size.copy()
stock_size_exclude.loc[:] = 0
stock_size_exclude.loc[(stock_size > self.buy_exclude[cond][0]) & (stock_size < self.buy_exclude[cond][1])] = 1
stock_size_exclude = stock_size_exclude.sort_index()
exclude_data[cond] = stock_size_exclude
if cond == 'price':
stock_price = self.today_data['basic'].get(factor.index.values, 'close_pre').sort_index()
stock_price_exclude = stock_price.copy()
stock_price_exclude.loc[:] = 0
stock_price_exclude.loc[(stock_price > self.buy_exclude[cond][0]) & (stock_price < self.buy_exclude[cond][1])] = 1
stock_price_exclude = stock_price_exclude.sort_index()
exclude_data[cond] = stock_price_exclude
# 剔除列表 # 剔除列表
# 包含强制剔除列表和买入剔除列表: # 包含强制列表和普通列表:
# 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中 # 强制列表会将已经持仓的也强制剔除并且不算在换手率限制中
# 买入剔除列表如果已经持有不会过滤只对新买入的过滤 # 普通列表如果已经持有不会过滤只对新买入的过滤
# 强制过滤列表 # 强制过滤列表
exclude_stock = [] exclude_stock = []
for cond in self.force_exclude: for exclude in self.exclude_list:
if cond == 'abnormal': if exclude == 'abnormal':
stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index() stock_abnormal = self.today_data['basic'].get(factor.index.values, 'abnormal').sort_index()
exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list() exclude_stock += stock_abnormal.loc[stock_abnormal > 0].index.to_list()
if cond == 'recession': if exclude == 'recession':
stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index() stock_recession = self.today_data['basic'].get(factor.index.values, 'recession').sort_index()
exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list() exclude_stock += stock_recession.loc[stock_recession > 0].index.to_list()
exclude_stock = list(set(exclude_stock)) exclude_stock = list(set(exclude_stock))
force_exclude = copy.deepcopy(exclude_stock) force_exclude = copy.deepcopy(exclude_stock)
# 买入过滤列表 # 普通过滤列表
buy_exclude = [] normal_exclude = []
for cond in self.buy_exclude: normal_exclude += stock_ipo_filter.loc[stock_ipo_filter != 1].index.to_list()
buy_exclude += exclude_data[cond].loc[exclude_data[cond] != 1].index.to_list() normal_exclude += stock_amt_filter.loc[stock_amt_filter != 1].index.to_list()
buy_exclude = list(set(buy_exclude)) normal_exclude = list(set(normal_exclude))
# 交易列表 # 交易列表
# 仓位判断给与计算误差冗余 # 仓位判断给与计算误差冗余
@ -368,7 +345,7 @@ class Trader(Account):
buy_list = [] buy_list = []
# 剔除过滤条件后可买入列表 # 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude)) after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list() target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表 # 更新卖出后的持仓列表
@ -477,7 +454,7 @@ class Trader(Account):
buy_list = [] buy_list = []
# 剔除过滤条件后可买入列表 # 剔除过滤条件后可买入列表
after_filter_list = list(set(factor.index) - set(buy_exclude) - set(force_exclude)) after_filter_list = list(set(factor.index) - set(normal_exclude) - set(force_exclude))
target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list() target_list = factor.loc[after_filter_list].dropna().sort_values(ascending=self.ascending).index.to_list()
# 更新卖出后的持仓列表 # 更新卖出后的持仓列表
@ -552,19 +529,13 @@ class Trader(Account):
buy_list (Iterable[str]): 买入目标 buy_list (Iterable[str]): 买入目标
sell_list (Iterable[str]): 卖出目标 sell_list (Iterable[str]): 卖出目标
""" """
if trade_time+'_trade' in self.today_data: stock_price = self.today_data[trade_time]
trade_price = self.today_data[trade_time+'_trade']
basic_price = self.today_data[trade_time]
else:
basic_price = self.today_data[trade_time]
trade_price = basic_price
target_price = pd.Series(index=target) target_price = pd.Series(index=target)
sell_list = list(set(target) & set(sell_list)) sell_list = list(set(target) & set(sell_list))
buy_list = list(set(target) & set(buy_list)) buy_list = list(set(target) & set(buy_list))
# 根据交易和非交易标的分别获取目标价格 target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
target_price.loc[target] = basic_price.get(target, 'price').fillna(0) target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1])
target_price.loc[sell_list] = trade_price.get(sell_list, 'price') * (1 - self.current_fee[1]) target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0])
target_price.loc[buy_list] = trade_price.get(buy_list, 'price') * (1 + self.current_fee[0])
return target_price return target_price
def check_update_status(self, def check_update_status(self,
@ -753,8 +724,6 @@ class Trader(Account):
# 计算收益 # 计算收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn'] cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
self.update_account(date, trade_time, cur_pos, next_position) self.update_account(date, trade_time, cur_pos, next_position)
# 更新仓位 # 更新仓位
cur_pos['weight'] = self.update_next_weight(cur_pos) cur_pos['weight'] = self.update_next_weight(cur_pos)
@ -775,8 +744,6 @@ class Trader(Account):
# 更新当日收益 # 更新当日收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1 cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1) cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
position_record = cur_pos.copy() position_record = cur_pos.copy()
position_record['end_weight'] = self.update_next_weight(position_record) position_record['end_weight'] = self.update_next_weight(position_record)
self.position_history[date] = position_record.copy()[position_fields] self.position_history[date] = position_record.copy()[position_fields]