Skip to content

Commit

Permalink
activation ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jun 4, 2024
1 parent 5caa557 commit 2691f85
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class GPTQModifier(Modifier):
and activation 8 bit quantization on the Linear layers.
"""

actorder: bool = False
sequential_update: Optional[bool] = False
targets: Union[str, List[str], None] = None
block_size: int = 128
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/modifiers/quantization/gptq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def apply_compression(
layer_compressor.pre_compress()
_LOGGER.info(f"Calibrating {layer_compressor.name}...")
run_calibration_forward(self.model, dataloader, mask_padding=True)
layer_compressor.compress()
layer_compressor.compress(self.actorder)
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()
torch.cuda.empty_cache()
Expand Down
11 changes: 11 additions & 0 deletions src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor):

def fasterprune(
self,
actorder: bool = False,
blocksize: int = 128,
percdamp: float = 0.01,
):
Expand Down Expand Up @@ -109,6 +110,12 @@ def fasterprune(
self.H[dead, dead] = 1
W[:, dead] = 0

if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
invperm = torch.argsort(perm)

Losses = torch.zeros(self.rows, device=self.dev)

damp = percdamp * torch.mean(torch.diag(self.H))
Expand Down Expand Up @@ -153,6 +160,7 @@ def fasterprune(
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

q = w.clone()
if sparsity >= SPARSITY_THRESHOLD:
q[mask1[:, i]] = 0
Expand Down Expand Up @@ -227,6 +235,9 @@ def fasterprune(
_LOGGER.info("time %.2f" % (time.time() - tick))
_LOGGER.info("error %.2f" % torch.sum(Losses).item())

if actorder:
W = W[:, invperm]

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.reshape(final_shape).to(final_dtype)
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/modifiers/utils/layer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def revert_layer_wrappers(self):
module_wrapper.free()
self.modules = None

def compress(self):
def compress(self, actorder: bool = False):
"""
Apply compression to each wrapped submodule in the layer
"""
Expand All @@ -141,7 +141,7 @@ def prune(module):
if isinstance(module, self.module_compressor_class):
full_name = self._get_full_submodule_name(module.name)
_LOGGER.info(f"Compressing {full_name}...")
module.fasterprune(**self.args)
module.fasterprune(actorder=actorder, **self.args)

self.layer.apply(prune)

Expand Down

0 comments on commit 2691f85

Please sign in to comment.