diff --git a/i6_models/assemblies/lstm/lstm_v1.py b/i6_models/assemblies/lstm/lstm_v1.py index 015381ab..c382c33a 100644 --- a/i6_models/assemblies/lstm/lstm_v1.py +++ b/i6_models/assemblies/lstm/lstm_v1.py @@ -14,6 +14,16 @@ @dataclass class LstmEncoderV1Config(ModelConfiguration): + """ + :param init_args: used to initialize parameters of modules, example: + ``` + { + "init_args_w": {"func": "normal", "arg": {"mean": 0.0, "std": 0.1}}, + "init_args_b": {"func": "normal", "arg": {"mean": 0.0, "std": 0.1}}, + } + ``` + """ + input_dim: int embed_dim: int embed_dropout: float @@ -30,6 +40,13 @@ def from_dict(cls, model_cfg_dict: Dict[str, Any]): class LstmEncoderV1(nn.Module): def __init__(self, model_cfg: Union[LstmEncoderV1Config, Dict[str, Any]], **kwargs): + """ + Model definition of LSTM encoder. Contains embedding layer followed by single lstm stack, dropout after both. + Padding sequence in forward call. + + :param model_cfg: holds model configuration as dataclass or dict instance. + :param kwargs: + """ super().__init__() self.cfg = LstmEncoderV1Config.from_dict(model_cfg) if isinstance(model_cfg, dict) else model_cfg