Skip to content

Commit

Permalink
add import, set var correctly, add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
christophmluscher committed Dec 19, 2024
1 parent 624f6c1 commit 6bf9e2e
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions i6_models/parts/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
import torch
from torch import nn
from typing import Dict, Union
from typing import Dict, Tuple, Union

from i6_models.config import ModelConfiguration

Expand All @@ -26,12 +26,18 @@ def from_dict(cls, model_cfg_dict: Dict):

class LstmBlockV1(nn.Module):
def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict], **kwargs):
"""
Model definition of LSTM block. Contains single lstm stack and padding sequence in forward call.
:param model_cfg: holds model configuration as dataclass or dict instance.
:param kwargs:
"""
super().__init__()

self.cfg = LstmBlockV1Config.from_dict(model_cfg) if isinstance(model_cfg, Dict) else model_cfg

self.dropout = self.cfg.dropout
self.enforce_sorted = None
self.enforce_sorted = self.cgf.enforce_sorted
self.lstm_stack = nn.LSTM(
input_size=self.cfg.input_dim,
hidden_size=self.cfg.hidden_dim,
Expand Down

0 comments on commit 6bf9e2e

Please sign in to comment.