Skip to content

Commit

Permalink
[GPTQ UX] Add string aliasing support for scheme (#2287)
Browse files Browse the repository at this point in the history
* Update GHA file to install compressed-tensors from source

* Missed commit (#2300)

* Remove src from import

* Style

* Full Scheme support

* Add a small test for accepting full scheme

* Add support for string aliasing

* Style
  • Loading branch information
rahul-tuli authored May 24, 2024
1 parent 7bb3db3 commit 2cee0b5
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

from pydantic import Field

from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
is_preset_scheme,
)
from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
from sparseml.core.model.base import ModifiableModel
Expand Down Expand Up @@ -71,7 +75,9 @@ class GPTQModifier(Modifier):
:param scheme: [Used, if a quantization modifier is not specified], the quantization
scheme to apply to the model, this is a dictionary that supports all keys from
QuantizationScheme except targets, which will be set to the targets parameter
set at the modifier level.
set at the modifier level. Can also be set to a dictionary of the format
`preset_scheme_name: targets` for example: `W8A8: ['Linear']` for weight 8 bit
and activation 8 bit quantization on the Linear layers.
"""

sequential_update: Optional[bool] = False
Expand Down Expand Up @@ -163,6 +169,19 @@ def _build_quant_modifier(self, framework):

if self.scheme is not None:
# takes precedence over config_groups

if any(is_preset_scheme(key) for key in self.scheme.keys()):
config_groups = QuantizationConfig(
config_groups=self.scheme
).config_groups
quant_args["config_groups"] = config_groups
else:
targets = self.targets or ["Linear"]
config_group = QuantizationScheme.model_validate(
{"targets": targets, **self.scheme}
)
quant_args["config_groups"] = {"config_group_0": config_group}

targets = self.targets or ["Linear"]
config_group = QuantizationScheme.model_validate(
{"targets": targets, **self.scheme}
Expand Down

0 comments on commit 2cee0b5

Please sign in to comment.