diff --git a/dasheng_model/modeling_dasheng.py b/dasheng_model/modeling_dasheng.py index cc83461..11204c4 100644 --- a/dasheng_model/modeling_dasheng.py +++ b/dasheng_model/modeling_dasheng.py @@ -291,7 +291,7 @@ def __init__( self.eval_avg = eval_avg self.time_patch_out = time_patch_out self.freq_patch_out = freq_patch_out - self.pad_last = kwargs.get("pad_last", False) + self.pad_last = kwargs.get("pad_last", True) if init_bn: self.init_bn = nn.Sequential( @@ -367,7 +367,8 @@ def forward_features(self, x): def forward(self, x): x = self.init_bn(x) if self.init_bn is not None else x - + # Remember starting position if we pad + padding_start = 0 if x.shape[-1] > self.target_length: splits = x.split(self.target_length, -1) @@ -375,6 +376,7 @@ def forward(self, x): if self.pad_last: pad = torch.zeros(*x.shape[:-1], self.target_length, device=x.device) pad[..., : splits[-1].shape[-1]] = splits[-1] + padding_start = x.shape[-1] // self.patch_stride[-1] splits = torch.stack((*splits[:-1], pad), dim=0) else: splits = torch.stack(splits[:-1], dim=0) @@ -387,6 +389,8 @@ def forward(self, x): x = self.forward_features(x) x = torch.reshape(x, (x.shape[0] // n_splits, -1, x.shape[-1])) + if padding_start: + x = x[:,:padding_start, :] return x