Skip to content

Commit

Permalink
feat: integrate leaf-audio into pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
TQZhang04 committed Aug 6, 2024
1 parent 4401ffe commit bbabeb0
Show file tree
Hide file tree
Showing 20 changed files with 4,870 additions and 13 deletions.
3,000 changes: 3,000 additions & 0 deletions inference.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyha_analyzer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):
self.required_checks("dataframe_csv")
self.required_checks("data_path")
self.get_git_hash()
self.cli_values()
#self.cli_values()
self.get_device()

def __new__(cls):
Expand Down
35 changes: 35 additions & 0 deletions pyha_analyzer/efficientnet-b0-leaf-default.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
frontend:
name: leaf
default_args: True
use_legacy_complex: True
model:
arch: efficientnet
num_classes: 133
model_depth: b0
pool: avgpool
type: multiclass
opt:
optimizer: Adam
lr: 1e-3
momentum: 0.9
scheduler: warmupcosine
warmup_epochs: 10
weight_decay: 1e-4
batch_size: 32
audio_config:
feature: raw
normalize: False
sample_rate: 32000
min_duration: 1
random_clip_size: 1
val_clip_size: 1
mixup: False
data:
meta_root: "132_peru_xc_BC_2020_meta"
is_lmdb: False
in_memory: False
train_manifest: "train.csv"
val_manifest: "test.csv"
test_manifest: None
label_map: species_labels.json
cw: cw_2.pth
1 change: 1 addition & 0 deletions pyha_analyzer/leaf_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .frontend_helper import get_frontend
101 changes: 101 additions & 0 deletions pyha_analyzer/leaf_pytorch/convolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import math
from typing import Tuple, Callable
from torch import nn
from pyha_analyzer.leaf_pytorch.initializers import GaborInit
from pyha_analyzer.leaf_pytorch.impulse_responses import gabor_filters
from pyha_analyzer.leaf_pytorch.utils import get_padding_value


class GaborConstraint(nn.Module):
def __init__(self, kernel_size):
super(GaborConstraint, self).__init__()
self._kernel_size = kernel_size

def forward(self, kernel_data):
mu_lower = 0.
mu_upper = math.pi
sigma_lower = 4 * torch.sqrt(2. * torch.log(torch.tensor(2., device=kernel_data.device))) / math.pi
sigma_upper = self._kernel_size * torch.sqrt(2. * torch.log(torch.tensor(2., device=kernel_data.device))) / math.pi
clipped_mu = torch.clamp(kernel_data[:, 0], mu_lower, mu_upper).unsqueeze(1)
clipped_sigma = torch.clamp(kernel_data[:, 1], sigma_lower, sigma_upper).unsqueeze(1)
return torch.cat([clipped_mu, clipped_sigma], dim=-1)


class GaborConv1d(nn.Module):
def __init__(self, filters, kernel_size,
strides, padding,
initializer=None,
use_bias=False,
sort_filters=False,
use_legacy_complex=False):
super(GaborConv1d, self).__init__()
self._filters = filters // 2
self._kernel_size = kernel_size
self._strides = strides
self._padding = padding
self._use_bias = use_bias
self._sort_filters = sort_filters
# initializer = override_initializer
# else:

# initializer = GaborInit(self._filters, default_window_len=self._kernel_size,
# sample_rate=16000, min_freq=60.0, max_freq=7800.0)
if isinstance(initializer, Callable):
init_weights = initializer((self._filters, 2))
elif initializer == "random":
init_weights = torch.randn(self._filters, 2)
elif initializer == "xavier_normal":
print("Using xavier_normal init scheme..")
init_weights = torch.randn(self._filters, 2)
init_weights = torch.nn.init.xavier_normal_(init_weights)
elif initializer == "kaiming_normal":
init_weights = torch.randn(self._filters, 2)
init_weights = torch.nn.init.kaiming_normal_(init_weights)
else:
raise ValueError("unsupported initializer")
self.constraint = GaborConstraint(self._kernel_size)
self._kernel = nn.Parameter(init_weights)
if self._padding.lower() == "same":
self._pad_value = get_padding_value(self._kernel_size)
else:
self._pad_value = self._padding
if self._use_bias:
self._bias = torch.nn.Parameter(torch.ones(self._filters*2,))
else:
self._bias = None
self.use_legacy_complex = use_legacy_complex
if self.use_legacy_complex:
print("ATTENTION: Using legacy_complex format for gabor filter estimation.")

