From 35dfd7541d1a1a228bfc800b9d3759ce4032b897 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 4 Dec 2024 13:19:17 +0800 Subject: [PATCH] add test_q4_torch.py --- gptqmodel/nn_modules/qlinear/qlinear_cuda.py | 4 +- tests/test_q4_cuda.py | 2 - tests/test_q4_torch.py | 93 ++++++++++++++++++++ tests/test_transformers_integration.py | 19 ++-- 4 files changed, 106 insertions(+), 12 deletions(-) create mode 100644 tests/test_q4_torch.py diff --git a/gptqmodel/nn_modules/qlinear/qlinear_cuda.py b/gptqmodel/nn_modules/qlinear/qlinear_cuda.py index 9e1239acb..a7684bc72 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_cuda.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_cuda.py @@ -33,12 +33,12 @@ def __init__( infeatures: int, outfeatures: int, bias: bool, - kernel_switch_threshold=128, weight_dtype=torch.float16, + kernel_switch_threshold=128, **kwargs, ): super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, - outfeatures=outfeatures, **kwargs) + outfeatures=outfeatures, bias=bias, weight_dtype=weight_dtype, **kwargs) self.kernel_switch_threshold = kernel_switch_threshold self.gptqmodel_cuda_available = _gptqmodel_cuda_available diff --git a/tests/test_q4_cuda.py b/tests/test_q4_cuda.py index bb1b6b48d..190e172a6 100644 --- a/tests/test_q4_cuda.py +++ b/tests/test_q4_cuda.py @@ -17,7 +17,6 @@ class TestsQ4CUDA(unittest.TestCase): @parameterized.expand( [ - (torch.float32, "cpu"), (torch.float32, "cuda:0"), (torch.float16, "cuda:0"), ] @@ -65,7 +64,6 @@ def test_generation_desc_act_true(self, torch_dtype, device): @parameterized.expand( [ - (torch.float32, "cpu"), (torch.float32, "cuda:0"), (torch.float16, "cuda:0"), ] diff --git a/tests/test_q4_torch.py b/tests/test_q4_torch.py new file mode 100644 index 000000000..fcb22672e --- /dev/null +++ b/tests/test_q4_torch.py @@ -0,0 +1,93 @@ +# -- do not touch +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch + +import unittest # noqa: E402 + +import torch # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from parameterized import parameterized # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +GENERATE_EVAL_SIZE = 100 + + +class TestsQ4Torch(unittest.TestCase): + @parameterized.expand( + [ + (torch.float32, "cpu"), + ] + ) + def test_generation_desc_act_true(self, torch_dtype, device): + prompt = "I am in Paris and" + + # CPU implementation is extremely slow. + new_tokens = 5 + reference_output = " I am in Paris and I am in love with" + + model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + revision = "desc_act_true" + + model_q = GPTQModel.from_quantized( + model_id, + revision=revision, + device=device, + backend=BACKEND.TORCH, + torch_dtype=torch_dtype, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + + inp = tokenizer(prompt, return_tensors="pt").to(device) + + # This one uses Autocast. + res = model_q.generate(**inp, num_beams=1, min_new_tokens=new_tokens, max_new_tokens=new_tokens) + predicted_text = tokenizer.decode(res[0]) + print("predicted_text", predicted_text) + print("reference_output", reference_output) + self.assertEqual(predicted_text[:GENERATE_EVAL_SIZE], reference_output[:GENERATE_EVAL_SIZE]) + + # This one does not. + res = model_q.model.generate(**inp, num_beams=1, min_new_tokens=new_tokens, max_new_tokens=new_tokens) + predicted_text = tokenizer.decode(res[0]) + print("predicted_text", predicted_text) + print("reference_output", reference_output) + self.assertEqual(predicted_text[:GENERATE_EVAL_SIZE], reference_output[:GENERATE_EVAL_SIZE]) + + @parameterized.expand( + [ + (torch.float32, "cpu"), + ] + ) + def test_generation_desc_act_false(self, torch_dtype, device): + prompt = "I am in Paris and" + + # CPU implementation is extremely slow. + new_tokens = 5 + reference_output = " I am in Paris and I am in love with" + + model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + + model_q = GPTQModel.from_quantized( + model_id, + device=device, + backend=BACKEND.TORCH, + torch_dtype=torch_dtype, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + inp = tokenizer(prompt, return_tensors="pt").to(device) + + # This one uses Autocast. + res = model_q.generate(**inp, num_beams=1, min_new_tokens=new_tokens, max_new_tokens=new_tokens) + predicted_text = tokenizer.decode(res[0]) + + self.assertEqual(predicted_text[:GENERATE_EVAL_SIZE], reference_output[:GENERATE_EVAL_SIZE]) + + # This one does not. + res = model_q.model.generate(**inp, num_beams=1, min_new_tokens=new_tokens, max_new_tokens=new_tokens) + predicted_text = tokenizer.decode(res[0]) + + self.assertEqual(predicted_text[:GENERATE_EVAL_SIZE], reference_output[:GENERATE_EVAL_SIZE]) diff --git a/tests/test_transformers_integration.py b/tests/test_transformers_integration.py index 7dd15bf96..70a2abd7b 100644 --- a/tests/test_transformers_integration.py +++ b/tests/test_transformers_integration.py @@ -35,19 +35,22 @@ def _test_quantize(self, device_map): gptq_config = GPTQConfig(bits=4, dataset=dataset, tokenizer=tokenizer) quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map, quantization_config=gptq_config) - quantized_model.save_pretrained("./opt-125m-gptq") - tokenizer.save_pretrained("./opt-125m-gptq") - model = AutoModelForCausalLM.from_pretrained("./opt-125m-gptq", device_map=device_map) + with tempfile.TemporaryDirectory() as tmp_dir: + quantized_model.save_pretrained(tmp_dir) + tokenizer.save_pretrained(tmp_dir) + del quantized_model - generate_str = tokenizer.decode(model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0]) + model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map) - expect_str = "gptqmodel is a good way to get a good way for a good way for a good way." + generate_str = tokenizer.decode(model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0]) - print('generate_str',generate_str) - print('expect_str',expect_str) + expect_str = "gptqmodel is a good way to get a good way for a good way for a good way." - self.assertEqual(generate_str[:40], expect_str[:40]) + print('generate_str',generate_str) + print('expect_str',expect_str) + + self.assertEqual(generate_str[:40], expect_str[:40]) def test_load_quantized_model_gptq_v1_ipex(self): self._test_load_quantized_model_gptq_v1(device_map="cpu")