From ec3356fac4a5d701627799f274266080868ff96d Mon Sep 17 00:00:00 2001 From: Judyxujj Date: Fri, 10 Nov 2023 15:23:06 +0100 Subject: [PATCH] Update i6_models/parts/frontend/generic_frontend.py Co-authored-by: SimBe195 <37951951+SimBe195@users.noreply.github.com> --- i6_models/parts/frontend/generic_frontend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/i6_models/parts/frontend/generic_frontend.py b/i6_models/parts/frontend/generic_frontend.py index 07575194..4291e2ca 100644 --- a/i6_models/parts/frontend/generic_frontend.py +++ b/i6_models/parts/frontend/generic_frontend.py @@ -206,9 +206,9 @@ def forward(self, tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[to if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.MaxPool2d): sequence_mask = mask_pool( sequence_mask, - kernel_size=layer.kernel_size if isinstance(layer.kernel_size, int) else layer.kernel_size[0], - stride=layer.stride if isinstance(layer.stride, int) else layer.stride[0], - padding=layer.padding if isinstance(layer.padding, int) else layer.padding[0], + kernel_size=layer.kernel_size[0], + stride=layer.stride[0], + padding=layer.padding[0], ) tensor = torch.transpose(tensor, 1, 2) # transpose to [B,T",C,F"]