From 5ea9ae46227b5acebc8768c96db9fef4bc9f4b14 Mon Sep 17 00:00:00 2001 From: Heinrich Dinkel Date: Wed, 18 Sep 2024 01:32:10 -0500 Subject: [PATCH 1/3] Added support for variable length inputs larger than 10s --- dasheng_model/modeling_dasheng.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dasheng_model/modeling_dasheng.py b/dasheng_model/modeling_dasheng.py index cc83461..97bccae 100644 --- a/dasheng_model/modeling_dasheng.py +++ b/dasheng_model/modeling_dasheng.py @@ -367,7 +367,8 @@ def forward_features(self, x): def forward(self, x): x = self.init_bn(x) if self.init_bn is not None else x - + # Remember starting position if we pad + padding_start_in_chunks = 0 if x.shape[-1] > self.target_length: splits = x.split(self.target_length, -1) @@ -375,6 +376,7 @@ def forward(self, x): if self.pad_last: pad = torch.zeros(*x.shape[:-1], self.target_length, device=x.device) pad[..., : splits[-1].shape[-1]] = splits[-1] + padding_start_in_chunks = splits[-1].shape[-1] // self.patch_stride[-1] splits = torch.stack((*splits[:-1], pad), dim=0) else: splits = torch.stack(splits[:-1], dim=0) @@ -387,6 +389,8 @@ def forward(self, x): x = self.forward_features(x) x = torch.reshape(x, (x.shape[0] // n_splits, -1, x.shape[-1])) + if padding_start_in_chunks: + x = x[:,:-padding_start_in_chunks, :] return x From 7c59b4e64d9823b23b993df6cfb4fc42b227c0d8 Mon Sep 17 00:00:00 2001 From: Heinrich Dinkel Date: Wed, 18 Sep 2024 05:13:38 -0500 Subject: [PATCH 2/3] Padding as default --- dasheng_model/modeling_dasheng.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dasheng_model/modeling_dasheng.py b/dasheng_model/modeling_dasheng.py index 97bccae..6c90204 100644 --- a/dasheng_model/modeling_dasheng.py +++ b/dasheng_model/modeling_dasheng.py @@ -291,7 +291,7 @@ def __init__( self.eval_avg = eval_avg self.time_patch_out = time_patch_out self.freq_patch_out = freq_patch_out - self.pad_last = kwargs.get("pad_last", False) + self.pad_last = kwargs.get("pad_last", True) if init_bn: self.init_bn = nn.Sequential( From e2d96c9de0b4d4dfeb76f3905d05aad9a86187a3 Mon Sep 17 00:00:00 2001 From: Heinrich Dinkel Date: Wed, 18 Sep 2024 18:28:34 +0800 Subject: [PATCH 3/3] Added max length --- dasheng_model/modeling_dasheng.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dasheng_model/modeling_dasheng.py b/dasheng_model/modeling_dasheng.py index 6c90204..11204c4 100644 --- a/dasheng_model/modeling_dasheng.py +++ b/dasheng_model/modeling_dasheng.py @@ -368,7 +368,7 @@ def forward_features(self, x): def forward(self, x): x = self.init_bn(x) if self.init_bn is not None else x # Remember starting position if we pad - padding_start_in_chunks = 0 + padding_start = 0 if x.shape[-1] > self.target_length: splits = x.split(self.target_length, -1) @@ -376,7 +376,7 @@ def forward(self, x): if self.pad_last: pad = torch.zeros(*x.shape[:-1], self.target_length, device=x.device) pad[..., : splits[-1].shape[-1]] = splits[-1] - padding_start_in_chunks = splits[-1].shape[-1] // self.patch_stride[-1] + padding_start = x.shape[-1] // self.patch_stride[-1] splits = torch.stack((*splits[:-1], pad), dim=0) else: splits = torch.stack(splits[:-1], dim=0) @@ -389,8 +389,8 @@ def forward(self, x): x = self.forward_features(x) x = torch.reshape(x, (x.shape[0] // n_splits, -1, x.shape[-1])) - if padding_start_in_chunks: - x = x[:,:-padding_start_in_chunks, :] + if padding_start: + x = x[:,:padding_start, :] return x