diff --git a/src/sparseml/onnx/utils/graph_optimizer.py b/src/sparseml/onnx/utils/graph_optimizer.py index cc7b9250936..d91f900c898 100644 --- a/src/sparseml/onnx/utils/graph_optimizer.py +++ b/src/sparseml/onnx/utils/graph_optimizer.py @@ -165,8 +165,12 @@ def quantize_resnet_identity_add_inputs(quantized_model: onnx.ModelProto) -> boo quantize_node = get_quantize_parent_for_dequantize_node( quantized_model, dequantize_node ) + # check that the quantize block takes input from the same relu - if quantize_node.input[0] != other_input_node.output[0]: + if ( + quantize_node is None + or quantize_node.input[0] != other_input_node.output[0] + ): continue # create de-quantize node for identity diff --git a/src/sparseml/pytorch/optim/modifier_quantization.py b/src/sparseml/pytorch/optim/modifier_quantization.py index 8a0f004ce86..1db0c54bf48 100644 --- a/src/sparseml/pytorch/optim/modifier_quantization.py +++ b/src/sparseml/pytorch/optim/modifier_quantization.py @@ -41,6 +41,7 @@ configure_module_qat_wrappers, fuse_module_conv_bn_relus, get_qat_qconfig, + prepare_embeddings_qat, ) @@ -80,6 +81,10 @@ class QuantizationModifier(ScheduledModifier): exception. For compatibility with YAML serialization only. :param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed to the model fusing function + :param quantize_embeddings: if True, will perform QAT on torch.nn.Embedding layers + using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake + quantize embedding weights. Default is True. Models without embedding layers + will be unaffected """ def __init__( @@ -91,6 +96,7 @@ def __init__( freeze_bn_stats_epoch: Union[float, None] = None, end_epoch: float = -1, model_fuse_fn_kwargs: Dict[str, Any] = None, + quantize_embeddings: bool = True, ): if torch_quantization is None or torch_intrinsic is None: raise RuntimeError( @@ -112,6 +118,7 @@ def __init__( self._model_fuse_fn_kwargs = model_fuse_fn_kwargs or {} self._disable_quantization_observer_epoch = disable_quantization_observer_epoch self._freeze_bn_stats_epoch = freeze_bn_stats_epoch + self._quantize_embeddings = quantize_embeddings self._modules_to_quantize = None self._qat_enabled = False @@ -140,7 +147,7 @@ def submodules(self) -> Union[List[str], None]: def submodules(self, value: Union[List[str], None]): """ :params value: List of submodule names to perform QAT on. Set None to quantize - entire model + entire model """ self._submodules = value if isinstance(self._submodules, list): @@ -151,8 +158,8 @@ def submodules(self, value: Union[List[str], None]): def model_fuse_fn_name(self) -> Union[str, None]: """ :return: Name of model function to fuse the model in place prior - to performing QAT. None to uses the default function - `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. + to performing QAT. None to uses the default function + `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. """ return self._model_fuse_fn_name @@ -160,9 +167,9 @@ def model_fuse_fn_name(self) -> Union[str, None]: def model_fuse_fn_name(self, value: Union[str, None]): """ :params value: Name of model function to fuse the model in place prior - to performing QAT. Set None to use the default function - `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse' - to skip module fusing. + to performing QAT. Set None to use the default function + `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse' + to skip module fusing. """ self._model_fuse_fn_name = value if ( @@ -176,8 +183,8 @@ def model_fuse_fn_name(self, value: Union[str, None]): def disable_quantization_observer_epoch(self) -> Union[float, None]: """ :return: Epoch to disable updates to the module's - quantization observers. After this point, quantized weights and zero points will - not be updated. When None, observers never disabled during QAT + quantization observers. After this point, quantized weights and zero points + will not be updated. When None, observers never disabled during QAT """ return self._disable_quantization_observer_epoch @@ -185,8 +192,8 @@ def disable_quantization_observer_epoch(self) -> Union[float, None]: def disable_quantization_observer_epoch(self, value: Union[float, None]): """ :params value: Epoch to disable updates to the module's - quantization observers. After this point, quantized weights and zero points will - not be updated. Set None to not disable observers during QAT + quantization observers. After this point, quantized weights and zero points + will not be updated. Set None to not disable observers during QAT """ self._disable_quantization_observer_epoch = value self._validate_params() @@ -195,7 +202,7 @@ def disable_quantization_observer_epoch(self, value: Union[float, None]): def freeze_bn_stats_epoch(self) -> Union[float, None]: """ :return: Epoch to stop the tracking of batch norm stats. When - None, batch norm stats are track for all of training + None, batch norm stats are track for all of training """ return self._freeze_bn_stats_epoch @@ -203,11 +210,29 @@ def freeze_bn_stats_epoch(self) -> Union[float, None]: def freeze_bn_stats_epoch(self, value: Union[float, None]): """ :params value: Epoch to stop the tracking of batch norm stats. Set - None to not stop tracking batch norm stats during QAT + None to not stop tracking batch norm stats during QAT """ self._freeze_bn_stats_epoch = value self._validate_params() + @ModifierProp() + def quantize_embeddings(self) -> bool: + """ + :return: if True, will perform QAT on torch.nn.Embedding layers + using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake + quantize embedding weights + """ + return self._freeze_bn_stats_epoch + + @quantize_embeddings.setter + def quantize_embeddings(self, value: bool): + """ + :params value: if True, will perform QAT on torch.nn.Embedding layers + using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake + quantize embedding weights + """ + self._quantize_embeddings = value + def initialize( self, module: Module, @@ -350,6 +375,8 @@ def _enable_module_qat(self, module: Module): add_quant_dequant(quant_module) # set model to QAT mode torch_quantization.prepare_qat(quant_module, inplace=True) + if self._quantize_embeddings: + prepare_embeddings_qat(quant_module) self._qat_enabled = True def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: diff --git a/src/sparseml/pytorch/utils/quantization/helpers.py b/src/sparseml/pytorch/utils/quantization/helpers.py index f105a37682a..f28711879d3 100644 --- a/src/sparseml/pytorch/utils/quantization/helpers.py +++ b/src/sparseml/pytorch/utils/quantization/helpers.py @@ -20,7 +20,7 @@ from typing import Any, Callable, List, Union import torch -from torch.nn import BatchNorm2d, Conv2d, Module, ReLU +from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU try: @@ -40,6 +40,7 @@ "add_quant_dequant", "get_qat_qconfig", "fuse_module_conv_bn_relus", + "prepare_embeddings_qat", ] @@ -318,11 +319,15 @@ def add_quant_dequant(module): def get_qat_qconfig( symmetric_activations: bool = False, + symmetric_weights: bool = True, ) -> "torch.quantization.QConfig": """ :param symmetric_activations: if True, activations will have a symmetric - quantization range with zero point set to 128. Otherwise activations + UINT8 quantization range with zero point set to 128. Otherwise activations will use asymmetric quantization with any zero point. Default is False + :param symmetric_weights: if True, weights will have a symmetric + INT8 quantization range with zero point set to 0. Otherwise activations + will use asymmetric quantization with any zero point. Default is True :return: A QAT fake quantization config for symmetric weight quantization and asymmetric activation quantization. The difference between this and torch.quantization.default_qat_qconfig is that the activation observer @@ -339,7 +344,17 @@ def get_qat_qconfig( qscheme=activation_qscheme, reduce_range=False, ) - weight_observer = torch_quantization.default_weight_fake_quant + weight_qscheme = ( + torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine + ) + weight_observer = torch_quantization.FakeQuantize.with_args( + observer=torch_quantization.MovingAverageMinMaxObserver, + quant_min=-128, + quant_max=127, + dtype=torch.qint8, + qscheme=weight_qscheme, + reduce_range=False, + ) return torch_quantization.QConfig( activation=activation_observer, weight=weight_observer, @@ -423,6 +438,44 @@ def fuse_module_conv_bn_relus( return module +def prepare_embeddings_qat( + module: Module, + qconfig: "torch.quantization.QConfig" = None, +): + """ + adds a fake quantize call to the weights of any Embedding modules in the given + module + + :param module: module to run QAT for the embeddings of + :param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8 + asymmetric range + """ + if qconfig is None: + qconfig = get_qat_qconfig(symmetric_weights=False) + for submodule in module.modules(): + if type(submodule) is Embedding: + _prepare_qat_embedding(submodule, qconfig) + + +def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"): + embedding.weight_fake_quant = qconfig.weight() + + def _qat_forward(self, input: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.embedding( + input, + self.weight_fake_quant(self.weight), + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + # bind qat forward to embedding + qat_forward_bound = _qat_forward.__get__(embedding, embedding.__class__) + setattr(embedding, "forward", qat_forward_bound) + + def _set_submodule(root_module, sub_module_path, sub_module): current_module = root_module sub_module_path = sub_module_path.split(".") diff --git a/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py b/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py index 78ee6fa2d9a..fa463a96469 100644 --- a/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py +++ b/src/sparseml/pytorch/utils/quantization/quantize_qat_export.py @@ -909,6 +909,100 @@ def _convert_quantizable_ops(model: ModelProto): ) +def _quantize_qat_embedding(model: ModelProto): + """ + A pass for quantizing qat embeddings + + Starting with: + | INPUT QuantizeLinear (with constant embedding) + | | | + | | DequantizeLinear + | | | + | Gather + | | + | QuantizeLinear + | | + | DequantizeLinear + | | + | OUTPUT + + Converts to: + | INPUT + | | + | Gather(UINT8 data initializer) + | | + | DequantizeLinear + | | + | OUTPUT + """ + graph = ONNXGraph(model) + gather_nodes = [node for node in model.graph.node if node.op_type == "Gather"] + + converted_nodes = 0 + for gather_node in gather_nodes: + # find input quant and dequant nodes + input_dequant_node = graph.get_node_single_parent(gather_node, 0) + if not input_dequant_node or input_dequant_node.op_type != "DequantizeLinear": + continue + input_quant_node = graph.get_node_single_parent(input_dequant_node, 0) + if not input_quant_node or input_quant_node.op_type != "QuantizeLinear": + continue + # find embedding weights, sclae, and zero point + embedding_initializer = graph.get_init_by_name(input_quant_node.input[0]) + scale_initializer = graph.get_init_by_name(input_quant_node.input[1]) + zp_initializer = graph.get_init_by_name(input_quant_node.input[2]) + if not embedding_initializer or not scale_initializer or not zp_initializer: + continue + + # quantize embedding + embedding = numpy_helper.to_array(embedding_initializer) + scale = numpy_helper.to_array(scale_initializer) + zero_point = numpy_helper.to_array(zp_initializer) + embedding_quant = _quantize_array(embedding, scale, zero_point) + embedding_quant_initializer = numpy_helper.from_array( + embedding_quant, name=f"{embedding_initializer.name}_quant" + ) + + # update graph + model.graph.initializer.append(embedding_quant_initializer) + gather_node.input[0] = embedding_quant_initializer.name + + # detect QDQ block on output + output_quant_node = graph.get_node_single_child(gather_node) + if output_quant_node and output_quant_node.op_type == "QuantizeLinear": + output_dequant_node = graph.get_node_single_child(output_quant_node) + qdq_output = ( + output_dequant_node + and output_dequant_node.op_type == "DequantizeLinear" + ) + else: + qdq_output = False + + if qdq_output: + # delete unnecessary quantize and dequantize ops + delete_quant_node(model, input_quant_node, keep_params=False) + delete_quant_node(model, input_dequant_node, keep_params=False) + delete_quant_node(model, output_quant_node, keep_params=False) + # forward gather output to dequant input + output_dequant_node.input[0] = gather_node.output[0] + + else: + # use input dequant to dequantize output + embedding_quant_output_id = f"{gather_node.output[0]}_quant" + input_dequant_node.input[0] = embedding_quant_output_id + input_dequant_node.output[0] = gather_node.output[0] + gather_node.output[0] = embedding_quant_output_id + + delete_quant_node(model, input_quant_node, keep_params=False) + graph.update() + converted_nodes += 1 + + graph.delete_unused_initializers() + + if converted_nodes > 0: + _LOGGER.info(f"Converted {converted_nodes} QAT embedding ops to UINT8") + + def _replace_input_id_model(model: ModelProto, old_id: str, new_id: str): for node in model.graph.node: for idx, inp in enumerate(node.input): @@ -996,6 +1090,7 @@ def quantize_torch_qat_export( _convert_quantizable_matmul(model) _convert_quantizable_matmul_and_add(model) _convert_quantizable_ops(model) + _quantize_qat_embedding(model) quantize_resnet_identity_add_inputs(model) quantized_residual_add_optim(model) _remove_duplicate_quantize_ops(model) diff --git a/tests/sparseml/pytorch/optim/test_modifier_quantization.py b/tests/sparseml/pytorch/optim/test_modifier_quantization.py index 790fd13990b..a2c0da55227 100644 --- a/tests/sparseml/pytorch/optim/test_modifier_quantization.py +++ b/tests/sparseml/pytorch/optim/test_modifier_quantization.py @@ -166,20 +166,16 @@ def test_quantization_modifier_yaml(): model_fuse_fn_name = "fuse_module" disable_quantization_observer_epoch = 2.0 freeze_bn_stats_epoch = 3.0 - yaml_str = """ + quantize_embeddings = False + yaml_str = f""" !QuantizationModifier start_epoch: {start_epoch} submodules: {submodules} model_fuse_fn_name: {model_fuse_fn_name} disable_quantization_observer_epoch: {disable_quantization_observer_epoch} freeze_bn_stats_epoch: {freeze_bn_stats_epoch} - """.format( - start_epoch=start_epoch, - submodules=submodules, - model_fuse_fn_name=model_fuse_fn_name, - disable_quantization_observer_epoch=disable_quantization_observer_epoch, - freeze_bn_stats_epoch=freeze_bn_stats_epoch, - ) + quantize_embeddings: {quantize_embeddings} + """ yaml_modifier = QuantizationModifier.load_obj( yaml_str ) # type: QuantizationModifier @@ -192,6 +188,7 @@ def test_quantization_modifier_yaml(): model_fuse_fn_name=model_fuse_fn_name, disable_quantization_observer_epoch=disable_quantization_observer_epoch, freeze_bn_stats_epoch=freeze_bn_stats_epoch, + quantize_embeddings=quantize_embeddings, ) assert isinstance(yaml_modifier, QuantizationModifier) @@ -220,3 +217,8 @@ def test_quantization_modifier_yaml(): == serialized_modifier.freeze_bn_stats_epoch == obj_modifier.freeze_bn_stats_epoch ) + assert ( + yaml_modifier.quantize_embeddings + == serialized_modifier.quantize_embeddings + == obj_modifier.quantize_embeddings + ) diff --git a/tests/sparseml/pytorch/utils/quantization/test_helpers.py b/tests/sparseml/pytorch/utils/quantization/test_helpers.py index d8947a321ba..1a1f55b17ef 100644 --- a/tests/sparseml/pytorch/utils/quantization/test_helpers.py +++ b/tests/sparseml/pytorch/utils/quantization/test_helpers.py @@ -25,6 +25,7 @@ configure_module_qat_wrappers, fuse_module_conv_bn_relus, get_qat_qconfig, + prepare_embeddings_qat, ) @@ -222,3 +223,19 @@ def test_fuse_module_conv_bn_relus(model_lambda, conv_bn_relus, conv_bns): fuse_module_conv_bn_relus(module, inplace=True) assert _count_submodule_instances(module, conv_bn_relu_class) == conv_bn_relus assert _count_submodule_instances(module, conv_bn_class) == conv_bns + + +def test_prepare_embeddings_qat(): + module = _ModuleWrapper(torch.nn.Embedding(10, 10)) + + # check that fake quant observer is properly added + assert not hasattr(module.module, "weight_fake_quant") + prepare_embeddings_qat(module) + assert hasattr(module.module, "weight_fake_quant") + + # check that the observer is updated on embedding forward pass + observer = module.module.weight_fake_quant + orig_range_min = observer.activation_post_process.min_val.item() + module(torch.arange(10)) + observed_range_min = observer.activation_post_process.min_val.item() + assert orig_range_min != observed_range_min