31 lines
1019 B
Python
31 lines
1019 B
Python
import numpy as np
|
|
import pandas as pd
|
|
from typing import Union, Iterable
|
|
|
|
class DataLoader():
|
|
"""
|
|
数据类: 数据加载模块
|
|
"""
|
|
def __init__(self, path: Union[str, pd.DataFrame]=None):
|
|
if type(path) == str:
|
|
self.data = pd.read_csv(path).set_index(['stock_code']).sort_index()
|
|
if type(path) == pd.DataFrame:
|
|
self.data = path.copy()
|
|
|
|
def get(self,
|
|
target: Iterable[str]=[],
|
|
column: str=''):
|
|
"""
|
|
- target: 查询目标代码
|
|
- column(str): 查询列
|
|
"""
|
|
res = pd.Series(index=target)
|
|
column_type = self.data[column].dtype
|
|
stock_list = list(set(self.data.index.values) & set(target))
|
|
res.loc[stock_list] = self.data.loc[stock_list, column].values
|
|
stock_list = list(set(target) - set(stock_list))
|
|
if column_type == str:
|
|
res.loc[stock_list] = None
|
|
else:
|
|
res.loc[stock_list] = np.nan
|
|
return res.sort_index() |