diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 1e9e9eb79b1..0ecc8a90b7f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -48,6 +48,7 @@ "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", + "initialize_channel_wise_scale_zp", "QConfigProperties", "LINEAR_ACTIVATION_NAMES", "CONV_ACTIVATION_NAMES", @@ -710,6 +711,58 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, submodule_qconfig) +def initialize_channel_wise_scale_zp(module: Module): + """ + On torch channel-wise quantization, zero points and scales are + initialized to a default size of (1,) instead of their true size + of (num_output_channels,). This can cause issues on reloading + of saved checkpoints due to shape mismatch. This function expands + these initial scales and zero points to match the true expected + shape + + :param module: qat ready, uncalibrated model + """ + for name, submodule in module.named_modules(): + weight_fake_quant = getattr(submodule, "weight_fake_quant", None) + if not weight_fake_quant or ( + getattr(weight_fake_quant, "qscheme", None) + not in [torch.per_channel_affine, torch.per_channel_symmetric] + ): + # only consider modules with channel-wise quantized weights + continue + num_channels = None + if hasattr(submodule, "out_features"): + # matmul layers + num_channels = submodule.out_features + elif hasattr(submodule, "out_channels"): + num_channels = submodule.out_channels + + if not num_channels: + # unable to infer num_channels or num_channels is 0 + continue + + # update scale and zero point if they are initialized to a size of 1 + scale = weight_fake_quant.scale + if scale.numel() == 1: + weight_fake_quant.scale = torch.ones(num_channels, dtype=scale.dtype) + + zero_point = weight_fake_quant.zero_point + if zero_point.numel() == 1: + weight_fake_quant.zero_point = torch.ones( + num_channels, dtype=zero_point.dtype + ) + + # update the observer min and max vals + if weight_fake_quant.activation_post_process.min_val.numel() == 0: + weight_fake_quant.activation_post_process.min_val = torch.empty_like( + weight_fake_quant.scale + ) + if weight_fake_quant.activation_post_process.max_val.numel() == 0: + weight_fake_quant.activation_post_process.max_val = torch.empty_like( + weight_fake_quant.scale + ) + + def _delete_get_block_hooks( module: Module, fuse_blocks: List[List[str]], diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 36f8b3bb0fe..f28f04cc560 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -40,6 +40,9 @@ from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer +from sparseml.pytorch.sparsification.quantization.helpers import ( + initialize_channel_wise_scale_zp, +) from sparseml.pytorch.utils import ( LoggerManager, ModuleSparsificationInfo, @@ -671,6 +674,13 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]): ) return False + # PerChannel quantization observers initialize variables + # to dummy shapes that do not match the ones saved in + # state_dict. + # Need to reshape these variables in order to load state_dict + # properly. + initialize_channel_wise_scale_zp(self.model) + current_state_dict = self.model.state_dict() if set(orig_state_dict.keys()) == set(current_state_dict):