Skip to content

Commit

Permalink
add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
christophmluscher committed Dec 19, 2024
1 parent be231ac commit 531b2f9
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions i6_models/assemblies/lstm/lstm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

@dataclass
class LstmEncoderV1Config(ModelConfiguration):
"""
:param init_args: used to initialize parameters of modules, example:
```
{
"init_args_w": {"func": "normal", "arg": {"mean": 0.0, "std": 0.1}},
"init_args_b": {"func": "normal", "arg": {"mean": 0.0, "std": 0.1}},
}
```
"""

input_dim: int
embed_dim: int
embed_dropout: float
Expand All @@ -30,6 +40,13 @@ def from_dict(cls, model_cfg_dict: Dict[str, Any]):

class LstmEncoderV1(nn.Module):
def __init__(self, model_cfg: Union[LstmEncoderV1Config, Dict[str, Any]], **kwargs):
"""
Model definition of LSTM encoder. Contains embedding layer followed by single lstm stack, dropout after both.
Padding sequence in forward call.
:param model_cfg: holds model configuration as dataclass or dict instance.
:param kwargs:
"""
super().__init__()

self.cfg = LstmEncoderV1Config.from_dict(model_cfg) if isinstance(model_cfg, dict) else model_cfg
Expand Down

0 comments on commit 531b2f9

Please sign in to comment.