diff --git a/2024-orthogonal-softmax/LBS-960/Figure1_orthosoftmax_jointly_train_five_models_FLOPs_aware_cmp_wise.config b/2024-orthogonal-softmax/LBS-960/Figure1_orthosoftmax_jointly_train_five_models_FLOPs_aware_cmp_wise.config new file mode 100644 index 00000000..ecca68a5 --- /dev/null +++ b/2024-orthogonal-softmax/LBS-960/Figure1_orthosoftmax_jointly_train_five_models_FLOPs_aware_cmp_wise.config @@ -0,0 +1,1024 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2400000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +debug_print_layer_output_template = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "partition_epoch": 1, + "seq_ordering": "laplace:.1000", + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.KErFrKsP3fTh/output/audio.hdf", + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.Clwnntg2nopq/output/audio.hdf", + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.h7DH1ILPAElF/output/targets.hdf", + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.vESciFN5It1u/output/targets.hdf", + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 5.466666666666666e-06, + 6.933333333333334e-06, + 8.400000000000001e-06, + 9.866666666666668e-06, + 1.1333333333333336e-05, + 1.28e-05, + 1.4266666666666667e-05, + 1.5733333333333334e-05, + 1.72e-05, + 1.866666666666667e-05, + 2.0133333333333336e-05, + 2.16e-05, + 2.3066666666666667e-05, + 2.4533333333333334e-05, + 2.6000000000000002e-05, + 2.746666666666667e-05, + 2.8933333333333336e-05, + 3.0400000000000004e-05, + 3.186666666666667e-05, + 3.333333333333334e-05, + 3.4800000000000006e-05, + 3.6266666666666676e-05, + 3.773333333333334e-05, + 3.9200000000000004e-05, + 4.0666666666666675e-05, + 4.213333333333334e-05, + 4.360000000000001e-05, + 4.506666666666667e-05, + 4.6533333333333344e-05, + 4.800000000000001e-05, + 4.946666666666668e-05, + 5.093333333333334e-05, + 5.2400000000000007e-05, + 5.386666666666668e-05, + 5.533333333333334e-05, + 5.680000000000001e-05, + 5.8266666666666676e-05, + 5.9733333333333346e-05, + 6.120000000000001e-05, + 6.266666666666668e-05, + 6.413333333333334e-05, + 6.560000000000001e-05, + 6.706666666666668e-05, + 6.853333333333335e-05, + 7.000000000000001e-05, + 7.146666666666668e-05, + 7.293333333333335e-05, + 7.44e-05, + 7.586666666666668e-05, + 7.733333333333335e-05, + 7.880000000000002e-05, + 8.026666666666668e-05, + 8.173333333333335e-05, + 8.320000000000002e-05, + 8.466666666666669e-05, + 8.613333333333334e-05, + 8.760000000000002e-05, + 8.906666666666669e-05, + 9.053333333333334e-05, + 9.200000000000001e-05, + 9.346666666666668e-05, + 9.493333333333336e-05, + 9.640000000000001e-05, + 9.786666666666668e-05, + 9.933333333333335e-05, + 0.00010080000000000001, + 0.00010226666666666668, + 0.00010373333333333335, + 0.00010520000000000002, + 0.00010666666666666668, + 0.00010813333333333335, + 0.00010960000000000002, + 0.00011106666666666668, + 0.00011253333333333335, + 0.00011400000000000002, + 0.00011546666666666669, + 0.00011693333333333335, + 0.00011840000000000002, + 0.00011986666666666669, + 0.00012133333333333336, + 0.0001228, + 0.0001242666666666667, + 0.00012573333333333334, + 0.0001272, + 0.00012866666666666669, + 0.00013013333333333334, + 0.0001316, + 0.00013306666666666668, + 0.00013453333333333334, + 0.000136, + 0.00013746666666666668, + 0.00013893333333333334, + 0.0001404, + 0.00014186666666666668, + 0.00014333333333333334, + 0.0001448, + 0.00014626666666666668, + 0.00014773333333333334, + 0.00014920000000000002, + 0.00015066666666666668, + 0.00015213333333333334, + 0.00015360000000000002, + 0.00015506666666666668, + 0.00015653333333333333, + 0.00015800000000000002, + 0.00015946666666666668, + 0.00016093333333333333, + 0.00016240000000000002, + 0.00016386666666666667, + 0.00016533333333333336, + 0.00016680000000000002, + 0.00016826666666666667, + 0.00016973333333333336, + 0.00017120000000000001, + 0.00017266666666666667, + 0.00017413333333333336, + 0.0001756, + 0.00017706666666666667, + 0.00017853333333333335, + 0.00018, + 0.00018146666666666667, + 0.00018293333333333335, + 0.0001844, + 0.0001858666666666667, + 0.00018733333333333335, + 0.0001888, + 0.0001902666666666667, + 0.00019173333333333335, + 0.0001932, + 0.0001946666666666667, + 0.00019613333333333335, + 0.0001976, + 0.0001990666666666667, + 0.00020053333333333335, + 0.00020200000000000003, + 0.0002034666666666667, + 0.00020493333333333335, + 0.00020640000000000003, + 0.0002078666666666667, + 0.00020933333333333334, + 0.00021080000000000003, + 0.00021226666666666669, + 0.00021373333333333334, + 0.00021520000000000003, + 0.00021666666666666668, + 0.00021813333333333334, + 0.00021960000000000003, + 0.00022106666666666668, + 0.00022253333333333337, + 0.00022400000000000002, + 0.00022546666666666668, + 0.00022693333333333337, + 0.00022840000000000002, + 0.00022986666666666668, + 0.00023133333333333336, + 0.00023280000000000002, + 0.00023426666666666668, + 0.00023573333333333336, + 0.00023720000000000002, + 0.0002386666666666667, + 0.00024013333333333336, + 0.00024160000000000002, + 0.0002430666666666667, + 0.0002445333333333334, + 0.000246, + 0.0002474666666666667, + 0.0002489333333333334, + 0.0002504, + 0.0002518666666666667, + 0.0002533333333333334, + 0.0002548, + 0.0002562666666666667, + 0.0002577333333333334, + 0.0002592, + 0.0002606666666666667, + 0.0002621333333333334, + 0.0002636, + 0.0002650666666666667, + 0.0002665333333333334, + 0.000268, + 0.0002694666666666667, + 0.0002709333333333334, + 0.0002724, + 0.0002738666666666667, + 0.0002753333333333334, + 0.0002768, + 0.0002782666666666667, + 0.0002797333333333334, + 0.0002812, + 0.0002826666666666667, + 0.0002841333333333334, + 0.0002856, + 0.0002870666666666667, + 0.00028853333333333337, + 0.00029000000000000006, + 0.0002914666666666667, + 0.00029293333333333337, + 0.00029440000000000005, + 0.0002958666666666667, + 0.00029733333333333337, + 0.00029880000000000005, + 0.0003002666666666667, + 0.00030173333333333337, + 0.00030320000000000005, + 0.0003046666666666667, + 0.00030613333333333337, + 0.00030760000000000005, + 0.0003090666666666667, + 0.00031053333333333336, + 0.00031200000000000005, + 0.0003134666666666667, + 0.00031493333333333336, + 0.00031640000000000005, + 0.0003178666666666667, + 0.00031933333333333336, + 0.00032080000000000005, + 0.0003222666666666667, + 0.00032373333333333336, + 0.00032520000000000004, + 0.00032666666666666673, + 0.00032813333333333336, + 0.00032960000000000004, + 0.0003310666666666667, + 0.00033253333333333336, + 0.00033400000000000004, + 0.0003354666666666667, + 0.00033693333333333336, + 0.00033840000000000004, + 0.0003398666666666667, + 0.00034133333333333335, + 0.00034280000000000004, + 0.0003442666666666667, + 0.00034573333333333335, + 0.00034720000000000004, + 0.0003486666666666667, + 0.00035013333333333335, + 0.00035160000000000004, + 0.0003530666666666667, + 0.00035453333333333335, + 0.00035600000000000003, + 0.0003574666666666667, + 0.00035893333333333335, + 0.00036040000000000003, + 0.0003618666666666667, + 0.0003633333333333334, + 0.00036480000000000003, + 0.0003662666666666667, + 0.0003677333333333334, + 0.00036920000000000003, + 0.0003706666666666667, + 0.0003721333333333334, + 0.00037360000000000003, + 0.0003750666666666667, + 0.0003765333333333334, + 0.000378, + 0.0003794666666666667, + 0.0003809333333333334, + 0.0003824, + 0.0003838666666666667, + 0.0003853333333333334, + 0.0003868, + 0.0003882666666666667, + 0.0003897333333333334, + 0.0003912, + 0.0003926666666666667, + 0.0003941333333333334, + 0.0003956, + 0.0003970666666666667, + 0.0003985333333333334, + 0.0004, + 0.00039853333333333333, + 0.0003970666666666667, + 0.0003956, + 0.00039413333333333334, + 0.0003926666666666667, + 0.0003912, + 0.00038973333333333334, + 0.0003882666666666667, + 0.0003868, + 0.00038533333333333334, + 0.0003838666666666667, + 0.0003824, + 0.00038093333333333334, + 0.0003794666666666667, + 0.000378, + 0.00037653333333333334, + 0.00037506666666666666, + 0.00037360000000000003, + 0.00037213333333333334, + 0.00037066666666666666, + 0.00036920000000000003, + 0.00036773333333333335, + 0.00036626666666666666, + 0.00036480000000000003, + 0.00036333333333333335, + 0.00036186666666666666, + 0.00036040000000000003, + 0.00035893333333333335, + 0.00035746666666666666, + 0.00035600000000000003, + 0.00035453333333333335, + 0.00035306666666666667, + 0.00035160000000000004, + 0.00035013333333333335, + 0.00034866666666666667, + 0.0003472, + 0.00034573333333333335, + 0.00034426666666666667, + 0.0003428, + 0.00034133333333333335, + 0.00033986666666666667, + 0.0003384, + 0.00033693333333333336, + 0.00033546666666666667, + 0.000334, + 0.00033253333333333336, + 0.00033106666666666667, + 0.0003296, + 0.00032813333333333336, + 0.0003266666666666667, + 0.0003252, + 0.00032373333333333336, + 0.0003222666666666667, + 0.0003208, + 0.00031933333333333336, + 0.0003178666666666667, + 0.0003164, + 0.00031493333333333336, + 0.0003134666666666667, + 0.000312, + 0.00031053333333333336, + 0.0003090666666666667, + 0.0003076, + 0.00030613333333333337, + 0.0003046666666666667, + 0.0003032, + 0.00030173333333333337, + 0.0003002666666666667, + 0.0002988, + 0.00029733333333333337, + 0.0002958666666666667, + 0.0002944, + 0.00029293333333333337, + 0.0002914666666666667, + 0.00029, + 0.0002885333333333333, + 0.0002870666666666667, + 0.0002856, + 0.0002841333333333333, + 0.00028266666666666663, + 0.0002812, + 0.0002797333333333333, + 0.00027826666666666664, + 0.0002768, + 0.0002753333333333333, + 0.00027386666666666664, + 0.0002724, + 0.0002709333333333333, + 0.00026946666666666664, + 0.000268, + 0.0002665333333333333, + 0.00026506666666666664, + 0.0002636, + 0.0002621333333333333, + 0.00026066666666666664, + 0.0002592, + 0.00025773333333333333, + 0.00025626666666666664, + 0.0002548, + 0.00025333333333333333, + 0.00025186666666666664, + 0.0002504, + 0.00024893333333333333, + 0.00024746666666666665, + 0.000246, + 0.00024453333333333333, + 0.00024306666666666668, + 0.0002416, + 0.00024013333333333333, + 0.00023866666666666665, + 0.0002372, + 0.00023573333333333334, + 0.00023426666666666665, + 0.0002328, + 0.00023133333333333334, + 0.00022986666666666665, + 0.0002284, + 0.00022693333333333334, + 0.00022546666666666665, + 0.000224, + 0.00022253333333333334, + 0.00022106666666666666, + 0.0002196, + 0.00021813333333333331, + 0.00021666666666666666, + 0.0002152, + 0.00021373333333333332, + 0.00021226666666666666, + 0.0002108, + 0.00020933333333333332, + 0.00020786666666666666, + 0.0002064, + 0.00020493333333333332, + 0.00020346666666666666, + 0.00020199999999999998, + 0.00020053333333333332, + 0.00019906666666666666, + 0.00019759999999999998, + 0.00019613333333333332, + 0.00019466666666666666, + 0.00019319999999999998, + 0.00019173333333333332, + 0.00019026666666666667, + 0.00018879999999999998, + 0.00018733333333333332, + 0.00018586666666666667, + 0.00018439999999999998, + 0.00018293333333333333, + 0.00018146666666666664, + 0.00017999999999999998, + 0.00017853333333333333, + 0.00017706666666666664, + 0.00017559999999999999, + 0.00017413333333333333, + 0.00017266666666666664, + 0.0001712, + 0.00016973333333333333, + 0.00016826666666666665, + 0.0001668, + 0.0001653333333333333, + 0.00016386666666666665, + 0.0001624, + 0.0001609333333333333, + 0.00015946666666666665, + 0.000158, + 0.0001565333333333333, + 0.00015506666666666662, + 0.0001536, + 0.0001521333333333333, + 0.00015066666666666662, + 0.0001492, + 0.0001477333333333333, + 0.00014626666666666663, + 0.0001448, + 0.0001433333333333333, + 0.00014186666666666663, + 0.0001404, + 0.0001389333333333333, + 0.00013746666666666663, + 0.000136, + 0.00013453333333333331, + 0.00013306666666666663, + 0.0001316, + 0.00013013333333333332, + 0.00012866666666666663, + 0.0001272, + 0.00012573333333333332, + 0.00012426666666666663, + 0.0001228, + 0.00012133333333333332, + 0.00011986666666666663, + 0.0001184, + 0.00011693333333333332, + 0.00011546666666666664, + 0.00011399999999999995, + 0.00011253333333333332, + 0.00011106666666666664, + 0.00010959999999999995, + 0.00010813333333333332, + 0.00010666666666666664, + 0.00010519999999999996, + 0.00010373333333333332, + 0.00010226666666666664, + 0.00010079999999999996, + 9.933333333333333e-05, + 9.786666666666664e-05, + 9.639999999999996e-05, + 9.493333333333333e-05, + 9.346666666666664e-05, + 9.199999999999996e-05, + 9.053333333333333e-05, + 8.906666666666665e-05, + 8.759999999999996e-05, + 8.613333333333333e-05, + 8.466666666666665e-05, + 8.319999999999996e-05, + 8.173333333333333e-05, + 8.026666666666665e-05, + 7.879999999999996e-05, + 7.733333333333328e-05, + 7.586666666666665e-05, + 7.439999999999997e-05, + 7.293333333333328e-05, + 7.146666666666665e-05, + 6.999999999999997e-05, + 6.853333333333328e-05, + 6.706666666666665e-05, + 6.559999999999997e-05, + 6.413333333333328e-05, + 6.266666666666665e-05, + 6.119999999999997e-05, + 5.9733333333333285e-05, + 5.8266666666666655e-05, + 5.679999999999997e-05, + 5.533333333333329e-05, + 5.386666666666666e-05, + 5.239999999999997e-05, + 5.093333333333329e-05, + 4.946666666666666e-05, + 4.7999999999999974e-05, + 4.653333333333329e-05, + 4.506666666666666e-05, + 4.3599999999999976e-05, + 4.213333333333329e-05, + 4.066666666666661e-05, + 3.919999999999998e-05, + 3.773333333333329e-05, + 3.626666666666661e-05, + 3.479999999999998e-05, + 3.3333333333333294e-05, + 3.186666666666661e-05, + 3.039999999999998e-05, + 2.8933333333333296e-05, + 2.746666666666661e-05, + 2.599999999999998e-05, + 2.4533333333333297e-05, + 2.3066666666666613e-05, + 2.1599999999999983e-05, + 2.01333333333333e-05, + 1.8666666666666614e-05, + 1.7199999999999984e-05, + 1.57333333333333e-05, + 1.4266666666666616e-05, + 1.2799999999999986e-05, + 1.1333333333333302e-05, + 9.866666666666617e-06, + 8.399999999999987e-06, + 6.933333333333303e-06, + 5.466666666666619e-06, + 4e-06, + 3.9335e-06, + 3.8669999999999996e-06, + 3.8005e-06, + 3.7339999999999997e-06, + 3.6675e-06, + 3.601e-06, + 3.5344999999999998e-06, + 3.4679999999999997e-06, + 3.4015e-06, + 3.335e-06, + 3.2685e-06, + 3.202e-06, + 3.1355e-06, + 3.0689999999999998e-06, + 3.0024999999999996e-06, + 2.936e-06, + 2.8695000000000002e-06, + 2.803e-06, + 2.7365e-06, + 2.67e-06, + 2.6034999999999997e-06, + 2.537e-06, + 2.4705e-06, + 2.404e-06, + 2.3375e-06, + 2.271e-06, + 2.2045e-06, + 2.138e-06, + 2.0715e-06, + 2.005e-06, + 1.9385e-06, + 1.872e-06, + 1.8055e-06, + 1.7390000000000002e-06, + 1.6725e-06, + 1.606e-06, + 1.5395000000000003e-06, + 1.4730000000000001e-06, + 1.4065e-06, + 1.3399999999999999e-06, + 1.2735000000000002e-06, + 1.207e-06, + 1.1405e-06, + 1.0740000000000002e-06, + 1.0075e-06, + 9.41e-07, + 8.745000000000003e-07, + 8.080000000000001e-07, + 7.415e-07, + 6.750000000000003e-07, + 6.085000000000002e-07, + 5.420000000000001e-07, + 4.7550000000000036e-07, + 4.0900000000000023e-07, + 3.425000000000001e-07, + 2.760000000000004e-07, + 2.0950000000000028e-07, + 1.4300000000000016e-07, + 7.650000000000003e-08, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 60 +model = "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/training/ReturnnTrainingJob.ctyBeRAC6BBa/output/models/epoch" +num_epochs = 600 +num_inputs = 1 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "partition_epoch": 20, + "seq_ordering": "laplace:.1000", + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.VZM5dHZhqlnJ/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.SYt8A5fOy2ta/output/targets.hdf" + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert( + 0, "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/recipe" +) +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.independent_softmax.joint_train_num_params_random_pct import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.independent_softmax.joint_train_num_params_random_pct import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_ortho_softmax_cmp_wise import ( + ConformerEncoderConfig, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_ortho_softmax_cmp_wise import ( + ConformerBlockConfig, +) +from i6_models.parts.conformer_structure_prune.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer_structure_prune.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer_structure_prune.convolution import ( + ConformerConvolutionV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models_repo.i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=5, + freq_mask_max_size=8, + ), + conformer_cfg=ConformerEncoderConfig( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=512, + ), + ), + block_cfg=ConformerBlockConfig( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, num_att_heads=8, att_weights_dropout=0.1, dropout=0.1 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC(512), + ), + layer_dropout=0.1, + apply_ff_adaptive_dropout=True, + ), + pct_params_set=[0.314, 0.451, 0.588, 0.793, 1], + softmax_kwargs={ + "softmax_constraint_norm": "L2_norm_sqrt", + "initial_tau": 2, + "min_tau": 0.01, + "tau_annealing": 0.999992, + "softmax_constraint_loss_scale": "linear_increase", + "max_softmax_constraint_loss_scale": 1, + "softmax_constraint_warmup_steps": 366000, + }, + ), + target_size=79, + recog_param_pct=1, + stage_1_global_steps=439200, + params_kwargs={ + "num_params": [ + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + ], + "rest_params": 110, + "total_params": 620, + }, + aux_loss_scales={ + "0.314": "focal_loss", + "0.451": "focal_loss", + "0.588": "focal_loss", + "0.793": "focal_loss", + "1": 1, + }, + final_dropout=0, +) +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.independent_softmax.joint_train_num_params_random_pct import ( + train_step, +) diff --git a/2024-orthogonal-softmax/LBS-960/Tab5_orthosoftmax_jointly_train_three_models_FLOPs_aware_cmp_wise.config b/2024-orthogonal-softmax/LBS-960/Tab5_orthosoftmax_jointly_train_three_models_FLOPs_aware_cmp_wise.config new file mode 100644 index 00000000..ed3aa669 --- /dev/null +++ b/2024-orthogonal-softmax/LBS-960/Tab5_orthosoftmax_jointly_train_three_models_FLOPs_aware_cmp_wise.config @@ -0,0 +1,1018 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2400000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +debug_print_layer_output_template = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "partition_epoch": 1, + "seq_ordering": "laplace:.1000", + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.KErFrKsP3fTh/output/audio.hdf", + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.Clwnntg2nopq/output/audio.hdf", + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.h7DH1ILPAElF/output/targets.hdf", + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.vESciFN5It1u/output/targets.hdf", + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 5.466666666666666e-06, + 6.933333333333334e-06, + 8.400000000000001e-06, + 9.866666666666668e-06, + 1.1333333333333336e-05, + 1.28e-05, + 1.4266666666666667e-05, + 1.5733333333333334e-05, + 1.72e-05, + 1.866666666666667e-05, + 2.0133333333333336e-05, + 2.16e-05, + 2.3066666666666667e-05, + 2.4533333333333334e-05, + 2.6000000000000002e-05, + 2.746666666666667e-05, + 2.8933333333333336e-05, + 3.0400000000000004e-05, + 3.186666666666667e-05, + 3.333333333333334e-05, + 3.4800000000000006e-05, + 3.6266666666666676e-05, + 3.773333333333334e-05, + 3.9200000000000004e-05, + 4.0666666666666675e-05, + 4.213333333333334e-05, + 4.360000000000001e-05, + 4.506666666666667e-05, + 4.6533333333333344e-05, + 4.800000000000001e-05, + 4.946666666666668e-05, + 5.093333333333334e-05, + 5.2400000000000007e-05, + 5.386666666666668e-05, + 5.533333333333334e-05, + 5.680000000000001e-05, + 5.8266666666666676e-05, + 5.9733333333333346e-05, + 6.120000000000001e-05, + 6.266666666666668e-05, + 6.413333333333334e-05, + 6.560000000000001e-05, + 6.706666666666668e-05, + 6.853333333333335e-05, + 7.000000000000001e-05, + 7.146666666666668e-05, + 7.293333333333335e-05, + 7.44e-05, + 7.586666666666668e-05, + 7.733333333333335e-05, + 7.880000000000002e-05, + 8.026666666666668e-05, + 8.173333333333335e-05, + 8.320000000000002e-05, + 8.466666666666669e-05, + 8.613333333333334e-05, + 8.760000000000002e-05, + 8.906666666666669e-05, + 9.053333333333334e-05, + 9.200000000000001e-05, + 9.346666666666668e-05, + 9.493333333333336e-05, + 9.640000000000001e-05, + 9.786666666666668e-05, + 9.933333333333335e-05, + 0.00010080000000000001, + 0.00010226666666666668, + 0.00010373333333333335, + 0.00010520000000000002, + 0.00010666666666666668, + 0.00010813333333333335, + 0.00010960000000000002, + 0.00011106666666666668, + 0.00011253333333333335, + 0.00011400000000000002, + 0.00011546666666666669, + 0.00011693333333333335, + 0.00011840000000000002, + 0.00011986666666666669, + 0.00012133333333333336, + 0.0001228, + 0.0001242666666666667, + 0.00012573333333333334, + 0.0001272, + 0.00012866666666666669, + 0.00013013333333333334, + 0.0001316, + 0.00013306666666666668, + 0.00013453333333333334, + 0.000136, + 0.00013746666666666668, + 0.00013893333333333334, + 0.0001404, + 0.00014186666666666668, + 0.00014333333333333334, + 0.0001448, + 0.00014626666666666668, + 0.00014773333333333334, + 0.00014920000000000002, + 0.00015066666666666668, + 0.00015213333333333334, + 0.00015360000000000002, + 0.00015506666666666668, + 0.00015653333333333333, + 0.00015800000000000002, + 0.00015946666666666668, + 0.00016093333333333333, + 0.00016240000000000002, + 0.00016386666666666667, + 0.00016533333333333336, + 0.00016680000000000002, + 0.00016826666666666667, + 0.00016973333333333336, + 0.00017120000000000001, + 0.00017266666666666667, + 0.00017413333333333336, + 0.0001756, + 0.00017706666666666667, + 0.00017853333333333335, + 0.00018, + 0.00018146666666666667, + 0.00018293333333333335, + 0.0001844, + 0.0001858666666666667, + 0.00018733333333333335, + 0.0001888, + 0.0001902666666666667, + 0.00019173333333333335, + 0.0001932, + 0.0001946666666666667, + 0.00019613333333333335, + 0.0001976, + 0.0001990666666666667, + 0.00020053333333333335, + 0.00020200000000000003, + 0.0002034666666666667, + 0.00020493333333333335, + 0.00020640000000000003, + 0.0002078666666666667, + 0.00020933333333333334, + 0.00021080000000000003, + 0.00021226666666666669, + 0.00021373333333333334, + 0.00021520000000000003, + 0.00021666666666666668, + 0.00021813333333333334, + 0.00021960000000000003, + 0.00022106666666666668, + 0.00022253333333333337, + 0.00022400000000000002, + 0.00022546666666666668, + 0.00022693333333333337, + 0.00022840000000000002, + 0.00022986666666666668, + 0.00023133333333333336, + 0.00023280000000000002, + 0.00023426666666666668, + 0.00023573333333333336, + 0.00023720000000000002, + 0.0002386666666666667, + 0.00024013333333333336, + 0.00024160000000000002, + 0.0002430666666666667, + 0.0002445333333333334, + 0.000246, + 0.0002474666666666667, + 0.0002489333333333334, + 0.0002504, + 0.0002518666666666667, + 0.0002533333333333334, + 0.0002548, + 0.0002562666666666667, + 0.0002577333333333334, + 0.0002592, + 0.0002606666666666667, + 0.0002621333333333334, + 0.0002636, + 0.0002650666666666667, + 0.0002665333333333334, + 0.000268, + 0.0002694666666666667, + 0.0002709333333333334, + 0.0002724, + 0.0002738666666666667, + 0.0002753333333333334, + 0.0002768, + 0.0002782666666666667, + 0.0002797333333333334, + 0.0002812, + 0.0002826666666666667, + 0.0002841333333333334, + 0.0002856, + 0.0002870666666666667, + 0.00028853333333333337, + 0.00029000000000000006, + 0.0002914666666666667, + 0.00029293333333333337, + 0.00029440000000000005, + 0.0002958666666666667, + 0.00029733333333333337, + 0.00029880000000000005, + 0.0003002666666666667, + 0.00030173333333333337, + 0.00030320000000000005, + 0.0003046666666666667, + 0.00030613333333333337, + 0.00030760000000000005, + 0.0003090666666666667, + 0.00031053333333333336, + 0.00031200000000000005, + 0.0003134666666666667, + 0.00031493333333333336, + 0.00031640000000000005, + 0.0003178666666666667, + 0.00031933333333333336, + 0.00032080000000000005, + 0.0003222666666666667, + 0.00032373333333333336, + 0.00032520000000000004, + 0.00032666666666666673, + 0.00032813333333333336, + 0.00032960000000000004, + 0.0003310666666666667, + 0.00033253333333333336, + 0.00033400000000000004, + 0.0003354666666666667, + 0.00033693333333333336, + 0.00033840000000000004, + 0.0003398666666666667, + 0.00034133333333333335, + 0.00034280000000000004, + 0.0003442666666666667, + 0.00034573333333333335, + 0.00034720000000000004, + 0.0003486666666666667, + 0.00035013333333333335, + 0.00035160000000000004, + 0.0003530666666666667, + 0.00035453333333333335, + 0.00035600000000000003, + 0.0003574666666666667, + 0.00035893333333333335, + 0.00036040000000000003, + 0.0003618666666666667, + 0.0003633333333333334, + 0.00036480000000000003, + 0.0003662666666666667, + 0.0003677333333333334, + 0.00036920000000000003, + 0.0003706666666666667, + 0.0003721333333333334, + 0.00037360000000000003, + 0.0003750666666666667, + 0.0003765333333333334, + 0.000378, + 0.0003794666666666667, + 0.0003809333333333334, + 0.0003824, + 0.0003838666666666667, + 0.0003853333333333334, + 0.0003868, + 0.0003882666666666667, + 0.0003897333333333334, + 0.0003912, + 0.0003926666666666667, + 0.0003941333333333334, + 0.0003956, + 0.0003970666666666667, + 0.0003985333333333334, + 0.0004, + 0.00039853333333333333, + 0.0003970666666666667, + 0.0003956, + 0.00039413333333333334, + 0.0003926666666666667, + 0.0003912, + 0.00038973333333333334, + 0.0003882666666666667, + 0.0003868, + 0.00038533333333333334, + 0.0003838666666666667, + 0.0003824, + 0.00038093333333333334, + 0.0003794666666666667, + 0.000378, + 0.00037653333333333334, + 0.00037506666666666666, + 0.00037360000000000003, + 0.00037213333333333334, + 0.00037066666666666666, + 0.00036920000000000003, + 0.00036773333333333335, + 0.00036626666666666666, + 0.00036480000000000003, + 0.00036333333333333335, + 0.00036186666666666666, + 0.00036040000000000003, + 0.00035893333333333335, + 0.00035746666666666666, + 0.00035600000000000003, + 0.00035453333333333335, + 0.00035306666666666667, + 0.00035160000000000004, + 0.00035013333333333335, + 0.00034866666666666667, + 0.0003472, + 0.00034573333333333335, + 0.00034426666666666667, + 0.0003428, + 0.00034133333333333335, + 0.00033986666666666667, + 0.0003384, + 0.00033693333333333336, + 0.00033546666666666667, + 0.000334, + 0.00033253333333333336, + 0.00033106666666666667, + 0.0003296, + 0.00032813333333333336, + 0.0003266666666666667, + 0.0003252, + 0.00032373333333333336, + 0.0003222666666666667, + 0.0003208, + 0.00031933333333333336, + 0.0003178666666666667, + 0.0003164, + 0.00031493333333333336, + 0.0003134666666666667, + 0.000312, + 0.00031053333333333336, + 0.0003090666666666667, + 0.0003076, + 0.00030613333333333337, + 0.0003046666666666667, + 0.0003032, + 0.00030173333333333337, + 0.0003002666666666667, + 0.0002988, + 0.00029733333333333337, + 0.0002958666666666667, + 0.0002944, + 0.00029293333333333337, + 0.0002914666666666667, + 0.00029, + 0.0002885333333333333, + 0.0002870666666666667, + 0.0002856, + 0.0002841333333333333, + 0.00028266666666666663, + 0.0002812, + 0.0002797333333333333, + 0.00027826666666666664, + 0.0002768, + 0.0002753333333333333, + 0.00027386666666666664, + 0.0002724, + 0.0002709333333333333, + 0.00026946666666666664, + 0.000268, + 0.0002665333333333333, + 0.00026506666666666664, + 0.0002636, + 0.0002621333333333333, + 0.00026066666666666664, + 0.0002592, + 0.00025773333333333333, + 0.00025626666666666664, + 0.0002548, + 0.00025333333333333333, + 0.00025186666666666664, + 0.0002504, + 0.00024893333333333333, + 0.00024746666666666665, + 0.000246, + 0.00024453333333333333, + 0.00024306666666666668, + 0.0002416, + 0.00024013333333333333, + 0.00023866666666666665, + 0.0002372, + 0.00023573333333333334, + 0.00023426666666666665, + 0.0002328, + 0.00023133333333333334, + 0.00022986666666666665, + 0.0002284, + 0.00022693333333333334, + 0.00022546666666666665, + 0.000224, + 0.00022253333333333334, + 0.00022106666666666666, + 0.0002196, + 0.00021813333333333331, + 0.00021666666666666666, + 0.0002152, + 0.00021373333333333332, + 0.00021226666666666666, + 0.0002108, + 0.00020933333333333332, + 0.00020786666666666666, + 0.0002064, + 0.00020493333333333332, + 0.00020346666666666666, + 0.00020199999999999998, + 0.00020053333333333332, + 0.00019906666666666666, + 0.00019759999999999998, + 0.00019613333333333332, + 0.00019466666666666666, + 0.00019319999999999998, + 0.00019173333333333332, + 0.00019026666666666667, + 0.00018879999999999998, + 0.00018733333333333332, + 0.00018586666666666667, + 0.00018439999999999998, + 0.00018293333333333333, + 0.00018146666666666664, + 0.00017999999999999998, + 0.00017853333333333333, + 0.00017706666666666664, + 0.00017559999999999999, + 0.00017413333333333333, + 0.00017266666666666664, + 0.0001712, + 0.00016973333333333333, + 0.00016826666666666665, + 0.0001668, + 0.0001653333333333333, + 0.00016386666666666665, + 0.0001624, + 0.0001609333333333333, + 0.00015946666666666665, + 0.000158, + 0.0001565333333333333, + 0.00015506666666666662, + 0.0001536, + 0.0001521333333333333, + 0.00015066666666666662, + 0.0001492, + 0.0001477333333333333, + 0.00014626666666666663, + 0.0001448, + 0.0001433333333333333, + 0.00014186666666666663, + 0.0001404, + 0.0001389333333333333, + 0.00013746666666666663, + 0.000136, + 0.00013453333333333331, + 0.00013306666666666663, + 0.0001316, + 0.00013013333333333332, + 0.00012866666666666663, + 0.0001272, + 0.00012573333333333332, + 0.00012426666666666663, + 0.0001228, + 0.00012133333333333332, + 0.00011986666666666663, + 0.0001184, + 0.00011693333333333332, + 0.00011546666666666664, + 0.00011399999999999995, + 0.00011253333333333332, + 0.00011106666666666664, + 0.00010959999999999995, + 0.00010813333333333332, + 0.00010666666666666664, + 0.00010519999999999996, + 0.00010373333333333332, + 0.00010226666666666664, + 0.00010079999999999996, + 9.933333333333333e-05, + 9.786666666666664e-05, + 9.639999999999996e-05, + 9.493333333333333e-05, + 9.346666666666664e-05, + 9.199999999999996e-05, + 9.053333333333333e-05, + 8.906666666666665e-05, + 8.759999999999996e-05, + 8.613333333333333e-05, + 8.466666666666665e-05, + 8.319999999999996e-05, + 8.173333333333333e-05, + 8.026666666666665e-05, + 7.879999999999996e-05, + 7.733333333333328e-05, + 7.586666666666665e-05, + 7.439999999999997e-05, + 7.293333333333328e-05, + 7.146666666666665e-05, + 6.999999999999997e-05, + 6.853333333333328e-05, + 6.706666666666665e-05, + 6.559999999999997e-05, + 6.413333333333328e-05, + 6.266666666666665e-05, + 6.119999999999997e-05, + 5.9733333333333285e-05, + 5.8266666666666655e-05, + 5.679999999999997e-05, + 5.533333333333329e-05, + 5.386666666666666e-05, + 5.239999999999997e-05, + 5.093333333333329e-05, + 4.946666666666666e-05, + 4.7999999999999974e-05, + 4.653333333333329e-05, + 4.506666666666666e-05, + 4.3599999999999976e-05, + 4.213333333333329e-05, + 4.066666666666661e-05, + 3.919999999999998e-05, + 3.773333333333329e-05, + 3.626666666666661e-05, + 3.479999999999998e-05, + 3.3333333333333294e-05, + 3.186666666666661e-05, + 3.039999999999998e-05, + 2.8933333333333296e-05, + 2.746666666666661e-05, + 2.599999999999998e-05, + 2.4533333333333297e-05, + 2.3066666666666613e-05, + 2.1599999999999983e-05, + 2.01333333333333e-05, + 1.8666666666666614e-05, + 1.7199999999999984e-05, + 1.57333333333333e-05, + 1.4266666666666616e-05, + 1.2799999999999986e-05, + 1.1333333333333302e-05, + 9.866666666666617e-06, + 8.399999999999987e-06, + 6.933333333333303e-06, + 5.466666666666619e-06, + 4e-06, + 3.9335e-06, + 3.8669999999999996e-06, + 3.8005e-06, + 3.7339999999999997e-06, + 3.6675e-06, + 3.601e-06, + 3.5344999999999998e-06, + 3.4679999999999997e-06, + 3.4015e-06, + 3.335e-06, + 3.2685e-06, + 3.202e-06, + 3.1355e-06, + 3.0689999999999998e-06, + 3.0024999999999996e-06, + 2.936e-06, + 2.8695000000000002e-06, + 2.803e-06, + 2.7365e-06, + 2.67e-06, + 2.6034999999999997e-06, + 2.537e-06, + 2.4705e-06, + 2.404e-06, + 2.3375e-06, + 2.271e-06, + 2.2045e-06, + 2.138e-06, + 2.0715e-06, + 2.005e-06, + 1.9385e-06, + 1.872e-06, + 1.8055e-06, + 1.7390000000000002e-06, + 1.6725e-06, + 1.606e-06, + 1.5395000000000003e-06, + 1.4730000000000001e-06, + 1.4065e-06, + 1.3399999999999999e-06, + 1.2735000000000002e-06, + 1.207e-06, + 1.1405e-06, + 1.0740000000000002e-06, + 1.0075e-06, + 9.41e-07, + 8.745000000000003e-07, + 8.080000000000001e-07, + 7.415e-07, + 6.750000000000003e-07, + 6.085000000000002e-07, + 5.420000000000001e-07, + 4.7550000000000036e-07, + 4.0900000000000023e-07, + 3.425000000000001e-07, + 2.760000000000004e-07, + 2.0950000000000028e-07, + 1.4300000000000016e-07, + 7.650000000000003e-08, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 60 +model = "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/training/ReturnnTrainingJob.uhAaZznpfnbW/output/models/epoch" +num_epochs = 600 +num_inputs = 1 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "partition_epoch": 20, + "seq_ordering": "laplace:.1000", + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_core/returnn/hdf/BlissToPcmHDFJob.VZM5dHZhqlnJ/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.SYt8A5fOy2ta/output/targets.hdf" + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert( + 0, "/u/jxu/setups/librispeech-960/2023-10-17-torch-conformer-ctc/recipe" +) +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.independent_softmax.joint_train_num_params_random_pct import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.independent_softmax.joint_train_num_params_random_pct import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_ortho_softmax_cmp_wise import ( + ConformerEncoderConfig, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_ortho_softmax_cmp_wise import ( + ConformerBlockConfig, +) +from i6_models.parts.conformer_structure_prune.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer_structure_prune.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer_structure_prune.convolution import ( + ConformerConvolutionV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models_repo.i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=5, + freq_mask_max_size=8, + ), + conformer_cfg=ConformerEncoderConfig( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=512, + ), + ), + block_cfg=ConformerBlockConfig( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, num_att_heads=8, att_weights_dropout=0.1, dropout=0.1 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC(512), + ), + layer_dropout=0.1, + apply_ff_adaptive_dropout=True, + ), + pct_params_set=[0.45, 0.72, 1], + softmax_kwargs={ + "softmax_constraint_norm": "L2_norm_sqrt", + "initial_tau": 2, + "min_tau": 0.01, + "tau_annealing": 0.999992, + "softmax_constraint_loss_scale": "linear_increase", + "max_softmax_constraint_loss_scale": 1, + "softmax_constraint_warmup_steps": 439200, + }, + ), + target_size=79, + recog_param_pct=1, + stage_1_global_steps=439200, + params_kwargs={ + "num_params": [ + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 3.675, + 5.65, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 0.92625, + 3.675, + 3.675, + 3.675, + 3.675, + ], + "rest_params": 110, + "total_params": 620, + }, + aux_loss_scales={"0.45": "focal_loss", "0.72": "focal_loss", "1": 1}, + final_dropout=0, +) +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.lbs_960.pytorch_networks.dynamic_encoder_size.independent_softmax.joint_train_num_params_random_pct import ( + train_step, +) diff --git a/2024-orthogonal-softmax/README.md b/2024-orthogonal-softmax/README.md index 4552316c..5eff80ca 100644 --- a/2024-orthogonal-softmax/README.md +++ b/2024-orthogonal-softmax/README.md @@ -4,4 +4,5 @@ Efficient Supernet Training with Orthogonal Softmax for Scalable ASR Model Compr We use [RETURNN](https://github.com/rwth-i6/returnn) for training and our setups are based on [Sisyphus](https://github.com/rwth-i6/sisyphus). +We use models parts from [i6-models](https://github.com/rwth-i6/i6_models/tree/jing-dynamic-encoder-size) diff --git a/2024-orthogonal-softmax/TED-LIUM-v2/Figure1_orthosoftmax_jointly_train_five_models_FLOPs_aware_cmp_wise.config b/2024-orthogonal-softmax/TED-LIUM-v2/Figure1_orthosoftmax_jointly_train_five_models_FLOPs_aware_cmp_wise.config new file mode 100644 index 00000000..0fdbfd9d --- /dev/null +++ b/2024-orthogonal-softmax/TED-LIUM-v2/Figure1_orthosoftmax_jointly_train_five_models_FLOPs_aware_cmp_wise.config @@ -0,0 +1,646 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2240000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.AltkEvXwM3dF/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.qGsTLu0lnU5n/output/targets.hdf" + ], + "partition_epoch": 1, + "seq_ordering": "sorted", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 4e-06, + 7.6e-06, + 1.1200000000000001e-05, + 1.48e-05, + 1.84e-05, + 2.2e-05, + 2.5600000000000002e-05, + 2.92e-05, + 3.2800000000000004e-05, + 3.6400000000000004e-05, + 4e-05, + 4.36e-05, + 4.720000000000001e-05, + 5.080000000000001e-05, + 5.440000000000001e-05, + 5.800000000000001e-05, + 6.16e-05, + 6.520000000000001e-05, + 6.88e-05, + 7.240000000000001e-05, + 7.6e-05, + 7.960000000000001e-05, + 8.32e-05, + 8.680000000000001e-05, + 9.040000000000002e-05, + 9.400000000000001e-05, + 9.760000000000001e-05, + 0.00010120000000000001, + 0.00010480000000000001, + 0.0001084, + 0.00011200000000000001, + 0.0001156, + 0.00011920000000000001, + 0.0001228, + 0.0001264, + 0.00013, + 0.0001336, + 0.0001372, + 0.0001408, + 0.00014439999999999999, + 0.000148, + 0.0001516, + 0.0001552, + 0.0001588, + 0.0001624, + 0.000166, + 0.0001696, + 0.0001732, + 0.00017680000000000001, + 0.0001804, + 0.000184, + 0.0001876, + 0.0001912, + 0.0001948, + 0.0001984, + 0.000202, + 0.0002056, + 0.00020920000000000002, + 0.0002128, + 0.0002164, + 0.00022, + 0.00022360000000000001, + 0.0002272, + 0.0002308, + 0.0002344, + 0.000238, + 0.00024160000000000002, + 0.0002452, + 0.00024880000000000003, + 0.0002524, + 0.000256, + 0.0002596, + 0.0002632, + 0.00026680000000000003, + 0.0002704, + 0.000274, + 0.0002776, + 0.0002812, + 0.0002848, + 0.0002884, + 0.000292, + 0.00029560000000000003, + 0.0002992, + 0.0003028, + 0.0003064, + 0.00031, + 0.00031360000000000003, + 0.0003172, + 0.0003208, + 0.0003244, + 0.000328, + 0.00033160000000000004, + 0.0003352, + 0.0003388, + 0.00034240000000000003, + 0.000346, + 0.00034960000000000004, + 0.0003532, + 0.0003568, + 0.00036040000000000003, + 0.000364, + 0.0003676, + 0.0003712, + 0.0003748, + 0.00037840000000000004, + 0.000382, + 0.0003856, + 0.00038920000000000003, + 0.0003928, + 0.00039640000000000004, + 0.0004, + 0.00039672727272727277, + 0.00039345454545454547, + 0.0003901818181818182, + 0.0003869090909090909, + 0.00038363636363636367, + 0.00038036363636363636, + 0.0003770909090909091, + 0.0003738181818181818, + 0.00037054545454545456, + 0.00036727272727272726, + 0.000364, + 0.00036072727272727276, + 0.00035745454545454546, + 0.0003541818181818182, + 0.0003509090909090909, + 0.00034763636363636366, + 0.00034436363636363636, + 0.0003410909090909091, + 0.00033781818181818186, + 0.00033454545454545456, + 0.0003312727272727273, + 0.000328, + 0.00032472727272727276, + 0.00032145454545454545, + 0.0003181818181818182, + 0.00031490909090909095, + 0.00031163636363636365, + 0.00030836363636363635, + 0.0003050909090909091, + 0.00030181818181818185, + 0.00029854545454545455, + 0.0002952727272727273, + 0.000292, + 0.00028872727272727275, + 0.00028545454545454544, + 0.0002821818181818182, + 0.00027890909090909095, + 0.00027563636363636364, + 0.00027236363636363634, + 0.0002690909090909091, + 0.00026581818181818184, + 0.00026254545454545454, + 0.0002592727272727273, + 0.00025600000000000004, + 0.00025272727272727274, + 0.00024945454545454544, + 0.0002461818181818182, + 0.0002429090909090909, + 0.00023963636363636364, + 0.00023636363636363636, + 0.0002330909090909091, + 0.00022981818181818184, + 0.00022654545454545456, + 0.00022327272727272728, + 0.00022, + 0.00021672727272727273, + 0.00021345454545454546, + 0.00021018181818181818, + 0.0002069090909090909, + 0.00020363636363636366, + 0.00020036363636363638, + 0.0001970909090909091, + 0.00019381818181818183, + 0.00019054545454545455, + 0.00018727272727272728, + 0.000184, + 0.00018072727272727272, + 0.00017745454545454545, + 0.0001741818181818182, + 0.00017090909090909092, + 0.00016763636363636365, + 0.00016436363636363637, + 0.0001610909090909091, + 0.00015781818181818182, + 0.00015454545454545457, + 0.00015127272727272727, + 0.00014800000000000002, + 0.00014472727272727272, + 0.00014145454545454547, + 0.00013818181818181816, + 0.00013490909090909092, + 0.0001316363636363636, + 0.00012836363636363636, + 0.00012509090909090912, + 0.00012181818181818181, + 0.00011854545454545456, + 0.00011527272727272726, + 0.00011200000000000001, + 0.00010872727272727271, + 0.00010545454545454546, + 0.00010218181818181816, + 9.890909090909091e-05, + 9.563636363636366e-05, + 9.236363636363636e-05, + 8.909090909090911e-05, + 8.58181818181818e-05, + 8.254545454545456e-05, + 7.927272727272725e-05, + 7.6e-05, + 7.27272727272727e-05, + 6.945454545454545e-05, + 6.61818181818182e-05, + 6.29090909090909e-05, + 5.963636363636365e-05, + 5.636363636363635e-05, + 5.30909090909091e-05, + 4.98181818181818e-05, + 4.654545454545455e-05, + 4.3272727272727245e-05, + 4e-05, + 3.8621034482758623e-05, + 3.7242068965517244e-05, + 3.5863103448275864e-05, + 3.4484137931034484e-05, + 3.3105172413793104e-05, + 3.172620689655173e-05, + 3.0347241379310348e-05, + 2.8968275862068968e-05, + 2.7589310344827588e-05, + 2.621034482758621e-05, + 2.4831379310344832e-05, + 2.3452413793103452e-05, + 2.2073448275862072e-05, + 2.0694482758620692e-05, + 1.9315517241379313e-05, + 1.7936551724137933e-05, + 1.6557586206896553e-05, + 1.5178620689655173e-05, + 1.3799655172413793e-05, + 1.2420689655172413e-05, + 1.1041724137931037e-05, + 9.662758620689657e-06, + 8.283793103448274e-06, + 6.904827586206901e-06, + 5.525862068965521e-06, + 4.146896551724141e-06, + 2.7679310344827613e-06, + 1.3889655172413814e-06, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 128 +model = "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/training/ReturnnTrainingJob.X6hEjDbuRROU/output/models/epoch" +num_epochs = 250 +num_inputs = 80 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.T3qQ5mfrQwlw/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.bAoRzty8czAI/output/targets.hdf" + ], + "partition_epoch": 5, + "seq_ordering": "laplace:.1000", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert(0, "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/recipe") +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable_random_pct import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable_random_pct import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_ortho_softmax_cmp_wise import ( + ConformerEncoderConfig, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_ortho_softmax_cmp_wise import ( + ConformerBlockConfig, +) +from i6_models.parts.conformer_structure_prune.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer_structure_prune.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer_structure_prune.convolution import ( + ConformerConvolutionV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models_repo.i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=16, + freq_mask_max_size=8, + ), + conformer_cfg=ConformerEncoderConfig( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=384, + ), + ), + block_cfg=ConformerBlockConfig( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.2, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, num_att_heads=6, att_weights_dropout=0.1, dropout=0.2 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=384, + kernel_size=31, + dropout=0.2, + activation=SiLU(), + norm=LayerNormNC((384,), eps=1e-05, elementwise_affine=True), + ), + layer_dropout=0.1, + apply_ff_adaptive_dropout=True, + ), + pct_params_set=[0.337, 0.457, 0.578, 0.759, 1], + softmax_kwargs={ + "softmax_constraint_norm": "L2_norm_sqrt", + "initial_tau": 1, + "min_tau": 0.01, + "tau_annealing": 0.999992, + "softmax_constraint_loss_scale": "linear_increase", + "max_softmax_constraint_loss_scale": 1, + "softmax_constraint_warmup_steps": 225000, + }, + ), + target_size=79, + recog_param_pct=1, + stage_1_global_steps=225000, + params_kwargs={ + "num_params": [ + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + ], + "rest_params": 110, + "total_params": 400, + }, + aux_loss_scales={ + "0.337": "focal_loss", + "0.457": "focal_loss", + "0.578": "focal_loss", + "0.759": "focal_loss", + "1": 1, + }, + final_dropout=0.2, +) + +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable_random_pct import ( + train_step, +) diff --git a/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_FLOPs_aware_cmp_wise.config b/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_FLOPs_aware_cmp_wise.config new file mode 100644 index 00000000..a420f283 --- /dev/null +++ b/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_FLOPs_aware_cmp_wise.config @@ -0,0 +1,645 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2240000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.AltkEvXwM3dF/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.qGsTLu0lnU5n/output/targets.hdf" + ], + "partition_epoch": 1, + "seq_ordering": "sorted", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 4e-06, + 7.6e-06, + 1.1200000000000001e-05, + 1.48e-05, + 1.84e-05, + 2.2e-05, + 2.5600000000000002e-05, + 2.92e-05, + 3.2800000000000004e-05, + 3.6400000000000004e-05, + 4e-05, + 4.36e-05, + 4.720000000000001e-05, + 5.080000000000001e-05, + 5.440000000000001e-05, + 5.800000000000001e-05, + 6.16e-05, + 6.520000000000001e-05, + 6.88e-05, + 7.240000000000001e-05, + 7.6e-05, + 7.960000000000001e-05, + 8.32e-05, + 8.680000000000001e-05, + 9.040000000000002e-05, + 9.400000000000001e-05, + 9.760000000000001e-05, + 0.00010120000000000001, + 0.00010480000000000001, + 0.0001084, + 0.00011200000000000001, + 0.0001156, + 0.00011920000000000001, + 0.0001228, + 0.0001264, + 0.00013, + 0.0001336, + 0.0001372, + 0.0001408, + 0.00014439999999999999, + 0.000148, + 0.0001516, + 0.0001552, + 0.0001588, + 0.0001624, + 0.000166, + 0.0001696, + 0.0001732, + 0.00017680000000000001, + 0.0001804, + 0.000184, + 0.0001876, + 0.0001912, + 0.0001948, + 0.0001984, + 0.000202, + 0.0002056, + 0.00020920000000000002, + 0.0002128, + 0.0002164, + 0.00022, + 0.00022360000000000001, + 0.0002272, + 0.0002308, + 0.0002344, + 0.000238, + 0.00024160000000000002, + 0.0002452, + 0.00024880000000000003, + 0.0002524, + 0.000256, + 0.0002596, + 0.0002632, + 0.00026680000000000003, + 0.0002704, + 0.000274, + 0.0002776, + 0.0002812, + 0.0002848, + 0.0002884, + 0.000292, + 0.00029560000000000003, + 0.0002992, + 0.0003028, + 0.0003064, + 0.00031, + 0.00031360000000000003, + 0.0003172, + 0.0003208, + 0.0003244, + 0.000328, + 0.00033160000000000004, + 0.0003352, + 0.0003388, + 0.00034240000000000003, + 0.000346, + 0.00034960000000000004, + 0.0003532, + 0.0003568, + 0.00036040000000000003, + 0.000364, + 0.0003676, + 0.0003712, + 0.0003748, + 0.00037840000000000004, + 0.000382, + 0.0003856, + 0.00038920000000000003, + 0.0003928, + 0.00039640000000000004, + 0.0004, + 0.00039672727272727277, + 0.00039345454545454547, + 0.0003901818181818182, + 0.0003869090909090909, + 0.00038363636363636367, + 0.00038036363636363636, + 0.0003770909090909091, + 0.0003738181818181818, + 0.00037054545454545456, + 0.00036727272727272726, + 0.000364, + 0.00036072727272727276, + 0.00035745454545454546, + 0.0003541818181818182, + 0.0003509090909090909, + 0.00034763636363636366, + 0.00034436363636363636, + 0.0003410909090909091, + 0.00033781818181818186, + 0.00033454545454545456, + 0.0003312727272727273, + 0.000328, + 0.00032472727272727276, + 0.00032145454545454545, + 0.0003181818181818182, + 0.00031490909090909095, + 0.00031163636363636365, + 0.00030836363636363635, + 0.0003050909090909091, + 0.00030181818181818185, + 0.00029854545454545455, + 0.0002952727272727273, + 0.000292, + 0.00028872727272727275, + 0.00028545454545454544, + 0.0002821818181818182, + 0.00027890909090909095, + 0.00027563636363636364, + 0.00027236363636363634, + 0.0002690909090909091, + 0.00026581818181818184, + 0.00026254545454545454, + 0.0002592727272727273, + 0.00025600000000000004, + 0.00025272727272727274, + 0.00024945454545454544, + 0.0002461818181818182, + 0.0002429090909090909, + 0.00023963636363636364, + 0.00023636363636363636, + 0.0002330909090909091, + 0.00022981818181818184, + 0.00022654545454545456, + 0.00022327272727272728, + 0.00022, + 0.00021672727272727273, + 0.00021345454545454546, + 0.00021018181818181818, + 0.0002069090909090909, + 0.00020363636363636366, + 0.00020036363636363638, + 0.0001970909090909091, + 0.00019381818181818183, + 0.00019054545454545455, + 0.00018727272727272728, + 0.000184, + 0.00018072727272727272, + 0.00017745454545454545, + 0.0001741818181818182, + 0.00017090909090909092, + 0.00016763636363636365, + 0.00016436363636363637, + 0.0001610909090909091, + 0.00015781818181818182, + 0.00015454545454545457, + 0.00015127272727272727, + 0.00014800000000000002, + 0.00014472727272727272, + 0.00014145454545454547, + 0.00013818181818181816, + 0.00013490909090909092, + 0.0001316363636363636, + 0.00012836363636363636, + 0.00012509090909090912, + 0.00012181818181818181, + 0.00011854545454545456, + 0.00011527272727272726, + 0.00011200000000000001, + 0.00010872727272727271, + 0.00010545454545454546, + 0.00010218181818181816, + 9.890909090909091e-05, + 9.563636363636366e-05, + 9.236363636363636e-05, + 8.909090909090911e-05, + 8.58181818181818e-05, + 8.254545454545456e-05, + 7.927272727272725e-05, + 7.6e-05, + 7.27272727272727e-05, + 6.945454545454545e-05, + 6.61818181818182e-05, + 6.29090909090909e-05, + 5.963636363636365e-05, + 5.636363636363635e-05, + 5.30909090909091e-05, + 4.98181818181818e-05, + 4.654545454545455e-05, + 4.3272727272727245e-05, + 4e-05, + 3.8621034482758623e-05, + 3.7242068965517244e-05, + 3.5863103448275864e-05, + 3.4484137931034484e-05, + 3.3105172413793104e-05, + 3.172620689655173e-05, + 3.0347241379310348e-05, + 2.8968275862068968e-05, + 2.7589310344827588e-05, + 2.621034482758621e-05, + 2.4831379310344832e-05, + 2.3452413793103452e-05, + 2.2073448275862072e-05, + 2.0694482758620692e-05, + 1.9315517241379313e-05, + 1.7936551724137933e-05, + 1.6557586206896553e-05, + 1.5178620689655173e-05, + 1.3799655172413793e-05, + 1.2420689655172413e-05, + 1.1041724137931037e-05, + 9.662758620689657e-06, + 8.283793103448274e-06, + 6.904827586206901e-06, + 5.525862068965521e-06, + 4.146896551724141e-06, + 2.7679310344827613e-06, + 1.3889655172413814e-06, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 128 +model = "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/training/ReturnnTrainingJob.b0OtMsOnoNGK/output/models/epoch" +num_epochs = 250 +num_inputs = 80 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.T3qQ5mfrQwlw/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.bAoRzty8czAI/output/targets.hdf" + ], + "partition_epoch": 5, + "seq_ordering": "laplace:.1000", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert(0, "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/recipe") +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_independent_softmax_based_on_num_params import ( + ConformerEncoderConfig, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_independent_softmax_based_on_num_params import ( + ConformerBlockConfig, +) +from i6_models.parts.conformer_structure_prune.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer_structure_prune.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer_structure_prune.convolution import ( + ConformerConvolutionV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models_repo.i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=16, + freq_mask_max_size=5, + ), + conformer_cfg=ConformerEncoderConfig( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=384, + ), + ), + block_cfg=ConformerBlockConfig( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.2, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, num_att_heads=6, att_weights_dropout=0.1, dropout=0.2 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=384, + kernel_size=31, + dropout=0.2, + activation=SiLU(), + norm=LayerNormNC((384,), eps=1e-05, elementwise_affine=True), + ), + layer_dropout=0.1, + apply_ff_adaptive_dropout=True, + apply_mhsa_adaptive_dropout=False, + ), + pct_params_set=[0.63, 1], + softmax_kwargs={ + "softmax_constraint_norm": "L2_norm_sqrt", + "initial_tau": 1, + "min_tau": 0.01, + "tau_annealing": 0.999992, + "softmax_constraint_loss_scale": "linear_increase", + "max_softmax_constraint_loss_scale": 1, + "softmax_constraint_warmup_steps": 225000, + }, + component_dropout=None, + ), + target_size=79, + recog_param_pct=1, + param_pct_anneal_kwargs={ + "pct_anneal_num_steps_per_iter": 225000.0, + "pct_reduction_per_iter": 0.5, + }, + params_kwargs={ + "num_params": [ + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + ], + "rest_params": 110, + "total_params": 400, + }, + aux_loss_scales={"0.63": "focal_loss", "1": 1}, + final_dropout=0.2, +) + +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable import ( + train_step, +) diff --git a/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_FLOPs_aware_layer_wise.config b/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_FLOPs_aware_layer_wise.config new file mode 100644 index 00000000..87db0a7b --- /dev/null +++ b/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_FLOPs_aware_layer_wise.config @@ -0,0 +1,511 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2240000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.AltkEvXwM3dF/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.qGsTLu0lnU5n/output/targets.hdf" + ], + "partition_epoch": 1, + "seq_ordering": "sorted", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 4e-06, + 7.6e-06, + 1.1200000000000001e-05, + 1.48e-05, + 1.84e-05, + 2.2e-05, + 2.5600000000000002e-05, + 2.92e-05, + 3.2800000000000004e-05, + 3.6400000000000004e-05, + 4e-05, + 4.36e-05, + 4.720000000000001e-05, + 5.080000000000001e-05, + 5.440000000000001e-05, + 5.800000000000001e-05, + 6.16e-05, + 6.520000000000001e-05, + 6.88e-05, + 7.240000000000001e-05, + 7.6e-05, + 7.960000000000001e-05, + 8.32e-05, + 8.680000000000001e-05, + 9.040000000000002e-05, + 9.400000000000001e-05, + 9.760000000000001e-05, + 0.00010120000000000001, + 0.00010480000000000001, + 0.0001084, + 0.00011200000000000001, + 0.0001156, + 0.00011920000000000001, + 0.0001228, + 0.0001264, + 0.00013, + 0.0001336, + 0.0001372, + 0.0001408, + 0.00014439999999999999, + 0.000148, + 0.0001516, + 0.0001552, + 0.0001588, + 0.0001624, + 0.000166, + 0.0001696, + 0.0001732, + 0.00017680000000000001, + 0.0001804, + 0.000184, + 0.0001876, + 0.0001912, + 0.0001948, + 0.0001984, + 0.000202, + 0.0002056, + 0.00020920000000000002, + 0.0002128, + 0.0002164, + 0.00022, + 0.00022360000000000001, + 0.0002272, + 0.0002308, + 0.0002344, + 0.000238, + 0.00024160000000000002, + 0.0002452, + 0.00024880000000000003, + 0.0002524, + 0.000256, + 0.0002596, + 0.0002632, + 0.00026680000000000003, + 0.0002704, + 0.000274, + 0.0002776, + 0.0002812, + 0.0002848, + 0.0002884, + 0.000292, + 0.00029560000000000003, + 0.0002992, + 0.0003028, + 0.0003064, + 0.00031, + 0.00031360000000000003, + 0.0003172, + 0.0003208, + 0.0003244, + 0.000328, + 0.00033160000000000004, + 0.0003352, + 0.0003388, + 0.00034240000000000003, + 0.000346, + 0.00034960000000000004, + 0.0003532, + 0.0003568, + 0.00036040000000000003, + 0.000364, + 0.0003676, + 0.0003712, + 0.0003748, + 0.00037840000000000004, + 0.000382, + 0.0003856, + 0.00038920000000000003, + 0.0003928, + 0.00039640000000000004, + 0.0004, + 0.00039672727272727277, + 0.00039345454545454547, + 0.0003901818181818182, + 0.0003869090909090909, + 0.00038363636363636367, + 0.00038036363636363636, + 0.0003770909090909091, + 0.0003738181818181818, + 0.00037054545454545456, + 0.00036727272727272726, + 0.000364, + 0.00036072727272727276, + 0.00035745454545454546, + 0.0003541818181818182, + 0.0003509090909090909, + 0.00034763636363636366, + 0.00034436363636363636, + 0.0003410909090909091, + 0.00033781818181818186, + 0.00033454545454545456, + 0.0003312727272727273, + 0.000328, + 0.00032472727272727276, + 0.00032145454545454545, + 0.0003181818181818182, + 0.00031490909090909095, + 0.00031163636363636365, + 0.00030836363636363635, + 0.0003050909090909091, + 0.00030181818181818185, + 0.00029854545454545455, + 0.0002952727272727273, + 0.000292, + 0.00028872727272727275, + 0.00028545454545454544, + 0.0002821818181818182, + 0.00027890909090909095, + 0.00027563636363636364, + 0.00027236363636363634, + 0.0002690909090909091, + 0.00026581818181818184, + 0.00026254545454545454, + 0.0002592727272727273, + 0.00025600000000000004, + 0.00025272727272727274, + 0.00024945454545454544, + 0.0002461818181818182, + 0.0002429090909090909, + 0.00023963636363636364, + 0.00023636363636363636, + 0.0002330909090909091, + 0.00022981818181818184, + 0.00022654545454545456, + 0.00022327272727272728, + 0.00022, + 0.00021672727272727273, + 0.00021345454545454546, + 0.00021018181818181818, + 0.0002069090909090909, + 0.00020363636363636366, + 0.00020036363636363638, + 0.0001970909090909091, + 0.00019381818181818183, + 0.00019054545454545455, + 0.00018727272727272728, + 0.000184, + 0.00018072727272727272, + 0.00017745454545454545, + 0.0001741818181818182, + 0.00017090909090909092, + 0.00016763636363636365, + 0.00016436363636363637, + 0.0001610909090909091, + 0.00015781818181818182, + 0.00015454545454545457, + 0.00015127272727272727, + 0.00014800000000000002, + 0.00014472727272727272, + 0.00014145454545454547, + 0.00013818181818181816, + 0.00013490909090909092, + 0.0001316363636363636, + 0.00012836363636363636, + 0.00012509090909090912, + 0.00012181818181818181, + 0.00011854545454545456, + 0.00011527272727272726, + 0.00011200000000000001, + 0.00010872727272727271, + 0.00010545454545454546, + 0.00010218181818181816, + 9.890909090909091e-05, + 9.563636363636366e-05, + 9.236363636363636e-05, + 8.909090909090911e-05, + 8.58181818181818e-05, + 8.254545454545456e-05, + 7.927272727272725e-05, + 7.6e-05, + 7.27272727272727e-05, + 6.945454545454545e-05, + 6.61818181818182e-05, + 6.29090909090909e-05, + 5.963636363636365e-05, + 5.636363636363635e-05, + 5.30909090909091e-05, + 4.98181818181818e-05, + 4.654545454545455e-05, + 4.3272727272727245e-05, + 4e-05, + 3.8621034482758623e-05, + 3.7242068965517244e-05, + 3.5863103448275864e-05, + 3.4484137931034484e-05, + 3.3105172413793104e-05, + 3.172620689655173e-05, + 3.0347241379310348e-05, + 2.8968275862068968e-05, + 2.7589310344827588e-05, + 2.621034482758621e-05, + 2.4831379310344832e-05, + 2.3452413793103452e-05, + 2.2073448275862072e-05, + 2.0694482758620692e-05, + 1.9315517241379313e-05, + 1.7936551724137933e-05, + 1.6557586206896553e-05, + 1.5178620689655173e-05, + 1.3799655172413793e-05, + 1.2420689655172413e-05, + 1.1041724137931037e-05, + 9.662758620689657e-06, + 8.283793103448274e-06, + 6.904827586206901e-06, + 5.525862068965521e-06, + 4.146896551724141e-06, + 2.7679310344827613e-06, + 1.3889655172413814e-06, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 128 +model = "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/training/ReturnnTrainingJob.AqL5F3LvpxEM/output/models/epoch" +num_epochs = 250 +num_inputs = 80 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.T3qQ5mfrQwlw/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.bAoRzty8czAI/output/targets.hdf" + ], + "partition_epoch": 5, + "seq_ordering": "laplace:.1000", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert(0, "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/recipe") +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_layers.joint_train_two_models_layerwise_num_params_aware import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_layers.joint_train_two_models_layerwise_num_params_aware import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_independent_softmax_layerwise_num_params_aware import ( + ConformerEncoderConfig, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_independent_softmax_layerwise_num_params_aware import ( + ConformerBlockConfig, +) +from i6_models.parts.conformer.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer.convolution import ConformerConvolutionV1Config +from torch.nn.modules.activation import SiLU +from i6_models_repo.i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=16, + freq_mask_max_size=5, + ), + conformer_cfg=ConformerEncoderConfig( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=384, + ), + ), + block_cfg=ConformerBlockConfig( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.2, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, num_att_heads=6, att_weights_dropout=0.1, dropout=0.2 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=384, + kernel_size=31, + dropout=0.2, + activation=SiLU(), + norm=LayerNormNC((384,), eps=1e-05, elementwise_affine=True), + ), + layer_dropout=0, + modules=["ff", "conv", "mhsa", "ff"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + pct_params_set=[0.63, 1], + layer_dropout_kwargs={"layer_dropout_stage_1": 0, "layer_dropout_stage_2": 0.2}, + softmax_kwargs={ + "softmax_constraint_norm": "L2_norm_sqrt", + "initial_tau": 1, + "min_tau": 0.01, + "tau_annealing": 0.999992, + "softmax_constraint_loss_scale": "linear_increase", + "max_softmax_constraint_loss_scale": 1, + "softmax_constraint_warmup_steps": 225000, + }, + ), + target_size=79, + recog_param_pct=1, + param_pct_anneal_kwargs={ + "pct_anneal_num_steps_per_iter": 225000.0, + "pct_reduction_per_iter": 0.5, + }, + params_kwargs={ + "num_params": [ + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + 8.27, + 3.21, + 4.18, + 8.27, + ], + "rest_params": 110, + "total_params": 400, + }, + aux_loss_scales={"0.63": "focal_loss", "1": 1}, + final_dropout=0, +) + +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_layers.joint_train_two_models_layerwise_num_params_aware import ( + train_step, +) diff --git a/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_sparsity_aware_cmp_wise.config b/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_sparsity_aware_cmp_wise.config new file mode 100644 index 00000000..ec1ee136 --- /dev/null +++ b/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_sparsity_aware_cmp_wise.config @@ -0,0 +1,645 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2240000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.AltkEvXwM3dF/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.qGsTLu0lnU5n/output/targets.hdf" + ], + "partition_epoch": 1, + "seq_ordering": "sorted", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 4e-06, + 7.6e-06, + 1.1200000000000001e-05, + 1.48e-05, + 1.84e-05, + 2.2e-05, + 2.5600000000000002e-05, + 2.92e-05, + 3.2800000000000004e-05, + 3.6400000000000004e-05, + 4e-05, + 4.36e-05, + 4.720000000000001e-05, + 5.080000000000001e-05, + 5.440000000000001e-05, + 5.800000000000001e-05, + 6.16e-05, + 6.520000000000001e-05, + 6.88e-05, + 7.240000000000001e-05, + 7.6e-05, + 7.960000000000001e-05, + 8.32e-05, + 8.680000000000001e-05, + 9.040000000000002e-05, + 9.400000000000001e-05, + 9.760000000000001e-05, + 0.00010120000000000001, + 0.00010480000000000001, + 0.0001084, + 0.00011200000000000001, + 0.0001156, + 0.00011920000000000001, + 0.0001228, + 0.0001264, + 0.00013, + 0.0001336, + 0.0001372, + 0.0001408, + 0.00014439999999999999, + 0.000148, + 0.0001516, + 0.0001552, + 0.0001588, + 0.0001624, + 0.000166, + 0.0001696, + 0.0001732, + 0.00017680000000000001, + 0.0001804, + 0.000184, + 0.0001876, + 0.0001912, + 0.0001948, + 0.0001984, + 0.000202, + 0.0002056, + 0.00020920000000000002, + 0.0002128, + 0.0002164, + 0.00022, + 0.00022360000000000001, + 0.0002272, + 0.0002308, + 0.0002344, + 0.000238, + 0.00024160000000000002, + 0.0002452, + 0.00024880000000000003, + 0.0002524, + 0.000256, + 0.0002596, + 0.0002632, + 0.00026680000000000003, + 0.0002704, + 0.000274, + 0.0002776, + 0.0002812, + 0.0002848, + 0.0002884, + 0.000292, + 0.00029560000000000003, + 0.0002992, + 0.0003028, + 0.0003064, + 0.00031, + 0.00031360000000000003, + 0.0003172, + 0.0003208, + 0.0003244, + 0.000328, + 0.00033160000000000004, + 0.0003352, + 0.0003388, + 0.00034240000000000003, + 0.000346, + 0.00034960000000000004, + 0.0003532, + 0.0003568, + 0.00036040000000000003, + 0.000364, + 0.0003676, + 0.0003712, + 0.0003748, + 0.00037840000000000004, + 0.000382, + 0.0003856, + 0.00038920000000000003, + 0.0003928, + 0.00039640000000000004, + 0.0004, + 0.00039672727272727277, + 0.00039345454545454547, + 0.0003901818181818182, + 0.0003869090909090909, + 0.00038363636363636367, + 0.00038036363636363636, + 0.0003770909090909091, + 0.0003738181818181818, + 0.00037054545454545456, + 0.00036727272727272726, + 0.000364, + 0.00036072727272727276, + 0.00035745454545454546, + 0.0003541818181818182, + 0.0003509090909090909, + 0.00034763636363636366, + 0.00034436363636363636, + 0.0003410909090909091, + 0.00033781818181818186, + 0.00033454545454545456, + 0.0003312727272727273, + 0.000328, + 0.00032472727272727276, + 0.00032145454545454545, + 0.0003181818181818182, + 0.00031490909090909095, + 0.00031163636363636365, + 0.00030836363636363635, + 0.0003050909090909091, + 0.00030181818181818185, + 0.00029854545454545455, + 0.0002952727272727273, + 0.000292, + 0.00028872727272727275, + 0.00028545454545454544, + 0.0002821818181818182, + 0.00027890909090909095, + 0.00027563636363636364, + 0.00027236363636363634, + 0.0002690909090909091, + 0.00026581818181818184, + 0.00026254545454545454, + 0.0002592727272727273, + 0.00025600000000000004, + 0.00025272727272727274, + 0.00024945454545454544, + 0.0002461818181818182, + 0.0002429090909090909, + 0.00023963636363636364, + 0.00023636363636363636, + 0.0002330909090909091, + 0.00022981818181818184, + 0.00022654545454545456, + 0.00022327272727272728, + 0.00022, + 0.00021672727272727273, + 0.00021345454545454546, + 0.00021018181818181818, + 0.0002069090909090909, + 0.00020363636363636366, + 0.00020036363636363638, + 0.0001970909090909091, + 0.00019381818181818183, + 0.00019054545454545455, + 0.00018727272727272728, + 0.000184, + 0.00018072727272727272, + 0.00017745454545454545, + 0.0001741818181818182, + 0.00017090909090909092, + 0.00016763636363636365, + 0.00016436363636363637, + 0.0001610909090909091, + 0.00015781818181818182, + 0.00015454545454545457, + 0.00015127272727272727, + 0.00014800000000000002, + 0.00014472727272727272, + 0.00014145454545454547, + 0.00013818181818181816, + 0.00013490909090909092, + 0.0001316363636363636, + 0.00012836363636363636, + 0.00012509090909090912, + 0.00012181818181818181, + 0.00011854545454545456, + 0.00011527272727272726, + 0.00011200000000000001, + 0.00010872727272727271, + 0.00010545454545454546, + 0.00010218181818181816, + 9.890909090909091e-05, + 9.563636363636366e-05, + 9.236363636363636e-05, + 8.909090909090911e-05, + 8.58181818181818e-05, + 8.254545454545456e-05, + 7.927272727272725e-05, + 7.6e-05, + 7.27272727272727e-05, + 6.945454545454545e-05, + 6.61818181818182e-05, + 6.29090909090909e-05, + 5.963636363636365e-05, + 5.636363636363635e-05, + 5.30909090909091e-05, + 4.98181818181818e-05, + 4.654545454545455e-05, + 4.3272727272727245e-05, + 4e-05, + 3.8621034482758623e-05, + 3.7242068965517244e-05, + 3.5863103448275864e-05, + 3.4484137931034484e-05, + 3.3105172413793104e-05, + 3.172620689655173e-05, + 3.0347241379310348e-05, + 2.8968275862068968e-05, + 2.7589310344827588e-05, + 2.621034482758621e-05, + 2.4831379310344832e-05, + 2.3452413793103452e-05, + 2.2073448275862072e-05, + 2.0694482758620692e-05, + 1.9315517241379313e-05, + 1.7936551724137933e-05, + 1.6557586206896553e-05, + 1.5178620689655173e-05, + 1.3799655172413793e-05, + 1.2420689655172413e-05, + 1.1041724137931037e-05, + 9.662758620689657e-06, + 8.283793103448274e-06, + 6.904827586206901e-06, + 5.525862068965521e-06, + 4.146896551724141e-06, + 2.7679310344827613e-06, + 1.3889655172413814e-06, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 128 +model = "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/training/ReturnnTrainingJob.BLZiCTx0OG4e/output/models/epoch" +num_epochs = 250 +num_inputs = 80 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.T3qQ5mfrQwlw/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.bAoRzty8czAI/output/targets.hdf" + ], + "partition_epoch": 5, + "seq_ordering": "laplace:.1000", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert(0, "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/recipe") +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_independent_softmax_based_on_num_params import ( + ConformerEncoderConfig, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_independent_softmax_based_on_num_params import ( + ConformerBlockConfig, +) +from i6_models.parts.conformer_structure_prune.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer_structure_prune.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer_structure_prune.convolution import ( + ConformerConvolutionV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models_repo.i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=5, + freq_mask_max_size=8, + ), + conformer_cfg=ConformerEncoderConfig( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=384, + ), + ), + block_cfg=ConformerBlockConfig( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.2, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, num_att_heads=6, att_weights_dropout=0.1, dropout=0.2 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=384, + kernel_size=31, + dropout=0.2, + activation=SiLU(), + norm=LayerNormNC((384,), eps=1e-05, elementwise_affine=True), + ), + layer_dropout=0.1, + apply_ff_adaptive_dropout=True, + apply_mhsa_adaptive_dropout=False, + ), + pct_params_set=[0.5, 1], + softmax_kwargs={ + "softmax_constraint_norm": "L2_norm", + "initial_tau": 1, + "min_tau": 0.01, + "tau_annealing": 0.999992, + "softmax_constraint_loss_scale": "linear_increase", + "max_softmax_constraint_loss_scale": 1, + "softmax_constraint_warmup_steps": 225000, + }, + component_dropout=None, + ), + target_size=79, + recog_param_pct=1, + param_pct_anneal_kwargs={ + "pct_anneal_num_steps_per_iter": 225000.0, + "pct_reduction_per_iter": 0.5, + }, + params_kwargs={ + "num_params": [ + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + ], + "rest_params": 1167342, + "total_params": 42129806, + }, + aux_loss_scales={"0.5": "focal_loss", "1": 1}, + final_dropout=0.2, +) + +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable import ( + train_step, +) diff --git a/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_sparsity_aware_layer_wise.config b/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_sparsity_aware_layer_wise.config new file mode 100644 index 00000000..270d361a --- /dev/null +++ b/2024-orthogonal-softmax/TED-LIUM-v2/Tab2_orthosoftmax_jointly_train_two_models_sparsity_aware_layer_wise.config @@ -0,0 +1,511 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2240000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.AltkEvXwM3dF/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.qGsTLu0lnU5n/output/targets.hdf" + ], + "partition_epoch": 1, + "seq_ordering": "sorted", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 4e-06, + 7.6e-06, + 1.1200000000000001e-05, + 1.48e-05, + 1.84e-05, + 2.2e-05, + 2.5600000000000002e-05, + 2.92e-05, + 3.2800000000000004e-05, + 3.6400000000000004e-05, + 4e-05, + 4.36e-05, + 4.720000000000001e-05, + 5.080000000000001e-05, + 5.440000000000001e-05, + 5.800000000000001e-05, + 6.16e-05, + 6.520000000000001e-05, + 6.88e-05, + 7.240000000000001e-05, + 7.6e-05, + 7.960000000000001e-05, + 8.32e-05, + 8.680000000000001e-05, + 9.040000000000002e-05, + 9.400000000000001e-05, + 9.760000000000001e-05, + 0.00010120000000000001, + 0.00010480000000000001, + 0.0001084, + 0.00011200000000000001, + 0.0001156, + 0.00011920000000000001, + 0.0001228, + 0.0001264, + 0.00013, + 0.0001336, + 0.0001372, + 0.0001408, + 0.00014439999999999999, + 0.000148, + 0.0001516, + 0.0001552, + 0.0001588, + 0.0001624, + 0.000166, + 0.0001696, + 0.0001732, + 0.00017680000000000001, + 0.0001804, + 0.000184, + 0.0001876, + 0.0001912, + 0.0001948, + 0.0001984, + 0.000202, + 0.0002056, + 0.00020920000000000002, + 0.0002128, + 0.0002164, + 0.00022, + 0.00022360000000000001, + 0.0002272, + 0.0002308, + 0.0002344, + 0.000238, + 0.00024160000000000002, + 0.0002452, + 0.00024880000000000003, + 0.0002524, + 0.000256, + 0.0002596, + 0.0002632, + 0.00026680000000000003, + 0.0002704, + 0.000274, + 0.0002776, + 0.0002812, + 0.0002848, + 0.0002884, + 0.000292, + 0.00029560000000000003, + 0.0002992, + 0.0003028, + 0.0003064, + 0.00031, + 0.00031360000000000003, + 0.0003172, + 0.0003208, + 0.0003244, + 0.000328, + 0.00033160000000000004, + 0.0003352, + 0.0003388, + 0.00034240000000000003, + 0.000346, + 0.00034960000000000004, + 0.0003532, + 0.0003568, + 0.00036040000000000003, + 0.000364, + 0.0003676, + 0.0003712, + 0.0003748, + 0.00037840000000000004, + 0.000382, + 0.0003856, + 0.00038920000000000003, + 0.0003928, + 0.00039640000000000004, + 0.0004, + 0.00039672727272727277, + 0.00039345454545454547, + 0.0003901818181818182, + 0.0003869090909090909, + 0.00038363636363636367, + 0.00038036363636363636, + 0.0003770909090909091, + 0.0003738181818181818, + 0.00037054545454545456, + 0.00036727272727272726, + 0.000364, + 0.00036072727272727276, + 0.00035745454545454546, + 0.0003541818181818182, + 0.0003509090909090909, + 0.00034763636363636366, + 0.00034436363636363636, + 0.0003410909090909091, + 0.00033781818181818186, + 0.00033454545454545456, + 0.0003312727272727273, + 0.000328, + 0.00032472727272727276, + 0.00032145454545454545, + 0.0003181818181818182, + 0.00031490909090909095, + 0.00031163636363636365, + 0.00030836363636363635, + 0.0003050909090909091, + 0.00030181818181818185, + 0.00029854545454545455, + 0.0002952727272727273, + 0.000292, + 0.00028872727272727275, + 0.00028545454545454544, + 0.0002821818181818182, + 0.00027890909090909095, + 0.00027563636363636364, + 0.00027236363636363634, + 0.0002690909090909091, + 0.00026581818181818184, + 0.00026254545454545454, + 0.0002592727272727273, + 0.00025600000000000004, + 0.00025272727272727274, + 0.00024945454545454544, + 0.0002461818181818182, + 0.0002429090909090909, + 0.00023963636363636364, + 0.00023636363636363636, + 0.0002330909090909091, + 0.00022981818181818184, + 0.00022654545454545456, + 0.00022327272727272728, + 0.00022, + 0.00021672727272727273, + 0.00021345454545454546, + 0.00021018181818181818, + 0.0002069090909090909, + 0.00020363636363636366, + 0.00020036363636363638, + 0.0001970909090909091, + 0.00019381818181818183, + 0.00019054545454545455, + 0.00018727272727272728, + 0.000184, + 0.00018072727272727272, + 0.00017745454545454545, + 0.0001741818181818182, + 0.00017090909090909092, + 0.00016763636363636365, + 0.00016436363636363637, + 0.0001610909090909091, + 0.00015781818181818182, + 0.00015454545454545457, + 0.00015127272727272727, + 0.00014800000000000002, + 0.00014472727272727272, + 0.00014145454545454547, + 0.00013818181818181816, + 0.00013490909090909092, + 0.0001316363636363636, + 0.00012836363636363636, + 0.00012509090909090912, + 0.00012181818181818181, + 0.00011854545454545456, + 0.00011527272727272726, + 0.00011200000000000001, + 0.00010872727272727271, + 0.00010545454545454546, + 0.00010218181818181816, + 9.890909090909091e-05, + 9.563636363636366e-05, + 9.236363636363636e-05, + 8.909090909090911e-05, + 8.58181818181818e-05, + 8.254545454545456e-05, + 7.927272727272725e-05, + 7.6e-05, + 7.27272727272727e-05, + 6.945454545454545e-05, + 6.61818181818182e-05, + 6.29090909090909e-05, + 5.963636363636365e-05, + 5.636363636363635e-05, + 5.30909090909091e-05, + 4.98181818181818e-05, + 4.654545454545455e-05, + 4.3272727272727245e-05, + 4e-05, + 3.8621034482758623e-05, + 3.7242068965517244e-05, + 3.5863103448275864e-05, + 3.4484137931034484e-05, + 3.3105172413793104e-05, + 3.172620689655173e-05, + 3.0347241379310348e-05, + 2.8968275862068968e-05, + 2.7589310344827588e-05, + 2.621034482758621e-05, + 2.4831379310344832e-05, + 2.3452413793103452e-05, + 2.2073448275862072e-05, + 2.0694482758620692e-05, + 1.9315517241379313e-05, + 1.7936551724137933e-05, + 1.6557586206896553e-05, + 1.5178620689655173e-05, + 1.3799655172413793e-05, + 1.2420689655172413e-05, + 1.1041724137931037e-05, + 9.662758620689657e-06, + 8.283793103448274e-06, + 6.904827586206901e-06, + 5.525862068965521e-06, + 4.146896551724141e-06, + 2.7679310344827613e-06, + 1.3889655172413814e-06, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 128 +model = "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/training/ReturnnTrainingJob.akSFAZxotWPD/output/models/epoch" +num_epochs = 250 +num_inputs = 80 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.T3qQ5mfrQwlw/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.bAoRzty8czAI/output/targets.hdf" + ], + "partition_epoch": 5, + "seq_ordering": "laplace:.1000", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert(0, "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/recipe") +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_layers.joint_train_two_models_layerwise_num_params_aware import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_layers.joint_train_two_models_layerwise_num_params_aware import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_independent_softmax_layerwise_num_params_aware import ( + ConformerEncoderConfig, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_independent_softmax_layerwise_num_params_aware import ( + ConformerBlockConfig, +) +from i6_models.parts.conformer.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer.convolution import ConformerConvolutionV1Config +from torch.nn.modules.activation import SiLU +from i6_models_repo.i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=16, + freq_mask_max_size=5, + ), + conformer_cfg=ConformerEncoderConfig( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=384, + ), + ), + block_cfg=ConformerBlockConfig( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.2, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, num_att_heads=6, att_weights_dropout=0.1, dropout=0.2 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=384, + kernel_size=31, + dropout=0.2, + activation=SiLU(), + norm=LayerNormNC((384,), eps=1e-05, elementwise_affine=True), + ), + layer_dropout=0, + modules=["ff", "conv", "mhsa", "ff"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + pct_params_set=[0.5, 1], + layer_dropout_kwargs={"layer_dropout_stage_1": 0, "layer_dropout_stage_2": 0.3}, + softmax_kwargs={ + "softmax_constraint_norm": "L2_norm", + "initial_tau": 1, + "min_tau": 0.01, + "tau_annealing": 0.999992, + "softmax_constraint_loss_scale": "linear_increase", + "max_softmax_constraint_loss_scale": 1, + "softmax_constraint_warmup_steps": 225000, + }, + ), + target_size=79, + recog_param_pct=1, + param_pct_anneal_kwargs={ + "pct_anneal_num_steps_per_iter": 225000.0, + "pct_reduction_per_iter": 0.5, + }, + params_kwargs={ + "num_params": [ + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + 1182336, + 458112, + 592128, + 1182336, + ], + "rest_params": 1167342, + "total_params": 42129806, + }, + aux_loss_scales={"0.5": "focal_loss", "1": 1}, + final_dropout=0, +) + +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_layers.joint_train_two_models_layerwise_num_params_aware import ( + train_step, +) diff --git a/2024-orthogonal-softmax/TED-LIUM-v2/Tab4_orthosoftmax_jointly_train_three_models_FLOPs_aware_cmp_wise.config b/2024-orthogonal-softmax/TED-LIUM-v2/Tab4_orthosoftmax_jointly_train_three_models_FLOPs_aware_cmp_wise.config new file mode 100644 index 00000000..9badbab5 --- /dev/null +++ b/2024-orthogonal-softmax/TED-LIUM-v2/Tab4_orthosoftmax_jointly_train_three_models_FLOPs_aware_cmp_wise.config @@ -0,0 +1,640 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2240000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.AltkEvXwM3dF/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.qGsTLu0lnU5n/output/targets.hdf" + ], + "partition_epoch": 1, + "seq_ordering": "sorted", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 4e-06, + 7.6e-06, + 1.1200000000000001e-05, + 1.48e-05, + 1.84e-05, + 2.2e-05, + 2.5600000000000002e-05, + 2.92e-05, + 3.2800000000000004e-05, + 3.6400000000000004e-05, + 4e-05, + 4.36e-05, + 4.720000000000001e-05, + 5.080000000000001e-05, + 5.440000000000001e-05, + 5.800000000000001e-05, + 6.16e-05, + 6.520000000000001e-05, + 6.88e-05, + 7.240000000000001e-05, + 7.6e-05, + 7.960000000000001e-05, + 8.32e-05, + 8.680000000000001e-05, + 9.040000000000002e-05, + 9.400000000000001e-05, + 9.760000000000001e-05, + 0.00010120000000000001, + 0.00010480000000000001, + 0.0001084, + 0.00011200000000000001, + 0.0001156, + 0.00011920000000000001, + 0.0001228, + 0.0001264, + 0.00013, + 0.0001336, + 0.0001372, + 0.0001408, + 0.00014439999999999999, + 0.000148, + 0.0001516, + 0.0001552, + 0.0001588, + 0.0001624, + 0.000166, + 0.0001696, + 0.0001732, + 0.00017680000000000001, + 0.0001804, + 0.000184, + 0.0001876, + 0.0001912, + 0.0001948, + 0.0001984, + 0.000202, + 0.0002056, + 0.00020920000000000002, + 0.0002128, + 0.0002164, + 0.00022, + 0.00022360000000000001, + 0.0002272, + 0.0002308, + 0.0002344, + 0.000238, + 0.00024160000000000002, + 0.0002452, + 0.00024880000000000003, + 0.0002524, + 0.000256, + 0.0002596, + 0.0002632, + 0.00026680000000000003, + 0.0002704, + 0.000274, + 0.0002776, + 0.0002812, + 0.0002848, + 0.0002884, + 0.000292, + 0.00029560000000000003, + 0.0002992, + 0.0003028, + 0.0003064, + 0.00031, + 0.00031360000000000003, + 0.0003172, + 0.0003208, + 0.0003244, + 0.000328, + 0.00033160000000000004, + 0.0003352, + 0.0003388, + 0.00034240000000000003, + 0.000346, + 0.00034960000000000004, + 0.0003532, + 0.0003568, + 0.00036040000000000003, + 0.000364, + 0.0003676, + 0.0003712, + 0.0003748, + 0.00037840000000000004, + 0.000382, + 0.0003856, + 0.00038920000000000003, + 0.0003928, + 0.00039640000000000004, + 0.0004, + 0.00039672727272727277, + 0.00039345454545454547, + 0.0003901818181818182, + 0.0003869090909090909, + 0.00038363636363636367, + 0.00038036363636363636, + 0.0003770909090909091, + 0.0003738181818181818, + 0.00037054545454545456, + 0.00036727272727272726, + 0.000364, + 0.00036072727272727276, + 0.00035745454545454546, + 0.0003541818181818182, + 0.0003509090909090909, + 0.00034763636363636366, + 0.00034436363636363636, + 0.0003410909090909091, + 0.00033781818181818186, + 0.00033454545454545456, + 0.0003312727272727273, + 0.000328, + 0.00032472727272727276, + 0.00032145454545454545, + 0.0003181818181818182, + 0.00031490909090909095, + 0.00031163636363636365, + 0.00030836363636363635, + 0.0003050909090909091, + 0.00030181818181818185, + 0.00029854545454545455, + 0.0002952727272727273, + 0.000292, + 0.00028872727272727275, + 0.00028545454545454544, + 0.0002821818181818182, + 0.00027890909090909095, + 0.00027563636363636364, + 0.00027236363636363634, + 0.0002690909090909091, + 0.00026581818181818184, + 0.00026254545454545454, + 0.0002592727272727273, + 0.00025600000000000004, + 0.00025272727272727274, + 0.00024945454545454544, + 0.0002461818181818182, + 0.0002429090909090909, + 0.00023963636363636364, + 0.00023636363636363636, + 0.0002330909090909091, + 0.00022981818181818184, + 0.00022654545454545456, + 0.00022327272727272728, + 0.00022, + 0.00021672727272727273, + 0.00021345454545454546, + 0.00021018181818181818, + 0.0002069090909090909, + 0.00020363636363636366, + 0.00020036363636363638, + 0.0001970909090909091, + 0.00019381818181818183, + 0.00019054545454545455, + 0.00018727272727272728, + 0.000184, + 0.00018072727272727272, + 0.00017745454545454545, + 0.0001741818181818182, + 0.00017090909090909092, + 0.00016763636363636365, + 0.00016436363636363637, + 0.0001610909090909091, + 0.00015781818181818182, + 0.00015454545454545457, + 0.00015127272727272727, + 0.00014800000000000002, + 0.00014472727272727272, + 0.00014145454545454547, + 0.00013818181818181816, + 0.00013490909090909092, + 0.0001316363636363636, + 0.00012836363636363636, + 0.00012509090909090912, + 0.00012181818181818181, + 0.00011854545454545456, + 0.00011527272727272726, + 0.00011200000000000001, + 0.00010872727272727271, + 0.00010545454545454546, + 0.00010218181818181816, + 9.890909090909091e-05, + 9.563636363636366e-05, + 9.236363636363636e-05, + 8.909090909090911e-05, + 8.58181818181818e-05, + 8.254545454545456e-05, + 7.927272727272725e-05, + 7.6e-05, + 7.27272727272727e-05, + 6.945454545454545e-05, + 6.61818181818182e-05, + 6.29090909090909e-05, + 5.963636363636365e-05, + 5.636363636363635e-05, + 5.30909090909091e-05, + 4.98181818181818e-05, + 4.654545454545455e-05, + 4.3272727272727245e-05, + 4e-05, + 3.8621034482758623e-05, + 3.7242068965517244e-05, + 3.5863103448275864e-05, + 3.4484137931034484e-05, + 3.3105172413793104e-05, + 3.172620689655173e-05, + 3.0347241379310348e-05, + 2.8968275862068968e-05, + 2.7589310344827588e-05, + 2.621034482758621e-05, + 2.4831379310344832e-05, + 2.3452413793103452e-05, + 2.2073448275862072e-05, + 2.0694482758620692e-05, + 1.9315517241379313e-05, + 1.7936551724137933e-05, + 1.6557586206896553e-05, + 1.5178620689655173e-05, + 1.3799655172413793e-05, + 1.2420689655172413e-05, + 1.1041724137931037e-05, + 9.662758620689657e-06, + 8.283793103448274e-06, + 6.904827586206901e-06, + 5.525862068965521e-06, + 4.146896551724141e-06, + 2.7679310344827613e-06, + 1.3889655172413814e-06, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 128 +model = "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/training/ReturnnTrainingJob.0JAIDWPfXDah/output/models/epoch" +num_epochs = 250 +num_inputs = 80 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.T3qQ5mfrQwlw/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.bAoRzty8czAI/output/targets.hdf" + ], + "partition_epoch": 5, + "seq_ordering": "laplace:.1000", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert(0, "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/recipe") +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable_random_pct import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable_random_pct import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_ortho_softmax_cmp_wise import ( + ConformerEncoderConfig, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_ortho_softmax_cmp_wise import ( + ConformerBlockConfig, +) +from i6_models.parts.conformer_structure_prune.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer_structure_prune.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer_structure_prune.convolution import ( + ConformerConvolutionV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models_repo.i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=16, + freq_mask_max_size=5, + ), + conformer_cfg=ConformerEncoderConfig( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=384, + ), + ), + block_cfg=ConformerBlockConfig( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.2, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, num_att_heads=6, att_weights_dropout=0.1, dropout=0.2 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=384, + kernel_size=31, + dropout=0.2, + activation=SiLU(), + norm=LayerNormNC((384,), eps=1e-05, elementwise_affine=True), + ), + layer_dropout=0.1, + apply_ff_adaptive_dropout=True, + ), + pct_params_set=[0.52, 0.75, 1], + softmax_kwargs={ + "softmax_constraint_norm": "L2_norm_sqrt", + "initial_tau": 1, + "min_tau": 0.01, + "tau_annealing": 0.999992, + "softmax_constraint_loss_scale": "linear_increase", + "max_softmax_constraint_loss_scale": 1, + "softmax_constraint_warmup_steps": 225000, + }, + ), + target_size=79, + recog_param_pct=1, + stage_1_global_steps=225000, + params_kwargs={ + "num_params": [ + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + 3.21, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 0.6966666667, + 2.0675, + 2.0675, + 2.0675, + 2.0675, + ], + "rest_params": 110, + "total_params": 400, + }, + aux_loss_scales={"0.52": "focal_loss", "0.75": "focal_loss", "1": 1}, + final_dropout=0.2, +) + +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable_random_pct import ( + train_step, +) diff --git a/2024-orthogonal-softmax/TED-LIUM-v2/Tab4_orthosoftmax_jointly_train_three_models_sparsity_aware_cmp_wise.config b/2024-orthogonal-softmax/TED-LIUM-v2/Tab4_orthosoftmax_jointly_train_three_models_sparsity_aware_cmp_wise.config new file mode 100644 index 00000000..10546d85 --- /dev/null +++ b/2024-orthogonal-softmax/TED-LIUM-v2/Tab4_orthosoftmax_jointly_train_three_models_sparsity_aware_cmp_wise.config @@ -0,0 +1,640 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2240000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.AltkEvXwM3dF/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.qGsTLu0lnU5n/output/targets.hdf" + ], + "partition_epoch": 1, + "seq_ordering": "sorted", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 4e-06, + 7.6e-06, + 1.1200000000000001e-05, + 1.48e-05, + 1.84e-05, + 2.2e-05, + 2.5600000000000002e-05, + 2.92e-05, + 3.2800000000000004e-05, + 3.6400000000000004e-05, + 4e-05, + 4.36e-05, + 4.720000000000001e-05, + 5.080000000000001e-05, + 5.440000000000001e-05, + 5.800000000000001e-05, + 6.16e-05, + 6.520000000000001e-05, + 6.88e-05, + 7.240000000000001e-05, + 7.6e-05, + 7.960000000000001e-05, + 8.32e-05, + 8.680000000000001e-05, + 9.040000000000002e-05, + 9.400000000000001e-05, + 9.760000000000001e-05, + 0.00010120000000000001, + 0.00010480000000000001, + 0.0001084, + 0.00011200000000000001, + 0.0001156, + 0.00011920000000000001, + 0.0001228, + 0.0001264, + 0.00013, + 0.0001336, + 0.0001372, + 0.0001408, + 0.00014439999999999999, + 0.000148, + 0.0001516, + 0.0001552, + 0.0001588, + 0.0001624, + 0.000166, + 0.0001696, + 0.0001732, + 0.00017680000000000001, + 0.0001804, + 0.000184, + 0.0001876, + 0.0001912, + 0.0001948, + 0.0001984, + 0.000202, + 0.0002056, + 0.00020920000000000002, + 0.0002128, + 0.0002164, + 0.00022, + 0.00022360000000000001, + 0.0002272, + 0.0002308, + 0.0002344, + 0.000238, + 0.00024160000000000002, + 0.0002452, + 0.00024880000000000003, + 0.0002524, + 0.000256, + 0.0002596, + 0.0002632, + 0.00026680000000000003, + 0.0002704, + 0.000274, + 0.0002776, + 0.0002812, + 0.0002848, + 0.0002884, + 0.000292, + 0.00029560000000000003, + 0.0002992, + 0.0003028, + 0.0003064, + 0.00031, + 0.00031360000000000003, + 0.0003172, + 0.0003208, + 0.0003244, + 0.000328, + 0.00033160000000000004, + 0.0003352, + 0.0003388, + 0.00034240000000000003, + 0.000346, + 0.00034960000000000004, + 0.0003532, + 0.0003568, + 0.00036040000000000003, + 0.000364, + 0.0003676, + 0.0003712, + 0.0003748, + 0.00037840000000000004, + 0.000382, + 0.0003856, + 0.00038920000000000003, + 0.0003928, + 0.00039640000000000004, + 0.0004, + 0.00039672727272727277, + 0.00039345454545454547, + 0.0003901818181818182, + 0.0003869090909090909, + 0.00038363636363636367, + 0.00038036363636363636, + 0.0003770909090909091, + 0.0003738181818181818, + 0.00037054545454545456, + 0.00036727272727272726, + 0.000364, + 0.00036072727272727276, + 0.00035745454545454546, + 0.0003541818181818182, + 0.0003509090909090909, + 0.00034763636363636366, + 0.00034436363636363636, + 0.0003410909090909091, + 0.00033781818181818186, + 0.00033454545454545456, + 0.0003312727272727273, + 0.000328, + 0.00032472727272727276, + 0.00032145454545454545, + 0.0003181818181818182, + 0.00031490909090909095, + 0.00031163636363636365, + 0.00030836363636363635, + 0.0003050909090909091, + 0.00030181818181818185, + 0.00029854545454545455, + 0.0002952727272727273, + 0.000292, + 0.00028872727272727275, + 0.00028545454545454544, + 0.0002821818181818182, + 0.00027890909090909095, + 0.00027563636363636364, + 0.00027236363636363634, + 0.0002690909090909091, + 0.00026581818181818184, + 0.00026254545454545454, + 0.0002592727272727273, + 0.00025600000000000004, + 0.00025272727272727274, + 0.00024945454545454544, + 0.0002461818181818182, + 0.0002429090909090909, + 0.00023963636363636364, + 0.00023636363636363636, + 0.0002330909090909091, + 0.00022981818181818184, + 0.00022654545454545456, + 0.00022327272727272728, + 0.00022, + 0.00021672727272727273, + 0.00021345454545454546, + 0.00021018181818181818, + 0.0002069090909090909, + 0.00020363636363636366, + 0.00020036363636363638, + 0.0001970909090909091, + 0.00019381818181818183, + 0.00019054545454545455, + 0.00018727272727272728, + 0.000184, + 0.00018072727272727272, + 0.00017745454545454545, + 0.0001741818181818182, + 0.00017090909090909092, + 0.00016763636363636365, + 0.00016436363636363637, + 0.0001610909090909091, + 0.00015781818181818182, + 0.00015454545454545457, + 0.00015127272727272727, + 0.00014800000000000002, + 0.00014472727272727272, + 0.00014145454545454547, + 0.00013818181818181816, + 0.00013490909090909092, + 0.0001316363636363636, + 0.00012836363636363636, + 0.00012509090909090912, + 0.00012181818181818181, + 0.00011854545454545456, + 0.00011527272727272726, + 0.00011200000000000001, + 0.00010872727272727271, + 0.00010545454545454546, + 0.00010218181818181816, + 9.890909090909091e-05, + 9.563636363636366e-05, + 9.236363636363636e-05, + 8.909090909090911e-05, + 8.58181818181818e-05, + 8.254545454545456e-05, + 7.927272727272725e-05, + 7.6e-05, + 7.27272727272727e-05, + 6.945454545454545e-05, + 6.61818181818182e-05, + 6.29090909090909e-05, + 5.963636363636365e-05, + 5.636363636363635e-05, + 5.30909090909091e-05, + 4.98181818181818e-05, + 4.654545454545455e-05, + 4.3272727272727245e-05, + 4e-05, + 3.8621034482758623e-05, + 3.7242068965517244e-05, + 3.5863103448275864e-05, + 3.4484137931034484e-05, + 3.3105172413793104e-05, + 3.172620689655173e-05, + 3.0347241379310348e-05, + 2.8968275862068968e-05, + 2.7589310344827588e-05, + 2.621034482758621e-05, + 2.4831379310344832e-05, + 2.3452413793103452e-05, + 2.2073448275862072e-05, + 2.0694482758620692e-05, + 1.9315517241379313e-05, + 1.7936551724137933e-05, + 1.6557586206896553e-05, + 1.5178620689655173e-05, + 1.3799655172413793e-05, + 1.2420689655172413e-05, + 1.1041724137931037e-05, + 9.662758620689657e-06, + 8.283793103448274e-06, + 6.904827586206901e-06, + 5.525862068965521e-06, + 4.146896551724141e-06, + 2.7679310344827613e-06, + 1.3889655172413814e-06, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 128 +model = "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/training/ReturnnTrainingJob.dDPnfoIO1pmi/output/models/epoch" +num_epochs = 250 +num_inputs = 80 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_core/returnn/hdf/BlissToPcmHDFJob.T3qQ5mfrQwlw/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.bAoRzty8czAI/output/targets.hdf" + ], + "partition_epoch": 5, + "seq_ordering": "laplace:.1000", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert(0, "/u/jxu/setups/tedlium2/2024-05-14--independent-softmax/recipe") +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable_random_pct import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable_random_pct import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_experiments.users.berger.pytorch.custom_parts.specaugment import ( + SpecaugmentByLengthConfigV1, +) +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_ortho_softmax_cmp_wise import ( + ConformerEncoderConfig, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_models.assemblies.conformer_with_dynamic_model_size.selection_with_ortho_softmax_cmp_wise import ( + ConformerBlockConfig, +) +from i6_models.parts.conformer_structure_prune.feedforward import ( + ConformerPositionwiseFeedForwardV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models.parts.conformer_structure_prune.mhsa import ConformerMHSAV1Config +from i6_models.parts.conformer_structure_prune.convolution import ( + ConformerConvolutionV1Config, +) +from torch.nn.modules.activation import SiLU +from i6_models_repo.i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaugment_cfg=SpecaugmentByLengthConfigV1( + time_min_num_masks=2, + time_max_mask_per_n_frames=25, + time_mask_max_size=20, + freq_min_num_masks=2, + freq_max_num_masks=16, + freq_mask_max_size=5, + ), + conformer_cfg=ConformerEncoderConfig( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=384, + ), + ), + block_cfg=ConformerBlockConfig( + ff_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, num_att_heads=6, att_weights_dropout=0.1, dropout=0.1 + ), + conv_cfg=ConformerConvolutionV1Config( + channels=384, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((384,), eps=1e-05, elementwise_affine=True), + ), + layer_dropout=0.1, + apply_ff_adaptive_dropout=True, + ), + pct_params_set=[0.352, 0.67, 1], + softmax_kwargs={ + "softmax_constraint_norm": "L2_norm_sqrt", + "initial_tau": 1, + "min_tau": 0.01, + "tau_annealing": 0.999992, + "softmax_constraint_loss_scale": "linear_increase", + "max_softmax_constraint_loss_scale": 1, + "softmax_constraint_warmup_steps": 225000, + }, + ), + target_size=79, + recog_param_pct=1, + stage_1_global_steps=225000, + params_kwargs={ + "num_params": [ + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 458112, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + ], + "rest_params": 1167342, + "total_params": 42129806, + }, + aux_loss_scales={"0.352": "focal_loss", "0.67": "focal_loss", "1": 1}, + final_dropout=0.2, +) + +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) + + +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.independent_softmax.jointly_train_two_models.num_params.conformer_size_384_log_mel_ffn_dim_conv_attn_heads_decomposable_random_pct import ( + train_step, +)