Skip to content

Commit

Permalink
Merge branch 'main' into damian/fix_fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz authored May 28, 2024
2 parents d0c2920 + 56b7854 commit 14ed0f6
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 38 deletions.
74 changes: 74 additions & 0 deletions examples/llama7b_quantize_sparse_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
from datasets import load_dataset

from sparseml.transformers import (
SparseAutoModelForCausalLM,
SparseAutoTokenizer,
oneshot,
)


# define a sparseml recipe for GPTQ W4A16 quantization
recipe = """
quant_stage:
quant_modifiers:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: "channel"
targets: ["Linear"]
"""

# load in a 50% sparse model with 2:4 sparsity structure
# setting device_map to auto to spread the model evenly across all available GPUs
model_stub = "neuralmagic/SparseLlama-2-7b-cnn-daily-mail-pruned_50.2of4"
model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)
tokenizer = SparseAutoTokenizer.from_pretrained(model_stub)

# for quantization calibration, we will use a subset of the dataset that was used to
# sparsify and finetune the model
dataset = load_dataset("abisee/cnn_dailymail", "1.0.0", split="train[:5%]")

# set dataset config parameters
max_seq_length = 4096
pad_to_max_length = False
num_calibration_samples = 1024


# preprocess the data into a single text entry, then tokenize the dataset
def process_sample(sample):
formatted = "Article:\n{}\n\n### Summarization:\n{}".format(
sample["article"], sample["highlights"]
)
return tokenizer(
formatted, padding=pad_to_max_length, max_length=max_seq_length, truncation=True
)


tokenized_dataset = dataset.map(
process_sample, remove_columns=["article", "highlights", "id"]
)

# save location of quantized model out
output_dir = "./llama7b_sparse_24_w4a16_channel_compressed"

# apply quantization recipe to the model and save quantized output int4 packed format
# the sparsity structure of the original model will be maintained
oneshot(
model=model,
dataset=tokenized_dataset,
recipe=recipe,
output_dir=output_dir,
max_seq_length=max_seq_length,
pad_to_max_length=pad_to_max_length,
num_calibration_samples=num_calibration_samples,
save_compressed=True,
)
7 changes: 2 additions & 5 deletions examples/llama7b_sparse_quantized/2:4_w4a16_recipe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ finetuning_stage:
quantization_stage:
run_type: oneshot
quantization_modifiers:
vLLMQuantizationModifier:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
Expand All @@ -32,7 +33,3 @@ quantization_stage:
symmetric: true
strategy: "channel"
targets: ["Linear"]
SparseGPTModifier:
sparsity: 0.0
quantize: True
sequential_update: false
4 changes: 2 additions & 2 deletions examples/llama7b_sparse_quantized/llama7b_sparse_w4a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
num_calibration_samples = 512

# set training parameters for finetuning
num_train_epochs = 1
num_train_epochs = 0.5
logging_steps = 500
save_steps = 5000
gradient_checkpointing = True # saves memory during training
learning_rate = 0.0001
bf16 = True # using bfloat16 for training
bf16 = False # using full precision for training
lr_scheduler_type = "cosine"
warmup_ratio = 0.1

