Skip to content

Commit

Permalink
Update transformer version to <4.40 (#2204)
Browse files Browse the repository at this point in the history
* initial commit

* initial commit

* fixing tests

* Update max transformers version

* Update mintransformers version

* initial commit

* fixing tests 1

* fixing tests 2

* quality

* fix bad rebase & quality

* Update setup.py

---------

Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
dbogunowicz and mgoin authored Apr 3, 2024
1 parent 72001e8 commit 8ba1dba
Show file tree
Hide file tree
Showing 19 changed files with 111 additions and 185 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@
"opencv-python<=4.6.0.66",
]
_transformers_deps = _pytorch_deps + [
"transformers<4.35.0",
"datasets<=2.14.6",
"transformers<4.40",
"datasets<2.19",
"dvc",
"scikit-learn",
"seqeval",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ quantization_modifiers:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
- SiLUActivation
- SiLU
- model.layers.0.mlp.down_proj
- model.layers.1.mlp.down_proj
- model.layers.2.mlp.down_proj
Expand Down
29 changes: 1 addition & 28 deletions src/sparseml/transformers/finetune/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,39 +109,13 @@ def __init__(self, trainer, *args, **kwargs):
self.on_begin_called = False
self.quant_start_epoch = math.inf

def check_disable(self, epoch: float, force: bool = False):
"""
If needed due to active quantization, disable FP16 training
"""
if (
force or hasattr(self.trainer, "scaler") and self.trainer.scaler._enabled
) and self.qat_active():
self.disable_amp(epoch)

def qat_active(self) -> bool:
"""
:return: True if a quantization modifier is active in the current session
"""
session = session_manager.active_session()
return session.state.model.qat_active()

def disable_amp(self, epoch: float):
"""
Disable FP16 training
:param epoch: epoch to disable from
"""
if not self.on_begin_called:
# disable if training loops haven't started so we don't load
# the empty scaler state dict and instead disable it from the start
self.trainer.use_cuda_amp = False

if hasattr(self.trainer, "scaler"):
self.trainer.scaler._enabled = False

self.quant_start_epoch = epoch
_LOGGER.info(f"entering QAT phase at epoch {epoch}, disabling FP16 training")

def on_epoch_begin(
self,
args: TrainingArguments,
Expand All @@ -150,8 +124,7 @@ def on_epoch_begin(
**kwargs,
):
"""
Event called at the beginning of an epoch. Disables FP16 training.
Event called at the beginning of an epoch.
"""
super().on_epoch_begin(args, state, control, **kwargs)
self.on_begin_called = True
self.check_disable(state.epoch)
7 changes: 0 additions & 7 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):
"""
checkpoint, epoch = self._calculate_checkpoint_info(kwargs)
self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage)
self.callback_disable_fp16.check_disable(epoch, force=True)
self.accelerator.wait_for_everyone()
output = super().train(*args, **kwargs)
self.accelerator.wait_for_everyone()
Expand Down Expand Up @@ -393,13 +392,7 @@ def evaluate(self, *args, **kwargs):
"""
self.initialize_structure()

# Always evaluate w/ fp32 to be closer to DeepSparse
use_cuda_amp = self.use_cuda_amp
if not self.args.fp16_full_eval and not self.args.bf16_full_eval:
self.use_cuda_amp = False

output = super().evaluate(*args, **kwargs)
self.use_cuda_amp = use_cuda_amp
self.finalize_session()

return output
Expand Down
4 changes: 0 additions & 4 deletions src/sparseml/transformers/finetune/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None):
os.path.join(output_dir, "scheduler.pt"),
)
reissue_pt_warnings(caught_warnings)
if self.use_cuda_amp:
torch.save(
self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt")
)

def _save_checkpoint(self, model, trial, metrics=None):
# Call into the save checkpoint by HF Transformers, which saves the
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/transformers/sparsification/modification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

__all__ = ["check_transformers_version"]

_TRANSFORMERS_MIN_VERSION = "4.34.1"
_TRANSFORMERS_MAX_VERSION = "4.35.0"
_TRANSFORMERS_MIN_VERSION = "4.39.0"
_TRANSFORMERS_MAX_VERSION = "4.39.2"


