diff --git a/dataloader.py b/dataloader.py index 0d115af..90f8267 100644 --- a/dataloader.py +++ b/dataloader.py @@ -1,6 +1,6 @@ import numpy as np import pandas as pd -from typing import Union, Iterable, Optional, Dict +from typing import Union, Iterable class DataLoader(): """ diff --git a/trader.py b/trader.py index d085f5b..4f71eb0 100644 --- a/trader.py +++ b/trader.py @@ -187,14 +187,19 @@ class Trader(Account): """ self.today_data = dict() self.today_data['basic'] = DataLoader(os.path.join(self.data_root['basic'],f'{date}.csv')) + # 遍历从data_root中读取当日数据 + for path in self.data_root: + if path == 'basic': + continue + else: + self.today_data[path] = DataLoader(os.path.join(self.data_root[path],f'{date}.csv')) if update_type == 'position': return True for s in self.signal: if s == 'open': if s in self.data_root: - continue - else: - self.today_data[s] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'})) + self.today_data[s+'_trade'] = DataLoader(self.today_data['open'].data[['open_post']].rename(columns={'open_post':'price'})) + self.today_data['open'] = DataLoader(self.today_data['basic'].data[['open_post']].rename(columns={'open_post':'price'})) else: self.today_data[s] = DataLoader(os.path.join(self.data_root[s],f'{date}.csv')) if 'close' in self.signal: @@ -529,13 +534,19 @@ class Trader(Account): buy_list (Iterable[str]): 买入目标 sell_list (Iterable[str]): 卖出目标 """ - stock_price = self.today_data[trade_time] + if trade_time+'_trade' in self.today_data: + trade_price = self.today_data[trade_time+'_trade'] + basic_price = self.today_data[trade_time] + else: + basic_price = self.today_data[trade_time] + trade_price = basic_price target_price = pd.Series(index=target) sell_list = list(set(target) & set(sell_list)) buy_list = list(set(target) & set(buy_list)) - target_price.loc[target] = stock_price.get(target, 'price').fillna(0) - target_price.loc[sell_list] = stock_price.get(sell_list, 'price') * (1 - self.current_fee[1]) - target_price.loc[buy_list] = stock_price.get(buy_list, 'price') * (1 + self.current_fee[0]) + # 根据交易和非交易标的分别获取目标价格 + target_price.loc[target] = basic_price.get(target, 'price').fillna(0) + target_price.loc[sell_list] = trade_price.get(sell_list, 'price') * (1 - self.current_fee[1]) + target_price.loc[buy_list] = trade_price.get(buy_list, 'price') * (1 + self.current_fee[0]) return target_price def check_update_status(self, @@ -724,6 +735,8 @@ class Trader(Account): # 计算收益 cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) cur_pos['end_weight'] = cur_pos['weight'] * cur_pos['rtn'] + # 价格缺失用初始weight填充 + cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight'] self.update_account(date, trade_time, cur_pos, next_position) # 更新仓位 cur_pos['weight'] = self.update_next_weight(cur_pos) @@ -744,6 +757,8 @@ class Trader(Account): # 更新当日收益 cur_pos['rtn'] = (cur_pos['close'] / cur_pos['open']) - 1 cur_pos['end_weight'] = cur_pos['weight'] * (cur_pos['rtn'] + 1) + # 价格缺失用初始weight填充 + cur_pos.loc[pd.isnull(cur_pos['rtn']), 'end_weight'] = cur_pos.loc[pd.isnull(cur_pos['rtn']), 'weight'] position_record = cur_pos.copy() position_record['end_weight'] = self.update_next_weight(position_record) self.position_history[date] = position_record.copy()[position_fields]