Skip to content

Commit

Permalink
Merge pull request #2 from RicherMans/main
Browse files Browse the repository at this point in the history
Added variable length inference support
  • Loading branch information
jimbozhang authored Sep 19, 2024
2 parents 5af488c + e2d96c9 commit 58eb3fb
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions dasheng_model/modeling_dasheng.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __init__(
self.eval_avg = eval_avg
self.time_patch_out = time_patch_out
self.freq_patch_out = freq_patch_out
self.pad_last = kwargs.get("pad_last", False)
self.pad_last = kwargs.get("pad_last", True)

if init_bn:
self.init_bn = nn.Sequential(
Expand Down Expand Up @@ -367,14 +367,16 @@ def forward_features(self, x):

def forward(self, x):
x = self.init_bn(x) if self.init_bn is not None else x

# Remember starting position if we pad
padding_start = 0
if x.shape[-1] > self.target_length:
splits = x.split(self.target_length, -1)

if splits[-1].shape[-1] < self.target_length:
if self.pad_last:
pad = torch.zeros(*x.shape[:-1], self.target_length, device=x.device)
pad[..., : splits[-1].shape[-1]] = splits[-1]
padding_start = x.shape[-1] // self.patch_stride[-1]
splits = torch.stack((*splits[:-1], pad), dim=0)
else:
splits = torch.stack(splits[:-1], dim=0)
Expand All @@ -387,6 +389,8 @@ def forward(self, x):

x = self.forward_features(x)
x = torch.reshape(x, (x.shape[0] // n_splits, -1, x.shape[-1]))
if padding_start:
x = x[:,:padding_start, :]
return x


Expand Down

0 comments on commit 58eb3fb

Please sign in to comment.