diff --git a/i6_models/parts/frontend/generic_frontend.py b/i6_models/parts/frontend/generic_frontend.py index 07575194..4291e2ca 100644 --- a/i6_models/parts/frontend/generic_frontend.py +++ b/i6_models/parts/frontend/generic_frontend.py @@ -206,9 +206,9 @@ def forward(self, tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[to if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.MaxPool2d): sequence_mask = mask_pool( sequence_mask, - kernel_size=layer.kernel_size if isinstance(layer.kernel_size, int) else layer.kernel_size[0], - stride=layer.stride if isinstance(layer.stride, int) else layer.stride[0], - padding=layer.padding if isinstance(layer.padding, int) else layer.padding[0], + kernel_size=layer.kernel_size[0], + stride=layer.stride[0], + padding=layer.padding[0], ) tensor = torch.transpose(tensor, 1, 2) # transpose to [B,T",C,F"]