From 605ed7e36496822b82922d0d840a9c914d367e33 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Thu, 9 Jan 2025 11:05:23 +0100 Subject: [PATCH] Replace check_if_xxx_greater with is_xxx_version (#2152) * add version and avaibility check utils * replace check_if_transformers_greater with is_transformers_version * fix style * fix style * fix --- optimum/exporters/executorch/__main__.py | 4 +- optimum/exporters/executorch/convert.py | 4 +- optimum/exporters/onnx/base.py | 6 +- optimum/exporters/onnx/convert.py | 6 +- optimum/exporters/onnx/model_configs.py | 24 ++--- optimum/exporters/onnx/utils.py | 15 ++- optimum/onnxruntime/modeling_decoder.py | 8 +- optimum/onnxruntime/modeling_diffusion.py | 8 +- optimum/onnxruntime/modeling_seq2seq.py | 8 +- optimum/onnxruntime/trainer.py | 6 +- optimum/onnxruntime/trainer_seq2seq.py | 4 +- optimum/onnxruntime/training_args.py | 6 +- optimum/pipelines/pipelines_base.py | 4 +- optimum/utils/__init__.py | 4 + optimum/utils/import_utils.py | 119 ++++++++++++++-------- optimum/utils/input_generators.py | 8 +- tests/onnxruntime/test_diffusion.py | 13 +-- tests/onnxruntime/test_modeling.py | 12 +-- 18 files changed, 148 insertions(+), 111 deletions(-) diff --git a/optimum/exporters/executorch/__main__.py b/optimum/exporters/executorch/__main__.py index 33a668b0674..e3b561f0f06 100644 --- a/optimum/exporters/executorch/__main__.py +++ b/optimum/exporters/executorch/__main__.py @@ -20,7 +20,7 @@ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from transformers.utils import is_torch_available -from optimum.utils.import_utils import check_if_transformers_greater +from optimum.utils.import_utils import is_transformers_version from ...commands.export.executorch import parse_args_executorch from .convert import export_to_executorch @@ -95,7 +95,7 @@ def main_export( ``` """ - if not check_if_transformers_greater("4.46"): + if is_transformers_version("<", "4.46"): raise ValueError( "The minimum Transformers version compatible with ExecuTorch is 4.46.0. Please upgrade to Transformers 4.46.0 or later." ) diff --git a/optimum/exporters/executorch/convert.py b/optimum/exporters/executorch/convert.py index f50a4b54a96..aceb733d529 100644 --- a/optimum/exporters/executorch/convert.py +++ b/optimum/exporters/executorch/convert.py @@ -19,7 +19,7 @@ from transformers.utils import is_torch_available -from optimum.utils.import_utils import check_if_transformers_greater +from optimum.utils.import_utils import is_transformers_version from .recipe_registry import discover_recipes, recipe_registry @@ -27,7 +27,7 @@ if is_torch_available(): from transformers.modeling_utils import PreTrainedModel -if check_if_transformers_greater("4.46"): +if is_transformers_version(">=", "4.46"): from transformers.integrations.executorch import ( TorchExportableModuleWithStaticCache, ) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index b5adb4522a2..85cadcdbc4c 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -44,7 +44,7 @@ from ...utils import TORCH_MINIMUM_VERSION as GLOBAL_MIN_TORCH_VERSION from ...utils import TRANSFORMERS_MINIMUM_VERSION as GLOBAL_MIN_TRANSFORMERS_VERSION from ...utils.doc import add_dynamic_docstring -from ...utils.import_utils import check_if_transformers_greater, is_onnx_available, is_onnxruntime_available +from ...utils.import_utils import is_onnx_available, is_onnxruntime_available, is_transformers_version from ..base import ExportConfig from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME from .model_patcher import ModelPatcher, Seq2SeqModelPatcher @@ -156,7 +156,7 @@ class OnnxConfig(ExportConfig, ABC): ), "mask-generation": OrderedDict({"logits": {0: "batch_size"}}), "masked-im": OrderedDict( - {"reconstruction" if check_if_transformers_greater("4.29.0") else "logits": {0: "batch_size"}} + {"reconstruction" if is_transformers_version(">=", "4.29.0") else "logits": {0: "batch_size"}} ), "multiple-choice": OrderedDict({"logits": {0: "batch_size", 1: "num_choices"}}), "object-detection": OrderedDict( @@ -375,7 +375,7 @@ def is_transformers_support_available(self) -> bool: `bool`: Whether the install version of Transformers is compatible with the model. """ - return check_if_transformers_greater(self.MIN_TRANSFORMERS_VERSION) + return is_transformers_version(">=", self.MIN_TRANSFORMERS_VERSION.base_version) @property def is_torch_support_available(self) -> bool: diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 80d945580c7..ebcc8e07c12 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -35,9 +35,9 @@ DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, TORCH_MINIMUM_VERSION, - check_if_transformers_greater, is_diffusers_available, is_torch_onnx_support_available, + is_transformers_version, logging, require_numpy_strictly_lower, ) @@ -512,7 +512,7 @@ def export_pytorch( model_kwargs = model_kwargs or {} # num_logits_to_keep was added in transformers 4.45 and isn't added as inputs when exporting the model - if check_if_transformers_greater("4.44.99") and "num_logits_to_keep" in signature(model.forward).parameters.keys(): + if is_transformers_version(">=", "4.44.99") and "num_logits_to_keep" in signature(model.forward).parameters.keys(): model_kwargs["num_logits_to_keep"] = 0 with torch.no_grad(): @@ -1105,7 +1105,7 @@ def onnx_export_from_model( if isinstance(atol, dict): atol = atol[task.replace("-with-past", "")] - if check_if_transformers_greater("4.44.99"): + if is_transformers_version(">=", "4.44.99"): misplaced_generation_parameters = model.config._get_non_default_generation_parameters() if ( isinstance(model, GenerationMixin) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 8966a1b1a33..63bf220ca92 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -59,9 +59,9 @@ NormalizedTextConfig, NormalizedTextConfigWithGQA, NormalizedVisionConfig, - check_if_diffusers_greater, - check_if_transformers_greater, is_diffusers_available, + is_diffusers_version, + is_transformers_version, logging, ) from ...utils.normalized_config import NormalizedConfigManager @@ -310,7 +310,7 @@ class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): # OPT does not take position_ids as input for transfomers < v4.46, needs it for transformers >= v4.46 -if check_if_transformers_greater("4.45.99"): +if is_transformers_version(">=", "4.45.99"): class OPTOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. @@ -370,8 +370,7 @@ class Phi3OnnxConfig(PhiOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.41.0") def __init__(self, *args, **kwargs): - # TODO : replace check_if_transformers_greater with is_transformers_available - if check_if_transformers_greater("4.46.0") and not check_if_transformers_greater("4.46.1"): + if is_transformers_version("==", "4.46.0"): logger.error( "Found transformers v4.46.0 while trying to exporting a Phi3 model, this specific version of transformers is not supported. " "Please upgrade to v4.46.1 or higher, or downgrade your transformers version" @@ -417,7 +416,7 @@ class BloomOnnxConfig(TextDecoderOnnxConfig): DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): - if check_if_transformers_greater("4.44"): + if is_transformers_version(">=", "4.44"): super().add_past_key_values(inputs_or_outputs, direction) else: if direction not in ["inputs", "outputs"]: @@ -1437,11 +1436,11 @@ def inputs(self): common_inputs = super().inputs common_inputs["hidden_states"] = {0: "batch_size", 1: "packed_height_width"} common_inputs["txt_ids"] = ( - {0: "sequence_length"} if check_if_diffusers_greater("0.31.0") else {0: "batch_size", 1: "sequence_length"} + {0: "sequence_length"} if is_diffusers_version(">=", "0.31.0") else {0: "batch_size", 1: "sequence_length"} ) common_inputs["img_ids"] = ( {0: "packed_height_width"} - if check_if_diffusers_greater("0.31.0") + if is_diffusers_version(">=", "0.31.0") else {0: "batch_size", 1: "packed_height_width"} ) @@ -1774,7 +1773,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis. if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs: - if check_if_transformers_greater("4.43.0"): + if is_transformers_version(">=", "4.43.0"): # since https://github.com/huggingface/transformers/pull/31166 common_inputs["cache_position"] = {0: "decoder_sequence_length"} @@ -2461,12 +2460,7 @@ class Pix2StructOnnxConfig(OnnxSeq2SeqConfigWithPast): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # TODO : replace check_if_transformers_greater with is_transformers_available - if ( - check_if_transformers_greater("4.46.0") - and not check_if_transformers_greater("4.46.1") - and self._behavior is ConfigBehavior.DECODER - ): + if is_transformers_version("==", "4.46.0") and self._behavior is ConfigBehavior.DECODER: logger.error( "Found transformers v4.46.0 while trying to exporting a Pix2Struct model, this specific version of transformers is not supported. " "Please upgrade to v4.46.1 or higher, or downgrade your transformers version" diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 19e24f88743..3659480abfe 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -20,14 +20,13 @@ from packaging import version from transformers.utils import is_tf_available, is_torch_available -from ...utils import ( - DIFFUSERS_MINIMUM_VERSION, - ORT_QUANTIZE_MINIMUM_VERSION, - check_if_diffusers_greater, +from ...utils import DIFFUSERS_MINIMUM_VERSION, ORT_QUANTIZE_MINIMUM_VERSION, logging +from ...utils.import_utils import ( + _diffusers_version, is_diffusers_available, - logging, + is_diffusers_version, + is_transformers_version, ) -from ...utils.import_utils import _diffusers_version, check_if_transformers_greater from ..utils import ( _get_submodels_and_export_configs, ) @@ -52,7 +51,7 @@ if is_diffusers_available(): - if not check_if_diffusers_greater(DIFFUSERS_MINIMUM_VERSION.base_version): + if not is_diffusers_version(">=", DIFFUSERS_MINIMUM_VERSION.base_version): raise ImportError( f"We found an older version of diffusers {_diffusers_version} but we require diffusers to be >= {DIFFUSERS_MINIMUM_VERSION}. " "Please update diffusers by running `pip install --upgrade diffusers`" @@ -90,7 +89,7 @@ } -if check_if_transformers_greater("4.45.99"): +if is_transformers_version(">=", "4.45.99"): MODEL_TYPES_REQUIRING_POSITION_IDS.add("opt") diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 8f1d062221a..7c7a8fb8399 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -31,7 +31,7 @@ from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export from ..onnx.utils import check_model_uses_external_data -from ..utils import NormalizedConfigManager, check_if_transformers_greater +from ..utils import NormalizedConfigManager, is_transformers_version from ..utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST from ..utils.save_utils import maybe_save_preprocessors from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN @@ -43,7 +43,7 @@ if TYPE_CHECKING: from transformers import PretrainedConfig -if check_if_transformers_greater("4.25.0"): +if is_transformers_version(">=", "4.25.0"): from transformers.generation import GenerationMixin else: from transformers.generation_utils import GenerationMixin # type: ignore # noqa: F401 @@ -149,7 +149,7 @@ def __init__( self.generation_config = generation_config - if check_if_transformers_greater("4.44.99"): + if is_transformers_version(">=", "4.44.99"): misplaced_generation_parameters = self.config._get_non_default_generation_parameters() if len(misplaced_generation_parameters) > 0: logger.warning( @@ -562,7 +562,7 @@ def _from_pretrained( ) # Since transformers 4.44, the bloom model has been updated to use the standard cache format - use_old_bloom_modeling = not check_if_transformers_greater("4.44") + use_old_bloom_modeling = not is_transformers_version(">=", "4.44") for input_name in input_dims.keys(): if input_dims[input_name][0] == "batch_size x num_heads": use_old_bloom_modeling = True diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 66b08e1ef66..193d75e0d44 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -51,7 +51,7 @@ from transformers.modeling_outputs import ModelOutput import onnxruntime as ort -from optimum.utils import check_if_diffusers_greater +from optimum.utils import is_diffusers_version from ..exporters.onnx import main_export from ..onnx.utils import _get_model_external_data_paths @@ -75,7 +75,7 @@ ) -if check_if_diffusers_greater("0.25.0"): +if is_diffusers_version(">=", "0.25.0"): from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution else: from diffusers.models.vae import DiagonalGaussianDistribution # type: ignore @@ -974,7 +974,7 @@ def __init__(self, *args, **kwargs): ) -if check_if_diffusers_greater("0.29.0"): +if is_diffusers_version(">=", "0.29.0"): from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) @@ -1006,7 +1006,7 @@ class ORTStableDiffusion3Img2ImgPipeline(ORTUnavailablePipeline): MIN_VERSION = "0.29.0" -if check_if_diffusers_greater("0.30.0"): +if is_diffusers_version(">=", "0.30.0"): from diffusers import FluxPipeline, StableDiffusion3InpaintPipeline @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 27e0dc01b4c..fba81525824 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -43,7 +43,7 @@ from ..exporters.onnx import main_export from ..onnx.utils import _get_external_data_paths -from ..utils import check_if_transformers_greater +from ..utils import is_transformers_version from ..utils.file_utils import validate_file_exists from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .base import ORTDecoderForSeq2Seq, ORTEncoder @@ -64,13 +64,13 @@ ) -if check_if_transformers_greater("4.25.0"): +if is_transformers_version(">=", "4.25.0"): from transformers.generation import GenerationMixin else: from transformers.generation_utils import GenerationMixin # type: ignore -if check_if_transformers_greater("4.43.0"): +if is_transformers_version(">=", "4.43.0"): from transformers.cache_utils import EncoderDecoderCache else: EncoderDecoderCache = dict @@ -705,7 +705,7 @@ def show_deprecated_argument(arg_name): generation_config = GenerationConfig.from_model_config(config) self.generation_config = generation_config - if check_if_transformers_greater("4.44.99"): + if is_transformers_version(">=", "4.44.99"): misplaced_generation_parameters = self.config._get_non_default_generation_parameters() if len(misplaced_generation_parameters) > 0: logger.warning( diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 66273cbcf96..4c6ad2553dd 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -83,7 +83,7 @@ ) from ..utils import logging -from ..utils.import_utils import check_if_transformers_greater +from ..utils.import_utils import is_transformers_version from .training_args import ORTOptimizerNames, ORTTrainingArguments from .utils import ( is_onnxruntime_training_available, @@ -93,7 +93,7 @@ if is_apex_available(): from apex import amp -if check_if_transformers_greater("4.33"): +if is_transformers_version(">=", "4.33"): from transformers.integrations.deepspeed import ( deepspeed_init, deepspeed_load_checkpoint, @@ -102,7 +102,7 @@ else: from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled -if check_if_transformers_greater("4.39"): +if is_transformers_version(">=", "4.39"): from transformers.utils import is_torch_xla_available as is_torch_tpu_xla_available if is_torch_tpu_xla_available(): diff --git a/optimum/onnxruntime/trainer_seq2seq.py b/optimum/onnxruntime/trainer_seq2seq.py index 1565ffa6acb..a76374a5ec1 100644 --- a/optimum/onnxruntime/trainer_seq2seq.py +++ b/optimum/onnxruntime/trainer_seq2seq.py @@ -22,7 +22,7 @@ from transformers.trainer_utils import PredictionOutput from transformers.utils import is_accelerate_available, logging -from ..utils.import_utils import check_if_transformers_greater +from ..utils.import_utils import is_transformers_version from .trainer import ORTTrainer @@ -33,7 +33,7 @@ "The package `accelerate` is required to use the ORTTrainer. Please install it following https://huggingface.co/docs/accelerate/basic_tutorials/install." ) -if check_if_transformers_greater("4.33"): +if is_transformers_version(">=", "4.33"): from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled else: from transformers.deepspeed import is_deepspeed_zero3_enabled diff --git a/optimum/onnxruntime/training_args.py b/optimum/onnxruntime/training_args.py index 6135abc1376..6eb2bb49044 100644 --- a/optimum/onnxruntime/training_args.py +++ b/optimum/onnxruntime/training_args.py @@ -44,13 +44,13 @@ ) from transformers.utils.generic import strtobool -from ..utils.import_utils import check_if_transformers_greater +from ..utils.import_utils import is_transformers_version if is_torch_available(): import torch -if is_accelerate_available() and check_if_transformers_greater("4.38.0"): +if is_accelerate_available() and is_transformers_version(">=", "4.38.0"): from transformers.trainer_pt_utils import AcceleratorConfig @@ -481,7 +481,7 @@ def __post_init__(self): os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true") os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false") - if is_accelerate_available() and check_if_transformers_greater("4.38.0"): + if is_accelerate_available() and is_transformers_version(">=", "4.38.0"): if not isinstance(self.accelerator_config, (AcceleratorConfig)): if self.accelerator_config is None: self.accelerator_config = AcceleratorConfig() diff --git a/optimum/pipelines/pipelines_base.py b/optimum/pipelines/pipelines_base.py index 7690143f13f..0016c73ff05 100644 --- a/optimum/pipelines/pipelines_base.py +++ b/optimum/pipelines/pipelines_base.py @@ -46,7 +46,7 @@ from transformers.pipelines import infer_framework_load_model from ..bettertransformer import BetterTransformer -from ..utils import check_if_transformers_greater, is_onnxruntime_available +from ..utils import is_onnxruntime_available, is_transformers_version from ..utils.file_utils import find_files_matching_pattern @@ -189,7 +189,7 @@ def load_bettertransformer( if model_kwargs is None: # the argument was first introduced in 4.36.0 but most models didn't have an sdpa implementation then # see https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/modeling_utils.py#L1258 - if check_if_transformers_greater("4.36.0"): + if is_transformers_version(">=", "4.36.0"): model_kwargs = {"attn_implementation": "eager"} else: model_kwargs = {} diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index e2b53a7dbc7..a975b1e118a 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -37,6 +37,7 @@ is_auto_gptq_available, is_datasets_available, is_diffusers_available, + is_diffusers_version, is_gptqmodel_available, is_onnx_available, is_onnxruntime_available, @@ -44,6 +45,9 @@ is_sentence_transformers_available, is_timm_available, is_torch_onnx_support_available, + is_torch_version, + is_transformers_available, + is_transformers_version, require_numpy_strictly_lower, torch_version, ) diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 7174084a453..457d7ab5f3a 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -13,9 +13,10 @@ # limitations under the License. """Import utilities.""" +import importlib.metadata as importlib_metadata import importlib.util import inspect -import sys +import operator as op from collections import OrderedDict from contextlib import contextmanager from typing import Tuple, Union @@ -25,6 +26,18 @@ from transformers.utils import is_torch_available +TORCH_MINIMUM_VERSION = version.parse("1.11.0") +TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0") +DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0") +AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0 +GPTQMODEL_MINIMUM_VERSION = version.parse("1.6.0") + +# This is the minimal required version to support some ONNX Runtime features +ORT_QUANTIZE_MINIMUM_VERSION = version.parse("1.4.0") + +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + + def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version package_exists = importlib.util.find_spec(pkg_name) is not None @@ -41,57 +54,78 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ return package_exists -# The package importlib_metadata is in a different place, depending on the python version. -if sys.version_info < (3, 8): - import importlib_metadata -else: - import importlib.metadata as importlib_metadata - - -TORCH_MINIMUM_VERSION = version.parse("1.11.0") -TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0") -DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0") -AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0 -GPTQMODEL_MINIMUM_VERSION = version.parse("1.6.0") - - -# This is the minimal required version to support some ONNX Runtime features -ORT_QUANTIZE_MINIMUM_VERSION = version.parse("1.4.0") - - _onnx_available = _is_package_available("onnx") - -# importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.) -_onnxruntime_available = importlib.util.find_spec("onnxruntime") is not None - _pydantic_available = _is_package_available("pydantic") _accelerate_available = _is_package_available("accelerate") -_diffusers_available = _is_package_available("diffusers") _auto_gptq_available = _is_package_available("auto_gptq") _gptqmodel_available = _is_package_available("gptqmodel") _timm_available = _is_package_available("timm") _sentence_transformers_available = _is_package_available("sentence_transformers") _datasets_available = _is_package_available("datasets") +_diffusers_available, _diffusers_version = _is_package_available("diffusers", return_version=True) +_transformers_available, _transformers_version = _is_package_available("transformers", return_version=True) +# importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.) +_onnxruntime_available = _is_package_available("onnxruntime", return_version=False) + +# TODO : Remove torch_version = None if is_torch_available(): torch_version = version.parse(importlib_metadata.version("torch")) -_is_torch_onnx_support_available = is_torch_available() and ( - TORCH_MINIMUM_VERSION.major, - TORCH_MINIMUM_VERSION.minor, -) <= ( - torch_version.major, - torch_version.minor, -) +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 +def compare_versions(library_or_version: Union[str, version.Version], operation: str, requirement_version: str): + """ + Compare a library version to some requirement using a given operation. + + Arguments: + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = version.parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, version.parse(requirement_version)) -_diffusers_version = None -if _diffusers_available: - try: - _diffusers_version = importlib_metadata.version("diffusers") - except importlib_metadata.PackageNotFoundError: - _diffusers_available = False + +def is_transformers_version(operation: str, reference_version: str): + """ + Compare the current Transformers version to a given reference with an operation. + """ + if not _transformers_available: + return False + return compare_versions(version.parse(_transformers_version), operation, reference_version) + + +def is_diffusers_version(operation: str, reference_version: str): + """ + Compare the current diffusers version to a given reference with an operation. + """ + if not _diffusers_available: + return False + return compare_versions(version.parse(_diffusers_version), operation, reference_version) + + +def is_torch_version(operation: str, reference_version: str): + """ + Compare the current torch version to a given reference with an operation. + """ + if not is_torch_available(): + return False + + import torch + + return compare_versions(version.parse(version.parse(torch.__version__).base_version), operation, reference_version) + + +_is_torch_onnx_support_available = is_torch_available() and is_torch_version(">=", TORCH_MINIMUM_VERSION.base_version) def is_torch_onnx_support_available(): @@ -138,6 +172,10 @@ def is_datasets_available(): return _datasets_available +def is_transformers_available(): + return _transformers_available + + def is_auto_gptq_available(): if _auto_gptq_available: v = version.parse(importlib_metadata.version("auto_gptq")) @@ -177,6 +215,7 @@ def check_if_pytorch_greater(target_version: str, message: str): pass +# TODO : Remove check_if_transformers_greater, check_if_diffusers_greater, check_if_torch_greater def check_if_transformers_greater(target_version: Union[str, version.Version]) -> bool: """ Checks whether the current install of transformers is greater than or equal to the target version. @@ -259,15 +298,15 @@ def require_numpy_strictly_lower(package_version: str, message: str): ("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)), ( "transformers_431", - (lambda: check_if_transformers_greater("4.31"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.31")), + (lambda: is_transformers_version(">=", "4.31"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.31")), ), ( "transformers_432", - (lambda: check_if_transformers_greater("4.32"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.32")), + (lambda: is_transformers_version(">=", "4.32"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.32")), ), ( "transformers_434", - (lambda: check_if_transformers_greater("4.34"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.34")), + (lambda: is_transformers_version(">=", "4.34"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.34")), ), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ] diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 18a2a5a3fd1..553795e74e5 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -22,7 +22,7 @@ import numpy as np from transformers.utils import is_tf_available, is_torch_available -from ..utils import check_if_diffusers_greater, check_if_transformers_greater +from ..utils import is_diffusers_version, is_transformers_version from .normalized_config import ( NormalizedConfig, NormalizedEncoderDecoderConfig, @@ -1072,7 +1072,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - if check_if_transformers_greater("4.44"): + if is_transformers_version(">=", "4.44"): return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype) else: past_key_shape = ( @@ -1504,7 +1504,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int elif input_name == "img_ids": shape = ( [(self.height // 2) * (self.width // 2), 3] - if check_if_diffusers_greater("0.31.0") + if is_diffusers_version(">=", "0.31.0") else [self.batch_size, (self.height // 2) * (self.width // 2), 3] ) return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) @@ -1524,7 +1524,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int if input_name == "txt_ids": shape = ( [self.sequence_length, 3] - if check_if_diffusers_greater("0.31.0") + if is_diffusers_version(">=", "0.31.0") else [self.batch_size, self.sequence_length, 3] ) return self.random_int_tensor(shape, max_value=1, framework=framework, dtype=int_dtype) diff --git a/tests/onnxruntime/test_diffusion.py b/tests/onnxruntime/test_diffusion.py index 07f90e8984e..a2df69077e1 100644 --- a/tests/onnxruntime/test_diffusion.py +++ b/tests/onnxruntime/test_diffusion.py @@ -34,7 +34,7 @@ ORTPipelineForInpainting, ORTPipelineForText2Image, ) -from optimum.utils import check_if_transformers_greater +from optimum.utils import is_transformers_version from optimum.utils.testing_utils import grid_parameters, require_diffusers @@ -77,7 +77,7 @@ class ORTPipelineForText2ImageTest(ORTModelTestMixin): "stable-diffusion-xl", "latent-consistency", ] - if check_if_transformers_greater("4.45"): + if is_transformers_version(">=", "4.45"): SUPPORTED_ARCHITECTURES += ["stable-diffusion-3", "flux"] NEGATIVE_PROMPT_SUPPORTED_ARCHITECTURES = [ @@ -85,7 +85,8 @@ class ORTPipelineForText2ImageTest(ORTModelTestMixin): "stable-diffusion-xl", "latent-consistency", ] - if check_if_transformers_greater("4.45"): + + if is_transformers_version(">=", "4.45"): NEGATIVE_PROMPT_SUPPORTED_ARCHITECTURES += ["stable-diffusion-3"] CALLBACK_SUPPORTED_ARCHITECTURES = [ @@ -93,7 +94,7 @@ class ORTPipelineForText2ImageTest(ORTModelTestMixin): "stable-diffusion-xl", "latent-consistency", ] - if check_if_transformers_greater("4.45"): + if is_transformers_version(">=", "4.45"): CALLBACK_SUPPORTED_ARCHITECTURES += ["flux"] ORTMODEL_CLASS = ORTPipelineForText2Image @@ -341,7 +342,7 @@ class ORTPipelineForImage2ImageTest(ORTModelTestMixin): "stable-diffusion-xl", "latent-consistency", ] - if check_if_transformers_greater("4.45"): + if is_transformers_version(">=", "4.45"): SUPPORTED_ARCHITECTURES += ["stable-diffusion-3"] CALLBACK_SUPPORTED_ARCHITECTURES = [ @@ -578,7 +579,7 @@ class ORTPipelineForInpaintingTest(ORTModelTestMixin): "stable-diffusion", "stable-diffusion-xl", ] - if check_if_transformers_greater("4.45"): + if is_transformers_version(">=", "4.45"): SUPPORTED_ARCHITECTURES += ["stable-diffusion-3"] CALLBACK_SUPPORTED_ARCHITECTURES = [ diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 456ad73505e..d92888a8dd6 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -107,7 +107,7 @@ DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, logging, ) -from optimum.utils.import_utils import check_if_transformers_greater, is_diffusers_available +from optimum.utils.import_utils import is_diffusers_available, is_transformers_version from optimum.utils.testing_utils import ( grid_parameters, remove_directory, @@ -2333,17 +2333,17 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): "opt", ] - if check_if_transformers_greater("4.37"): + if is_transformers_version(">=", "4.37"): SUPPORTED_ARCHITECTURES.append("qwen2") - if check_if_transformers_greater("4.38"): + if is_transformers_version(">=", "4.38"): SUPPORTED_ARCHITECTURES.append("gemma") # TODO: fix "mpt" for which inference fails for transformers < v4.41 - if check_if_transformers_greater("4.41"): + if is_transformers_version(">=", "4.41"): SUPPORTED_ARCHITECTURES.extend(["phi3", "mpt"]) - if check_if_transformers_greater("4.45"): + if is_transformers_version(">=", "4.45"): SUPPORTED_ARCHITECTURES.append("granite") FULL_GRID = { @@ -4612,7 +4612,7 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) - if model_arch == "whisper" and check_if_transformers_greater("4.43"): + if model_arch == "whisper" and is_transformers_version(">=", "4.43"): gen_length = self.GENERATION_LENGTH + 2 else: gen_length = self.GENERATION_LENGTH + 1