From 53541f3bc3f2a80cde54d91ff724f072a0a6a270 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 22 May 2024 14:22:01 -0400 Subject: [PATCH] Llama2 7B Quantization Examples (#2285) * add channel quant example * typo * typo * memory improvement * false sequential * remove tb files * revert hessian typing * dynamic w8a8 example * support for packing * sparse example * README for sparse example * fix README links * remove FSDP references for now * update path * Update examples/llama7b_sparse_quantized/README.md Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> * clarify examples * update stage name * add memory requirements --------- --- Makefile | 2 +- .../2:4_w4a16_recipe.yaml | 38 ++++ examples/llama7b_sparse_quantized/README.md | 47 +++++ .../llama7b_sparse_w4a16.py | 54 +++++ examples/llama7b_w4a16_quantization.ipynb | 185 ++++++++++++++++++ examples/llama7b_w4a16_quantization.py | 56 ++++++ examples/llama7b_w8a8_quantization.py | 62 ++++++ .../compression/quantization_format.py | 27 ++- 8 files changed, 469 insertions(+), 2 deletions(-) create mode 100644 examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml create mode 100644 examples/llama7b_sparse_quantized/README.md create mode 100644 examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py create mode 100644 examples/llama7b_w4a16_quantization.ipynb create mode 100644 examples/llama7b_w4a16_quantization.py create mode 100644 examples/llama7b_w8a8_quantization.py diff --git a/Makefile b/Makefile index 2b715ffd5d3..6d2514c137c 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: build docs test BUILDDIR := $(PWD) -CHECKDIRS := integrations src tests utils status setup.py +CHECKDIRS := integrations src tests utils status examples setup.py CHECKGLOBS := 'integrations/**/*.py' 'src/**/*.py' 'tests/**/*.py' 'utils/**/*.py' 'status/**/*.py' setup.py DOCDIR := docs MDCHECKGLOBS := 'docs/**/*.md' 'docs/**/*.rst' 'integrations/**/*.md' diff --git a/examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml b/examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml new file mode 100644 index 00000000000..9969e5d77ce --- /dev/null +++ b/examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml @@ -0,0 +1,38 @@ +sparsity_stage: + run_type: oneshot + sparsity_modifiers: + SparseGPTModifier: + sparsity: 0.5 + mask_structure: "2:4" + sequential_update: false +finetuning_stage: + run_type: train + finetuning_modifiers: + ConstantPruningModifier: + targets: [ + 're:.*q_proj.weight', + 're:.*k_proj.weight', + 're:.*v_proj.weight', + 're:.*o_proj.weight', + 're:.*gate_proj.weight', + 're:.*up_proj.weight', + 're:.*down_proj.weight', + ] + start: 0 +quantization_stage: + run_type: oneshot + quantization_modifiers: + vLLMQuantizationModifier: + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 4 + type: "int" + symmetric: true + strategy: "channel" + targets: ["Linear"] + SparseGPTModifier: + sparsity: 0.0 + quantize: True + sequential_update: false \ No newline at end of file diff --git a/examples/llama7b_sparse_quantized/README.md b/examples/llama7b_sparse_quantized/README.md new file mode 100644 index 00000000000..45a86627d20 --- /dev/null +++ b/examples/llama7b_sparse_quantized/README.md @@ -0,0 +1,47 @@ +# Creating a Sparse Quantized Llama7b Model + +The example in this folder runs in multiple stages to create a Llama 7b model with +a 2:4 sparsity pattern and W4A16 post training quantization (PTW). The model is +calibrated and trained with the ultachat200k dataset. At least 75GB of GPU memory is +required to run this example. + +## Recipe Summary + +The recipe used for this flow is located in [2:4_w4a16_recipe.yaml](./2:4_w4a16_recipe.yaml). It contains 3 stages that are outlined below. + + +### Stage 1: Sparsification + +Runs the SparseGPT one-shot algorithm to prune the model to 50% sparsity with a 2:4 +sparsity pattern. This means that 2 weights out of every group of 4 weights are masked to 0. + +### Stage 2: Finetuning Recovery + +This stage runs a single epoch of training on the ultrachat200k dataset while maintaining +the sparsity mask from stage 1. The purpose of this stage is to recover any accuracy lost +during the sparsification process. + +### Stage 3: Quantization + +Finally, we run the GPTQ one-shot algorithm to quantize all linear weights to 4 bit +channelwise. + +## How to Run + +We can run the entire staged recipe with one call to SparseML's `apply` pathway. This +will save a checkpoint of the model after each stage. + +```python examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py``` + +### Compression + +The resulting model will be uncompressed. To save a final compressed copy of the model +run the following: + +``` +import torch +from sparseml import SparseAutoModelForCausalLM + +model = SparseAutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16) +model.save_pretrained(compressed_output_dir, save_compressed=True) +``` diff --git a/examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py b/examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py new file mode 100644 index 00000000000..f70bf20a947 --- /dev/null +++ b/examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py @@ -0,0 +1,54 @@ +import torch + +from sparseml.transformers import SparseAutoModelForCausalLM, apply + + +# define a recipe to handle sparsity, finetuning and quantization +recipe = "2:4_w4a16_recipe.yaml" + +# load the model in as bfloat16 to save on memory and compute +model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base" +model = SparseAutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.bfloat16, device_map="auto" +) + +# uses SparseML's built-in preprocessing for ultra chat +dataset = "ultrachat-200k" + +# save location of quantized model +output_dir = "output_llama7b_2:4_w4a16_channel" + +# set dataset config parameters +splits = {"calibration": "train_gen[:5%]", "train": "train_gen"} +max_seq_length = 512 +num_calibration_samples = 512 + +# set training parameters for finetuning +num_train_epochs = 1 +logging_steps = 500 +save_steps = 5000 +gradient_checkpointing = True # saves memory during training +learning_rate = 0.0001 +bf16 = True # using bfloat16 for training +lr_scheduler_type = "cosine" +warmup_ratio = 0.1 + +# this will run the recipe stage by stage: +# oneshot sparsification -> finetuning -> oneshot quantization +apply( + model=model, + dataset=dataset, + recipe=recipe, + bf16=bf16, + output_dir=output_dir, + splits=splits, + max_seq_length=max_seq_length, + num_calibration_samples=num_calibration_samples, + num_train_epochs=num_train_epochs, + logging_steps=logging_steps, + save_steps=save_steps, + gradient_checkpointing=gradient_checkpointing, + learning_rate=learning_rate, + lr_scheduler_type=lr_scheduler_type, + warmup_ratio=warmup_ratio, +) diff --git a/examples/llama7b_w4a16_quantization.ipynb b/examples/llama7b_w4a16_quantization.ipynb new file mode 100644 index 00000000000..ad1ee7af8ce --- /dev/null +++ b/examples/llama7b_w4a16_quantization.ipynb @@ -0,0 +1,185 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quantizing Llama 7B to W4A16 Using SparseML's OneShot Pathway\n", + "\n", + "This example notebook walks through how to quantize Llama 7B using SparseML. We apply int4 channel-wise quantization all Linear layers, using UltraChat 200k as a calibration dataset.\n", + "\n", + "This example requires at least 45GB of GPU memory to run. The memory requirement can be reduced to 32GB by setting `sequential_update: true` in the recipe definition, but this will increase the runtime significantly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from sparseml.transformers import SparseAutoModelForCausalLM, oneshot" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SparseML uses recipes to define configurations for different oneshot algorithms. Recipes can be defined as a string or a yaml file. Below we create a sample recipe for GPTQ quantization. The recipe is made up of two different algorithms, called modifiers.\n", + "\n", + "1. **vLLMQuantizationModifier**: calibrates the model for quantization by calculating scale and zero points from a small amount of calibration data\n", + "2. **SparseGPTModifier**: applies the GPTQ algorithm, using the result of the vLLMQuantizationModifier to determine the best quantization bin to place each linear weight into" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "recipe=\"\"\"\n", + "quant_stage:\n", + " quant_modifiers:\n", + " vLLMQuantizationModifier:\n", + " ignore: [\"lm_head\"]\n", + " config_groups:\n", + " group_0:\n", + " weights:\n", + " num_bits: 4\n", + " type: \"int\"\n", + " symmetric: true\n", + " strategy: \"channel\"\n", + " targets: [\"Linear\"]\n", + " SparseGPTModifier:\n", + " sparsity: 0.0\n", + " quantize: True\n", + " sequential_update: false\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next we need to initialize the model we wish to quantize, and define a dataset for calibration. We will use a llama2 7b model that has been pretrained on the ultrachat 200k dataset. We will use the same dataset the model has been pretrained on for our one shot calibration. \n", + "\n", + "SparseML supports several datasets, such as ultrachat-200k, out of the box. You can also pass in a tokenized `datasets.Dataset` object for custom dataset support." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# by setting the device_map to auto, we can spread the model evenly across all available GPUs\n", + "# load the model in as bfloat16 to save on memory and compute\n", + "model_stub = \"zoo:llama2-7b-ultrachat200k_llama2_pretrain-base\"\n", + "model = SparseAutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.bfloat16, device_map=\"auto\")\n", + "\n", + "# uses SparseML's built-in preprocessing for ultra chat\n", + "dataset = \"ultrachat-200k\"\n", + "\n", + "# save location of quantized model\n", + "output_dir = \"./output_llama7b_W4A16_channel_compressed\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we will configure our calibration dataset. To save on load time, we load only a small subset of ultrachat200k's `train_gen` split and label it as calibration data. For oneshot we do not need to pad the input, so we set `pad_to_max_length` to false. We also truncate each sample to a maximum of 512 tokens and select 512 samples for calibration. \n", + "\n", + "Using more calibration samples can improve model performance but will take longer to run. Generally 256-2048 calibration samples is recommended." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# set dataset config parameters\n", + "splits = {\"calibration\": \"train_gen[:5%]\"}\n", + "max_seq_length = 512\n", + "pad_to_max_length = False\n", + "num_calibration_samples = 512" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we can launch our quantization recipe using the `oneshot` function. This function call will apply the algorithms defined in `recipe` to the input `model`, using `num_calibration_samples` from `dataset` as calibration data. We will save the quantized model to `output_dir`.\n", + "\n", + "By setting `save_compressed` to True, the model will be saved by packing every 8 int4 weights into a single int32. This will enable the model to be loaded by vLLM. Once a model has been saved in this way, you can no longer recover the original unquantized weights. To save the model in a \"fake quantized\" state instead so that the original weights are preserved, set `save_compressed` to False." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "oneshot(\n", + " model=model,\n", + " dataset=dataset,\n", + " recipe=recipe,\n", + " output_dir=output_dir,\n", + " splits=splits,\n", + " max_seq_length=max_seq_length,\n", + " pad_to_max_length=pad_to_max_length,\n", + " num_calibration_samples=num_calibration_samples,\n", + " save_compressed=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The quantized model should now be stored in the defined `output_dir`. Its `config.json` will contain a new `compression_config` field that describes how the model has been quantized. This config will be used to load the model into vLLM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.save_pretrained(\"/network/sadkins/llama1.1b_W4A16_channel_packed\", save_compressed=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/llama7b_w4a16_quantization.py b/examples/llama7b_w4a16_quantization.py new file mode 100644 index 00000000000..5aabf496436 --- /dev/null +++ b/examples/llama7b_w4a16_quantization.py @@ -0,0 +1,56 @@ +import torch + +from sparseml.transformers import SparseAutoModelForCausalLM, oneshot + + +# define a sparseml recipe for GPTQ W8A8 quantization +recipe = """ +quant_stage: + quant_modifiers: + vLLMQuantizationModifier: + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 4 + type: "int" + symmetric: true + strategy: "channel" + targets: ["Linear"] + SparseGPTModifier: + sparsity: 0.0 + quantize: true + sequential_update: false +""" + +# setting device_map to auto to spread the model evenly across all available GPUs +# load the model in as bfloat16 to save on memory and compute +model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base" +model = SparseAutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.bfloat16, device_map="auto" +) + +# uses SparseML's built-in preprocessing for ultra chat +dataset = "ultrachat-200k" + +# save location of quantized model out +output_dir = "./output_llama7b_w4a16_channel_compressed" + +# set dataset config parameters +splits = {"calibration": "train_gen[:5%]"} +max_seq_length = 512 +pad_to_max_length = False +num_calibration_samples = 512 + +# apply recipe to the model and save quantized output in an int4 packed format +oneshot( + model=model, + dataset=dataset, + recipe=recipe, + output_dir=output_dir, + splits=splits, + max_seq_length=max_seq_length, + pad_to_max_length=pad_to_max_length, + num_calibration_samples=num_calibration_samples, + save_compressed=True, +) diff --git a/examples/llama7b_w8a8_quantization.py b/examples/llama7b_w8a8_quantization.py new file mode 100644 index 00000000000..5f70a2f1ae7 --- /dev/null +++ b/examples/llama7b_w8a8_quantization.py @@ -0,0 +1,62 @@ +import torch + +from sparseml.transformers import SparseAutoModelForCausalLM, oneshot + + +# define a sparseml recipe for GPTQ W8A8 quantization +recipe = """ +quant_stage: + quant_modifiers: + vLLMQuantizationModifier: + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: "int" + symmetric: true + strategy: "channel" + input_activations: + num_bits: 8 + type: "int" + symmetric: true + dynamic: True + strategy: "token" + targets: ["Linear"] + SparseGPTModifier: + sparsity: 0.0 + quantize: true + sequential_update: false +""" + +# setting device_map to auto to spread the model evenly across all available GPUs +# load the model in as bfloat16 to save on memory and compute +model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base" +model = SparseAutoModelForCausalLM.from_pretrained( + model_stub, torch_dtype=torch.bfloat16, device_map="auto" +) + +# uses SparseML's built-in preprocessing for ultra chat +dataset = "ultrachat-200k" + +# save location of quantized model out +output_dir = "./output_llama7b_w8a8_channel_compressed" + +# set dataset config parameters +splits = {"calibration": "train_gen[:5%]"} +max_seq_length = 512 +pad_to_max_length = False +num_calibration_samples = 512 + +# apply recipe to the model and save quantized output in an int8 compressed format +oneshot( + model=model, + dataset=dataset, + recipe=recipe, + output_dir=output_dir, + splits=splits, + max_seq_length=max_seq_length, + pad_to_max_length=pad_to_max_length, + num_calibration_samples=num_calibration_samples, + save_compressed=True, +) diff --git a/src/sparseml/transformers/compression/quantization_format.py b/src/sparseml/transformers/compression/quantization_format.py index 5f8f8722753..48d47d68d22 100644 --- a/src/sparseml/transformers/compression/quantization_format.py +++ b/src/sparseml/transformers/compression/quantization_format.py @@ -16,7 +16,11 @@ from typing import Optional from compressed_tensors import CompressionFormat -from compressed_tensors.quantization.utils import is_model_quantized +from compressed_tensors.quantization.utils import ( + is_model_quantized, + is_module_quantized, + iter_named_leaf_modules, +) __all__ = ["infer_quantization_format"] @@ -42,7 +46,28 @@ def infer_quantization_format( return quantization_format if save_compressed: + quant_depths = _get_quant_depths(model) + if quant_depths == [4]: # save packed if everything is int4 + return CompressionFormat.pack_quantized + + # otherwise just quantize to int8 return CompressionFormat.int_quantized else: # format will be inferred from config return None + + +def _get_quant_depths(model): + """ + Gets a list of all the quantized bit depths present in model + """ + quant_depths = [] + for _, submodule in iter_named_leaf_modules(model): + if is_module_quantized(submodule): + weight_scheme = submodule.quantization_scheme.weights + if weight_scheme is not None: + weight_bit_depth = weight_scheme.num_bits + if weight_bit_depth not in quant_depths: + quant_depths.append(weight_bit_depth) + + return quant_depths