import numpy as np import pandas as pd from typing import Union, Iterable class DataLoader(): """ 数据类: 数据加载模块 """ def __init__(self, path: Union[str, pd.DataFrame]=None): if type(path) == str: self.data = pd.read_csv(path).set_index(['stock_code']).sort_index() if type(path) == pd.DataFrame: self.data = path.copy() def get(self, target: Iterable[str]=[], column: str=''): """ - target: 查询目标代码 - column(str): 查询列 """ res = pd.Series(index=target) column_type = self.data[column].dtype stock_list = list(set(self.data.index.values) & set(target)) res.loc[stock_list] = self.data.loc[stock_list, column].values stock_list = list(set(target) - set(stock_list)) if column_type == str: res.loc[stock_list] = None else: res.loc[stock_list] = np.nan return res.sort_index()