diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 5b1268f3c64..2da3f5bea6b 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1291,7 +1291,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]: class WhisperOnnxConfig(AudioToTextOnnxConfig): - NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig + NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( + encoder_num_layers="encoder_layers", + decoder_num_layers="decoder_layers", + ) ATOL_FOR_VALIDATION = 1e-3 @property