From 0c7bb2f4d9be004ecc1ed7966e0ca9615a03dd11 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Wed, 15 Nov 2023 06:24:13 -0500
Subject: [PATCH 1/9] extend feature extraction
---
i6_models/primitives/feature_extraction.py | 54 ++++++++++++++--------
1 file changed, 36 insertions(+), 18 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index ead52dd5..806a8570 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -1,7 +1,7 @@
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Any, Dict
from librosa import filters
import torch
@@ -22,6 +22,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
min_amp: minimum amplitude for safe log
num_filters: number of mel windows
center: centered STFT with automatic padding
+ periodic: whether the window is assumed to be periodic
+ mel_options: extra options for mel filters
+ rasr_compatible: apply FFT to make features compatible to RASR's
"""
sample_rate: int
@@ -33,6 +36,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
num_filters: int
center: bool
n_fft: Optional[int] = None
+ periodic: Optional[bool] = True
+ mel_options: Optional[Dict[str, Any]] = None
+ rasr_compatible: Optional[bool] = False
def __post_init__(self) -> None:
super().__post_init__()
@@ -62,6 +68,8 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
self.min_amp = cfg.min_amp
self.n_fft = cfg.n_fft
self.win_length = int(cfg.win_size * cfg.sample_rate)
+ self.mel_options = cfg.mel_options or {}
+ self.rasr_compatible = cfg.rasr_compatible
self.register_buffer(
"mel_basis",
@@ -72,10 +80,11 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
n_mels=cfg.num_filters,
fmin=cfg.f_min,
fmax=cfg.f_max,
+ **self.mel_options,
)
),
)
- self.register_buffer("window", torch.hann_window(self.win_length))
+ self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic))
def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@@ -83,25 +92,34 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
- power_spectrum = (
- torch.abs(
- torch.stft(
- raw_audio,
- n_fft=self.n_fft,
- hop_length=self.hop_length,
- win_length=self.win_length,
- window=self.window,
- center=self.center,
- pad_mode="constant",
- return_complex=True,
+ if not self.rasr_compatible:
+ power_spectrogram = (
+ torch.abs(
+ torch.stft(
+ raw_audio,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=self.window,
+ center=self.center,
+ pad_mode="constant",
+ return_complex=True,
+ )
)
+ ** 2
)
- ** 2
- )
- if len(power_spectrum.size()) == 2:
+ else:
+ windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length)
+ smoothed = windowed * self.window.unsqueeze(0)
+
+ # Compute power spectrogram using torch.fft.rfftn
+ power_spectrogram = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T]
+ power_spectrogram = power_spectrogram.transpose(1, 2) # [B, T, F]
+
+ if len(power_spectrogram.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
- power_spectrum = torch.unsqueeze(power_spectrum, 0)
- melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
+ power_spectrogram = torch.unsqueeze(power_spectrogram, 0)
+ melspec = torch.einsum("...ft,mf->...mt", power_spectrogram, self.mel_basis)
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
feature_data = torch.transpose(log_melspec, 1, 2)
From 91e8e0bdd8ccba06d3296e286a5f54a80815b30c Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Wed, 15 Nov 2023 06:28:25 -0500
Subject: [PATCH 2/9] restore names
---
i6_models/primitives/feature_extraction.py | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 806a8570..6df10574 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -93,7 +93,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
:return features as [B,T,F] and length in frames [B]
"""
if not self.rasr_compatible:
- power_spectrogram = (
+ power_spectrum = (
torch.abs(
torch.stft(
raw_audio,
@@ -112,14 +112,14 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length)
smoothed = windowed * self.window.unsqueeze(0)
- # Compute power spectrogram using torch.fft.rfftn
- power_spectrogram = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T]
- power_spectrogram = power_spectrogram.transpose(1, 2) # [B, T, F]
+ # Compute power spectrum using torch.fft.rfftn
+ power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T]
+ power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F]
- if len(power_spectrogram.size()) == 2:
+ if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
- power_spectrogram = torch.unsqueeze(power_spectrogram, 0)
- melspec = torch.einsum("...ft,mf->...mt", power_spectrogram, self.mel_basis)
+ power_spectrum = torch.unsqueeze(power_spectrum, 0)
+ melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
feature_data = torch.transpose(log_melspec, 1, 2)
From 1559a09df0776642bbde7baffc1dc97baa2a70f1 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Wed, 15 Nov 2023 09:16:50 -0500
Subject: [PATCH 3/9] reorder
---
i6_models/primitives/feature_extraction.py | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 6df10574..dadf6e41 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -92,7 +92,14 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
- if not self.rasr_compatible:
+ if self.rasr_compatible:
+ windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length)
+ smoothed = windowed * self.window.unsqueeze(0)
+
+ # Compute power spectrum using torch.fft.rfftn
+ power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T]
+ power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F]
+ else:
power_spectrum = (
torch.abs(
torch.stft(
@@ -108,13 +115,6 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
)
** 2
)
- else:
- windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length)
- smoothed = windowed * self.window.unsqueeze(0)
-
- # Compute power spectrum using torch.fft.rfftn
- power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T]
- power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F]
if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
From e3101cf89968c42b393cbe0913debe6a3ec3db1c Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Wed, 15 Nov 2023 12:07:16 -0500
Subject: [PATCH 4/9] fix type annotations, add shape annotations
---
i6_models/primitives/feature_extraction.py | 24 +++++++++++-----------
1 file changed, 12 insertions(+), 12 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index dadf6e41..824267c3 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -24,7 +24,7 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
center: centered STFT with automatic padding
periodic: whether the window is assumed to be periodic
mel_options: extra options for mel filters
- rasr_compatible: apply FFT to make features compatible to RASR's
+ rasr_compatible: apply FFT to make features compatible to RASR's, otherwise (defalut) apply STFT
"""
sample_rate: int
@@ -36,9 +36,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
num_filters: int
center: bool
n_fft: Optional[int] = None
- periodic: Optional[bool] = True
+ periodic: bool = True
mel_options: Optional[Dict[str, Any]] = None
- rasr_compatible: Optional[bool] = False
+ rasr_compatible: bool = False
def __post_init__(self) -> None:
super().__post_init__()
@@ -93,12 +93,12 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
:return features as [B,T,F] and length in frames [B]
"""
if self.rasr_compatible:
- windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length)
- smoothed = windowed * self.window.unsqueeze(0)
+ windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
+ smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]
# Compute power spectrum using torch.fft.rfftn
- power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T]
- power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F]
+ power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
+ power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
else:
power_spectrum = (
torch.abs(
@@ -118,14 +118,14 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
- power_spectrum = torch.unsqueeze(power_spectrum, 0)
- melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
+ power_spectrum = torch.unsqueeze(power_spectrum, 0) # [B, F, T']
+ melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T']
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
- feature_data = torch.transpose(log_melspec, 1, 2)
+ feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
- if self.center:
+ if self.center and not rasr_compatible:
length = (length // self.hop_length) + 1
else:
- length = ((length - self.n_fft) // self.hop_length) + 1
+ length = ((length - self.win_length) // self.hop_length) + 1
return feature_data, length.int()
From cf72c0a3b2f1981390b6c13aab69e61ac76b271e Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Wed, 15 Nov 2023 12:12:48 -0500
Subject: [PATCH 5/9] fix
---
i6_models/primitives/feature_extraction.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 824267c3..ca716b37 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -123,7 +123,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
- if self.center and not rasr_compatible:
+ if self.center and not self.rasr_compatible:
length = (length // self.hop_length) + 1
else:
length = ((length - self.win_length) // self.hop_length) + 1
From 221a81389d8fb6800482820d59740b3b755eb320 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Wed, 15 Nov 2023 12:36:35 -0500
Subject: [PATCH 6/9] fix typo
---
i6_models/primitives/feature_extraction.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index ca716b37..296068ad 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -24,7 +24,7 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
center: centered STFT with automatic padding
periodic: whether the window is assumed to be periodic
mel_options: extra options for mel filters
- rasr_compatible: apply FFT to make features compatible to RASR's, otherwise (defalut) apply STFT
+ rasr_compatible: apply FFT to make features compatible to RASR's, otherwise (default) apply STFT
"""
sample_rate: int
From cfb3f5661c9d9941c43314ec7b6f482b80fb0935 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Fri, 24 Nov 2023 10:55:36 -0500
Subject: [PATCH 7/9] pass mel filter parameters directly & make spectrum type
enum & fix length computation
---
i6_models/primitives/feature_extraction.py | 55 ++++++++++++++--------
1 file changed, 35 insertions(+), 20 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 296068ad..c1042030 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -1,15 +1,23 @@
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]
from dataclasses import dataclass
-from typing import Optional, Tuple, Any, Dict
+from typing import Optional, Tuple, Union, Literal
+from enum import Enum
from librosa import filters
import torch
from torch import nn
+import numpy as np
+from numpy.typing import DTypeLike
from i6_models.config import ModelConfiguration
+class SpectrumType(Enum):
+ STFT = 1
+ RFFTN = 2
+
+
@dataclass
class LogMelFeatureExtractionV1Config(ModelConfiguration):
"""
@@ -23,8 +31,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
num_filters: number of mel windows
center: centered STFT with automatic padding
periodic: whether the window is assumed to be periodic
- mel_options: extra options for mel filters
- rasr_compatible: apply FFT to make features compatible to RASR's, otherwise (default) apply STFT
+ htk: whether use HTK formula instead of Slaney
+ norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html
+ spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's
"""
sample_rate: int
@@ -37,8 +46,10 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
center: bool
n_fft: Optional[int] = None
periodic: bool = True
- mel_options: Optional[Dict[str, Any]] = None
- rasr_compatible: bool = False
+ htk: bool = False
+ norm: Optional[Union[Literal["slaney"], float]] = "slaney"
+ dtype: DTypeLike = np.float32
+ spectrum_type: SpectrumType = SpectrumType.STFT
def __post_init__(self) -> None:
super().__post_init__()
@@ -68,8 +79,7 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
self.min_amp = cfg.min_amp
self.n_fft = cfg.n_fft
self.win_length = int(cfg.win_size * cfg.sample_rate)
- self.mel_options = cfg.mel_options or {}
- self.rasr_compatible = cfg.rasr_compatible
+ self.spectrum_type = cfg.spectrum_type
self.register_buffer(
"mel_basis",
@@ -80,8 +90,10 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
n_mels=cfg.num_filters,
fmin=cfg.f_min,
fmax=cfg.f_max,
- **self.mel_options,
- )
+ htk=cfg.htk,
+ norm=cfg.norm,
+ dtype=cfg.dtype,
+ ),
),
)
self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic))
@@ -92,14 +104,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
- if self.rasr_compatible:
- windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
- smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]
-
- # Compute power spectrum using torch.fft.rfftn
- power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
- power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
- else:
+ if self.spectrum_type == SpectrumType.STFT:
power_spectrum = (
torch.abs(
torch.stft(
@@ -115,6 +120,13 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
)
** 2
)
+ elif self.spectrum_type == SpectrumType.RFFTN:
+ windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
+ smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]
+
+ # Compute power spectrum using torch.fft.rfftn
+ power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
+ power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
@@ -123,9 +135,12 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
- if self.center and not self.rasr_compatible:
- length = (length // self.hop_length) + 1
- else:
+ if self.spectrum_type == SpectrumType.STFT:
+ if self.center:
+ length = (length // self.hop_length) + 1
+ else:
+ length = ((length - self.n_fft) // self.hop_length) + 1
+ elif self.spectrum_type == SpectrumType.RFFTN:
length = ((length - self.win_length) // self.hop_length) + 1
return feature_data, length.int()
From f1f8081ae7401f3009d377f2ca068f1ebb346061 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Tue, 28 Nov 2023 10:34:50 -0500
Subject: [PATCH 8/9] add else branches
---
i6_models/primitives/feature_extraction.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index c1042030..7c0a7058 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -127,6 +127,8 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
# Compute power spectrum using torch.fft.rfftn
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
+ else:
+ raise ValueError("Invalid spectrum type.")
if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
@@ -142,5 +144,6 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
length = ((length - self.n_fft) // self.hop_length) + 1
elif self.spectrum_type == SpectrumType.RFFTN:
length = ((length - self.win_length) // self.hop_length) + 1
-
+ else:
+ raise ValueError("Invalid spectrum type.")
return feature_data, length.int()
From 7bd8e6ccc17db4bd8697fda807e791d8da710ca6 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Tue, 28 Nov 2023 10:59:24 -0500
Subject: [PATCH 9/9] better error message
---
i6_models/primitives/feature_extraction.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 7c0a7058..ccb2476e 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -128,7 +128,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
else:
- raise ValueError("Invalid spectrum type.")
+ raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")
if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
@@ -145,5 +145,5 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
elif self.spectrum_type == SpectrumType.RFFTN:
length = ((length - self.win_length) // self.hop_length) + 1
else:
- raise ValueError("Invalid spectrum type.")
+ raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")
return feature_data, length.int()