diff --git a/src/sparseml/modifiers/quantization/gptq/base.py b/src/sparseml/modifiers/quantization/gptq/base.py index 004fce2ee7a..c3254ab31ca 100644 --- a/src/sparseml/modifiers/quantization/gptq/base.py +++ b/src/sparseml/modifiers/quantization/gptq/base.py @@ -80,6 +80,7 @@ class GPTQModifier(Modifier): and activation 8 bit quantization on the Linear layers. """ + actorder: bool = False sequential_update: Optional[bool] = False targets: Union[str, List[str], None] = None block_size: int = 128 diff --git a/src/sparseml/modifiers/quantization/gptq/pytorch.py b/src/sparseml/modifiers/quantization/gptq/pytorch.py index e9e3f715625..66898688f12 100644 --- a/src/sparseml/modifiers/quantization/gptq/pytorch.py +++ b/src/sparseml/modifiers/quantization/gptq/pytorch.py @@ -156,7 +156,7 @@ def apply_compression( layer_compressor.pre_compress() _LOGGER.info(f"Calibrating {layer_compressor.name}...") run_calibration_forward(self.model, dataloader, mask_padding=True) - layer_compressor.compress() + layer_compressor.compress(self.actorder) layer_compressor.post_compress() layer_compressor.revert_layer_wrappers() torch.cuda.empty_cache() diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index 73321c0d0aa..f7b54f56038 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -81,6 +81,7 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor): def fasterprune( self, + actorder: bool = False, blocksize: int = 128, percdamp: float = 0.01, ): @@ -109,6 +110,12 @@ def fasterprune( self.H[dead, dead] = 1 W[:, dead] = 0 + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + invperm = torch.argsort(perm) + Losses = torch.zeros(self.rows, device=self.dev) damp = percdamp * torch.mean(torch.diag(self.H)) @@ -153,6 +160,7 @@ def fasterprune( for i in range(count): w = W1[:, i] d = Hinv1[i, i] + q = w.clone() if sparsity >= SPARSITY_THRESHOLD: q[mask1[:, i]] = 0 @@ -227,6 +235,9 @@ def fasterprune( _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item()) + if actorder: + W = W[:, invperm] + if isinstance(self.layer, transformers.Conv1D): W = W.t() W = W.reshape(final_shape).to(final_dtype) diff --git a/src/sparseml/modifiers/utils/layer_compressor.py b/src/sparseml/modifiers/utils/layer_compressor.py index e5a36f77278..5090539d84e 100644 --- a/src/sparseml/modifiers/utils/layer_compressor.py +++ b/src/sparseml/modifiers/utils/layer_compressor.py @@ -131,7 +131,7 @@ def revert_layer_wrappers(self): module_wrapper.free() self.modules = None - def compress(self): + def compress(self, actorder: bool = False): """ Apply compression to each wrapped submodule in the layer """ @@ -141,7 +141,7 @@ def prune(module): if isinstance(module, self.module_compressor_class): full_name = self._get_full_submodule_name(module.name) _LOGGER.info(f"Compressing {full_name}...") - module.fasterprune(**self.args) + module.fasterprune(actorder=actorder, **self.args) self.layer.apply(prune)