-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: integrate leaf-audio into pipeline
- Loading branch information
Showing
20 changed files
with
4,870 additions
and
13 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
frontend: | ||
name: leaf | ||
default_args: True | ||
use_legacy_complex: True | ||
model: | ||
arch: efficientnet | ||
num_classes: 133 | ||
model_depth: b0 | ||
pool: avgpool | ||
type: multiclass | ||
opt: | ||
optimizer: Adam | ||
lr: 1e-3 | ||
momentum: 0.9 | ||
scheduler: warmupcosine | ||
warmup_epochs: 10 | ||
weight_decay: 1e-4 | ||
batch_size: 32 | ||
audio_config: | ||
feature: raw | ||
normalize: False | ||
sample_rate: 32000 | ||
min_duration: 1 | ||
random_clip_size: 1 | ||
val_clip_size: 1 | ||
mixup: False | ||
data: | ||
meta_root: "132_peru_xc_BC_2020_meta" | ||
is_lmdb: False | ||
in_memory: False | ||
train_manifest: "train.csv" | ||
val_manifest: "test.csv" | ||
test_manifest: None | ||
label_map: species_labels.json | ||
cw: cw_2.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .frontend_helper import get_frontend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import torch | ||
import math | ||
from typing import Tuple, Callable | ||
from torch import nn | ||
from pyha_analyzer.leaf_pytorch.initializers import GaborInit | ||
from pyha_analyzer.leaf_pytorch.impulse_responses import gabor_filters | ||
from pyha_analyzer.leaf_pytorch.utils import get_padding_value | ||
|
||
|
||
class GaborConstraint(nn.Module): | ||
def __init__(self, kernel_size): | ||
super(GaborConstraint, self).__init__() | ||
self._kernel_size = kernel_size | ||
|
||
def forward(self, kernel_data): | ||
mu_lower = 0. | ||
mu_upper = math.pi | ||
sigma_lower = 4 * torch.sqrt(2. * torch.log(torch.tensor(2., device=kernel_data.device))) / math.pi | ||
sigma_upper = self._kernel_size * torch.sqrt(2. * torch.log(torch.tensor(2., device=kernel_data.device))) / math.pi | ||
clipped_mu = torch.clamp(kernel_data[:, 0], mu_lower, mu_upper).unsqueeze(1) | ||
clipped_sigma = torch.clamp(kernel_data[:, 1], sigma_lower, sigma_upper).unsqueeze(1) | ||
return torch.cat([clipped_mu, clipped_sigma], dim=-1) | ||
|
||
|
||
class GaborConv1d(nn.Module): | ||
def __init__(self, filters, kernel_size, | ||
strides, padding, | ||
initializer=None, | ||
use_bias=False, | ||
sort_filters=False, | ||
use_legacy_complex=False): | ||
super(GaborConv1d, self).__init__() | ||
self._filters = filters // 2 | ||
self._kernel_size = kernel_size | ||
self._strides = strides | ||
self._padding = padding | ||
self._use_bias = use_bias | ||
self._sort_filters = sort_filters | ||
# initializer = override_initializer | ||
# else: | ||
|
||
# initializer = GaborInit(self._filters, default_window_len=self._kernel_size, | ||
# sample_rate=16000, min_freq=60.0, max_freq=7800.0) | ||
if isinstance(initializer, Callable): | ||
init_weights = initializer((self._filters, 2)) | ||
elif initializer == "random": | ||
init_weights = torch.randn(self._filters, 2) | ||
elif initializer == "xavier_normal": | ||
print("Using xavier_normal init scheme..") | ||
init_weights = torch.randn(self._filters, 2) | ||
init_weights = torch.nn.init.xavier_normal_(init_weights) | ||
elif initializer == "kaiming_normal": | ||
init_weights = torch.randn(self._filters, 2) | ||
init_weights = torch.nn.init.kaiming_normal_(init_weights) | ||
else: | ||
raise ValueError("unsupported initializer") | ||
self.constraint = GaborConstraint(self._kernel_size) | ||
self._kernel = nn.Parameter(init_weights) | ||
if self._padding.lower() == "same": | ||
self._pad_value = get_padding_value(self._kernel_size) | ||
else: | ||
self._pad_value = self._padding | ||
if self._use_bias: | ||
self._bias = torch.nn.Parameter(torch.ones(self._filters*2,)) | ||
else: | ||
self._bias = None | ||
self.use_legacy_complex = use_legacy_complex | ||
if self.use_legacy_complex: | ||
print("ATTENTION: Using legacy_complex format for gabor filter estimation.") | ||
|
||
def forward(self, x): | ||
# apply Gabor constraint | ||
kernel = self.constraint(self._kernel) | ||
if self._sort_filters: | ||
raise NotImplementedError("sort filter functionality not yet implemented") | ||
filters = gabor_filters(kernel, self._kernel_size, legacy_complex=self.use_legacy_complex) | ||
if not self.use_legacy_complex: | ||
temp = torch.view_as_real(filters) | ||
real_filters = temp[:, :, 0] | ||
img_filters = temp[:, :, 1] | ||
else: | ||
real_filters = filters[:, :, 0] | ||
img_filters = filters[:, :, 1] | ||
# img_filters = filters.imag | ||
# print(real_filters.shape) | ||
# print(img_filters.shape) | ||
# print(torch.view_as_real(filters).shape) | ||
stacked_filters = torch.cat([real_filters.unsqueeze(1), img_filters.unsqueeze(1)], dim=1) | ||
stacked_filters = torch.reshape(stacked_filters, (2 * self._filters, self._kernel_size)) | ||
stacked_filters = stacked_filters.unsqueeze(1) | ||
if self._padding.lower() == "same": | ||
x = nn.functional.pad(x, self._pad_value, mode='constant', value=0) | ||
pad_val = 0 | ||
else: | ||
pad_val = self._pad_value | ||
print('CONVOLUTION') | ||
print(x.shape) | ||
x = x[:, None, :] | ||
output = nn.functional.conv1d(x, stacked_filters, | ||
bias=self._bias, stride=self._strides, padding=pad_val) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import math | ||
|
||
import torch | ||
import torchaudio | ||
import numpy as np | ||
from pyha_analyzer.leaf_pytorch import impulse_responses | ||
from torch import nn | ||
|
||
|
||
class GaborFilter(): | ||
def __init__(self, | ||
n_filters: int = 40, | ||
min_freq: float = 0., | ||
max_freq: float = 8000., | ||
sample_rate: int = 16000, | ||
window_len: int = 401, | ||
n_fft: int = 512, | ||
normalize_energy: bool = False): | ||
super(GaborFilter, self).__init__() | ||
self.n_filters = n_filters | ||
self.min_freq = min_freq | ||
self.max_freq = max_freq | ||
self.sample_rate = sample_rate | ||
self.window_len = window_len | ||
self.n_fft = n_fft | ||
self.normalize_energy = normalize_energy | ||
|
||
def gabor_params_from_mels(self): | ||
coeff = torch.sqrt(2. * torch.log(torch.tensor(2.))) * self.n_fft | ||
sqrt_filters = torch.sqrt(self.mel_filters()) | ||
center_frequencies = torch.argmax(sqrt_filters, dim=1) | ||
peaks, _ = torch.max(sqrt_filters, dim=1, keepdim=True) | ||
half_magnitudes = peaks / 2. | ||
fwhms = torch.sum((sqrt_filters >= half_magnitudes).float(), dim=1) | ||
output = torch.cat([ | ||
(center_frequencies * 2 * np.pi / self.n_fft).unsqueeze(1), | ||
(coeff / (np.pi * fwhms)).unsqueeze(1) | ||
], dim=-1) | ||
print(output.shape) | ||
return output | ||
|
||
def _mel_filters_areas(self, filters): | ||
peaks, _ = torch.max(filters, dim=1, keepdim=True) | ||
return peaks * (torch.sum((filters > 0).float(), dim=1, keepdim=True) + 2) * np.pi / self.n_fft | ||
|
||
|
||
def mel_filters(self): | ||
mel_filters = torchaudio.functional.melscale_fbanks( | ||
n_freqs=self.n_fft // 2 + 1, | ||
f_min=self.min_freq, | ||
f_max=self.max_freq, | ||
n_mels=self.n_filters, | ||
sample_rate=self.sample_rate | ||
) | ||
mel_filters = mel_filters.transpose(1, 0) | ||
if self.normalize_energy: | ||
mel_filters = mel_filters / self._mel_filters_areas(mel_filters) | ||
return mel_filters | ||
|
||
def gabor_filters(self): | ||
gabor_filters = impulse_responses.gabor_filters(self.gabor_params_from_mels, size=self.window_len) | ||
output = gabor_filters * torch.sqrt( | ||
self._mel_filters_areas(self.mel_filters) * 2 * math.sqrt(math.pi) * self.gabor_params_from_mels[:, 1:2] | ||
).type(torch.complex64) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch | ||
from torch import nn | ||
from pyha_analyzer.leaf_pytorch import convolution | ||
from pyha_analyzer.leaf_pytorch import initializers | ||
from pyha_analyzer.leaf_pytorch import pooling | ||
from pyha_analyzer.leaf_pytorch import postprocessing | ||
from pyha_analyzer.leaf_pytorch import utils | ||
|
||
|
||
class SquaredModulus(nn.Module): | ||
def __init__(self): | ||
super(SquaredModulus, self).__init__() | ||
self._pool = nn.AvgPool1d(kernel_size=2, stride=2) | ||
|
||
def forward(self, x): | ||
# print(x) | ||
# print(x.shape) | ||
x = x[:, None, :] | ||
# print(x.shape) | ||
x = x.transpose(1, 2) | ||
# print(x.shape) | ||
output = 2 * self._pool(x ** 2.) | ||
output = output.transpose(1, 2) | ||
return output | ||
|
||
|
||
class Leaf(nn.Module): | ||
def __init__( | ||
self, | ||
n_filters: int = 40, | ||
sample_rate: int = 16000, | ||
window_len: float = 25., | ||
window_stride: float = 10., | ||
preemp: bool = False, | ||
init_min_freq = 60.0, | ||
init_max_freq = 7800.0, | ||
mean_var_norm: bool = False, | ||
pcen_compression: bool = True, | ||
use_legacy_complex=False, | ||
initializer="default" | ||
): | ||
super(Leaf, self).__init__() | ||
window_size = int(sample_rate * window_len // 1000 + 1) | ||
window_stride = int(sample_rate * window_stride // 1000) | ||
if preemp: | ||
raise NotImplementedError("Pre-emp functionality not implemented yet..") | ||
else: | ||
self._preemp = None | ||
if initializer == "default": | ||
initializer = initializers.GaborInit( | ||
default_window_len=window_size, sample_rate=sample_rate, | ||
min_freq=init_min_freq, max_freq=init_max_freq | ||
) | ||
self._complex_conv = convolution.GaborConv1d( | ||
filters=2 * n_filters, | ||
kernel_size=window_size, | ||
strides=1, | ||
padding="same", | ||
use_bias=False, | ||
initializer=initializer, | ||
use_legacy_complex=use_legacy_complex | ||
) | ||
self._activation = SquaredModulus() | ||
self._pooling = pooling.GaussianLowPass(n_filters, kernel_size=window_size, | ||
strides=window_stride, padding="same") | ||
self._instance_norm = None | ||
if mean_var_norm: | ||
raise NotImplementedError("Instance Norm functionality not added yet..") | ||
if pcen_compression: | ||
self._compression = postprocessing.PCENLayer( | ||
n_filters, | ||
alpha=0.96, | ||
smooth_coef=0.04, | ||
delta=2.0, | ||
floor=1e-12, | ||
trainable=True, | ||
learn_smooth_coef=True, | ||
per_channel_smooth_coef=True) | ||
else: | ||
self._compression = None | ||
self._maximum_val = torch.tensor(1e-5) | ||
|
||
def forward(self, x): | ||
if self._preemp: | ||
x = self._preemp(x) | ||
print(x.shape) | ||
# x = x.transpose(0, 1) | ||
# print(x.shape) | ||
outputs = self._complex_conv(x) | ||
outputs = self._activation(outputs) | ||
outputs = self._pooling(outputs) | ||
outputs = torch.maximum(outputs, torch.tensor(1e-5, device=outputs.device)) | ||
if self._compression: | ||
outputs = self._compression(outputs) | ||
if self._instance_norm is not None: | ||
outputs = self._instance_norm(outputs) | ||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import os | ||
import torch | ||
from torch import nn | ||
from pyha_analyzer.leaf_pytorch.frontend import Leaf | ||
|
||
|
||
def get_frontend(opt): | ||
|
||
front_end_config = opt['frontend'] | ||
audio_config = opt['audio_config'] | ||
|
||
pretrained = front_end_config.get("pretrained", "") | ||
if os.path.isfile(pretrained): | ||
pretrained_flag = True | ||
ckpt = torch.load(pretrained) | ||
else: | ||
pretrained_flag = False | ||
|
||
if "leaf" in front_end_config['name'].lower(): | ||
default_args = front_end_config.get("default_args", False) | ||
use_legacy_complex = front_end_config.get("use_legacy_complex", False) | ||
initializer = front_end_config.get("initializer", "default") | ||
if default_args: | ||
print("Using default Leaf arguments..") | ||
fe = Leaf(use_legacy_complex=use_legacy_complex, initializer=initializer) | ||
else: | ||
sr = int(audio_config.get("sample_rate", 16000)) | ||
window_len_ms = float(audio_config.get("window_len", 25.)) | ||
window_stride_ms = float(audio_config.get("window_stride", 10.)) | ||
|
||
n_filters = int(front_end_config.get("n_filters", 40.0)) | ||
min_freq = float(front_end_config.get("min_freq", 60.0)) | ||
max_freq = float(front_end_config.get("max_freq", 7800.0)) | ||
pcen_compress = bool(front_end_config.get("pcen_compress", True)) | ||
mean_var_norm = bool(front_end_config.get("mean_var_norm", False)) | ||
preemp = bool(front_end_config.get("preemp", False)) | ||
fe = Leaf( | ||
n_filters=n_filters, | ||
sample_rate=sr, | ||
window_len=window_len_ms, | ||
window_stride=window_stride_ms, | ||
preemp=preemp, | ||
init_min_freq=min_freq, | ||
init_max_freq=max_freq, | ||
mean_var_norm=mean_var_norm, | ||
pcen_compression=pcen_compress, | ||
use_legacy_complex=use_legacy_complex, | ||
initializer=initializer | ||
) | ||
else: | ||
raise NotImplementedError("Other front ends not implemented yet.") | ||
if pretrained_flag: | ||
print("attempting to load pretrained frontend weights..", fe.load_state_dict(ckpt)) | ||
return fe |
Oops, something went wrong.