Fix: 修复自定义价格时,非买卖标的的价格选取 (#24)
This commit is contained in:
parent
411e1a3f78
commit
30a998b58a
|
@ -1,6 +1,6 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union, Iterable, Optional, Dict
|
||||
from typing import Union, Iterable
|
||||
|
||||
class DataLoader():
|
||||
"""
|
||||
|
|
29
trader.py
29
trader.py
|
@ -187,14 +187,19 @@ class Trader(Account):
|
|||
"""
|
||||
self.today_data = dict()
|
||||
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
|
||||
# 遍历从data_root中读取当日数据
|
||||
for path in self.data_root:
|
||||
if path == 'basic':
|
||||
continue
|
||||
else:
|
||||
self.today_data[path] = DataLoader(os.path.join(self.data_root[path],f'{date}.csv'))
|
||||
if update_type == 'position':
|
||||
return True
|
||||
for s in self.signal:
|
||||
if s == 'open':
|
||||
if s in self.data_root:
|
||||
continue
|
||||
else:
|
||||
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
|
||||
self.today_data[s+'_trade'] = DataLoader(self.today_data['open'].data[['open_post']].rename(columns={'open_post':'price'}))
|
||||
self.today_data['open'] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
|
||||
else:
|
||||
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
|
||||
if 'close' in self.signal:
|
||||
|
@ -529,13 +534,19 @@ class Trader(Account):
|
|||
buy_list (Iterable[str]): 买入目标
|
||||
sell_list (Iterable[str]): 卖出目标
|
||||
"""
|
||||
stock_price = self.today_data[trade_time]
|
||||
if trade_time+'_trade' in self.today_data:
|
||||
trade_price = self.today_data[trade_time+'_trade']
|
||||
basic_price = self.today_data[trade_time]
|
||||
else:
|
||||
basic_price = self.today_data[trade_time]
|
||||
trade_price = basic_price
|
||||
target_price = pd.Series(index=target)
|
||||
sell_list = list(set(target) & set(sell_list))
|
||||
buy_list = list(set(target) & set(buy_list))
|
||||
target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
|
||||
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1])
|
||||
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0])
|
||||
# 根据交易和非交易标的分别获取目标价格
|
||||
target_price.loc[target] = basic_price.get(target, 'price').fillna(0)
|
||||
target_price.loc[sell_list] = trade_price.get(sell_list, 'price') * (1 - self.current_fee[1])
|
||||
target_price.loc[buy_list] = trade_price.get(buy_list, 'price') * (1 + self.current_fee[0])
|
||||
return target_price
|
||||
|
||||
def check_update_status(self,
|
||||
|
@ -724,6 +735,8 @@ class Trader(Account):
|
|||
# 计算收益
|
||||
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
|
||||
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
|
||||
# 价格缺失用初始weight填充
|
||||
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
|
||||
self.update_account(date, trade_time, cur_pos, next_position)
|
||||
# 更新仓位
|
||||
cur_pos['weight'] = self.update_next_weight(cur_pos)
|
||||
|
@ -744,6 +757,8 @@ class Trader(Account):
|
|||
# 更新当日收益
|
||||
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
|
||||
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
|
||||
# 价格缺失用初始weight填充
|
||||
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
|
||||
position_record = cur_pos.copy()
|
||||
position_record['end_weight'] = self.update_next_weight(position_record)
|
||||
self.position_history[date] = position_record.copy()[position_fields]
|
||||
|
|
Loading…
Reference in New Issue