Skip to content

Commit

Permalink
[CI] FIX test perplexity fail (ModelCloud#160)
Browse files Browse the repository at this point in the history
* fix not defined error

* fix test_perplexity fail

* modify dataset filter text length

* modify assert the difference of ppl

* modify dataset filter with text length
  • Loading branch information
ZYC-ModelCloud authored Jul 4, 2024
1 parent fb388f3 commit d5c1024
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 34 deletions.
3 changes: 2 additions & 1 deletion gptqmodel/utils/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def _prepare_data(self):
self._dataset_name = "wikitext-2-raw-v1"

# Load the dataset
data = load_dataset(self._dataset_path, self._dataset_name, split=self._split).filter(lambda x: len(x[self._text_column]) >= 512).select(range(1024))
length = 512 if self._dataset_path == "wikitext" else 2048
data = load_dataset(self._dataset_path, self._dataset_name, split=self._split).filter(lambda x: len(x[self._text_column]) >= length).select(range(1024))
# Format the text column of the dataset
text_list = [" \n" if s == "" else s for s in data[self._text_column]]
return "".join(text_list)
Expand Down
87 changes: 54 additions & 33 deletions tests/test_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,37 @@
import unittest # noqa: E402

from datasets import load_dataset # noqa: E402
from parameterized import parameterized # noqa: E402
from transformers import AutoTokenizer # noqa: E402

from gptqmodel import GPTQModel # noqa: E402
from gptqmodel.quantization import FORMAT, QuantizeConfig # noqa: E402
from gptqmodel.utils import Perplexity # noqa: E402
from parameterized import parameterized # noqa: E402
from transformers import AutoTokenizer, AutoModelForCausalLM # noqa: E402


class TestPerplexity(unittest.TestCase):
NATIVE_MODEL_ID = "ModelCloud/tinyllama-15M-stories"

# DATASET_PATH = "wikitext"
# DATASET_NAME = "wikitext-2-raw-v1"
# DATASET_SPLIT = "test"
# DATASET_COLUMN = "text"
DATASET_PATH = "skeskinen/TinyStories-hf"
DATASET_NAME = "default"
DATASET_SPLIT = "train"
DATASET_COLUMN = "text"
TINYLLAMA_MODEL_ID = "ModelCloud/tinyllama-15M-stories"
OPT_MODEL_ID = "facebook/opt-125m"

OPT_DATASET_PATH = "wikitext"
OPT_DATASET_NAME = "wikitext-2-raw-v1"
OPT_DATASET_SPLIT = "test"
OPT_DATASET_COLUMN = "text"
TINYLLAMA_DATASET_PATH = "skeskinen/TinyStories-hf"
TINYLLAMA_DATASET_NAME = "default"
TINYLLAMA_DATASET_SPLIT = "train"
TINYLLAMA_DATASET_COLUMN = "text"

N_CTX = 512
N_BATCH = 512

def calculate_avg_ppl(self, model, tokenizer):
def calculate_avg_ppl(self, model, tokenizer, format: FORMAT):
ppl = Perplexity(
model=model,
tokenizer=tokenizer,
dataset_path=self.DATASET_PATH,
# dataset_name=self.DATASET_NAME,
split=self.DATASET_SPLIT,
text_column=self.DATASET_COLUMN,
dataset_path=self.OPT_DATASET_PATH if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.TINYLLAMA_DATASET_PATH,
dataset_name=self.OPT_DATASET_NAME if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.TINYLLAMA_DATASET_NAME,
split=self.OPT_DATASET_SPLIT if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.TINYLLAMA_DATASET_SPLIT,
text_column=self.OPT_DATASET_COLUMN if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.TINYLLAMA_DATASET_COLUMN,
)

all = ppl.calculate(n_ctx=self.N_CTX, n_batch=self.N_BATCH)
Expand All @@ -52,27 +52,43 @@ def calculate_avg_ppl(self, model, tokenizer):

@classmethod
def setUpClass(self):
from transformers import AutoModelForCausalLM
self.tokenizer = AutoTokenizer.from_pretrained(self.NATIVE_MODEL_ID, use_fast=True)
self.tinyllama_tokenizer = AutoTokenizer.from_pretrained(self.TINYLLAMA_MODEL_ID, use_fast=True)

if not self.tinyllama_tokenizer.pad_token_id:
self.tinyllama_tokenizer.pad_token_id = self.tinyllama_tokenizer.eos_token_id

if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.opt_tokenizer = AutoTokenizer.from_pretrained(self.OPT_MODEL_ID, use_fast=True)

if not self.opt_tokenizer.pad_token_id:
self.opt_tokenizer.pad_token_id = self.opt_tokenizer.eos_token_id

self.tinyllama_calibration_dataset, self.tinyllama_native_ppl = self.calculate_native_ppl(self, self.tinyllama_tokenizer, FORMAT.GPTQ)
self.opt_calibration_dataset, self.opt_native_ppl = self.calculate_native_ppl(self, self.opt_tokenizer, FORMAT.MARLIN)


def calculate_native_ppl(self, tokenizer, format):
model_id = self.OPT_MODEL_ID if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.TINYLLAMA_MODEL_ID
model = AutoModelForCausalLM.from_pretrained(
self.NATIVE_MODEL_ID,
model_id,
device_map="auto",
)

self.native_ppl = self.calculate_avg_ppl(self, model, self.tokenizer)
native_ppl = self.calculate_avg_ppl(self, model, tokenizer, format)

print(f"Native PPL: {self.native_ppl}")
print(f"{model_id} Native PPL: {native_ppl}")

# 4090: [wikitext-2-raw-v1, test, text, 512, 512] data split, tinyllama ppl == 8.4790, opt ppl == 30.02
# assert self.native_ppl < 30.5

traindata = load_dataset(self.DATASET_PATH, split=self.DATASET_SPLIT).filter(lambda x: len(x[self.DATASET_COLUMN]) >= 512)
# traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train").filter(lambda x: len(x['text']) >= 512)
self.calibration_dataset = [self.tokenizer(example[self.DATASET_COLUMN]) for example in traindata.select(range(1024))]
dataset_path = self.OPT_DATASET_PATH if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.TINYLLAMA_DATASET_PATH
dataset_name = self.OPT_DATASET_NAME if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.TINYLLAMA_DATASET_NAME
dataset_split = self.OPT_DATASET_SPLIT if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.TINYLLAMA_DATASET_SPLIT
dataset_column = self.OPT_DATASET_COLUMN if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.TINYLLAMA_DATASET_COLUMN

length = 512 if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else 2048
traindata = load_dataset(dataset_path, dataset_name, split=dataset_split).filter(lambda x: len(x[dataset_column]) >= length)
calibration_dataset = [self.tinyllama_tokenizer(example[dataset_column]) for example in traindata.select(range(1024))]
return calibration_dataset, native_ppl

@parameterized.expand(
[
Expand All @@ -95,11 +111,12 @@ def test_quantized_perplexity(self, format: FORMAT):
quantize_config.desc_act = False

model = GPTQModel.from_pretrained(
self.NATIVE_MODEL_ID,
self.OPT_MODEL_ID if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.TINYLLAMA_MODEL_ID,
quantize_config=quantize_config,
)

model.quantize(self.calibration_dataset, batch_size=256)
dataset = self.opt_calibration_dataset if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.tinyllama_calibration_dataset
model.quantize(dataset, batch_size=256)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_quantized(
Expand All @@ -113,11 +130,15 @@ def test_quantized_perplexity(self, format: FORMAT):
device_map="auto",
)

quantized_ppl = self.calculate_avg_ppl(model, self.tokenizer)
quantized_ppl = self.calculate_avg_ppl(model, self.opt_tokenizer if format == FORMAT.MARLIN or format == FORMAT.BITBLAS else self.tinyllama_tokenizer, format)

print(f"Format {format}, Quantized PPL: {quantized_ppl}")

# 4090: [wikitext-2-raw-v1, test, text, 512, 512] data split
# FORMAT.GTPQ and FORMAT.GTPQ_V2 Tinyllama ppl == 8.7863, FORMAT.MARLIN Tinyllama ppl == 9.0036
# FORMAT.GTPQ and FORMAT.GTPQ_V2 opt ppl == 30.91, FORMAT.MARLIN opt ppl == 31.09
assert abs(quantized_ppl - self.native_ppl) < 1.1
# FORMAT.MARLIN opt ppl == 34.85, FORMAT.BITBLAS opt ppl == 34.11, native opt ppl == 30.39
# FORMAT.GTPQ and FORMAT.GTPQ_V2 Tinyllama-15M ppl == 111.32, native Tinyllama-15M ppl == 54.61
if format == FORMAT.MARLIN or format == FORMAT.BITBLAS:
assert abs(quantized_ppl - self.opt_native_ppl) < 4.7
else:
assert abs(quantized_ppl - self.tinyllama_native_ppl) < 56.8

0 comments on commit d5c1024

Please sign in to comment.