Skip to content

Commit

Permalink
Expose BandPassFilter, HighPassFilter and LowPassFilter in the public…
Browse files Browse the repository at this point in the history
… API
  • Loading branch information
iver56 committed Aug 5, 2021
1 parent cbcb8e1 commit 528f0ca
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 33 deletions.
3 changes: 3 additions & 0 deletions audiomentations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
AddImpulseResponse,
ApplyImpulseResponse,
AddShortNoises,
BandPassFilter,
Clip,
ClippingDistortion,
FrequencyMask,
Gain,
HighPassFilter,
LoudnessNormalization,
LowPassFilter,
Mp3Compression,
Normalize,
PitchShift,
Expand Down
2 changes: 0 additions & 2 deletions demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
Mp3Compression,
LoudnessNormalization,
Trim,
)
from audiomentations.augmentations.transforms import (
LowPassFilter,
HighPassFilter,
BandPassFilter,
Expand Down
23 changes: 10 additions & 13 deletions tests/test_band_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,28 @@

import matplotlib.pyplot as plt

from audiomentations.augmentations.transforms import BandPassFilter
from audiomentations import BandPassFilter

DEBUG = False


class TestBandPassFilter(unittest.TestCase):
def test_band_pass_filter(self):
sample_rate = 16000
t = .25 # signal duration in sec
f = 500 # signal frequency in Hz
t = 0.25 # signal duration in sec
f = 500 # signal frequency in Hz
samples = np.arange(t * f, dtype=np.float32) / sample_rate
samples = np.sin(2 * np.pi * f * samples)

augment = BandPassFilter(min_center_freq=100,
max_center_freq=5000,
min_q=1.0,
max_q=2.0,
p=1.0)
processed_samples = augment(
samples=samples, sample_rate=sample_rate
augment = BandPassFilter(
min_center_freq=100, max_center_freq=5000, min_q=1.0, max_q=2.0, p=1.0
)
processed_samples = augment(samples=samples, sample_rate=sample_rate)

self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)
if DEBUG:

if DEBUG:
plt.plot(samples)
plt.plot(processed_samples, '-.')
plt.plot(processed_samples, "-.")
plt.show()
17 changes: 8 additions & 9 deletions tests/test_high_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,26 @@

import matplotlib.pyplot as plt

from audiomentations.augmentations.transforms import HighPassFilter
from audiomentations import HighPassFilter

DEBUG = False


class TestHighPassFilter(unittest.TestCase):
def test_high_pass_filter(self):
sample_rate = 16000
t = .25 # signal duration in sec
f = 500 # signal frequency in Hz
t = 0.25 # signal duration in sec
f = 500 # signal frequency in Hz
samples = np.arange(t * f, dtype=np.float32) / sample_rate
samples = np.sin(2 * np.pi * f * samples)

augment = HighPassFilter(min_cutoff_freq=100, max_cutoff_freq=200, p=1.0)
processed_samples = augment(
samples=samples, sample_rate=sample_rate
)
processed_samples = augment(samples=samples, sample_rate=sample_rate)

self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)
if DEBUG:

if DEBUG:
plt.plot(samples)
plt.plot(processed_samples, '-.')
plt.plot(processed_samples, "-.")
plt.show()
17 changes: 8 additions & 9 deletions tests/test_low_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,26 @@

import matplotlib.pyplot as plt

from audiomentations.augmentations.transforms import LowPassFilter
from audiomentations import LowPassFilter

DEBUG = False


class TestLowPassFilter(unittest.TestCase):
def test_low_pass_filter(self):
sample_rate = 16000
t = .25 # signal duration in sec
f = 500 # signal frequency in Hz
t = 0.25 # signal duration in sec
f = 500 # signal frequency in Hz
samples = np.arange(t * f, dtype=np.float32) / sample_rate
samples = np.sin(2 * np.pi * f * samples)

augment = LowPassFilter(min_cutoff_freq=100, max_cutoff_freq=200, p=1.0)
processed_samples = augment(
samples=samples, sample_rate=sample_rate
)
processed_samples = augment(samples=samples, sample_rate=sample_rate)

self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)
if DEBUG:

if DEBUG:
plt.plot(samples)
plt.plot(processed_samples, '-.')
plt.plot(processed_samples, "-.")
plt.show()

0 comments on commit 528f0ca

Please sign in to comment.