Skip to content

Commit

Permalink
Quantization Compressor Support (#2260)
Browse files Browse the repository at this point in the history
* initial commit

* update setup.py

* Update setup.py

* fix setup.py

* move all config to sparsetensors

* cleanup class name and comments

* initial implementation untested

* fixing issues

* add test script

* update perplexity test

* refactor to compressed-tensors

* rename sparsetensors

* update setup

* Sa/model reload (#2250)

* working reload

* sparsegpt

* cleanup

* refactor tests

* only run oneshot once

* all tests passing

* remove unused config

* reset models on each parameterize

* style

* bring back SparsityConfigMetadata

* Update setup.py

Co-authored-by: Rahul Tuli <[email protected]>

* add more comparisons, tighten threshold

* use wikitext for perplexity

* update setup

* fix import problem

* fix clearml test

* compressed-tensors are transformers dep

* address PR comments

* can't repeat freeze

* UX pr comments

* initial commit

* style

* skipping unit tests

* tests for quantization

* reloading unit tests

* backwards compat

* test updates

* update format

* fix inferring

* quality

* shape consistency

* address PR comments

* PR comments

* fixing some things

* style

* pull from cp main

* postmerge too

* export needs it too

* Update src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py

Co-authored-by: Rahul Tuli <[email protected]>

---------

Co-authored-by: dbogunowicz <[email protected]>
Co-authored-by: dbogunowicz <[email protected]>
Co-authored-by: Rahul Tuli <[email protected]>
Co-authored-by: George Ohashi <[email protected]>
  • Loading branch information
5 people authored May 9, 2024
1 parent 214873b commit 8a7fc99
Show file tree
Hide file tree
Showing 10 changed files with 327 additions and 98 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/Integrations-post-merge-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ jobs:
run: pip3 install -U pip && pip3 install setuptools sparsezoo/
- name: "Clean sparsezoo directory"
run: rm -r sparsezoo/
- uses: actions/checkout@v2
with:
repository: "neuralmagic/compressed-tensors"
path: "compressed-tensors"
ref: ${{needs.test-setup.outputs.branch}}
- name: "⚙️ Install compressed-tensors dependencies"
run: pip3 install -U pip && pip3 install setuptools compressed-tensors/
- name: "Clean compressed-tensors directory"
run: rm -r compressed-tensors/
- name: "⚙️ Install dependencies"
run: pip3 install .[dev,torchvision,deepsparse,onnxruntime,transformers,yolov5]
- name: "🔬 Running integrations tests (cadence: commit}})"
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/integrations-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ jobs:
run: pip3 install -U pip && pip3 install setuptools sparsezoo/
- name: "Clean sparsezoo directory"
run: rm -r sparsezoo/
- uses: actions/checkout@v2
with:
repository: "neuralmagic/compressed-tensors"
path: "compressed-tensors"
ref: ${{needs.test-setup.outputs.branch}}
- name: "⚙️ Install compressed-tensors dependencies"
run: pip3 install -U pip && pip3 install setuptools compressed-tensors/
- name: "Clean compressed-tensors directory"
run: rm -r compressed-tensors/
- name: "⚙️ Install dependencies"
run: pip3 install .[dev,torchvision,deepsparse,onnxruntime,transformers,yolov5]
- name: "🔬 Running integrations tests (cadence: pre-commit}})"
Expand Down
18 changes: 18 additions & 0 deletions .github/workflows/test-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,15 @@ jobs:
run: pip3 install -U pip && pip3 install setuptools sparsezoo/
- name: "Clean sparsezoo directory"
run: rm -r sparsezoo/
- uses: actions/checkout@v2
with:
repository: "neuralmagic/compressed-tensors"
path: "compressed-tensors"
ref: ${{needs.test-setup.outputs.branch}}
- name: "⚙️ Install compressed-tensors dependencies"
run: pip3 install -U pip && pip3 install setuptools compressed-tensors/
- name: "Clean compressed-tensors directory"
run: rm -r compressed-tensors/
- name: "⚙️ Install dependencies"
run: pip3 install .[dev,torch,transformers,onnxruntime]
- name: "🔬 Running transformers tests"
Expand All @@ -270,6 +279,15 @@ jobs:
run: pip3 install -U pip && pip3 install setuptools sparsezoo/
- name: "Clean sparsezoo directory"
run: rm -r sparsezoo/
- uses: actions/checkout@v2
with:
repository: "neuralmagic/compressed-tensors"
path: "compressed-tensors"
ref: ${{needs.test-setup.outputs.branch}}
- name: "⚙️ Install compressed-tensors dependencies"
run: pip3 install -U pip && pip3 install setuptools compressed-tensors/
- name: "Clean compressed-tensors directory"
run: rm -r compressed-tensors/
- name: "⚙️ Install dependencies"
run: pip3 install .[dev,torch,transformers,torchvision,onnxruntime]
- name: "🔬 Running export tests"
Expand Down
43 changes: 27 additions & 16 deletions src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,28 +172,39 @@ def fasterprune(
q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype)
q = torch.dequantize(q)
elif hasattr(self.layer, "quantization_scheme"):
if self.layer.quantization_scheme.weights is not None:
quant_scheme = self.layer.quantization_scheme
if quant_scheme.weights is not None:
scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point
from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
)

while scale.ndim < 2:
scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

while q.ndim < 2:
q = q.unsqueeze(1)
q = fake_quantize(
q,
scale[:, i],
zero_point[:, i],
self.layer.quantization_scheme.weights,
)

while q.ndim != 1:
q.squeeze()
if quant_scheme.weights.strategy == QuantizationStrategy.TENSOR:
q = fake_quantize(
q,
scale,
zero_point,
self.layer.quantization_scheme.weights,
)
else:
while scale.ndim < 2:
scale = scale.unsqueeze(scale.ndim)
zero_point = zero_point.unsqueeze(zero_point.ndim)

while q.ndim < 2:
q = q.unsqueeze(q.ndim)

q = fake_quantize(
q,
scale[:, i],
zero_point[:, i],
self.layer.quantization_scheme.weights,
)

while q.ndim > 1:
q = q.squeeze()

Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
Expand Down
48 changes: 48 additions & 0 deletions src/sparseml/transformers/compression/quantization_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional

from compressed_tensors import CompressionFormat
from compressed_tensors.quantization.utils import is_model_quantized


__all__ = ["infer_quantization_format"]


def infer_quantization_format(
model, quantization_format: Optional[str] = None, save_compressed: bool = False
) -> str:
"""
Infers a quantization format based on model state and compression args
:param model: model to check for quantization, if the model is not quantized no
quantization format is returned
:param quantization_format: user provided quantization format, supercedes any
inferred quantization format
:param save_compressed: used to infer a quantization format if None is provided
:return compression format appropriate for model
"""
if not is_model_quantized(model):
return None

if quantization_format is not None:
return quantization_format

if save_compressed:
return CompressionFormat.int_quantized
else:
# format will be inferred from config
return None
20 changes: 12 additions & 8 deletions src/sparseml/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from torch.nn import Module

import sparseml
from compressed_tensors import CompressionConfig
from compressed_tensors import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization.utils import is_model_quantized
from sparseml.pytorch.utils import ModuleSparsificationInfo


class SparsityConfigMetadata:
"""
Class of helper functions for filling out a CompressionConfig with readable
Class of helper functions for filling out a SparsityCompressionConfig with readable
metadata from the model
"""

Expand Down Expand Up @@ -72,7 +73,7 @@ def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
) -> Optional["CompressionConfig"]:
) -> Optional["SparsityCompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
Expand All @@ -91,20 +92,23 @@ def from_pretrained(
return None

sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure()
if compress:
format = "sparse_bitmask"
if is_model_quantized(model):
# compressing a sparse quantized model is not supported yet
format = CompressionFormat.dense.value
elif compress:
format = CompressionFormat.sparse_bitmask.value
else:
format = "dense_sparsity"
format = CompressionFormat.dense.value

return CompressionConfig.load_from_registry(
return SparsityCompressionConfig.load_from_registry(
format,
global_sparsity=global_sparsity,
sparsity_structure=sparsity_structure,
)

@staticmethod
def fill_config_details(
config: CompressionConfig,
config: SparsityCompressionConfig,
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
import weakref
from functools import wraps
from typing import Optional

from transformers import PreTrainedModel
from transformers.file_utils import CONFIG_NAME

from compressed_tensors import (
QUANTIZATION_CONFIG_NAME,
SPARSITY_CONFIG_NAME,
CompressionConfig,
ModelCompressor,
QuantizationConfig,
)

from compressed_tensors import ModelCompressor, SparsityCompressionConfig
from compressed_tensors.quantization.utils import is_model_quantized
from sparseml.transformers.compression.quantization_format import (
infer_quantization_format,
)
from sparseml.transformers.compression.sparsity_config import SparsityConfigMetadata
from sparseml.utils.pytorch import qat_active

Expand Down Expand Up @@ -60,7 +54,8 @@ def save_pretrained_compressed(save_pretrained_method):
@wraps(original_save_pretrained)
def save_pretrained_wrapper(
save_directory: str,
sparsity_config: Optional[CompressionConfig] = None,
sparsity_config: Optional[SparsityCompressionConfig] = None,
quantization_format: str = None,
save_compressed: bool = False,
skip_compression_stats: bool = False,
**kwargs,
Expand All @@ -73,6 +68,8 @@ def save_pretrained_wrapper(
:param save_directory: output directory to save model to
:param sparsity_config: optional sparsity config to compress model with,
if no config is provided it will be inferred from the model
:param quantization_format: optional compression format for quantized
models. If none is provided it will be inferred from the model
:param save_compresed: whether or not to compress the model on disk
:param skip_compression_stats: whether to skip the calculation of
compression statistics (such as global sparsity and sparsity structure) when
Expand All @@ -98,30 +95,6 @@ def save_pretrained_wrapper(

return

elif qat_active(model): # quantized in new framework
_LOGGER.info(
"Sparsity compression for quantized models is not yet supported. "
"No sparsity statistics will be calculated and no sparsity config "
"will be saved."
)

original_save_pretrained.__get__(model, model_class)(
save_directory, **kwargs
)

quant_config = QuantizationConfig.from_pretrained(model)
quant_config_data = quant_config.model_dump(exclude_unset=True)
config_file_path = os.path.join(save_directory, CONFIG_NAME)

# add the sparsity config to the model's config file
with open(config_file_path, "r") as config_file:
config_data = json.load(config_file)
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2, sort_keys=True)

return

if sparsity_config is not None:
sparsity_config.global_sparsity = (
SparsityConfigMetadata.infer_global_sparsity(
Expand All @@ -131,7 +104,6 @@ def save_pretrained_wrapper(
sparsity_config.sparsity_structure = (
SparsityConfigMetadata.infer_sparsity_structure()
)

elif not skip_compression_stats:
# try to infer a sparsity config from the model if none is provided
_LOGGER.info(
Expand All @@ -144,38 +116,36 @@ def save_pretrained_wrapper(
model, state_dict=state_dict, compress=save_compressed
)

if sparsity_config is None:
# model is not sparse, save as dense
quantization_format = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
)
compressor = ModelCompressor.from_pretrained_model(
model,
sparsity_config=sparsity_config,
quantization_format=quantization_format,
)
if compressor is None:
# model is not compressed or quantized, save as normal
return original_save_pretrained.__get__(model, model_class)(
save_directory, **kwargs
)

# if we've gotten to this point we have a config so we can run compression
kwargs["safe_serialization"] = True
compressor = ModelCompressor.load_from_registry(
sparsity_config.format, config=sparsity_config
)

if state_dict is None:
state_dict = model.state_dict()

# make sure we're on the main process when saving
if state_dict is not None and len(state_dict) > 0:
compressed_state_dict = compressor.compress(state_dict)
compressed_state_dict = compressor.compress(model, state_dict)
kwargs["state_dict"] = compressed_state_dict

original_save_pretrained.__get__(model, model_class)(
save_directory, **kwargs
)
sparsity_config_data = sparsity_config.dict()
config_file_path = os.path.join(save_directory, CONFIG_NAME)

# add the sparsity config to the model's config file
with open(config_file_path, "r") as config_file:
config_data = json.load(config_file)
config_data[SPARSITY_CONFIG_NAME] = sparsity_config_data
with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2, sort_keys=True)
compressor.update_config(save_directory)

save_pretrained_wrapper._overriden = True
return save_pretrained_wrapper
Expand Down
Loading

0 comments on commit 8a7fc99

Please sign in to comment.