diff --git a/src/sparseml/utils/pytorch/__init__.py b/src/sparseml/utils/pytorch/__init__.py index 10c86104af1..05bd7af1510 100644 --- a/src/sparseml/utils/pytorch/__init__.py +++ b/src/sparseml/utils/pytorch/__init__.py @@ -14,4 +14,5 @@ # flake8: noqa +from .converters import * from .module import * diff --git a/src/sparseml/utils/pytorch/converters/__init__.py b/src/sparseml/utils/pytorch/converters/__init__.py new file mode 100644 index 00000000000..87c7a2fed59 --- /dev/null +++ b/src/sparseml/utils/pytorch/converters/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa + + +from .converters import * diff --git a/src/sparseml/utils/pytorch/converters/converters.py b/src/sparseml/utils/pytorch/converters/converters.py new file mode 100644 index 00000000000..283136a3bb9 --- /dev/null +++ b/src/sparseml/utils/pytorch/converters/converters.py @@ -0,0 +1,170 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import logging +import shutil +from abc import ABC +from pathlib import Path +from typing import Callable, Dict, Iterable, Union + +import torch + +from safetensors.torch import save_file +from sparseml.pytorch.model_load.helpers import load_safetensors_state_dict +from sparseml.utils.pytorch.converters.transformations import ( + transform_autogptq_weights_and_reshape_tensors, + transform_exllama_names, +) + + +StateDictType = Union[Dict[str, torch.Tensor], str, Path] +TransformationType = Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class BaseConverter(ABC): + @classmethod + def translate(cls, state_dict: StateDictType, **kwargs) -> StateDictType: + """ + Applies transformations to the state_dict + + :param state_dict: The state_dict to apply transformations to + :param kwargs: Additional arguments to pass to the transformations + :return: The transformed state_dict + """ + _LOGGER.info("Applying transformations...") + new_state_dict = copy.copy(state_dict) + for transformation in cls.transformations(): + new_state_dict = transformation(new_state_dict, **kwargs) + return new_state_dict + + @classmethod + def convert_from_safetensors(cls, filepath: str, save_dir: str = None) -> str: + """ + Convert a .safetensors file or directory of .safetensors files, applying + transformations to the state_dict and saving the new state_dict to a new + directory + + :param filepath: The file path to the .safetensors file or directory + containing .safetensors files to convert + :param save_dir: The directory to save the converted state_dict to + :return: The directory where the converted state_dict was saved + """ + _validate_safetensors_file_path(filepath) + + filepath_: Path = Path(filepath) + if not save_dir: + save_dir = "compressed_tensors_model" + + save_dir_: Path = Path(save_dir) + save_dir_.mkdir(exist_ok=True, parents=True) + + metadata = {"format": "pt", "source": "Created by SparseML"} + + # transform and save the state_dict + if filepath_.is_dir(): + for file in filepath_.glob("*.safetensors"): + _LOGGER.info(f"Loading file: {file}") + state_dict: StateDictType = load_safetensors_state_dict(file) + new_state_dict = cls.translate(state_dict=state_dict) + save_file( + new_state_dict, filename=save_dir_ / file.name, metadata=metadata + ) + _copy_non_safetensor_files_(filepath_, save_dir_) + _update_quantization_config(filepath_, save_dir_) + + elif filepath_.is_file(): + state_dict: StateDictType = load_safetensors_state_dict(filepath) + new_state_dict = cls.translate(state_dict=state_dict) + save_file( + new_state_dict, save_path=save_dir_ / filepath_.name, metadata=metadata + ) + + return str(save_dir_) + + @classmethod + def transformations(cls) -> Iterable[TransformationType]: + """ + Returns an iterable of transformations that are applied in the converter, + each transformation should be a callable that takes a state_dict and returns + a transformed state_dict + """ + raise NotImplementedError() + + +class ExllamaToCompressedTensorConverter(BaseConverter): + """ + A converter that applies transformations to the state_dict of a autogptq + quantized model to convert it to a compressed tensor model, which can be + loaded by the SparseAutoModel classes + """ + + @classmethod + def transformations(cls): + return (transform_autogptq_weights_and_reshape_tensors, transform_exllama_names) + + +def _validate_safetensors_file_path(filepath: str): + """ + Given a file path, it is valid if: + - The file exists + - The file is either a single .safetensors file or a + directory containing .safetensors files + + :param filepath: A string file path to validate + """ + + filepath_: Path = Path(filepath) + + if not filepath_.exists(): + raise FileNotFoundError(f"File not found: {filepath}") + + if filepath_.is_dir() and not any(filepath_.glob("*.safetensors")): + raise FileNotFoundError(f"No .safetensors files found in directory: {filepath}") + + if filepath_.is_file() and not filepath_.suffix == ".safetensors": + raise ValueError(f"File must be a .safetensors file: {filepath}") + + +def _copy_non_safetensor_files_(source_dir: Path, dest_dir: Path): + """ + A helper function to copy all auxillary files in a directory that are + not .safetensors files, for example (config.json, recipe.yaml, ...) + + :param source_dir: The directory to copy files from + :param dest_dir: The directory to copy files to + """ + for file in source_dir.glob("*"): + if file.suffix != ".safetensors": + _LOGGER.info(f"Copying file: {file} to {dest_dir}") + shutil.copy(file, dest_dir / file.name) + + +def _update_quantization_config(source_dir: Path, dest_dir: Path): + """ + Updates config.json file in the destination directory by removing the + quantization_config attribute + + :param source_dir: The directory containing the original config.json file + :param dest_dir: The directory to save the updated config.json file + """ + from sparseml.transformers import SparseAutoConfig + + config = SparseAutoConfig.from_pretrained(source_dir) + + if hasattr(config, "quantization_config"): + _LOGGER.info("Updating quantization config...") + delattr(config, "quantization_config") + config.save_pretrained(dest_dir) diff --git a/src/sparseml/utils/pytorch/converters/transformations.py b/src/sparseml/utils/pytorch/converters/transformations.py new file mode 100644 index 00000000000..9a96a847b87 --- /dev/null +++ b/src/sparseml/utils/pytorch/converters/transformations.py @@ -0,0 +1,224 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa: F821 + +import functools +import logging +from typing import Dict + +import numpy +import numpy as np +import torch +from torch import Tensor + + +_LOGGER = logging.getLogger(__name__) + + +def _log_transformation(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + _LOGGER.info("Applying transformation: %s", func.__name__.upper()) + return_value = func(*args, **kwargs) + _LOGGER.info("Transformation: %s complete", func.__name__.upper()) + return return_value + + return wrapper + + +def is_gptq_quantization_target(key: str) -> bool: + """ + Assumes self_attn and mlp are the only quantization targets + in model layers of the state_dict. + :param key: The key of the state_dict + :return: True if the key is a quantization target, False otherwise + """ + return "model.layers" in key and ("self_attn" in key or "mlp" in key) + + +@_log_transformation +def transform_exllama_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Transforms the exallama state_dict keys to be compatible with + SparseAutoModel classes. + + The renames include: + - scales -> weight_fake_quant.scale + - qzeros -> weight_fake_quant.zero_point + - qweight -> weight + + Note: does not transforms the actual tensor values + + :pre-condition: The state_dict should be for a quantized model + :pre-condition: Targets only the weights of the self_attn and mlp nodes + :param state_dict: The quantized state_dict to be transformed + :return: The transformed state_dict + """ + + name_map: Dict[str, str] = { + ".scales": ".weight_fake_quant.scale", + ".qzeros": ".weight_fake_quant.zero_point", + ".qweight": ".weight", + } + + updated_state_dict = {} + for key, tensor in state_dict.items(): + if any(key.endswith(target_suffix := suffix) for suffix in name_map): + updated_key = key.replace(target_suffix, name_map[target_suffix]) + updated_state_dict[updated_key] = tensor + else: + updated_state_dict[key] = tensor + return updated_state_dict + + +@_log_transformation +def transform_autogptq_weights_and_reshape_tensors( + state_dict: Dict[str, Tensor] +) -> Dict[str, Tensor]: + """ + Tranforms weights into their required shapes and types for Exllama + to CompressedTensors conversion + + The transformations include: + - Unpack ad dequantize the weight tensor using the scales, zeros, and g_idx tensors + - Squeeze the scales tensor to [x] from [1, x] + + :pre-condition: The state_dict should be for a quantized model + :pre-condition: The state_dict should have the bias and g_idx tensors added + + :param state_dict: The state_dict to be transformed + :return: The transformed state_dict, with repacked and reshaped tensors + """ + + transformed_state_dict: Dict[str, Tensor] = {} + + # auxillary dict to store transformed weights + transformed_weights_dict: Dict[str, Tensor] = {} + + # quantize qweights before scales, and qzeros + # because the ordering in which tensors are fetched + # is not guaranteed by our implementation + for key, tensor in state_dict.items(): + if is_gptq_quantization_target(key) and key.endswith(".qweight"): + # quantize the weight tensor + scales = state_dict[key.replace("qweight", "scales")] + qzeros = state_dict[key.replace("qweight", "qzeros")] + g_idx = state_dict[key.replace("qweight", "g_idx")] + + zeros = unpack_zeros(qzeros) + qweight = unpack_int32_into_fp32( + qweight=tensor, + scales=scales, + zeros=zeros, + g_idx=g_idx, + ) + transformed_weights_dict[key] = qweight + + # transform scales + for key, tensor in state_dict.items(): + if is_gptq_quantization_target(key) and key.endswith(".scales"): + # scales [1, x] should be reshaped to [x] + scales = tensor.squeeze(0) + transformed_state_dict[key] = scales + else: + transformed_state_dict[key] = tensor + + # overwrite old weights with the new quantized weights + transformed_state_dict.update(transformed_weights_dict) + + # auxillary weights_dict not needed anymore + del transformed_weights_dict + + return transformed_state_dict + + +def unpack_zeros(qzeros): + """ + Unpack the quantized zero points tensor from 32 bit integers into 4 bit integers. + + :param qzeros: The quantized zero points tensor of int32 dtype and shape [1, 8x] + """ + bits = 4 + qzeros = qzeros.numpy().astype(np.uint32) + intzeros = np.zeros( + (qzeros.shape[0], qzeros.shape[1] * 32 // bits), dtype=np.uint32 + ) + + i = 0 + col = 0 + while col < intzeros.shape[1]: + if bits in [4]: + for j in range(i, min(i + (32 // bits), intzeros.shape[1])): + intzeros[:, j] = (qzeros[:, col] >> (bits * (j - i))) & 0xF + i += 32 // bits + col += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + intzeros = intzeros.astype(np.int32) + intzeros = torch.from_numpy(intzeros) + + return intzeros + + +def unpack_int32_into_fp32( + qweight: Tensor, scales: Tensor, zeros: Tensor, g_idx: Tensor +) -> Tensor: + """ + Unpack the quantized weight tensor from 32 bit integers into 4 bit integers, + and then dequantize them using the scales, zeros, and g_idx tensors. + + :param qweight: The quantized weight tensor of int32 dtype and shape [x, y] + :param scales: The scales tensor + :param zeros: The zero points tensor + :param g_idx: The group index tensor + :return: The dequantized weight tensor of shape [x, 8y] + """ + bits = 4 + qweight = qweight.numpy().astype(numpy.uint32) + intweight = numpy.zeros( + (qweight.shape[0] * 32 // bits, qweight.shape[1]), dtype=numpy.uint32 + ) + + i = 0 + row = 0 + while row < intweight.shape[0]: + if bits in [4]: + for j in range(i, min(i + (32 // bits), intweight.shape[0])): + intweight[j] = (qweight[row] >> (bits * (j - i))) & 0xF + i += 32 // bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + intweight = torch.from_numpy(intweight.astype(numpy.int32)) + intweight = intweight.t().contiguous() + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + scales = scales.clone().half() + + weight = [] + infeatures = intweight.shape[1] + for idx in range(infeatures): + weight.append( + ( + intweight[:, idx].float() * scales[:, g_idx[idx]] + - scale_zeros[:, g_idx[idx]] + )[:, None] + ) + weight = torch.cat(weight, dim=1) + + return weight