From c6996975ad51fd7fc8b5decec3d3e42f4947da0c Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 23 May 2024 16:35:34 +0000 Subject: [PATCH 1/6] Update tests; diff updated on compressed tensors side --- .../pytorch/modifiers/pruning/sparsegpt/test_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index e52b6e2ef23..814d146cdda 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -100,7 +100,7 @@ def test_create_default_quant_modifier(self): self.assertEqual(should_be_default_quant_scheme.input_activations.num_bits, 8) # input activations are symmetric by default in vLLMQuantizationModifier assert should_be_default_quant_scheme.input_activations.symmetric - + self.assertEqual(should_be_default_quant_scheme.weights.num_bits, 8) assert should_be_default_quant_scheme.weights.symmetric From 08b39dc68f2603a9a6433d1a318ec24a8311ca6c Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 23 May 2024 17:25:41 +0000 Subject: [PATCH 2/6] Style --- .../pytorch/modifiers/pruning/sparsegpt/test_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 814d146cdda..e52b6e2ef23 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -100,7 +100,7 @@ def test_create_default_quant_modifier(self): self.assertEqual(should_be_default_quant_scheme.input_activations.num_bits, 8) # input activations are symmetric by default in vLLMQuantizationModifier assert should_be_default_quant_scheme.input_activations.symmetric - + self.assertEqual(should_be_default_quant_scheme.weights.num_bits, 8) assert should_be_default_quant_scheme.weights.symmetric From d6709ddb086520e410850f08bf96a6ad08f9fe05 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Mon, 27 May 2024 12:26:49 +0000 Subject: [PATCH 3/6] Initial commit --- src/sparseml/evaluation/integrations/perplexity.py | 2 +- src/sparseml/modifiers/quantization/gptq/pytorch.py | 3 ++- src/sparseml/pytorch/utils/sparsification.py | 3 +++ src/sparseml/utils/pytorch/module.py | 1 - 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/sparseml/evaluation/integrations/perplexity.py b/src/sparseml/evaluation/integrations/perplexity.py index b3ae2d12ec4..5d8b3d5c1fc 100644 --- a/src/sparseml/evaluation/integrations/perplexity.py +++ b/src/sparseml/evaluation/integrations/perplexity.py @@ -61,7 +61,7 @@ def perplexity_eval( dataset_config_name = _infer_dataset_config_name(datasets) task = "text-generation" split = kwargs.pop("split", None) - model = SparseAutoModelForCausalLM.from_pretrained(model_path) + model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") tokenizer = SparseAutoTokenizer.from_pretrained(model_path) input_text = _load_perplexity_dataset( diff --git a/src/sparseml/modifiers/quantization/gptq/pytorch.py b/src/sparseml/modifiers/quantization/gptq/pytorch.py index 6f1c9f40bbd..854e14b8e4f 100644 --- a/src/sparseml/modifiers/quantization/gptq/pytorch.py +++ b/src/sparseml/modifiers/quantization/gptq/pytorch.py @@ -23,7 +23,7 @@ from sparseml.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper from sparseml.modifiers.utils.layer_compressor import LayerCompressor from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward - +from sparseml.utils.fsdp.context import fix_fsdp_module_name __all__ = ["GPTQModifierPyTorch"] @@ -116,6 +116,7 @@ def initialize_compression( self.layer_compressors_ = [] for idx, (name, layer) in enumerate(self.compressible_layers_.items()): + name = fix_fsdp_module_name(name) _LOGGER.info(f"Preparing {name} for compression") args = self._pruning_arguments() comp_cls = self._compression_class() diff --git a/src/sparseml/pytorch/utils/sparsification.py b/src/sparseml/pytorch/utils/sparsification.py index f22750c85c6..9891752143f 100644 --- a/src/sparseml/pytorch/utils/sparsification.py +++ b/src/sparseml/pytorch/utils/sparsification.py @@ -69,6 +69,9 @@ def __init__( self.state_dict = state_dict if self.state_dict is not None: + # when analyzing an FSDP model, the state_dict does not differentiate between + # trainable and non-trainable parameters (e.g. it can contain buffers) + # this means that the self.trainable_parameters may be overestimated self.trainable_params = [param for _, param in state_dict.items()] else: self.trainable_params = list( diff --git a/src/sparseml/utils/pytorch/module.py b/src/sparseml/utils/pytorch/module.py index 780f1255db1..437a2e723f2 100644 --- a/src/sparseml/utils/pytorch/module.py +++ b/src/sparseml/utils/pytorch/module.py @@ -188,7 +188,6 @@ def get_layer(target: str, module: Module) -> Tuple[str, Module]: def set_layer(target: str, layer: Module, module: Module) -> Module: - target = fix_fsdp_module_name(target) with summon_full_params_context(module): # importing here to avoid circular import from sparseml.utils.fsdp.helpers import maybe_get_wrapped From eee2526306cc05c910ce29ec331b88eae47506b1 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Tue, 28 May 2024 07:03:29 +0000 Subject: [PATCH 4/6] fix the FSDP name stripping --- src/sparseml/utils/fsdp/context.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/sparseml/utils/fsdp/context.py b/src/sparseml/utils/fsdp/context.py index d6a3063f05c..9805028db2d 100644 --- a/src/sparseml/utils/fsdp/context.py +++ b/src/sparseml/utils/fsdp/context.py @@ -30,7 +30,7 @@ "fix_fsdp_module_name", ] -FSDP_WRAPPER_NAME = "_fsdp_wrapped_module." +FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" def summon_full_params_context(model, offload_to_cpu: bool = False): @@ -66,4 +66,12 @@ def fix_fsdp_module_name(name: str) -> str: :param name: name to strip :return: stripped name """ - return name.replace(FSDP_WRAPPER_NAME, "") + if FSDP_WRAPPER_NAME + "." in name: + # accounting for the scenario, where the FSDP_WRAPPER_NAME + # is not the last part of the name + return name.replace(FSDP_WRAPPER_NAME + ".", "") + elif "." + FSDP_WRAPPER_NAME in name: + # accounting for the scenario, where the FSDP_WRAPPER_NAME + # is the last part of the name + return name.replace("." + FSDP_WRAPPER_NAME, "") + return name From 874f7c064858c8788cd7f1c4026932be1b654873 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Tue, 28 May 2024 07:09:08 +0000 Subject: [PATCH 5/6] cleanup after rebase --- src/sparseml/evaluation/integrations/perplexity.py | 2 +- src/sparseml/modifiers/quantization/gptq/pytorch.py | 1 + src/sparseml/pytorch/utils/sparsification.py | 7 ++++--- src/sparseml/utils/fsdp/context.py | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/sparseml/evaluation/integrations/perplexity.py b/src/sparseml/evaluation/integrations/perplexity.py index 5d8b3d5c1fc..b3ae2d12ec4 100644 --- a/src/sparseml/evaluation/integrations/perplexity.py +++ b/src/sparseml/evaluation/integrations/perplexity.py @@ -61,7 +61,7 @@ def perplexity_eval( dataset_config_name = _infer_dataset_config_name(datasets) task = "text-generation" split = kwargs.pop("split", None) - model = SparseAutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") + model = SparseAutoModelForCausalLM.from_pretrained(model_path) tokenizer = SparseAutoTokenizer.from_pretrained(model_path) input_text = _load_perplexity_dataset( diff --git a/src/sparseml/modifiers/quantization/gptq/pytorch.py b/src/sparseml/modifiers/quantization/gptq/pytorch.py index 854e14b8e4f..2eb14e8d2d7 100644 --- a/src/sparseml/modifiers/quantization/gptq/pytorch.py +++ b/src/sparseml/modifiers/quantization/gptq/pytorch.py @@ -25,6 +25,7 @@ from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward from sparseml.utils.fsdp.context import fix_fsdp_module_name + __all__ = ["GPTQModifierPyTorch"] _LOGGER = logging.getLogger(__name__) diff --git a/src/sparseml/pytorch/utils/sparsification.py b/src/sparseml/pytorch/utils/sparsification.py index 9891752143f..9542c730a0b 100644 --- a/src/sparseml/pytorch/utils/sparsification.py +++ b/src/sparseml/pytorch/utils/sparsification.py @@ -69,9 +69,10 @@ def __init__( self.state_dict = state_dict if self.state_dict is not None: - # when analyzing an FSDP model, the state_dict does not differentiate between - # trainable and non-trainable parameters (e.g. it can contain buffers) - # this means that the self.trainable_parameters may be overestimated + # when analyzing an FSDP model, the state_dict does not differentiate + # between trainable and non-trainable parameters + # (e.g. it can contain buffers) this means that the + # self.trainable_parameters may be overestimated self.trainable_params = [param for _, param in state_dict.items()] else: self.trainable_params = list( diff --git a/src/sparseml/utils/fsdp/context.py b/src/sparseml/utils/fsdp/context.py index 9805028db2d..ee18510d8be 100644 --- a/src/sparseml/utils/fsdp/context.py +++ b/src/sparseml/utils/fsdp/context.py @@ -67,7 +67,7 @@ def fix_fsdp_module_name(name: str) -> str: :return: stripped name """ if FSDP_WRAPPER_NAME + "." in name: - # accounting for the scenario, where the FSDP_WRAPPER_NAME + # accounting for the scenario, where the FSDP_WRAPPER_NAME # is not the last part of the name return name.replace(FSDP_WRAPPER_NAME + ".", "") elif "." + FSDP_WRAPPER_NAME in name: From d0c29207fb43b02db2d231614bdfc9de80170748 Mon Sep 17 00:00:00 2001 From: "bogunowicz@arrival.com" Date: Tue, 28 May 2024 07:11:09 +0000 Subject: [PATCH 6/6] refactoring --- src/sparseml/utils/fsdp/context.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/sparseml/utils/fsdp/context.py b/src/sparseml/utils/fsdp/context.py index ee18510d8be..6d9470e20a2 100644 --- a/src/sparseml/utils/fsdp/context.py +++ b/src/sparseml/utils/fsdp/context.py @@ -61,17 +61,13 @@ def main_process_first_context(): def fix_fsdp_module_name(name: str) -> str: """ - Remove FSDP wrapper prefixes from a module name + Remove FSDP wrapper prefixes from a module name. + Accounts for scenario where FSDP_WRAPPER_NAME is + at the end of the name, as well as in the middle. :param name: name to strip :return: stripped name """ - if FSDP_WRAPPER_NAME + "." in name: - # accounting for the scenario, where the FSDP_WRAPPER_NAME - # is not the last part of the name - return name.replace(FSDP_WRAPPER_NAME + ".", "") - elif "." + FSDP_WRAPPER_NAME in name: - # accounting for the scenario, where the FSDP_WRAPPER_NAME - # is the last part of the name - return name.replace("." + FSDP_WRAPPER_NAME, "") - return name + return name.replace(FSDP_WRAPPER_NAME + ".", "").replace( + "." + FSDP_WRAPPER_NAME, "" + )