Skip to content

Commit

Permalink
add back ort support
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 17, 2023
1 parent b46e146 commit 244a985
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 13 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
DummyVisionEmbeddingsGenerator,
DummyVisionInputGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MultiQueryPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
MultiQueryPastKeyValuesGenerator,
NormalizedConfig,
NormalizedEncoderDecoderConfig,
NormalizedSeq2SeqConfig,
Expand Down
111 changes: 99 additions & 12 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Classes handling causal-lm related architectures in ONNX Runtime."""

import logging
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -115,7 +114,7 @@
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForCausalLM(ORTModel, GenerationMixin):
"""
ONNX model with a causal language modeling head for ONNX Runtime inference. This class officially supports bloom, codegen, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gptj, llama.
ONNX model with a causal language modeling head for ONNX Runtime inference. This class officially supports bloom, codegen, falcon, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gptj, llama.
"""

auto_model_class = AutoModelForCausalLM
Expand Down Expand Up @@ -468,7 +467,7 @@ def _from_pretrained(
f"{ONNX_DECODER_WITH_PAST_NAME} not supported for the following architecture : {', '.join(MODEL_TO_PATCH_FOR_PAST)}. Please re-export your model or set use_cache=False."
)
else:
decoder_without_past_path = model_path / subfolder / decoder_file_name
pass

regular_file_names = []
for name in [ONNX_WEIGHTS_NAME, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME]:
Expand All @@ -480,15 +479,6 @@ def _from_pretrained(
f"{cls.__name__} might not behave as expected."
)

if config.model_type == "bloom":
init_cls = ORTBloomForCausalLM
elif config.model_type == "mpt":
init_cls = ORTMPTForCausalLM
elif config.model_type == "opt":
init_cls = ORTOPTForCausalLM
else:
init_cls = ORTModelForCausalLM

model_cache_path, preprocessors = cls._cached_file(
model_path=model_path,
use_auth_token=use_auth_token,
Expand Down Expand Up @@ -544,6 +534,17 @@ def _from_pretrained(
provider_options=provider_options,
)

if config.model_type == "bloom":
init_cls = ORTBloomForCausalLM
elif config.model_type == "falcon":
init_cls = ORTFalconForCausalLM
elif config.model_type == "mpt":
init_cls = ORTMPTForCausalLM
elif config.model_type == "opt":
init_cls = ORTOPTForCausalLM
else:
init_cls = ORTModelForCausalLM

return init_cls(
model=model,
config=config,
Expand Down Expand Up @@ -723,3 +724,89 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
"position_ids": None,
"attention_mask": attention_mask,
}


class ORTFalconForCausalLM(ORTModelForCausalLM):
def __init__(
self,
decoder_session: onnxruntime.InferenceSession,
config: "PretrainedConfig",
onnx_paths: List[str],
decoder_with_past_session: Optional[onnxruntime.InferenceSession] = None,
use_cache: bool = True,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
super().__init__(
decoder_session=decoder_session,
config=config,
onnx_paths=onnx_paths,
decoder_with_past_session=decoder_with_past_session,
use_cache=use_cache,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
generation_config=generation_config,
**kwargs,
)
# self.num_kv_heads = config.num_kv_heads if (config.new_decoder_architecture or not config.multi_query) else 1

# Copied from https://github.com/huggingface/transformers/pull/26199
def _reorder_cache(
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
standardized_past = self._convert_cache_to_standard_format(past, batch_size=len(beam_idx))

# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in standardized_past
)
return self._convert_to_rw_cache(reordered_past)

# Copied from https://github.com/huggingface/transformers/pull/26199
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
if past_key_values is not None:
input_ids = input_ids[:, -1:]

# the cache may be in the stardard format (e.g. in contrastive search), convert to falcon's format if needed
if len(past_key_values[0][0].shape) == 4:
past_key_values = self._convert_to_rw_cache(past_key_values)

# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.config.alibi and attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

return {
"input_ids": input_ids,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}

0 comments on commit 244a985

Please sign in to comment.