diff --git a/setup.py b/setup.py index 4af84f957b0..27802de438c 100644 --- a/setup.py +++ b/setup.py @@ -79,8 +79,8 @@ "opencv-python<=4.6.0.66", ] _transformers_deps = _pytorch_deps + [ - "transformers<4.35.0", - "datasets<=2.14.6", + "transformers<4.40", + "datasets<2.19", "dvc", "scikit-learn", "seqeval", diff --git a/src/sparseml/experimental/sparsegpt/examples/llama2/recipes/llama_recipe.yaml b/src/sparseml/experimental/sparsegpt/examples/llama2/recipes/llama_recipe.yaml index 41513e49946..3056735e040 100644 --- a/src/sparseml/experimental/sparsegpt/examples/llama2/recipes/llama_recipe.yaml +++ b/src/sparseml/experimental/sparsegpt/examples/llama2/recipes/llama_recipe.yaml @@ -11,7 +11,7 @@ quantization_modifiers: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm - - SiLUActivation + - SiLU - model.layers.0.mlp.down_proj - model.layers.1.mlp.down_proj - model.layers.2.mlp.down_proj diff --git a/src/sparseml/transformers/finetune/callbacks.py b/src/sparseml/transformers/finetune/callbacks.py index 1c483b5d99f..196240a2b8f 100644 --- a/src/sparseml/transformers/finetune/callbacks.py +++ b/src/sparseml/transformers/finetune/callbacks.py @@ -109,15 +109,6 @@ def __init__(self, trainer, *args, **kwargs): self.on_begin_called = False self.quant_start_epoch = math.inf - def check_disable(self, epoch: float, force: bool = False): - """ - If needed due to active quantization, disable FP16 training - """ - if ( - force or hasattr(self.trainer, "scaler") and self.trainer.scaler._enabled - ) and self.qat_active(): - self.disable_amp(epoch) - def qat_active(self) -> bool: """ :return: True if a quantization modifier is active in the current session @@ -125,23 +116,6 @@ def qat_active(self) -> bool: session = session_manager.active_session() return session.state.model.qat_active() - def disable_amp(self, epoch: float): - """ - Disable FP16 training - - :param epoch: epoch to disable from - """ - if not self.on_begin_called: - # disable if training loops haven't started so we don't load - # the empty scaler state dict and instead disable it from the start - self.trainer.use_cuda_amp = False - - if hasattr(self.trainer, "scaler"): - self.trainer.scaler._enabled = False - - self.quant_start_epoch = epoch - _LOGGER.info(f"entering QAT phase at epoch {epoch}, disabling FP16 training") - def on_epoch_begin( self, args: TrainingArguments, @@ -150,8 +124,7 @@ def on_epoch_begin( **kwargs, ): """ - Event called at the beginning of an epoch. Disables FP16 training. + Event called at the beginning of an epoch. """ super().on_epoch_begin(args, state, control, **kwargs) self.on_begin_called = True - self.check_disable(state.epoch) diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index 72d18d98a9b..3971b1c0a02 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -363,7 +363,6 @@ def train(self, *args, stage: Optional[str] = None, **kwargs): """ checkpoint, epoch = self._calculate_checkpoint_info(kwargs) self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage) - self.callback_disable_fp16.check_disable(epoch, force=True) self.accelerator.wait_for_everyone() output = super().train(*args, **kwargs) self.accelerator.wait_for_everyone() @@ -393,13 +392,7 @@ def evaluate(self, *args, **kwargs): """ self.initialize_structure() - # Always evaluate w/ fp32 to be closer to DeepSparse - use_cuda_amp = self.use_cuda_amp - if not self.args.fp16_full_eval and not self.args.bf16_full_eval: - self.use_cuda_amp = False - output = super().evaluate(*args, **kwargs) - self.use_cuda_amp = use_cuda_amp self.finalize_session() return output diff --git a/src/sparseml/transformers/finetune/trainer.py b/src/sparseml/transformers/finetune/trainer.py index cf920e1feb6..36a850f251b 100644 --- a/src/sparseml/transformers/finetune/trainer.py +++ b/src/sparseml/transformers/finetune/trainer.py @@ -91,10 +91,6 @@ def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): os.path.join(output_dir, "scheduler.pt"), ) reissue_pt_warnings(caught_warnings) - if self.use_cuda_amp: - torch.save( - self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt") - ) def _save_checkpoint(self, model, trial, metrics=None): # Call into the save checkpoint by HF Transformers, which saves the diff --git a/src/sparseml/transformers/sparsification/modification/base.py b/src/sparseml/transformers/sparsification/modification/base.py index 946e6499851..6d9435b8b8b 100644 --- a/src/sparseml/transformers/sparsification/modification/base.py +++ b/src/sparseml/transformers/sparsification/modification/base.py @@ -23,8 +23,8 @@ __all__ = ["check_transformers_version"] -_TRANSFORMERS_MIN_VERSION = "4.34.1" -_TRANSFORMERS_MAX_VERSION = "4.35.0" +_TRANSFORMERS_MIN_VERSION = "4.39.0" +_TRANSFORMERS_MAX_VERSION = "4.39.2" def check_transformers_version( diff --git a/src/sparseml/transformers/sparsification/modification/modifying_bert.py b/src/sparseml/transformers/sparsification/modification/modifying_bert.py index d53046abd03..20e2e8ded4e 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_bert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_bert.py @@ -17,6 +17,7 @@ context of SparseML """ + import logging import math from typing import Optional, Tuple @@ -122,22 +123,16 @@ def forward( use_cache = past_key_value is not None if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) - # of all cross attention key/value_states. - # Further calls to cross_attention - # layer can then reuse all cross-attention + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # noqa + # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) - # if uni-directional self-attention - # (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. - # Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to - # current projected key/value_states (third "elif" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # noqa + # all previous decoder key/value_states. Further calls to uni-directional self-attention # noqa + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # noqa # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_layer, value_layer) - # Take the dot product between "query" and "key" - # to get the raw attention scores. + # Take the dot product between "query" and "key" to get the raw attention scores. # noqa # ==== SparseML MODIFICATION ==== attention_scores = self.attention_scores_matmul( query_layer, key_layer.transpose(-1, -2) @@ -189,8 +184,7 @@ def forward( attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: - # Apply the attention mask is - # (precomputed for all layers in BertModel forward() function) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) # noqa attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. diff --git a/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py b/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py index 0312f5c6bac..c37da2cbdd0 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py @@ -23,7 +23,10 @@ import torch from torch import nn -from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention +from transformers.models.distilbert.modeling_distilbert import ( + DistilBertFlashAttention2, + MultiHeadSelfAttention, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.modification_objects import ( @@ -45,6 +48,9 @@ def modify(model: nn.Module) -> nn.Module: 1. Replaces the MultiHeadSelfAttention modules with MultiHeadSelfAttentionWithQuantizableMatmuls modules + Note: This function will not alter any of the alternatives + to the MultiHeadSelfAttention module such as DistilBertFlashAttention2 + :param model: the original DistilBert model :return: the modified DistilBert model """ @@ -53,6 +59,11 @@ def modify(model: nn.Module) -> nn.Module: swap_modules( model, name, MultiHeadSelfAttentionWithQuantizableMatmuls(submodule) ) + if isinstance(submodule, DistilBertFlashAttention2): + _LOGGER.debug( + f"The model contains {submodule.__class__.__name__} " + "module, which will not be modified" + ) return model @@ -92,15 +103,12 @@ def forward( mask: torch.tensor(bs, seq_length) Returns: - weights: torch.tensor(bs, n_heads, seq_length, seq_length) - Attention weights context: torch.tensor(bs, - seq_length, dim) Contextualized layer. - Optional: only if `output_attentions=True` + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, # noqa + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` # noqa """ bs, q_length, dim = query.size() k_length = key.size(1) - # assert dim == self.dim, f'Dimensions do not match: - # {dim} input vs {self.dim} configured' + # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' # noqa # assert key.size() == value.size() dim_per_head = self.dim // self.n_heads diff --git a/src/sparseml/transformers/sparsification/modification/modifying_llama.py b/src/sparseml/transformers/sparsification/modification/modifying_llama.py index ae5998ed69f..6c89469f524 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_llama.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_llama.py @@ -25,8 +25,10 @@ import torch.nn.functional as F from torch import nn from transformers.models.llama.modeling_llama import ( + Cache, LlamaAttention, LlamaFlashAttention2, + LlamaSdpaAttention, apply_rotary_pos_emb, repeat_kv, ) @@ -54,6 +56,7 @@ def modify(model: nn.Module) -> nn.Module: Note: This function will not alter any of the alternatives to the LlamaAttention module such as LlamaFlashAttention2 + or LlamaSdpaAttention :param model: the original LLaMa model :return: the modified LLaMa model @@ -61,7 +64,7 @@ def modify(model: nn.Module) -> nn.Module: for name, submodule in model.named_modules(): if isinstance(submodule, LlamaAttention): swap_modules(model, name, LlamaAttentionWithQuantizableMatmuls(submodule)) - elif isinstance(submodule, LlamaFlashAttention2): + elif isinstance(submodule, (LlamaSdpaAttention, LlamaFlashAttention2)): _LOGGER.debug( f"The model contains {submodule.__class__.__name__} " "module, which will not be modified" @@ -121,10 +124,11 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -171,20 +175,18 @@ def forward( bsz, q_len, self.num_key_value_heads, self.head_dim ).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids + query_states, key_states, cos, sin ) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + # sin and cos are specific to RoPE models; cache_position needed for the static cache # noqa + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -195,33 +197,24 @@ def forward( ) / math.sqrt(self.head_dim) # ============================== - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size " - f"{(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) # ==== SparseML MODIFICATION ==== attn_output = self.attn_output_matmul(attn_weights, value_states) # =============================== if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" # noqa f" {attn_output.size()}" ) diff --git a/src/sparseml/transformers/sparsification/modification/modifying_mistral.py b/src/sparseml/transformers/sparsification/modification/modifying_mistral.py index 2c206cce091..28d9d7f109f 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_mistral.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_mistral.py @@ -18,13 +18,16 @@ """ import logging import math +import warnings from typing import Optional, Tuple import torch from torch import nn from transformers.models.mistral.modeling_mistral import ( + Cache, MistralAttention, MistralFlashAttention2, + MistralSdpaAttention, apply_rotary_pos_emb, repeat_kv, ) @@ -52,6 +55,7 @@ def modify(model: torch.nn.Module) -> torch.nn.Module: Note: This function will not alter any of the alternatives to the MistralAttention module such as MistralFlashAttention2 + or MistralSdpaAttention :param model: the original Mistral model :return: the modified Mistral model @@ -59,7 +63,7 @@ def modify(model: torch.nn.Module) -> torch.nn.Module: for name, submodule in model.named_modules(): if isinstance(submodule, MistralAttention): swap_modules(model, name, MistralAttentionWithQuantizableMatmuls(submodule)) - if isinstance(submodule, MistralFlashAttention2): + if isinstance(submodule, (MistralSdpaAttention, MistralFlashAttention2)): _LOGGER.debug( f"The model contains {submodule.__class__.__name__} " "module, which will not be modified" @@ -112,11 +116,15 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" # noqa + ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -135,18 +143,23 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " # noqa + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " # noqa + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -160,16 +173,14 @@ def forward( if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" # noqa f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size " - f"{(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" # noqa ) attn_weights = attn_weights + attention_mask @@ -178,14 +189,16 @@ def forward( attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) # ==== SparseML MODIFICATION ==== attn_output = self.attn_output_matmul(attn_weights, value_states) # =============================== if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" # noqa f" {attn_output.size()}" ) diff --git a/src/sparseml/transformers/sparsification/modification/modifying_opt.py b/src/sparseml/transformers/sparsification/modification/modifying_opt.py index fad36f05f96..4d2fd58c4f2 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_opt.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_opt.py @@ -22,7 +22,7 @@ import torch from torch import nn -from transformers.models.opt.modeling_opt import OPTAttention +from transformers.models.opt.modeling_opt import OPTAttention, OptFlashAttention2 from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.modification_objects import ( @@ -45,12 +45,20 @@ def modify(model: nn.Module) -> nn.Module: 1. Replaces the OPTAttention modules with OPTAttentionWithQuantizableMatmuls modules + Note: This function will not alter any of the alternatives + to the OPTAttention module such as OptFlashAttention2 + :param model: the original LLaMa model :return: the modified LLaMa model """ for name, submodule in model.named_modules(): if isinstance(submodule, OPTAttention): swap_modules(model, name, OPTAttentionWithQuantizableMatmuls(submodule)) + elif isinstance(submodule, OptFlashAttention2): + _LOGGER.debug( + f"The model contains {submodule.__class__.__name__} " + "module, which will not be modified" + ) return model @@ -141,19 +149,13 @@ def forward( value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) - # of all cross attention key/value_states. - # Further calls to cross_attention layer - # can then reuse all cross-attention + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # noqa + # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) - # save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. - # Further calls to uni-directional self-attention - # can concat previous decoder key/value_states - # to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention - # `past_key_value` is always `None` + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # noqa + # all previous decoder key/value_states. Further calls to uni-directional self-attention # noqa + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # noqa + # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) @@ -168,16 +170,14 @@ def forward( if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( - f"Attention weights should be of size " - f"{(bsz * self.num_heads, tgt_len, src_len)}, but is" + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" # noqa f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( - f"Attention mask should be of size " - f"{(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" # noqa ) attn_weights = ( attn_weights.view(bsz, self.num_heads, tgt_len, src_len) @@ -191,8 +191,7 @@ def forward( ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - # upcast to fp32 if the weights are in fp16. - # Please see https://github.com/huggingface/transformers/pull/17437 + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 # noqa if attn_weights.dtype == torch.float16: attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 @@ -203,8 +202,7 @@ def forward( if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( - f"Head mask for a single layer " - f"should be of size {(self.num_heads,)}, but is" + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" # noqa f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( @@ -236,16 +234,14 @@ def forward( if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size " - f"{(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" # noqa f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) - # Use the `embed_dim` from the config (stored in the class) - # rather than `hidden_state` because `attn_output` can be + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # noqa # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) diff --git a/src/sparseml/transformers/sparsification/obcq/README.md b/src/sparseml/transformers/sparsification/obcq/README.md index b1d2ee6c34b..28f686f5afd 100644 --- a/src/sparseml/transformers/sparsification/obcq/README.md +++ b/src/sparseml/transformers/sparsification/obcq/README.md @@ -194,7 +194,7 @@ test_stage: # These operations don't make sense to quantize - LlamaRotaryEmbedding - LlamaRMSNorm - - SiLUActivation + - SiLU # Skip quantizing the BMMs - QuantizableMatMul # Skip quantizing the layers with the most sensitive activations @@ -242,7 +242,7 @@ test_stage: # These operations don't make sense to quantize - MistralRotaryEmbedding - MistralRMSNorm - - SiLUActivation + - SiLU # Skip quantizing the layers with the most sensitive activations - model.layers.1.mlp.down_proj - model.layers.31.mlp.down_proj diff --git a/src/sparseml/transformers/sparsification/obcq/example_llama.yaml b/src/sparseml/transformers/sparsification/obcq/example_llama.yaml index db22f39ad0e..da265bf7d27 100644 --- a/src/sparseml/transformers/sparsification/obcq/example_llama.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example_llama.yaml @@ -10,7 +10,7 @@ test_stage: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm - - SiLUActivation + - SiLU - model.layers.0.mlp.down_proj - model.layers.1.mlp.down_proj - model.layers.2.mlp.down_proj diff --git a/src/sparseml/transformers/sparsification/obcq/example_mistral.yaml b/src/sparseml/transformers/sparsification/obcq/example_mistral.yaml index ba9c4124c1b..7800c9b9b09 100644 --- a/src/sparseml/transformers/sparsification/obcq/example_mistral.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example_mistral.yaml @@ -4,7 +4,7 @@ test_stage: ignore: - MistralRotaryEmbedding - MistralRMSNorm - - SiLUActivation + - SiLU - model.layers.1.mlp.down_proj - model.layers.31.mlp.down_proj - model.layers.30.mlp.down_proj diff --git a/src/sparseml/transformers/sparsification/question_answering.py b/src/sparseml/transformers/sparsification/question_answering.py index a681122b5d0..ea933b92705 100644 --- a/src/sparseml/transformers/sparsification/question_answering.py +++ b/src/sparseml/transformers/sparsification/question_answering.py @@ -79,11 +79,6 @@ def evaluate( eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_examples = self.eval_examples if eval_examples is None else eval_examples - # Always evaluate w/ fp32 to be closer to DeepSparse - use_cuda_amp = self.use_cuda_amp - if not self.args.fp16_full_eval and not self.args.bf16_full_eval: - self.use_cuda_amp = False - # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics self.compute_metrics = None @@ -129,8 +124,6 @@ def evaluate( self.args, self.state, self.control, metrics ) - self.use_cuda_amp = use_cuda_amp - return metrics def predict( diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index bc45bec6d97..035c0215e1f 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -35,7 +35,7 @@ from transformers.integrations import TensorBoardCallback from transformers.trainer_callback import TrainerState from transformers.trainer_pt_utils import reissue_pt_warnings -from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint +from transformers.trainer_utils import get_last_checkpoint from sparseml.pytorch.model_load.helpers import log_model_load from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer @@ -787,7 +787,6 @@ def train(self, *args, **kwargs): """ checkpoint, epoch = self._generate_apply_manager_params(kwargs) applied = self.apply_manager(epoch=epoch, checkpoint=checkpoint) - self.callback_disable_fp16.check_disable(epoch, force=True) output = None if not self.one_shot: output = super().train(*args, **kwargs) @@ -811,13 +810,7 @@ def evaluate(self, *args, **kwargs): """ applied = self.apply_manager(epoch=math.inf, checkpoint=None) - # Always evaluate w/ fp32 to be closer to DeepSparse - use_cuda_amp = self.use_cuda_amp - if not self.args.fp16_full_eval and not self.args.bf16_full_eval: - self.use_cuda_amp = False - output = super().evaluate(*args, **kwargs) - self.use_cuda_amp = use_cuda_amp if applied: self.finalize_manager() @@ -894,9 +887,6 @@ def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): if output_dir is None: output_dir = self.args.output_dir - if self.sharded_ddp == ShardedDDPOption.SIMPLE and self.optimizer is not None: - self.optimizer.consolidate_state_dict() - if self.is_world_process_zero(): if self.optimizer is not None: torch.save( @@ -910,10 +900,6 @@ def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): os.path.join(output_dir, "scheduler.pt"), ) reissue_pt_warnings(caught_warnings) - if self.use_cuda_amp: - torch.save( - self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt") - ) def _load_optimizer_and_scheduler(self, checkpoint): """ @@ -1027,12 +1013,6 @@ def __init__(self, trainer: RecipeManagerTrainerInterface, *args, **kwargs): self.on_begin_called = False self.quant_start_epoch = math.inf - def check_disable(self, epoch: float, force: bool = False): - if ( - force or hasattr(self.trainer, "scaler") and self.trainer.scaler._enabled - ) and self.qat_active(epoch): - self.disable_amp(epoch) - def qat_active(self, epoch: float) -> bool: manager_q_active = arch_manager_q_active = False if self.trainer.manager: @@ -1043,18 +1023,6 @@ def qat_active(self, epoch: float) -> bool: ) return manager_q_active or arch_manager_q_active - def disable_amp(self, epoch: float): - if not self.on_begin_called: - # disable if training loops haven't started so we don't load - # the empty scaler state dict and instead disable it from the start - self.trainer.use_cuda_amp = False - - if hasattr(self.trainer, "scaler"): - self.trainer.scaler._enabled = False - - self.quant_start_epoch = epoch - _LOGGER.info(f"entering QAT phase at epoch {epoch}, disabling FP16 training") - def on_epoch_begin( self, args: TrainingArguments, @@ -1067,7 +1035,6 @@ def on_epoch_begin( """ super().on_epoch_begin(args, state, control, **kwargs) self.on_begin_called = True - self.check_disable(state.epoch) if state.epoch > self.quant_start_epoch: _LOGGER.info(self.trainer.model) diff --git a/tests/sparseml/transformers/finetune/test_quantization.yaml b/tests/sparseml/transformers/finetune/test_quantization.yaml index 825074e227d..89381c31006 100644 --- a/tests/sparseml/transformers/finetune/test_quantization.yaml +++ b/tests/sparseml/transformers/finetune/test_quantization.yaml @@ -4,7 +4,7 @@ test_stage: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm - - SiLUActivation + - SiLU - model.layers.0.mlp.down_proj - model.layers.1.mlp.down_proj - model.layers.2.mlp.down_proj diff --git a/tests/sparseml/transformers/obcq/test_repeats.py b/tests/sparseml/transformers/obcq/test_repeats.py index 93cd7667841..d4b2d2ee5a0 100644 --- a/tests/sparseml/transformers/obcq/test_repeats.py +++ b/tests/sparseml/transformers/obcq/test_repeats.py @@ -97,7 +97,7 @@ def test_fail_on_repeated_quant(tmp_path): ignore: - LlamaRotaryEmbedding - LlamaRMSNorm - - SiLUActivation + - SiLU scheme_overrides: Embedding: input_activations: null @@ -110,7 +110,7 @@ def test_fail_on_repeated_quant(tmp_path): ignore: - LlamaRotaryEmbedding - LlamaRMSNorm - - SiLUActivation + - SiLU - Embedding """ @@ -152,7 +152,7 @@ def test_separate_quants_allowed(tmp_path): ignore: - LlamaRotaryEmbedding - LlamaRMSNorm - - SiLUActivation + - SiLU - Linear scheme_overrides: Embedding: @@ -166,7 +166,7 @@ def test_separate_quants_allowed(tmp_path): ignore: - LlamaRotaryEmbedding - LlamaRMSNorm - - SiLUActivation + - SiLU - Embedding - MatMulLeftInput_QK - MatMulRightInput_QK diff --git a/tests/sparseml/transformers/obcq/test_tiny.yaml b/tests/sparseml/transformers/obcq/test_tiny.yaml index 7949b454d90..422baf87580 100644 --- a/tests/sparseml/transformers/obcq/test_tiny.yaml +++ b/tests/sparseml/transformers/obcq/test_tiny.yaml @@ -10,7 +10,7 @@ test_stage: ignore: - LlamaRotaryEmbedding - LlamaRMSNorm - - SiLUActivation + - SiLU - model.layers.0.mlp.down_proj - model.layers.1.mlp.down_proj - model.layers.2.mlp.down_proj