Skip to content

Commit

Permalink
add test_q4_torch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Dec 4, 2024
1 parent 9b44f80 commit 35dfd75
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 12 deletions.
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def __init__(
infeatures: int,
outfeatures: int,
bias: bool,
kernel_switch_threshold=128,
weight_dtype=torch.float16,
kernel_switch_threshold=128,
**kwargs,
):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures,
outfeatures=outfeatures, **kwargs)
outfeatures=outfeatures, bias=bias, weight_dtype=weight_dtype, **kwargs)

self.kernel_switch_threshold = kernel_switch_threshold
self.gptqmodel_cuda_available = _gptqmodel_cuda_available
Expand Down
2 changes: 0 additions & 2 deletions tests/test_q4_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
class TestsQ4CUDA(unittest.TestCase):
@parameterized.expand(
[
(torch.float32, "cpu"),
(torch.float32, "cuda:0"),
(torch.float16, "cuda:0"),
]
Expand Down Expand Up @@ -65,7 +64,6 @@ def test_generation_desc_act_true(self, torch_dtype, device):

@parameterized.expand(
[
(torch.float32, "cpu"),
(torch.float32, "cuda:0"),
(torch.float16, "cuda:0"),
]
Expand Down
93 changes: 93 additions & 0 deletions tests/test_q4_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# -- do not touch
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch

import unittest # noqa: E402

import torch # noqa: E402
from gptqmodel import BACKEND, GPTQModel # noqa: E402
from parameterized import parameterized # noqa: E402
from transformers import AutoTokenizer # noqa: E402

GENERATE_EVAL_SIZE = 100


class TestsQ4Torch(unittest.TestCase):
@parameterized.expand(
[
(torch.float32, "cpu"),
]
)
def test_generation_desc_act_true(self, torch_dtype, device):
prompt = "I am in Paris and"

# CPU implementation is extremely slow.
new_tokens = 5
reference_output = "<s> I am in Paris and I am in love with"

model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
revision = "desc_act_true"

model_q = GPTQModel.from_quantized(
model_id,
revision=revision,
device=device,
backend=BACKEND.TORCH,
torch_dtype=torch_dtype,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

inp = tokenizer(prompt, return_tensors="pt").to(device)

# This one uses Autocast.
res = model_q.generate(**inp, num_beams=1, min_new_tokens=new_tokens, max_new_tokens=new_tokens)
predicted_text = tokenizer.decode(res[0])
print("predicted_text", predicted_text)
print("reference_output", reference_output)
self.assertEqual(predicted_text[:GENERATE_EVAL_SIZE], reference_output[:GENERATE_EVAL_SIZE])

# This one does not.
res = model_q.model.generate(**inp, num_beams=1, min_new_tokens=new_tokens, max_new_tokens=new_tokens)
predicted_text = tokenizer.decode(res[0])
print("predicted_text", predicted_text)
print("reference_output", reference_output)
self.assertEqual(predicted_text[:GENERATE_EVAL_SIZE], reference_output[:GENERATE_EVAL_SIZE])

@parameterized.expand(
[
(torch.float32, "cpu"),
]
)
def test_generation_desc_act_false(self, torch_dtype, device):
prompt = "I am in Paris and"

# CPU implementation is extremely slow.
new_tokens = 5
reference_output = "<s> I am in Paris and I am in love with"

model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"

model_q = GPTQModel.from_quantized(
model_id,
device=device,
backend=BACKEND.TORCH,
torch_dtype=torch_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

inp = tokenizer(prompt, return_tensors="pt").to(device)

# This one uses Autocast.
res = model_q.generate(**inp, num_beams=1, min_new_tokens=new_tokens, max_new_tokens=new_tokens)
predicted_text = tokenizer.decode(res[0])

self.assertEqual(predicted_text[:GENERATE_EVAL_SIZE], reference_output[:GENERATE_EVAL_SIZE])

# This one does not.
res = model_q.model.generate(**inp, num_beams=1, min_new_tokens=new_tokens, max_new_tokens=new_tokens)
predicted_text = tokenizer.decode(res[0])

self.assertEqual(predicted_text[:GENERATE_EVAL_SIZE], reference_output[:GENERATE_EVAL_SIZE])
19 changes: 11 additions & 8 deletions tests/test_transformers_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,22 @@ def _test_quantize(self, device_map):
gptq_config = GPTQConfig(bits=4, dataset=dataset, tokenizer=tokenizer)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device_map,
quantization_config=gptq_config)
quantized_model.save_pretrained("./opt-125m-gptq")
tokenizer.save_pretrained("./opt-125m-gptq")

model = AutoModelForCausalLM.from_pretrained("./opt-125m-gptq", device_map=device_map)
with tempfile.TemporaryDirectory() as tmp_dir:
quantized_model.save_pretrained(tmp_dir)
tokenizer.save_pretrained(tmp_dir)
del quantized_model

generate_str = tokenizer.decode(model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0])
model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map)

expect_str = "</s>gptqmodel is a good way to get a good way for a good way for a good way."
generate_str = tokenizer.decode(model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0])

print('generate_str',generate_str)
print('expect_str',expect_str)
expect_str = "</s>gptqmodel is a good way to get a good way for a good way for a good way."

self.assertEqual(generate_str[:40], expect_str[:40])
print('generate_str',generate_str)
print('expect_str',expect_str)

self.assertEqual(generate_str[:40], expect_str[:40])

def test_load_quantized_model_gptq_v1_ipex(self):
self._test_load_quantized_model_gptq_v1(device_map="cpu")
Expand Down

0 comments on commit 35dfd75

Please sign in to comment.