From 7bd8e6ccc17db4bd8697fda807e791d8da710ca6 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Tue, 28 Nov 2023 10:59:24 -0500
Subject: [PATCH] better error message
---
i6_models/primitives/feature_extraction.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 7c0a7058..ccb2476e 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -128,7 +128,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
else:
- raise ValueError("Invalid spectrum type.")
+ raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")
if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
@@ -145,5 +145,5 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
elif self.spectrum_type == SpectrumType.RFFTN:
length = ((length - self.win_length) // self.hop_length) + 1
else:
- raise ValueError("Invalid spectrum type.")
+ raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")
return feature_data, length.int()