From 4b6e4ef6496c1545e359f8c0242335505b73396f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20M=2E=20L=C3=BCscher?= Date: Thu, 19 Dec 2024 14:17:26 +0100 Subject: [PATCH] typing Co-authored-by: Albert Zeyer --- i6_models/parts/lstm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/i6_models/parts/lstm.py b/i6_models/parts/lstm.py index b1109fe1..4d0b3575 100644 --- a/i6_models/parts/lstm.py +++ b/i6_models/parts/lstm.py @@ -19,13 +19,13 @@ class LstmBlockV1Config(ModelConfiguration): enforce_sorted: bool @classmethod - def from_dict(cls, model_cfg_dict: Dict): + def from_dict(cls, model_cfg_dict: Dict[str, Any]): model_cfg_dict = model_cfg_dict.copy() return cls(**model_cfg_dict) class LstmBlockV1(nn.Module): - def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict], **kwargs): + def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict[str, Any]], **kwargs): """ Model definition of LSTM block. Contains single lstm stack and padding sequence in forward call. @@ -34,7 +34,7 @@ def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict], **kwargs): """ super().__init__() - self.cfg = LstmBlockV1Config.from_dict(model_cfg) if isinstance(model_cfg, Dict) else model_cfg + self.cfg = LstmBlockV1Config.from_dict(model_cfg) if isinstance(model_cfg, dict) else model_cfg self.dropout = self.cfg.dropout self.enforce_sorted = self.cgf.enforce_sorted