diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 10d38fae76d..df7754c3769 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -343,6 +343,7 @@ def patch_model_for_export( class MPTOnnxConfig(TextDecoderOnnxConfig): # MPT does not require position_ids input. DEFAULT_ONNX_OPSET = 13 + # TODO: fix inference for transformers < v4.41 for beam_search > 1 MIN_TRANSFORMERS_VERSION = version.parse("4.41.0") NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"