diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py index 581fc682a..f84cf4bfd 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_claix2023.py @@ -480,6 +480,49 @@ def py(): env_updates={"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}, ) + from .model_ext.ctc_sep_net import ModelSepNet, FeedForwardNet, ctc_training_with_sep_net + + # Time downsampling 6 (standard), spm10k. + # Separate FF net. + ctc_train_exp( + f"n12-spm10k-sepFf_alpha05-auxAED-b150k", + config_96gb_bf16_accgrad1, + train_def=ctc_training_with_sep_net, + model_config={ + "ctc_model_cls": rf.build_dict(ModelSepNet)["class"], + "separate_enc_net": rf.build_dict(FeedForwardNet), + "enc_conformer_layer": rf.build_dict( + ConformerEncoderLayer, + ff=rf.build_dict( + ConformerPositionwiseFeedForward, activation=rf.build_dict(rf.relu_square), with_bias=False + ), + num_heads=8, + ), + "feature_batch_norm": True, + "num_enc_layers": 12, + }, + config_updates={ + **_get_cfg_lrlin_oclr_by_bs_nep_v3(150_000, 100, batch_size_factor=_batch_size_factor), + "optimizer.weight_decay": 1e-2, + "max_seq_length_default_target": None, + # Note on max seq len stats: Before, when we used max_seq_length_default_target=75 with bpe10k, + # out of 281241 seqs in train, we removed only 71 seqs. + # With max seq len 19.5 secs on the audio, we also remove exactly 71 seqs. + "max_seq_length_default_input": 19.5 * _raw_sample_rate, + "__train_audio_preprocess": speed_pert_librosa_config, + "speed_pert_discrete_values": [0.7, 0.8, 0.9, 1.0, 1.1], + "aux_attention_decoder": rf.build_dict(TransformerDecoder, num_layers=6), # purely used for training + "use_fixed_ctc_grad": "v2", + "sep_net_grad_interpolate_alpha": 0.5, + }, + post_config_updates={"log_grad_norm": True, "__multi_proc_dataset_opts": {"num_workers": 25}}, + vocab="spm10k", + train_vocab_opts={"other_opts": {"class": "SamplingBytePairEncoding", "breadth_prob": 0.01}}, + dataset_train_opts={"train_epoch_split": 1, "train_epoch_wise_filter": None}, + # avoid OOM + env_updates={"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}, + ) + from i6_experiments.common.setups import serialization from sisyphus import gs