Skip to content

Commit

Permalink
QAT and quant postprocessing for torch.nn.Embedding (#374)
Browse files Browse the repository at this point in the history
* QAT and quant postprocessing for torch.nn.Embedding

* cleanup

* residual optim and logging fixes

* response to comments
  • Loading branch information
bfineran committed Sep 8, 2021
1 parent a1fda05 commit 47d9472
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 24 deletions.
6 changes: 5 additions & 1 deletion src/sparseml/onnx/utils/graph_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,12 @@ def quantize_resnet_identity_add_inputs(quantized_model: onnx.ModelProto) -> boo
quantize_node = get_quantize_parent_for_dequantize_node(
quantized_model, dequantize_node
)

# check that the quantize block takes input from the same relu
if quantize_node.input[0] != other_input_node.output[0]:
if (
quantize_node is None
or quantize_node.input[0] != other_input_node.output[0]
):
continue

# create de-quantize node for identity
Expand Down
51 changes: 39 additions & 12 deletions src/sparseml/pytorch/optim/modifier_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
configure_module_qat_wrappers,
fuse_module_conv_bn_relus,
get_qat_qconfig,
prepare_embeddings_qat,
)


Expand Down Expand Up @@ -80,6 +81,10 @@ class QuantizationModifier(ScheduledModifier):
exception. For compatibility with YAML serialization only.
:param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed
to the model fusing function
:param quantize_embeddings: if True, will perform QAT on torch.nn.Embedding layers
using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake
quantize embedding weights. Default is True. Models without embedding layers
will be unaffected
"""

def __init__(
Expand All @@ -91,6 +96,7 @@ def __init__(
freeze_bn_stats_epoch: Union[float, None] = None,
end_epoch: float = -1,
model_fuse_fn_kwargs: Dict[str, Any] = None,
quantize_embeddings: bool = True,
):
if torch_quantization is None or torch_intrinsic is None:
raise RuntimeError(
Expand All @@ -112,6 +118,7 @@ def __init__(
self._model_fuse_fn_kwargs = model_fuse_fn_kwargs or {}
self._disable_quantization_observer_epoch = disable_quantization_observer_epoch
self._freeze_bn_stats_epoch = freeze_bn_stats_epoch
self._quantize_embeddings = quantize_embeddings

self._modules_to_quantize = None
self._qat_enabled = False
Expand Down Expand Up @@ -140,7 +147,7 @@ def submodules(self) -> Union[List[str], None]:
def submodules(self, value: Union[List[str], None]):
"""
:params value: List of submodule names to perform QAT on. Set None to quantize
entire model
entire model
"""
self._submodules = value
if isinstance(self._submodules, list):
Expand All @@ -151,18 +158,18 @@ def submodules(self, value: Union[List[str], None]):
def model_fuse_fn_name(self) -> Union[str, None]:
"""
:return: Name of model function to fuse the model in place prior
to performing QAT. None to uses the default function
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
to performing QAT. None to uses the default function
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
"""
return self._model_fuse_fn_name

@model_fuse_fn_name.setter
def model_fuse_fn_name(self, value: Union[str, None]):
"""
:params value: Name of model function to fuse the model in place prior
to performing QAT. Set None to use the default function
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse'
to skip module fusing.
to performing QAT. Set None to use the default function
`sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse'
to skip module fusing.
"""
self._model_fuse_fn_name = value
if (
Expand All @@ -176,17 +183,17 @@ def model_fuse_fn_name(self, value: Union[str, None]):
def disable_quantization_observer_epoch(self) -> Union[float, None]:
"""
:return: Epoch to disable updates to the module's
quantization observers. After this point, quantized weights and zero points will
not be updated. When None, observers never disabled during QAT
quantization observers. After this point, quantized weights and zero points
will not be updated. When None, observers never disabled during QAT
"""
return self._disable_quantization_observer_epoch

@disable_quantization_observer_epoch.setter
def disable_quantization_observer_epoch(self, value: Union[float, None]):
"""
:params value: Epoch to disable updates to the module's
quantization observers. After this point, quantized weights and zero points will
not be updated. Set None to not disable observers during QAT
quantization observers. After this point, quantized weights and zero points
will not be updated. Set None to not disable observers during QAT
"""
self._disable_quantization_observer_epoch = value
self._validate_params()
Expand All @@ -195,19 +202,37 @@ def disable_quantization_observer_epoch(self, value: Union[float, None]):
def freeze_bn_stats_epoch(self) -> Union[float, None]:
"""
:return: Epoch to stop the tracking of batch norm stats. When
None, batch norm stats are track for all of training
None, batch norm stats are track for all of training
"""
return self._freeze_bn_stats_epoch

@freeze_bn_stats_epoch.setter
def freeze_bn_stats_epoch(self, value: Union[float, None]):
"""
:params value: Epoch to stop the tracking of batch norm stats. Set
None to not stop tracking batch norm stats during QAT
None to not stop tracking batch norm stats during QAT
"""
self._freeze_bn_stats_epoch = value
self._validate_params()

@ModifierProp()
def quantize_embeddings(self) -> bool:
"""
:return: if True, will perform QAT on torch.nn.Embedding layers
using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake
quantize embedding weights
"""
return self._freeze_bn_stats_epoch

@quantize_embeddings.setter
def quantize_embeddings(self, value: bool):
"""
:params value: if True, will perform QAT on torch.nn.Embedding layers
using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake
quantize embedding weights
"""
self._quantize_embeddings = value

def initialize(
self,
module: Module,
Expand Down Expand Up @@ -350,6 +375,8 @@ def _enable_module_qat(self, module: Module):
add_quant_dequant(quant_module)
# set model to QAT mode
torch_quantization.prepare_qat(quant_module, inplace=True)
if self._quantize_embeddings:
prepare_embeddings_qat(quant_module)
self._qat_enabled = True

def _disable_quantization_observer_update_ready(self, epoch: float) -> bool:
Expand Down
59 changes: 56 additions & 3 deletions src/sparseml/pytorch/utils/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Callable, List, Union

import torch
from torch.nn import BatchNorm2d, Conv2d, Module, ReLU
from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU


try:
Expand All @@ -40,6 +40,7 @@
"add_quant_dequant",
"get_qat_qconfig",
"fuse_module_conv_bn_relus",
"prepare_embeddings_qat",
]


Expand Down Expand Up @@ -318,11 +319,15 @@ def add_quant_dequant(module):

def get_qat_qconfig(
symmetric_activations: bool = False,
symmetric_weights: bool = True,
) -> "torch.quantization.QConfig":
"""
:param symmetric_activations: if True, activations will have a symmetric
quantization range with zero point set to 128. Otherwise activations
UINT8 quantization range with zero point set to 128. Otherwise activations
will use asymmetric quantization with any zero point. Default is False
:param symmetric_weights: if True, weights will have a symmetric
INT8 quantization range with zero point set to 0. Otherwise activations
will use asymmetric quantization with any zero point. Default is True
:return: A QAT fake quantization config for symmetric weight quantization and
asymmetric activation quantization. The difference between this and
torch.quantization.default_qat_qconfig is that the activation observer
Expand All @@ -339,7 +344,17 @@ def get_qat_qconfig(
qscheme=activation_qscheme,
reduce_range=False,
)
weight_observer = torch_quantization.default_weight_fake_quant
weight_qscheme = (
torch.per_tensor_symmetric if symmetric_weights else torch.per_tensor_affine
)
weight_observer = torch_quantization.FakeQuantize.with_args(
observer=torch_quantization.MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=weight_qscheme,
reduce_range=False,
)
return torch_quantization.QConfig(
activation=activation_observer,
weight=weight_observer,
Expand Down Expand Up @@ -423,6 +438,44 @@ def fuse_module_conv_bn_relus(
return module


def prepare_embeddings_qat(
module: Module,
qconfig: "torch.quantization.QConfig" = None,
):
"""
adds a fake quantize call to the weights of any Embedding modules in the given
module
:param module: module to run QAT for the embeddings of
:param qconfig: qconfig to generate the fake quantize ops from. Default uses INT8
asymmetric range
"""
if qconfig is None:
qconfig = get_qat_qconfig(symmetric_weights=False)
for submodule in module.modules():
if type(submodule) is Embedding:
_prepare_qat_embedding(submodule, qconfig)


def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConfig"):
embedding.weight_fake_quant = qconfig.weight()

def _qat_forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.embedding(
input,
self.weight_fake_quant(self.weight),
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)

# bind qat forward to embedding
qat_forward_bound = _qat_forward.__get__(embedding, embedding.__class__)
setattr(embedding, "forward", qat_forward_bound)


def _set_submodule(root_module, sub_module_path, sub_module):
current_module = root_module
sub_module_path = sub_module_path.split(".")
Expand Down
95 changes: 95 additions & 0 deletions src/sparseml/pytorch/utils/quantization/quantize_qat_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,100 @@ def _convert_quantizable_ops(model: ModelProto):
)


def _quantize_qat_embedding(model: ModelProto):
"""
A pass for quantizing qat embeddings
Starting with:
| INPUT QuantizeLinear (with constant embedding)
| | |
| | DequantizeLinear
| | |
| Gather
| |
| QuantizeLinear
| |
| DequantizeLinear
| |
| OUTPUT
Converts to:
| INPUT
| |
| Gather(UINT8 data initializer)
| |
| DequantizeLinear
| |
| OUTPUT
"""
graph = ONNXGraph(model)
gather_nodes = [node for node in model.graph.node if node.op_type == "Gather"]

converted_nodes = 0
for gather_node in gather_nodes:
# find input quant and dequant nodes
input_dequant_node = graph.get_node_single_parent(gather_node, 0)
if not input_dequant_node or input_dequant_node.op_type != "DequantizeLinear":
continue
input_quant_node = graph.get_node_single_parent(input_dequant_node, 0)
if not input_quant_node or input_quant_node.op_type != "QuantizeLinear":
continue
# find embedding weights, sclae, and zero point
embedding_initializer = graph.get_init_by_name(input_quant_node.input[0])
scale_initializer = graph.get_init_by_name(input_quant_node.input[1])
zp_initializer = graph.get_init_by_name(input_quant_node.input[2])
if not embedding_initializer or not scale_initializer or not zp_initializer:
continue

# quantize embedding
embedding = numpy_helper.to_array(embedding_initializer)
scale = numpy_helper.to_array(scale_initializer)
zero_point = numpy_helper.to_array(zp_initializer)
embedding_quant = _quantize_array(embedding, scale, zero_point)
embedding_quant_initializer = numpy_helper.from_array(
embedding_quant, name=f"{embedding_initializer.name}_quant"
)

# update graph
model.graph.initializer.append(embedding_quant_initializer)
gather_node.input[0] = embedding_quant_initializer.name

# detect QDQ block on output
output_quant_node = graph.get_node_single_child(gather_node)
if output_quant_node and output_quant_node.op_type == "QuantizeLinear":
output_dequant_node = graph.get_node_single_child(output_quant_node)
qdq_output = (
output_dequant_node
and output_dequant_node.op_type == "DequantizeLinear"
)
else:
qdq_output = False

if qdq_output:
# delete unnecessary quantize and dequantize ops
delete_quant_node(model, input_quant_node, keep_params=False)
delete_quant_node(model, input_dequant_node, keep_params=False)
delete_quant_node(model, output_quant_node, keep_params=False)
# forward gather output to dequant input
output_dequant_node.input[0] = gather_node.output[0]

else:
# use input dequant to dequantize output
embedding_quant_output_id = f"{gather_node.output[0]}_quant"
input_dequant_node.input[0] = embedding_quant_output_id
input_dequant_node.output[0] = gather_node.output[0]
gather_node.output[0] = embedding_quant_output_id

delete_quant_node(model, input_quant_node, keep_params=False)
graph.update()
converted_nodes += 1

graph.delete_unused_initializers()

if converted_nodes > 0:
_LOGGER.info(f"Converted {converted_nodes} QAT embedding ops to UINT8")


def _replace_input_id_model(model: ModelProto, old_id: str, new_id: str):
for node in model.graph.node:
for idx, inp in enumerate(node.input):
Expand Down Expand Up @@ -996,6 +1090,7 @@ def quantize_torch_qat_export(
_convert_quantizable_matmul(model)
_convert_quantizable_matmul_and_add(model)
_convert_quantizable_ops(model)
_quantize_qat_embedding(model)
quantize_resnet_identity_add_inputs(model)
quantized_residual_add_optim(model)
_remove_duplicate_quantize_ops(model)
Expand Down
Loading

0 comments on commit 47d9472

Please sign in to comment.