Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
Co-authored-by: Albert Zeyer <[email protected]>
  • Loading branch information
christophmluscher and albertz authored Dec 19, 2024
1 parent 6bf9e2e commit 4b6e4ef
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions i6_models/parts/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ class LstmBlockV1Config(ModelConfiguration):
enforce_sorted: bool

@classmethod
def from_dict(cls, model_cfg_dict: Dict):
def from_dict(cls, model_cfg_dict: Dict[str, Any]):
model_cfg_dict = model_cfg_dict.copy()
return cls(**model_cfg_dict)


class LstmBlockV1(nn.Module):
def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict], **kwargs):
def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict[str, Any]], **kwargs):
"""
Model definition of LSTM block. Contains single lstm stack and padding sequence in forward call.
Expand All @@ -34,7 +34,7 @@ def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict], **kwargs):
"""
super().__init__()

self.cfg = LstmBlockV1Config.from_dict(model_cfg) if isinstance(model_cfg, Dict) else model_cfg
self.cfg = LstmBlockV1Config.from_dict(model_cfg) if isinstance(model_cfg, dict) else model_cfg

self.dropout = self.cfg.dropout
self.enforce_sorted = self.cgf.enforce_sorted
Expand Down

0 comments on commit 4b6e4ef

Please sign in to comment.