Skip to content

Commit

Permalink
feat: support multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
pprobst committed Oct 20, 2023
1 parent 864c854 commit 99dc0d8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 44 deletions.
22 changes: 0 additions & 22 deletions audio/aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
import audiomentations as AA
import numpy as np

# import os

from audiomentations import Compose
from typing import List, Tuple

# from multiprocessing import Pool

AUG_PARAMS = {
# See a list of possible transforms here: https://iver56.github.io/audiomentations/
Expand Down Expand Up @@ -66,22 +63,3 @@ def apply_augmentation(
transforms_used.append(transform.__class__.__name__)

return augmented_samples, transforms_used


"""
def apply_augmentation_batch(
samples: List[np.ndarray], sample_rates: List[float], augmentations: List[str]
):
if isinstance(samples, list) and len(samples) > 1:
with Pool(processes=os.cpu_count()) as pool:
augmented_samples = pool.starmap(
apply_augmentation,
[
(sample, augmentations, sr)
for sample, sr in zip(samples, sample_rates)
],
)
return augmented_samples
else:
return apply_augmentation(samples[0], sample_rates[0], augmentations)
"""
43 changes: 21 additions & 22 deletions run_audio_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,21 @@
import random
import numpy as np

from multiprocessing import Pool, cpu_count
from audio.aug import apply_augmentation, AUG_PARAMS
from utils.files import load_audio, save_audio
from typing import List


def process_audio(input_file: str, augmentations: List[str], output_format: str):
audio, sr = load_audio(input_file)
output_filename = os.path.splitext(os.path.basename(input_file))[0]
augmented_audio, transforms_used = apply_augmentation(audio, sr, augmentations)

if len(transforms_used) > 0:
output_filename = output_filename + "_" + "_".join(transforms_used)
save_audio(augmented_audio, output_filename, sr, output_format)
print(f"Augmented audio saved to {output_filename}.{output_format}")


if __name__ == "__main__":
Expand Down Expand Up @@ -42,26 +55,12 @@
random.seed(args.seed)
np.random.seed(args.seed)

audios = []
srs = []
output_filenames = []
for input_file in args.input_file:
audio, sr = load_audio(input_file)
output_filename = os.path.splitext(os.path.basename(input_file))[0]
audios.append(audio)
srs.append(sr)
output_filenames.append(output_filename)
input_args = [
(input_file, args.augmentations, args.output_format)
for input_file in args.input_file
]

for audio, sr, output_filename in zip(audios, srs, output_filenames):
augmented_audio, transforms_used = apply_augmentation(
audio, sr, args.augmentations
)
if len(transforms_used) > 0:
output_filename = output_filename + "_" + "_".join(transforms_used)
save_audio(
augmented_audio,
output_filename,
sr,
args.output_format,
)
print(f"Augmented audio saved to {output_filename}.{args.output_format}")
pool = Pool(processes=cpu_count())
pool.starmap(process_audio, input_args)
pool.close()
pool.join()

0 comments on commit 99dc0d8

Please sign in to comment.