def forward(self, x):
# apply Gabor constraint
kernel = self.constraint(self._kernel)
if self._sort_filters:
raise NotImplementedError("sort filter functionality not yet implemented")
filters = gabor_filters(kernel, self._kernel_size, legacy_complex=self.use_legacy_complex)
if not self.use_legacy_complex:
temp = torch.view_as_real(filters)
real_filters = temp[:, :, 0]
img_filters = temp[:, :, 1]
else:
real_filters = filters[:, :, 0]
img_filters = filters[:, :, 1]
# img_filters = filters.imag
# print(real_filters.shape)
# print(img_filters.shape)
# print(torch.view_as_real(filters).shape)
stacked_filters = torch.cat([real_filters.unsqueeze(1), img_filters.unsqueeze(1)], dim=1)
stacked_filters = torch.reshape(stacked_filters, (2 * self._filters, self._kernel_size))
stacked_filters = stacked_filters.unsqueeze(1)
if self._padding.lower() == "same":
x = nn.functional.pad(x, self._pad_value, mode='constant', value=0)
pad_val = 0
else:
pad_val = self._pad_value
print('CONVOLUTION')
print(x.shape)
x = x[:, None, :]
output = nn.functional.conv1d(x, stacked_filters,
bias=self._bias, stride=self._strides, padding=pad_val)
return output
65 changes: 65 additions & 0 deletions pyha_analyzer/leaf_pytorch/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import math

import torch
import torchaudio
import numpy as np
from pyha_analyzer.leaf_pytorch import impulse_responses
from torch import nn


class GaborFilter():
def __init__(self,
n_filters: int = 40,
min_freq: float = 0.,
max_freq: float = 8000.,
sample_rate: int = 16000,
window_len: int = 401,
n_fft: int = 512,
normalize_energy: bool = False):
super(GaborFilter, self).__init__()
self.n_filters = n_filters
self.min_freq = min_freq
self.max_freq = max_freq
self.sample_rate = sample_rate
self.window_len = window_len
self.n_fft = n_fft
self.normalize_energy = normalize_energy

def gabor_params_from_mels(self):
coeff = torch.sqrt(2. * torch.log(torch.tensor(2.))) * self.n_fft
sqrt_filters = torch.sqrt(self.mel_filters())
center_frequencies = torch.argmax(sqrt_filters, dim=1)
peaks, _ = torch.max(sqrt_filters, dim=1, keepdim=True)
half_magnitudes = peaks / 2.
fwhms = torch.sum((sqrt_filters >= half_magnitudes).float(), dim=1)
output = torch.cat([
(center_frequencies * 2 * np.pi / self.n_fft).unsqueeze(1),
(coeff / (np.pi * fwhms)).unsqueeze(1)
], dim=-1)
print(output.shape)
return output

def _mel_filters_areas(self, filters):
peaks, _ = torch.max(filters, dim=1, keepdim=True)
return peaks * (torch.sum((filters > 0).float(), dim=1, keepdim=True) + 2) * np.pi / self.n_fft


def mel_filters(self):
mel_filters = torchaudio.functional.melscale_fbanks(
n_freqs=self.n_fft // 2 + 1,
f_min=self.min_freq,
f_max=self.max_freq,
n_mels=self.n_filters,
sample_rate=self.sample_rate
)
mel_filters = mel_filters.transpose(1, 0)
if self.normalize_energy:
mel_filters = mel_filters / self._mel_filters_areas(mel_filters)
return mel_filters

def gabor_filters(self):
gabor_filters = impulse_responses.gabor_filters(self.gabor_params_from_mels, size=self.window_len)
output = gabor_filters * torch.sqrt(
self._mel_filters_areas(self.mel_filters) * 2 * math.sqrt(math.pi) * self.gabor_params_from_mels[:, 1:2]
).type(torch.complex64)
return output
97 changes: 97 additions & 0 deletions pyha_analyzer/leaf_pytorch/frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
from torch import nn
from pyha_analyzer.leaf_pytorch import convolution
from pyha_analyzer.leaf_pytorch import initializers
from pyha_analyzer.leaf_pytorch import pooling
from pyha_analyzer.leaf_pytorch import postprocessing
from pyha_analyzer.leaf_pytorch import utils


class SquaredModulus(nn.Module):
def __init__(self):
super(SquaredModulus, self).__init__()
self._pool = nn.AvgPool1d(kernel_size=2, stride=2)

