diff --git a/i6_models/parts/lstm.py b/i6_models/parts/lstm.py index 898c0755..b1109fe1 100644 --- a/i6_models/parts/lstm.py +++ b/i6_models/parts/lstm.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import torch from torch import nn -from typing import Dict, Union +from typing import Dict, Tuple, Union from i6_models.config import ModelConfiguration @@ -26,12 +26,18 @@ def from_dict(cls, model_cfg_dict: Dict): class LstmBlockV1(nn.Module): def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict], **kwargs): + """ + Model definition of LSTM block. Contains single lstm stack and padding sequence in forward call. + + :param model_cfg: holds model configuration as dataclass or dict instance. + :param kwargs: + """ super().__init__() self.cfg = LstmBlockV1Config.from_dict(model_cfg) if isinstance(model_cfg, Dict) else model_cfg self.dropout = self.cfg.dropout - self.enforce_sorted = None + self.enforce_sorted = self.cgf.enforce_sorted self.lstm_stack = nn.LSTM( input_size=self.cfg.input_dim, hidden_size=self.cfg.hidden_dim,