From 308d28281577dfc96578bc8c3e7fd4581e09ce67 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Fri, 3 Nov 2023 18:42:24 +0100 Subject: [PATCH] Deprecate ORTDecoder (#1511) * Deprecate ORTDecoder * add depreciation warning * fix style * fix style * format --- optimum/onnxruntime/base.py | 402 ++++------------------------- tests/onnxruntime/test_modeling.py | 6 +- 2 files changed, 50 insertions(+), 358 deletions(-) diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index e2746561bff..f79faf6661f 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -14,17 +14,17 @@ """Defines the base classes that are used to perform inference with ONNX Runtime of Transformers models.""" from abc import abstractmethod -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, Union import numpy as np import torch -from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput +from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from onnxruntime import InferenceSession from ..utils import NormalizedConfigManager from ..utils.logging import warn_once -from .utils import MULTI_QUERY_ATTN_MODELS, get_ordered_input_names, logging +from .utils import get_ordered_input_names, logging logger = logging.get_logger(__name__) @@ -121,7 +121,7 @@ def forward( return BaseModelOutput(last_hidden_state=last_hidden_state) -class ORTDecoder(ORTModelPart): +class ORTDecoderForSeq2Seq(ORTModelPart): """ Decoder model with a language modeling head on top for ONNX Runtime inference. """ @@ -132,6 +132,7 @@ def __init__( parent_model: "ORTModel", ): super().__init__(session, parent_model) + # TODO: make this less hacky. self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)] self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)] @@ -145,363 +146,13 @@ def __init__( if self.parent_model.use_cache is True and len(self.key_value_output_names) == 0: raise RuntimeError("Could not find the past key values in the provided model.") - if len(self.key_value_input_names) > 0: - self.use_past = True - else: - self.use_past = False - + self.use_past = len(self.key_value_input_names) > 0 self.use_fp16 = False for inp in session.get_inputs(): if "past_key_values" in inp.name and inp.type == "tensor(float16)": self.use_fp16 = True break - if len(self.key_value_output_names) != 0: - # Attributes useful when computing the past key/values output shapes. - self.expected_key_symbolic_shape = None - self.expected_value_symbolic_shape = None - for output in self.session.get_outputs(): - # To handle the case of multi-query attn where key and value are concatenated - if ".key_value" in output.name: - expected_key_value_symbolic_shape = output.shape - self.expected_key_symbolic_shape = ( - self.expected_value_symbolic_shape - ) = expected_key_value_symbolic_shape[:-1] + [ - expected_key_value_symbolic_shape[-1] // 2, - ] - elif ".key" in output.name: - self.expected_key_symbolic_shape = output.shape - elif ".value" in output.name: - self.expected_value_symbolic_shape = output.shape - # To handle the old case when past_key_values were following the format: past_key_values_{idx} - elif "key_values" in output.name: - if self.expected_key_symbolic_shape is None: - self.expected_key_symbolic_shape = output.shape - else: - self.expected_value_symbolic_shape = output.shape - if self.expected_key_symbolic_shape is not None and self.expected_value_symbolic_shape is not None: - break - - self.key_sequence_length_idx = -2 - if ( - isinstance(self.expected_key_symbolic_shape[-1], str) - and "sequence_length" in self.expected_key_symbolic_shape[-1] - ): - self.key_sequence_length_idx = -1 - - self.value_sequence_length_idx = -2 - if ( - isinstance(self.expected_value_symbolic_shape[-1], str) - and "sequence_length" in self.expected_value_symbolic_shape[-1] - ): - self.value_sequence_length_idx = -1 - - def prepare_inputs_for_merged( - self, - input_ids: Union[None, torch.LongTensor, np.ndarray], - past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]], - use_torch: bool, - ): - if self.parent_model.use_merged: - constructor = torch if use_torch is True else np - # Uses without/with branch of a merged decoder depending on whether real past key values are passed - use_cache_branch = constructor.full((1,), past_key_values is not None) - else: - # Uses separate decoders - use_cache_branch = None - - if use_torch and use_cache_branch is not None: - use_cache_branch = use_cache_branch.to(self.device) - - # Generate dummy past for the first forward if uses a merged decoder - if self.parent_model.use_merged and past_key_values is None: - batch_size = input_ids.shape[0] - - if self.normalized_config.config.model_type in {"mistral", "llama"}: - num_attention_heads = self.normalized_config.num_key_value_heads - else: - num_attention_heads = self.normalized_config.num_attention_heads - embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads - - dtype = constructor.float16 if self.use_fp16 else constructor.float32 - # TODO: find a way to better handle this controlflow, this is EXTREMELY ugly - # "1" is the dummy sequence length - if self.parent_model.config.model_type == "bloom": - shape_value = (batch_size * num_attention_heads, 1, embed_size_per_head) - shape_key = (batch_size * num_attention_heads, embed_size_per_head, 1) - key = constructor.zeros(shape_key, dtype=dtype) - value = constructor.zeros(shape_value, dtype=dtype) - - if use_torch is True: - key = key.to(self.device) - value = value.to(self.device) - - past_key_values = tuple( - key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value] - ) - elif self.parent_model.config.model_type in MULTI_QUERY_ATTN_MODELS: - shape_key_and_value = (batch_size, 1, embed_size_per_head * 2) - key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype) - - if use_torch is True: - key_and_value = key_and_value.to(self.device) - - past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names))) - else: - shape = (batch_size, num_attention_heads, 1, embed_size_per_head) - key_or_value = constructor.zeros(shape, dtype=dtype) - - if use_torch is True: - key_or_value = key_or_value.to(self.device) - - past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) - - return use_cache_branch, past_key_values - - def compute_past_key_values_output_shapes( - self, - input_ids: torch.Tensor, - use_cache_branch: Optional[bool], - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - ) -> Dict[str, List[int]]: - """ - Computes the outputs of the past key / value because it is not always easy to perform shape inference on them, - which is needed for creating IO binding output buffers. - - Args: - input_ids (`torch.Tensor`): - The input ids that are associated with the current inputs. - use_cache_branch (`Optional[bool]`): - In the case of a merged decoder, whether the with-past branch is used. In case the decoders without and with past are - separate, this parameter should be None. - past_key_values (`Optional[Tuple[Tuple[torch.Tensor]]]`, defaults to `None`): - The past key values associated with the current inputs. - - Returns: - `Dict[str, List[int]]`: The dictionary mapping each past key value output name to its corresponding shape. - """ - batch_size = input_ids.size(0) - if self.normalized_config.config.model_type in {"mistral", "llama"}: - num_attention_heads = self.normalized_config.num_key_value_heads - else: - num_attention_heads = self.normalized_config.num_attention_heads - embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads - - sequence_length = input_ids.size(1) - if past_key_values is not None and use_cache_branch is not False: - # Here, use_cache_branch may be None in the case of separate decoder without/with past, or True if the with past branch - # of a merged decoder is used - sequence_length += past_key_values[0].size(2) - - half_shape = [batch_size, num_attention_heads] - if len(self.expected_key_symbolic_shape) == 3: - half_shape[0] = batch_size * num_attention_heads - half_shape.pop(1) - - key_shape = [sequence_length, embed_size_per_head] - if self.key_sequence_length_idx == -1: - key_shape[0], key_shape[1] = key_shape[1], key_shape[0] - - value_shape = [sequence_length, embed_size_per_head] - if self.value_sequence_length_idx == -1: - value_shape[0], value_shape[1] = value_shape[1], value_shape[0] - - key_shape = half_shape + key_shape - value_shape = half_shape + value_shape - - return {name: key_shape if "key" in name else value_shape for name in self.key_value_output_names} - - def compute_past_key_values_output_shapes_mqa( - self, - input_ids: torch.Tensor, - use_cache_branch: Optional[bool], - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - ) -> Dict[str, List[int]]: - batch_size = input_ids.size(0) - num_attention_heads = self.normalized_config.num_attention_heads - embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads - - sequence_length = input_ids.size(1) - if past_key_values is not None and use_cache_branch is not False: - sequence_length += past_key_values[0].size(-2) - - key_and_value_shape = (batch_size, sequence_length, embed_size_per_head * 2) - - return {name: key_and_value_shape for name in self.key_value_output_names} - - def forward( - self, - input_ids: torch.LongTensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - labels: Optional[torch.LongTensor] = None, - use_cache_branch: None = None, - ) -> CausalLMOutputWithCrossAttentions: - # adding use_cache_branch in the signature here is just a hack for IO Binding - use_torch = isinstance(input_ids, torch.Tensor) - self.parent_model.raise_on_numpy_input_io_binding(use_torch) - - # Flatten the past_key_values (no need to flatten for models using multi-query attn) - if past_key_values is not None and (self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS): - past_key_values = tuple( - past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer - ) - - # no-ops if merged decoder is not used - use_cache_branch_tensor, past_key_values = self.prepare_inputs_for_merged( - input_ids, past_key_values, use_torch=use_torch - ) - - if self.parent_model.use_io_binding: - if self.parent_model.config.model_type in MULTI_QUERY_ATTN_MODELS: - compute_past_key_values_output_shapes_func = self.compute_past_key_values_output_shapes_mqa - else: - compute_past_key_values_output_shapes_func = self.compute_past_key_values_output_shapes - known_output_shapes = compute_past_key_values_output_shapes_func( - input_ids, - use_cache_branch=use_cache_branch_tensor.item() if use_cache_branch_tensor is not None else None, - past_key_values=past_key_values, - ) - - # TODO: fix transformers generate to have contiguous input_ids, position_ids here already - # Calling `contiguous()` here is necessary to not have errors - # on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding. - # I suspect the garbage collector to somehow negate `tensor = tensor.contiguous()` - # in modeling_ort.py, which is then never assigned anywhere. - model_inputs = [input_ids.contiguous()] - - if "attention_mask" in self.input_names: - model_inputs.append(attention_mask) - - if "position_ids" in self.input_names: - if position_ids is None: - raise ValueError("position_ids was not passed but is a required input for this ONNX model.") - model_inputs.append(position_ids.contiguous()) - - if past_key_values is not None: - model_inputs += past_key_values - - if use_cache_branch_tensor is not None: - model_inputs.append(use_cache_branch_tensor) - - if "labels" in self.input_names: - model_inputs.append(labels) - known_output_shapes.update({"loss": []}) - - io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( - self.session, - *model_inputs, - known_output_shapes=known_output_shapes, - ordered_input_names=self._ordered_input_names, - ) - - if self.device.type == "cpu": - self.session.run_with_iobinding(io_binding) - else: - io_binding.synchronize_inputs() - self.session.run_with_iobinding(io_binding) - io_binding.synchronize_outputs() - - # Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2) - past_key_values = () - for name in self.key_value_output_names: - past_key_values += (output_buffers[name].view(output_shapes[name]),) - - # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (self-attention key and value per decoder layer) - if self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS: - num_pkv = 2 - past_key_values = tuple( - past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv) - ) - - logits = output_buffers["logits"].view(output_shapes["logits"]) - - loss = None - if "loss" in self.output_names: - loss = output_buffers["loss"].view(output_shapes["loss"]) - else: - if use_torch: - onnx_inputs = { - "input_ids": input_ids.cpu().detach().numpy(), - "attention_mask": attention_mask.cpu().detach().numpy(), - } - - if self.parent_model.use_merged is True: - onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy() - - if past_key_values is not None: - # Add the past_key_values to the decoder inputs - for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): - onnx_inputs[input_name] = past_key_value.cpu().detach().numpy() - - if "position_ids" in self.input_names: - if position_ids is None: - raise ValueError("position_ids was not passed but is a required input for this ONNX model.") - onnx_inputs["position_ids"] = position_ids.cpu().detach().numpy() - - if "labels" in self.input_names: - onnx_inputs["labels"] = labels.cpu().detach().numpy() - else: - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - - if self.parent_model.use_merged is True: - onnx_inputs["use_cache_branch"] = use_cache_branch_tensor - - if past_key_values is not None: - # Add the past_key_values to the decoder inputs - for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): - onnx_inputs[input_name] = past_key_value - - if "position_ids" in self.input_names: - if position_ids is None: - raise ValueError("position_ids was not passed but is a required input for this ONNX model.") - onnx_inputs["position_ids"] = position_ids - - if "labels" in self.input_names: - onnx_inputs["labels"] = labels - - # Run inference - outputs = self.session.run(None, onnx_inputs) - - # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention) - past_key_values = tuple( - torch.from_numpy(outputs[self.output_names[key]]).to(self.device) - for key in self.key_value_output_names - ) - - # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and - # per decoder layer - if self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS: - num_pkv = 2 - past_key_values = tuple( - past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv) - ) - - logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device) - - loss = None - if "loss" in self.output_names: - loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device) - - return CausalLMOutputWithCrossAttentions(loss=loss, logits=logits, past_key_values=past_key_values) - - -class ORTDecoderForSeq2Seq(ORTDecoder): - """ - Decoder model with a language modeling head on top for ONNX Runtime inference. - """ - - def __init__( - self, - session: InferenceSession, - parent_model: "ORTModel", - ): - super().__init__(session, parent_model) - # We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2 # can be used but do not support KV caching for the cross-attention key/values, see: # https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L302-L311 @@ -808,3 +459,44 @@ def forward( raise ValueError("Unsupported num_pkv") return Seq2SeqLMOutput(loss=loss, logits=logits, past_key_values=out_past_key_values) + + def prepare_inputs_for_merged( + self, + input_ids: Union[None, torch.LongTensor, np.ndarray], + past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]], + use_torch: bool, + ): + if self.parent_model.use_merged: + constructor = torch if use_torch is True else np + # Uses without/with branch of a merged decoder depending on whether real past key values are passed + use_cache_branch = constructor.full((1,), past_key_values is not None) + else: + # Uses separate decoders + use_cache_branch = None + + if use_torch and use_cache_branch is not None: + use_cache_branch = use_cache_branch.to(self.device) + + # Generate dummy past for the first forward if uses a merged decoder + if self.parent_model.use_merged and past_key_values is None: + batch_size = input_ids.shape[0] + num_attention_heads = self.normalized_config.num_attention_heads + embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads + dtype = constructor.float16 if self.use_fp16 else constructor.float32 + shape = (batch_size, num_attention_heads, 1, embed_size_per_head) + key_or_value = constructor.zeros(shape, dtype=dtype) + + if use_torch is True: + key_or_value = key_or_value.to(self.device) + + past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) + + return use_cache_branch, past_key_values + + +class ORTDecoder(ORTDecoderForSeq2Seq): + def __init__(self, *args, **kwargs): + logger.warning( + "The class `ORTDecoder` is deprecated and will be removed in optimum v1.15.0, please use `ORTDecoderForSeq2Seq` instead." + ) + super().__init__(*args, **kwargs) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index be12d36cccd..baa62f1670c 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -91,7 +91,7 @@ ORTModelForVision2Seq, ORTStableDiffusionPipeline, ) -from optimum.onnxruntime.base import ORTDecoder, ORTDecoderForSeq2Seq, ORTEncoder +from optimum.onnxruntime.base import ORTDecoderForSeq2Seq, ORTEncoder from optimum.onnxruntime.modeling_diffusion import ( ORTModelTextEncoder, ORTModelUnet, @@ -4524,9 +4524,9 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach self.assertTrue(has_onnx_input(model_path, "use_cache_branch")) self.assertEqual(onnx_model.use_merged, True) - self.assertIsInstance(onnx_model.decoder, ORTDecoder) + self.assertIsInstance(onnx_model.decoder, ORTDecoderForSeq2Seq) if onnx_model.use_cache is True and onnx_model.use_merged is False: - self.assertIsInstance(onnx_model.decoder_with_past, ORTDecoder) + self.assertIsInstance(onnx_model.decoder_with_past, ORTDecoderForSeq2Seq) if onnx_model.use_cache is True and onnx_model.use_merged is True: self.assertTrue(onnx_model.decoder_with_past is None)