Skip to content

Commit

Permalink
Specifiy number of filter is cosine filterbank (#72)
Browse files Browse the repository at this point in the history
* n_filt argument for cos_filterbank

* correct number of filters

* improve _center_freqs test

* debug cos_filterbank test

* n_filt argument for cochleagram
  • Loading branch information
OleBialas authored Dec 18, 2023
1 parent 1f5e23f commit 4dfb6ed
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 32 deletions.
25 changes: 15 additions & 10 deletions slab/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,22 +268,23 @@ def tf(self, channels='all', n_bins=None, show=True, axis=None):

@staticmethod
# TODO: oversampling factor needed for cochleagram!
def cos_filterbank(length=5000, bandwidth=1/3, low_cutoff=0, high_cutoff=None, pass_bands=False, samplerate=None):
def cos_filterbank(length=5000, bandwidth=1/3, low_cutoff=0, high_cutoff=None, pass_bands=False, n_filters=None, samplerate=None):
"""
Generate a set of Fourier filters. Each filter's transfer function is given by the positive phase of a
cosine wave. The amplitude of the cosine is that filters central frequency. Following the organization of the
cochlea, the width of the filter increases in proportion to it's center frequency. This increase is defined
by Moore & Glasberg's formula for the equivalent rectangular bandwidth (ERB) of auditory filters. This
functions is used for example to divide a sound into bands for equalization.
by Moore & Glasberg's formula for the equivalent rectangular bandwidth (ERB) of auditory filters.
The number of filters is either determined by the `n_filters` argument or calculated based on the desired
`bandwidth` or. This function is used for example to divide a sound into sub-bands for equalization.
Attributes:
length (int): The number of bins in each filter, determines the frequency resolution.
bandwidth (float): Width of the sub-filters in octaves. The smaller the bandwidth, the more filters
will be generated.
bandwidth (float): Width of the sub-filters in octaves. The smaller the bandwidth, the more filters will be generated.
low_cutoff (int | float): The lower limit of frequency range in Hz.
high_cutoff (int | float): The upper limit of frequency range in Hz. If None, use the Nyquist frequency.
pass_bands (bool): Whether to include a half cosine at the filter bank's lower and upper edge frequency.
If True, allows reconstruction of original bandwidth when collapsing subbands.
n_filters (int | None): Number of filters. When this is not None, the `bandwidth` argument is ignored.
samplerate (int | None): the samplerate of the sound that the filter shall be applied to.
If None, use the default samplerate.s
Examples::
Expand All @@ -304,7 +305,7 @@ def cos_filterbank(length=5000, bandwidth=1/3, low_cutoff=0, high_cutoff=None, p
freq_bins = numpy.fft.rfftfreq(length, d=1/samplerate)
n_freqs = len(freq_bins)
center_freqs, bandwidth, erb_spacing = Filter._center_freqs(
low_cutoff=low_cutoff, high_cutoff=high_cutoff, bandwidth=bandwidth, pass_bands=pass_bands)
low_cutoff=low_cutoff, high_cutoff=high_cutoff, bandwidth=bandwidth, pass_bands=pass_bands, n_filters=n_filters)
n_filters = len(center_freqs)
filts = numpy.zeros((n_freqs, n_filters))
freqs_erb = Filter._freq2erb(freq_bins)
Expand All @@ -318,13 +319,17 @@ def cos_filterbank(length=5000, bandwidth=1/3, low_cutoff=0, high_cutoff=None, p
return Filter(data=filts, samplerate=samplerate, fir=False)

@staticmethod
def _center_freqs(low_cutoff, high_cutoff, bandwidth=1/3, pass_bands=False):
def _center_freqs(low_cutoff, high_cutoff, bandwidth=1/3, pass_bands=False, n_filters=None):
ref_freq = 1000 # Hz, reference for conversion between oct and erb bandwidth
ref_erb = Filter._freq2erb(ref_freq)
erb_spacing = Filter._freq2erb(ref_freq*2**bandwidth) - ref_erb
h = Filter._freq2erb(high_cutoff)
l = Filter._freq2erb(low_cutoff)
n_filters = int(numpy.round((h - l) / erb_spacing))
ref_erb = Filter._freq2erb(ref_freq)
if n_filters is None:
erb_spacing = Filter._freq2erb(ref_freq*2**bandwidth) - ref_erb
n_filters = int(numpy.round((h - l) / erb_spacing))
elif n_filters is not None and pass_bands is False:
# add 2 so that after omitting pass_bands we get the desired n_filt
n_filters += 2
center_freqs, erb_spacing = numpy.linspace(l, h, n_filters, retstep=True)
if not pass_bands:
center_freqs = center_freqs[1:-1] # exclude low and highpass filters
Expand Down
18 changes: 11 additions & 7 deletions slab/sound.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,13 +1075,17 @@ def spectrogram(self, window_dur=0.005, dyn_range=120, upper_frequency=None, oth
else:
return freqs, times, power

def cochleagram(self, bandwidth=1 / 5, show=True, axis=None):
def cochleagram(self, bandwidth=1 / 5, n_bands=None, show=True, axis=None):
"""
Computes a cochleagram of the sound by filtering with a bank of cosine-shaped filters with given bandwidth
and applying a cube-root compression to the resulting envelopes.
Computes a cochleagram of the sound by filtering with a bank of cosine-shaped filters
and applying a cube-root compression to the resulting envelopes. The number of bands
is either calculated based on the desired `bandwidth` or specified by the `n_bands`
argument.
Arguments:
bandwidth (float): filter bandwidth in octaves.
n_bands (int | None): number of bands in the cochleagram. If this is not
None, the `bandwidth` argument is ignored.
show (bool): whether to show the plot right after drawing. Note that if show is False and no `axis` is
passed, no plot will be created
axis (matplotlib.axes.Axes | None): axis to plot to. If None create a new plot.
Expand All @@ -1090,7 +1094,8 @@ def cochleagram(self, bandwidth=1 / 5, show=True, axis=None):
Else, an array with the envelope is returned.
"""
fbank = Filter.cos_filterbank(bandwidth=bandwidth, low_cutoff=20,
high_cutoff=None, samplerate=self.samplerate)
high_cutoff=None, n_filters=n_bands,
samplerate=self.samplerate)
freqs = fbank.filter_bank_center_freqs()
subbands = fbank.apply(self.channel(0))
envs = subbands.envelope()
Expand All @@ -1103,9 +1108,8 @@ def cochleagram(self, bandwidth=1 / 5, show=True, axis=None):
if axis is None:
_, axis = plt.subplots()
axis.imshow(envs.T, origin='lower', aspect='auto', cmap=cmap)
#labels = list(freqs.astype(int))
#axis.yaxis.set_major_formatter(matplotlib.ticker.IndexFormatter(
# labels)) # centre frequencies as ticks -> commented because IndexFomatter deprecated in matplotlib 3.3
labels = list(freqs.astype(int))
axis.set_yticks(ticks=range(fbank.n_filters), labels=labels)
axis.set_xlim([0, self.duration])
axis.set(title='Cochleagram', xlabel='Time [sec]', ylabel='Frequency [Hz]')
if show:
Expand Down
33 changes: 18 additions & 15 deletions tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,41 +62,44 @@ def test_custom_band():


def test_cos_filterbank():
for i in range(10):
for _ in range(10):
sound = slab.Sound.whitenoise(duration=1.0, samplerate=44100)
length = numpy.random.randint(1000, 5000)
low_cutoff = numpy.random.randint(0, 500)
high_cutoff = numpy.random.choice([numpy.random.randint(5000, 15000), None])
pass_bands = False
n_filters = []
for bandwidth in numpy.linspace(0.1, 0.9, 9):
fbank = slab.Filter.cos_filterbank(length, bandwidth, low_cutoff, high_cutoff, pass_bands, sound.samplerate)
fbank = slab.Filter.cos_filterbank(length=length, bandwidth=bandwidth, low_cutoff=low_cutoff, high_cutoff=high_cutoff, pass_bands=False, samplerate=sound.samplerate)
n_filters.append(fbank.n_filters)
filtsound = fbank.apply(sound)
assert filtsound.n_channels == fbank.n_filters
assert filtsound.n_samples == sound.n_samples
assert all([n_filters[i] >= n_filters[i+1] for i in range(len(n_filters)-1)])
bandwidth = numpy.random.uniform(0.1, 0.9)
pass_bands = True
fbank = slab.Filter.cos_filterbank(sound.n_samples, bandwidth, low_cutoff, high_cutoff, pass_bands,
sound.samplerate)
fbank = slab.Filter.cos_filterbank(
length=sound.n_samples, bandwidth=bandwidth, low_cutoff=low_cutoff,
high_cutoff=high_cutoff, pass_bands=True, samplerate=sound.samplerate)
filtsound = fbank.apply(sound)
collapsed = slab.Filter.collapse_subbands(filtsound, fbank)
numpy.testing.assert_almost_equal(sound.data, collapsed.data, decimal=-1)


def test_center_freqs():
for i in range(100):
for _ in range(100):
low_cutoff = numpy.random.randint(0, 500)
high_cutoff = numpy.random.choice([numpy.random.randint(5000, 20000)])
bandwidth1 = numpy.random.uniform(0.1, 0.7)
pass_bands = False
center_freqs1, bandwidth2, _ = slab.Filter._center_freqs(low_cutoff, high_cutoff, bandwidth1, pass_bands)
assert numpy.abs(bandwidth1 - bandwidth2) < 0.3
fbank = slab.Filter.cos_filterbank(5000, bandwidth1, low_cutoff, high_cutoff, pass_bands, 44100)
center_freqs2 = fbank.filter_bank_center_freqs()
assert numpy.abs(slab.Filter._erb2freq(center_freqs1[1:]) - center_freqs2[1:]).max() < 40
assert numpy.abs(center_freqs1 - slab.Filter._freq2erb(center_freqs2)).max() < 1
bandwidth = numpy.random.uniform(0.1, 0.7)
center_freqs1, bandwidth1, erb_spacing1 = slab.Filter._center_freqs(low_cutoff, high_cutoff, bandwidth=bandwidth, pass_bands=False)
center_freqs2, bandwidth2, erb_spacing2 = slab.Filter._center_freqs(low_cutoff, high_cutoff, bandwidth=bandwidth, pass_bands=True)
assert all(center_freqs1 == center_freqs2[1:-1])
assert len(center_freqs1) == len(center_freqs2)-2
assert bandwidth1 == bandwidth2
assert erb_spacing1 == erb_spacing2
n_filters = len(center_freqs1)
center_freqs3, bandwidth3, erb_spacing3 = slab.Filter._center_freqs(low_cutoff, high_cutoff, n_filters=n_filters, pass_bands=False)
assert all(center_freqs1 == center_freqs3)
assert bandwidth1 == bandwidth3
assert erb_spacing1 == erb_spacing3


def test_equalization():
Expand Down

0 comments on commit 4dfb6ed

Please sign in to comment.