Skip to content

Commit

Permalink
[GPTQ UX] Add scheme arg with QuantizationScheme support (#2286)
Browse files Browse the repository at this point in the history
* Update GHA file to install compressed-tensors from source

* Missed commit (#2300)

* Remove src from import

* Style

* Full Scheme support

* Add a small test for accepting full scheme
  • Loading branch information
rahul-tuli authored May 24, 2024
1 parent c672b9a commit 7bb3db3
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class GPTQModifier(Modifier):
not be updated. Leave None to not disable observers during QAT. Default is None
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
:param scheme: [Used, if a quantization modifier is not specified], the quantization
scheme to apply to the model, this is a dictionary that supports all keys from
QuantizationScheme except targets, which will be set to the targets parameter
set at the modifier level.
"""

sequential_update: Optional[bool] = False
Expand All @@ -79,6 +83,7 @@ class GPTQModifier(Modifier):
ignore: List[str] = Field(default_factory=list)
disable_quantization_observer_epoch: Optional[float] = None
num_calibration_steps: Optional[int] = None
scheme: Optional[Dict[str, Any]] = None
compressible_layers_: Optional[List] = None
quantization_modifier_: Any = None

Expand Down Expand Up @@ -156,6 +161,14 @@ def _build_quant_modifier(self, framework):
if getattr(self, key, False)
}

if self.scheme is not None:
# takes precedence over config_groups
targets = self.targets or ["Linear"]
config_group = QuantizationScheme.model_validate(
{"targets": targets, **self.scheme}
)
quant_args["config_groups"] = {"config_group_0": config_group}

if "config_groups" not in quant_args:
default_quant_scheme = QuantizationScheme.default_scheme(
targets=self.targets
Expand Down
76 changes: 76 additions & 0 deletions tests/sparseml/transformers/gptq/test_oneshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import shutil
import unittest

from sparseml.transformers.sparsification.sparse_model import SparseAutoModelForCausalLM
from tests.testing_utils import requires_torch


@requires_torch
class TestGPTQOneShotWithFullScheme(unittest.TestCase):
def setUp(self):
import torch

self.output = "./oneshot_output"
self.model = "roneneldan/TinyStories-1M"
self.dataset = "open_platypus"
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

self.recipe = """
first_stage:
quant_modifiers:
GPTQModifier:
ignore: ["lm_head"]
sequential_update: True
dampening_frac: 0.001
block_size: 128
targets: ["Linear"]
scheme:
input_activations: null
output_activations: null
weights:
num_bits: 8
type: "int"
symmetric: true
strategy: "tensor"
group_size: 128
"""

def test_oneshot_application(self):
from sparseml.transformers import oneshot

oneshot(
model=self.model,
dataset=self.dataset,
output_dir=self.output,
overwrite_output_dir=True,
recipe=self.recipe,
oneshot_device=self.device,
num_calibration_samples=9,
)

model_loaded = SparseAutoModelForCausalLM.from_pretrained(self.output)

# Check that the model is quantized
assert model_loaded.quantization_config is not None

# Check a specific layer is quantized
targetted_linear_layer = model_loaded.transformer.h[0].attn.attention.k_proj
assert hasattr(targetted_linear_layer, "quantization_scheme")

def tearDown(self):
shutil.rmtree(self.output)

0 comments on commit 7bb3db3

Please sign in to comment.