From fa6234bc9e73428510dcd0dcba84c0225a612fe9 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 7 Jun 2024 14:45:10 +0000 Subject: [PATCH] Fix docstrings --- .../pytorch/converters/transformations.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/sparseml/utils/pytorch/converters/transformations.py b/src/sparseml/utils/pytorch/converters/transformations.py index fb9f26f38ee..9a96a847b87 100644 --- a/src/sparseml/utils/pytorch/converters/transformations.py +++ b/src/sparseml/utils/pytorch/converters/transformations.py @@ -50,7 +50,9 @@ def is_gptq_quantization_target(key: str) -> bool: @_log_transformation def transform_exllama_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: """ - Transforms the state_dict keys to match with exllama format + Transforms the exallama state_dict keys to be compatible with + SparseAutoModel classes. + The renames include: - scales -> weight_fake_quant.scale - qzeros -> weight_fake_quant.zero_point @@ -85,17 +87,16 @@ def transform_autogptq_weights_and_reshape_tensors( state_dict: Dict[str, Tensor] ) -> Dict[str, Tensor]: """ - Tranforms weights into their required shapes and types for Exllama format + Tranforms weights into their required shapes and types for Exllama + to CompressedTensors conversion + The transformations include: - - Quantize the weight tensor using the scales, zeros, and g_idx tensors - additonally pack a group of 8 of them into a single 32 bit integer - and rename the tensor to qweight - - Reshape the scales tensor to [1, x] and convert to fp16 - - Reshape the zero points tensor to [1, x] of type int32 and fill with zeros - (it is assumed that quantization was symmetric) + - Unpack ad dequantize the weight tensor using the scales, zeros, and g_idx tensors + - Squeeze the scales tensor to [x] from [1, x] + :pre-condition: The state_dict should be for a quantized model - :pre-condition: The state_dict should have been transformed to exllama names :pre-condition: The state_dict should have the bias and g_idx tensors added + :param state_dict: The state_dict to be transformed :return: The transformed state_dict, with repacked and reshaped tensors """