diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e23716d4b74..421b7c9010a 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -600,7 +600,7 @@ def inputs_for_default_and_seq2seq_lm(self): def inputs_for_causal_lm(self): if self.use_past_in_inputs: common_inputs = { - "input_ids": {0: "batch_size"}, + "input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "past_sequence_length + 1"}, } for i in range(self._normalized_config.decoder_num_layers): @@ -645,7 +645,11 @@ def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super(OnnxConfigWithPast, self).outputs if self.use_past: # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. - for i in range(self._normalized_config.encoder_num_layers): + for i in range( + self._normalized_config.encoder_num_layers + if self.task != "text-generation" + else self._normalized_config.decoder_num_layers + ): common_outputs[f"present.{i}.key"] = {0: "batch_size", 2: "past_sequence_length + sequence_length"} common_outputs[f"present.{i}.value"] = { 0: "batch_size",