def forward(self, x):
# print(x)
# print(x.shape)
x = x[:, None, :]
# print(x.shape)
x = x.transpose(1, 2)
# print(x.shape)
output = 2 * self._pool(x ** 2.)
output = output.transpose(1, 2)
return output


class Leaf(nn.Module):
def __init__(
self,
n_filters: int = 40,
sample_rate: int = 16000,
window_len: float = 25.,
window_stride: float = 10.,
preemp: bool = False,
init_min_freq = 60.0,
init_max_freq = 7800.0,
mean_var_norm: bool = False,
pcen_compression: bool = True,
use_legacy_complex=False,
initializer="default"
):
super(Leaf, self).__init__()
window_size = int(sample_rate * window_len // 1000 + 1)
window_stride = int(sample_rate * window_stride // 1000)
if preemp:
raise NotImplementedError("Pre-emp functionality not implemented yet..")
else:
self._preemp = None
if initializer == "default":
initializer = initializers.GaborInit(
default_window_len=window_size, sample_rate=sample_rate,
min_freq=init_min_freq, max_freq=init_max_freq
)
self._complex_conv = convolution.GaborConv1d(
filters=2 * n_filters,
kernel_size=window_size,
strides=1,
padding="same",
use_bias=False,
initializer=initializer,
use_legacy_complex=use_legacy_complex
)
self._activation = SquaredModulus()
self._pooling = pooling.GaussianLowPass(n_filters, kernel_size=window_size,
strides=window_stride, padding="same")
self._instance_norm = None
if mean_var_norm:
raise NotImplementedError("Instance Norm functionality not added yet..")
if pcen_compression:
self._compression = postprocessing.PCENLayer(
n_filters,
alpha=0.96,
smooth_coef=0.04,
delta=2.0,
floor=1e-12,
trainable=True,
learn_smooth_coef=True,
per_channel_smooth_coef=True)
else:
self._compression = None
self._maximum_val = torch.tensor(1e-5)

def forward(self, x):
if self._preemp:
x = self._preemp(x)
print(x.shape)
# x = x.transpose(0, 1)
# print(x.shape)
outputs = self._complex_conv(x)
outputs = self._activation(outputs)
outputs = self._pooling(outputs)
outputs = torch.maximum(outputs, torch.tensor(1e-5, device=outputs.device))
if self._compression:
outputs = self._compression(outputs)
if self._instance_norm is not None:
outputs = self._instance_norm(outputs)
return outputs
54 changes: 54 additions & 0 deletions pyha_analyzer/leaf_pytorch/frontend_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
import torch
from torch import nn
from pyha_analyzer.leaf_pytorch.frontend import Leaf


def get_frontend(opt):

front_end_config = opt['frontend']
audio_config = opt['audio_config']

pretrained = front_end_config.get("pretrained", "")
if os.path.isfile(pretrained):
pretrained_flag = True
ckpt = torch.load(pretrained)
else:
pretrained_flag = False

if "leaf" in front_end_config['name'].lower():
default_args = front_end_config.get("default_args", False)
use_legacy_complex = front_end_config.get("use_legacy_complex", False)
initializer = front_end_config.get("initializer", "default")
if default_args:
print("Using default Leaf arguments..")
fe = Leaf(use_legacy_complex=use_legacy_complex, initializer=initializer)
else:
sr = int(audio_config.get("sample_rate", 16000))
window_len_ms = float(audio_config.get("window_len", 25.))
window_stride_ms = float(audio_config.get("window_stride", 10.))

n_filters = int(front_end_config.get("n_filters", 40.0))
min_freq = float(front_end_config.get("min_freq", 60.0))
max_freq = float(front_end_config.get("max_freq", 7800.0))
pcen_compress = bool(front_end_config.get("pcen_compress", True))
mean_var_norm = bool(front_end_config.get("mean_var_norm", False))
preemp = bool(front_end_config.get("preemp", False))
fe = Leaf(
n_filters=n_filters,
sample_rate=sr,
window_len=window_len_ms,
window_stride=window_stride_ms,
preemp=preemp,
init_min_freq=min_freq,
init_max_freq=max_freq,
mean_var_norm=mean_var_norm,
pcen_compression=pcen_compress,
use_legacy_complex=use_legacy_complex,
initializer=initializer
)
else:
raise NotImplementedError("Other front ends not implemented yet.")
if pretrained_flag:
print("attempting to load pretrained frontend weights..", fe.load_state_dict(ckpt))
return fe
Loading

0 comments on commit bbabeb0

Please sign in to comment.