Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

activation ordering #2316

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class GPTQModifier(Modifier):
- LayerCompressor.revert_layer_wrappers()


:param actorder: Whether to use activation reordering or not
horheynm marked this conversation as resolved.
Show resolved Hide resolved
:param sequential_update: Whether or not to update weights sequentially by layer,
True saves on GPU memory
:param targets: list of layer names to compress during GPTQ, or '__ALL__'
Expand Down
46 changes: 44 additions & 2 deletions src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def fasterprune(
Run pruning and quantization(if applicable) on the layer up to the target
sparsity value.

:param actorder: Flag to apply activation reordering
:param blocksize: Number of columns to compress in one pass
:param percdamp: Amount of dampening to apply to H, as a fraction of the
diagonal norm
Expand Down Expand Up @@ -127,6 +128,9 @@ def fasterprune(
self.H = torch.linalg.cholesky(self.H, upper=True)
Hinv = self.H

actorder = False
invperm = None

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
Expand All @@ -144,6 +148,7 @@ def fasterprune(
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

q = w.clone()

if hasattr(self.layer, "weight_fake_quant"):
Expand All @@ -156,18 +161,42 @@ def fasterprune(
else:
q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype)
q = torch.dequantize(q)

elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
actorder = quant_scheme.weights.actorder
if quant_scheme.weights is not None:

if actorder:
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

group_size = quant_scheme.weights.group_size
if group_size is None or group_size == -1:
group_size = self.layer.weight.shape[1]

if actorder:
indices = torch.arange(self.columns, device=invperm.device)
g_idx = (perm[indices] // group_size).to(dtype=torch.int32)
g_idx = g_idx[invperm]
self.layer.weight_g_idx.data = g_idx
else:
indices = torch.arange(
self.columns, device=W.device, dtype=torch.int32
)
g_idx = indices // group_size

from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
)

strategy = quant_scheme.weights.strategy

if strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
Expand All @@ -189,11 +218,21 @@ def fasterprune(
input_dim_group = (
column_idx // quant_scheme.weights.group_size
)

# Since we're only applying quantization to a slice, this
# ends up being a channelwise application
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL

# apply g_idx
if g_idx is not None:
# scale and zp already transformed by group_size
# extract first index of group_idze
indices_to_extract = torch.arange(
0, g_idx.shape[0], group_size
)
scale = scale[:, g_idx[indices_to_extract]]
zero_point = zero_point[:, g_idx[indices_to_extract]]

q = fake_quantize(
q,
scale[:, input_dim_group],
Expand Down Expand Up @@ -224,6 +263,9 @@ def fasterprune(
_LOGGER.info("time %.2f" % (time.time() - tick))
_LOGGER.info("error %.2f" % torch.sum(Losses).item())

if actorder:
W = W[:, invperm]

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.reshape(final_shape).to(final_dtype)
Expand Down
2 changes: 2 additions & 0 deletions src/sparseml/modifiers/utils/layer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def revert_layer_wrappers(self):
def compress(self):
"""
Apply compression to each wrapped submodule in the layer

:param: actorder: flag to apply activation reordering
"""

@torch.no_grad()
Expand Down
Loading