From 6c4e2fc61f6fb74a8c1cf60efada751c797d267d Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Thu, 2 Nov 2023 16:57:17 +0100 Subject: [PATCH] add depreciation warning --- optimum/onnxruntime/base.py | 12 ++++++++---- tests/onnxruntime/test_modeling.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index c87d02ea557..1a80152d471 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -14,11 +14,11 @@ """Defines the base classes that are used to perform inference with ONNX Runtime of Transformers models.""" from abc import abstractmethod -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, Union import numpy as np import torch -from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput +from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from onnxruntime import InferenceSession @@ -495,5 +495,9 @@ def prepare_inputs_for_merged( class ORTDecoder(ORTDecoderForSeq2Seq): - # TODO : add warning message - pass + + def __init__(self, *args, **kwargs): + logger.warning( + f"The class `ORTDecoder` is deprecated and will be removed in optimum v1.15.0, please use `ORTDecoderForSeq2Seq` instead." + ) + super().__init__(*args, **kwargs) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 09a322b39e1..baa62f1670c 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -91,7 +91,7 @@ ORTModelForVision2Seq, ORTStableDiffusionPipeline, ) -from optimum.onnxruntime.base import ORTDecoder, ORTDecoderForSeq2Seq, ORTEncoder +from optimum.onnxruntime.base import ORTDecoderForSeq2Seq, ORTEncoder from optimum.onnxruntime.modeling_diffusion import ( ORTModelTextEncoder, ORTModelUnet,