Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update transformer version to <4.40 #2204

Merged
merged 12 commits into from
Apr 3, 2024
Merged
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@
"opencv-python<=4.6.0.66",
]
_transformers_deps = _pytorch_deps + [
"transformers<4.35.0",
"datasets<=2.14.6",
"transformers<4.40",
"datasets<2.19",
"dvc",
"scikit-learn",
"seqeval",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ quantization_modifiers:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
- SiLUActivation
- SiLU
- model.layers.0.mlp.down_proj
- model.layers.1.mlp.down_proj
- model.layers.2.mlp.down_proj
Expand Down
29 changes: 1 addition & 28 deletions src/sparseml/transformers/finetune/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,39 +109,13 @@ def __init__(self, trainer, *args, **kwargs):
self.on_begin_called = False
self.quant_start_epoch = math.inf

def check_disable(self, epoch: float, force: bool = False):
"""
If needed due to active quantization, disable FP16 training
"""
if (
force or hasattr(self.trainer, "scaler") and self.trainer.scaler._enabled
) and self.qat_active():
self.disable_amp(epoch)

def qat_active(self) -> bool:
"""
:return: True if a quantization modifier is active in the current session
"""
session = session_manager.active_session()
return session.state.model.qat_active()

def disable_amp(self, epoch: float):
"""
Disable FP16 training

:param epoch: epoch to disable from
"""
if not self.on_begin_called:
# disable if training loops haven't started so we don't load
# the empty scaler state dict and instead disable it from the start
self.trainer.use_cuda_amp = False

if hasattr(self.trainer, "scaler"):
self.trainer.scaler._enabled = False

self.quant_start_epoch = epoch
_LOGGER.info(f"entering QAT phase at epoch {epoch}, disabling FP16 training")

def on_epoch_begin(
self,
args: TrainingArguments,
Expand All @@ -150,8 +124,7 @@ def on_epoch_begin(
**kwargs,
):
"""
Event called at the beginning of an epoch. Disables FP16 training.
Event called at the beginning of an epoch.
"""
super().on_epoch_begin(args, state, control, **kwargs)
self.on_begin_called = True
self.check_disable(state.epoch)
7 changes: 0 additions & 7 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):
"""
checkpoint, epoch = self._calculate_checkpoint_info(kwargs)
self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage)
self.callback_disable_fp16.check_disable(epoch, force=True)
self.accelerator.wait_for_everyone()
output = super().train(*args, **kwargs)
self.accelerator.wait_for_everyone()
Expand Down Expand Up @@ -393,13 +392,7 @@ def evaluate(self, *args, **kwargs):
"""
self.initialize_structure()

# Always evaluate w/ fp32 to be closer to DeepSparse
use_cuda_amp = self.use_cuda_amp
if not self.args.fp16_full_eval and not self.args.bf16_full_eval:
self.use_cuda_amp = False

output = super().evaluate(*args, **kwargs)
self.use_cuda_amp = use_cuda_amp
self.finalize_session()

return output
Expand Down
4 changes: 0 additions & 4 deletions src/sparseml/transformers/finetune/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None):
os.path.join(output_dir, "scheduler.pt"),
)
reissue_pt_warnings(caught_warnings)
if self.use_cuda_amp:
torch.save(
self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt")
)

def _save_checkpoint(self, model, trial, metrics=None):
# Call into the save checkpoint by HF Transformers, which saves the
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/transformers/sparsification/modification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

__all__ = ["check_transformers_version"]

_TRANSFORMERS_MIN_VERSION = "4.34.1"
_TRANSFORMERS_MAX_VERSION = "4.35.0"
_TRANSFORMERS_MIN_VERSION = "4.39.0"
_TRANSFORMERS_MAX_VERSION = "4.39.2"


