diff --git a/spread_backtest.py b/spread_backtest.py index 86aaaa4..bf2c999 100644 --- a/spread_backtest.py +++ b/spread_backtest.py @@ -8,7 +8,7 @@ from rich import print as rprint from rich.table import Table -class Spread_Backtest(): +class SpreadBacktest(): def __init__( self, trader: Trader diff --git a/trader.py b/trader.py index 8f580cf..1bac6b2 100644 --- a/trader.py +++ b/trader.py @@ -36,6 +36,7 @@ class Trader(Account): exclude_list (list): 额外的剔除列表,会优先满足该剔除列表中的条件,之后再进行正常的调仓 - abnormal: 异常公告剔除,包含中止上市、立案调查、警示函等异常情况的剔除 - report: 财报同比下降50%以上剔除 + account (Account): 账户设置,account.Account """ def __init__(self, signal: Dict[str, pd.DataFrame]=None, @@ -57,9 +58,10 @@ class Trader(Account): '2023-08-28': (0, 0.0005) }, exclude_list: list=[], + account: dict={}, **kwargs) -> None: # 初始化账户 - super().__init__(**kwargs.get('account', {})) + super().__init__(account) if isinstance(signal, dict): self.signal = signal if 'close' in signal: @@ -71,6 +73,8 @@ class Trader(Account): # -------------------- # 参数检验 # -------------------- + if len(kwargs) > 0: + raise ValueError(f"Unexpected keyword argument '{','.join(kwargs.keys())}'") # interval self.interval = [] for s in signal: