diff --git a/requirements.txt b/requirements.txt index f1afabb1..ba83326e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ black==22.3.0 h5py +librosa matplotlib numpy soundfile diff --git a/returnn/hdf.py b/returnn/hdf.py index e75b5736..8a4bbd63 100644 --- a/returnn/hdf.py +++ b/returnn/hdf.py @@ -4,6 +4,7 @@ from enum import Enum, auto import glob import math +import librosa import numpy as np import os import shutil @@ -230,6 +231,7 @@ class RoundingScheme(Enum): "multi_channel_strategy": BaseStrategy(), "rounding": RoundingScheme.start_and_duration, "round_factor": 1, + "target_sampling_rate": None, } def __init__( @@ -241,6 +243,7 @@ def __init__( returnn_root: Optional[tk.Path] = None, rounding: RoundingScheme = RoundingScheme.start_and_duration, round_factor: int = 1, + target_sampling_rate: Optional[int] = None, ): """ @@ -256,6 +259,7 @@ def __init__( start_and_duration will round down the start time and the duration of the segment rasr_compatible will round up the start time and round down the end time :param round_factor: do the rounding based on a sampling rate that is scaled down by this factor + :param target_sampling_rate: desired sampling rate for the HDF, data will be resampled to this rate if needed """ self.set_vis_name("Dump audio to HDF") assert output_dtype in ["float64", "float32", "int32", "int16"] @@ -267,6 +271,7 @@ def __init__( self.returnn_root = returnn_root self.rounding = rounding self.round_factor = round_factor + self.target_sampling_rate = target_sampling_rate self.out_hdf = self.output_path("audio.hdf") @@ -284,7 +289,7 @@ def run(self): if self.segment_file: with uopen(self.segment_file, "rt") as f: - segments_whitelist = set(l.strip() for l in f.readlines() if len(l.strip()) > 0) + segments_whitelist = {line.strip() for line in f.readlines() if len(line.strip()) > 0} else: segments_whitelist = None @@ -295,35 +300,46 @@ def run(self): audio = sf.SoundFile(audio_file) for segment in recording.segments: - if (not segments_whitelist) or (segment.fullname() in segments_whitelist): - if self.rounding == self.RoundingScheme.start_and_duration: - start = int(segment.start * audio.samplerate / self.round_factor) * self.round_factor - duration = ( - int((segment.end - segment.start) * audio.samplerate / self.round_factor) - * self.round_factor - ) - elif self.rounding == self.RoundingScheme.rasr_compatible: - start = math.ceil(segment.start * audio.samplerate / self.round_factor) * self.round_factor - duration = ( - math.floor(segment.end * audio.samplerate / self.round_factor) * self.round_factor - start - ) - else: - raise NotImplementedError(f"RoundingScheme {self.rounding} not implemented.") - audio.seek(start) - data = audio.read( - duration, - always_2d=True, - dtype=self.output_dtype, + if (segments_whitelist is not None) and (segment.fullname() not in segments_whitelist): + continue + + # determine correct start and duration values + if self.rounding == self.RoundingScheme.start_and_duration: + start = int(segment.start * audio.samplerate / self.round_factor) * self.round_factor + duration = ( + int((segment.end - segment.start) * audio.samplerate / self.round_factor) * self.round_factor ) - if isinstance(self.multi_channel_strategy, self.PickNth): - data = data[:, self.multi_channel_strategy.channel] - else: - assert data.shape[-1] == 1, "Audio has more than one channel, choose a multi_channel_strategy" - out_hdf.insert_batch( - inputs=data.reshape(1, -1, 1), - seq_len=[data.shape[0]], - seq_tag=[segment.fullname()], + elif self.rounding == self.RoundingScheme.rasr_compatible: + start = math.ceil(segment.start * audio.samplerate / self.round_factor) * self.round_factor + duration = ( + math.floor(segment.end * audio.samplerate / self.round_factor) * self.round_factor - start ) + else: + raise NotImplementedError(f"RoundingScheme {self.rounding} not implemented.") + + # read audio data + audio.seek(start) + data = audio.read(duration, always_2d=True, dtype=self.output_dtype) + if isinstance(self.multi_channel_strategy, self.PickNth): + data = data[:, self.multi_channel_strategy.channel] + else: + assert data.shape[-1] == 1, "Audio has more than one channel, choose a multi_channel_strategy" + + # resample if necessary + if (sr := self.target_sampling_rate) is not None and sr != audio.samplerate: + data = librosa.resample( + y=data.astype(float), + orig_sr=audio.samplerate, + target_sr=sr, + axis=0, + ).astype(self.output_dtype) + + # add audio to hdf + out_hdf.insert_batch( + inputs=data.reshape(1, -1, 1), + seq_len=[data.shape[0]], + seq_tag=[segment.fullname()], + ) audio.close()