Skip to content

3. Models

Félix Voituret edited this page Oct 28, 2019 · 15 revisions

Pretrained model

Spleeter is provided with 3 pretrained models:

Model Description
2 stems model Performs vocals + accompaniment separation
4 stems model1 Performs vocals + drums + bass + other separation
5 stems model Performs vocals + piano + drums + bass + other

1 Pretrained 4 stems model performs very well on the standard musDB benchmark, as shown in the associated paper.

Model version

We use GitHub Releases to distribute the pretrained models. Every change in the model training will result in a new release of Spleeter. By default, the model provider will download models for the latest available version.

If you intend to use Spleeter for your research and want to enable full reproducibility, you should specify which version you used2.

2 specific model version could be setup using GITHUB_RELEASE environment variable.

Configuration file

For setting a model for training, you must create a JSON configuration file that contains the parameters that describe all that is need for training. The same file is then used at prediction time. Parameters that can be set in this file are the following:

Model description

Following parameters should be settled for model description :

Key Description
instrument_list A list of instrument name (example: ['vocals', 'accompaniment'])
mix_name Prefix used for the mixture filename (default is 'mix')
model_dir Path where to store the trained model
model Dictionary that describe the model

Last parameter model should contains following elements :

  • Type of the model under type key. Such model is dynamically loaded from spleeter.model.functions package. You can add your own model builder in this submodule. Module for building a U-net and a Bi-LSTM are already provided (default 'unet.unet').
  • Dictionary of model's parameters under params key. It depends on the model: this dictionary is passed to the model at model building time and can be used to store parameters specific to the model such as number of layers, number of unit per layer, etc. Check the code of the provided models for more details.

Audio parameters

Next parameters control audio signal processing behaviors :

Key Description
sample_rate Sample rate used for loading the audio files (files are resampled to this sample rate)
frame_length Frame length of the short time Fourier transform
frame_step Frame step of the short time Fourier transform
T Time length of the input spectrogram segment expressed in short time Fourier transform frames3
F Number of frequency bins to be processed (frequency above F are note processed)
n_channels Number of channels of input audio4
n_chunks_per_song Number of audio chunks to be used per songs

3 Must be a power of 2.

4 Audio with different number of channels are discarder at training / validation time and remixed to the provided number of channels at prediction time.

Separation parameters

Separation process required following parameters :

Key Description
separation_exponent Exponent applied to magnitude spectrogram to compute ratio masks (default to 2, correspond to basic Wiener filtering)
mask_extension As the model estimates instrument spectrograms up to frequency bin F (@see above), ratio masks have to be extended for separation. mask_extension can be 'zeros' (Masks are set to zero above F, thus discarding these frequencies) or 'average' (For each time frame, masks above F are set to the average of the value of the estimated mask under F)
MWF Whether to use Multi-channel Wiener Filtering with Norbert5

5 Note that Multi-channel Wiener Filtering is not implemented in tensorflow and thus might be slow.

Training parameters

Key Description
train_csv Path to a CSV file that describe the training dataset6
validation_csv path to a CSV file that describe the validation dataset6
learning_rate Learning rate of the optimizer
random_seed Seed used for random generators (randomness is used in model initialization, data providing, and possibly dropout)
batch_size Size of the mini-batch used in optimization
training_cache Path that denotes localization where to cache features for training7
validation_cache Path that denotes localization where to cache features for validation7
train_max_steps Number of training steps (@see Tensorflow estimators documentation)
throttle_secs @see Tensorflow estimators documentation
save_checkpoints_steps @see Tensorflow estimators documentation
save_summary_steps @see Tensorflow estimators documentation

6 Such CSV file should have columns <mix_name>_path (indicates the path to the mixture audio file), one column per instrument named <instrument_name>_path (indicates the path to the audio file of the isolated corresponding instrument), duration (indicate the duration in seconds of the audio files, that should be the same for the mixture and the separated instrument audio files).

7 Caching features (spectrograms of mix and isolated instruments) avoid recurrent computation and thus increase performances.

Implementation details

The pre-processing pipeline is coded using the tensorflow data API. It makes it possible features/labels (spectrograms) caching and asynchronous data providing. Training is done using the tensorflow estimator API. A custom estimator is designed. For efficiency reasons, the behavior of the estimator is not the same in train/evaluate mode and in predict mode: in train/evaluate mode, the estimator takes as input spectrograms (mix spectrogram as features, and target instrument spectrograms as labels) and outputs estimated spectrograms. This makes it possible to leverage spectrograms caching, avoiding features/labels recomputation during the whole training. In predict mode, the estimator takes the mixture waveform as input and outputs estimated isolated instrument waveforms. This makes it possible to run the whole separation graph on GPU, thus making the separation process very fast (more than 100x realtime).

Clone this wiki locally