diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 4192892f6eb..44549f9add3 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -175,9 +175,9 @@ class OnnxConfig(ExportConfig, ABC): ), "reinforcement-learning": OrderedDict( { - "return_preds": {0: "batch_size", 1: "sequence_length"}, - "action_preds": {0: "batch_size", 1: "sequence_length", 2: "act_dim"}, "state_preds": {0: "batch_size", 1: "sequence_length", 2: "state_dim"}, + "action_preds": {0: "batch_size", 1: "sequence_length", 2: "act_dim"}, + "return_preds": {0: "batch_size", 1: "sequence_length"}, "last_hidden_state": {0: "batch_size", 1: "sequence_length", 2: "last_hidden_state"}, } ),