From 0c7bb2f4d9be004ecc1ed7966e0ca9615a03dd11 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Wed, 15 Nov 2023 06:24:13 -0500 Subject: [PATCH 1/9] extend feature extraction --- i6_models/primitives/feature_extraction.py | 54 ++++++++++++++-------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index ead52dd5..806a8570 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -1,7 +1,7 @@ __all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"] from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Any, Dict from librosa import filters import torch @@ -22,6 +22,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): min_amp: minimum amplitude for safe log num_filters: number of mel windows center: centered STFT with automatic padding + periodic: whether the window is assumed to be periodic + mel_options: extra options for mel filters + rasr_compatible: apply FFT to make features compatible to RASR's """ sample_rate: int @@ -33,6 +36,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): num_filters: int center: bool n_fft: Optional[int] = None + periodic: Optional[bool] = True + mel_options: Optional[Dict[str, Any]] = None + rasr_compatible: Optional[bool] = False def __post_init__(self) -> None: super().__post_init__() @@ -62,6 +68,8 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config): self.min_amp = cfg.min_amp self.n_fft = cfg.n_fft self.win_length = int(cfg.win_size * cfg.sample_rate) + self.mel_options = cfg.mel_options or {} + self.rasr_compatible = cfg.rasr_compatible self.register_buffer( "mel_basis", @@ -72,10 +80,11 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config): n_mels=cfg.num_filters, fmin=cfg.f_min, fmax=cfg.f_max, + **self.mel_options, ) ), ) - self.register_buffer("window", torch.hann_window(self.win_length)) + self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic)) def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -83,25 +92,34 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: :param length in samples: [B] :return features as [B,T,F] and length in frames [B] """ - power_spectrum = ( - torch.abs( - torch.stft( - raw_audio, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode="constant", - return_complex=True, + if not self.rasr_compatible: + power_spectrogram = ( + torch.abs( + torch.stft( + raw_audio, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode="constant", + return_complex=True, + ) ) + ** 2 ) - ** 2 - ) - if len(power_spectrum.size()) == 2: + else: + windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) + smoothed = windowed * self.window.unsqueeze(0) + + # Compute power spectrogram using torch.fft.rfftn + power_spectrogram = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T] + power_spectrogram = power_spectrogram.transpose(1, 2) # [B, T, F] + + if len(power_spectrogram.size()) == 2: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again - power_spectrum = torch.unsqueeze(power_spectrum, 0) - melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) + power_spectrogram = torch.unsqueeze(power_spectrogram, 0) + melspec = torch.einsum("...ft,mf->...mt", power_spectrogram, self.mel_basis) log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) feature_data = torch.transpose(log_melspec, 1, 2) From 91e8e0bdd8ccba06d3296e286a5f54a80815b30c Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Wed, 15 Nov 2023 06:28:25 -0500 Subject: [PATCH 2/9] restore names --- i6_models/primitives/feature_extraction.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 806a8570..6df10574 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -93,7 +93,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: :return features as [B,T,F] and length in frames [B] """ if not self.rasr_compatible: - power_spectrogram = ( + power_spectrum = ( torch.abs( torch.stft( raw_audio, @@ -112,14 +112,14 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) smoothed = windowed * self.window.unsqueeze(0) - # Compute power spectrogram using torch.fft.rfftn - power_spectrogram = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T] - power_spectrogram = power_spectrogram.transpose(1, 2) # [B, T, F] + # Compute power spectrum using torch.fft.rfftn + power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T] + power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F] - if len(power_spectrogram.size()) == 2: + if len(power_spectrum.size()) == 2: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again - power_spectrogram = torch.unsqueeze(power_spectrogram, 0) - melspec = torch.einsum("...ft,mf->...mt", power_spectrogram, self.mel_basis) + power_spectrum = torch.unsqueeze(power_spectrum, 0) + melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) feature_data = torch.transpose(log_melspec, 1, 2) From 1559a09df0776642bbde7baffc1dc97baa2a70f1 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Wed, 15 Nov 2023 09:16:50 -0500 Subject: [PATCH 3/9] reorder --- i6_models/primitives/feature_extraction.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 6df10574..dadf6e41 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -92,7 +92,14 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: :param length in samples: [B] :return features as [B,T,F] and length in frames [B] """ - if not self.rasr_compatible: + if self.rasr_compatible: + windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) + smoothed = windowed * self.window.unsqueeze(0) + + # Compute power spectrum using torch.fft.rfftn + power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T] + power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F] + else: power_spectrum = ( torch.abs( torch.stft( @@ -108,13 +115,6 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: ) ** 2 ) - else: - windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) - smoothed = windowed * self.window.unsqueeze(0) - - # Compute power spectrum using torch.fft.rfftn - power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T] - power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F] if len(power_spectrum.size()) == 2: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again From e3101cf89968c42b393cbe0913debe6a3ec3db1c Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Wed, 15 Nov 2023 12:07:16 -0500 Subject: [PATCH 4/9] fix type annotations, add shape annotations --- i6_models/primitives/feature_extraction.py | 24 +++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index dadf6e41..824267c3 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -24,7 +24,7 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): center: centered STFT with automatic padding periodic: whether the window is assumed to be periodic mel_options: extra options for mel filters - rasr_compatible: apply FFT to make features compatible to RASR's + rasr_compatible: apply FFT to make features compatible to RASR's, otherwise (defalut) apply STFT """ sample_rate: int @@ -36,9 +36,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): num_filters: int center: bool n_fft: Optional[int] = None - periodic: Optional[bool] = True + periodic: bool = True mel_options: Optional[Dict[str, Any]] = None - rasr_compatible: Optional[bool] = False + rasr_compatible: bool = False def __post_init__(self) -> None: super().__post_init__() @@ -93,12 +93,12 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: :return features as [B,T,F] and length in frames [B] """ if self.rasr_compatible: - windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) - smoothed = windowed * self.window.unsqueeze(0) + windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] + smoothed = windowed * self.window.unsqueeze(0) # [B, T', W] # Compute power spectrum using torch.fft.rfftn - power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T] - power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F] + power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1] + power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T'] else: power_spectrum = ( torch.abs( @@ -118,14 +118,14 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: if len(power_spectrum.size()) == 2: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again - power_spectrum = torch.unsqueeze(power_spectrum, 0) - melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) + power_spectrum = torch.unsqueeze(power_spectrum, 0) # [B, F, T'] + melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T'] log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) - feature_data = torch.transpose(log_melspec, 1, 2) + feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F'] - if self.center: + if self.center and not rasr_compatible: length = (length // self.hop_length) + 1 else: - length = ((length - self.n_fft) // self.hop_length) + 1 + length = ((length - self.win_length) // self.hop_length) + 1 return feature_data, length.int() From cf72c0a3b2f1981390b6c13aab69e61ac76b271e Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Wed, 15 Nov 2023 12:12:48 -0500 Subject: [PATCH 5/9] fix --- i6_models/primitives/feature_extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 824267c3..ca716b37 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -123,7 +123,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F'] - if self.center and not rasr_compatible: + if self.center and not self.rasr_compatible: length = (length // self.hop_length) + 1 else: length = ((length - self.win_length) // self.hop_length) + 1 From 221a81389d8fb6800482820d59740b3b755eb320 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Wed, 15 Nov 2023 12:36:35 -0500 Subject: [PATCH 6/9] fix typo --- i6_models/primitives/feature_extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index ca716b37..296068ad 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -24,7 +24,7 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): center: centered STFT with automatic padding periodic: whether the window is assumed to be periodic mel_options: extra options for mel filters - rasr_compatible: apply FFT to make features compatible to RASR's, otherwise (defalut) apply STFT + rasr_compatible: apply FFT to make features compatible to RASR's, otherwise (default) apply STFT """ sample_rate: int From cfb3f5661c9d9941c43314ec7b6f482b80fb0935 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Fri, 24 Nov 2023 10:55:36 -0500 Subject: [PATCH 7/9] pass mel filter parameters directly & make spectrum type enum & fix length computation --- i6_models/primitives/feature_extraction.py | 55 ++++++++++++++-------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 296068ad..c1042030 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -1,15 +1,23 @@ __all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"] from dataclasses import dataclass -from typing import Optional, Tuple, Any, Dict +from typing import Optional, Tuple, Union, Literal +from enum import Enum from librosa import filters import torch from torch import nn +import numpy as np +from numpy.typing import DTypeLike from i6_models.config import ModelConfiguration +class SpectrumType(Enum): + STFT = 1 + RFFTN = 2 + + @dataclass class LogMelFeatureExtractionV1Config(ModelConfiguration): """ @@ -23,8 +31,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): num_filters: number of mel windows center: centered STFT with automatic padding periodic: whether the window is assumed to be periodic - mel_options: extra options for mel filters - rasr_compatible: apply FFT to make features compatible to RASR's, otherwise (default) apply STFT + htk: whether use HTK formula instead of Slaney + norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html + spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's """ sample_rate: int @@ -37,8 +46,10 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): center: bool n_fft: Optional[int] = None periodic: bool = True - mel_options: Optional[Dict[str, Any]] = None - rasr_compatible: bool = False + htk: bool = False + norm: Optional[Union[Literal["slaney"], float]] = "slaney" + dtype: DTypeLike = np.float32 + spectrum_type: SpectrumType = SpectrumType.STFT def __post_init__(self) -> None: super().__post_init__() @@ -68,8 +79,7 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config): self.min_amp = cfg.min_amp self.n_fft = cfg.n_fft self.win_length = int(cfg.win_size * cfg.sample_rate) - self.mel_options = cfg.mel_options or {} - self.rasr_compatible = cfg.rasr_compatible + self.spectrum_type = cfg.spectrum_type self.register_buffer( "mel_basis", @@ -80,8 +90,10 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config): n_mels=cfg.num_filters, fmin=cfg.f_min, fmax=cfg.f_max, - **self.mel_options, - ) + htk=cfg.htk, + norm=cfg.norm, + dtype=cfg.dtype, + ), ), ) self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic)) @@ -92,14 +104,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: :param length in samples: [B] :return features as [B,T,F] and length in frames [B] """ - if self.rasr_compatible: - windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] - smoothed = windowed * self.window.unsqueeze(0) # [B, T', W] - - # Compute power spectrum using torch.fft.rfftn - power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1] - power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T'] - else: + if self.spectrum_type == SpectrumType.STFT: power_spectrum = ( torch.abs( torch.stft( @@ -115,6 +120,13 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: ) ** 2 ) + elif self.spectrum_type == SpectrumType.RFFTN: + windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] + smoothed = windowed * self.window.unsqueeze(0) # [B, T', W] + + # Compute power spectrum using torch.fft.rfftn + power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1] + power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T'] if len(power_spectrum.size()) == 2: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again @@ -123,9 +135,12 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F'] - if self.center and not self.rasr_compatible: - length = (length // self.hop_length) + 1 - else: + if self.spectrum_type == SpectrumType.STFT: + if self.center: + length = (length // self.hop_length) + 1 + else: + length = ((length - self.n_fft) // self.hop_length) + 1 + elif self.spectrum_type == SpectrumType.RFFTN: length = ((length - self.win_length) // self.hop_length) + 1 return feature_data, length.int() From f1f8081ae7401f3009d377f2ca068f1ebb346061 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Tue, 28 Nov 2023 10:34:50 -0500 Subject: [PATCH 8/9] add else branches --- i6_models/primitives/feature_extraction.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index c1042030..7c0a7058 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -127,6 +127,8 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: # Compute power spectrum using torch.fft.rfftn power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1] power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T'] + else: + raise ValueError("Invalid spectrum type.") if len(power_spectrum.size()) == 2: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again @@ -142,5 +144,6 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: length = ((length - self.n_fft) // self.hop_length) + 1 elif self.spectrum_type == SpectrumType.RFFTN: length = ((length - self.win_length) // self.hop_length) + 1 - + else: + raise ValueError("Invalid spectrum type.") return feature_data, length.int() From 7bd8e6ccc17db4bd8697fda807e791d8da710ca6 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Tue, 28 Nov 2023 10:59:24 -0500 Subject: [PATCH 9/9] better error message --- i6_models/primitives/feature_extraction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 7c0a7058..ccb2476e 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -128,7 +128,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1] power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T'] else: - raise ValueError("Invalid spectrum type.") + raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.") if len(power_spectrum.size()) == 2: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again @@ -145,5 +145,5 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: elif self.spectrum_type == SpectrumType.RFFTN: length = ((length - self.win_length) // self.hop_length) + 1 else: - raise ValueError("Invalid spectrum type.") + raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.") return feature_data, length.int()