diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index b3bad65954d..bfdfbff1b11 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -1091,10 +1091,10 @@ def forward( onnx_outputs = self.model.run(None, onnx_inputs) model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - # TODO: why do we only return last_hidden_state? why not all outputs? - # that way, there will be less need for ORTModelForCustomTask in cases where - # we just want to extend model outputs with attentions, hidden_states, etc. - last_hidden_state = model_outputs["last_hidden_state"] + if "last_hidden_state" in self.output_names: + last_hidden_state = model_outputs[self.output_names["last_hidden_state"]] + else: + last_hidden_state = model_outputs[0] # converts output to namedtuple for pipelines post-processing return BaseModelOutput(last_hidden_state=last_hidden_state)