def check_transformers_version(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
context of SparseML
"""


import logging
import math
from typing import Optional, Tuple
Expand Down Expand Up @@ -122,22 +123,16 @@ def forward(

use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor)
# of all cross attention key/value_states.
# Further calls to cross_attention
# layer can then reuse all cross-attention
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # noqa
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention
# (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states.
# Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to
# current projected key/value_states (third "elif" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # noqa
# all previous decoder key/value_states. Further calls to uni-directional self-attention # noqa
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # noqa
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key"
# to get the raw attention scores.
# Take the dot product between "query" and "key" to get the raw attention scores. # noqa
# ==== SparseML MODIFICATION ====
attention_scores = self.attention_scores_matmul(
query_layer, key_layer.transpose(-1, -2)
Expand Down Expand Up @@ -189,8 +184,7 @@ def forward(

attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is
# (precomputed for all layers in BertModel forward() function)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) # noqa
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

import torch
from torch import nn
from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention
from transformers.models.distilbert.modeling_distilbert import (
DistilBertFlashAttention2,
MultiHeadSelfAttention,
)

from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.modification_objects import (
Expand All @@ -45,6 +48,9 @@ def modify(model: nn.Module) -> nn.Module:
1. Replaces the MultiHeadSelfAttention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules
Note: This function will not alter any of the alternatives
to the MultiHeadSelfAttention module such as DistilBertFlashAttention2
:param model: the original DistilBert model
:return: the modified DistilBert model
"""
Expand All @@ -53,6 +59,11 @@ def modify(model: nn.Module) -> nn.Module:
swap_modules(
model, name, MultiHeadSelfAttentionWithQuantizableMatmuls(submodule)
)
if isinstance(submodule, DistilBertFlashAttention2):
_LOGGER.debug(
f"The model contains {submodule.__class__.__name__} "
"module, which will not be modified"
)
return model


Expand Down Expand Up @@ -92,15 +103,12 @@ def forward(
mask: torch.tensor(bs, seq_length)
Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length)
Attention weights context: torch.tensor(bs,
seq_length, dim) Contextualized layer.
Optional: only if `output_attentions=True`
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, # noqa
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` # noqa
"""
bs, q_length, dim = query.size()
k_length = key.size(1)
# assert dim == self.dim, f'Dimensions do not match:
# {dim} input vs {self.dim} configured'
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' # noqa
# assert key.size() == value.size()

dim_per_head = self.dim // self.n_heads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import torch.nn.functional as F
from torch import nn
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
Expand Down Expand Up @@ -54,14 +56,15 @@ def modify(model: nn.Module) -> nn.Module:
Note: This function will not alter any of the alternatives
to the LlamaAttention module such as LlamaFlashAttention2
or LlamaSdpaAttention
:param model: the original LLaMa model
:return: the modified LLaMa model
"""
for name, submodule in model.named_modules():
if isinstance(submodule, LlamaAttention):
swap_modules(model, name, LlamaAttentionWithQuantizableMatmuls(submodule))
elif isinstance(submodule, LlamaFlashAttention2):
elif isinstance(submodule, (LlamaSdpaAttention, LlamaFlashAttention2)):
_LOGGER.debug(
f"The model contains {submodule.__class__.__name__} "
"module, which will not be modified"
Expand Down Expand Up @@ -121,10 +124,11 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -171,20 +175,18 @@ def forward(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
query_states, key_states, cos, sin
)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None
# sin and cos are specific to RoPE models; cache_position needed for the static cache # noqa
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand All @@ -195,33 +197,24 @@ def forward(
) / math.sqrt(self.head_dim)
# ==============================

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size "
f"{(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
# ==== SparseML MODIFICATION ====
attn_output = self.attn_output_matmul(attn_weights, value_states)
# ===============================

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" # noqa
f" {attn_output.size()}"
)

Expand Down
Loading

0 comments on commit 8ba1dba

Please sign in to comment.