init
This commit is contained in:
commit
e0d025f71c
|
@ -0,0 +1,5 @@
|
|||
from account import Account
|
||||
from trader import Trader
|
||||
from spread_backtest import Spread_Backtest
|
||||
|
||||
__all__ = ['Account', 'Trader', 'Spread_Backtest']
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,63 @@
|
|||
import warnings
|
||||
import pandas as pd
|
||||
from rich import print as rprint
|
||||
from typing import Union, Iterable, Optional, Dict
|
||||
|
||||
|
||||
# 警告格式
|
||||
def custom_warnings(message, category, filename, lineno, file=None, line=None):
|
||||
rprint("[bold red]Warning:[/bold red]", message)
|
||||
warnings.showwarning = custom_warnings
|
||||
|
||||
|
||||
class CustomWarning(UserWarning):
|
||||
pass
|
||||
warnings.simplefilter('always', category=CustomWarning)
|
||||
|
||||
|
||||
class Account():
|
||||
"""
|
||||
账户类:用于控制账户类型
|
||||
Arguments:
|
||||
- is_real(bool): 是否为真实交易
|
||||
- init_account(Union[int,float]): 初始账户金额
|
||||
- trade_limit(float): 是否限制每个股票的购买比例(只在真实账户中有效)
|
||||
- leverage(float): 杠杆比例
|
||||
"""
|
||||
def __init__(self,
|
||||
is_real: bool = False,
|
||||
init_account: Union[int,float] = 1e6,
|
||||
trade_limit: float = 0.01,
|
||||
leverage: Union[int,float] = 1.) -> None:
|
||||
if is_real:
|
||||
self.a_type = 'amount'
|
||||
self.account = init_account
|
||||
else:
|
||||
self.a_type = 'percent'
|
||||
self.account = 1.0
|
||||
self.trade_limit = trade_limit
|
||||
if isinstance(leverage, (int,float)):
|
||||
if leverage > 1.5:
|
||||
raise ValueError('leverage should less than 1.5')
|
||||
if leverage > 1:
|
||||
warnings.warn('leverage(杠杆率)大于1时,优先满足杠杆,会出现高于指定换手率的情况!', category=CustomWarning)
|
||||
else:
|
||||
raise ValueError('leverage should be int or float')
|
||||
self.leverage = leverage
|
||||
# 当前持仓
|
||||
self.position = pd.DataFrame()
|
||||
# 账户收益更新
|
||||
self.account_history = pd.DataFrame(columns=['date','trade_time','turnover','leverage','pnl'])
|
||||
# 历史持仓
|
||||
self.position_history = dict()
|
||||
# 当前日期
|
||||
self.date = None
|
||||
|
||||
# 获取收盘时持仓和仓位
|
||||
@property
|
||||
def close_position(self):
|
||||
return self.position[self.position['trade_time'] == 'close']['stock_code'].values
|
||||
|
||||
@property
|
||||
def close_weight(self):
|
||||
return self.position[self.position['trade_time'] == 'close'].weight.values
|
|
@ -0,0 +1,44 @@
|
|||
import pandas as pd
|
||||
import os, sys
|
||||
|
||||
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
|
||||
from db_tushare import get_factor_tools
|
||||
gft = get_factor_tools()
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_dir = '/home/lenovo/quant/tools/detail_testing/basic_data'
|
||||
save_dir = '/home/lenovo/quant/data/backtest/basic_data'
|
||||
|
||||
for i,f in enumerate(['open_post','close_post','down_limit','up_limit','size','amount_20','opening_info','ipo_days','margin_list']):
|
||||
if f in ['margin_list']:
|
||||
tmp = gft.get_stock_factor(f, start='2012-01-01').fillna(0)
|
||||
else:
|
||||
tmp = pd.read_csv(f'{data_dir}/{f}.csv', index_col=0)
|
||||
tmp = tmp.unstack().reset_index()
|
||||
tmp.columns = ['stock_code', 'date', f]
|
||||
if i == 0:
|
||||
df = tmp
|
||||
else:
|
||||
df = df.merge(tmp, on=['stock_code', 'date'], how="left")
|
||||
df = df.set_index(['date']).sort_index()
|
||||
|
||||
existed = os.listdir(save_dir)
|
||||
for d in sorted(df.index.unique()):
|
||||
if (d+'.csv' in existed) and (d+'.csv' != max(existed)):
|
||||
continue
|
||||
else:
|
||||
df.loc[d].sort_values(by=['stock_code']).to_csv(f'{save_dir}/{d}.csv', index=False)
|
||||
|
||||
# 更新下一日的数据用于筛选
|
||||
next_date = gft.days_after(df.index.max(), 1)
|
||||
next_list = []
|
||||
for i,f in enumerate(['amount_20','opening_info','ipo_days','margin_list']):
|
||||
if f in ['margin_list']:
|
||||
next_list.append(pd.Series(gft.get_stock_factor(f, start='2012-01-01').fillna(0).iloc[-1], name=f))
|
||||
else:
|
||||
next_list.append(pd.Series(pd.read_csv(f'{data_dir}/{f}.csv', index_col=0).iloc[-1], name=f))
|
||||
df = pd.concat(next_list, axis=1)
|
||||
df.index.name = 'stock_code'
|
||||
df = df.reset_index()
|
||||
|
||||
df.sort_values(by=['stock_code']).to_csv(f'{save_dir}/{next_date}.csv', index=False)
|
|
@ -0,0 +1,32 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union, Iterable, Optional, Dict
|
||||
|
||||
class DataLoader():
|
||||
"""
|
||||
数据类: 数据加载模块
|
||||
"""
|
||||
def __init__(self, path: Union[str, pd.DataFrame]=None):
|
||||
if type(path) == str:
|
||||
self.data = pd.read_csv(path).set_index(['stock_code']).sort_index()
|
||||
if type(path) == pd.DataFrame:
|
||||
self.data = path.copy()
|
||||
|
||||
def get(self,
|
||||
target: Iterable[str]=[],
|
||||
column: str=''):
|
||||
"""
|
||||
- target: 查询目标代码
|
||||
- column(str): 查询列
|
||||
"""
|
||||
res = pd.Series(index=target)
|
||||
column_map = dict(zip(self.data.index.values, self.data[column].values))
|
||||
column_type = self.data[column].dtype
|
||||
stock_list = list(set(self.data.index.values) & set(target))
|
||||
res.loc[stock_list] = self.data.loc[stock_list, column].values
|
||||
stock_list = list(set(target) - set(stock_list))
|
||||
if column_type == str:
|
||||
res.loc[stock_list] = None
|
||||
else:
|
||||
res.loc[stock_list] = np.nan
|
||||
return res.sort_index()
|
|
@ -0,0 +1,96 @@
|
|||
# -*- coding: UTF-8 -*-
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from trader import Trader
|
||||
from rich.progress import track
|
||||
from rich import print as rprint
|
||||
from rich import pretty, text
|
||||
from rich.table import Column, Table
|
||||
from rich.style import Style
|
||||
|
||||
|
||||
class Spread_Backtest():
|
||||
def __init__(self,
|
||||
trader: Trader):
|
||||
self.trader = trader
|
||||
|
||||
def run(self,
|
||||
start: str,
|
||||
end: str
|
||||
):
|
||||
# 确定回测的开始和结束时间
|
||||
if len(self.trader.account_history) > 0:
|
||||
bkt_start = max(start, self.trader.avaliable_date.index.min(), self.trader.account_history['date'].max())
|
||||
else:
|
||||
bkt_start = max(start, self.trader.avaliable_date.index.min())
|
||||
rec_end = min(end, self.trader.avaliable_date.index.max())
|
||||
# 如果传入的第一个因子时间范围大于现有数据最大时间,则只更新下一时刻的持仓不处理收益
|
||||
first_signal = self.trader.signal[list(self.trader.signal.keys())[0]]
|
||||
if rec_end < first_signal.index.max():
|
||||
bkt_end = first_signal.loc[rec_end:].index.to_list()[1]
|
||||
else:
|
||||
bkt_end = rec_end
|
||||
print(f'回测区间: {bkt_start} - {bkt_end}')
|
||||
for d in track(self.trader.signal[list(self.trader.signal.keys())[0]].loc[bkt_start:bkt_end].index,
|
||||
description='Backtesting...',
|
||||
update_period=0.5):
|
||||
# avaliable_date最后一天的数据只能用于记录持仓
|
||||
if (d <= rec_end) and (d < self.trader.avaliable_date.index.max()):
|
||||
self.trader.update_signal(d)
|
||||
else:
|
||||
self.trader.update_signal(d, update_type='position')
|
||||
|
||||
@property
|
||||
def account_history(self):
|
||||
return self.trader.account_history
|
||||
|
||||
@property
|
||||
def position_history(self):
|
||||
return self.trader.position_history
|
||||
|
||||
def analyze(self):
|
||||
"""
|
||||
分析统计
|
||||
"""
|
||||
rtn_stat = self.trader.account_history.copy()
|
||||
rtn_stat['pnl'] += 1
|
||||
rtn_stat = rtn_stat.groupby(['date'])['pnl'].prod().cumprod()
|
||||
|
||||
# 根据basic data路径确定交易日
|
||||
trading_day = pd.Series(index=[f.split('.')[0] for f in os.listdir(self.trader.data_root['basic'])]).sort_index()
|
||||
# 按年统计
|
||||
year_rtn = pd.DataFrame(columns=['收益'])
|
||||
for year in sorted(pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.unique()):
|
||||
start_date = max(rtn_stat.index.min(), trading_day.loc[:'{}-01-01'.format(year)].index.values[-1])
|
||||
end_date = min(rtn_stat.index.max(), trading_day.loc[:'{}-12-31'.format(year)].index.values[-1])
|
||||
if year == pd.to_datetime(pd.Series(rtn_stat.index.values)).dt.year.min():
|
||||
year_rtn.loc[year, '收益'] = rtn_stat.loc[end_date] - 1
|
||||
else:
|
||||
year_rtn.loc[year, '收益'] = (rtn_stat.loc[end_date] / rtn_stat.loc[start_date]) - 1
|
||||
year_rtn.loc['Annualized'] = np.power(rtn_stat.values[-1], 1 / (len(rtn_stat) / 244)) - 1
|
||||
year_rtn = year_rtn.applymap(lambda x: '{:.2%}'.format(x))
|
||||
year_rtn = year_rtn.reset_index()
|
||||
year_rtn.columns = ['Year', 'Rtn']
|
||||
rprint("[bold black]1 收益统计[/bold black]")
|
||||
print_year_rtn(year_rtn)
|
||||
|
||||
|
||||
def print_year_rtn(year_rtn: pd.DataFrame) -> None:
|
||||
table = Table(show_header=True, header_style='bold')
|
||||
for col in year_rtn:
|
||||
table.add_column(col, justify='center', width=10, no_wrap=True)
|
||||
for _,row in year_rtn.iterrows():
|
||||
new_row = []
|
||||
for col in year_rtn.columns:
|
||||
if col == '收益':
|
||||
color = "green" if row['收益'][0] == '-' else "sred"
|
||||
value = f"[{color}]{row[col]}[/{color}]"
|
||||
new_row.append(str(value))
|
||||
else:
|
||||
new_row.append(str(row[col]))
|
||||
table.add_row(*new_row)
|
||||
|
||||
rprint(table)
|
||||
|
||||
|
|
@ -0,0 +1,550 @@
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
import sys, os
|
||||
sys.path.append("/home/lenovo/quant/tools/get_factor_tools/")
|
||||
from db_tushare import get_factor_tools
|
||||
gft = get_factor_tools()
|
||||
from typing import Union, Iterable, Dict
|
||||
from account import Account
|
||||
from dataloader import DataLoader
|
||||
|
||||
|
||||
class Trader(Account):
|
||||
"""
|
||||
交易类: 用于控制每日交易情况
|
||||
|
||||
Args:
|
||||
signal (dict[str, pd.DataFrame]): 目标因子,按顺序执行
|
||||
interval (int, tuple, pd.Series): 交易间隔
|
||||
num (int): 持仓数量
|
||||
ascending (bool): 因子方向
|
||||
fee (tuple): 买入成本和卖出成本
|
||||
with_st (bool): 是否包含st
|
||||
tick (bool): 是否开始tick模拟模式(开发中)
|
||||
weight (str): 权重分配
|
||||
- avg: 平均分配,每天早盘重新分配,日中交易不重新分配
|
||||
amt_filter (set): 20日均成交额筛选,第一个参数是筛选下限,第二个参数是筛选上限,可以只提供下限
|
||||
data_root (dict): 对应各个目标因子的交易价格数据,必须包含stock_code和price列
|
||||
ipo_days (int): 筛选上市时间
|
||||
"""
|
||||
def __init__(self,
|
||||
signal: Dict[str, pd.DataFrame]=None,
|
||||
interval: Dict[str, Union[int,tuple,pd.Series]]=1,
|
||||
num: int=100,
|
||||
ascending: bool=False,
|
||||
fee :tuple=(0.001,0.002),
|
||||
with_st: bool=False,
|
||||
data_root:dict={},
|
||||
tick: bool=False,
|
||||
weight: str='avg',
|
||||
amt_filter: set=(0,),
|
||||
ipo_days: int=20,
|
||||
**kwargs) -> None:
|
||||
# 初始化账户
|
||||
super().__init__(**kwargs.get('account', {}))
|
||||
if isinstance(signal, dict):
|
||||
self.signal = signal
|
||||
if 'close' in signal:
|
||||
raise ValueError('signal key cannot be close')
|
||||
for s in self.signal:
|
||||
self.signal[s] = gft.return_factor(self.signal[s], self.signal[s].index.min(), self.signal[s].index.max(), return_type='origin')
|
||||
else:
|
||||
raise ValueError('type of signal is invalid')
|
||||
# interval
|
||||
self.interval = []
|
||||
for s in signal:
|
||||
if s in interval:
|
||||
s_interval = interval[s]
|
||||
if isinstance(s_interval, int):
|
||||
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
|
||||
df_interval[::s_interval] = 1
|
||||
elif isinstance(s_interval, tuple):
|
||||
df_interval = pd.Series(index=signal[s].index, data=[0]*len(signal[s].index))
|
||||
df_interval[::s_interval[0]] = s_interval[1]
|
||||
elif isinstance(s_interval, pd.Series):
|
||||
df_interval = s_interval
|
||||
else:
|
||||
raise ValueError('invalid interval type')
|
||||
self.interval.append(df_interval)
|
||||
else:
|
||||
raise ValueError(f'not found interval for signal {s}')
|
||||
self.interval = pd.concat(self.interval)
|
||||
# num
|
||||
if isinstance(num, int):
|
||||
self.num = int(num)
|
||||
else:
|
||||
raise ValueError('num should be int')
|
||||
# ascending
|
||||
if isinstance(ascending, bool):
|
||||
self.ascending = ascending
|
||||
else:
|
||||
raise ValueError('invalid type for ascending')
|
||||
# fee
|
||||
if isinstance(fee, tuple) and len(fee) == 2:
|
||||
self.fee = fee
|
||||
else:
|
||||
raise ValueError('invalid input for fee')
|
||||
# with_st
|
||||
if isinstance(with_st, bool):
|
||||
self.with_st = with_st
|
||||
else:
|
||||
raise ValueError('invalid type for with_st')
|
||||
# amt_filter
|
||||
if isinstance(amt_filter, (set,list,tuple)):
|
||||
if len(amt_filter) == 1:
|
||||
self.amt_filter_min = amt_filter[0]
|
||||
self.amt_filter_max = np.inf
|
||||
else:
|
||||
self.amt_filter_min = amt_filter[0]
|
||||
self.amt_filter_max = amt_filter[1]
|
||||
else:
|
||||
raise Exception('wrong type for amt_filter, `set` `list` or `tuple` is required')
|
||||
# ipo_days
|
||||
if isinstance(ipo_days, int):
|
||||
self.ipo_days = ipo_days
|
||||
else:
|
||||
raise Exception('wrong type for ipo_days, `int` is required')
|
||||
# data_root
|
||||
# 至少包含basic data路径,open信号默认使用basic_data
|
||||
if len(data_root) <= 0:
|
||||
raise ValueError('num of data_root should be equal or greater than 1')
|
||||
if 'basic' in data_root:
|
||||
# 可执行日期
|
||||
self.avaliable_date = pd.Series(index=[f.split('.')[0] for f in os.listdir(data_root['basic'])]).sort_index()
|
||||
else:
|
||||
raise ValueError('data_root should contain basic data root')
|
||||
for s in self.signal:
|
||||
if s == 'open':
|
||||
continue
|
||||
else:
|
||||
if s not in data_root:
|
||||
raise ValueError(f'data for signal {s} is not provided')
|
||||
self.data_root = data_root
|
||||
|
||||
def load_data(self,
|
||||
date: str,
|
||||
update_type: str='rtn'):
|
||||
"""
|
||||
加载每日基础数据
|
||||
|
||||
Args:
|
||||
update_type (str): 更新模式
|
||||
- rtn: 更新所有信号数据
|
||||
- position: 只更新basic数据,用于持仓判断
|
||||
"""
|
||||
self.today_data = dict()
|
||||
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
|
||||
if update_type == 'position':
|
||||
return True
|
||||
for s in self.signal:
|
||||
if s == 'open':
|
||||
if s in self.data_root:
|
||||
continue
|
||||
else:
|
||||
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
|
||||
else:
|
||||
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
|
||||
if 'close' in self.signal:
|
||||
pass
|
||||
else:
|
||||
self.today_data['close'] = DataLoader(self.today_data['basic'].data[['close_post']].rename(columns={'close_post':'price'}))
|
||||
|
||||
def get_next_position(self, date, factor):
|
||||
"""
|
||||
计算下一时刻持仓
|
||||
"""
|
||||
# 计算持仓和最大可交易数量
|
||||
if len(self.position) > 0:
|
||||
last_position = self.position['stock_code'].values
|
||||
last_position = factor.loc[last_position].sort_values(ascending=self.ascending)
|
||||
max_trade_num = max(int(self.interval.loc[date]*self.num), self.num-len(self.position))
|
||||
else:
|
||||
last_position = pd.Series()
|
||||
max_trade_num = self.num
|
||||
target_list = []
|
||||
# 获取用于筛选的数据
|
||||
stock_status = self.today_data['basic'].get(factor.index.values, 'opening_info').sort_index()
|
||||
stock_amt20 = self.today_data['basic'].get(factor.index.values, 'amount_20').sort_index()
|
||||
stock_amt_filter = stock_amt20.copy()
|
||||
stock_amt_filter.loc[:] = 0
|
||||
stock_amt_filter.loc[(stock_amt20 > self.amt_filter_min) & (stock_amt20 < self.amt_filter_max)] = 1
|
||||
stock_amt_filter = stock_amt_filter.sort_index()
|
||||
stock_ipo_days = self.today_data['basic'].get(factor.index.values, 'ipo_days').sort_index()
|
||||
stock_ipo_filter = stock_ipo_days.copy()
|
||||
stock_ipo_filter.loc[:] = 0
|
||||
stock_ipo_filter.loc[stock_ipo_days > self.ipo_days] = 1
|
||||
# 交易列表
|
||||
if self.leverage <= 1.0:
|
||||
# 获取当前时间目标列表和冻结(无法交易)列表
|
||||
for stock in factor.dropna().sort_values(ascending=self.ascending).index.values:
|
||||
if (stock_amt_filter.loc[stock] != 1) or (stock_ipo_filter.loc[stock] != 1):
|
||||
continue
|
||||
if self.with_st:
|
||||
if stock_status.loc[stock] in [0,2]:
|
||||
if stock in last_position.index:
|
||||
target_list.append(stock)
|
||||
else:
|
||||
target_list.append(stock)
|
||||
else:
|
||||
# 非ST
|
||||
if stock_status.loc[stock] in [3,7]:
|
||||
target_list.append(stock)
|
||||
else:
|
||||
# 如果停牌或者跌停继续持有
|
||||
if stock_status.loc[stock] in [0,2,6]:
|
||||
if stock in last_position.index:
|
||||
target_list.append(stock)
|
||||
if len(target_list) == self.num:
|
||||
break
|
||||
# 如果没有杠杆
|
||||
buy_list = []
|
||||
sell_list = []
|
||||
# ----- 卖出 -----
|
||||
# 按照反向排名逐个卖出
|
||||
if self.ascending:
|
||||
factor = factor.fillna(factor.max()+1)
|
||||
else:
|
||||
factor = factor.fillna(factor.min()-1)
|
||||
for stock in factor.loc[last_position.index].sort_values(ascending=self.ascending).index.values[::-1]:
|
||||
if stock in target_list:
|
||||
continue
|
||||
else:
|
||||
if stock_status.loc[stock] in [0,2,5,7]:
|
||||
continue
|
||||
else:
|
||||
sell_list.append(stock)
|
||||
if len(sell_list) == max_trade_num:
|
||||
break
|
||||
# ----- 买入 -----
|
||||
# 卖出后持仓列表
|
||||
after_sell_list = set(last_position.index) - set(sell_list)
|
||||
max_trade_num = min(max_trade_num, self.num-len(last_position)+len(sell_list))
|
||||
for stock in target_list:
|
||||
if stock in after_sell_list:
|
||||
continue
|
||||
else:
|
||||
buy_list.append(stock)
|
||||
if len(buy_list) == max_trade_num:
|
||||
break
|
||||
next_position = pd.DataFrame({'stock_code': list((set(last_position.index) - set(sell_list)) | set(buy_list))})
|
||||
next_position['date'] = date
|
||||
next_position['weight'] = self.leverage / len(next_position)
|
||||
next_position['margin_trade'] = 0
|
||||
else:
|
||||
# 如果有杠杆
|
||||
def assign_stock(normal_list, margin_list, margin_needed, stock, status):
|
||||
if status == 1:
|
||||
if len(margin_list) < margin_needed:
|
||||
margin_list.append(stock)
|
||||
else:
|
||||
if len(normal_list) < self.num - margin_needed:
|
||||
normal_list.append(stock)
|
||||
return normal_list, margin_list
|
||||
# 计算需要融资融券标的数量
|
||||
margin_ratio = max(self.leverage-1, 0)
|
||||
margin_needed = round(self.num * margin_ratio)
|
||||
is_margin = self.today_data['basic'].get(factor.index.values, 'margin_list').sort_index()
|
||||
# 获取当前时间目标列表和冻结(无法交易)列表
|
||||
normal_list = []
|
||||
margin_list = []
|
||||
for stock in factor.dropna().sort_values(ascending=self.ascending).index.values:
|
||||
if stock_amt_filter.loc[stock] != 1:
|
||||
continue
|
||||
if self.with_st:
|
||||
if stock_status.loc[stock] in [0,2]:
|
||||
if stock in last_position.index:
|
||||
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
|
||||
else:
|
||||
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
|
||||
else:
|
||||
# 非ST
|
||||
if stock_status.loc[stock] in [3,7]:
|
||||
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
|
||||
else:
|
||||
# 如果停牌或者跌停继续持有
|
||||
if stock_status.loc[stock] in [0,2,6]:
|
||||
if stock in last_position.index:
|
||||
normal_list, margin_list = assign_stock(normal_list, margin_list, margin_needed, stock, is_margin.loc[stock])
|
||||
if len(normal_list + margin_list) == self.num:
|
||||
break
|
||||
target_list = normal_list + margin_list
|
||||
# ----- 卖出 -----
|
||||
buy_list = []
|
||||
sell_list = []
|
||||
# 融资融券池的和非融资融券池的分开更新
|
||||
# 更新融资融券池
|
||||
if len(last_position) > 0:
|
||||
last_margin_list = self.position.loc[self.position['margin_trade'] == 1, 'stock_code'].to_list()
|
||||
else:
|
||||
last_margin_list = []
|
||||
for stock in factor.loc[last_margin_list].sort_values(ascending=self.ascending).index.values[::-1]:
|
||||
if stock in normal_list:
|
||||
continue
|
||||
else:
|
||||
if stock_status.loc[stock] in [0,2,5,7]:
|
||||
continue
|
||||
else:
|
||||
sell_list.append(stock)
|
||||
if len(sell_list) >= int(max_trade_num * margin_ratio) + 1:
|
||||
break
|
||||
next_margin_list = list(set(last_margin_list) - set(sell_list))
|
||||
# 更新非融资融券池
|
||||
if len(last_position) > 0:
|
||||
last_normal_list = self.position.loc[self.position['margin_trade'] == 0, 'stock_code'].to_list()
|
||||
else:
|
||||
last_normal_list = []
|
||||
for stock in factor.loc[last_normal_list].sort_values(ascending=self.ascending).index.values[::-1]:
|
||||
if stock in normal_list:
|
||||
continue
|
||||
else:
|
||||
if stock_status.loc[stock] in [0,2,5,7]:
|
||||
continue
|
||||
else:
|
||||
sell_list.append(stock)
|
||||
if len(sell_list) >= max_trade_num:
|
||||
break
|
||||
next_normal_list = list(set(last_normal_list) - set(sell_list))
|
||||
# ----- 买入 -----
|
||||
# 卖出后持仓列表
|
||||
after_sell_list = set(last_position.index) - set(sell_list)
|
||||
max_trade_num = min(max_trade_num, self.num-len(last_position)+len(sell_list))
|
||||
# 融资融券池的和非融资融券池的分开更新
|
||||
# 更新融资融券池
|
||||
for stock in margin_list:
|
||||
if stock in after_sell_list:
|
||||
continue
|
||||
else:
|
||||
next_margin_list.append(stock)
|
||||
if len(next_margin_list) == margin_needed:
|
||||
break
|
||||
# 更新非融资融券池
|
||||
for stock in normal_list:
|
||||
if stock in after_sell_list:
|
||||
continue
|
||||
else:
|
||||
next_normal_list.append(stock)
|
||||
if len(next_normal_list) >= self.num - margin_needed:
|
||||
break
|
||||
next_position = pd.DataFrame({'stock_code': next_margin_list + next_normal_list})
|
||||
next_position['date'] = date
|
||||
# 融资融券数量
|
||||
margin_num = len(next_margin_list)
|
||||
# next_position['weight'] = self.leverage*((margin_needed-margin_num)/margin_needed) / len(next_position)
|
||||
next_position['weight'] = (1 + (margin_num / self.num)) / self.num
|
||||
next_position['margin_trade'] = 0
|
||||
next_position = next_position.set_index(['stock_code'])
|
||||
next_position.loc[next_margin_list, 'margin_trade'] = 1
|
||||
next_position = next_position.reset_index()
|
||||
# 检测当前持仓是否可以交易
|
||||
frozen_list = []
|
||||
if len(self.position) > 0:
|
||||
for stock in next_position['stock_code']:
|
||||
if stock_status.loc[stock] in [0,2]:
|
||||
frozen_list.append(stock)
|
||||
return sell_list, buy_list, frozen_list, next_position
|
||||
|
||||
def get_price(self,
|
||||
trade_time: str='open',
|
||||
target: Iterable[str]=[],
|
||||
buy_list: Iterable[str] = [],
|
||||
sell_list: Iterable[str] = []):
|
||||
"""
|
||||
获取价格
|
||||
Args:
|
||||
trade_time (str): 交易时间
|
||||
target (Iterable[float]): 目标
|
||||
buy_list (Iterable[str]): 买入目标
|
||||
sell_list (Iterable[str]): 卖出目标
|
||||
"""
|
||||
stock_price = self.today_data[trade_time]
|
||||
target_price = pd.Series(index=target)
|
||||
sell_list = list(set(target) & set(sell_list))
|
||||
buy_list = list(set(target) & set(buy_list))
|
||||
target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
|
||||
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.fee[1])
|
||||
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.fee[0])
|
||||
return target_price
|
||||
|
||||
def check_update_status(self,
|
||||
date: str,
|
||||
trade_time: str):
|
||||
# 判断当前更新状态
|
||||
# 如果日期和交易时间已经存在则返回True
|
||||
if len(self.account_history) == 0:
|
||||
return False
|
||||
elif date < self.account_history['date'].max():
|
||||
return True
|
||||
else:
|
||||
exist_list = self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values
|
||||
if f'{date}-{trade_time}' in exist_list:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def reblance_weight(self,
|
||||
trade_time: str,
|
||||
cur_pos: pd.DataFrame,
|
||||
next_position: pd.DataFrame):
|
||||
"""
|
||||
动态平衡权重
|
||||
"""
|
||||
# 判断冻结列表
|
||||
stock_status = self.today_data['basic'].get(cur_pos['stock_code'].values, 'opening_info')
|
||||
buy_frozen_list = []
|
||||
for stock in cur_pos['stock_code']:
|
||||
if stock_status.loc[stock] in [0,2,4,6]:
|
||||
buy_frozen_list.append(stock)
|
||||
sell_frozen_list = []
|
||||
for stock in cur_pos['stock_code']:
|
||||
if stock_status.loc[stock] in [0,2,5,7]:
|
||||
sell_frozen_list.append(stock)
|
||||
# 设定目标仓位
|
||||
next_position['target_weight'] = next_position['weight']
|
||||
next_position['current_weight'] = 0
|
||||
cur_pos = cur_pos.set_index(['stock_code'])
|
||||
next_position = next_position.set_index(['stock_code'])
|
||||
current_list = list(set(cur_pos.index) & set(next_position.index))
|
||||
next_position.loc[current_list, 'current_weight'] = cur_pos.loc[current_list, 'weight']
|
||||
next_position['open'] = 0
|
||||
next_position.loc[current_list, 'open'] = cur_pos.loc[current_list, 'close']
|
||||
# 计算理想仓位变动
|
||||
next_position['weight_chg'] = next_position['weight'] - next_position['current_weight']
|
||||
# 根据冻结判断是否能够变动
|
||||
next_position['final_weight'] = 0
|
||||
buy_frozen_list = set(buy_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] > 0].index)
|
||||
sell_frozen_list = set(sell_frozen_list) & set(next_position.index) & set(next_position.loc[next_position['weight_chg'] < 0].index)
|
||||
next_position.loc[next_position.index, 'final_weight'] = next_position['weight']
|
||||
next_position.loc[buy_frozen_list, 'final_weight'] = next_position.loc[buy_frozen_list, 'current_weight']
|
||||
next_position.loc[sell_frozen_list, 'final_weight'] = next_position.loc[sell_frozen_list, 'current_weight']
|
||||
# 动态平衡仓位
|
||||
next_position['final_weight'] /= next_position['final_weight'].sum()
|
||||
next_position['final_weight'] *= next_position['weight'].sum()
|
||||
# 计算理想仓位变动
|
||||
next_position['weight_chg'] = next_position['final_weight'] - next_position['current_weight']
|
||||
next_position.loc[list(buy_frozen_list | sell_frozen_list), 'weight_chg'] = 0
|
||||
# 动态平衡价格
|
||||
next_position['adjust_price'] = 0
|
||||
buy_adjust_list = next_position[next_position['weight_chg'] > 0].index.values
|
||||
sell_adjust_list = next_position[next_position['weight_chg'] < 0].index.values
|
||||
next_position.loc[buy_adjust_list, 'adjust_price'] = self.get_price(trade_time, buy_adjust_list, buy_adjust_list, []).values
|
||||
next_position.loc[sell_adjust_list, 'adjust_price'] = self.get_price(trade_time, sell_adjust_list, [], sell_adjust_list).values
|
||||
# 价格调整
|
||||
next_position['adjust_open'] = (next_position['current_weight']*next_position['open'] + next_position['weight_chg']*next_position['adjust_price'])
|
||||
next_position['adjust_open'] = next_position['adjust_open'] / next_position['final_weight']
|
||||
next_position.loc[list(buy_frozen_list | sell_frozen_list), 'adjust_open'] = next_position.loc[list(buy_frozen_list | sell_frozen_list), 'open']
|
||||
# 当日买入不调整
|
||||
next_position['open'] = next_position['adjust_open']
|
||||
next_position['weight'] = next_position['final_weight']
|
||||
next_position = next_position.reset_index()
|
||||
next_position = next_position[['stock_code','date','open','weight','margin_trade']]
|
||||
return next_position
|
||||
|
||||
def update_account(self,
|
||||
date: str,
|
||||
trade_time: str,
|
||||
cur_pos: pd.DataFrame,
|
||||
next_position: pd.DataFrame):
|
||||
"""
|
||||
更新账户
|
||||
Args:
|
||||
date (str): 日期
|
||||
trade_time (str): 交易时间
|
||||
cur_pos (DataFrame): 当前持仓
|
||||
next_position (Iterable[str]): 下一刻持仓
|
||||
"""
|
||||
turnover = pd.concat([
|
||||
cur_pos.set_index(['stock_code'])['weight'].fillna(0).rename('cur'),
|
||||
next_position.set_index(['stock_code'])['weight'].rename('next'),
|
||||
], axis=1)
|
||||
turnover = (turnover['next'] - turnover['cur'].fillna(0)).abs().sum()
|
||||
leverage = next_position['weight'].sum()
|
||||
if cur_pos['weight'].sum() == 0:
|
||||
pnl = 0
|
||||
else:
|
||||
pnl = (cur_pos['end_weight'].sum() - cur_pos['weight'].sum())
|
||||
self.account *= 1+pnl
|
||||
self.account_history = self.account_history.append({
|
||||
'date': date,
|
||||
'trade_time': trade_time,
|
||||
'turnover': turnover,
|
||||
'leverage': leverage,
|
||||
'pnl': pnl
|
||||
}, ignore_index=True)
|
||||
return True
|
||||
|
||||
def update_signal(self,
|
||||
date:str,
|
||||
update_type='rtn'):
|
||||
"""
|
||||
更新信号收益
|
||||
|
||||
Args:
|
||||
update_type (str): 更新类型
|
||||
- position: 只更新持仓不更新收益
|
||||
- rtn: 更新收益和持仓
|
||||
"""
|
||||
# 如果更新日期的close已经记录,则跳过,否则删除现有日期相关记录继续更新
|
||||
if f'{date}-close' in self.account_history['date'].str.cat(self.account_history['trade_time'], sep='-').values:
|
||||
return True
|
||||
else:
|
||||
self.account_history = self.account_history.query(f'date != "{date}" ', engine='python')
|
||||
if date in self.position_history:
|
||||
self.position_history.pop(date)
|
||||
# 更新持仓信号
|
||||
self.load_data(date, update_type)
|
||||
# 如果当前持仓不空,添加隔夜收益,否则直接买入
|
||||
if len(self.position) == 0:
|
||||
cur_pos = pd.DataFrame(columns=['stock_code','date','weight','open','close','margin_trade'])
|
||||
else:
|
||||
cur_pos = self.position.copy()
|
||||
# 冻结列表
|
||||
frozen_list = []
|
||||
# 遍历各个交易时间的信号
|
||||
for idx,trade_time in enumerate(self.signal):
|
||||
if self.check_update_status(date, trade_time):
|
||||
continue
|
||||
if date in self.signal[trade_time].index:
|
||||
factor = self.signal[trade_time].loc[date]
|
||||
else:
|
||||
continue
|
||||
factor = self.signal[trade_time].loc[date]
|
||||
# 获取当前、持仓
|
||||
sell_list, buy_list, frozen_list, next_position = self.get_next_position(date, factor)
|
||||
if update_type == 'position':
|
||||
self.position_history[date] = next_position.copy()
|
||||
return True
|
||||
if len(cur_pos) > 0:
|
||||
cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], sell_list).values
|
||||
# 停牌股价格不变
|
||||
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
|
||||
# 计算收益
|
||||
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
|
||||
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
|
||||
self.update_account(date, trade_time, cur_pos, next_position)
|
||||
# 更新仓位
|
||||
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
|
||||
# 调整权重:买入、卖出、仓位再平衡
|
||||
next_position = self.reblance_weight(trade_time, cur_pos, next_position)
|
||||
else:
|
||||
next_position['open'] = self.get_price(trade_time, next_position['stock_code'].values, buy_list, []).values
|
||||
self.position = next_position.copy()
|
||||
# 收盘统计当日收益
|
||||
trade_time = 'close'
|
||||
if self.check_update_status(date, trade_time):
|
||||
return True
|
||||
cur_pos = self.position.copy()
|
||||
cur_pos['close'] = self.get_price(trade_time, cur_pos['stock_code'].values, [], []).values
|
||||
# 停牌价格不变
|
||||
cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'close'] = cur_pos.loc[cur_pos['stock_code'].isin(frozen_list), 'open']
|
||||
cur_pos.loc[cur_pos['open'] == 0, 'close'] = cur_pos.loc[cur_pos['open'] == 0, 'open']
|
||||
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
|
||||
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
|
||||
position_record = cur_pos.copy()
|
||||
position_record['end_weight'] = (position_record['end_weight'] / position_record['end_weight'].sum()) * position_record['weight'].sum()
|
||||
cur_pos['weight'] = (cur_pos['end_weight'] / cur_pos['end_weight'].sum()) * cur_pos['weight'].sum()
|
||||
next_position = cur_pos.copy()[['stock_code','date','weight','margin_trade']]
|
||||
next_position['open'] = cur_pos['close']
|
||||
self.update_account(date, trade_time, cur_pos, cur_pos)
|
||||
self.position = next_position.copy()
|
||||
self.position_history[date] = position_record.copy()
|
||||
return True
|
Loading…
Reference in New Issue