Skip to content

Commit

Permalink
group size
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed May 23, 2024
1 parent a81edca commit fdb60b3
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
39 changes: 39 additions & 0 deletions examples/llama7b_sparse_quantized/2:4_w4a16_recipe-group.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
sparsity_stage:
run_type: oneshot
sparsity_modifiers:
SparseGPTModifier:
sparsity: 0.5
mask_structure: "2:4"
sequential_update: false
finetuning_stage:
run_type: train
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
quantization_stage:
run_type: oneshot
quantization_modifiers:
vLLMQuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: "group"
group_size: 128
targets: ["Linear"]
SparseGPTModifier:
sparsity: 0.0
quantize: True
sequential_update: false
54 changes: 54 additions & 0 deletions examples/llama7b_sparse_quantized/llama7b_sparse_w4a16-group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

from sparseml.transformers import SparseAutoModelForCausalLM, apply


# define a recipe to handle sparsity, finetuning and quantization
recipe = "2:4_w4a16_recipe-group.yaml"

# load the model in as bfloat16 to save on memory and compute
model_stub = "zoo:llama2-7b-ultrachat200k_llama2_pretrain-base"
model = SparseAutoModelForCausalLM.from_pretrained(
model_stub, torch_dtype=torch.bfloat16, device_map="auto"
)

# uses SparseML's built-in preprocessing for ultra chat
dataset = "ultrachat-200k"

# save location of quantized model
output_dir = "output_llama7b_2:4_w4a16_group"

# set dataset config parameters
splits = {"calibration": "train_gen[:5%]", "train": "train_gen"}
max_seq_length = 512
num_calibration_samples = 512

# set training parameters for finetuning
num_train_epochs = 1
logging_steps = 500
save_steps = 5000
gradient_checkpointing = True # saves memory during training
learning_rate = 0.0001
bf16 = True # using bfloat16 for training
lr_scheduler_type = "cosine"
warmup_ratio = 0.1

# this will run the recipe stage by stage:
# oneshot sparsification -> finetuning -> oneshot quantization
apply(
model=model,
dataset=dataset,
recipe=recipe,
bf16=bf16,
output_dir=output_dir,
splits=splits,
max_seq_length=max_seq_length,
num_calibration_samples=num_calibration_samples,
num_train_epochs=num_train_epochs,
logging_steps=logging_steps,
save_steps=save_steps,
gradient_checkpointing=gradient_checkpointing,
learning_rate=learning_rate,
lr_scheduler_type=lr_scheduler_type,
warmup_ratio=warmup_ratio,
)

0 comments on commit fdb60b3

Please sign in to comment.