diff --git a/.github/workflows/test_offline.yml b/.github/workflows/test_offline.yml index ca90730b6bc..90b0108e512 100644 --- a/.github/workflows/test_offline.yml +++ b/.github/workflows/test_offline.yml @@ -2,9 +2,9 @@ name: Offline usage / Python - Test on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -15,29 +15,33 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.9] + python-version: [3.8, 3.9] os: [ubuntu-20.04] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies for pytorch export - run: | - pip install .[tests,exporters,onnxruntime] - - name: Test with unittest - run: | - HF_HOME=/tmp/ huggingface-cli download hf-internal-testing/tiny-random-gpt2 + - name: Checkout code + uses: actions/checkout@v4 - HF_HOME=/tmp/ HF_HUB_OFFLINE=1 optimum-cli export onnx --model hf-internal-testing/tiny-random-gpt2 gpt2_onnx --task text-generation + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} - huggingface-cli download hf-internal-testing/tiny-random-gpt2 + - name: Install dependencies for pytorch export + run: | + pip install .[tests,exporters,onnxruntime] - HF_HUB_OFFLINE=1 optimum-cli export onnx --model hf-internal-testing/tiny-random-gpt2 gpt2_onnx --task text-generation + - name: Test with pytest + run: | + HF_HOME=/tmp/ huggingface-cli download hf-internal-testing/tiny-random-gpt2 - pytest tests/onnxruntime/test_modeling.py -k "test_load_model_from_hub and not from_hub_onnx" -s -vvvvv + HF_HOME=/tmp/ HF_HUB_OFFLINE=1 optimum-cli export onnx --model hf-internal-testing/tiny-random-gpt2 gpt2_onnx --task text-generation - HF_HUB_OFFLINE=1 pytest tests/onnxruntime/test_modeling.py -k "test_load_model_from_hub and not from_hub_onnx" -s -vvvvv \ No newline at end of file + huggingface-cli download hf-internal-testing/tiny-random-gpt2 + + HF_HUB_OFFLINE=1 optimum-cli export onnx --model hf-internal-testing/tiny-random-gpt2 gpt2_onnx --task text-generation + + pytest tests/onnxruntime/test_modeling.py -k "test_load_model_from_hub and not from_hub_onnx" -s -vvvvv + + HF_HUB_OFFLINE=1 pytest tests/onnxruntime/test_modeling.py -k "test_load_model_from_hub and not from_hub_onnx" -s -vvvvv diff --git a/.github/workflows/test_onnxruntime.yml b/.github/workflows/test_onnxruntime.yml index 4893b681a66..291a3b08335 100644 --- a/.github/workflows/test_onnxruntime.yml +++ b/.github/workflows/test_onnxruntime.yml @@ -50,6 +50,8 @@ jobs: pytest onnxruntime -m "run_in_series" --durations=0 -vvvv -s - name: Test with pytest (in parallel) + env: + FXMARTYCLONE_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} working-directory: tests run: | pytest onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 00000000000..c71afbbb459 --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,17 @@ +on: + push: + +name: Secret Leaks + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main + + diff --git a/docs/source/_redirects.yml b/docs/source/_redirects.yml index 4022ba1618d..47e3fab1db8 100644 --- a/docs/source/_redirects.yml +++ b/docs/source/_redirects.yml @@ -19,12 +19,18 @@ habana/tutorials/pretraining: habana/usage_guides/pretraining # Optimum Intel intel_index: intel/index -intel_quickstart: intel/optimization_inc -intel_configuration: intel/reference_inc -intel_optimization: intel/optimization_inc -intel_quantization: intel/optimization_inc -intel_pruning: intel/optimization_inc -intel_trainer: intel/reference_inc +intel_quickstart: intel/index +intel_configuration: intel/neural_compressor/reference +intel_optimization: intel/neural_compressor/optimization +intel_quantization: intel/neural_compressor/optimization +intel_pruning: intel/neural_compressor/optimization +intel_trainer: intel/neural_compressor/reference +intel/inference: intel/openvino/inference +intel/optimization_ov: intel/openvino/optimization +intel/reference_ov: intel/openvino/reference +intel/optimization_inc: intel/neural_compressor/optimization +intel/distributed_training: intel/neural_compressor/distributed_training +intel/reference_inc: intel/neural_compressor/reference # Optimum Neuron docs/optimum-neuron/index: /docs/optimum-neuron/index diff --git a/optimum/commands/__init__.py b/optimum/commands/__init__.py index 540ea4dd863..8a2a276d1c5 100644 --- a/optimum/commands/__init__.py +++ b/optimum/commands/__init__.py @@ -15,5 +15,4 @@ from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand from .env import EnvironmentCommand from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand -from .onnxruntime import ONNXRuntimeCommand, ONNXRuntimeOptimizeCommand, ONNXRuntimeQuantizeCommand -from .optimum_cli import register_optimum_cli_subcommand +from .optimum_cli import optimum_cli_subcommand diff --git a/optimum/commands/optimum_cli.py b/optimum/commands/optimum_cli.py index 4bae9bb5f82..64a7075c6ce 100644 --- a/optimum/commands/optimum_cli.py +++ b/optimum/commands/optimum_cli.py @@ -17,16 +17,57 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Type, Union +from ..subpackages import load_subpackages from ..utils import logging from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand from .env import EnvironmentCommand from .export import ExportCommand -from .onnxruntime import ONNXRuntimeCommand logger = logging.get_logger() -OPTIMUM_CLI_SUBCOMMANDS = [ExportCommand, EnvironmentCommand, ONNXRuntimeCommand] +# The table below contains the optimum-cli root subcommands provided by the optimum package +OPTIMUM_CLI_ROOT_SUBCOMMANDS = [ExportCommand, EnvironmentCommand] + +# The table below is dynamically populated when loading subpackages +_OPTIMUM_CLI_SUBCOMMANDS = [] + + +def optimum_cli_subcommand(parent_command: Optional[Type[BaseOptimumCLICommand]] = None): + """ + A decorator to declare optimum-cli subcommands. + + The declaration of an optimum-cli subcommand looks like this: + + ``` + @optimum_cli_subcommand() + class MySubcommand(BaseOptimumCLICommand): + + ``` + + or + + ``` + @optimum_cli_subcommand(ExportCommand) + class MySubcommand(BaseOptimumCLICommand): + + ``` + + Args: + parent_command: (`Optional[Type[BaseOptimumCLICommand]]`): + The class of the parent command or None if this is a top-level command. Defaults to None. + + """ + + if parent_command is not None and not issubclass(parent_command, BaseOptimumCLICommand): + raise ValueError(f"The parent command {parent_command} must be a subclass of BaseOptimumCLICommand") + + def wrapper(subcommand): + if not issubclass(subcommand, BaseOptimumCLICommand): + raise ValueError(f"The subcommand {subcommand} must be a subclass of BaseOptimumCLICommand") + _OPTIMUM_CLI_SUBCOMMANDS.append((subcommand, parent_command)) + + return wrapper def resolve_command_to_command_instance( @@ -137,15 +178,19 @@ def main(): root = RootOptimumCLICommand("Optimum CLI tool", usage="optimum-cli") parser = root.parser - for subcommand_cls in OPTIMUM_CLI_SUBCOMMANDS: + for subcommand_cls in OPTIMUM_CLI_ROOT_SUBCOMMANDS: register_optimum_cli_subcommand(subcommand_cls, parent_command=root) - commands_in_register = dynamic_load_commands_in_register() + # Load subpackages to give them a chance to declare their own subcommands + load_subpackages() + + # Register subcommands declared by the subpackages or found in the register files under commands/register + commands_to_register = _OPTIMUM_CLI_SUBCOMMANDS + dynamic_load_commands_in_register() command2command_instance = resolve_command_to_command_instance( - root, [parent_command_cls for _, parent_command_cls in commands_in_register if parent_command_cls is not None] + root, [parent_command_cls for _, parent_command_cls in commands_to_register if parent_command_cls is not None] ) - for command_or_command_info, parent_command in commands_in_register: + for command_or_command_info, parent_command in commands_to_register: if parent_command is None: parent_command_instance = root else: diff --git a/optimum/configuration_utils.py b/optimum/configuration_utils.py index 3216d4a94c3..ab5d6c057f8 100644 --- a/optimum/configuration_utils.py +++ b/optimum/configuration_utils.py @@ -18,9 +18,9 @@ import json import os import re +import warnings from typing import Any, Dict, List, Tuple, Union -from huggingface_hub import HfFolder from packaging import version from transformers import PretrainedConfig from transformers import __version__ as transformers_version_str @@ -93,7 +93,19 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: repo_id = self._create_repo(repo_id, **kwargs) use_auth_token = kwargs.get("use_auth_token", None) - token = HfFolder.get_token() if use_auth_token is True else use_auth_token + token = kwargs.get("token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "You cannot use both `use_auth_token` and `token` arguments at the same time." + ) + kwargs["token"] = use_auth_token + token = use_auth_token files_timestamps = self._get_files_timestamps(save_directory) @@ -197,6 +209,7 @@ def _get_config_dict( resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) use_auth_token = kwargs.pop("use_auth_token", None) + token = kwargs.pop("token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) trust_remote_code = kwargs.pop("trust_remote_code", None) @@ -205,6 +218,15 @@ def _get_config_dict( from_auto_class = kwargs.pop("_from_auto", False) commit_hash = kwargs.pop("_commit_hash", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if trust_remote_code is True: logger.warning( "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" @@ -255,7 +277,7 @@ def _get_config_dict( proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, user_agent=user_agent, ) else: @@ -268,7 +290,7 @@ def _get_config_dict( proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, user_agent=user_agent, revision=revision, subfolder=subfolder, diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 585a779c2e5..1e36af06ade 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -15,6 +15,7 @@ """Entry point to the optimum.exporters.onnx command line.""" import argparse +import warnings from pathlib import Path from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE @@ -66,6 +67,7 @@ def main_export( force_download: bool = False, local_files_only: bool = False, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, for_ort: bool = False, do_validation: bool = True, model_kwargs: Optional[Dict[str, Any]] = None, @@ -135,9 +137,11 @@ def main_export( cached versions if they exist. local_files_only (`Optional[bool]`, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). - use_auth_token (`Optional[str]`, defaults to `None`): + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Experimental usage: keyword arguments to pass to the model during the export. This argument should be used along the `custom_onnx_configs` argument @@ -174,6 +178,15 @@ def main_export( ``` """ + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if fp16: if dtype is not None: raise ValueError( @@ -250,7 +263,7 @@ def main_export( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -283,7 +296,7 @@ def main_export( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 44021e959b1..78bdb639029 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -18,12 +18,14 @@ import inspect import itertools import os +import warnings from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import huggingface_hub from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.errors import OfflineModeIsEnabled from packaging import version from requests.exceptions import ConnectionError as RequestsConnectionError from transformers import AutoConfig, PretrainedConfig, is_tf_available, is_torch_available @@ -1392,9 +1394,19 @@ def get_model_files( model_name_or_path: Union[str, Path], subfolder: str = "", cache_dir: str = HUGGINGFACE_HUB_CACHE, - use_auth_token: Optional[str] = None, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + request_exception = None full_model_path = Path(model_name_or_path) / subfolder if full_model_path.is_dir(): @@ -1410,12 +1422,12 @@ def get_model_files( all_files = huggingface_hub.list_repo_files( model_name_or_path, repo_type="model", - token=use_auth_token, + token=token, revision=revision, ) if subfolder != "": all_files = [file[len(subfolder) + 1 :] for file in all_files if file.startswith(subfolder)] - except (RequestsConnectionError, huggingface_hub.utils._http.OfflineModeIsEnabled) as e: + except (RequestsConnectionError, OfflineModeIsEnabled) as e: request_exception = e object_id = model_name_or_path.replace("/", "--") full_model_path = Path(cache_dir, f"models--{object_id}") @@ -1589,7 +1601,7 @@ def _infer_task_from_model_name_or_path( ) try: model_info = huggingface_hub.model_info(model_name_or_path, revision=revision) - except (RequestsConnectionError, huggingface_hub.utils._http.OfflineModeIsEnabled): + except (RequestsConnectionError, OfflineModeIsEnabled): raise RuntimeError( f"Hugging Face Hub is not reachable and we cannot infer the task from a cached model. Make sure you are not offline, or otherwise please specify the `task` (or `--task` in command-line) argument ({', '.join(TasksManager.get_all_tasks())})." ) @@ -1707,7 +1719,8 @@ def infer_library_from_model( revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, library_name: Optional[str] = None, - use_auth_token: Optional[str] = None, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, ): """ Infers the library from the model repo. @@ -1725,16 +1738,30 @@ def infer_library_from_model( Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used. library_name (`Optional[str]`, *optional*): The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers". - use_auth_token (`Optional[str]`, defaults to `None`): - The token to use as HTTP bearer authorization for remote files. + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + Returns: `str`: The library name automatically detected from the model repo. """ + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if library_name is not None: return library_name all_files, _ = TasksManager.get_model_files( - model_name_or_path, subfolder, cache_dir, use_auth_token=use_auth_token + model_name_or_path, subfolder, cache_dir, token=token, revision=revision ) if "model_index.json" in all_files: @@ -1750,7 +1777,7 @@ def infer_library_from_model( "subfolder": subfolder, "revision": revision, "cache_dir": cache_dir, - "use_auth_token": use_auth_token, + "token": token, } config_dict, kwargs = PretrainedConfig.get_config_dict(model_name_or_path, **kwargs) model_config = PretrainedConfig.from_dict(config_dict, **kwargs) @@ -1925,12 +1952,23 @@ def get_model_from_task( elif library_name == "sentence_transformers": cache_folder = model_kwargs.pop("cache_folder", None) use_auth_token = model_kwargs.pop("use_auth_token", None) + token = model_kwargs.pop("token", None) trust_remote_code = model_kwargs.pop("trust_remote_code", False) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + model = model_class( model_name_or_path, device=device, cache_folder=cache_folder, - use_auth_token=use_auth_token, + token=token, trust_remote_code=trust_remote_code, ) else: diff --git a/optimum/gptq/data.py b/optimum/gptq/data.py index 37a42714fc8..b8734da478e 100644 --- a/optimum/gptq/data.py +++ b/optimum/gptq/data.py @@ -182,40 +182,11 @@ def get_c4_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train") def get_ptb(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): - if split == "train": - data = load_dataset("ptb_text_only", "penn_treebank", split="train") - elif split == "validation": - data = load_dataset("ptb_text_only", "penn_treebank", split="validation") - - enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt") - - dataset = [] - for _ in range(nsamples): - i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = enc.input_ids[:, i:j] - attention_mask = torch.ones_like(inp) - dataset.append({"input_ids": inp, "attention_mask": attention_mask}) - - return dataset + raise RuntimeError("Loading the `ptb` dataset was deprecated") def get_ptb_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): - if split == "train": - data = load_dataset("ptb_text_only", "penn_treebank", split="train") - elif split == "validation": - data = load_dataset("ptb_text_only", "penn_treebank", split="test") - - enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt") - - dataset = [] - for _ in range(nsamples): - i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = enc.input_ids[:, i:j] - attention_mask = torch.ones_like(inp) - dataset.append({"input_ids": inp, "attention_mask": attention_mask}) - return dataset + raise RuntimeError("Loading the `ptb` dataset was deprecated") def get_dataset( @@ -226,7 +197,7 @@ def get_dataset( Args: dataset_name (`str`): - Dataset name. Available options are `['wikitext2', 'c4', 'ptb', 'c4-new', 'ptb_new']`. + Dataset name. Available options are `['wikitext2', 'c4', 'c4-new']`. tokenizer (`Any`): Tokenizer of the model nsamples (`int`, defaults to `128`): @@ -247,11 +218,13 @@ def get_dataset( "wikitext2": get_wikitext2, "c4": get_c4, "c4-new": get_c4_new, - "ptb": get_ptb, - "ptb-new": get_ptb_new, } if split not in ["train", "validation"]: raise ValueError(f"The split need to be 'train' or 'validation' but found {split}") + if dataset_name in {"ptb", "ptb-new"}: + raise ValueError( + f"{dataset_name} dataset was deprecated, only the following dataset are supported : {list(get_dataset_map)}" + ) if dataset_name not in get_dataset_map: raise ValueError(f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}") get_dataset_fn = get_dataset_map[dataset_name] diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 2c2c9d7e71a..902af87bbb0 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -432,7 +432,10 @@ def store_input_hook(_, input, *args): for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu - data[k] = v.to(0) + if not has_device_map or device.type == "cpu": + data[k] = v.to(0) + else: + data[k] = v.to(device) try: model(**data) except ValueError: @@ -458,7 +461,10 @@ def store_input_hook(_, input, *args): for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu - data[k] = v.to(0) + if not has_device_map or device.type == "cpu": + data[k] = v.to(0) + else: + data[k] = v.to(device) try: model(**data) except ValueError: diff --git a/optimum/modeling_base.py b/optimum/modeling_base.py index 9663a311692..74b05d5b151 100644 --- a/optimum/modeling_base.py +++ b/optimum/modeling_base.py @@ -17,11 +17,12 @@ import logging import os import subprocess +import warnings from abc import ABC, abstractmethod from pathlib import Path from typing import TYPE_CHECKING, Optional, Union -from huggingface_hub import HfApi, HfFolder +from huggingface_hub import create_repo, upload_file from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from transformers import AutoConfig, PretrainedConfig, add_start_docstrings @@ -51,9 +52,11 @@ force_download (`bool`, defaults to `True`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. - use_auth_token (`Optional[str]`, defaults to `None`): + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). cache_dir (`Optional[str]`, defaults to `None`): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. @@ -156,33 +159,33 @@ def push_to_hub( save_directory: str, repository_id: str, private: Optional[bool] = None, - use_auth_token: Union[bool, str] = True, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, ) -> str: - if isinstance(use_auth_token, str): - huggingface_token = use_auth_token - elif use_auth_token: - huggingface_token = HfFolder.get_token() - else: - raise ValueError("You need to proivde `use_auth_token` to be able to push to the hub") - api = HfApi() - - user = api.whoami(huggingface_token) - self.git_config_username_and_email(git_email=user["email"], git_user=user["fullname"]) - - api.create_repo( - token=huggingface_token, + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + create_repo( + token=token, repo_id=repository_id, exist_ok=True, private=private, ) + for path, subdirs, files in os.walk(save_directory): for name in files: local_file_path = os.path.join(path, name) _, hub_file_path = os.path.split(local_file_path) # FIXME: when huggingface_hub fixes the return of upload_file try: - api.upload_file( - token=huggingface_token, + upload_file( + token=token, repo_id=f"{repository_id}", path_or_fileobj=os.path.join(os.getcwd(), local_file_path), path_in_repo=hub_file_path, @@ -222,18 +225,28 @@ def _load_config( config_name_or_path: Union[str, os.PathLike], revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, - use_auth_token: Optional[Union[bool, str]] = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, force_download: bool = False, subfolder: str = "", trust_remote_code: bool = False, ) -> PretrainedConfig: + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + try: config = AutoConfig.from_pretrained( pretrained_model_name_or_path=config_name_or_path, revision=revision, cache_dir=cache_dir, force_download=force_download, - use_auth_token=use_auth_token, + token=token, subfolder=subfolder, trust_remote_code=trust_remote_code, ) @@ -245,7 +258,7 @@ def _load_config( revision=revision, cache_dir=cache_dir, force_download=force_download, - use_auth_token=use_auth_token, + token=token, trust_remote_code=trust_remote_code, ) logger.info( @@ -261,6 +274,7 @@ def _from_pretrained( model_id: Union[str, Path], config: PretrainedConfig, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -277,6 +291,7 @@ def _from_transformers( model_id: Union[str, Path], config: PretrainedConfig, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -297,6 +312,7 @@ def _export( model_id: Union[str, Path], config: PretrainedConfig, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -317,7 +333,8 @@ def from_pretrained( model_id: Union[str, Path], export: bool = False, force_download: bool = False, - use_auth_token: Optional[str] = None, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", config: Optional[PretrainedConfig] = None, @@ -330,6 +347,16 @@ def from_pretrained( Returns: `OptimizedModel`: The loaded optimized model. """ + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if isinstance(model_id, Path): model_id = model_id.as_posix() @@ -347,9 +374,7 @@ def from_pretrained( ) model_id, revision = model_id.split("@") - library_name = TasksManager.infer_library_from_model( - model_id, subfolder, revision, cache_dir, use_auth_token=use_auth_token - ) + library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir, token=token) if library_name == "timm": config = PretrainedConfig.from_pretrained(model_id, subfolder, revision) @@ -374,7 +399,7 @@ def from_pretrained( model_id, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, force_download=force_download, subfolder=subfolder, trust_remote_code=trust_remote_code, @@ -384,7 +409,7 @@ def from_pretrained( config, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, force_download=force_download, subfolder=subfolder, trust_remote_code=trust_remote_code, @@ -405,7 +430,7 @@ def from_pretrained( revision=revision, cache_dir=cache_dir, force_download=force_download, - use_auth_token=use_auth_token, + token=token, subfolder=subfolder, local_files_only=local_files_only, trust_remote_code=trust_remote_code, diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index bf9c80a86cd..16461dce957 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -14,7 +14,7 @@ """Defines the base classes that are used to perform inference with ONNX Runtime of Transformers models.""" from abc import abstractmethod -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, Union +from typing import Dict, Optional, Set, Tuple, Union import numpy as np import torch @@ -24,22 +24,22 @@ from ..utils import NormalizedConfigManager from ..utils.logging import warn_once +from .modeling_ort import ORTModel from .utils import get_ordered_input_names, logging logger = logging.get_logger(__name__) -if TYPE_CHECKING: - from .modeling_ort import ORTModel - - class ORTModelPart: """ For multi-file ONNX models, such as encoder-decoder models, represents a part of the model. It has its own `onnxruntime.InferenceSession`, and can perform a forward pass. """ + _prepare_onnx_inputs = ORTModel._prepare_onnx_inputs + _prepare_onnx_outputs = ORTModel._prepare_onnx_outputs + def __init__( self, session: InferenceSession, @@ -53,6 +53,8 @@ def __init__( self.main_input_name = self.parent_model.main_input_name self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()} + self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()} self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward) @@ -98,25 +100,13 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - if use_torch: - onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()} - - # Add the attention_mask inputs when needed - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() - else: - onnx_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} - # Add the attention_mask inputs when needed - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - # Run inference - outputs = self.session.run(None, onnx_inputs) - - last_hidden_state = outputs[self.output_names["last_hidden_state"]] - if use_torch: - last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) + last_hidden_state = model_outputs["last_hidden_state"] return BaseModelOutput(last_hidden_state=last_hidden_state) @@ -350,83 +340,29 @@ def forward( else: raise ValueError("Unsupported num_pkv") else: - if use_torch: - onnx_inputs = { - "input_ids": input_ids.cpu().detach().numpy(), - } - - # Add the encoder_hidden_states inputs when needed - if "encoder_hidden_states" in self.input_names: - onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy() - - # Add the decoder_attention_mask inputs when needed - if "decoder_attention_mask" in self.input_names: - onnx_inputs["decoder_attention_mask"] = decoder_attention_mask.cpu().detach().numpy() - - # Add the encoder_attention_mask inputs when needed - if "encoder_attention_mask" in self.input_names: - onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy() - - if past_key_values is not None: - # Add the past_key_values to the decoder inputs - for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): - onnx_inputs[input_name] = past_key_value.cpu().detach().numpy() - - if "labels" in self.input_names: - # TODO: Any preprocessing like `self._shift_right(labels)`? - onnx_inputs["labels"] = labels.cpu().detach().numpy() - - if self.parent_model.use_merged is True: - onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy() - else: - onnx_inputs = { - "input_ids": input_ids, - } - - # Add the encoder_hidden_states inputs when needed - if "encoder_hidden_states" in self.input_names: - onnx_inputs["encoder_hidden_states"] = encoder_hidden_states - - # Add the decoder_attention_mask inputs when needed - if "decoder_attention_mask" in self.input_names: - onnx_inputs["decoder_attention_mask"] = decoder_attention_mask - - # Add the encoder_attention_mask inputs when needed - if "encoder_attention_mask" in self.input_names: - onnx_inputs["encoder_attention_mask"] = encoder_attention_mask - - if past_key_values is not None: - # Add the past_key_values to the decoder inputs - for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): - onnx_inputs[input_name] = past_key_value - - if "labels" in self.input_names: - # TODO: Any preprocessing like `self._shift_right(labels)`? - onnx_inputs["labels"] = labels - - if self.parent_model.use_merged is True: - onnx_inputs["use_cache_branch"] = use_cache_branch_tensor + model_inputs = { + "input_ids": input_ids, + "encoder_hidden_states": encoder_hidden_states, + "decoder_attention_mask": decoder_attention_mask, + "encoder_attention_mask": encoder_attention_mask, + "use_cache_branch": use_cache_branch_tensor, + "labels": labels, + } + if past_key_values is not None: + model_inputs.update(zip(self.key_value_input_names, past_key_values)) - # Run inference - outputs = self.session.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - # TODO: using two loops here is probably unefficient + # TODO: using a new variable out_past_key_values is memory inefficient, + # past_key_values is not used anymore at this point # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the # self-attention layer and 2 to the cross-attention layer) - out_past_key_values = tuple( - torch.from_numpy(outputs[self.output_names[key]]).to(self.device) - for key in self.key_value_output_names - ) - - logits = outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + out_past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names) - loss = None - if "loss" in self.output_names: - loss = outputs[self.output_names["loss"]] - if use_torch: - loss = torch.from_numpy(loss).to(self.device) + loss = model_outputs.get("loss", None) + logits = model_outputs["logits"] # TODO: this is extremely ugly and unreadable. What if cross-attention k/v change? # Tuple of tuple of length `n_layers`, with each tuple of length equal to: diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 2d9be2d757f..fd7e741d7c0 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -14,6 +14,7 @@ """Classes handling causal-lm related architectures in ONNX Runtime.""" import logging +import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -46,7 +47,7 @@ if check_if_transformers_greater("4.25.0"): from transformers.generation import GenerationMixin else: - from transformers.generation_utils import GenerationMixin + from transformers.generation_utils import GenerationMixin # type: ignore # noqa: F401 logger = logging.getLogger(__name__) @@ -139,15 +140,16 @@ def __init__( self.num_pkv = 2 self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) - self.key_value_input_names = [key for key in self.inputs_names if (".key" in key) or (".value" in key)] + self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)] self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)] self.use_cache = len(self.key_value_input_names) > 0 if generation_config is None: generation_config = GenerationConfig.from_model_config(config) + self.generation_config = generation_config self.onnx_paths = [self.model_path] - self.use_merged = "use_cache_branch" in self.inputs_names + self.use_merged = "use_cache_branch" in self.input_names self.model_type = self.config.model_type self.use_fp16 = False @@ -160,7 +162,7 @@ def __init__( # Reference: https://github.com/huggingface/optimum/pull/1381 model_type = config.model_type.replace("_", "-") - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.inputs_names: + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.input_names: logger.warning( f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. " "We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support." @@ -202,7 +204,6 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - inputs = {} known_output_shapes = {} use_cache_branch = None loss = None @@ -226,10 +227,10 @@ def forward( # I suspect the reason is the contiguous python list that messes something up? model_inputs = [input_ids.contiguous()] - if "attention_mask" in self.inputs_names: + if "attention_mask" in self.input_names: model_inputs.append(attention_mask) - if "position_ids" in self.inputs_names: + if "position_ids" in self.input_names: if position_ids is None: raise ValueError("position_ids was not passed but is a required input for this ONNX model.") model_inputs.append(position_ids.contiguous()) @@ -240,12 +241,11 @@ def forward( if use_cache_branch is not None: model_inputs.append(use_cache_branch) - if "labels" in self.inputs_names: + if "labels" in self.input_names: model_inputs.append(labels) known_output_shapes.update({"loss": []}) - io_binding, output_shapes, output_buffers = self._prepare_io_binding( - self.model, + io_binding, output_shapes, output_buffers = self.prepare_io_binding( *model_inputs, known_output_shapes=known_output_shapes, ordered_input_names=self._ordered_input_names, @@ -259,53 +259,41 @@ def forward( io_binding.synchronize_outputs() if self.use_cache: - # Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2) - past_key_values = () - for name in self.key_value_output_names: - past_key_values += (output_buffers[name].view(output_shapes[name]),) + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2 for the self-attention) + past_key_values = tuple( + output_buffers[name].view(output_shapes[name]) for name in self.key_value_output_names + ) logits = output_buffers["logits"].view(output_shapes["logits"]) if "loss" in self.output_names: loss = output_buffers["loss"].view(output_shapes["loss"]) else: - inputs["input_ids"] = input_ids.cpu().detach().numpy() if use_torch else input_ids - - if "attention_mask" in self.inputs_names: - inputs["attention_mask"] = attention_mask.cpu().detach().numpy() if use_torch else attention_mask - - if "labels" in self.inputs_names: - inputs["labels"] = labels.cpu().detach().numpy() if use_torch else labels - - if "position_ids" in self.inputs_names: - if position_ids is None: - raise ValueError("position_ids was not passed but is a required input for this ONNX model.") - inputs["position_ids"] = position_ids.cpu().detach().numpy() if use_torch else position_ids - - # Add the past_key_values to the decoder inputs + model_inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "use_cache_branch": use_cache_branch, + "labels": labels, + } if past_key_values is not None: - for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): - inputs[input_name] = past_key_value.cpu().detach().numpy() if use_torch else past_key_value + model_inputs.update( + zip(self.key_value_input_names, past_key_values), + ) - if use_cache_branch is not None: - inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if use_torch else use_cache_branch + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - outputs = self.model.run(None, inputs) + loss = model_outputs.get("loss", None) + logits = model_outputs["logits"] if self.use_cache: # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention) - past_key_values = tuple( - torch.from_numpy(outputs[self.output_names[key]]).to(self.device) - for key in self.key_value_output_names - ) - - logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device) - if "loss" in self.output_names: - loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device) + past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names) if self.use_cache and self.model_type != "gpt_bigcode": - # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and - # per decoder layer + # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and per decoder layer past_key_values = tuple( past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) ) @@ -406,6 +394,7 @@ def _from_pretrained( model_id: Union[str, Path], config: "PretrainedConfig", use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -421,6 +410,15 @@ def _from_pretrained( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ) -> "ORTModelForCausalLM": + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + model_path = Path(model_id) # We do not implement the logic for use_cache=False, use_merged=True @@ -450,7 +448,7 @@ def _from_pretrained( [DECODER_MERGED_ONNX_FILE_PATTERN], argument_name=None, subfolder=subfolder, - use_auth_token=use_auth_token, + token=token, revision=revision, ) use_merged = True @@ -472,7 +470,7 @@ def _from_pretrained( [r"^((?!decoder).)*.onnx", pattern], argument_name=None, subfolder=subfolder, - use_auth_token=use_auth_token, + token=token, revision=revision, ) file_name = decoder_path.name @@ -494,7 +492,7 @@ def _from_pretrained( model_cache_path, preprocessors = cls._cached_file( model_path=model_path, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, cache_dir=cache_dir, @@ -597,6 +595,7 @@ def _from_transformers( model_id: str, config: "PretrainedConfig", use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: str = "main", force_download: bool = True, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -611,6 +610,15 @@ def _from_transformers( use_io_binding: Optional[bool] = None, task: Optional[str] = None, ) -> "ORTModelForCausalLM": + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + file_name = ONNX_WEIGHTS_NAME if use_merged: @@ -636,7 +644,7 @@ def _from_transformers( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 63360ce80a8..f4e54752115 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -16,6 +16,7 @@ import logging import os import shutil +import warnings from abc import abstractmethod from pathlib import Path from tempfile import TemporaryDirectory @@ -272,6 +273,7 @@ def _from_pretrained( model_id: Union[str, Path], config: Dict[str, Any], use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, vae_decoder_file_name: str = ONNX_WEIGHTS_NAME, @@ -287,6 +289,15 @@ def _from_pretrained( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if provider == "TensorrtExecutionProvider": raise ValueError("The provider `'TensorrtExecutionProvider'` is not supported") @@ -314,7 +325,7 @@ def _from_pretrained( model_id, cache_dir=cache_dir, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, revision=revision, allow_patterns=allow_patterns, ignore_patterns=["*.msgpack", "*.safetensors", "*.bin", "*.xml"], @@ -376,6 +387,7 @@ def _from_transformers( model_id: str, config: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: str = "main", force_download: bool = True, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -388,6 +400,15 @@ def _from_transformers( use_io_binding: Optional[bool] = None, task: Optional[str] = None, ) -> "ORTStableDiffusionPipeline": + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if task is None: task = cls._auto_model_to_task(cls.auto_model_class) @@ -403,7 +424,7 @@ def _from_transformers( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index eb38a7fef12..b3bad65954d 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -16,13 +16,14 @@ import logging import re import shutil +import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch -from huggingface_hub import HfFolder, hf_hub_download +from huggingface_hub import hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import EntryNotFoundError from transformers import ( @@ -267,10 +268,13 @@ def __init__( **kwargs, ) - self.inputs_names = {input_key.name: idx for idx, input_key in enumerate(model.get_inputs())} + self.input_names = {input_key.name: idx for idx, input_key in enumerate(model.get_inputs())} + self.input_dtypes = {input_key.name: input_key.type for input_key in model.get_inputs()} + self.output_names = {output_key.name: idx for idx, output_key in enumerate(model.get_outputs())} + self.output_dtypes = {output_key.name: output_key.type for output_key in model.get_outputs()} - self._ordered_input_names = get_ordered_input_names(self.inputs_names.keys(), func=self.forward) + self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward) # TODO: why do we make device a property since we are only access the value, and do not do any check when setting the value? @property @@ -410,9 +414,19 @@ def infer_onnx_filename( argument_name: str, subfolder: str = "", use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, fail_if_not_found: bool = True, ) -> str: + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + onnx_files = [] for pattern in patterns: onnx_files = find_files_matching_pattern( @@ -420,7 +434,7 @@ def infer_onnx_filename( pattern, glob_pattern="**/*.onnx", subfolder=subfolder, - use_auth_token=use_auth_token, + token=token, revision=revision, ) if onnx_files: @@ -448,6 +462,7 @@ def _from_pretrained( model_id: Union[str, Path], config: "PretrainedConfig", use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -461,6 +476,15 @@ def _from_pretrained( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ) -> "ORTModel": + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + model_path = Path(model_id) regular_onnx_filenames = ORTModel._generate_regular_names_for_filename(ONNX_WEIGHTS_NAME) @@ -468,13 +492,8 @@ def _from_pretrained( if model_path.is_dir(): onnx_files = list(model_path.glob("*.onnx")) else: - if isinstance(use_auth_token, bool): - token = HfFolder().get_token() - else: - token = use_auth_token - repo_files, _ = TasksManager.get_model_files( - model_id, revision=revision, cache_dir=cache_dir, use_auth_token=token + model_id, revision=revision, cache_dir=cache_dir, token=token ) repo_files = map(Path, repo_files) @@ -499,7 +518,7 @@ def _from_pretrained( model_cache_path, preprocessors = cls._cached_file( model_path=model_path, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, cache_dir=cache_dir, @@ -535,6 +554,7 @@ def _from_transformers( model_id: str, config: "PretrainedConfig", use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -548,13 +568,23 @@ def _from_transformers( task: Optional[str] = None, ) -> "ORTModel": """The method will be deprecated in future releases.""" + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + return cls._export( model_id=model_id, config=config, revision=revision, cache_dir=cache_dir, force_download=force_download, - use_auth_token=use_auth_token, + token=token, subfolder=subfolder, local_files_only=local_files_only, trust_remote_code=trust_remote_code, @@ -571,6 +601,7 @@ def _export( model_id: str, config: "PretrainedConfig", use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -583,6 +614,15 @@ def _export( use_io_binding: Optional[bool] = None, task: Optional[str] = None, ) -> "ORTModel": + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if task is None: task = cls._auto_model_to_task(cls.auto_model_class) @@ -598,7 +638,7 @@ def _export( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -624,7 +664,8 @@ def from_pretrained( model_id: Union[str, Path], export: bool = False, force_download: bool = False, - use_auth_token: Optional[str] = None, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", config: Optional["PretrainedConfig"] = None, @@ -666,11 +707,21 @@ def from_pretrained( Returns: `ORTModel`: The loaded ORTModel model. """ + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + return super().from_pretrained( model_id, export=export, force_download=force_download, - use_auth_token=use_auth_token, + token=token, cache_dir=cache_dir, subfolder=subfolder, config=config, @@ -736,6 +787,7 @@ def _output_shape_inference(self, axis_name: Union[str, int], dimensions: Dict[s # exception. return int(eval(" ".join(tokens))) + # TODO: this method is bloated with state arguments (that are accesible using self) why ? def _prepare_io_binding( self, model: ort.InferenceSession, @@ -833,9 +885,15 @@ def _prepare_io_binding( return io_binding, output_shapes, output_buffers - def prepare_io_binding(self, *model_inputs, ordered_input_names, known_output_shapes=None): + def prepare_io_binding( + self, *model_inputs, ordered_input_names, outputs_to_not_bind=None, known_output_shapes=None + ): return self._prepare_io_binding( - self.model, ordered_input_names=ordered_input_names, known_output_shapes=known_output_shapes, *model_inputs + self.model, + *model_inputs, + ordered_input_names=ordered_input_names, + known_output_shapes=known_output_shapes, + outputs_to_not_bind=outputs_to_not_bind, ) def raise_on_numpy_input_io_binding(self, use_torch: bool): @@ -852,10 +910,44 @@ def raise_on_numpy_input_io_binding(self, use_torch: bool): " with model.use_io_binding = False, or pass torch.Tensor inputs instead." ) + def _prepare_onnx_inputs( + self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray] + ) -> Dict[str, np.ndarray]: + onnx_inputs = {} + + # converts pytorch inputs into numpy inputs for onnx + for input_name in self.input_names.keys(): + onnx_inputs[input_name] = inputs.pop(input_name) + + if use_torch: + onnx_inputs[input_name] = onnx_inputs[input_name].cpu().detach().numpy() + + if onnx_inputs[input_name].dtype != self.input_dtypes[input_name]: + onnx_inputs[input_name] = onnx_inputs[input_name].astype( + TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name]) + ) + + return onnx_inputs + + def _prepare_onnx_outputs( + self, use_torch: bool, *onnx_outputs: np.ndarray + ) -> Dict[str, Union[torch.Tensor, np.ndarray]]: + model_outputs = {} + + # converts onnxruntime outputs into tensor for standard outputs + for output_name, idx in self.output_names.items(): + model_outputs[output_name] = onnx_outputs[idx] + + if use_torch: + model_outputs[output_name] = torch.from_numpy(model_outputs[output_name]).to(self.device) + + return model_outputs + @staticmethod def _cached_file( model_path: Union[Path, str], use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -863,6 +955,15 @@ def _cached_file( subfolder: str = "", local_files_only: bool = False, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + model_path = Path(model_path) # locates a file in a local folder and repo, downloads and cache it if necessary. @@ -874,7 +975,7 @@ def _cached_file( repo_id=model_path.as_posix(), filename=file_name, subfolder=subfolder, - use_auth_token=use_auth_token, + token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -886,7 +987,7 @@ def _cached_file( repo_id=model_path.as_posix(), subfolder=subfolder, filename=file_name + "_data", - use_auth_token=use_auth_token, + token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -970,9 +1071,6 @@ def forward( self.raise_on_numpy_input_io_binding(use_torch) if self.device.type == "cuda" and self.use_io_binding: - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, attention_mask, @@ -985,35 +1083,21 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return BaseModelOutput( - last_hidden_state=output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) - ) + last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - if attention_mask is None: - attention_mask = np.ones_like(input_ids) - else: - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() - - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - outputs = self.model.run(None, onnx_inputs) - - last_hidden_state = outputs[self.output_names["last_hidden_state"]] - if use_torch: - last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - # converts output to namedtuple for pipelines post-processing - return BaseModelOutput(last_hidden_state=last_hidden_state) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + # TODO: why do we only return last_hidden_state? why not all outputs? + # that way, there will be less need for ORTModelForCustomTask in cases where + # we just want to extend model outputs with attentions, hidden_states, etc. + last_hidden_state = model_outputs["last_hidden_state"] + + # converts output to namedtuple for pipelines post-processing + return BaseModelOutput(last_hidden_state=last_hidden_state) @classmethod def _export( @@ -1021,6 +1105,7 @@ def _export( model_id: str, config: "PretrainedConfig", use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -1033,6 +1118,15 @@ def _export( use_io_binding: Optional[bool] = None, task: Optional[str] = None, ) -> "ORTModel": + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if task is None: task = cls._auto_model_to_task(cls.auto_model_class) @@ -1049,7 +1143,7 @@ def _export( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -1144,32 +1238,18 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return MaskedLMOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() - - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - # run inference - outputs = self.model.run(None, onnx_inputs) - logits = outputs[self.output_names["logits"]] + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + logits = model_outputs["logits"] - # converts output to namedtuple for pipelines post-processing - return MaskedLMOutput(logits=logits) + # converts output to namedtuple for pipelines post-processing + return MaskedLMOutput(logits=logits) QUESTION_ANSWERING_EXAMPLE = r""" @@ -1247,37 +1327,21 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return QuestionAnsweringModelOutput( - start_logits=output_buffers["start_logits"].view(output_shapes["start_logits"]), - end_logits=output_buffers["end_logits"].view(output_shapes["end_logits"]), - ) + # TODO: this is the same routine in all io binding branches, should we refactor it into a prepare_io_binding_outputs method? + start_logits = output_buffers["start_logits"].view(output_shapes["start_logits"]) + end_logits = output_buffers["end_logits"].view(output_shapes["end_logits"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() - - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - # run inference - outputs = self.model.run(None, onnx_inputs) - - start_logits = outputs[self.output_names["start_logits"]] - end_logits = outputs[self.output_names["end_logits"]] - if use_torch: - start_logits = torch.from_numpy(start_logits).to(self.device) - end_logits = torch.from_numpy(end_logits).to(self.device) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} + + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + start_logits = model_outputs["start_logits"] + end_logits = model_outputs["end_logits"] - # converts output to namedtuple for pipelines post-processing - return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits) + # converts output to namedtuple for pipelines post-processing + return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits) SEQUENCE_CLASSIFICATION_EXAMPLE = r""" @@ -1370,30 +1434,18 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return SequenceClassifierOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - outputs = self.model.run(None, onnx_inputs) + logits = model_outputs["logits"] - logits = outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) - - # converts output to namedtuple for pipelines post-processing - return SequenceClassifierOutput(logits=logits) + # converts output to namedtuple for pipelines post-processing + return SequenceClassifierOutput(logits=logits) TOKEN_CLASSIFICATION_EXAMPLE = r""" @@ -1472,32 +1524,17 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return TokenClassifierOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() - - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - # run inference - outputs = self.model.run(None, onnx_inputs) - logits = outputs[self.output_names["logits"]] + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - # converts output to namedtuple for pipelines post-processing - return TokenClassifierOutput(logits=logits) + logits = model_outputs["logits"] + + return TokenClassifierOutput(logits=logits) MULTIPLE_CHOICE_EXAMPLE = r""" @@ -1570,31 +1607,18 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return MultipleChoiceModelOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() - - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - # run inference - outputs = self.model.run(None, onnx_inputs) - logits = outputs[self.output_names["logits"]] + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - # converts output to namedtuple for pipelines post-processing - return MultipleChoiceModelOutput(logits=logits) + logits = model_outputs["logits"] + + # converts output to namedtuple for pipelines post-processing + return MultipleChoiceModelOutput(logits=logits) IMAGE_CLASSIFICATION_EXAMPLE = r""" @@ -1662,7 +1686,8 @@ def forward( if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( - pixel_values, ordered_input_names=self._ordered_input_names + pixel_values, + ordered_input_names=self._ordered_input_names, ) # run inference with binding & synchronize in case of multiple CUDA streams @@ -1670,25 +1695,18 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return ImageClassifierOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - pixel_values = pixel_values.cpu().detach().numpy() - - onnx_inputs = { - "pixel_values": pixel_values, - } + model_inputs = {"pixel_values": pixel_values} - # run inference - outputs = self.model.run(None, onnx_inputs) - logits = outputs[self.output_names["logits"]] + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + logits = model_outputs["logits"] - # converts output to namedtuple for pipelines post-processing - return ImageClassifierOutput(logits=logits) + # converts output to namedtuple for pipelines post-processing + return ImageClassifierOutput(logits=logits) SEMANTIC_SEGMENTATION_EXAMPLE = r""" @@ -1746,51 +1764,37 @@ class ORTModelForSemanticSegmentation(ORTModel): checkpoint="optimum/segformer-b0-finetuned-ade-512-512", ) ) - def forward(self, **kwargs): - use_torch = isinstance(next(iter(kwargs.values())), torch.Tensor) + def forward( + self, + pixel_values: Union[torch.Tensor, np.ndarray], + **kwargs, + ): + use_torch = isinstance(pixel_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) if self.device.type == "cuda" and self.use_io_binding: - io_binding = IOBindingHelper.prepare_io_binding( - self, - **kwargs, + io_binding, output_shapes, output_buffers = self.prepare_io_binding( + pixel_values, ordered_input_names=self._ordered_input_names, ) - # run inference with binding + # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - outputs = {} - for name, output in zip(self.output_names.keys(), io_binding._iobinding.get_outputs()): - outputs[name] = IOBindingHelper.to_pytorch(output) - - # converts output to namedtuple for pipelines post-processing - return SemanticSegmenterOutput(logits=outputs["logits"]) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, **kwargs) + model_inputs = {"pixel_values": pixel_values} - # run inference + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - logits = onnx_outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) - - # converts output to namedtuple for pipelines post-processing - return SemanticSegmenterOutput(logits=logits) - - def _prepare_onnx_inputs(self, use_torch: bool, **kwargs): - onnx_inputs = {} - # converts pytorch inputs into numpy inputs for onnx - for input in self.inputs_names.keys(): - onnx_inputs[input] = kwargs.pop(input) - - if use_torch: - onnx_inputs[input] = onnx_inputs[input].cpu().detach().numpy() + logits = model_outputs["logits"] - return onnx_inputs + # converts output to namedtuple for pipelines post-processing + return SemanticSegmenterOutput(logits=logits) AUDIO_CLASSIFICATION_EXAMPLE = r""" @@ -1878,18 +1882,28 @@ def __init__( ) def forward( self, - input_values: Optional[torch.Tensor] = None, - attenton_mask: Optional[torch.Tensor] = None, + input_values: Optional[Union[torch.Tensor, np.ndarray]] = None, + attention_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, + input_features: Optional[Union[torch.Tensor, np.ndarray]] = None, **kwargs, ): - if input_values is None: - # Whisper uses input_features and not input_values. - input_values = kwargs["input_features"] - use_torch = isinstance(input_values, torch.Tensor) + if self.input_name == "input_features": + assert input_features is not None, "input_features must be provided for this model" + model_input = input_features + elif self.input_name == "input_values": + assert input_values is not None, "input_values must be provided for this model" + model_input = input_values + else: + raise ValueError(f"Input {self.input_name} not supported for Audio Classification") + + use_torch = isinstance(model_input, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_values, ordered_input_names=self._ordered_input_names + model_input, + attention_mask, + ordered_input_names=self._ordered_input_names, ) # run inference with binding & synchronize in case of multiple CUDA streams @@ -1897,28 +1911,18 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return SequenceClassifierOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - self.input_name: input_values.cpu().detach().numpy(), - } - else: - onnx_inputs = { - self.input_name: input_values, - } + model_inputs = {self.input_name: model_input, "attention_mask": attention_mask} - # run inference - outputs = self.model.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - logits = outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + logits = model_outputs["logits"] - # converts output to namedtuple for pipelines post-processing - return SequenceClassifierOutput(logits=logits) + # converts output to namedtuple for pipelines post-processing + return SequenceClassifierOutput(logits=logits) CTC_EXAMPLE = r""" @@ -1966,11 +1970,12 @@ class ORTModelForCTC(ORTModel): ) def forward( self, - input_values: Optional[torch.Tensor] = None, + input_values: Optional[Union[torch.Tensor, np.ndarray]] = None, **kwargs, ): use_torch = isinstance(input_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: input_size = input_values.shape[1] output_sizes = [] @@ -1985,9 +1990,7 @@ def _conv_output_size(input_size, kernel_size, stride): known_output_shapes = {"logits": [input_values.shape[0], output_sizes[-1], self.config.vocab_size]} io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_values, - ordered_input_names=self._ordered_input_names, - known_output_shapes=known_output_shapes, + input_values, ordered_input_names=self._ordered_input_names, known_output_shapes=known_output_shapes ) # run inference with binding & synchronize in case of multiple CUDA streams @@ -1995,28 +1998,18 @@ def _conv_output_size(input_size, kernel_size, stride): self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - outputs = {} - - return CausalLMOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_values": input_values.cpu().detach().numpy(), - } - else: - onnx_inputs = { - "input_values": input_values, - } + model_inputs = {"input_values": input_values} - # run inference - outputs = self.model.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - logits = outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) - # converts output to namedtuple for pipelines post-processing - return CausalLMOutput(logits=logits) + logits = model_outputs["logits"] + + # converts output to namedtuple for pipelines post-processing + return CausalLMOutput(logits=logits) AUDIO_XVECTOR_EXAMPLE = r""" @@ -2072,11 +2065,12 @@ class ORTModelForAudioXVector(ORTModel): ) def forward( self, - input_values: Optional[torch.Tensor] = None, + input_values: Optional[Union[torch.Tensor, np.ndarray]] = None, **kwargs, ): use_torch = isinstance(input_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_values, ordered_input_names=self._ordered_input_names @@ -2087,33 +2081,21 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return XVectorOutput( - logits=output_buffers["logits"].view(output_shapes["logits"]), - embeddings=output_buffers["embeddings"].view(output_shapes["embeddings"]), - ) + logits = output_buffers["logits"].view(output_shapes["logits"]) + embeddings = output_buffers["embeddings"].view(output_shapes["embeddings"]) + else: - if use_torch: - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_values": input_values.cpu().detach().numpy(), - } - else: - onnx_inputs = { - "input_values": input_values, - } + model_inputs = {"input_values": input_values} - # run inference - outputs = self.model.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - logits = outputs[self.output_names["logits"]] - embeddings = outputs[self.output_names["embeddings"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) - embeddings = torch.from_numpy(embeddings).to(self.device) + logits = model_outputs["logits"] + embeddings = model_outputs["embeddings"] - # converts output to namedtuple for pipelines post-processing - return XVectorOutput(logits=logits, embeddings=embeddings) + # converts output to namedtuple for pipelines post-processing + return XVectorOutput(logits=logits, embeddings=embeddings) AUDIO_FRAME_CLASSIFICATION_EXAMPLE = r""" @@ -2161,7 +2143,7 @@ class ORTModelForAudioFrameClassification(ORTModel): ) def forward( self, - input_values: Optional[torch.Tensor] = None, + input_values: Optional[Union[torch.Tensor, np.ndarray]] = None, **kwargs, ): use_torch = isinstance(input_values, torch.Tensor) @@ -2170,24 +2152,16 @@ def forward( if self.device.type == "cuda" and self.use_io_binding: raise NotImplementedError() else: - if use_torch: - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_values": input_values.cpu().detach().numpy(), - } - else: - onnx_inputs = { - "input_values": input_values, - } + model_inputs = {"input_values": input_values} + + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - # run inference - outputs = self.model.run(None, onnx_inputs) + logits = model_outputs["logits"] - logits = outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) - # converts output to namedtuple for pipelines post-processing - return TokenClassifierOutput(logits=logits) + # converts output to namedtuple for pipelines post-processing + return TokenClassifierOutput(logits=logits) CUSTOM_TASKS_EXAMPLE = r""" @@ -2236,57 +2210,27 @@ class ORTModelForCustomTasks(ORTModel): checkpoint="optimum/sbert-all-MiniLM-L6-with-pooler", ) ) - def forward(self, **kwargs): - use_torch = isinstance(next(iter(kwargs.values())), torch.Tensor) + def forward(self, **model_inputs: Union[torch.Tensor, np.ndarray]): + use_torch = isinstance(next(iter(model_inputs.values())), torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) if self.device.type == "cuda" and self.use_io_binding: - io_binding = IOBindingHelper.prepare_io_binding( - self, - **kwargs, - ordered_input_names=self._ordered_input_names, - ) + # TODO: should this be used in favor of `model.prepare_io_binding`? + io_binding = IOBindingHelper.prepare_io_binding(self, **model_inputs) # run inference with binding io_binding.synchronize_inputs() self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - outputs = {} + model_outputs = {} for name, output in zip(self.output_names.keys(), io_binding._iobinding.get_outputs()): - outputs[name] = IOBindingHelper.to_pytorch(output) + model_outputs[name] = IOBindingHelper.to_pytorch(output) - # converts output to namedtuple for pipelines post-processing - return ModelOutput(**outputs) else: - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, **kwargs) - - # run inference + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - outputs = self._prepare_onnx_outputs(onnx_outputs, use_torch=use_torch) - - # converts output to namedtuple for pipelines post-processing - return ModelOutput(outputs) - - def _prepare_onnx_inputs(self, use_torch: bool, **kwargs): - onnx_inputs = {} - # converts pytorch inputs into numpy inputs for onnx - for input in self.inputs_names.keys(): - onnx_inputs[input] = kwargs.pop(input) - - if use_torch: - onnx_inputs[input] = onnx_inputs[input].cpu().detach().numpy() - - return onnx_inputs - - def _prepare_onnx_outputs(self, onnx_outputs, use_torch: bool): - outputs = {} - # converts onnxruntime outputs into tensor for standard outputs - for output, idx in self.output_names.items(): - outputs[output] = onnx_outputs[idx] - - if use_torch: - outputs[output] = torch.from_numpy(outputs[output]).to(self.device) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - return outputs + # converts output to namedtuple for pipelines post-processing + return ModelOutput(**model_outputs) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 2da4b4c8c45..89a0ae44d58 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -777,6 +777,7 @@ def _from_pretrained( model_id: Union[str, Path], config: "PretrainedConfig", use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -794,6 +795,15 @@ def _from_pretrained( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + model_path = Path(model_id) # We do not implement the logic for use_cache=False, use_merged=True @@ -815,7 +825,7 @@ def _from_pretrained( [DECODER_MERGED_ONNX_FILE_PATTERN], argument_name=None, subfolder=subfolder, - use_auth_token=use_auth_token, + token=token, revision=revision, ) use_merged = True @@ -838,7 +848,7 @@ def _from_pretrained( [DECODER_ONNX_FILE_PATTERN], "decoder_file_name", subfolder=subfolder, - use_auth_token=use_auth_token, + token=token, revision=revision, ) else: @@ -866,7 +876,7 @@ def _from_pretrained( [DECODER_WITH_PAST_ONNX_FILE_PATTERN], "decoder_with_past_file_name", subfolder=subfolder, - use_auth_token=use_auth_token, + token=token, revision=revision, ) except FileNotFoundError as e: @@ -896,7 +906,7 @@ def _from_pretrained( [ENCODER_ONNX_FILE_PATTERN], "encoder_file_name", subfolder=subfolder, - use_auth_token=use_auth_token, + token=token, revision=revision, ) else: @@ -932,7 +942,7 @@ def _from_pretrained( repo_id=model_id, subfolder=subfolder, filename=filename, - use_auth_token=use_auth_token, + token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -944,7 +954,7 @@ def _from_pretrained( repo_id=model_id, subfolder=subfolder, filename=filename + "_data", - use_auth_token=use_auth_token, + token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -989,7 +999,7 @@ def _from_pretrained( cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, revision=revision, subfolder=subfolder, ) @@ -1022,6 +1032,7 @@ def _from_transformers( model_id: str, config: "PretrainedConfig", use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: str = "main", force_download: bool = True, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -1036,6 +1047,15 @@ def _from_transformers( use_io_binding: Optional[bool] = None, task: Optional[str] = None, ) -> "ORTModelForConditionalGeneration": + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if use_cache is False and use_merged is True: raise ValueError( "The incompatible arguments use_cache=False, use_merged=True were passed to" @@ -1062,7 +1082,7 @@ def _from_transformers( subfolder=subfolder, revision=revision, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, diff --git a/optimum/onnxruntime/quantization.py b/optimum/onnxruntime/quantization.py index d93a7a31320..056123f8d8e 100644 --- a/optimum/onnxruntime/quantization.py +++ b/optimum/onnxruntime/quantization.py @@ -15,6 +15,7 @@ import logging import os +import warnings from collections import defaultdict from pathlib import Path from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union @@ -424,7 +425,8 @@ def get_calibration_dataset( preprocess_function: Optional[Callable] = None, preprocess_batch: bool = True, seed: int = 2016, - use_auth_token: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, ) -> Dataset: """ Creates the calibration `datasets.Dataset` to use for the post-training static quantization calibration step. @@ -445,13 +447,26 @@ def get_calibration_dataset( Whether the `preprocess_function` should be batched. seed (`int`, defaults to 2016): The random seed to use when shuffling the calibration dataset. - use_auth_token (`bool`, defaults to `False`): - Whether to use the token generated when running `transformers-cli login` (necessary for some datasets - like ImageNet). + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + Returns: The calibration `datasets.Dataset` to use for the post-training static quantization calibration step. """ + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + if dataset_name is None: raise ValueError( "ORTQuantizer: Static quantization calibration step requires a dataset_name if no calib_dataset is " @@ -462,7 +477,7 @@ def get_calibration_dataset( dataset_name, name=dataset_config_name, split=dataset_split, - use_auth_token=use_auth_token, + token=token, ) if num_samples is not None: diff --git a/optimum/onnxruntime/subpackage/__init__.py b/optimum/onnxruntime/subpackage/__init__.py new file mode 100644 index 00000000000..7029af7132f --- /dev/null +++ b/optimum/onnxruntime/subpackage/__init__.py @@ -0,0 +1 @@ +from .commands import ONNXRuntimeCommand diff --git a/optimum/commands/onnxruntime/__init__.py b/optimum/onnxruntime/subpackage/commands/__init__.py similarity index 87% rename from optimum/commands/onnxruntime/__init__.py rename to optimum/onnxruntime/subpackage/commands/__init__.py index 1b9c24c3b2c..44facf5ea53 100644 --- a/optimum/commands/onnxruntime/__init__.py +++ b/optimum/onnxruntime/subpackage/commands/__init__.py @@ -14,5 +14,3 @@ # limitations under the License. from .base import ONNXRuntimeCommand -from .optimize import ONNXRuntimeOptimizeCommand -from .quantize import ONNXRuntimeQuantizeCommand diff --git a/optimum/commands/onnxruntime/base.py b/optimum/onnxruntime/subpackage/commands/base.py similarity index 91% rename from optimum/commands/onnxruntime/base.py rename to optimum/onnxruntime/subpackage/commands/base.py index 53e3245ea4d..df4414c19d5 100644 --- a/optimum/commands/onnxruntime/base.py +++ b/optimum/onnxruntime/subpackage/commands/base.py @@ -14,11 +14,13 @@ # limitations under the License. """optimum.onnxruntime command-line interface base classes.""" -from .. import BaseOptimumCLICommand, CommandInfo +from optimum.commands import BaseOptimumCLICommand, CommandInfo, optimum_cli_subcommand + from .optimize import ONNXRuntimeOptimizeCommand from .quantize import ONNXRuntimeQuantizeCommand +@optimum_cli_subcommand() class ONNXRuntimeCommand(BaseOptimumCLICommand): COMMAND = CommandInfo( name="onnxruntime", diff --git a/optimum/commands/onnxruntime/optimize.py b/optimum/onnxruntime/subpackage/commands/optimize.py similarity index 96% rename from optimum/commands/onnxruntime/optimize.py rename to optimum/onnxruntime/subpackage/commands/optimize.py index 5890e0a07c7..1dd82f0ee22 100644 --- a/optimum/commands/onnxruntime/optimize.py +++ b/optimum/onnxruntime/subpackage/commands/optimize.py @@ -75,8 +75,8 @@ def parse_args(parser: "ArgumentParser"): return parse_args_onnxruntime_optimize(parser) def run(self): - from ...onnxruntime.configuration import AutoOptimizationConfig, ORTConfig - from ...onnxruntime.optimization import ORTOptimizer + from ...configuration import AutoOptimizationConfig, ORTConfig + from ...optimization import ORTOptimizer if self.args.output == self.args.onnx_model: raise ValueError("The output directory must be different than the directory hosting the ONNX model.") diff --git a/optimum/commands/onnxruntime/quantize.py b/optimum/onnxruntime/subpackage/commands/quantize.py similarity index 87% rename from optimum/commands/onnxruntime/quantize.py rename to optimum/onnxruntime/subpackage/commands/quantize.py index 2613cb33ba6..45df903e0c2 100644 --- a/optimum/commands/onnxruntime/quantize.py +++ b/optimum/onnxruntime/subpackage/commands/quantize.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from .. import BaseOptimumCLICommand +from optimum.commands import BaseOptimumCLICommand if TYPE_CHECKING: @@ -69,14 +69,15 @@ def parse_args(parser: "ArgumentParser"): return parse_args_onnxruntime_quantize(parser) def run(self): - from ...onnxruntime.configuration import AutoQuantizationConfig, ORTConfig - from ...onnxruntime.quantization import ORTQuantizer + from ...configuration import AutoQuantizationConfig, ORTConfig + from ...quantization import ORTQuantizer if self.args.output == self.args.onnx_model: raise ValueError("The output directory must be different than the directory hosting the ONNX model.") save_dir = self.args.output quantizers = [] + use_external_data_format = False quantizers = [ ORTQuantizer.from_pretrained(self.args.onnx_model, file_name=model.name) @@ -96,7 +97,11 @@ def run(self): "TensorRT quantization relies on static quantization that requires calibration, which is currently not supported through optimum-cli. Please adapt Optimum static quantization examples to run static quantization for TensorRT: https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/quantization" ) else: - qconfig = ORTConfig.from_pretrained(self.args.config).quantization + config = ORTConfig.from_pretrained(self.args.config) + qconfig = config.quantization + use_external_data_format = config.use_external_data_format for q in quantizers: - q.quantize(save_dir=save_dir, quantization_config=qconfig) + q.quantize( + save_dir=save_dir, quantization_config=qconfig, use_external_data_format=use_external_data_format + ) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 0e1da447a64..37d0feefcc4 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -128,6 +128,7 @@ class ORTConfigManager: "nystromformer": "bert", "pegasus": "bert", "roberta": "bert", + "segformer": "vit", "t5": "bert", "vit": "vit", "whisper": "bart", diff --git a/optimum/pipelines/pipelines_base.py b/optimum/pipelines/pipelines_base.py index cc36e94ef5d..a08ab8782a3 100644 --- a/optimum/pipelines/pipelines_base.py +++ b/optimum/pipelines/pipelines_base.py @@ -251,7 +251,7 @@ def load_ort_pipeline( pattern, glob_pattern="**/*.onnx", subfolder=subfolder, - use_auth_token=token, + token=token, revision=revision, ) export = len(onnx_files) == 0 diff --git a/optimum/subpackages.py b/optimum/subpackages.py new file mode 100644 index 00000000000..8729581521a --- /dev/null +++ b/optimum/subpackages.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import logging +import sys + + +if sys.version_info >= (3, 8): + from importlib import metadata as importlib_metadata +else: + import importlib_metadata +from importlib.util import find_spec, module_from_spec + +from .utils import is_onnxruntime_available + + +logger = logging.getLogger(__name__) + + +def load_namespace_modules(namespace: str, module: str): + """Load modules with a specific name inside a namespace + + This method operates on namespace packages: + https://packaging.python.org/en/latest/guides/packaging-namespace-packages/ + + For each package inside the specified `namespace`, it looks for the specified `module` and loads it. + + Args: + namespace (`str`): + The namespace containing modules to be loaded. + module (`str`): + The name of the module to load in each namespace package. + """ + for dist in importlib_metadata.distributions(): + dist_name = dist.metadata["Name"] + if not dist_name.startswith(f"{namespace}-"): + continue + package_import_name = dist_name.replace("-", ".") + module_import_name = f"{package_import_name}.{module}" + if module_import_name in sys.modules: + # Module already loaded + continue + backend_spec = find_spec(module_import_name) + if backend_spec is None: + continue + try: + imported_module = module_from_spec(backend_spec) + sys.modules[module_import_name] = imported_module + backend_spec.loader.exec_module(imported_module) + logger.debug(f"Successfully loaded {module_import_name}") + except Exception as e: + logger.error(f"An exception occured while loading {module_import_name}: {e}.") + + +def load_subpackages(): + """Load optimum subpackages + + This method goes through packages inside the `optimum` namespace and loads the `subpackage` module if it exists. + + This module is then in charge of registering the subpackage commands. + """ + SUBPACKAGE_LOADER = "subpackage" + load_namespace_modules("optimum", SUBPACKAGE_LOADER) + + # Load subpackages from internal modules not explicitly defined as namespace packages + loader_name = "." + SUBPACKAGE_LOADER + if is_onnxruntime_available(): + importlib.import_module(loader_name, package="optimum.onnxruntime") diff --git a/optimum/utils/file_utils.py b/optimum/utils/file_utils.py index 3afa5cea81e..16190709f83 100644 --- a/optimum/utils/file_utils.py +++ b/optimum/utils/file_utils.py @@ -15,10 +15,17 @@ """Utility functions related to both local files and files on the Hugging Face Hub.""" import re +import warnings from pathlib import Path from typing import List, Optional, Union -from huggingface_hub import HfApi, HfFolder, get_hf_file_metadata, hf_hub_url +import huggingface_hub +from huggingface_hub import get_hf_file_metadata, hf_hub_url + +from ..utils import logging + + +logger = logging.get_logger(__name__) def validate_file_exists( @@ -44,6 +51,7 @@ def find_files_matching_pattern( glob_pattern: str = "**/*", subfolder: str = "", use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, ) -> List[Path]: """ @@ -59,7 +67,12 @@ def find_files_matching_pattern( subfolder (`str`, defaults to `""`): In case the model files are located inside a subfolder of the model directory / repo on the Hugging Face Hub, you can specify the subfolder name here. - use_auth_token (`Optional[bool, str]`, *optional*): + use_auth_token (`Optional[Union[bool,str]]`, defaults to `None`): + Deprecated. Please use the `token` argument instead. + token (`Optional[Union[bool,str]]`, defaults to `None`): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). + token (`Optional[Union[bool, str]]`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). revision (`Optional[str]`, defaults to `None`): @@ -68,6 +81,16 @@ def find_files_matching_pattern( Returns: `List[Path]` """ + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern) if model_path.is_dir(): @@ -76,11 +99,7 @@ def find_files_matching_pattern( files = [p for p in files if re.search(pattern, str(p))] else: path = model_name_or_path - if isinstance(use_auth_token, bool): - token = HfFolder().get_token() - else: - token = use_auth_token - repo_files = map(Path, HfApi().list_repo_files(model_name_or_path, revision=revision, token=token)) + repo_files = map(Path, huggingface_hub.list_repo_files(model_name_or_path, revision=revision, token=token)) if subfolder != "": path = f"{path}/{subfolder}" files = [Path(p) for p in repo_files if re.match(pattern, str(p))] diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 795f3d57597..5ae5310ab13 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -102,6 +102,19 @@ class NormalizedVisionConfig(NormalizedConfig): INPUT_SIZE = "input_size" +class NormalizedSegformerConfig(NormalizedVisionConfig): + NUM_ATTENTION_HEADS = "num_attention_heads" + HIDDEN_SIZE = "hidden_sizes" + + # If the attribute is a list, return 0 + # 0 means let the optimizer infer the correct value based on the model graph + def __getattr__(self, attr_name): + attr_value = super().__getattr__(attr_name) + if isinstance(attr_value, list): + attr_value = 0 + return attr_value + + class NormalizedTextAndVisionConfig(NormalizedTextConfig, NormalizedVisionConfig): TEXT_CONFIG = None VISION_CONFIG = None @@ -203,7 +216,6 @@ class NormalizedConfigManager: 'owlvit', 'perceiver', 'roformer', - 'segformer', 'squeezebert', 'table-transformer', """ @@ -259,6 +271,7 @@ class NormalizedConfigManager: "regnet": NormalizedVisionConfig, "resnet": NormalizedVisionConfig, "roberta": NormalizedTextConfig, + "segformer": NormalizedSegformerConfig, "speech-to-text": SpeechToTextLikeNormalizedTextConfig, "splinter": NormalizedTextConfig, "t5": T5LikeNormalizedTextConfig, diff --git a/optimum/utils/testing_utils.py b/optimum/utils/testing_utils.py index f1c2f668e3c..76fe9a05b13 100644 --- a/optimum/utils/testing_utils.py +++ b/optimum/utils/testing_utils.py @@ -16,6 +16,7 @@ import importlib.util import itertools import os +import shutil import subprocess import sys import unittest @@ -36,9 +37,6 @@ # Used to test the hub USER = "__DUMMY_OPTIMUM_USER__" -# Not critical, only usable on the sandboxed CI instance. -TOKEN = "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" - def flatten_dict(dictionary: Dict): """ @@ -90,8 +88,9 @@ def require_hf_token(test_case): """ Decorator marking a test that requires huggingface hub token. """ - use_auth_token = os.environ.get("HF_AUTH_TOKEN", None) - if use_auth_token is None: + # is HF_AUTH_TOKEN used instead of HF_TOKEN to avoid huggigface_hub picking it up ? + hf_token = os.environ.get("HF_AUTH_TOKEN", None) + if hf_token is None: return unittest.skip("test requires hf token as `HF_AUTH_TOKEN` environment variable")(test_case) else: return test_case @@ -101,9 +100,9 @@ def require_sigopt_token_and_project(test_case): """ Decorator marking a test that requires sigopt API token. """ - use_auth_token = os.environ.get("SIGOPT_API_TOKEN", None) + sigopt_api_token = os.environ.get("SIGOPT_API_TOKEN", None) has_sigopt_project = os.environ.get("SIGOPT_PROJECT", None) - if use_auth_token is None or has_sigopt_project is None: + if sigopt_api_token is None or has_sigopt_project is None: return unittest.skip("test requires an environment variable `SIGOPT_API_TOKEN` and `SIGOPT_PROJECT`")( test_case ) @@ -184,3 +183,16 @@ def grid_parameters( else: returned_list = [test_name] + list(params) if add_test_name is True else list(params) yield returned_list + + +def remove_directory(dirpath): + """ + Remove a directory and its content. + This is a cross-platform solution to remove a directory and its content that avoids the use of `shutil.rmtree` on Windows. + Reference: https://github.com/python/cpython/issues/107408 + """ + if os.path.exists(dirpath) and os.path.isdir(dirpath): + if os.name == "nt": + os.system(f"rmdir /S /Q {dirpath}") + else: + shutil.rmtree(dirpath) diff --git a/setup.py b/setup.py index b40eba068d5..6b28fb696be 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ "transformers[sentencepiece]>=4.26.0,<4.42.0", "torch>=1.11", "packaging", - "numpy", + "numpy<2.0", # transformers requires numpy<2.0 https://github.com/huggingface/transformers/pull/31569 "huggingface_hub>=0.8.0", "datasets", ] @@ -79,10 +79,10 @@ "openvino": "optimum-intel[openvino]>=1.16.0", "nncf": "optimum-intel[nncf]>=1.16.0", "neural-compressor": "optimum-intel[neural-compressor]>=1.16.0", - "graphcore": "optimum-graphcore", "habana": ["optimum-habana", "transformers >= 4.38.0, < 4.39.0"], "neuron": ["optimum-neuron[neuron]>=0.0.20", "transformers >= 4.36.2, < 4.42.0"], "neuronx": ["optimum-neuron[neuronx]>=0.0.20", "transformers >= 4.36.2, < 4.42.0"], + "graphcore": "optimum-graphcore", "furiosa": "optimum-furiosa", "amd": "optimum-amd", "dev": TESTS_REQUIRE + QUALITY_REQUIRE, diff --git a/tests/gptq/test_quantization.py b/tests/gptq/test_quantization.py index 0c070f8c9e4..5ed1619fde3 100644 --- a/tests/gptq/test_quantization.py +++ b/tests/gptq/test_quantization.py @@ -394,7 +394,7 @@ class GPTQDataTest(unittest.TestCase): def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) - @parameterized.expand(["wikitext2", "c4", "ptb", "c4-new", "ptb-new"]) + @parameterized.expand(["wikitext2", "c4", "c4-new"]) def test_dataset(self, dataset): train_dataset = get_dataset( dataset, self.tokenizer, nsamples=self.NBSAMPLES, seqlen=self.SEQLEN, split="train" diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 39cd4fb4cb4..c2c9e0c9e9f 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -14,7 +14,6 @@ # limitations under the License. import gc import os -import shutil import subprocess import tempfile import time @@ -109,7 +108,7 @@ DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, logging, ) -from optimum.utils.testing_utils import grid_parameters, require_hf_token, require_ort_rocm +from optimum.utils.testing_utils import grid_parameters, remove_directory, require_hf_token, require_ort_rocm logger = logging.get_logger() @@ -184,9 +183,8 @@ def test_load_model_from_cache(self): def test_load_model_from_empty_cache(self): dirpath = os.path.join(default_cache_path, "models--" + self.TINY_ONNX_MODEL_ID.replace("/", "--")) + remove_directory(dirpath) - if os.path.exists(dirpath) and os.path.isdir(dirpath): - shutil.rmtree(dirpath) with self.assertRaises(Exception): _ = ORTModel.from_pretrained(self.TINY_ONNX_MODEL_ID, local_files_only=True) @@ -202,9 +200,8 @@ def test_load_seq2seq_model_from_cache(self): def test_load_seq2seq_model_from_empty_cache(self): dirpath = os.path.join(default_cache_path, "models--" + self.TINY_ONNX_SEQ2SEQ_MODEL_ID.replace("/", "--")) + remove_directory(dirpath) - if os.path.exists(dirpath) and os.path.isdir(dirpath): - shutil.rmtree(dirpath) with self.assertRaises(Exception): _ = ORTModelForSeq2SeqLM.from_pretrained(self.TINY_ONNX_SEQ2SEQ_MODEL_ID, local_files_only=True) @@ -225,9 +222,8 @@ def test_load_stable_diffusion_model_from_empty_cache(self): dirpath = os.path.join( default_cache_path, "models--" + self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID.replace("/", "--") ) + remove_directory(dirpath) - if os.path.exists(dirpath) and os.path.isdir(dirpath): - shutil.rmtree(dirpath) with self.assertRaises(Exception): _ = ORTStableDiffusionPipeline.from_pretrained( self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID, local_files_only=True @@ -938,11 +934,13 @@ def test_stable_diffusion_model_on_rocm_ep_str(self): self.assertListEqual(model.providers, ["ROCMExecutionProvider", "CPUExecutionProvider"]) def test_load_model_from_hub_private(self): - subprocess.run("huggingface-cli logout", shell=True) - # Read token of fxmartyclone (dummy user). - token = "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" + token = os.environ.get("HF_HUB_READ_TOKEN", None) + + if token is None: + self.skipTest("Test requires a token for fxmartyclone in the environment variable `HF_HUB_READ_TOKEN`.") + + model = ORTModelForCustomTasks.from_pretrained("optimum-internal-testing/tiny-random-phi-private", token=token) - model = ORTModelForCustomTasks.from_pretrained("fxmartyclone/tiny-onnx-private-2", use_auth_token=token) self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) @@ -1005,6 +1003,7 @@ def test_save_load_ort_model_with_external_data(self): # verify loading from local folder works model = ORTModelForSequenceClassification.from_pretrained(tmpdirname, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") + remove_directory(tmpdirname) @parameterized.expand([(False,), (True,)]) @pytest.mark.run_slow @@ -1012,11 +1011,7 @@ def test_save_load_ort_model_with_external_data(self): def test_save_load_decoder_model_with_external_data(self, use_cache: bool): with tempfile.TemporaryDirectory() as tmpdirname: model = ORTModelForCausalLM.from_pretrained( - "gpt2-large", - use_cache=use_cache, - export=True, - use_merged=False, - use_io_binding=False, + "gpt2-large", use_cache=use_cache, export=True, use_merged=False, use_io_binding=False ) model.save_pretrained(tmpdirname) @@ -1030,6 +1025,7 @@ def test_save_load_decoder_model_with_external_data(self, use_cache: bool): model = ORTModelForCausalLM.from_pretrained( tmpdirname, use_cache=use_cache, export=False, use_io_binding=False ) + remove_directory(tmpdirname) @parameterized.expand([(False,), (True,)]) def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): @@ -1052,6 +1048,7 @@ def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): # verify loading from local folder works model = ORTModelForSeq2SeqLM.from_pretrained(tmpdirname, use_cache=use_cache, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") + remove_directory(tmpdirname) def test_save_load_stable_diffusion_model_with_external_data(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -1073,6 +1070,7 @@ def test_save_load_stable_diffusion_model_with_external_data(self): # verify loading from local folder works model = ORTStableDiffusionPipeline.from_pretrained(tmpdirname, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") + remove_directory(tmpdirname) @parameterized.expand([(False,), (True,)]) @unittest.skip("Skipping as this test consumes too much memory") @@ -1113,7 +1111,7 @@ def test_save_model_from_hub(self): model = ORTModel.from_pretrained(self.LOCAL_MODEL_PATH) model.save_pretrained( tmpdirname, - use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + token=os.environ.get("HF_AUTH_TOKEN", None), push_to_hub=True, repository_id=self.HUB_REPOSITORY, private=True, @@ -1126,7 +1124,7 @@ def test_push_ort_model_with_external_data_to_hub(self): model = ORTModelForSequenceClassification.from_pretrained(MODEL_NAMES["bert"], export=True) model.save_pretrained( tmpdirname + "/onnx", - use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + token=os.environ.get("HF_AUTH_TOKEN", None), repository_id=MODEL_NAMES["bert"].split("/")[-1] + "-onnx", private=True, push_to_hub=True, @@ -1136,7 +1134,7 @@ def test_push_ort_model_with_external_data_to_hub(self): model = ORTModelForSequenceClassification.from_pretrained( MODEL_NAMES["bert"] + "-onnx", export=False, - use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + token=os.environ.get("HF_AUTH_TOKEN", None), ) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") @@ -1147,7 +1145,7 @@ def test_push_decoder_model_with_external_data_to_hub(self): model = ORTModelForCausalLM.from_pretrained(MODEL_NAMES["gpt2"], export=True) model.save_pretrained( tmpdirname + "/onnx", - use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + token=os.environ.get("HF_AUTH_TOKEN", None), repository_id=MODEL_NAMES["gpt2"].split("/")[-1] + "-onnx", private=True, push_to_hub=True, @@ -1157,7 +1155,7 @@ def test_push_decoder_model_with_external_data_to_hub(self): model = ORTModelForCausalLM.from_pretrained( MODEL_NAMES["gpt2"] + "-onnx", export=False, - use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + token=os.environ.get("HF_AUTH_TOKEN", None), ) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") @@ -1168,7 +1166,7 @@ def test_push_seq2seq_model_with_external_data_to_hub(self): model = ORTModelForSeq2SeqLM.from_pretrained(MODEL_NAMES["mbart"], export=True) model.save_pretrained( tmpdirname + "/onnx", - use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + token=os.environ.get("HF_AUTH_TOKEN", None), repository_id=MODEL_NAMES["mbart"].split("/")[-1] + "-onnx", private=True, push_to_hub=True, @@ -1178,7 +1176,7 @@ def test_push_seq2seq_model_with_external_data_to_hub(self): model = ORTModelForSeq2SeqLM.from_pretrained( MODEL_NAMES["mbart"] + "-onnx", export=False, - use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + token=os.environ.get("HF_AUTH_TOKEN", None), ) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") @@ -1189,7 +1187,7 @@ def test_push_stable_diffusion_model_with_external_data_to_hub(self): model = ORTStableDiffusionPipeline.from_pretrained(MODEL_NAMES["stable-diffusion"], export=True) model.save_pretrained( tmpdirname + "/onnx", - use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + token=os.environ.get("HF_AUTH_TOKEN", None), repository_id=MODEL_NAMES["stable-diffusion"].split("/")[-1] + "-onnx", private=True, push_to_hub=True, @@ -1199,7 +1197,7 @@ def test_push_stable_diffusion_model_with_external_data_to_hub(self): model = ORTStableDiffusionPipeline.from_pretrained( MODEL_NAMES["stable-diffusion"] + "-onnx", export=False, - use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + token=os.environ.get("HF_AUTH_TOKEN", None), ) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") @@ -2276,6 +2274,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): @parameterized.expand([(False,), (True,)]) @pytest.mark.run_in_series + # TODO: still gotta find out why this needs to be ran in series / why it fails in parallel + # my guess is that the model surgery is happening in parallel and that's causing the issue def test_inference_old_onnx_model(self, use_cache): tokenizer = get_preprocessor("gpt2") model = AutoModelForCausalLM.from_pretrained("gpt2") @@ -2288,9 +2288,9 @@ def test_inference_old_onnx_model(self, use_cache): tokens = tokenizer(text, return_tensors="pt") onnx_outputs = onnx_model.generate( - **tokens, num_beams=1, do_sample=False, min_new_tokens=10, max_new_tokens=10 + **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30 ) - outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=10, max_new_tokens=10) + outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) onnx_text_outputs = tokenizer.decode(onnx_outputs[0], skip_special_tokens=True) text_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) self.assertEqual(onnx_text_outputs, text_outputs) @@ -3603,13 +3603,20 @@ def _get_onnx_model_dir(self, model_id, model_arch, test_name): @pytest.mark.run_in_series def test_inference_old_onnx_model(self): - model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small") + tokenizer = get_preprocessor("t5-small") + model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") + onnx_model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small") - tokenizer = get_preprocessor("optimum/t5-small") text = "This is a sample output" tokens = tokenizer(text, return_tensors="pt") - model.generate(**tokens) + outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) + onnx_outputs = onnx_model.generate( + **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30 + ) + onnx_text_outputs = tokenizer.decode(onnx_outputs[0], skip_special_tokens=True) + text_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) + self.assertEqual(onnx_text_outputs, text_outputs) def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: @@ -4758,6 +4765,9 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach self.assertTrue("logits" in onnx_outputs) self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + self.assertTrue( + torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3) + ) if use_cache: self.assertEqual( @@ -4766,19 +4776,17 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach self.assertEqual( len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0]) ) - for i, _ in enumerate(onnx_outputs["past_key_values"]): - for j, ort_pkv in enumerate(onnx_outputs["past_key_values"][i]): - trfs_pkv = transformers_outputs["past_key_values"][i][j] + for i in range(len(onnx_outputs["past_key_values"])): + print(onnx_outputs["past_key_values"][i]) + for ort_pkv, trfs_pkv in zip( + onnx_outputs["past_key_values"][i], transformers_outputs["past_key_values"][i] + ): + ort_pkv = torch.Tensor(ort_pkv) self.assertTrue( torch.allclose(ort_pkv, trfs_pkv, atol=1e-3), f" Maxdiff: {torch.abs(ort_pkv - trfs_pkv).max()}", ) - # Compare tensor outputs - self.assertTrue( - torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3) - ) - gc.collect() @parameterized.expand(grid_parameters(FULL_GRID)) diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index c9cadbaa825..82109fcd11f 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -36,6 +36,7 @@ AutoOptimizationConfig, ORTConfig, ORTModelForImageClassification, + ORTModelForSemanticSegmentation, ORTModelForSequenceClassification, ORTOptimizer, ) @@ -171,6 +172,7 @@ def test_compare_original_seq2seq_model_with_optimized_model(self, model_cls, mo # Contribution note: Please add test models in alphabetical order. Find test models here: https://huggingface.co/hf-internal-testing. SUPPORTED_IMAGE_ARCHITECTURES_WITH_MODEL_ID = ( + (ORTModelForSemanticSegmentation, "hf-internal-testing/tiny-random-segformer"), (ORTModelForImageClassification, "hf-internal-testing/tiny-random-vit"), ) diff --git a/tests/test_modeling_base.py b/tests/test_modeling_base.py index 4bee079fbb8..34e66927632 100644 --- a/tests/test_modeling_base.py +++ b/tests/test_modeling_base.py @@ -48,7 +48,7 @@ def test_push_to_hub(self): model.save_pretrained( tmpdirname, - use_auth_token=os.environ.get("HF_AUTH_TOKEN", None), + token=os.environ.get("HF_AUTH_TOKEN", None), push_to_hub=True, repository_id="unit_test_save_model", ) diff --git a/tests/utils/test_task_processors.py b/tests/utils/test_task_processors.py index af89aec2b90..16567048073 100644 --- a/tests/utils/test_task_processors.py +++ b/tests/utils/test_task_processors.py @@ -50,7 +50,7 @@ "dataset_data_keys": {"question": "question", "context": "answer"}, }, "image-classification": { - "dataset_args": "mnist", + "dataset_args": "sasha/dog-food", "dataset_data_keys": {"image": "image"}, }, } @@ -232,6 +232,11 @@ def test_load_dataset_with_max_length(self): input_ids = dataset[0]["input_ids"] self.assertEqual(len(input_ids), max_length) + def test_load_default_dataset(self): + self.skipTest( + "Skipping so as not to execute conll2003 remote code (test would require trust_remote_code=True)" + ) + class QuestionAnsweringProcessorTest(TestCase, TaskProcessorTestBase): TASK_NAME = "question-answering"