Skip to content

Commit

Permalink
add TestLmHeadQuant
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jan 10, 2025
1 parent cff67f3 commit de7c0ab
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 4 deletions.
2 changes: 1 addition & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def store_input_hook(_, args, kwargs):
one_kwargs[k] = nested_move_to(v, data_device)
layer_input_kwargs.append(one_kwargs)

if not self.quantize_config.lm_head and not self.quantize_config.lm_head_low_gpu_mem_usage:
if not self.quantize_config.lm_head or self.quantize_config.lm_head_low_gpu_mem_usage:
raise ValueError

lm_head_inputs = []
Expand Down
85 changes: 82 additions & 3 deletions tests/test_lm_head.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# -- do not touch
import os
import tempfile
import unittest

from datasets import load_dataset
from transformers import AutoTokenizer

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch

from gptqmodel import GPTQModel # noqa: E402
import vllm
from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402
from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402
from models.model_test import ModelTest # noqa: E402


class TestLmHead(ModelTest):
class TestLmHeadLoad(ModelTest):
NATIVE_MODEL_ID = "/monster/data/model/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse" # "LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse"
DEVICE = "cuda:0"
NATIVE_ARC_CHALLENGE_ACC = 0.2799
Expand All @@ -24,3 +29,77 @@ def test_load(self):

def test_eval(self):
self.quant_lm_eval()


class TestLmHeadQuant(ModelTest):
APPLY_CHAT_TEMPLATE = True

sample_length = 1024
samples = 128
model_id = "Qwen/Qwen1.5-1.8B-Chat"

@classmethod
def setUpClass(cls):
calibration_dataset = load_dataset(
"allenai/c4",
data_files="en/c4-train.00001-of-01024.json.gz",
split="train"
).filter(lambda x: len(x["text"]) >= cls.sample_length).select(range(cls.samples))["text"]

# Truncating sample text to reduce memory usage
cls.calibration_dataset = [c[:cls.sample_length] for c in calibration_dataset]

def test_quant_lm_head(self):
self.NATIVE_ARC_CHALLENGE_ACC = 0.3148464163822526
self.NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3310580204778157

quant_config = QuantizeConfig(bits=4, group_size=32, lm_head=True)

model = GPTQModel.load(self.model_id, quant_config)

model.quantize(self.calibration_dataset, batch_size=8)

with tempfile.TemporaryDirectory() as tmp_dir:
model.tokenizer.save_pretrained(tmp_dir)
model.save(tmp_dir)

del model.tokenizer
del model

model = GPTQModel.load(
tmp_dir,
device_map="auto",
)

task_results = self.lm_eval(model=model,
apply_chat_template=self.APPLY_CHAT_TEMPLATE,
trust_remote_code=self.TRUST_REMOTE_CODE,
delete_quantized_model=self.DELETE_QUANTIZED_MODEL)
self.check_results(task_results)

def test_quant_lm_head_low_gpu(self):
self.NATIVE_ARC_CHALLENGE_ACC = 0.3199658703071672
self.NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3225255972696246
quant_config = QuantizeConfig(bits=4, group_size=32, lm_head=True, lm_head_low_gpu_mem_usage=True)

model = GPTQModel.load(self.model_id, quant_config)

model.quantize(self.calibration_dataset, batch_size=8)

with tempfile.TemporaryDirectory() as tmp_dir:
model.tokenizer.save_pretrained(tmp_dir)
model.save(tmp_dir)

del model.tokenizer
del model

model = GPTQModel.load(
tmp_dir,
device_map="auto",
)

task_results = self.lm_eval(model=model,
apply_chat_template=self.APPLY_CHAT_TEMPLATE,
trust_remote_code=self.TRUST_REMOTE_CODE,
delete_quantized_model=self.DELETE_QUANTIZED_MODEL)
self.check_results(task_results)

0 comments on commit de7c0ab

Please sign in to comment.