Expand Down
14 changes: 5 additions & 9 deletions examples/llama7b_w4a16_quantization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"SparseML uses recipes to define configurations for different oneshot algorithms. Recipes can be defined as a string or a yaml file. Below we create a sample recipe for GPTQ quantization. The recipe is made up of two different algorithms, called modifiers.\n",
"SparseML uses recipes to define configurations for different oneshot algorithms. Recipes can be defined as a string or a yaml file. A recipe consists of one or more sparsification or quantization algorithms, called modifiers in SparseML. Below we create a sample recipe for GPTQ quantization that only requires a single modifier.\n",
"\n",
"1. **vLLMQuantizationModifier**: calibrates the model for quantization by calculating scale and zero points from a small amount of calibration data\n",
"2. **SparseGPTModifier**: applies the GPTQ algorithm, using the result of the vLLMQuantizationModifier to determine the best quantization bin to place each linear weight into"
"This modifier specifies that we should quantize the weights of each linear layer to 4 bits, using a symmetric channelwise quantization pattern. The lm-head will not be quantized even though it is a Linear layer, because it is included in the ignore list."
]
},
{
Expand All @@ -37,10 +36,11 @@
"metadata": {},
"outputs": [],
"source": [
"recipe=\"\"\"\n",
"recipe = \"\"\"\n",
"quant_stage:\n",
" quant_modifiers:\n",
" vLLMQuantizationModifier:\n",
" GPTQModifier:\n",
" sequential_update: false\n",
" ignore: [\"lm_head\"]\n",
" config_groups:\n",
" group_0:\n",
Expand All @@ -50,10 +50,6 @@
" symmetric: true\n",
" strategy: \"channel\"\n",
" targets: [\"Linear\"]\n",
" SparseGPTModifier:\n",
" sparsity: 0.0\n",
" quantize: True\n",
" sequential_update: false\n",
"\"\"\""
]
},
Expand Down
7 changes: 2 additions & 5 deletions examples/llama7b_w4a16_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
recipe = """
quant_stage:
quant_modifiers:
vLLMQuantizationModifier:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
Expand All @@ -17,10 +18,6 @@
symmetric: true
strategy: "channel"
targets: ["Linear"]
SparseGPTModifier:
sparsity: 0.0
quantize: true
sequential_update: false
"""

# setting device_map to auto to spread the model evenly across all available GPUs
Expand Down
9 changes: 3 additions & 6 deletions examples/llama7b_w8a8_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
recipe = """
quant_stage:
quant_modifiers:
vLLMQuantizationModifier:
GPTQModifier:
sequential_update: false
ignore: ["lm_head"]
config_groups:
group_0:
Expand All @@ -23,10 +24,6 @@
dynamic: True
strategy: "token"
targets: ["Linear"]
SparseGPTModifier:
sparsity: 0.0
quantize: true
sequential_update: false
"""

# setting device_map to auto to spread the model evenly across all available GPUs
Expand All @@ -40,7 +37,7 @@
dataset = "ultrachat-200k"

# save location of quantized model out
output_dir = "./output_llama7b_w8a8_channel_compressed"
output_dir = "./output_llama7b_w8a8_channel_dynamic_compressed"

# set dataset config parameters
splits = {"calibration": "train_gen[:5%]"}
Expand Down
22 changes: 11 additions & 11 deletions tests/sparseml/transformers/finetune/test_oneshot_then_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,13 @@ def setUp(self):
self.output = Path("./finetune_output")

def test_oneshot_then_finetune(self):
import torch

import sparseml
from sparseml.transformers import oneshot, train
from sparseml.transformers import SparseAutoModelForCausalLM, oneshot, train

recipe_str = "tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml"
model = "Xenova/llama2.c-stories15M"
device = "cuda:0"
if not torch.cuda.is_available():
device = "cpu"
model = SparseAutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M", device_map="auto"
)
dataset = "open_platypus"
concatenate_data = False
num_calibration_samples = 64
Expand All @@ -59,11 +56,15 @@ def test_oneshot_then_finetune(self):
recipe=recipe_str,
concatenate_data=concatenate_data,
splits=splits,
oneshot_device=device,
)

recipe_str = "tests/sparseml/transformers/finetune/test_finetune_recipe.yaml"
model = self.output / "oneshot_out"
model = SparseAutoModelForCausalLM.from_pretrained(
self.output / "oneshot_out", device_map="auto"
)
distill_teacher = SparseAutoModelForCausalLM.from_pretrained(
"Xenova/llama2.c-stories15M", device_map="auto"
)
dataset = "open_platypus"
concatenate_data = False
output_dir = self.output / "finetune_out"
Expand All @@ -73,15 +74,14 @@ def test_oneshot_then_finetune(self):
with sparseml.create_session():
train(
model=model,
distill_teacher="Xenova/llama2.c-stories15M",
distill_teacher=distill_teacher,
dataset=dataset,
output_dir=output_dir,
num_calibration_samples=num_calibration_samples,
recipe=recipe_str,
concatenate_data=concatenate_data,
splits=splits,
max_steps=max_steps,
oneshot_device=device,
)

def tearDown(self):
Expand Down

0 comments on commit 14ed0f6

Please sign in to comment.