From 8a7fc99b56825369e58b46a7525a5fc4aa8e79fb Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 9 May 2024 10:39:55 -0400 Subject: [PATCH] Quantization Compressor Support (#2260) * initial commit * update setup.py * Update setup.py * fix setup.py * move all config to sparsetensors * cleanup class name and comments * initial implementation untested * fixing issues * add test script * update perplexity test * refactor to compressed-tensors * rename sparsetensors * update setup * Sa/model reload (#2250) * working reload * sparsegpt * cleanup * refactor tests * only run oneshot once * all tests passing * remove unused config * reset models on each parameterize * style * bring back SparsityConfigMetadata * Update setup.py Co-authored-by: Rahul Tuli * add more comparisons, tighten threshold * use wikitext for perplexity * update setup * fix import problem * fix clearml test * compressed-tensors are transformers dep * address PR comments * can't repeat freeze * UX pr comments * initial commit * style * skipping unit tests * tests for quantization * reloading unit tests * backwards compat * test updates * update format * fix inferring * quality * shape consistency * address PR comments * PR comments * fixing some things * style * pull from cp main * postmerge too * export needs it too * Update src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py Co-authored-by: Rahul Tuli --------- Co-authored-by: dbogunowicz Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Co-authored-by: Rahul Tuli Co-authored-by: George Ohashi --- .../Integrations-post-merge-check.yaml | 9 ++ .github/workflows/integrations-check.yaml | 9 ++ .github/workflows/test-check.yaml | 18 +++ .../modifiers/obcq/utils/sgpt_wrapper.py | 43 +++-- .../compression/quantization_format.py | 48 ++++++ .../compression/sparsity_config.py | 20 ++- .../compressed_tensors_utils.py | 76 +++------ .../sparsification/sparse_model.py | 25 +-- .../compression/recipes/new_quant_simple.yaml | 27 ++++ .../test_compress_tensor_utils.py | 150 +++++++++++++++++- 10 files changed, 327 insertions(+), 98 deletions(-) create mode 100644 src/sparseml/transformers/compression/quantization_format.py create mode 100644 tests/sparseml/transformers/compression/recipes/new_quant_simple.yaml diff --git a/.github/workflows/Integrations-post-merge-check.yaml b/.github/workflows/Integrations-post-merge-check.yaml index 25aeea10051..f3c29e8e3b3 100644 --- a/.github/workflows/Integrations-post-merge-check.yaml +++ b/.github/workflows/Integrations-post-merge-check.yaml @@ -41,6 +41,15 @@ jobs: run: pip3 install -U pip && pip3 install setuptools sparsezoo/ - name: "Clean sparsezoo directory" run: rm -r sparsezoo/ + - uses: actions/checkout@v2 + with: + repository: "neuralmagic/compressed-tensors" + path: "compressed-tensors" + ref: ${{needs.test-setup.outputs.branch}} + - name: "⚙️ Install compressed-tensors dependencies" + run: pip3 install -U pip && pip3 install setuptools compressed-tensors/ + - name: "Clean compressed-tensors directory" + run: rm -r compressed-tensors/ - name: "⚙️ Install dependencies" run: pip3 install .[dev,torchvision,deepsparse,onnxruntime,transformers,yolov5] - name: "🔬 Running integrations tests (cadence: commit}})" diff --git a/.github/workflows/integrations-check.yaml b/.github/workflows/integrations-check.yaml index 86c37b57890..ff57a0db08d 100644 --- a/.github/workflows/integrations-check.yaml +++ b/.github/workflows/integrations-check.yaml @@ -62,6 +62,15 @@ jobs: run: pip3 install -U pip && pip3 install setuptools sparsezoo/ - name: "Clean sparsezoo directory" run: rm -r sparsezoo/ + - uses: actions/checkout@v2 + with: + repository: "neuralmagic/compressed-tensors" + path: "compressed-tensors" + ref: ${{needs.test-setup.outputs.branch}} + - name: "⚙️ Install compressed-tensors dependencies" + run: pip3 install -U pip && pip3 install setuptools compressed-tensors/ + - name: "Clean compressed-tensors directory" + run: rm -r compressed-tensors/ - name: "⚙️ Install dependencies" run: pip3 install .[dev,torchvision,deepsparse,onnxruntime,transformers,yolov5] - name: "🔬 Running integrations tests (cadence: pre-commit}})" diff --git a/.github/workflows/test-check.yaml b/.github/workflows/test-check.yaml index 362fd297321..887dd745a81 100644 --- a/.github/workflows/test-check.yaml +++ b/.github/workflows/test-check.yaml @@ -246,6 +246,15 @@ jobs: run: pip3 install -U pip && pip3 install setuptools sparsezoo/ - name: "Clean sparsezoo directory" run: rm -r sparsezoo/ + - uses: actions/checkout@v2 + with: + repository: "neuralmagic/compressed-tensors" + path: "compressed-tensors" + ref: ${{needs.test-setup.outputs.branch}} + - name: "⚙️ Install compressed-tensors dependencies" + run: pip3 install -U pip && pip3 install setuptools compressed-tensors/ + - name: "Clean compressed-tensors directory" + run: rm -r compressed-tensors/ - name: "⚙️ Install dependencies" run: pip3 install .[dev,torch,transformers,onnxruntime] - name: "🔬 Running transformers tests" @@ -270,6 +279,15 @@ jobs: run: pip3 install -U pip && pip3 install setuptools sparsezoo/ - name: "Clean sparsezoo directory" run: rm -r sparsezoo/ + - uses: actions/checkout@v2 + with: + repository: "neuralmagic/compressed-tensors" + path: "compressed-tensors" + ref: ${{needs.test-setup.outputs.branch}} + - name: "⚙️ Install compressed-tensors dependencies" + run: pip3 install -U pip && pip3 install setuptools compressed-tensors/ + - name: "Clean compressed-tensors directory" + run: rm -r compressed-tensors/ - name: "⚙️ Install dependencies" run: pip3 install .[dev,torch,transformers,torchvision,onnxruntime] - name: "🔬 Running export tests" diff --git a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py index 2b439862b4e..99484490e01 100644 --- a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py +++ b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py @@ -172,28 +172,39 @@ def fasterprune( q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype) q = torch.dequantize(q) elif hasattr(self.layer, "quantization_scheme"): - if self.layer.quantization_scheme.weights is not None: + quant_scheme = self.layer.quantization_scheme + if quant_scheme.weights is not None: scale = self.layer.weight_scale zero_point = self.layer.weight_zero_point + from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import ( fake_quantize, ) - while scale.ndim < 2: - scale = scale.unsqueeze(1) - zero_point = zero_point.unsqueeze(1) - - while q.ndim < 2: - q = q.unsqueeze(1) - q = fake_quantize( - q, - scale[:, i], - zero_point[:, i], - self.layer.quantization_scheme.weights, - ) - - while q.ndim != 1: - q.squeeze() + if quant_scheme.weights.strategy == QuantizationStrategy.TENSOR: + q = fake_quantize( + q, + scale, + zero_point, + self.layer.quantization_scheme.weights, + ) + else: + while scale.ndim < 2: + scale = scale.unsqueeze(scale.ndim) + zero_point = zero_point.unsqueeze(zero_point.ndim) + + while q.ndim < 2: + q = q.unsqueeze(q.ndim) + + q = fake_quantize( + q, + scale[:, i], + zero_point[:, i], + self.layer.quantization_scheme.weights, + ) + + while q.ndim > 1: + q = q.squeeze() Q1[:, i] = q Losses1[:, i] = (w - q) ** 2 / d**2 diff --git a/src/sparseml/transformers/compression/quantization_format.py b/src/sparseml/transformers/compression/quantization_format.py new file mode 100644 index 00000000000..5f8f8722753 --- /dev/null +++ b/src/sparseml/transformers/compression/quantization_format.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +from compressed_tensors import CompressionFormat +from compressed_tensors.quantization.utils import is_model_quantized + + +__all__ = ["infer_quantization_format"] + + +def infer_quantization_format( + model, quantization_format: Optional[str] = None, save_compressed: bool = False +) -> str: + """ + Infers a quantization format based on model state and compression args + + :param model: model to check for quantization, if the model is not quantized no + quantization format is returned + :param quantization_format: user provided quantization format, supercedes any + inferred quantization format + :param save_compressed: used to infer a quantization format if None is provided + :return compression format appropriate for model + """ + if not is_model_quantized(model): + return None + + if quantization_format is not None: + return quantization_format + + if save_compressed: + return CompressionFormat.int_quantized + else: + # format will be inferred from config + return None diff --git a/src/sparseml/transformers/compression/sparsity_config.py b/src/sparseml/transformers/compression/sparsity_config.py index 665e3c6a340..b5f69cb83e1 100644 --- a/src/sparseml/transformers/compression/sparsity_config.py +++ b/src/sparseml/transformers/compression/sparsity_config.py @@ -18,13 +18,14 @@ from torch.nn import Module import sparseml -from compressed_tensors import CompressionConfig +from compressed_tensors import CompressionFormat, SparsityCompressionConfig +from compressed_tensors.quantization.utils import is_model_quantized from sparseml.pytorch.utils import ModuleSparsificationInfo class SparsityConfigMetadata: """ - Class of helper functions for filling out a CompressionConfig with readable + Class of helper functions for filling out a SparsityCompressionConfig with readable metadata from the model """ @@ -72,7 +73,7 @@ def from_pretrained( model: Module, state_dict: Optional[Dict[str, Tensor]] = None, compress: bool = False, - ) -> Optional["CompressionConfig"]: + ) -> Optional["SparsityCompressionConfig"]: """ Determines compression type and informational parameters for a given model @@ -91,12 +92,15 @@ def from_pretrained( return None sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure() - if compress: - format = "sparse_bitmask" + if is_model_quantized(model): + # compressing a sparse quantized model is not supported yet + format = CompressionFormat.dense.value + elif compress: + format = CompressionFormat.sparse_bitmask.value else: - format = "dense_sparsity" + format = CompressionFormat.dense.value - return CompressionConfig.load_from_registry( + return SparsityCompressionConfig.load_from_registry( format, global_sparsity=global_sparsity, sparsity_structure=sparsity_structure, @@ -104,7 +108,7 @@ def from_pretrained( @staticmethod def fill_config_details( - config: CompressionConfig, + config: SparsityCompressionConfig, model: Module, state_dict: Optional[Dict[str, Tensor]] = None, ): diff --git a/src/sparseml/transformers/sparsification/compressed_tensors_utils.py b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py index b6852535a2c..c62a1eb9bf9 100644 --- a/src/sparseml/transformers/sparsification/compressed_tensors_utils.py +++ b/src/sparseml/transformers/sparsification/compressed_tensors_utils.py @@ -12,24 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging -import os import weakref from functools import wraps from typing import Optional from transformers import PreTrainedModel -from transformers.file_utils import CONFIG_NAME - -from compressed_tensors import ( - QUANTIZATION_CONFIG_NAME, - SPARSITY_CONFIG_NAME, - CompressionConfig, - ModelCompressor, - QuantizationConfig, -) + +from compressed_tensors import ModelCompressor, SparsityCompressionConfig from compressed_tensors.quantization.utils import is_model_quantized +from sparseml.transformers.compression.quantization_format import ( + infer_quantization_format, +) from sparseml.transformers.compression.sparsity_config import SparsityConfigMetadata from sparseml.utils.pytorch import qat_active @@ -60,7 +54,8 @@ def save_pretrained_compressed(save_pretrained_method): @wraps(original_save_pretrained) def save_pretrained_wrapper( save_directory: str, - sparsity_config: Optional[CompressionConfig] = None, + sparsity_config: Optional[SparsityCompressionConfig] = None, + quantization_format: str = None, save_compressed: bool = False, skip_compression_stats: bool = False, **kwargs, @@ -73,6 +68,8 @@ def save_pretrained_wrapper( :param save_directory: output directory to save model to :param sparsity_config: optional sparsity config to compress model with, if no config is provided it will be inferred from the model + :param quantization_format: optional compression format for quantized + models. If none is provided it will be inferred from the model :param save_compresed: whether or not to compress the model on disk :param skip_compression_stats: whether to skip the calculation of compression statistics (such as global sparsity and sparsity structure) when @@ -98,30 +95,6 @@ def save_pretrained_wrapper( return - elif qat_active(model): # quantized in new framework - _LOGGER.info( - "Sparsity compression for quantized models is not yet supported. " - "No sparsity statistics will be calculated and no sparsity config " - "will be saved." - ) - - original_save_pretrained.__get__(model, model_class)( - save_directory, **kwargs - ) - - quant_config = QuantizationConfig.from_pretrained(model) - quant_config_data = quant_config.model_dump(exclude_unset=True) - config_file_path = os.path.join(save_directory, CONFIG_NAME) - - # add the sparsity config to the model's config file - with open(config_file_path, "r") as config_file: - config_data = json.load(config_file) - config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data - with open(config_file_path, "w") as config_file: - json.dump(config_data, config_file, indent=2, sort_keys=True) - - return - if sparsity_config is not None: sparsity_config.global_sparsity = ( SparsityConfigMetadata.infer_global_sparsity( @@ -131,7 +104,6 @@ def save_pretrained_wrapper( sparsity_config.sparsity_structure = ( SparsityConfigMetadata.infer_sparsity_structure() ) - elif not skip_compression_stats: # try to infer a sparsity config from the model if none is provided _LOGGER.info( @@ -144,38 +116,36 @@ def save_pretrained_wrapper( model, state_dict=state_dict, compress=save_compressed ) - if sparsity_config is None: - # model is not sparse, save as dense + quantization_format = infer_quantization_format( + model=model, + quantization_format=quantization_format, + save_compressed=save_compressed, + ) + compressor = ModelCompressor.from_pretrained_model( + model, + sparsity_config=sparsity_config, + quantization_format=quantization_format, + ) + if compressor is None: + # model is not compressed or quantized, save as normal return original_save_pretrained.__get__(model, model_class)( save_directory, **kwargs ) # if we've gotten to this point we have a config so we can run compression kwargs["safe_serialization"] = True - compressor = ModelCompressor.load_from_registry( - sparsity_config.format, config=sparsity_config - ) - if state_dict is None: state_dict = model.state_dict() # make sure we're on the main process when saving if state_dict is not None and len(state_dict) > 0: - compressed_state_dict = compressor.compress(state_dict) + compressed_state_dict = compressor.compress(model, state_dict) kwargs["state_dict"] = compressed_state_dict original_save_pretrained.__get__(model, model_class)( save_directory, **kwargs ) - sparsity_config_data = sparsity_config.dict() - config_file_path = os.path.join(save_directory, CONFIG_NAME) - - # add the sparsity config to the model's config file - with open(config_file_path, "r") as config_file: - config_data = json.load(config_file) - config_data[SPARSITY_CONFIG_NAME] = sparsity_config_data - with open(config_file_path, "w") as config_file: - json.dump(config_data, config_file, indent=2, sort_keys=True) + compressor.update_config(save_directory) save_pretrained_wrapper._overriden = True return save_pretrained_wrapper diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index 995b349f513..76e75862fff 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -31,11 +31,6 @@ from transformers.file_utils import WEIGHTS_NAME from compressed_tensors.compressors import ModelCompressor -from compressed_tensors.quantization import ( - QuantizationConfig, - apply_quantization_config, - load_pretrained_quantization, -) from sparseml.modifiers.quantization.modification import modify_model from sparseml.pytorch.model_load.helpers import ( apply_recipe_structure_to_model, @@ -105,11 +100,8 @@ def skip(*args, **kwargs): pretrained_model_name_or_path, **kwargs ) - # determine compression format, if any, from the model config + # instantiate compressor from model config compressor = ModelCompressor.from_pretrained(pretrained_model_name_or_path) - quantization_config = QuantizationConfig.from_model_config( - pretrained_model_name_or_path - ) # temporarily set the log level to error, to ignore printing out long missing # and unexpected key error messages (these are EXPECTED for quantized models) @@ -123,18 +115,13 @@ def skip(*args, **kwargs): # override the PreTrainedModel instance with compression save function modify_save_pretrained(model) - # If model is compressed on disk, decompress and load the weights + # If model is quantized or compressed on disk, initialize quantization + # structure and run decompression if compressor is not None: - # decompress weights - compressor.overwrite_weights( - model_path=pretrained_model_name_or_path, model=model - ) - - if quantization_config is not None: - # if we loaded from a HF stub, find the cached model - apply_quantization_config(model, quantization_config) - load_pretrained_quantization(model, pretrained_model_name_or_path) + # initialize quantization and decompress weights + compressor.decompress(model_path=pretrained_model_name_or_path, model=model) else: + # legacy loading for old quantization modifier recipe = resolve_recipe( recipe=recipe, model_path=pretrained_model_name_or_path ) diff --git a/tests/sparseml/transformers/compression/recipes/new_quant_simple.yaml b/tests/sparseml/transformers/compression/recipes/new_quant_simple.yaml new file mode 100644 index 00000000000..753605fc1dd --- /dev/null +++ b/tests/sparseml/transformers/compression/recipes/new_quant_simple.yaml @@ -0,0 +1,27 @@ +test_stage: + quant_modifiers: + vLLMQuantizationModifier: + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: "int" + symmetric: true + strategy: "tensor" + input_activations: + num_bits: 8 + type: "int" + symmetric: false + strategy: "tensor" + output_activations: null + targets: ["Linear"] + group_1: + weights: + num_bits: 8 + type: "int" + symmetric: true + strategy: "tensor" + input_activations: null + output_activations: null + targets: ["Embedding"] diff --git a/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py b/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py index fd4594046e9..f80db1b4005 100644 --- a/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py @@ -20,8 +20,18 @@ from transformers import AutoConfig import sparseml -from compressed_tensors import SPARSITY_CONFIG_NAME +from compressed_tensors import ( + COMPRESSION_CONFIG_NAME, + QUANTIZATION_CONFIG_NAME, + SPARSITY_CONFIG_NAME, +) from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig +from compressed_tensors.quantization import ( + QuantizationStatus, + compress_quantized_weights, + freeze_module_quantization, +) +from safetensors import safe_open from sparseml.transformers import SparseAutoModelForCausalLM, oneshot from sparseml.transformers.compression.sparsity_config import SparsityConfigMetadata @@ -59,6 +69,7 @@ def test_sparse_model_reload(compressed, config, dtype, tmp_path): splits=splits, oneshot_device=device, precision=dtype, + clear_sparse_session=False, ) model = SparseAutoModelForCausalLM.from_pretrained( @@ -77,7 +88,8 @@ def test_sparse_model_reload(compressed, config, dtype, tmp_path): ) config = AutoConfig.from_pretrained(tmp_path / "compress_out") - sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None) + compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) + sparsity_config = compression_config.get(SPARSITY_CONFIG_NAME, None) assert ( sparsity_config["format"] == "dense" if (not compressed and config is None) @@ -129,3 +141,137 @@ def test_dense_model_save(tmp_path, skip_compression_stats, save_compressed): assert sparsity_config is None shutil.rmtree(tmp_path) + + +@pytest.mark.parametrize( + "format,dtype", + [ + ["dense", torch.float32], + ["dense", torch.float16], + ["int_quantized", torch.float32], + # [True, "int_quantized", torch.float16], + ], +) +def test_quant_model_reload(format, dtype, tmp_path): + recipe_str = "tests/sparseml/transformers/compression/recipes/new_quant_simple.yaml" + model_path = "Xenova/llama2.c-stories15M" + device = "cuda:0" + if not torch.cuda.is_available(): + device = "cpu" + dataset = "open_platypus" + concatenate_data = False + num_calibration_samples = 64 + output_dir = tmp_path / "oneshot_out" + splits = {"calibration": "train[:10%]"} + + # create a quantized model + oneshot( + model=model_path, + dataset=dataset, + output_dir=output_dir, + num_calibration_samples=num_calibration_samples, + recipe=recipe_str, + concatenate_data=concatenate_data, + splits=splits, + oneshot_device=device, + precision=dtype, + ) + + model = SparseAutoModelForCausalLM.from_pretrained( + tmp_path / "oneshot_out", torch_dtype=dtype + ) + + for _, module in model.named_modules(): + if hasattr(module, "quantization_scheme"): + assert module.weight.dtype == dtype + assert module.quantization_status == QuantizationStatus.FROZEN + + model.save_pretrained( + tmp_path / "compress_out", + quantization_format=format, + save_compressed=True, + ) + + config = AutoConfig.from_pretrained(tmp_path / "compress_out") + compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) + quant_config = compression_config.get(QUANTIZATION_CONFIG_NAME, None) + assert quant_config["format"] == format + + dense_model = SparseAutoModelForCausalLM.from_pretrained( + tmp_path / "compress_out", torch_dtype="auto" + ) + + og_state_dict = model.state_dict() + reconstructed_state_dict = dense_model.state_dict() + assert len(og_state_dict) == len(reconstructed_state_dict) + for key in og_state_dict.keys(): + dense_tensor = og_state_dict[key] + reconstructed_tensor = reconstructed_state_dict[key] + assert dense_tensor.dtype == reconstructed_tensor.dtype + if key.endswith("weight") and format != "dense": + # we don't expect an exact match for compressed + diff = torch.abs(dense_tensor - reconstructed_tensor) + assert not torch.any(diff > 0.01).item() + else: + assert torch.equal(dense_tensor, reconstructed_tensor) + + shutil.rmtree(tmp_path) + + +@pytest.mark.parametrize( + "status,expected_format,expected_dtype", + [ + [QuantizationStatus.FROZEN, "dense", torch.float32], + [QuantizationStatus.COMPRESSED, "int-quantized", torch.int8], + ], +) +def test_quant_infer_format(status, expected_format, expected_dtype, tmp_path): + recipe_str = "tests/sparseml/transformers/compression/recipes/new_quant_simple.yaml" + model_path = "Xenova/llama2.c-stories15M" + device = "cuda:0" + if not torch.cuda.is_available(): + device = "cpu" + dataset = "open_platypus" + concatenate_data = False + num_calibration_samples = 64 + output_dir = tmp_path / "oneshot_out" + splits = {"calibration": "train[:10%]"} + + model = SparseAutoModelForCausalLM.from_pretrained(model_path) + + # create a quantized model + oneshot( + model=model, + dataset=dataset, + output_dir=output_dir, + num_calibration_samples=num_calibration_samples, + recipe=recipe_str, + concatenate_data=concatenate_data, + splits=splits, + oneshot_device=device, + ) + + if status == QuantizationStatus.FROZEN: + model.apply(freeze_module_quantization) + elif status == QuantizationStatus.COMPRESSED: + model.apply(compress_quantized_weights) + + for _, module in model.named_modules(): + if hasattr(module, "quantization_scheme"): + assert module.quantization_status == status + + model.save_pretrained(tmp_path / "compress_out") + + config = AutoConfig.from_pretrained(tmp_path / "compress_out") + compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) + quant_config = compression_config.get(QUANTIZATION_CONFIG_NAME, None) + assert quant_config["quantization_status"] == status.value + assert quant_config["format"] == expected_format + + with safe_open( + tmp_path / "compress_out" / "model.safetensors", framework="pt", device=device + ) as f: + test_tensor = f.get_tensor("model.layers.0.mlp.down_proj.weight") + assert test_tensor.dtype == expected_dtype + + shutil.rmtree(tmp_path)