From d20987052a94273844d8ae8cf54724af5b8c4edf Mon Sep 17 00:00:00 2001 From: Judyxujj Date: Mon, 4 Aug 2025 17:39:22 +0200 Subject: [PATCH] update --- ...t_epochs_90_pct_0.15_iters_1_0.0004.config | 1350 +++++++++++++++++ ...t_epochs_90_pct_0.15_iters_1_0.0004.config | 1304 ++++++++++++++++ .../README.md | 12 + ...pt_epochs_50_pct_0.2_iters_1_0.0005.config | 976 ++++++++++++ 4 files changed, 3642 insertions(+) create mode 100644 2025-dynamic-model-architecture-optimization/LBS-960/conformer_double_prune_model_size_512_freq_mask_50_adapt_epochs_90_pct_0.15_iters_1_0.0004.config create mode 100644 2025-dynamic-model-architecture-optimization/LBS-960/ebranchformer_double_prune_model_size_384_freq_mask_50_adapt_epochs_90_pct_0.15_iters_1_0.0004.config create mode 100644 2025-dynamic-model-architecture-optimization/TED-LIUM-v2/ebranchformer_double_prune_model_size_512_adjust_dropout_True_adapt_epochs_50_pct_0.2_iters_1_0.0005.config diff --git a/2025-dynamic-model-architecture-optimization/LBS-960/conformer_double_prune_model_size_512_freq_mask_50_adapt_epochs_90_pct_0.15_iters_1_0.0004.config b/2025-dynamic-model-architecture-optimization/LBS-960/conformer_double_prune_model_size_512_freq_mask_50_adapt_epochs_90_pct_0.15_iters_1_0.0004.config new file mode 100644 index 00000000..cb53c572 --- /dev/null +++ b/2025-dynamic-model-architecture-optimization/LBS-960/conformer_double_prune_model_size_512_freq_mask_50_adapt_epochs_90_pct_0.15_iters_1_0.0004.config @@ -0,0 +1,1350 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2400000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_core/returnn/hdf/BlissToPcmHDFJob.KErFrKsP3fTh/output/audio.hdf", + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_core/returnn/hdf/BlissToPcmHDFJob.Clwnntg2nopq/output/audio.hdf", + ], + "partition_epoch": 1, + "seq_ordering": "laplace:.1000", + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.K32gY5lpp41M/output/targets.hdf", + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.YH5GlkJveVPE/output/targets.hdf", + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 4e-06, + 5.466666666666666e-06, + 6.933333333333334e-06, + 8.400000000000001e-06, + 9.866666666666668e-06, + 1.1333333333333336e-05, + 1.28e-05, + 1.4266666666666667e-05, + 1.5733333333333334e-05, + 1.72e-05, + 1.866666666666667e-05, + 2.0133333333333336e-05, + 2.16e-05, + 2.3066666666666667e-05, + 2.4533333333333334e-05, + 2.6000000000000002e-05, + 2.746666666666667e-05, + 2.8933333333333336e-05, + 3.0400000000000004e-05, + 3.186666666666667e-05, + 3.333333333333334e-05, + 3.4800000000000006e-05, + 3.6266666666666676e-05, + 3.773333333333334e-05, + 3.9200000000000004e-05, + 4.0666666666666675e-05, + 4.213333333333334e-05, + 4.360000000000001e-05, + 4.506666666666667e-05, + 4.6533333333333344e-05, + 4.800000000000001e-05, + 4.946666666666668e-05, + 5.093333333333334e-05, + 5.2400000000000007e-05, + 5.386666666666668e-05, + 5.533333333333334e-05, + 5.680000000000001e-05, + 5.8266666666666676e-05, + 5.9733333333333346e-05, + 6.120000000000001e-05, + 6.266666666666668e-05, + 6.413333333333334e-05, + 6.560000000000001e-05, + 6.706666666666668e-05, + 6.853333333333335e-05, + 7.000000000000001e-05, + 7.146666666666668e-05, + 7.293333333333335e-05, + 7.44e-05, + 7.586666666666668e-05, + 7.733333333333335e-05, + 7.880000000000002e-05, + 8.026666666666668e-05, + 8.173333333333335e-05, + 8.320000000000002e-05, + 8.466666666666669e-05, + 8.613333333333334e-05, + 8.760000000000002e-05, + 8.906666666666669e-05, + 9.053333333333334e-05, + 9.200000000000001e-05, + 9.346666666666668e-05, + 9.493333333333336e-05, + 9.640000000000001e-05, + 9.786666666666668e-05, + 9.933333333333335e-05, + 0.00010080000000000001, + 0.00010226666666666668, + 0.00010373333333333335, + 0.00010520000000000002, + 0.00010666666666666668, + 0.00010813333333333335, + 0.00010960000000000002, + 0.00011106666666666668, + 0.00011253333333333335, + 0.00011400000000000002, + 0.00011546666666666669, + 0.00011693333333333335, + 0.00011840000000000002, + 0.00011986666666666669, + 0.00012133333333333336, + 0.0001228, + 0.0001242666666666667, + 0.00012573333333333334, + 0.0001272, + 0.00012866666666666669, + 0.00013013333333333334, + 0.0001316, + 0.00013306666666666668, + 0.00013453333333333334, + 0.000136, + 0.00013746666666666668, + 0.00013893333333333334, + 0.0001404, + 0.00014186666666666668, + 0.00014333333333333334, + 0.0001448, + 0.00014626666666666668, + 0.00014773333333333334, + 0.00014920000000000002, + 0.00015066666666666668, + 0.00015213333333333334, + 0.00015360000000000002, + 0.00015506666666666668, + 0.00015653333333333333, + 0.00015800000000000002, + 0.00015946666666666668, + 0.00016093333333333333, + 0.00016240000000000002, + 0.00016386666666666667, + 0.00016533333333333336, + 0.00016680000000000002, + 0.00016826666666666667, + 0.00016973333333333336, + 0.00017120000000000001, + 0.00017266666666666667, + 0.00017413333333333336, + 0.0001756, + 0.00017706666666666667, + 0.00017853333333333335, + 0.00018, + 0.00018146666666666667, + 0.00018293333333333335, + 0.0001844, + 0.0001858666666666667, + 0.00018733333333333335, + 0.0001888, + 0.0001902666666666667, + 0.00019173333333333335, + 0.0001932, + 0.0001946666666666667, + 0.00019613333333333335, + 0.0001976, + 0.0001990666666666667, + 0.00020053333333333335, + 0.00020200000000000003, + 0.0002034666666666667, + 0.00020493333333333335, + 0.00020640000000000003, + 0.0002078666666666667, + 0.00020933333333333334, + 0.00021080000000000003, + 0.00021226666666666669, + 0.00021373333333333334, + 0.00021520000000000003, + 0.00021666666666666668, + 0.00021813333333333334, + 0.00021960000000000003, + 0.00022106666666666668, + 0.00022253333333333337, + 0.00022400000000000002, + 0.00022546666666666668, + 0.00022693333333333337, + 0.00022840000000000002, + 0.00022986666666666668, + 0.00023133333333333336, + 0.00023280000000000002, + 0.00023426666666666668, + 0.00023573333333333336, + 0.00023720000000000002, + 0.0002386666666666667, + 0.00024013333333333336, + 0.00024160000000000002, + 0.0002430666666666667, + 0.0002445333333333334, + 0.000246, + 0.0002474666666666667, + 0.0002489333333333334, + 0.0002504, + 0.0002518666666666667, + 0.0002533333333333334, + 0.0002548, + 0.0002562666666666667, + 0.0002577333333333334, + 0.0002592, + 0.0002606666666666667, + 0.0002621333333333334, + 0.0002636, + 0.0002650666666666667, + 0.0002665333333333334, + 0.000268, + 0.0002694666666666667, + 0.0002709333333333334, + 0.0002724, + 0.0002738666666666667, + 0.0002753333333333334, + 0.0002768, + 0.0002782666666666667, + 0.0002797333333333334, + 0.0002812, + 0.0002826666666666667, + 0.0002841333333333334, + 0.0002856, + 0.0002870666666666667, + 0.00028853333333333337, + 0.00029000000000000006, + 0.0002914666666666667, + 0.00029293333333333337, + 0.00029440000000000005, + 0.0002958666666666667, + 0.00029733333333333337, + 0.00029880000000000005, + 0.0003002666666666667, + 0.00030173333333333337, + 0.00030320000000000005, + 0.0003046666666666667, + 0.00030613333333333337, + 0.00030760000000000005, + 0.0003090666666666667, + 0.00031053333333333336, + 0.00031200000000000005, + 0.0003134666666666667, + 0.00031493333333333336, + 0.00031640000000000005, + 0.0003178666666666667, + 0.00031933333333333336, + 0.00032080000000000005, + 0.0003222666666666667, + 0.00032373333333333336, + 0.00032520000000000004, + 0.00032666666666666673, + 0.00032813333333333336, + 0.00032960000000000004, + 0.0003310666666666667, + 0.00033253333333333336, + 0.00033400000000000004, + 0.0003354666666666667, + 0.00033693333333333336, + 0.00033840000000000004, + 0.0003398666666666667, + 0.00034133333333333335, + 0.00034280000000000004, + 0.0003442666666666667, + 0.00034573333333333335, + 0.00034720000000000004, + 0.0003486666666666667, + 0.00035013333333333335, + 0.00035160000000000004, + 0.0003530666666666667, + 0.00035453333333333335, + 0.00035600000000000003, + 0.0003574666666666667, + 0.00035893333333333335, + 0.00036040000000000003, + 0.0003618666666666667, + 0.0003633333333333334, + 0.00036480000000000003, + 0.0003662666666666667, + 0.0003677333333333334, + 0.00036920000000000003, + 0.0003706666666666667, + 0.0003721333333333334, + 0.00037360000000000003, + 0.0003750666666666667, + 0.0003765333333333334, + 0.000378, + 0.0003794666666666667, + 0.0003809333333333334, + 0.0003824, + 0.0003838666666666667, + 0.0003853333333333334, + 0.0003868, + 0.0003882666666666667, + 0.0003897333333333334, + 0.0003912, + 0.0003926666666666667, + 0.0003941333333333334, + 0.0003956, + 0.0003970666666666667, + 0.0003985333333333334, + 0.0004, + 0.00039853333333333333, + 0.0003970666666666667, + 0.0003956, + 0.00039413333333333334, + 0.0003926666666666667, + 0.0003912, + 0.00038973333333333334, + 0.0003882666666666667, + 0.0003868, + 0.00038533333333333334, + 0.0003838666666666667, + 0.0003824, + 0.00038093333333333334, + 0.0003794666666666667, + 0.000378, + 0.00037653333333333334, + 0.00037506666666666666, + 0.00037360000000000003, + 0.00037213333333333334, + 0.00037066666666666666, + 0.00036920000000000003, + 0.00036773333333333335, + 0.00036626666666666666, + 0.00036480000000000003, + 0.00036333333333333335, + 0.00036186666666666666, + 0.00036040000000000003, + 0.00035893333333333335, + 0.00035746666666666666, + 0.00035600000000000003, + 0.00035453333333333335, + 0.00035306666666666667, + 0.00035160000000000004, + 0.00035013333333333335, + 0.00034866666666666667, + 0.0003472, + 0.00034573333333333335, + 0.00034426666666666667, + 0.0003428, + 0.00034133333333333335, + 0.00033986666666666667, + 0.0003384, + 0.00033693333333333336, + 0.00033546666666666667, + 0.000334, + 0.00033253333333333336, + 0.00033106666666666667, + 0.0003296, + 0.00032813333333333336, + 0.0003266666666666667, + 0.0003252, + 0.00032373333333333336, + 0.0003222666666666667, + 0.0003208, + 0.00031933333333333336, + 0.0003178666666666667, + 0.0003164, + 0.00031493333333333336, + 0.0003134666666666667, + 0.000312, + 0.00031053333333333336, + 0.0003090666666666667, + 0.0003076, + 0.00030613333333333337, + 0.0003046666666666667, + 0.0003032, + 0.00030173333333333337, + 0.0003002666666666667, + 0.0002988, + 0.00029733333333333337, + 0.0002958666666666667, + 0.0002944, + 0.00029293333333333337, + 0.0002914666666666667, + 0.00029, + 0.0002885333333333333, + 0.0002870666666666667, + 0.0002856, + 0.0002841333333333333, + 0.00028266666666666663, + 0.0002812, + 0.0002797333333333333, + 0.00027826666666666664, + 0.0002768, + 0.0002753333333333333, + 0.00027386666666666664, + 0.0002724, + 0.0002709333333333333, + 0.00026946666666666664, + 0.000268, + 0.0002665333333333333, + 0.00026506666666666664, + 0.0002636, + 0.0002621333333333333, + 0.00026066666666666664, + 0.0002592, + 0.00025773333333333333, + 0.00025626666666666664, + 0.0002548, + 0.00025333333333333333, + 0.00025186666666666664, + 0.0002504, + 0.00024893333333333333, + 0.00024746666666666665, + 0.000246, + 0.00024453333333333333, + 0.00024306666666666668, + 0.0002416, + 0.00024013333333333333, + 0.00023866666666666665, + 0.0002372, + 0.00023573333333333334, + 0.00023426666666666665, + 0.0002328, + 0.00023133333333333334, + 0.00022986666666666665, + 0.0002284, + 0.00022693333333333334, + 0.00022546666666666665, + 0.000224, + 0.00022253333333333334, + 0.00022106666666666666, + 0.0002196, + 0.00021813333333333331, + 0.00021666666666666666, + 0.0002152, + 0.00021373333333333332, + 0.00021226666666666666, + 0.0002108, + 0.00020933333333333332, + 0.00020786666666666666, + 0.0002064, + 0.00020493333333333332, + 0.00020346666666666666, + 0.00020199999999999998, + 0.00020053333333333332, + 0.00019906666666666666, + 0.00019759999999999998, + 0.00019613333333333332, + 0.00019466666666666666, + 0.00019319999999999998, + 0.00019173333333333332, + 0.00019026666666666667, + 0.00018879999999999998, + 0.00018733333333333332, + 0.00018586666666666667, + 0.00018439999999999998, + 0.00018293333333333333, + 0.00018146666666666664, + 0.00017999999999999998, + 0.00017853333333333333, + 0.00017706666666666664, + 0.00017559999999999999, + 0.00017413333333333333, + 0.00017266666666666664, + 0.0001712, + 0.00016973333333333333, + 0.00016826666666666665, + 0.0001668, + 0.0001653333333333333, + 0.00016386666666666665, + 0.0001624, + 0.0001609333333333333, + 0.00015946666666666665, + 0.000158, + 0.0001565333333333333, + 0.00015506666666666662, + 0.0001536, + 0.0001521333333333333, + 0.00015066666666666662, + 0.0001492, + 0.0001477333333333333, + 0.00014626666666666663, + 0.0001448, + 0.0001433333333333333, + 0.00014186666666666663, + 0.0001404, + 0.0001389333333333333, + 0.00013746666666666663, + 0.000136, + 0.00013453333333333331, + 0.00013306666666666663, + 0.0001316, + 0.00013013333333333332, + 0.00012866666666666663, + 0.0001272, + 0.00012573333333333332, + 0.00012426666666666663, + 0.0001228, + 0.00012133333333333332, + 0.00011986666666666663, + 0.0001184, + 0.00011693333333333332, + 0.00011546666666666664, + 0.00011399999999999995, + 0.00011253333333333332, + 0.00011106666666666664, + 0.00010959999999999995, + 0.00010813333333333332, + 0.00010666666666666664, + 0.00010519999999999996, + 0.00010373333333333332, + 0.00010226666666666664, + 0.00010079999999999996, + 9.933333333333333e-05, + 9.786666666666664e-05, + 9.639999999999996e-05, + 9.493333333333333e-05, + 9.346666666666664e-05, + 9.199999999999996e-05, + 9.053333333333333e-05, + 8.906666666666665e-05, + 8.759999999999996e-05, + 8.613333333333333e-05, + 8.466666666666665e-05, + 8.319999999999996e-05, + 8.173333333333333e-05, + 8.026666666666665e-05, + 7.879999999999996e-05, + 7.733333333333328e-05, + 7.586666666666665e-05, + 7.439999999999997e-05, + 7.293333333333328e-05, + 7.146666666666665e-05, + 6.999999999999997e-05, + 6.853333333333328e-05, + 6.706666666666665e-05, + 6.559999999999997e-05, + 6.413333333333328e-05, + 6.266666666666665e-05, + 6.119999999999997e-05, + 5.9733333333333285e-05, + 5.8266666666666655e-05, + 5.679999999999997e-05, + 5.533333333333329e-05, + 5.386666666666666e-05, + 5.239999999999997e-05, + 5.093333333333329e-05, + 4.946666666666666e-05, + 4.7999999999999974e-05, + 4.653333333333329e-05, + 4.506666666666666e-05, + 4.3599999999999976e-05, + 4.213333333333329e-05, + 4.066666666666661e-05, + 3.919999999999998e-05, + 3.773333333333329e-05, + 3.626666666666661e-05, + 3.479999999999998e-05, + 3.3333333333333294e-05, + 3.186666666666661e-05, + 3.039999999999998e-05, + 2.8933333333333296e-05, + 2.746666666666661e-05, + 2.599999999999998e-05, + 2.4533333333333297e-05, + 2.3066666666666613e-05, + 2.1599999999999983e-05, + 2.01333333333333e-05, + 1.8666666666666614e-05, + 1.7199999999999984e-05, + 1.57333333333333e-05, + 1.4266666666666616e-05, + 1.2799999999999986e-05, + 1.1333333333333302e-05, + 9.866666666666617e-06, + 8.399999999999987e-06, + 6.933333333333303e-06, + 5.466666666666619e-06, + 4e-06, + 3.932372881355932e-06, + 3.8647457627118645e-06, + 3.7971186440677964e-06, + 3.7294915254237288e-06, + 3.661864406779661e-06, + 3.594237288135593e-06, + 3.5266101694915254e-06, + 3.4589830508474577e-06, + 3.3913559322033896e-06, + 3.323728813559322e-06, + 3.2561016949152543e-06, + 3.1884745762711862e-06, + 3.1208474576271186e-06, + 3.053220338983051e-06, + 2.985593220338983e-06, + 2.917966101694915e-06, + 2.8503389830508475e-06, + 2.7827118644067794e-06, + 2.7150847457627118e-06, + 2.647457627118644e-06, + 2.579830508474576e-06, + 2.5122033898305084e-06, + 2.4445762711864407e-06, + 2.3769491525423726e-06, + 2.309322033898305e-06, + 2.2416949152542373e-06, + 2.1740677966101692e-06, + 2.1064406779661016e-06, + 2.038813559322034e-06, + 1.9711864406779662e-06, + 1.9035593220338982e-06, + 1.8359322033898305e-06, + 1.7683050847457628e-06, + 1.7006779661016948e-06, + 1.633050847457627e-06, + 1.5654237288135594e-06, + 1.4977966101694914e-06, + 1.4301694915254237e-06, + 1.362542372881356e-06, + 1.2949152542372884e-06, + 1.2272881355932203e-06, + 1.1596610169491526e-06, + 1.092033898305085e-06, + 1.024406779661017e-06, + 9.567796610169492e-07, + 8.891525423728816e-07, + 8.215254237288135e-07, + 7.538983050847458e-07, + 6.862711864406782e-07, + 6.186440677966101e-07, + 5.510169491525424e-07, + 4.833898305084748e-07, + 4.157627118644067e-07, + 3.4813559322033905e-07, + 2.805084745762714e-07, + 2.128813559322033e-07, + 1.4525423728813607e-07, + 7.762711864406799e-08, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 128 +model = "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_core/returnn/training/ReturnnTrainingJob.Gml7clbIF93o/output/models/epoch" +num_epochs = 600 +num_inputs = 80 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_core/returnn/hdf/BlissToPcmHDFJob.J6Jzn8HrWB9r/output/audio.hdf" + ], + "partition_epoch": 20, + "seq_ordering": "laplace:.1000", + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.MHO753hGyR2z/output/targets.hdf" + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert( + 0, "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/recipe" +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.neural_block.dynamic_adaptable_conformer.adapt_based_on_gradient_ranking_finer_granul_double_and_prune import ( + ConformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.neural_block.dynamic_adaptable_conformer.adapt_based_on_gradient_ranking_finer_granul_double_and_prune import ( + ConformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_models.assemblies.dynamic_adaptable_conformer.conformer_v1_adapt_dropout import ( + ConformerEncoderV1Config, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.neural_block.dynamic_adaptable_conformer.adapt_based_on_gradient_ranking_finer_granul_double_and_prune import ( + train_step, +) +from i6_models.assemblies.dynamic_adaptable_conformer.conformer_v1_adapt_dropout import ( + ConformerBlockV1Config, +) +from torch.nn import SiLU +from i6_models.parts.dynamic_adaptable_conformer import ConformerConvolutionV2Config +from i6_models.parts.dynamic_adaptable_conformer import ConformerMHSAV1Config +from i6_models.parts.dynamic_adaptable_conformer import ( + ConformerPositionwiseFeedForwardV1Config, +) +from i6_models_repo.i6_models.parts.conformer.norm import LayerNormNC + +cfg = ConformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaug_args={ + "time_min_num_masks": 2, + "time_max_mask_per_n_frames": 25, + "time_mask_max_size": 20, + "freq_min_num_masks": 2, + "freq_mask_max_size": 5, + "freq_max_num_masks": 8, + }, + conformer_cfg=ConformerEncoderV1Config( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=512, + ), + ), + block_cfgs=[ + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ConformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + conv_cfg=ConformerConvolutionV2Config( + input_dim=512, + channels=512, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + norm=LayerNormNC((512,), eps=1e-05, elementwise_affine=True), + ), + adjust_dropout=True, + modules=["ff1", "conv", "mhsa", "ff2"], + scales=[0.5, 1.0, 1.0, 0.5], + ), + ], + ), + final_dropout=0.1, + target_size=79, + grad_score_opts={ + "grad_score_update_steps": 1000, + "grad_score_metric": "first_taylor", + "grad_update_beta": 0.1, + }, + adaptation_opts={ + "adaptation_global_step": [109800], + "total_cost": 74165392, + "rest_cost": 1451152, + "lst_replace_pct": [0.15], + "dict_module_cost": { + "ff1": 525184, + "conv": 201600, + "mhsa": 131456, + "ff2": 525184, + }, + }, + component_dist={ + "ff1": [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + "conv": [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + "mhsa": [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + "ff2": [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + }, + total_num_components=240, + lst_cmp_cost=[ + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 201600, + 201600, + 201600, + 201600, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + ], +) + +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return ConformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) diff --git a/2025-dynamic-model-architecture-optimization/LBS-960/ebranchformer_double_prune_model_size_384_freq_mask_50_adapt_epochs_90_pct_0.15_iters_1_0.0004.config b/2025-dynamic-model-architecture-optimization/LBS-960/ebranchformer_double_prune_model_size_384_freq_mask_50_adapt_epochs_90_pct_0.15_iters_1_0.0004.config new file mode 100644 index 00000000..78f7b6da --- /dev/null +++ b/2025-dynamic-model-architecture-optimization/LBS-960/ebranchformer_double_prune_model_size_384_freq_mask_50_adapt_epochs_90_pct_0.15_iters_1_0.0004.config @@ -0,0 +1,1304 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2400000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_core/returnn/hdf/BlissToPcmHDFJob.KErFrKsP3fTh/output/audio.hdf", + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_core/returnn/hdf/BlissToPcmHDFJob.Clwnntg2nopq/output/audio.hdf", + ], + "partition_epoch": 1, + "seq_ordering": "laplace:.1000", + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.K32gY5lpp41M/output/targets.hdf", + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.YH5GlkJveVPE/output/targets.hdf", + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 4e-06, + 5.466666666666666e-06, + 6.933333333333334e-06, + 8.400000000000001e-06, + 9.866666666666668e-06, + 1.1333333333333336e-05, + 1.28e-05, + 1.4266666666666667e-05, + 1.5733333333333334e-05, + 1.72e-05, + 1.866666666666667e-05, + 2.0133333333333336e-05, + 2.16e-05, + 2.3066666666666667e-05, + 2.4533333333333334e-05, + 2.6000000000000002e-05, + 2.746666666666667e-05, + 2.8933333333333336e-05, + 3.0400000000000004e-05, + 3.186666666666667e-05, + 3.333333333333334e-05, + 3.4800000000000006e-05, + 3.6266666666666676e-05, + 3.773333333333334e-05, + 3.9200000000000004e-05, + 4.0666666666666675e-05, + 4.213333333333334e-05, + 4.360000000000001e-05, + 4.506666666666667e-05, + 4.6533333333333344e-05, + 4.800000000000001e-05, + 4.946666666666668e-05, + 5.093333333333334e-05, + 5.2400000000000007e-05, + 5.386666666666668e-05, + 5.533333333333334e-05, + 5.680000000000001e-05, + 5.8266666666666676e-05, + 5.9733333333333346e-05, + 6.120000000000001e-05, + 6.266666666666668e-05, + 6.413333333333334e-05, + 6.560000000000001e-05, + 6.706666666666668e-05, + 6.853333333333335e-05, + 7.000000000000001e-05, + 7.146666666666668e-05, + 7.293333333333335e-05, + 7.44e-05, + 7.586666666666668e-05, + 7.733333333333335e-05, + 7.880000000000002e-05, + 8.026666666666668e-05, + 8.173333333333335e-05, + 8.320000000000002e-05, + 8.466666666666669e-05, + 8.613333333333334e-05, + 8.760000000000002e-05, + 8.906666666666669e-05, + 9.053333333333334e-05, + 9.200000000000001e-05, + 9.346666666666668e-05, + 9.493333333333336e-05, + 9.640000000000001e-05, + 9.786666666666668e-05, + 9.933333333333335e-05, + 0.00010080000000000001, + 0.00010226666666666668, + 0.00010373333333333335, + 0.00010520000000000002, + 0.00010666666666666668, + 0.00010813333333333335, + 0.00010960000000000002, + 0.00011106666666666668, + 0.00011253333333333335, + 0.00011400000000000002, + 0.00011546666666666669, + 0.00011693333333333335, + 0.00011840000000000002, + 0.00011986666666666669, + 0.00012133333333333336, + 0.0001228, + 0.0001242666666666667, + 0.00012573333333333334, + 0.0001272, + 0.00012866666666666669, + 0.00013013333333333334, + 0.0001316, + 0.00013306666666666668, + 0.00013453333333333334, + 0.000136, + 0.00013746666666666668, + 0.00013893333333333334, + 0.0001404, + 0.00014186666666666668, + 0.00014333333333333334, + 0.0001448, + 0.00014626666666666668, + 0.00014773333333333334, + 0.00014920000000000002, + 0.00015066666666666668, + 0.00015213333333333334, + 0.00015360000000000002, + 0.00015506666666666668, + 0.00015653333333333333, + 0.00015800000000000002, + 0.00015946666666666668, + 0.00016093333333333333, + 0.00016240000000000002, + 0.00016386666666666667, + 0.00016533333333333336, + 0.00016680000000000002, + 0.00016826666666666667, + 0.00016973333333333336, + 0.00017120000000000001, + 0.00017266666666666667, + 0.00017413333333333336, + 0.0001756, + 0.00017706666666666667, + 0.00017853333333333335, + 0.00018, + 0.00018146666666666667, + 0.00018293333333333335, + 0.0001844, + 0.0001858666666666667, + 0.00018733333333333335, + 0.0001888, + 0.0001902666666666667, + 0.00019173333333333335, + 0.0001932, + 0.0001946666666666667, + 0.00019613333333333335, + 0.0001976, + 0.0001990666666666667, + 0.00020053333333333335, + 0.00020200000000000003, + 0.0002034666666666667, + 0.00020493333333333335, + 0.00020640000000000003, + 0.0002078666666666667, + 0.00020933333333333334, + 0.00021080000000000003, + 0.00021226666666666669, + 0.00021373333333333334, + 0.00021520000000000003, + 0.00021666666666666668, + 0.00021813333333333334, + 0.00021960000000000003, + 0.00022106666666666668, + 0.00022253333333333337, + 0.00022400000000000002, + 0.00022546666666666668, + 0.00022693333333333337, + 0.00022840000000000002, + 0.00022986666666666668, + 0.00023133333333333336, + 0.00023280000000000002, + 0.00023426666666666668, + 0.00023573333333333336, + 0.00023720000000000002, + 0.0002386666666666667, + 0.00024013333333333336, + 0.00024160000000000002, + 0.0002430666666666667, + 0.0002445333333333334, + 0.000246, + 0.0002474666666666667, + 0.0002489333333333334, + 0.0002504, + 0.0002518666666666667, + 0.0002533333333333334, + 0.0002548, + 0.0002562666666666667, + 0.0002577333333333334, + 0.0002592, + 0.0002606666666666667, + 0.0002621333333333334, + 0.0002636, + 0.0002650666666666667, + 0.0002665333333333334, + 0.000268, + 0.0002694666666666667, + 0.0002709333333333334, + 0.0002724, + 0.0002738666666666667, + 0.0002753333333333334, + 0.0002768, + 0.0002782666666666667, + 0.0002797333333333334, + 0.0002812, + 0.0002826666666666667, + 0.0002841333333333334, + 0.0002856, + 0.0002870666666666667, + 0.00028853333333333337, + 0.00029000000000000006, + 0.0002914666666666667, + 0.00029293333333333337, + 0.00029440000000000005, + 0.0002958666666666667, + 0.00029733333333333337, + 0.00029880000000000005, + 0.0003002666666666667, + 0.00030173333333333337, + 0.00030320000000000005, + 0.0003046666666666667, + 0.00030613333333333337, + 0.00030760000000000005, + 0.0003090666666666667, + 0.00031053333333333336, + 0.00031200000000000005, + 0.0003134666666666667, + 0.00031493333333333336, + 0.00031640000000000005, + 0.0003178666666666667, + 0.00031933333333333336, + 0.00032080000000000005, + 0.0003222666666666667, + 0.00032373333333333336, + 0.00032520000000000004, + 0.00032666666666666673, + 0.00032813333333333336, + 0.00032960000000000004, + 0.0003310666666666667, + 0.00033253333333333336, + 0.00033400000000000004, + 0.0003354666666666667, + 0.00033693333333333336, + 0.00033840000000000004, + 0.0003398666666666667, + 0.00034133333333333335, + 0.00034280000000000004, + 0.0003442666666666667, + 0.00034573333333333335, + 0.00034720000000000004, + 0.0003486666666666667, + 0.00035013333333333335, + 0.00035160000000000004, + 0.0003530666666666667, + 0.00035453333333333335, + 0.00035600000000000003, + 0.0003574666666666667, + 0.00035893333333333335, + 0.00036040000000000003, + 0.0003618666666666667, + 0.0003633333333333334, + 0.00036480000000000003, + 0.0003662666666666667, + 0.0003677333333333334, + 0.00036920000000000003, + 0.0003706666666666667, + 0.0003721333333333334, + 0.00037360000000000003, + 0.0003750666666666667, + 0.0003765333333333334, + 0.000378, + 0.0003794666666666667, + 0.0003809333333333334, + 0.0003824, + 0.0003838666666666667, + 0.0003853333333333334, + 0.0003868, + 0.0003882666666666667, + 0.0003897333333333334, + 0.0003912, + 0.0003926666666666667, + 0.0003941333333333334, + 0.0003956, + 0.0003970666666666667, + 0.0003985333333333334, + 0.0004, + 0.00039853333333333333, + 0.0003970666666666667, + 0.0003956, + 0.00039413333333333334, + 0.0003926666666666667, + 0.0003912, + 0.00038973333333333334, + 0.0003882666666666667, + 0.0003868, + 0.00038533333333333334, + 0.0003838666666666667, + 0.0003824, + 0.00038093333333333334, + 0.0003794666666666667, + 0.000378, + 0.00037653333333333334, + 0.00037506666666666666, + 0.00037360000000000003, + 0.00037213333333333334, + 0.00037066666666666666, + 0.00036920000000000003, + 0.00036773333333333335, + 0.00036626666666666666, + 0.00036480000000000003, + 0.00036333333333333335, + 0.00036186666666666666, + 0.00036040000000000003, + 0.00035893333333333335, + 0.00035746666666666666, + 0.00035600000000000003, + 0.00035453333333333335, + 0.00035306666666666667, + 0.00035160000000000004, + 0.00035013333333333335, + 0.00034866666666666667, + 0.0003472, + 0.00034573333333333335, + 0.00034426666666666667, + 0.0003428, + 0.00034133333333333335, + 0.00033986666666666667, + 0.0003384, + 0.00033693333333333336, + 0.00033546666666666667, + 0.000334, + 0.00033253333333333336, + 0.00033106666666666667, + 0.0003296, + 0.00032813333333333336, + 0.0003266666666666667, + 0.0003252, + 0.00032373333333333336, + 0.0003222666666666667, + 0.0003208, + 0.00031933333333333336, + 0.0003178666666666667, + 0.0003164, + 0.00031493333333333336, + 0.0003134666666666667, + 0.000312, + 0.00031053333333333336, + 0.0003090666666666667, + 0.0003076, + 0.00030613333333333337, + 0.0003046666666666667, + 0.0003032, + 0.00030173333333333337, + 0.0003002666666666667, + 0.0002988, + 0.00029733333333333337, + 0.0002958666666666667, + 0.0002944, + 0.00029293333333333337, + 0.0002914666666666667, + 0.00029, + 0.0002885333333333333, + 0.0002870666666666667, + 0.0002856, + 0.0002841333333333333, + 0.00028266666666666663, + 0.0002812, + 0.0002797333333333333, + 0.00027826666666666664, + 0.0002768, + 0.0002753333333333333, + 0.00027386666666666664, + 0.0002724, + 0.0002709333333333333, + 0.00026946666666666664, + 0.000268, + 0.0002665333333333333, + 0.00026506666666666664, + 0.0002636, + 0.0002621333333333333, + 0.00026066666666666664, + 0.0002592, + 0.00025773333333333333, + 0.00025626666666666664, + 0.0002548, + 0.00025333333333333333, + 0.00025186666666666664, + 0.0002504, + 0.00024893333333333333, + 0.00024746666666666665, + 0.000246, + 0.00024453333333333333, + 0.00024306666666666668, + 0.0002416, + 0.00024013333333333333, + 0.00023866666666666665, + 0.0002372, + 0.00023573333333333334, + 0.00023426666666666665, + 0.0002328, + 0.00023133333333333334, + 0.00022986666666666665, + 0.0002284, + 0.00022693333333333334, + 0.00022546666666666665, + 0.000224, + 0.00022253333333333334, + 0.00022106666666666666, + 0.0002196, + 0.00021813333333333331, + 0.00021666666666666666, + 0.0002152, + 0.00021373333333333332, + 0.00021226666666666666, + 0.0002108, + 0.00020933333333333332, + 0.00020786666666666666, + 0.0002064, + 0.00020493333333333332, + 0.00020346666666666666, + 0.00020199999999999998, + 0.00020053333333333332, + 0.00019906666666666666, + 0.00019759999999999998, + 0.00019613333333333332, + 0.00019466666666666666, + 0.00019319999999999998, + 0.00019173333333333332, + 0.00019026666666666667, + 0.00018879999999999998, + 0.00018733333333333332, + 0.00018586666666666667, + 0.00018439999999999998, + 0.00018293333333333333, + 0.00018146666666666664, + 0.00017999999999999998, + 0.00017853333333333333, + 0.00017706666666666664, + 0.00017559999999999999, + 0.00017413333333333333, + 0.00017266666666666664, + 0.0001712, + 0.00016973333333333333, + 0.00016826666666666665, + 0.0001668, + 0.0001653333333333333, + 0.00016386666666666665, + 0.0001624, + 0.0001609333333333333, + 0.00015946666666666665, + 0.000158, + 0.0001565333333333333, + 0.00015506666666666662, + 0.0001536, + 0.0001521333333333333, + 0.00015066666666666662, + 0.0001492, + 0.0001477333333333333, + 0.00014626666666666663, + 0.0001448, + 0.0001433333333333333, + 0.00014186666666666663, + 0.0001404, + 0.0001389333333333333, + 0.00013746666666666663, + 0.000136, + 0.00013453333333333331, + 0.00013306666666666663, + 0.0001316, + 0.00013013333333333332, + 0.00012866666666666663, + 0.0001272, + 0.00012573333333333332, + 0.00012426666666666663, + 0.0001228, + 0.00012133333333333332, + 0.00011986666666666663, + 0.0001184, + 0.00011693333333333332, + 0.00011546666666666664, + 0.00011399999999999995, + 0.00011253333333333332, + 0.00011106666666666664, + 0.00010959999999999995, + 0.00010813333333333332, + 0.00010666666666666664, + 0.00010519999999999996, + 0.00010373333333333332, + 0.00010226666666666664, + 0.00010079999999999996, + 9.933333333333333e-05, + 9.786666666666664e-05, + 9.639999999999996e-05, + 9.493333333333333e-05, + 9.346666666666664e-05, + 9.199999999999996e-05, + 9.053333333333333e-05, + 8.906666666666665e-05, + 8.759999999999996e-05, + 8.613333333333333e-05, + 8.466666666666665e-05, + 8.319999999999996e-05, + 8.173333333333333e-05, + 8.026666666666665e-05, + 7.879999999999996e-05, + 7.733333333333328e-05, + 7.586666666666665e-05, + 7.439999999999997e-05, + 7.293333333333328e-05, + 7.146666666666665e-05, + 6.999999999999997e-05, + 6.853333333333328e-05, + 6.706666666666665e-05, + 6.559999999999997e-05, + 6.413333333333328e-05, + 6.266666666666665e-05, + 6.119999999999997e-05, + 5.9733333333333285e-05, + 5.8266666666666655e-05, + 5.679999999999997e-05, + 5.533333333333329e-05, + 5.386666666666666e-05, + 5.239999999999997e-05, + 5.093333333333329e-05, + 4.946666666666666e-05, + 4.7999999999999974e-05, + 4.653333333333329e-05, + 4.506666666666666e-05, + 4.3599999999999976e-05, + 4.213333333333329e-05, + 4.066666666666661e-05, + 3.919999999999998e-05, + 3.773333333333329e-05, + 3.626666666666661e-05, + 3.479999999999998e-05, + 3.3333333333333294e-05, + 3.186666666666661e-05, + 3.039999999999998e-05, + 2.8933333333333296e-05, + 2.746666666666661e-05, + 2.599999999999998e-05, + 2.4533333333333297e-05, + 2.3066666666666613e-05, + 2.1599999999999983e-05, + 2.01333333333333e-05, + 1.8666666666666614e-05, + 1.7199999999999984e-05, + 1.57333333333333e-05, + 1.4266666666666616e-05, + 1.2799999999999986e-05, + 1.1333333333333302e-05, + 9.866666666666617e-06, + 8.399999999999987e-06, + 6.933333333333303e-06, + 5.466666666666619e-06, + 4e-06, + 3.932372881355932e-06, + 3.8647457627118645e-06, + 3.7971186440677964e-06, + 3.7294915254237288e-06, + 3.661864406779661e-06, + 3.594237288135593e-06, + 3.5266101694915254e-06, + 3.4589830508474577e-06, + 3.3913559322033896e-06, + 3.323728813559322e-06, + 3.2561016949152543e-06, + 3.1884745762711862e-06, + 3.1208474576271186e-06, + 3.053220338983051e-06, + 2.985593220338983e-06, + 2.917966101694915e-06, + 2.8503389830508475e-06, + 2.7827118644067794e-06, + 2.7150847457627118e-06, + 2.647457627118644e-06, + 2.579830508474576e-06, + 2.5122033898305084e-06, + 2.4445762711864407e-06, + 2.3769491525423726e-06, + 2.309322033898305e-06, + 2.2416949152542373e-06, + 2.1740677966101692e-06, + 2.1064406779661016e-06, + 2.038813559322034e-06, + 1.9711864406779662e-06, + 1.9035593220338982e-06, + 1.8359322033898305e-06, + 1.7683050847457628e-06, + 1.7006779661016948e-06, + 1.633050847457627e-06, + 1.5654237288135594e-06, + 1.4977966101694914e-06, + 1.4301694915254237e-06, + 1.362542372881356e-06, + 1.2949152542372884e-06, + 1.2272881355932203e-06, + 1.1596610169491526e-06, + 1.092033898305085e-06, + 1.024406779661017e-06, + 9.567796610169492e-07, + 8.891525423728816e-07, + 8.215254237288135e-07, + 7.538983050847458e-07, + 6.862711864406782e-07, + 6.186440677966101e-07, + 5.510169491525424e-07, + 4.833898305084748e-07, + 4.157627118644067e-07, + 3.4813559322033905e-07, + 2.805084745762714e-07, + 2.128813559322033e-07, + 1.4525423728813607e-07, + 7.762711864406799e-08, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 128 +model = "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_core/returnn/training/ReturnnTrainingJob.f9Ss2t9rtsQn/output/models/epoch" +num_epochs = 600 +num_inputs = 80 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_core/returnn/hdf/BlissToPcmHDFJob.J6Jzn8HrWB9r/output/audio.hdf" + ], + "partition_epoch": 20, + "seq_ordering": "laplace:.1000", + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.MHO753hGyR2z/output/targets.hdf" + ], + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "features", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert( + 0, "/u/jxu/setups/librispeech-960/2024-11-19--construct-better-neural-blocks/recipe" +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.neural_block.dynamic_adaptable_e_branchformer.adapt_based_on_gradient_ranking_finer_granul_double_and_prune_no_input_norm import ( + EbranchformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.neural_block.dynamic_adaptable_e_branchformer.adapt_based_on_gradient_ranking_finer_granul_double_and_prune_no_input_norm import ( + EbranchformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_models.assemblies.dynamic_adaptable_e_branchformer.e_branchformer_v1 import ( + EbranchformerEncoderV1Config, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.neural_block.dynamic_adaptable_e_branchformer.adapt_based_on_gradient_ranking_finer_granul_double_and_prune_no_input_norm import ( + train_step, +) +from i6_models.assemblies.dynamic_adaptable_e_branchformer.e_branchformer_v1 import ( + EbranchformerBlockV1Config, +) +from torch.nn import SiLU +from i6_models.parts.dynamic_adaptable_conformer import ConformerMHSAV1Config +from i6_models.parts.dynamic_adaptable_conformer import ( + ConformerPositionwiseFeedForwardV1Config, +) +from i6_models.parts.dynamic_adaptable_e_branchformer import ( + ConvolutionalGatingMLPV1Config, +) +from i6_models.parts.dynamic_adaptable_e_branchformer import MergerV1Config + +cfg = EbranchformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaug_args={ + "time_min_num_masks": 2, + "time_max_mask_per_n_frames": 25, + "time_mask_max_size": 20, + "freq_min_num_masks": 2, + "freq_mask_max_size": 5, + "freq_max_num_masks": 8, + }, + e_branchformer_cfg=EbranchformerEncoderV1Config( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=384, + ), + ), + block_cfgs=[ + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=384, hidden_dim=1536, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=384, + num_att_heads=6, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=384, + hidden_dim=2304, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=384, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + ], + ), + final_dropout=0.1, + target_size=79, + grad_score_opts={ + "grad_score_update_steps": 1000, + "grad_score_metric": "first_taylor", + "grad_update_beta": 0.1, + }, + adaptation_opts={ + "adaptation_global_step": [108000], + "total_cost": 56865808, + "rest_cost": 4947472, + "lst_replace_pct": [0.15], + "dict_module_cost": { + "ff1": 295584, + "cgmlp": 342432, + "mhsa": 98688, + "ff2": 295584, + }, + }, + component_dist={ + "ff1": [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + "cgmlp": [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + "mhsa": [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6], + "ff2": [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + }, + total_num_components=216, + lst_cmp_cost=[ + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 295584, + 342432, + 342432, + 342432, + 342432, + 98688, + 98688, + 98688, + 98688, + 98688, + 98688, + 295584, + 295584, + 295584, + 295584, + ], +) + +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return EbranchformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs) diff --git a/2025-dynamic-model-architecture-optimization/README.md b/2025-dynamic-model-architecture-optimization/README.md index 88e4d252..46e8631e 100644 --- a/2025-dynamic-model-architecture-optimization/README.md +++ b/2025-dynamic-model-architecture-optimization/README.md @@ -1,3 +1,15 @@ This folder contains configs and code related to the publication: paper [Dynamic Acoustic Model Architecture Optimization in Training for ASR]() + +We use [RETURNN](https://github.com/rwth-i6/returnn) for training and our setups are based on [Sisyphus](https://github.com/rwth-i6/sisyphus). + +We use models parts from [i6-models](https://github.com/rwth-i6/i6_models/tree/jing-dynamic-encoder-size) + +### DMAO with Conformer CTC + +ConformerCTCModel, ConformerCTCConfig and train_step in returnn config is defined in [here](https://github.com/rwth-i6/i6_experiments/blob/main/users/jxu/experiments/ctc/tedlium2/pytorch_networks/neural_block/dynamic_adaptable_conformer/adapt_based_on_gradient_ranking_finer_granul_double_and_prune.py) + +### DMAO with Ebranchformer CTC + +EbranchformerCTCModel, EbranchformerCTCConfig and train_step in returnn config is defined in [here](https://github.com/rwth-i6/i6_experiments/blob/main/users/jxu/experiments/ctc/tedlium2/pytorch_networks/neural_block/dynamic_adaptable_e_branchformer/adapt_based_on_gradient_ranking_finer_granul_double_and_prune.py) \ No newline at end of file diff --git a/2025-dynamic-model-architecture-optimization/TED-LIUM-v2/ebranchformer_double_prune_model_size_512_adjust_dropout_True_adapt_epochs_50_pct_0.2_iters_1_0.0005.config b/2025-dynamic-model-architecture-optimization/TED-LIUM-v2/ebranchformer_double_prune_model_size_512_adjust_dropout_True_adapt_epochs_50_pct_0.2_iters_1_0.0005.config new file mode 100644 index 00000000..49e9b7fd --- /dev/null +++ b/2025-dynamic-model-architecture-optimization/TED-LIUM-v2/ebranchformer_double_prune_model_size_512_adjust_dropout_True_adapt_epochs_50_pct_0.2_iters_1_0.0005.config @@ -0,0 +1,976 @@ +#!rnn.py + + +import numpy as np + +backend = "torch" +batch_size = 2880000 +batching = "random" +cache_size = "0" +cleanup_old_models = True +dev = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-09-10--construct-better-neural-blocks/work/i6_core/returnn/hdf/BlissToPcmHDFJob.AltkEvXwM3dF/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-09-10--construct-better-neural-blocks/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.qGsTLu0lnU5n/output/targets.hdf" + ], + "partition_epoch": 1, + "seq_ordering": "sorted", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +device = "gpu" +extern_data = {"data": {"dim": 1}, "targets": {"dim": 79, "sparse": True}} +gradient_clip = 0.0 +gradient_noise = 0.0 +learning_rate_file = "learning_rates" +learning_rates = [ + 5e-06, + 9.5e-06, + 1.4000000000000001e-05, + 1.85e-05, + 2.3e-05, + 2.75e-05, + 3.2e-05, + 3.65e-05, + 4.1e-05, + 4.55e-05, + 5e-05, + 5.45e-05, + 5.9e-05, + 6.35e-05, + 6.8e-05, + 7.25e-05, + 7.7e-05, + 8.15e-05, + 8.6e-05, + 9.05e-05, + 9.5e-05, + 9.95e-05, + 0.00010400000000000001, + 0.00010850000000000001, + 0.000113, + 0.0001175, + 0.000122, + 0.0001265, + 0.000131, + 0.00013550000000000001, + 0.00014000000000000001, + 0.00014450000000000002, + 0.00014900000000000002, + 0.00015350000000000002, + 0.00015800000000000002, + 0.00016250000000000002, + 0.00016700000000000002, + 0.00017150000000000002, + 0.00017600000000000002, + 0.00018050000000000002, + 0.00018500000000000002, + 0.00018950000000000003, + 0.00019400000000000003, + 0.00019850000000000003, + 0.00020300000000000003, + 0.00020750000000000003, + 0.00021200000000000003, + 0.00021650000000000003, + 0.000221, + 0.0002255, + 0.00023, + 0.0002345, + 0.000239, + 0.0002435, + 0.000248, + 0.0002525, + 0.000257, + 0.0002615, + 0.000266, + 0.0002705, + 0.000275, + 0.0002795, + 0.000284, + 0.0002885, + 0.000293, + 0.0002975, + 0.000302, + 0.0003065, + 0.000311, + 0.0003155, + 0.00032, + 0.00032450000000000003, + 0.00032900000000000003, + 0.00033350000000000003, + 0.00033800000000000003, + 0.00034250000000000003, + 0.00034700000000000003, + 0.00035150000000000003, + 0.00035600000000000003, + 0.00036050000000000003, + 0.00036500000000000004, + 0.00036950000000000004, + 0.00037400000000000004, + 0.00037850000000000004, + 0.00038300000000000004, + 0.00038750000000000004, + 0.00039200000000000004, + 0.00039650000000000004, + 0.00040100000000000004, + 0.00040550000000000004, + 0.00041000000000000005, + 0.00041450000000000005, + 0.00041900000000000005, + 0.00042350000000000005, + 0.00042800000000000005, + 0.00043250000000000005, + 0.000437, + 0.0004415, + 0.000446, + 0.0004505, + 0.000455, + 0.0004595, + 0.000464, + 0.0004685, + 0.000473, + 0.0004775, + 0.000482, + 0.0004865, + 0.000491, + 0.0004955000000000001, + 0.0005, + 0.0004959090909090909, + 0.0004918181818181818, + 0.00048772727272727276, + 0.00048363636363636366, + 0.00047954545454545456, + 0.00047545454545454545, + 0.00047136363636363635, + 0.0004672727272727273, + 0.0004631818181818182, + 0.0004590909090909091, + 0.000455, + 0.0004509090909090909, + 0.00044681818181818185, + 0.00044272727272727275, + 0.00043863636363636365, + 0.00043454545454545455, + 0.0004304545454545455, + 0.0004263636363636364, + 0.0004222727272727273, + 0.0004181818181818182, + 0.0004140909090909091, + 0.00041, + 0.00040590909090909094, + 0.00040181818181818184, + 0.00039772727272727274, + 0.00039363636363636364, + 0.0003895454545454546, + 0.0003854545454545455, + 0.0003813636363636364, + 0.0003772727272727273, + 0.0003731818181818182, + 0.0003690909090909091, + 0.000365, + 0.00036090909090909093, + 0.00035681818181818183, + 0.0003527272727272728, + 0.0003486363636363637, + 0.0003445454545454546, + 0.0003404545454545455, + 0.0003363636363636364, + 0.0003322727272727273, + 0.0003281818181818182, + 0.0003240909090909091, + 0.00032, + 0.0003159090909090909, + 0.0003118181818181819, + 0.0003077272727272728, + 0.0003036363636363637, + 0.00029954545454545457, + 0.00029545454545454547, + 0.00029136363636363637, + 0.00028727272727272727, + 0.0002831818181818182, + 0.0002790909090909091, + 0.000275, + 0.00027090909090909097, + 0.00026681818181818187, + 0.00026272727272727277, + 0.00025863636363636366, + 0.00025454545454545456, + 0.00025045454545454546, + 0.0002463636363636364, + 0.0002422727272727273, + 0.0002381818181818182, + 0.0002340909090909091, + 0.00023, + 0.00022590909090909096, + 0.00022181818181818186, + 0.00021772727272727276, + 0.00021363636363636365, + 0.00020954545454545455, + 0.0002054545454545455, + 0.0002013636363636364, + 0.0001972727272727273, + 0.0001931818181818182, + 0.0001890909090909091, + 0.00018500000000000005, + 0.00018090909090909095, + 0.00017681818181818185, + 0.00017272727272727275, + 0.00016863636363636364, + 0.0001645454545454546, + 0.0001604545454545455, + 0.0001563636363636364, + 0.0001522727272727273, + 0.0001481818181818182, + 0.00014409090909090914, + 0.00014000000000000004, + 0.00013590909090909094, + 0.00013181818181818184, + 0.00012772727272727274, + 0.0001236363636363637, + 0.00011954545454545459, + 0.00011545454545454549, + 0.00011136363636363638, + 0.00010727272727272734, + 0.00010318181818181824, + 9.909090909090913e-05, + 9.500000000000003e-05, + 9.090909090909093e-05, + 8.681818181818188e-05, + 8.272727272727278e-05, + 7.863636363636368e-05, + 7.454545454545458e-05, + 7.045454545454548e-05, + 6.636363636363643e-05, + 6.227272727272733e-05, + 5.8181818181818226e-05, + 5.4090909090909124e-05, + 5e-05, + 4.827620689655173e-05, + 4.655241379310345e-05, + 4.4828620689655175e-05, + 4.31048275862069e-05, + 4.1381034482758624e-05, + 3.965724137931035e-05, + 3.793344827586207e-05, + 3.62096551724138e-05, + 3.448586206896552e-05, + 3.276206896551724e-05, + 3.103827586206897e-05, + 2.931448275862069e-05, + 2.7590689655172415e-05, + 2.586689655172414e-05, + 2.4143103448275864e-05, + 2.241931034482759e-05, + 2.0695517241379313e-05, + 1.8971724137931034e-05, + 1.7247931034482758e-05, + 1.5524137931034483e-05, + 1.3800344827586207e-05, + 1.2076551724137931e-05, + 1.0352758620689656e-05, + 8.62896551724138e-06, + 6.905172413793104e-06, + 5.1813793103448286e-06, + 3.457586206896553e-06, + 1.7337931034482773e-06, + 1e-08, +] +log = ["./returnn.log"] +log_batch_size = True +log_verbosity = 5 +max_seqs = 128 +model = "/u/jxu/setups/tedlium2/2024-09-10--construct-better-neural-blocks/work/i6_core/returnn/training/ReturnnTrainingJob.2mX3iCXw4Jip/output/models/epoch" +num_epochs = 250 +num_inputs = 50 +num_outputs = {"targets": 79} +optimizer = {"class": "adamw", "epsilon": 1e-16, "weight_decay": 0.001} +save_interval = 1 +target = "targets" +task = "train" +tf_log_memory_usage = True +train = { + "class": "MetaDataset", + "datasets": { + "features": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-09-10--construct-better-neural-blocks/work/i6_core/returnn/hdf/BlissToPcmHDFJob.T3qQ5mfrQwlw/output/audio.hdf" + ], + }, + "targets": { + "class": "HDFDataset", + "use_cache_manager": True, + "files": [ + "/u/jxu/setups/tedlium2/2024-09-10--construct-better-neural-blocks/work/i6_experiments/users/berger/recipe/returnn/hdf/BlissCorpusToTargetHdfJob.bAoRzty8czAI/output/targets.hdf" + ], + "partition_epoch": 5, + "seq_ordering": "laplace:.1000", + }, + }, + "data_map": {"data": ("features", "data"), "targets": ("targets", "data")}, + "seq_order_control_dataset": "targets", +} +update_on_device = True +window = 1 +config = {} + +locals().update(**config) + +import os +import sys + +sys.path.insert( + 0, "/u/jxu/setups/tedlium2/2024-09-10--construct-better-neural-blocks/recipe" +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.neural_block.dynamic_adaptable_e_branchformer.adapt_based_on_gradient_ranking_finer_granul_double_and_prune import ( + EbranchformerCTCModel, +) +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.neural_block.dynamic_adaptable_e_branchformer.adapt_based_on_gradient_ranking_finer_granul_double_and_prune import ( + EbranchformerCTCConfig, +) +from i6_models.primitives.feature_extraction import LogMelFeatureExtractionV1Config +from i6_models.assemblies.dynamic_adaptable_e_branchformer.e_branchformer_v1 import ( + EbranchformerEncoderV1Config, +) +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1Config +from torch.nn.modules.activation import ReLU +from i6_models.parts.frontend.vgg_act import VGG4LayerActFrontendV1 +from i6_models.config import ModuleFactoryV1 +from i6_experiments.users.jxu.experiments.ctc.tedlium2.pytorch_networks.neural_block.dynamic_adaptable_e_branchformer.adapt_based_on_gradient_ranking_finer_granul_double_and_prune import ( + train_step, +) +from i6_models.assemblies.dynamic_adaptable_e_branchformer.e_branchformer_v1 import ( + EbranchformerBlockV1Config, +) +from torch.nn import SiLU +from i6_models.parts.dynamic_adaptable_conformer import ConformerMHSAV1Config +from i6_models.parts.dynamic_adaptable_conformer import ( + ConformerPositionwiseFeedForwardV1Config, +) +from i6_models.parts.dynamic_adaptable_e_branchformer import ( + ConvolutionalGatingMLPV1Config, +) +from i6_models.parts.dynamic_adaptable_e_branchformer import MergerV1Config + +cfg = EbranchformerCTCConfig( + feature_extraction_cfg=LogMelFeatureExtractionV1Config( + sample_rate=16000, + win_size=0.025, + hop_size=0.01, + f_min=60, + f_max=7600, + min_amp=1e-10, + num_filters=80, + center=False, + n_fft=400, + ), + specaug_args={ + "time_min_num_masks": 2, + "time_max_mask_per_n_frames": 25, + "time_mask_max_size": 20, + "freq_min_num_masks": 2, + "freq_mask_max_size": 5, + "freq_max_num_masks": 16, + }, + e_branchformer_cfg=EbranchformerEncoderV1Config( + num_layers=12, + frontend=ModuleFactoryV1( + module_class=VGG4LayerActFrontendV1, + cfg=VGG4LayerActFrontendV1Config( + in_features=80, + conv1_channels=32, + conv2_channels=64, + conv3_channels=64, + conv4_channels=32, + conv_kernel_size=(3, 3), + conv_padding=None, + pool1_kernel_size=(2, 1), + pool1_stride=(2, 1), + pool1_padding=None, + pool2_kernel_size=(2, 1), + pool2_stride=(2, 1), + pool2_padding=None, + activation=ReLU(), + out_features=512, + ), + ), + block_cfgs=[ + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + EbranchformerBlockV1Config( + ff1_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + ff2_cfg=ConformerPositionwiseFeedForwardV1Config( + input_dim=512, hidden_dim=2048, dropout=0.1, activation=SiLU() + ), + mhsa_cfg=ConformerMHSAV1Config( + input_dim=512, + num_att_heads=8, + att_head_dim=64, + att_weights_dropout=0.1, + dropout=0.1, + ), + cgmlp_cfg=ConvolutionalGatingMLPV1Config( + input_dim=512, + hidden_dim=3072, + kernel_size=31, + dropout=0.1, + activation=SiLU(), + ), + merger_cfg=MergerV1Config(input_dim=512, kernel_size=31, dropout=0.1), + adjust_dropout=True, + ), + ], + ), + final_dropout=0.2, + target_size=79, + grad_score_opts={ + "grad_score_update_steps": 1000, + "grad_score_metric": "first_taylor", + "grad_update_beta": 0.1, + }, + adaptation_opts={ + "adaptation_global_step": [60000], + "total_cost": 100176527, + "rest_cost": 8145551, + "lst_replace_pct": [0.2], + "dict_module_cost": { + "ff1": 525184, + "cgmlp": 604032, + "mhsa": 131456, + "ff2": 525184, + }, + }, + component_dist={ + "ff1": [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + "cgmlp": [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + "mhsa": [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + "ff2": [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + }, + total_num_components=240, + lst_cmp_cost=[ + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 525184, + 604032, + 604032, + 604032, + 604032, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 131456, + 525184, + 525184, + 525184, + 525184, + ], +) + +model_kwargs = {"cfg": cfg} + + +def get_model(epoch, step, **kwargs): + return EbranchformerCTCModel(epoch=epoch, step=step, **model_kwargs, **kwargs)