diff --git a/examples/llama7b_quantize_sparse_cnn.py b/examples/llama7b_quantize_sparse_cnn.py new file mode 100644 index 00000000000..2ce3fb0ac6a --- /dev/null +++ b/examples/llama7b_quantize_sparse_cnn.py @@ -0,0 +1,74 @@ +import torch +from datasets import load_dataset + +from sparseml.transformers import ( + SparseAutoModelForCausalLM, + SparseAutoTokenizer, + oneshot, +) + + +# define a sparseml recipe for GPTQ W4A16 quantization +recipe = """ +quant_stage: + quant_modifiers: + GPTQModifier: + sequential_update: false + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 4 + type: "int" + symmetric: true + strategy: "channel" + targets: ["Linear"] +""" + +# load in a 50% sparse model with 2:4 sparsity structure +# setting device_map to auto to spread the model evenly across all available GPUs +model_stub = "neuralmagic/SparseLlama-2-7b-cnn-daily-mail-pruned_50.2of4" +model = SparseAutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.bfloat16, device_map="auto" +) +tokenizer = SparseAutoTokenizer.from_pretrained(model_stub) + +# for quantization calibration, we will use a subset of the dataset that was used to +# sparsify and finetune the model +dataset = load_dataset("abisee/cnn_dailymail", "1.0.0", split="train[:5%]") + +# set dataset config parameters +max_seq_length = 4096 +pad_to_max_length = False +num_calibration_samples = 1024 + + +# preprocess the data into a single text entry, then tokenize the dataset +def process_sample(sample): + formatted = "Article:\n{}\n\n### Summarization:\n{}".format( + sample["article"], sample["highlights"] + ) + return tokenizer( + formatted, padding=pad_to_max_length, max_length=max_seq_length, truncation=True + ) + + +tokenized_dataset = dataset.map( + process_sample, remove_columns=["article", "highlights", "id"] +) + +# save location of quantized model out +output_dir = "./llama7b_sparse_24_w4a16_channel_compressed" + +# apply quantization recipe to the model and save quantized output int4 packed format +# the sparsity structure of the original model will be maintained +oneshot( + model=model, + dataset=tokenized_dataset, + recipe=recipe, + output_dir=output_dir, + max_seq_length=max_seq_length, + pad_to_max_length=pad_to_max_length, + num_calibration_samples=num_calibration_samples, + save_compressed=True, +) diff --git a/examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml b/examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml index 9969e5d77ce..1c4d2a09802 100644 --- a/examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml +++ b/examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml @@ -22,7 +22,8 @@ finetuning_stage: quantization_stage: run_type: oneshot quantization_modifiers: - vLLMQuantizationModifier: + GPTQModifier: + sequential_update: false ignore: ["lm_head"] config_groups: group_0: @@ -32,7 +33,3 @@ quantization_stage: symmetric: true strategy: "channel" targets: ["Linear"] - SparseGPTModifier: - sparsity: 0.0 - quantize: True - sequential_update: false \ No newline at end of file diff --git a/examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py b/examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py index f70bf20a947..fe454a0d7ad 100644 --- a/examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py +++ b/examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py @@ -24,12 +24,12 @@ num_calibration_samples = 512 # set training parameters for finetuning -num_train_epochs = 1 +num_train_epochs = 0.5 logging_steps = 500 save_steps = 5000 gradient_checkpointing = True # saves memory during training learning_rate = 0.0001 -bf16 = True # using bfloat16 for training +bf16 = False # using full precision for training lr_scheduler_type = "cosine" warmup_ratio = 0.1 diff --git a/examples/llama7b_w4a16_quantization.ipynb b/examples/llama7b_w4a16_quantization.ipynb index ad1ee7af8ce..194215891fa 100644 --- a/examples/llama7b_w4a16_quantization.ipynb +++ b/examples/llama7b_w4a16_quantization.ipynb @@ -25,10 +25,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "SparseML uses recipes to define configurations for different oneshot algorithms. Recipes can be defined as a string or a yaml file. Below we create a sample recipe for GPTQ quantization. The recipe is made up of two different algorithms, called modifiers.\n", + "SparseML uses recipes to define configurations for different oneshot algorithms. Recipes can be defined as a string or a yaml file. A recipe consists of one or more sparsification or quantization algorithms, called modifiers in SparseML. Below we create a sample recipe for GPTQ quantization that only requires a single modifier.\n", "\n", - "1. **vLLMQuantizationModifier**: calibrates the model for quantization by calculating scale and zero points from a small amount of calibration data\n", - "2. **SparseGPTModifier**: applies the GPTQ algorithm, using the result of the vLLMQuantizationModifier to determine the best quantization bin to place each linear weight into" + "This modifier specifies that we should quantize the weights of each linear layer to 4 bits, using a symmetric channelwise quantization pattern. The lm-head will not be quantized even though it is a Linear layer, because it is included in the ignore list." ] }, { @@ -37,10 +36,11 @@ "metadata": {}, "outputs": [], "source": [ - "recipe=\"\"\"\n", + "recipe = \"\"\"\n", "quant_stage:\n", " quant_modifiers:\n", - " vLLMQuantizationModifier:\n", + " GPTQModifier:\n", + " sequential_update: false\n", " ignore: [\"lm_head\"]\n", " config_groups:\n", " group_0:\n", @@ -50,10 +50,6 @@ " symmetric: true\n", " strategy: \"channel\"\n", " targets: [\"Linear\"]\n", - " SparseGPTModifier:\n", - " sparsity: 0.0\n", - " quantize: True\n", - " sequential_update: false\n", "\"\"\"" ] }, diff --git a/examples/llama7b_w4a16_quantization.py b/examples/llama7b_w4a16_quantization.py index 5aabf496436..a4a5f6bbb53 100644 --- a/examples/llama7b_w4a16_quantization.py +++ b/examples/llama7b_w4a16_quantization.py @@ -7,7 +7,8 @@ recipe = """ quant_stage: quant_modifiers: - vLLMQuantizationModifier: + GPTQModifier: + sequential_update: false ignore: ["lm_head"] config_groups: group_0: @@ -17,10 +18,6 @@ symmetric: true strategy: "channel" targets: ["Linear"] - SparseGPTModifier: - sparsity: 0.0 - quantize: true - sequential_update: false """ # setting device_map to auto to spread the model evenly across all available GPUs diff --git a/examples/llama7b_w8a8_quantization.py b/examples/llama7b_w8a8_quantization.py index 5f70a2f1ae7..c894613ffbb 100644 --- a/examples/llama7b_w8a8_quantization.py +++ b/examples/llama7b_w8a8_quantization.py @@ -7,7 +7,8 @@ recipe = """ quant_stage: quant_modifiers: - vLLMQuantizationModifier: + GPTQModifier: + sequential_update: false ignore: ["lm_head"] config_groups: group_0: @@ -23,10 +24,6 @@ dynamic: True strategy: "token" targets: ["Linear"] - SparseGPTModifier: - sparsity: 0.0 - quantize: true - sequential_update: false """ # setting device_map to auto to spread the model evenly across all available GPUs @@ -40,7 +37,7 @@ dataset = "ultrachat-200k" # save location of quantized model out -output_dir = "./output_llama7b_w8a8_channel_compressed" +output_dir = "./output_llama7b_w8a8_channel_dynamic_compressed" # set dataset config parameters splits = {"calibration": "train_gen[:5%]"} diff --git a/tests/sparseml/transformers/finetune/test_oneshot_then_finetune.py b/tests/sparseml/transformers/finetune/test_oneshot_then_finetune.py index eb010083343..ef7a1b30a5b 100644 --- a/tests/sparseml/transformers/finetune/test_oneshot_then_finetune.py +++ b/tests/sparseml/transformers/finetune/test_oneshot_then_finetune.py @@ -34,16 +34,13 @@ def setUp(self): self.output = Path("./finetune_output") def test_oneshot_then_finetune(self): - import torch - import sparseml - from sparseml.transformers import oneshot, train + from sparseml.transformers import SparseAutoModelForCausalLM, oneshot, train recipe_str = "tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml" - model = "Xenova/llama2.c-stories15M" - device = "cuda:0" - if not torch.cuda.is_available(): - device = "cpu" + model = SparseAutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", device_map="auto" + ) dataset = "open_platypus" concatenate_data = False num_calibration_samples = 64 @@ -59,11 +56,15 @@ def test_oneshot_then_finetune(self): recipe=recipe_str, concatenate_data=concatenate_data, splits=splits, - oneshot_device=device, ) recipe_str = "tests/sparseml/transformers/finetune/test_finetune_recipe.yaml" - model = self.output / "oneshot_out" + model = SparseAutoModelForCausalLM.from_pretrained( + self.output / "oneshot_out", device_map="auto" + ) + distill_teacher = SparseAutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", device_map="auto" + ) dataset = "open_platypus" concatenate_data = False output_dir = self.output / "finetune_out" @@ -73,7 +74,7 @@ def test_oneshot_then_finetune(self): with sparseml.create_session(): train( model=model, - distill_teacher="Xenova/llama2.c-stories15M", + distill_teacher=distill_teacher, dataset=dataset, output_dir=output_dir, num_calibration_samples=num_calibration_samples, @@ -81,7 +82,6 @@ def test_oneshot_then_finetune(self): concatenate_data=concatenate_data, splits=splits, max_steps=max_steps, - oneshot_device=device, ) def tearDown(self):