Skip to content

Commit

Permalink
fix tests ortmodel
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Nov 6, 2023
1 parent 554e312 commit 4096b55
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 21 deletions.
4 changes: 2 additions & 2 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ def export_tensorflow(
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.
device (`str`, *optional*, defaults to `cpu`):
device (`Optional[str]`, defaults to `"cpu"`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
Expand Down Expand Up @@ -790,7 +790,7 @@ def export(
Directory to store the exported ONNX model.
opset (`Optional[int]`, defaults to `None`):
The version of the ONNX operator set to use.
device (`str`, *optional*, defaults to `cpu`):
device (`Optional[str]`, defaults to `"cpu"`):
The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for
export on CUDA devices.
input_shapes (`Optional[Dict]`, defaults to `None`):
Expand Down
11 changes: 7 additions & 4 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
if self.parent_model.use_cache is True and len(self.key_value_output_names) == 0:
raise RuntimeError("Could not find the past key values in the provided model.")

self.use_past = len(self.key_value_input_names) > 0
self.use_past = len(self.key_value_output_names) > 0
self.use_fp16 = False
for inp in session.get_inputs():
if "past_key_values" in inp.name and inp.type == "tensor(float16)":
Expand Down Expand Up @@ -312,8 +312,9 @@ def forward(
if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])

# IO Binding does not support 0-dim output with null pointer, so handle this case here
if self.use_past is False or use_merged_no_cache:
if not self.use_past:
out_past_key_values = None
elif use_merged_no_cache:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv)
)
Expand Down Expand Up @@ -427,7 +428,9 @@ def forward(
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
# * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention)
# * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant)
if not self.use_past or use_merged_no_cache or self.no_cross_attention_cache:
if not self.use_past:
out_past_key_values = None
elif use_merged_no_cache or self.no_cross_attention_cache:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv)
)
Expand Down
75 changes: 60 additions & 15 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
self.generation_config = generation_config
self.onnx_paths = [self.model_path]
self.use_merged = "use_cache_branch" in self.inputs_names
self.model_type = self.config.model_type

self.use_fp16 = False
for inp in model.get_inputs():
Expand Down Expand Up @@ -204,8 +205,8 @@ def forward(
loss = None
if self.use_cache:
if past_key_values is not None:
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
# Flatten the past_key_values (gpt_bigcode has fused key/value cache, so no need to flatten it)
if self.model_type != "gpt_bigcode":
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
Expand Down Expand Up @@ -299,7 +300,7 @@ def forward(
if "loss" in self.output_names:
loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device)

if self.use_cache and self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
if self.use_cache and self.model_type != "gpt_bigcode":
# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and
# per decoder layer
past_key_values = tuple(
Expand Down Expand Up @@ -330,7 +331,7 @@ def prepare_past_key_values(
# Generate dummy past for the first forward if uses a merged decoder
if past_key_values is None:
batch_size = input_ids.shape[0]
if self.config.model_type in {"mistral", "llama"}:
if self.model_type in {"mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
Expand All @@ -339,7 +340,7 @@ def prepare_past_key_values(
dtype = constructor.float16 if self.use_fp16 else constructor.float32
# TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY.
# "1" is the dummy sequence length
if self.config.model_type == "bloom":
if self.model_type == "bloom":
shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head)
shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0)
key = constructor.zeros(shape_key, dtype=dtype)
Expand All @@ -352,7 +353,7 @@ def prepare_past_key_values(
past_key_values = tuple(
key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value]
)
elif self.config.model_type == "gpt_bigcode":
elif self.model_type == "gpt_bigcode":
# GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor.
shape_key_and_value = (batch_size, 0, embed_size_per_head * 2)
key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype)
Expand All @@ -361,7 +362,7 @@ def prepare_past_key_values(
key_and_value = key_and_value.to(self.device)

past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names)))
elif self.config.model_type == "falcon":
elif self.model_type == "falcon":
shape = (batch_size * self.num_key_value_heads, 0, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)

Expand All @@ -383,8 +384,7 @@ def prepare_past_key_values(
shape = [*value.shape]
index = (
1
if self.config.model_type in MULTI_QUERY_ATTN_MODELS
or (self.config.model_type == "bloom" and "value" in name)
if self.model_type in MULTI_QUERY_ATTN_MODELS or (self.model_type == "bloom" and "value" in name)
else 2
)

Expand Down Expand Up @@ -548,6 +548,8 @@ def _from_pretrained(
init_cls = ORTMPTForCausalLM
elif config.model_type == "opt":
init_cls = ORTOPTForCausalLM
elif config.model_type == "gpt_bigcode":
init_cls = ORTGPTBigCodeForCausalLM
else:
init_cls = ORTModelForCausalLM

Expand Down Expand Up @@ -672,6 +674,49 @@ def can_generate(self):
return True


class ORTGPTBigCodeForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
# Omit tokens covered by past_key_values
if past_key_values:
if self.config.multi_query:
past_length = past_key_values[0].shape[1]
else:
past_length = past_key_values[0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
else:
position_ids = None

model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}
)
return model_inputs


class ORTBloomForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
Expand Down Expand Up @@ -832,7 +877,11 @@ def prepare_inputs_for_generation(
**kwargs,
) -> dict:
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
if past_key_values[0][0].ndim != 3:
# Compared to transformers, we do not use _convert_cache_to_standard_format in the model itself, hence the 3D cache.
raise ValueError("Falcon uses 3D KV cache.")

past_length = past_key_values[0][0].shape[1]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
Expand All @@ -841,17 +890,13 @@ def prepare_inputs_for_generation(
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]

# the cache may be in the stardard format (e.g. in contrastive search), convert to falcon's format if needed
if len(past_key_values[0][0].shape) == 4:
past_key_values = self._convert_to_rw_cache(past_key_values)

# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.config.alibi and attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
position_ids = position_ids[:, -input_ids.shape[1] :]

return {
"input_ids": input_ids,
Expand Down

0 comments on commit 4096b55

Please sign in to comment.