Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jul 25, 2024
1 parent 9fd034e commit c552d25
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 13 deletions.
43 changes: 43 additions & 0 deletions tests/test_save_loaded_quantized_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from parameterized import parameterized
from gptqmodel import GPTQModel,BACKEND

MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"

class TestSave(unittest.TestCase):
@parameterized.expand(
[
(BACKEND.AUTO),
(BACKEND.EXLLAMA_V2),
(BACKEND.EXLLAMA),
(BACKEND.TRITON),
(BACKEND.BITBLAS),
(BACKEND.MARLIN),
(BACKEND.QBITS),
]
)
def test_save(self, backend):
prompt = "I am in Paris and"
device = torch.device("cuda:0") if backend != BACKEND.QBITS else torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
inp = tokenizer(prompt, return_tensors="pt").to(device)

# origin model produce correct output
origin_model = GPTQModel.from_quantized(MODEL_ID, backend=backend)
origin_model_res = origin_model.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
origin_model_predicted_text = tokenizer.decode(origin_model_res[0])

origin_model.save_quantized("./test_reshard")

# saved model produce wrong output
new_model = GPTQModel.from_quantized("./test_reshard", backend=backend)

new_model_res = new_model.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
new_model_predicted_text = tokenizer.decode(new_model_res[0])

print("origin_model_predicted_text",origin_model_predicted_text)
print("new_model_predicted_text",new_model_predicted_text)

self.assertEqual(origin_model_predicted_text[:20], new_model_predicted_text[:20])
14 changes: 1 addition & 13 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,11 @@ def test_marlin_local_serialization(self):

self.assertTrue(os.path.isfile(os.path.join(tmpdir, "gptq_model-4bit-128g.safetensors")))

with open(os.path.join(tmpdir, QUANT_CONFIG_FILENAME), "r") as config_file:
config = json.load(config_file)

self.assertTrue(config[FORMAT_FIELD_JSON] == FORMAT.MARLIN)

model = GPTQModel.from_quantized(tmpdir, device="cuda:0", backend=BACKEND.MARLIN)

def test_marlin_hf_cache_serialization(self):
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", backend=BACKEND.MARLIN)
self.assertEqual(model.quantize_config.format, FORMAT.MARLIN)

model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0", backend=BACKEND.MARLIN)
self.assertEqual(model.quantize_config.format, FORMAT.MARLIN)

def test_gptq_v1_to_v2_runtime_convert(self):
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0")
self.assertEqual(model.quantize_config.format, FORMAT.GPTQ_V2)
self.assertEqual(model.runtime_format, FORMAT.GPTQ_V2)

def test_gptq_v1_serialization(self):
model = GPTQModel.from_quantized(self.MODEL_ID, device="cuda:0")
Expand Down

0 comments on commit c552d25

Please sign in to comment.