spread_backtest/dataloader.py

31 lines
1019 B
Python
Raw Permalink Normal View History

2024-05-22 23:33:19 +08:00
import numpy as np
import pandas as pd
from typing import Union, Iterable
2024-05-22 23:33:19 +08:00
class DataLoader():
"""
数据类: 数据加载模块
"""
def __init__(self, path: Union[str, pd.DataFrame]=None):
if type(path) == str:
self.data = pd.read_csv(path).set_index(['stock_code']).sort_index()
if type(path) == pd.DataFrame:
self.data = path.copy()
def get(self,
target: Iterable[str]=[],
column: str=''):
"""
- target: 查询目标代码
- column(str): 查询列
"""
res = pd.Series(index=target)
column_type = self.data[column].dtype
stock_list = list(set(self.data.index.values) & set(target))
res.loc[stock_list] = self.data.loc[stock_list, column].values
stock_list = list(set(target) - set(stock_list))
if column_type == str:
res.loc[stock_list] = None
else:
res.loc[stock_list] = np.nan
return res.sort_index()