From e7f86e7d80a7c9c1ed1114094e75b5e6189aa90d Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 26 Sep 2023 17:22:28 -0400 Subject: [PATCH 1/7] [QuanitztionModifier] initialize per channel scales and zps to correct shape --- .../sparsification/quantization/helpers.py | 40 +++++++++++++++++++ .../quantization/modifier_quantization.py | 5 +++ 2 files changed, 45 insertions(+) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 1e9e9eb79b1..6c569ee8b81 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -48,6 +48,7 @@ "freeze_bn_stats", "fuse_module_conv_bn_relus", "prepare_embeddings_qat", + "initialize_channel_wise_scale_zp", "QConfigProperties", "LINEAR_ACTIVATION_NAMES", "CONV_ACTIVATION_NAMES", @@ -710,6 +711,45 @@ def prepare_embeddings_qat( _prepare_qat_embedding(submodule, submodule_qconfig) +def initialize_channel_wise_scale_zp(module: Module): + """ + On torch channel-wise quantization, zero points and scales are + initialized to a default size of (1,) instead of their true size + of (num_output_channels,). This can cause issues on reloading + of saved checkpoints due to shape mismatch. This function expands + these initial scales and zero points to match the true expected + shape + + :param module: qat ready, uncalibrated model + """ + for name, submodule in module.named_modules(): + weight_fake_quant = getattr(submodule, "weight_fake_quant", None) + if not weight_fake_quant or ( + getattr(weight_fake_quant, "qscheme", None) is not torch.per_channel_affine + ): + # only consider modules with channel-wise quantized weights + continue + num_channels = None + if hasattr(submodule, "out_features"): + # matmul layers + num_channels = submodule.out_features + elif hasattr(submodule, "out_channels"): + num_channels = submodule.out_channels + + if not num_channels: + # unable to infer num_channels or num_channels is 0 + continue + + # update scale and zero point if they are initialized to a size of 1 + scale = weight_fake_quant.scale + if scale.numel() == 1: + weight_fake_quant.scale = scale.reshape(-1).expand(num_channels) + + zero_point = weight_fake_quant.zero_point + if zero_point.numel() == 1: + weight_fake_quant.zero_point = zero_point.reshape(-1).expand(num_channels) + + def _delete_get_block_hooks( module: Module, fuse_blocks: List[List[str]], diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index a604cfef44f..aa749c17a29 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -38,6 +38,7 @@ configure_module_bn_wrappers, freeze_bn_stats, fuse_module_conv_bn_relus, + initialize_channel_wise_scale_zp, ) from sparseml.pytorch.sparsification.quantization.legacy_modifier_quantization import ( QuantizationModifier as LegacyQuantizationModifier, @@ -516,6 +517,10 @@ def _enable_module_qat(self, module: Module): self._calibrate_if_possible(module) + # if channel-wise quantization is targeted, properly initialize + # the scale and zp shapes + initialize_channel_wise_scale_zp(module) + def _fuse(self, module: Module): if self.model_fuse_fn_name in [None, "conv_bn_relus"]: self._model_fuse_fn_kwargs["inplace"] = True From 1c2d80de1b463a579f3678946f417b72458e9307 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 27 Sep 2023 16:54:15 -0400 Subject: [PATCH 2/7] add adjustment for observer min/max vals --- .../pytorch/sparsification/quantization/helpers.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 6c569ee8b81..c1bd131ac4a 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -749,6 +749,16 @@ def initialize_channel_wise_scale_zp(module: Module): if zero_point.numel() == 1: weight_fake_quant.zero_point = zero_point.reshape(-1).expand(num_channels) + # update the observer min and max vals + if weight_fake_quant.activation_post_process.min_val.numel() == 0: + weight_fake_quant.activation_post_process.min_val = torch.empty_like( + weight_fake_quant.scale + ) + if weight_fake_quant.activation_post_process.min_val.numel() == 0: + weight_fake_quant.activation_post_process.max_val = torch.empty_like( + weight_fake_quant.scale + ) + def _delete_get_block_hooks( module: Module, From 779af694f08ee603f6a5eff5c05b4ec123241904 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Wed, 27 Sep 2023 18:32:16 -0400 Subject: [PATCH 3/7] typo bug fix --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index c1bd131ac4a..a9337f7979b 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -754,7 +754,7 @@ def initialize_channel_wise_scale_zp(module: Module): weight_fake_quant.activation_post_process.min_val = torch.empty_like( weight_fake_quant.scale ) - if weight_fake_quant.activation_post_process.min_val.numel() == 0: + if weight_fake_quant.activation_post_process.max_val.numel() == 0: weight_fake_quant.activation_post_process.max_val = torch.empty_like( weight_fake_quant.scale ) From 7cfad324fceac41fe6b3b9acc8f279f2e4d83642 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 2 Oct 2023 15:58:46 -0400 Subject: [PATCH 4/7] Fixes to load pre-trained model w/ channel-wise quantization --- src/sparseml/pytorch/sparsification/quantization/helpers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index a9337f7979b..91c294f2f8c 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -725,7 +725,7 @@ def initialize_channel_wise_scale_zp(module: Module): for name, submodule in module.named_modules(): weight_fake_quant = getattr(submodule, "weight_fake_quant", None) if not weight_fake_quant or ( - getattr(weight_fake_quant, "qscheme", None) is not torch.per_channel_affine + getattr(weight_fake_quant, "qscheme", None) not in [torch.per_channel_affine, torch.per_channel_symmetric] ): # only consider modules with channel-wise quantized weights continue @@ -743,11 +743,11 @@ def initialize_channel_wise_scale_zp(module: Module): # update scale and zero point if they are initialized to a size of 1 scale = weight_fake_quant.scale if scale.numel() == 1: - weight_fake_quant.scale = scale.reshape(-1).expand(num_channels) + weight_fake_quant.scale = torch.ones(num_channels, dtype=scale.dtype) zero_point = weight_fake_quant.zero_point if zero_point.numel() == 1: - weight_fake_quant.zero_point = zero_point.reshape(-1).expand(num_channels) + weight_fake_quant.scale = torch.ones(num_channels, dtype=zero_point.dtype) # update the observer min and max vals if weight_fake_quant.activation_post_process.min_val.numel() == 0: From e7ff212592b60b27b990fdc47d9182260496ab1a Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Wed, 4 Oct 2023 10:21:49 -0400 Subject: [PATCH 5/7] Quality fixes --- .../pytorch/sparsification/quantization/helpers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 91c294f2f8c..0ecc8a90b7f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -725,7 +725,8 @@ def initialize_channel_wise_scale_zp(module: Module): for name, submodule in module.named_modules(): weight_fake_quant = getattr(submodule, "weight_fake_quant", None) if not weight_fake_quant or ( - getattr(weight_fake_quant, "qscheme", None) not in [torch.per_channel_affine, torch.per_channel_symmetric] + getattr(weight_fake_quant, "qscheme", None) + not in [torch.per_channel_affine, torch.per_channel_symmetric] ): # only consider modules with channel-wise quantized weights continue @@ -747,7 +748,9 @@ def initialize_channel_wise_scale_zp(module: Module): zero_point = weight_fake_quant.zero_point if zero_point.numel() == 1: - weight_fake_quant.scale = torch.ones(num_channels, dtype=zero_point.dtype) + weight_fake_quant.zero_point = torch.ones( + num_channels, dtype=zero_point.dtype + ) # update the observer min and max vals if weight_fake_quant.activation_post_process.min_val.numel() == 0: From 6b41c1ff1393c64bb326895c0b39188ba2b89f51 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 12 Oct 2023 17:57:19 -0400 Subject: [PATCH 6/7] Switch fake initialization to just prior to loading model weights --- .../sparsification/quantization/modifier_quantization.py | 5 ----- src/sparseml/transformers/sparsification/trainer.py | 9 ++++++++- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index aa749c17a29..a604cfef44f 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -38,7 +38,6 @@ configure_module_bn_wrappers, freeze_bn_stats, fuse_module_conv_bn_relus, - initialize_channel_wise_scale_zp, ) from sparseml.pytorch.sparsification.quantization.legacy_modifier_quantization import ( QuantizationModifier as LegacyQuantizationModifier, @@ -517,10 +516,6 @@ def _enable_module_qat(self, module: Module): self._calibrate_if_possible(module) - # if channel-wise quantization is targeted, properly initialize - # the scale and zp shapes - initialize_channel_wise_scale_zp(module) - def _fuse(self, module: Module): if self.model_fuse_fn_name in [None, "conv_bn_relus"]: self._model_fuse_fn_kwargs["inplace"] = True diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 36f8b3bb0fe..b13dfff4be9 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -48,7 +48,7 @@ ) from sparseml.transformers.utils import SparseAutoModel from sparseml.transformers.utils.helpers import RECIPE_NAME - +from sparseml.pytorch.sparsification.quantization.helpers import initialize_channel_wise_scale_zp __all__ = [ "RecipeManagerTrainerInterface", @@ -671,6 +671,13 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]): ) return False + # PerChannel quantization observers initialize variables + # to dummy shapes that do not match the ones saved in + # state_dict. + # Need to reshape these variables in order to load state_dict + # properly. + initialize_channel_wise_scale_zp(self.model) + current_state_dict = self.model.state_dict() if set(orig_state_dict.keys()) == set(current_state_dict): From 3244084df0bd11e3b11aebe4fa3918e0dc4e6a54 Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Thu, 12 Oct 2023 18:02:02 -0400 Subject: [PATCH 7/7] Style and quality fixes --- src/sparseml/transformers/sparsification/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index b13dfff4be9..f28f04cc560 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -40,6 +40,9 @@ from transformers.trainer_utils import ShardedDDPOption, get_last_checkpoint from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer +from sparseml.pytorch.sparsification.quantization.helpers import ( + initialize_channel_wise_scale_zp, +) from sparseml.pytorch.utils import ( LoggerManager, ModuleSparsificationInfo, @@ -48,7 +51,7 @@ ) from sparseml.transformers.utils import SparseAutoModel from sparseml.transformers.utils.helpers import RECIPE_NAME -from sparseml.pytorch.sparsification.quantization.helpers import initialize_channel_wise_scale_zp + __all__ = [ "RecipeManagerTrainerInterface",