Skip to content

Commit

Permalink
resolve
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenglongjiepheonix committed Sep 3, 2024
2 parents c752e29 + 8cb6832 commit 82d1cf9
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 134 deletions.
218 changes: 146 additions & 72 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from typing import Optional, Tuple

import torch
import torch.nn.functional as F

from ...utils import check_if_transformers_greater


# TODO (CRITICAL): Layer-wise attention scaling is broken for several archs.
Expand All @@ -23,7 +26,7 @@
def raise_on_head_mask(head_mask: Optional[torch.Tensor]):
if head_mask is not None:
raise ValueError(
"layer_head_mask different than None is unsupported for now with BetterTransformer, please"
"layer_head_mask (or head_mask) different than None is unsupported for now with BetterTransformer, please"
"open a PR or an issue at https://github.com/huggingface/optimum."
)

Expand Down Expand Up @@ -534,88 +537,159 @@ def bart_forward(
return attn_output, None, past_key_value


# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward
def bloom_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
raise_on_head_mask(head_mask)
if check_if_transformers_greater("4.44"):
from transformers.cache_utils import Cache
from transformers.models.bloom.modeling_bloom import dropout_add

# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward
def bloom_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Cache] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
raise_on_head_mask(head_mask)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

batch_size, q_length, _ = hidden_states.shape
# [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states)
# 3 x [batch_size, num_heads, seq_length, head_dim]
query_layer, key_layer, value_layer = self._reshape(fused_qkv)

if layer_past is not None:
cache_kwargs = {"cache_position": cache_position}
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)

alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])

if attention_mask is not None: # no matter the length, we just slice it
kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1
causal_mask = attention_mask[:, :, :, :kv_length]
alibi = torch.masked_fill(alibi, causal_mask.bool(), torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(batch_size, q_length, self.hidden_size)

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)

fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)

# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
outputs = (output_tensor, layer_past)

batch_size, q_length, _, _ = query_layer.shape
return outputs

# Permute to [batch_size, num_heads, seq_length, head_dim]
query_layer = query_layer.transpose(1, 2)
else:
# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward
def bloom_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
raise_on_head_mask(head_mask)

if layer_past is not None:
past_key, past_value = layer_past
past_key = past_key.transpose(1, 2)
if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
# [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states)

# concatenate along seq_length dimension
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

# untangle batch_size from self.num_heads
key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:])
value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:])
else:
key_layer = key_layer.transpose(1, 2)
value_layer = value_layer.transpose(1, 2)

alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)
batch_size, q_length, _, _ = query_layer.shape

# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(*context_layer.shape[:2], -1)

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + torch.nn.functional.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# Permute to [batch_size, num_heads, seq_length, head_dim]
query_layer = query_layer.transpose(1, 2)

if layer_past is not None:
past_key, past_value = layer_past
past_key = past_key.transpose(1, 2)

output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training)
output_tensor = residual + output_tensor
key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)

if use_cache is True:
present = (
key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2),
value_layer.reshape(-1, *value_layer.shape[2:]),
# concatenate along seq_length dimension
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)

# untangle batch_size from self.num_heads
key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:])
value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:])
else:
key_layer = key_layer.transpose(1, 2)
value_layer = value_layer.transpose(1, 2)

alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)
else:
present = None

return (output_tensor, present)
# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(*context_layer.shape[:2], -1)

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + torch.nn.functional.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)

output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training)
output_tensor = residual + output_tensor

if use_cache is True:
present = (
key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2),
value_layer.reshape(-1, *value_layer.shape[2:]),
)
else:
present = None

return (output_tensor, present)
2 changes: 2 additions & 0 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
self.dropout_prob_attn = config.attention_dropout

self.module_mapping = None
self.layer_idx = getattr(layer, "layer_idx", None)

submodules = ["query_key_value", "dense", "attention_dropout"]
for attr in submodules:
setattr(self, attr, getattr(layer, attr))
Expand Down
38 changes: 21 additions & 17 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,27 +338,31 @@ class BloomOnnxConfig(TextDecoderOnnxConfig):
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = BloomDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")
DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
if check_if_transformers_greater("4.44"):
super().add_past_key_values(inputs_or_outputs, direction)
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}
if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}


class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand Down
19 changes: 10 additions & 9 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,12 +1937,6 @@ def standardize_model_attributes(cls, model: Union["PreTrainedModel", "TFPreTrai
if inferred_model_type is not None:
break

if inferred_model_type is None:
raise ValueError(
f"The export of a DiffusionPipeline model with the class name {model.__class__.__name__} is currently not supported in Optimum. "
"Please open an issue or submit a PR to add the support."
)

# `model_type` is a class attribute in Transformers, let's avoid modifying it.
model.config.export_model_type = inferred_model_type

Expand Down Expand Up @@ -2068,9 +2062,16 @@ def get_model_from_task(
if original_task == "auto" and config.architectures is not None:
model_class_name = config.architectures[0]

model_class = TasksManager.get_model_class_for_task(
task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name
)
if library_name == "diffusers":
config = DiffusionPipeline.load_config(model_name_or_path, **kwargs)
class_name = config.get("_class_name", None)
loaded_library = importlib.import_module(library_name)
model_class = getattr(loaded_library, class_name)
else:
model_class = TasksManager.get_model_class_for_task(
task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name
)

if library_name == "timm":
model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True)
model = model.to(torch_dtype).to(device)
Expand Down
4 changes: 2 additions & 2 deletions optimum/fx/parallelization/op_registry/op_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

class Registry:
"""
Registry class handles registration of parallel axis propagation handlers of different aten ops, to support a new
aten op, you need to register the corresponding handler class by decorating it with `register` function.
Registry class handles registration of parallel axis propagation handlers of different aten ops.
To support a new aten op, you need to register the corresponding handler class by decorating it with `register` function.
"""

def __init__(self) -> None:
Expand Down
Loading

0 comments on commit 82d1cf9

Please sign in to comment.