def check_transformers_version(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
context of SparseML
"""


import logging
import math
from typing import Optional, Tuple
Expand Down Expand Up @@ -122,22 +123,16 @@ def forward(

use_cache = past_key_value is not None
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor)
# of all cross attention key/value_states.
# Further calls to cross_attention
# layer can then reuse all cross-attention
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # noqa
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention
# (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states.
# Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to
# current projected key/value_states (third "elif" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # noqa
# all previous decoder key/value_states. Further calls to uni-directional self-attention # noqa
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # noqa
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key"
# to get the raw attention scores.
# Take the dot product between "query" and "key" to get the raw attention scores. # noqa
# ==== SparseML MODIFICATION ====
attention_scores = self.attention_scores_matmul(
query_layer, key_layer.transpose(-1, -2)
Expand Down Expand Up @@ -189,8 +184,7 @@ def forward(

attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is
# (precomputed for all layers in BertModel forward() function)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) # noqa
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@

import torch
from torch import nn
from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention
from transformers.models.distilbert.modeling_distilbert import (
DistilBertFlashAttention2,
MultiHeadSelfAttention,
)

from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.modification_objects import (
Expand All @@ -45,6 +48,9 @@ def modify(model: nn.Module) -> nn.Module:
1. Replaces the MultiHeadSelfAttention modules with
MultiHeadSelfAttentionWithQuantizableMatmuls modules

Note: This function will not alter any of the alternatives
to the MultiHeadSelfAttention module such as DistilBertFlashAttention2

:param model: the original DistilBert model
:return: the modified DistilBert model
"""
Expand All @@ -53,6 +59,11 @@ def modify(model: nn.Module) -> nn.Module:
swap_modules(
model, name, MultiHeadSelfAttentionWithQuantizableMatmuls(submodule)
)
if isinstance(submodule, DistilBertFlashAttention2):
_LOGGER.debug(
f"The model contains {submodule.__class__.__name__} "
"module, which will not be modified"
)
return model


Expand Down Expand Up @@ -92,15 +103,12 @@ def forward(
mask: torch.tensor(bs, seq_length)

Returns:
weights: torch.tensor(bs, n_heads, seq_length, seq_length)
Attention weights context: torch.tensor(bs,
seq_length, dim) Contextualized layer.
Optional: only if `output_attentions=True`
weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, # noqa
seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` # noqa
"""
bs, q_length, dim = query.size()
k_length = key.size(1)
# assert dim == self.dim, f'Dimensions do not match:
# {dim} input vs {self.dim} configured'
# assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' # noqa
# assert key.size() == value.size()

dim_per_head = self.dim // self.n_heads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import torch.nn.functional as F
from torch import nn
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
apply_rotary_pos_emb,
repeat_kv,
)
Expand Down Expand Up @@ -54,14 +56,15 @@ def modify(model: nn.Module) -> nn.Module:

Note: This function will not alter any of the alternatives
to the LlamaAttention module such as LlamaFlashAttention2
or LlamaSdpaAttention

:param model: the original LLaMa model
:return: the modified LLaMa model
"""
for name, submodule in model.named_modules():
if isinstance(submodule, LlamaAttention):
swap_modules(model, name, LlamaAttentionWithQuantizableMatmuls(submodule))
elif isinstance(submodule, LlamaFlashAttention2):
elif isinstance(submodule, (LlamaSdpaAttention, LlamaFlashAttention2)):
_LOGGER.debug(
f"The model contains {submodule.__class__.__name__} "
"module, which will not be modified"
Expand Down Expand Up @@ -121,10 +124,11 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -171,20 +175,18 @@ def forward(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
query_states, key_states, cos, sin
)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None
# sin and cos are specific to RoPE models; cache_position needed for the static cache # noqa
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand All @@ -195,33 +197,24 @@ def forward(
) / math.sqrt(self.head_dim)
# ==============================

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size "
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size "
f"{(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
# ==== SparseML MODIFICATION ====
attn_output = self.attn_output_matmul(attn_weights, value_states)
# ===============================

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size "
f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" # noqa
f" {attn_output.size()}"
)

Expand Down
Loading
Loading