diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 3f47f9c7570..474628d546b 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -41,6 +41,7 @@ Supported architectures: - Donut-Swin - Electra - Encoder Decoder +- Falcon - Flaubert - GPT-2 - GPT-BigCode diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 3aca641513c..dd11032eccb 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -163,7 +163,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: # Decoders based on GPT2 require a position_ids input to avoid # generating wrong position_ids in the model itself: # https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802 - if not self.no_position_ids and self.task == "text-generation": + if not self.no_position_ids and self.task in ["text-generation", "feature-extraction"]: common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} return common_inputs diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 1d4c49f2912..4c96bbdbe9a 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -35,6 +35,7 @@ DummyVisionInputGenerator, GPTBigCodeDummyPastKeyValuesGenerator, MistralDummyPastKeyValuesGenerator, + MultiQueryPastKeyValuesGenerator, NormalizedConfig, NormalizedEncoderDecoderConfig, NormalizedSeq2SeqConfig, @@ -59,6 +60,7 @@ from .model_patcher import ( BartModelPatcher, BloomModelPatcher, + FalconModelPatcher, LlamaModelPatcher, MistralModelPatcher, OPTModelPatcher, @@ -279,9 +281,6 @@ class BloomOnnxConfig(TextDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): - """ - Refer to OnnxConfigWithPast in base.py - """ if direction not in ["inputs", "outputs"]: raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') @@ -337,6 +336,87 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.key_value"] = t +class FalconOnnxConfig(TextDecoderOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + MultiQueryPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DEFAULT_ONNX_OPSET = 14 # Falcon uses aten::triu that requires opset>=14 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + DUMMY_PKV_GENERATOR_CLASS = MultiQueryPastKeyValuesGenerator + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + preprocessors: Optional[List[Any]] = None, + no_position_ids: bool = False, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + preprocessors=preprocessors, + no_position_ids=no_position_ids, + ) + # For some reason Falcon config.num_kv_heads can not be trusted, see in Transformers: + # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L337 + self._normalized_config.num_kv_heads = ( + self._normalized_config.num_kv_heads + if (self._normalized_config.new_decoder_architecture or not self._normalized_config.multi_query) + else 1 + ) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = super().inputs + + if ( + not self.no_position_ids + and not self._config.alibi + and self.task in ["text-generation", "feature-extraction"] + ): + # When alibi is used, position_ids are not used in Falcon. + # Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116 + common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} + + return common_inputs + + # we need to set output_attentions=True in the model input to avoid calling + # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return FalconModelPatcher(self, model, model_kwargs=model_kwargs) + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = { + 0: "batch_size x num_heads", + 1: decoder_sequence_name, + } + inputs_or_outputs[f"{name}.{i}.value"] = { + 0: "batch_size x num_heads", + 1: decoder_sequence_name, + } + + class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): encoder_shape = ( diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index aa14526bd8c..b6f5a4dcd82 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -15,11 +15,16 @@ import dataclasses import functools import inspect -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +import types +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union +import transformers +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.falcon.modeling_falcon import FalconModel, build_alibi_tensor from transformers.utils import is_torch_available from ...utils.modeling_utils import ( + _falcon_prepare_attn_mask, _prepare_attn_mask, _prepare_decoder_attention_mask, _prepare_decoder_sliding_window_attention_mask, @@ -229,6 +234,237 @@ def patched_forward(*args, **kwargs): self.patched_forward = patched_forward +def _make_causal_mask_falcon_patched( + input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int +) -> torch.BoolTensor: + """ + Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it + just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1, + target_length, target_length+past_key_values_length]`. + """ + batch_size, target_length = input_ids_shape + + # NOTE: ONNX Runtime is not able to run ONNX Trilu node with bool input. As a workaround, we pass a float input + # and cast to bool here. Reference: https://github.com/microsoft/onnxruntime/issues/16189 + mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.float, device=device), diagonal=1).to( + torch.bool + ) + + # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op. + # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this + # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later. + past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device) + mask = torch.cat([past_mask, mask], dim=-1) + expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) + return expanded_mask + + +def falcon_model_forward_without_kv_reformatting( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # NOTE: here we removed the _convert_to_rw_cache call + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = inputs_embeds + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # Compute alibi tensor: check build_alibi_tensor documentation + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + if self.use_alibi: + alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + else: + alibi = None + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + # NOTE: here we use expand(batch_size, -1) instead of transformers view(-1, seq_length) that is bugged + position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) + else: + position_ids = position_ids.view(-1, seq_length).long() + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # NOTE: here we removed the _convert_cache_to_standard_format call + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class FalconModelPatcher(ModelPatcher): + def __enter__(self): + self.patch_ops() + + transformers.models.falcon.modeling_falcon._make_causal_mask = _make_causal_mask_falcon_patched + + if self.real_config.task == "text-generation": + self._model.transformer.forward = types.MethodType( + falcon_model_forward_without_kv_reformatting, self._model.transformer + ) + + # In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length. + if isinstance(self._model, FalconModel): + self._model._prepare_attn_mask = _falcon_prepare_attn_mask + else: + self._model.transformer._prepare_attn_mask = _falcon_prepare_attn_mask + + setattr(self._model, self.orig_forward_name, self.patched_forward) + + def __exit__(self, exc_type, exc_value, traceback): + self.restore_ops() + + setattr(self._model, self.orig_forward_name, self.orig_forward) + + if self.real_config.task == "text-generation": + self._model.transformer.forward = types.MethodType( + self.original_model_transformer_forward, self._model.transformer + ) + + transformers.models.falcon.modeling_falcon._make_causal_mask = self.original_make_causal + + # In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length. + if isinstance(self._model, FalconModel): + self._model._prepare_attn_mask = self.original_falcon_prepare_attn_mask + else: + self._model.transformer._prepare_attn_mask = self.original_falcon_prepare_attn_mask + + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + if config.task == "text-generation": + self.original_model_transformer_forward = model.transformer.forward + + self.original_make_causal = transformers.models.falcon.modeling_falcon._make_causal_mask + + if isinstance(model, FalconModel): + self.original_falcon_prepare_attn_mask = model._prepare_attn_mask + else: + self.original_falcon_prepare_attn_mask = model.transformer._prepare_attn_mask + + self._model = model + + self.orig_forward_name = "forward" if hasattr(self._model, "forward") else "call" + self.orig_forward = getattr(self._model, self.orig_forward_name) + + allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past + + @functools.wraps(self.orig_forward) + def patched_forward(*args, **kwargs): + model_kwargs = self.model_kwargs + # setting output_attentions=True in the model input to avoid calling torch.nn.functional.scaled_dot_product_attention + # in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/falcon/modeling_falcon.py#L425 + model_kwargs["output_attentions"] = True + signature = inspect.signature(self.orig_forward) + args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=model_kwargs) + + outputs = self.orig_forward(*args, **kwargs) + + filterd_outputs = {} + for name, value in outputs.items(): + onnx_output_name = config.torch_to_onnx_output_map.get(name, name) + if ( + onnx_output_name in config.outputs + or (allow_past_in_outputs and name.startswith("past_key_values")) + or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) + ): + filterd_outputs[name] = value + return filterd_outputs + + self.patched_forward = patched_forward + + class WavLMModelPatcher(ModelPatcher): def __init__( self, diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 2dda5594a66..6fce8b4f2d8 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -68,6 +68,7 @@ MODEL_TYPES_REQUIRING_POSITION_IDS = { "codegen", + "falcon", "gpt2", "gpt-bigcode", "gpt-neo", diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 2841383eb96..89107fa053a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -511,6 +511,15 @@ class TasksManager: "text2text-generation-with-past", onnx="EncoderDecoderOnnxConfig", ), + "falcon": supported_tasks_mapping( + "feature-extraction", + "feature-extraction-with-past", + "question-answering", + "text-generation", + "text-generation-with-past", + "token-classification", + onnx="FalconOnnxConfig", + ), "flaubert": supported_tasks_mapping( "feature-extraction", "fill-mask", diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 2707c6eeab2..13aef3546a5 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -114,7 +114,7 @@ @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) class ORTModelForCausalLM(ORTModel, GenerationMixin): """ - ONNX model with a causal language modeling head for ONNX Runtime inference. This class officially supports bloom, codegen, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gptj, llama. + ONNX model with a causal language modeling head for ONNX Runtime inference. This class officially supports bloom, codegen, falcon, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gptj, llama. """ auto_model_class = AutoModelForCausalLM @@ -265,7 +265,6 @@ def forward( if "loss" in self.output_names: loss = output_buffers["loss"].view(output_shapes["loss"]) - else: inputs["input_ids"] = input_ids.cpu().detach().numpy() if use_torch else input_ids @@ -337,8 +336,9 @@ def prepare_past_key_values( else: num_attention_heads = self.normalized_config.num_attention_heads embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads + dtype = constructor.float16 if self.use_fp16 else constructor.float32 - # TODO: find a way to better handle this controlflow + # TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY. # "1" is the dummy sequence length if self.config.model_type == "bloom": shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head) @@ -353,7 +353,8 @@ def prepare_past_key_values( past_key_values = tuple( key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value] ) - elif self.config.model_type in MULTI_QUERY_ATTN_MODELS: + elif self.config.model_type == "gpt_bigcode": + # GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor. shape_key_and_value = (batch_size, 0, embed_size_per_head * 2) key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype) @@ -361,6 +362,14 @@ def prepare_past_key_values( key_and_value = key_and_value.to(self.device) past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names))) + elif self.config.model_type == "falcon": + shape = (batch_size * self.num_key_value_heads, 0, embed_size_per_head) + key_or_value = constructor.zeros(shape, dtype=dtype) + + if use_torch: + key_or_value = key_or_value.to(self.device) + + past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) else: shape = (batch_size, num_attention_heads, 0, embed_size_per_head) key_or_value = constructor.zeros(shape, dtype=dtype) @@ -477,15 +486,6 @@ def _from_pretrained( f"{cls.__name__} might not behave as expected." ) - if config.model_type == "bloom": - init_cls = ORTBloomForCausalLM - elif config.model_type == "mpt": - init_cls = ORTMPTForCausalLM - elif config.model_type == "opt": - init_cls = ORTOPTForCausalLM - else: - init_cls = ORTModelForCausalLM - model_cache_path, preprocessors = cls._cached_file( model_path=model_path, use_auth_token=use_auth_token, @@ -541,6 +541,17 @@ def _from_pretrained( provider_options=provider_options, ) + if config.model_type == "bloom": + init_cls = ORTBloomForCausalLM + elif config.model_type == "falcon": + init_cls = ORTFalconForCausalLM + elif config.model_type == "mpt": + init_cls = ORTMPTForCausalLM + elif config.model_type == "opt": + init_cls = ORTOPTForCausalLM + else: + init_cls = ORTModelForCausalLM + return init_cls( model=model, config=config, @@ -720,3 +731,87 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "position_ids": None, "attention_mask": attention_mask, } + + +class ORTFalconForCausalLM(ORTModelForCausalLM): + def __init__( + self, + model: onnxruntime.InferenceSession, + config: "PretrainedConfig", + use_io_binding: Optional[bool] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + generation_config: Optional[GenerationConfig] = None, + use_cache: Optional[bool] = None, + **kwargs, + ): + super().__init__( + model=model, + config=config, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + preprocessors=preprocessors, + generation_config=generation_config, + use_cache=use_cache, + **kwargs, + ) + self.num_key_value_heads = ( + config.num_kv_heads if (config.new_decoder_architecture or not config.multi_query) else 1 + ) + + # Copied from https://github.com/huggingface/transformers/pull/26199 + def _reorder_cache( + self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + Output shares the same memory storage as `past`. + """ + standardized_past = self._convert_cache_to_standard_format(past, batch_size=len(beam_idx)) + + # Get a copy of `beam_idx` on all the devices where we need those indices. + device_to_beam_idx = { + past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past + } + reordered_past = tuple( + ( + layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), + layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), + ) + for layer_past in standardized_past + ) + return self._convert_to_rw_cache(reordered_past) + + # Copied from https://github.com/huggingface/transformers/pull/26199 + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + # the cache may be in the stardard format (e.g. in contrastive search), convert to falcon's format if needed + if len(past_key_values[0][0].shape) == 4: + past_key_values = self._convert_to_rw_cache(past_key_values) + + # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. + if not self.config.alibi and attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + return { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index e8d1588dd7a..cdd4460fb0f 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -54,7 +54,8 @@ "tensor(double)": np.float64, } -MULTI_QUERY_ATTN_MODELS = {"gpt_bigcode"} +# TODO: this is likely bugged as Falcon handles both the MQA and non-MQA implem +MULTI_QUERY_ATTN_MODELS = {"falcon", "gpt_bigcode"} def _is_gpu_available(): diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 03a6c0bdec3..1555f846f32 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -62,6 +62,7 @@ FalconDummyPastKeyValuesGenerator, GPTBigCodeDummyPastKeyValuesGenerator, MistralDummyPastKeyValuesGenerator, + MultiQueryPastKeyValuesGenerator, ) from .modeling_utils import recurse_getattr, recurse_setattr from .normalized_config import ( diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index c444c913cac..765f489a341 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -832,7 +832,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int past_key_value_shape = ( self.batch_size, self.sequence_length, - self.hidden_size // self.num_attention_heads * 2, + self.hidden_size // self.num_attention_heads * 2, # GPT BigCode has a fused KV cache. ) return [ self.random_float_tensor(past_key_value_shape, framework=framework, dtype=float_dtype) @@ -861,6 +861,43 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int ] +class MultiQueryPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + **kwargs, + ) + self.num_kv_heads = normalized_config.num_kv_heads + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + past_shape = ( + self.batch_size * self.num_kv_heads, + self.sequence_length, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_float_tensor(past_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(past_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + class FalconDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def __init__( self, diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index 67e12861eb5..336ad31e5a7 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -178,3 +178,41 @@ def _prepare_decoder_sliding_window_attention_mask( ) return combined_attention_mask + + +def _falcon_prepare_attn_mask( + attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int +) -> torch.BoolTensor: + from transformers.models.falcon.modeling_falcon import ( + _expand_mask, + ) + + # NOTE: there is no "copied from" for falcon in transformers which makes no sense to me. + + # Create a causal mask + # The attention mask we receive as input should cover the whole extended sequence, including any past + # cache, so its shape should be [batch_size, seq_length + past_key_values_length] + # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length] + if input_shape[1] + past_key_values_length != attention_mask.shape[1]: + raise ValueError( + "Attention mask shape should be (batch_size, seq_length + past_key_values_length)" + f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length" + f" {past_key_values_length}." + ) + combined_attention_mask = None + device = attention_mask.device + _, seq_length = input_shape + + # if seq_length > 1: + # NOTE: we remove here the `if seq_length > 1` to allow to use a single decoder. + combined_attention_mask = _make_causal_mask( + input_shape, device=device, past_key_values_length=past_key_values_length + ) + + # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length] + expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 335bb4dabcf..eb4f5659e8a 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -213,9 +213,7 @@ class NormalizedConfigManager: "blenderbot": BartLikeNormalizedTextConfig, "blenderbot-small": BartLikeNormalizedTextConfig, "bloom": NormalizedTextConfig.with_args(num_layers="n_layer"), - "falcon": NormalizedTextConfig.with_args( - num_layers="num_hidden_layers", num_attention_heads="num_attention_heads" - ), + "falcon": NormalizedTextConfig, "camembert": NormalizedTextConfig, "codegen": GPT2LikeNormalizedTextConfig, "cvt": NormalizedVisionConfig, diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 253f3cdac41..ba2030d6742 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -72,6 +72,7 @@ ], "mohitsha/tiny-random-testing-bert2gpt2": ["text2text-generation", "text2text-generation-with-past"], }, + "falcon": "fxmarty/really-tiny-falcon-testing", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 1281d2a5606..7794bd7b2d4 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1936,6 +1936,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES = [ "bloom", "codegen", + "falcon", "gpt2", "gpt_bigcode", "gpt_neo", @@ -2056,14 +2057,17 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) onnx_outputs = onnx_model(**tokens, position_ids=position_ids) - self.assertTrue("logits" in onnx_outputs) - self.assertIsInstance(onnx_outputs.logits, torch.Tensor) - with torch.no_grad(): transformers_outputs = transformers_model(**tokens) + self.assertTrue("logits" in onnx_outputs) + self.assertIsInstance(onnx_outputs.logits, torch.Tensor) + # compare tensor outputs - self.assertTrue(torch.allclose(onnx_outputs.logits, transformers_outputs.logits, atol=1e-4)) + self.assertTrue( + torch.allclose(onnx_outputs.logits, transformers_outputs.logits, atol=1e-4), + f"Maxdiff: {(onnx_outputs.logits - transformers_outputs.logits).abs()}", + ) # Compare batched generation. tokenizer.pad_token_id = tokenizer.eos_token_id @@ -2074,11 +2078,25 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach onnx_model.config.eos_token_id = None transformers_model.config.eos_token_id = None + new_tokens = 30 + if model_arch == "falcon": + # TODO: remove once https://github.com/huggingface/transformers/pull/26873 is released, falcon is broken in transformers + new_tokens = 5 onnx_outputs = onnx_model.generate( - **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30, eos_token_id=None + **tokens, + num_beams=1, + do_sample=False, + min_new_tokens=new_tokens, + max_new_tokens=new_tokens, + eos_token_id=None, ) transformers_outputs = transformers_model.generate( - **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30, eos_token_id=None + **tokens, + num_beams=1, + do_sample=False, + min_new_tokens=new_tokens, + max_new_tokens=new_tokens, + eos_token_id=None, ) self.assertTrue(torch.allclose(onnx_outputs, transformers_outputs)) @@ -2256,12 +2274,13 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode text = "My Name is Philipp and i live" tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None) model_not_merged_dir = self.onnx_model_dirs[test_name + "_False"] + model_not_merged = ORTModelForCausalLM.from_pretrained(model_not_merged_dir) not_merged_onnx_path = Path(model_not_merged_dir, ONNX_WEIGHTS_NAME) self.assertFalse(has_onnx_input(not_merged_onnx_path, "use_cache_branch")) self.assertFalse(model_not_merged.use_merged) - model_merged_dir = Path(model_not_merged_dir) / "merged" + model_merged_dir = Path(Path(model_not_merged_dir).parents[0], "merged") task = model_not_merged.export_feature if use_cache: task += "-with-past" diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 696479c6c33..f8a1c36e2b5 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -57,6 +57,7 @@ ], "mohitsha/tiny-random-testing-bert2gpt2": ["text2text-generation", "text2text-generation-with-past"], }, + "falcon": "fxmarty/really-tiny-falcon-testing", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",