Fix: 修复自定义价格时,非买卖标的的价格选取 (#24)

This commit is contained in:
binz 2024-06-17 20:35:25 +08:00
parent 411e1a3f78
commit 30a998b58a
2 changed files with 23 additions and 8 deletions

View File

@ -1,6 +1,6 @@
import numpy as np
import pandas as pd
from typing import Union, Iterable, Optional, Dict
from typing import Union, Iterable
class DataLoader():
"""

View File

@ -187,14 +187,19 @@ class Trader(Account):
"""
self.today_data = dict()
self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv'))
# 遍历从data_root中读取当日数据
for path in self.data_root:
if path == 'basic':
continue
else:
self.today_data[path] = DataLoader(os.path.join(self.data_root[path],f'{date}.csv'))
if update_type == 'position':
return True
for s in self.signal:
if s == 'open':
if s in self.data_root:
continue
else:
self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
self.today_data[s+'_trade'] = DataLoader(self.today_data['open'].data[['open_post']].rename(columns={'open_post':'price'}))
self.today_data['open'] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'}))
else:
self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv'))
if 'close' in self.signal:
@ -529,13 +534,19 @@ class Trader(Account):
buy_list (Iterable[str]): 买入目标
sell_list (Iterable[str]): 卖出目标
"""
stock_price = self.today_data[trade_time]
if trade_time+'_trade' in self.today_data:
trade_price = self.today_data[trade_time+'_trade']
basic_price = self.today_data[trade_time]
else:
basic_price = self.today_data[trade_time]
trade_price = basic_price
target_price = pd.Series(index=target)
sell_list = list(set(target) & set(sell_list))
buy_list = list(set(target) & set(buy_list))
target_price.loc[target] = stock_price.get(target, 'price').fillna(0)
target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1])
target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0])
# 根据交易和非交易标的分别获取目标价格
target_price.loc[target] = basic_price.get(target, 'price').fillna(0)
target_price.loc[sell_list] = trade_price.get(sell_list, 'price') * (1 - self.current_fee[1])
target_price.loc[buy_list] = trade_price.get(buy_list, 'price') * (1 + self.current_fee[0])
return target_price
def check_update_status(self,
@ -724,6 +735,8 @@ class Trader(Account):
# 计算收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open'])
cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn']
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
self.update_account(date, trade_time, cur_pos, next_position)
# 更新仓位
cur_pos['weight'] = self.update_next_weight(cur_pos)
@ -744,6 +757,8 @@ class Trader(Account):
# 更新当日收益
cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1
cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1)
# 价格缺失用初始weight填充
cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight']
position_record = cur_pos.copy()
position_record['end_weight'] = self.update_next_weight(position_record)
self.position_history[date] = position_record.copy()[position_fields]