diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml
index 0ef843967..811b6149d 100644
--- a/.github/workflows/unit_tests.yml
+++ b/.github/workflows/unit_tests.yml
@@ -44,7 +44,7 @@ env:
MAX_JOBS: 8
RUNNER: 10.0.14.248
TRANSFORMERS_DIFF_TESTS: "models/test_internlm,models/test_internlm2_5,models/test_xverse"
- TORCH_2_5_TESTS: "test_q4_ipex.py,test_save_loaded_quantized_model,test_quant_formats,models/test_hymba"
+ TORCH_2_5_TESTS: "test_q4_ipex.py,test_ipex_xpu.py,test_save_loaded_quantized_model,test_quant_formats,models/test_hymba"
IGNORED_TEST_FILES: "test_tgi.py,test_gptneox.py,models/test_mixtral"
GPTQMODEL_FORCE_BUILD: 1
repo: ${{ github.event.inputs.repo || github.repository }}
@@ -190,7 +190,9 @@ jobs:
- name: Install requirements
run: |
+ echo "===== init test env ====="
bash -c "$(curl -L http://$RUNNER/files/scripts/init_unit_tests.sh)" @ 12.4 2.4.1 3.11
+ echo "===== install transformers typing-extensions ====="
uv pip install transformers typing-extensions -U -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }}
- name: Compile
@@ -302,20 +304,24 @@ jobs:
- name: Install wheel
run: |
- uv pip install optimum bitblas==0.0.1.dev13 parameterized uvicorn -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }}
+ echo "===== install optimum bitblas ====="
+ uv pip install optimum bitblas==0.0.1.dev13 uvicorn -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }}
+ echo "===== install dist/whl ====="
uv pip install dist/*.whl
if [ "${{ matrix.test_script }}" == "test_quant_formats" ] || [ "${{ matrix.test_script }}" == "test_perplexity" ]; then
+ echo "===== install auto_round ====="
uv pip install auto_round
fi
bash -c "$(curl -L http://$RUNNER/files/scripts/init_unit_tests.sh)" @ 12.4 2.4.1 3.11
- uv pip install typing-extensions numpy==1.26.4 -U -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }}
if [ "${{ matrix.test_script }}" == "test_cohere2" ]; then
+ echo "===== install transformers from git ====="
uv pip install -U git+https://github.com/huggingface/transformers.git@5615a393691c81e00251e420c73e4d04c6fe22e5
else
+ echo "===== install transformers from pypi ====="
uv pip install transformers -U
fi
-
-
+ echo "===== install typing-extensions numpy==1.26.4 ====="
+ uv pip install typing-extensions numpy==1.26.4 -U -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }}
- name: Check platform
run: |
@@ -427,14 +433,20 @@ jobs:
- name: Install wheel
run: |
+ echo "===== install optimum bitblas parameterized uvicorn ====="
uv pip install optimum bitblas==0.0.1.dev13 parameterized uvicorn -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }}
+ echo "===== install dist/whl ====="
uv pip install dist/*.whl
+ echo "===== init test env ====="
bash -c "$(curl -L http://$RUNNER/files/scripts/init_unit_tests.sh)" @ 12.4 2.4.1 3.11
+ echo "===== install transformers==4.38.2 typing-extensions numpy==1.26.4 peft==0.13.2 ====="
uv pip install transformers==4.38.2 typing-extensions numpy==1.26.4 peft==0.13.2 -U -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }}
if [ "${{ matrix.test_script }}" = "test_xverse" ]; then
+ echo "===== install tokenizers==0.15.2 ====="
uv pip install tokenizers==0.15.2 -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }}
fi
if [ "${{ matrix.test_script }}" == "test_quant_formats" ] || [ "${{ matrix.test_script }}" == "test_perplexity" ]; then
+ echo "===== install auto_round ====="
uv pip install auto_round
fi
@@ -474,8 +486,10 @@ jobs:
runs-on: self-hosted
if: always() && !cancelled() && (needs.build.result == 'success' || github.event.inputs.artifact_id != '') && needs.list-test-files.outputs.torch-2-5-files != '[]'
container:
- image: ${{ needs.check-vm.outputs.ip }}:5000/modelcloud/gptqmodel:github-ci-v2-torch2.5.1
+ image: ${{ needs.check-vm.outputs.ip }}:5000/modelcloud/gptqmodel:github-ci-v3-torch2.5.1
+ options: --device /dev/dri --ipc=host
volumes:
+ - /dev/dri/by-path:/dev/dri/by-path
- /home/ci/models:/monster/data/model
strategy:
fail-fast: false
@@ -533,11 +547,17 @@ jobs:
- name: Install wheel
run: |
- bash -c "$(curl -L http://$RUNNER/files/scripts/init_unit_tests.sh)" @ 12.4 2.5.1 3.11
- uv pip install -U intel_extension_for_pytorch typing-extensions bitblas==0.0.1.dev13 -i http://${{ needs.check-vm.outputs.ip }}/simple/ --trusted-host ${{ needs.check-vm.outputs.ip }}
+ if [ "${{ matrix.test_script }}" == "test_ipex_xpu" ]; then
+ source /etc/profile.d/pyenv.sh && pyenv activate xpu
+ else
+ bash -c "$(curl -L http://$RUNNER/files/scripts/init_unit_tests.sh)" @ 12.4 2.5.1 3.11
+ fi
+
if [ "${{ matrix.test_script }}" == "test_quant_formats" ] || [ "${{ matrix.test_script }}" == "test_perplexity" ]; then
+ echo "===== install auto_round ====="
uv pip install auto_round
fi
+ echo "===== install dist/whl ====="
uv pip install dist/*.whl
- name: Find suitable GPU
@@ -562,7 +582,16 @@ jobs:
- name: Run tests
if: ${{ (!github.event.inputs.test_names || contains(github.event.inputs.test_names, matrix.test_script)) && !cancelled() }}
- run: pytest --durations=0 tests/${{ matrix.test_script }}.py
+ run: |
+ if [ "${{ matrix.test_script }}" == "test_ipex_xpu" ]; then
+ export CUDA_VISIBLE_DEVICES=""
+ source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh
+ source $ONEAPI_ROOT/../pti/0.9/env/vars.sh
+ export Pti_DIR=$ONEAPI_ROOT/../pti/0.9/lib/cmake/pti
+ source /etc/profile.d/pyenv.sh && pyenv activate xpu
+ pip list
+ fi
+ pytest --durations=0 tests/${{ matrix.test_script }}.py
- name: Release GPU
if: always()
diff --git a/README.md b/README.md
index cd83c98c2..758f3f9dd 100644
--- a/README.md
+++ b/README.md
@@ -9,6 +9,7 @@
## News
+* 12/16/2024 1.4.5-dev: Windows 11 support added/validated. Fix `dynamic` loading.
* 12/15/2024 [1.4.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.2): MacOS `gpu` (Metal) and `cpu` (M+) support added/validated for inference and quantization. Cohere 2 model support added.
* 12/13/2024 [1.4.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.1): Added Qwen2-VL model support. `mse` quantization control exposed in `QuantizeConfig`. Monkey patch `patch_vllm()` and `patch_hf()` api added to allow Transformers/Optimum/PEFT and vLLM to correctly loaded GPTQModel quantized models while upstream PRs are in pending status.
* 12/10/2024 [1.4.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.0) `EvalPlus` harness integration merged upstream. We now support both `lm-eval` and `EvalPlus`. Added pure torch `Torch` kernel. Refactored `Cuda` kernel to be `DynamicCuda` kernel. `Triton` kernel now auto-padded for max model support. `Dynamic` quantization now supports both positive `+:`:default, and `-:` negative matching which allows matched modules to be skipped entirely for quantization. Fixed auto-`Marlin` kerenl selection. Added auto-kernel fallback for unsupported kernel/module pairs. Lots of internal refractor and cleanup in-preparation for transformers/optimum/peft upstream PR merge. Deprecated the saving of `Marlin` weight format since `Marlin` supports auto conversion of `gptq` format to `Marlin` during runtime.
@@ -16,12 +17,14 @@
* 11/29/2024 [1.3.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.3.1) Olmo2 model support. Intel XPU acceleration via IPEX. Model sharding Transformer compat fix due to api deprecation in HF. Removed triton dependency. Triton kernel now optionally dependent on triton pkg.
* 11/26/2024 [1.3.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.3.0) Zero-Day Hymba model support. Removed `tqdm` and `rogue` dependency.
* 11/24/2024 [1.2.3](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.2.3) HF GLM model support. ClearML logging integration. Use `device-smi` and replace `gputil` + `psutil` depends. Fixed model unit tests.
-* 11/11/2024 🚀 [1.2.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.2.1) Meta MobileLLM model support added. `lm-eval[gptqmodel]` integration merged upstream. Intel/IPEX cpu inference merged replacing QBits (deprecated). Auto-fix/patch ChatGLM-3/GLM-4 compat with latest transformers. New `.load()` and `.save()` api.
-* 10/29/2024 🚀 [1.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.1.0) IBM Granite model support. Full auto-buildless wheel install from pypi. Reduce max cpu memory usage by >20% during quantization. 100% CI model/feature coverage.
Archived News:
+* 11/11/2024 🚀 [1.2.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.2.1) Meta MobileLLM model support added. `lm-eval[gptqmodel]` integration merged upstream. Intel/IPEX cpu inference merged replacing QBits (deprecated). Auto-fix/patch ChatGLM-3/GLM-4 compat with latest transformers. New `.load()` and `.save()` api.
+
+* 10/29/2024 🚀 [1.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.1.0) IBM Granite model support. Full auto-buildless wheel install from pypi. Reduce max cpu memory usage by >20% during quantization. 100% CI model/feature coverage.
+
* 10/12/2024 ✨ [1.0.9](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.0.9) Move AutoRound to optional and fix pip install regression in v1.0.8.
* 10/11/2024 ✨ [1.0.8](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.0.8) Add wheel for python 3.12 and cuda 11.8.
@@ -61,6 +64,7 @@ Public tests/papers and ModelCloud's internal tests have shown that GPTQ is on-p
## Features
* 🚀 Extensive model support for: `Llama 1-3.3`, `Qwen2-VL`, `Olmo2`, `Hymba`, `GLM`, `IBM Granite`, `Llama 3.2 Vision`, `MiniCPM3`, `GRIN-Moe`, `Phi 1-4`, `EXAONE 3.0`, `InternLM 2.5`, `Gemma 2`, `DeepSeek-V2`, `DeepSeek-V2-Lite`, `ChatGLM`, `MiniCPM`, `Qwen2MoE`, `DBRX`.
+* ✨ Linux, MacOS, Windows platform quantization and accelerated inference support.
* 💯 100% CI unit-test coverage for all supported models and kernels including post-quantization quality regression.
* ✨ `Dynamic`/Mixed quantization control on a per-module basis. Each layer/module can have a unique quantization config or be excluded from quantization all together.
* 🚀 [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang) inference integration for quantized model where format = `FORMAT.GPTQ`
@@ -79,7 +83,9 @@ Public tests/papers and ModelCloud's internal tests have shown that GPTQ is on-p
## Quality: GPTQModel 4bit can match BF16:
🤗 [ModelCloud quantized ultra-high recovery vortex-series models on HF](https://huggingface.co/collections/ModelCloud/vortex-673743382af0a52b2a8b9fe2)
-![image](https://github.com/user-attachments/assets/aab69119-f9c8-4c94-9634-a3c63e57095e)
+![image](https://github.com/user-attachments/assets/7b2db012-b8af-4d19-a25d-7023cef19220)
+
+
## Model Support: 🚀 (Added by GPTQModel)
| Model | | | | | | | | |
@@ -96,17 +102,17 @@ Public tests/papers and ModelCloud's internal tests have shown that GPTQ is on-p
| EXAONE 3.0 | 🚀 | InternLM 1/2.5 | 🚀 | OPT | ✅ | Yi | ✅ | |
-## Kernel and HW Accelerator Support
+## Platform and HW Support
-GPTQModel is validated for Linux x86_64 with the following devices:
+GPTQModel is validated for Linux, MacOS, and Windows 11:
| Platform | Device | | Optimized Arch | Kernels |
|-----------------|---------------| --- | -------------- | -------------- |
| Linux | Nvidia GPU | ✅ | Ampere or Higher | Marlin, Exllama V2, Exallma V1, Triton, DyanamicCuda, Torch |
| Linux | Intel/AMD CPU | ✅ | `avx512` or `amx` | IPEX, Torch |
| Linux | Intel XPU | ✅ | Intel Arc + Datacenter Max | IPEX, Torch |
-| MacOS | GPU (Metal) and CPU | ✅ | M1+ | Torch |
-
+| MacOS | GPU (Metal) / CPU | ✅ | M1+ | Torch |
+| Windows 11 | GPU (Nvidia) / CPU | ✅ | Nvidia | DynamicCuda, Torch |
## Install
diff --git a/examples/benchmark/generation_speed.py b/examples/benchmark/generation_speed.py
index 5bfc61db2..10c957896 100644
--- a/examples/benchmark/generation_speed.py
+++ b/examples/benchmark/generation_speed.py
@@ -7,11 +7,13 @@
import torch
from datasets import Dataset, load_dataset
-from gptqmodel import BACKEND, GPTQModel, QuantizeConfig
-from gptqmodel.utils.progress import ProgressBar
from transformers import AutoTokenizer, GenerationConfig
from transformers.generation.logits_process import LogitsProcessor
+from gptqmodel import BACKEND, GPTQModel, QuantizeConfig
+from gptqmodel.utils.progress import ProgressBar
+
+
logger = logging.getLogger(__name__)
random.seed(0)
diff --git a/examples/benchmark/ipex.py b/examples/benchmark/ipex.py
index 1fed35bef..753858d52 100644
--- a/examples/benchmark/ipex.py
+++ b/examples/benchmark/ipex.py
@@ -4,6 +4,7 @@
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+
try:
from optimum.intel.utils.modeling_utils import bind_cores_for_best_perf
bind_cores_for_best_perf()
@@ -13,6 +14,7 @@
import argparse
+
parser = argparse.ArgumentParser(description="Benchmark IPEX vs HF on a pre-trained model.")
parser.add_argument("--model", type=str, required=True, help="Path or name of the pre-trained model.")
parser.add_argument("--cores", type=int, default=8, help="Number of CPU cores to use.")
diff --git a/examples/benchmark/perplexity.py b/examples/benchmark/perplexity.py
index 8d6c21d36..ca045ce98 100644
--- a/examples/benchmark/perplexity.py
+++ b/examples/benchmark/perplexity.py
@@ -2,9 +2,11 @@
import os
import torch
-from gptqmodel.utils import Perplexity
from transformers import AutoTokenizer
+from gptqmodel.utils import Perplexity
+
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
if __name__ == "__main__":
@@ -51,7 +53,7 @@
tokenizer.pad_token_id = tokenizer.eos_token_id
if args.is_quantized:
- from gptqmodel import GPTQModel, BACKEND
+ from gptqmodel import BACKEND, GPTQModel
model = GPTQModel.load(
args.model_name,
diff --git a/examples/evaluation/run_language_modeling_task.py b/examples/evaluation/run_language_modeling_task.py
index fb33e1d94..ed384215d 100644
--- a/examples/evaluation/run_language_modeling_task.py
+++ b/examples/evaluation/run_language_modeling_task.py
@@ -2,12 +2,13 @@
import datasets
import torch
-from gptqmodel import GPTQModel, QuantizeConfig, BACKEND
-from gptqmodel.eval_tasks import LanguageModelingTask
from transformers import AutoTokenizer
+from gptqmodel import BACKEND, GPTQModel, QuantizeConfig
+from gptqmodel.eval_tasks import LanguageModelingTask
from gptqmodel.utils.torch import torch_empty_cache
+
DATASET = "tatsu-lab/alpaca"
WITH_INPUT_TEMPLATE = "Instruction:\n{instruction}\n\nInput:\n{input}\n\nOutput:\n"
WITHOUT_INPUT_TEMPLATE = "Instruction:\n{instruction}\n\nOutput:\n"
diff --git a/examples/evaluation/run_sequence_classification_task.py b/examples/evaluation/run_sequence_classification_task.py
index 489914fa0..f3344c858 100644
--- a/examples/evaluation/run_sequence_classification_task.py
+++ b/examples/evaluation/run_sequence_classification_task.py
@@ -3,12 +3,13 @@
import datasets
import torch
-from gptqmodel import GPTQModel, QuantizeConfig, BACKEND
-from gptqmodel.eval_tasks import SequenceClassificationTask
from transformers import AutoTokenizer
+from gptqmodel import BACKEND, GPTQModel, QuantizeConfig
+from gptqmodel.eval_tasks import SequenceClassificationTask
from gptqmodel.utils.torch import torch_empty_cache
+
DATASET = "cardiffnlp/tweet_sentiment_multilingual"
TEMPLATE = "Question:What's the sentiment of the given text? Choices are {labels}.\nText: {text}\nAnswer:"
ID2LABEL = {0: "negative", 1: "neutral", 2: "positive"}
diff --git a/examples/evaluation/run_text_summarization_task.py b/examples/evaluation/run_text_summarization_task.py
index ae44fe7ec..2357baebe 100644
--- a/examples/evaluation/run_text_summarization_task.py
+++ b/examples/evaluation/run_text_summarization_task.py
@@ -3,12 +3,13 @@
import datasets
import torch
-from gptqmodel import GPTQModel, QuantizeConfig, BACKEND
-from gptqmodel.eval_tasks import TextSummarizationTask
from transformers import AutoTokenizer, GenerationConfig
+from gptqmodel import BACKEND, GPTQModel, QuantizeConfig
+from gptqmodel.eval_tasks import TextSummarizationTask
from gptqmodel.utils.torch import torch_empty_cache
+
os.system("pip install py7zr")
diff --git a/examples/inference/run_transformers.py b/examples/inference/run_transformers.py
index 348515d3a..077dc25ea 100644
--- a/examples/inference/run_transformers.py
+++ b/examples/inference/run_transformers.py
@@ -1,5 +1,6 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
+
tokenizer = AutoTokenizer.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ")
quantized_model = AutoModelForCausalLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ")
print(tokenizer.decode(quantized_model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(quantized_model.device))[0]))
diff --git a/examples/inference/run_with_different_backends.py b/examples/inference/run_with_different_backends.py
index 428a8a04f..5b018c036 100644
--- a/examples/inference/run_with_different_backends.py
+++ b/examples/inference/run_with_different_backends.py
@@ -3,9 +3,11 @@
import sys
from argparse import ArgumentParser
-from gptqmodel import BACKEND, GPTQModel, QuantizeConfig, get_best_device
from transformers import AutoTokenizer
+from gptqmodel import BACKEND, GPTQModel, QuantizeConfig, get_best_device
+
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
pretrained_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quantized_model_id = "./TinyLlama/TinyLlama-1.1B-Chat-v1.0-4bit-128g"
diff --git a/examples/quantization/basic_usage.py b/examples/quantization/basic_usage.py
index d2aba4e3b..1fb6ce61d 100644
--- a/examples/quantization/basic_usage.py
+++ b/examples/quantization/basic_usage.py
@@ -1,8 +1,10 @@
import os
-from gptqmodel import GPTQModel, QuantizeConfig, get_best_device
from transformers import AutoTokenizer
+from gptqmodel import GPTQModel, QuantizeConfig, get_best_device
+
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
pretrained_model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
diff --git a/examples/quantization/basic_usage_autoround.py b/examples/quantization/basic_usage_autoround.py
index 4b0e2e0e6..ecf1ca363 100644
--- a/examples/quantization/basic_usage_autoround.py
+++ b/examples/quantization/basic_usage_autoround.py
@@ -1,7 +1,9 @@
import torch
+from transformers import AutoTokenizer
+
from gptqmodel import GPTQModel
from gptqmodel.quantization.config import AutoRoundQuantizeConfig # noqa: E402
-from transformers import AutoTokenizer
+
pretrained_model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quantized_model_id = "./autoround/TinyLlama-1.1B-Chat-v1.0-4bit-128g"
diff --git a/examples/quantization/basic_usage_wikitext2.py b/examples/quantization/basic_usage_wikitext2.py
index c93af66a2..1c07aa6ed 100644
--- a/examples/quantization/basic_usage_wikitext2.py
+++ b/examples/quantization/basic_usage_wikitext2.py
@@ -1,8 +1,10 @@
import torch
from datasets import load_dataset
-from gptqmodel import GPTQModel, QuantizeConfig
from transformers import AutoTokenizer
+from gptqmodel import GPTQModel, QuantizeConfig
+
+
pretrained_model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quantized_model_id = "TinyLlama-1.1B-Chat-v1.0-4bit-128g"
diff --git a/examples/quantization/transformers_usage.py b/examples/quantization/transformers_usage.py
index d6e279d29..03f4b5100 100755
--- a/examples/quantization/transformers_usage.py
+++ b/examples/quantization/transformers_usage.py
@@ -1,5 +1,6 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
+
model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = ["gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py
index fa771cce9..277c43d11 100644
--- a/gptqmodel/__init__.py
+++ b/gptqmodel/__init__.py
@@ -1,5 +1,5 @@
from .models import GPTQModel, get_best_device
-from .utils import BACKEND
from .quantization import BaseQuantizeConfig, QuantizeConfig
+from .utils import BACKEND
from .utils.exllama import exllama_set_max_input_length
from .version import __version__
diff --git a/gptqmodel/integration/src/optimum/gptq/quantizer.py b/gptqmodel/integration/src/optimum/gptq/quantizer.py
index f87d99d7d..4706b38f3 100644
--- a/gptqmodel/integration/src/optimum/gptq/quantizer.py
+++ b/gptqmodel/integration/src/optimum/gptq/quantizer.py
@@ -625,7 +625,7 @@ def tmp(_, input, output):
h.remove()
for name in subset_name_list:
logger.info(f"Quantizing {name} in block {i + 1}/{len(blocks)}...")
- quant_outputs = gptq[name].hf_quantize(
+ quant_outputs = gptq[name].fasterquant(
percdamp=self.damp_percent, group_size=self.group_size, actorder=self.desc_act
)
scale, zero, g_idx = quant_outputs[0], quant_outputs[1], quant_outputs[2]
diff --git a/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py b/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py
index f9b3416c5..e6f0d6d15 100644
--- a/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py
+++ b/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py
@@ -72,7 +72,7 @@ def validate_environment(self, *args, **kwargs):
)
elif is_gptqmodel_available() and (
version.parse(importlib.metadata.version("gptqmodel")) < version.parse("1.4.3")
- or version.parse(importlib.metadata.version("optimum")) < version.parse("1.23.99")
+ or version.parse(importlib.metadata.version("optimum")) < version.parse("1.23.3")
):
raise ImportError("The gptqmodel version should be >= 1.4.3, optimum version should >= 1.24.0")
diff --git a/gptqmodel/models/_const.py b/gptqmodel/models/_const.py
index 97bf0bc69..b8b7bd368 100644
--- a/gptqmodel/models/_const.py
+++ b/gptqmodel/models/_const.py
@@ -1,11 +1,11 @@
-import sys
from enum import Enum
import torch
from torch import device
from ..utils import BACKEND
-from ..utils.torch import HAS_XPU, HAS_MPS, HAS_CUDA
+from ..utils.torch import HAS_CUDA, HAS_MPS, HAS_XPU
+
CPU = device("cpu")
CUDA = device("cuda")
@@ -15,11 +15,18 @@
MPS = device("mps")
class DEVICE(str, Enum):
+ ALL = "all" # All device
CPU = "cpu" # All CPU
CUDA = "cuda" # Nvidia GPU
XPU = "xpu" # Intel GPU
MPS = "mps" # MacOS GPU
+class PLATFORM(str, Enum):
+ ALL = "all" # All platform
+ LINUX = "linux" # linux
+ WIN32 = "win32" # windows
+ DARWIN = "darwin" # macos
+
def validate_cuda_support(raise_exception: bool = False):
got_cuda = HAS_CUDA
diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py
index 64149561f..6cbd9eed1 100644
--- a/gptqmodel/models/auto.py
+++ b/gptqmodel/models/auto.py
@@ -1,7 +1,9 @@
from __future__ import annotations
+
import os
import sys
+
# TODO: waiting for pytorch implementgation of aten ops for MPS
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@@ -68,6 +70,7 @@
from .definitions.yi import YiGPTQ
from .definitions.ovis import OvisGPTQ
+
logger = setup_logger()
MODEL_MAP = {
@@ -124,7 +127,6 @@
HAS_IPEX = False
try:
- from intel_extension_for_pytorch.llm.quantization import IPEXWeightOnlyQuantizedLinear
HAS_IPEX = True
except Exception:
pass
@@ -276,10 +278,11 @@ def eval(
if task not in EVAL.get_task_enums():
raise ValueError(f"lm_eval support tasks: {EVAL.get_all_tasks_string()}")
- from gptqmodel.utils.eval import lm_eval
from lm_eval.utils import make_table
from transformers import AutoTokenizer
+ from gptqmodel.utils.eval import lm_eval
+
tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
model_name = 'hf' if backend == 'gptqmodel' else backend
diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py
index 257a8d7d8..a89701de5 100644
--- a/gptqmodel/models/base.py
+++ b/gptqmodel/models/base.py
@@ -19,14 +19,23 @@
from ..utils.device import get_cpu_usage_memory, get_gpu_usage_memory
from ..utils.importer import select_quant_linear
from ..utils.logger import setup_logger
-from ..utils.model import (check_to_quantized, find_layers, get_device, get_module_by_name_prefix,
- get_module_by_name_suffix, get_moe_layer_modules, move_to,
- nested_move_to, pack_model, simple_dispatch_model)
+from ..utils.model import (
+ check_to_quantized,
+ find_layers,
+ get_device,
+ get_module_by_name_prefix,
+ get_module_by_name_suffix,
+ get_moe_layer_modules,
+ move_to,
+ nested_move_to,
+ pack_model,
+ simple_dispatch_model,
+)
from ..utils.progress import ProgressBar
-from ._const import CPU, get_best_device, DEVICE
+from ..utils.torch import torch_empty_cache
+from ._const import CPU, DEVICE, get_best_device
from .loader import ModelLoader
from .writer import QUANT_LOG_DAMP, QUANT_LOG_LAYER, QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter
-from ..utils.torch import torch_empty_cache
def check_support_param_buffer_assignment(*args, **kwargs):
diff --git a/gptqmodel/models/definitions/gemma2.py b/gptqmodel/models/definitions/gemma2.py
index 15cbdba01..2bde8126d 100644
--- a/gptqmodel/models/definitions/gemma2.py
+++ b/gptqmodel/models/definitions/gemma2.py
@@ -2,6 +2,7 @@
from ...utils.logger import setup_logger
from ..base import BaseGPTQModel
+
logger = setup_logger()
SUPPORT_ERR = "Currently, only vLLM/SGLang with flashinfer enabled can correctly inference a quantized Gemma2-27B model. Pre-quantized model with sample vLLM code: https://huggingface.co/ModelCloud/gemma-2-27b-it-gptq-4bit ."
diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py
index 53ebf6bae..f5aeb4d6e 100644
--- a/gptqmodel/models/loader.py
+++ b/gptqmodel/models/loader.py
@@ -16,15 +16,27 @@
from ..quantization import QuantizeConfig
from ..quantization.config import FORMAT, FORMAT_FIELD_JSON, MIN_VERSION_WITH_V2
from ..utils.backend import BACKEND
-from ..utils.importer import select_quant_linear, select_device
+from ..utils.importer import select_device, select_quant_linear
from ..utils.logger import setup_logger
-from ..utils.marlin import (_validate_marlin_compatibility,
- _validate_marlin_device_support, prepare_model_for_marlin_load)
-from ..utils.model import (auto_dtype_from_config, convert_gptq_v1_to_v2_format, find_layers,
- get_checkpoints, get_moe_layer_modules, gptqmodel_post_init, make_quant,
- simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes)
+from ..utils.marlin import (
+ _validate_marlin_compatibility,
+ _validate_marlin_device_support,
+ prepare_model_for_marlin_load,
+)
+from ..utils.model import (
+ auto_dtype_from_config,
+ convert_gptq_v1_to_v2_format,
+ find_layers,
+ get_checkpoints,
+ get_moe_layer_modules,
+ gptqmodel_post_init,
+ make_quant,
+ simple_dispatch_model,
+ verify_model_hash,
+ verify_sharded_model_hashes,
+)
from ._const import DEVICE, SUPPORTED_MODELS, normalize_device
-from ..utils.torch import HAS_CUDA, HAS_XPU, HAS_MPS
+
logger = setup_logger()
diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py
index 2f2f5a0ef..50ae49ecd 100644
--- a/gptqmodel/models/writer.py
+++ b/gptqmodel/models/writer.py
@@ -18,17 +18,34 @@
from transformers.modeling_utils import no_init_weights
from transformers.utils.generic import ContextManagers
-from ..quantization.config import (FORMAT, META_FIELD_DAMP_AUTO_INCREMENT, META_FIELD_DAMP_PERCENT, META_FIELD_MSE,
- META_FIELD_QUANTIZER, META_FIELD_STATIC_GROUPS, META_FIELD_TRUE_SEQUENTIAL,
- META_FIELD_URI, META_QUANTIZER_GPTQMODEL, META_VALUE_URI, MIN_VERSION_WITH_V2)
+from ..quantization.config import (
+ FORMAT,
+ META_FIELD_DAMP_AUTO_INCREMENT,
+ META_FIELD_DAMP_PERCENT,
+ META_FIELD_MSE,
+ META_FIELD_QUANTIZER,
+ META_FIELD_STATIC_GROUPS,
+ META_FIELD_TRUE_SEQUENTIAL,
+ META_FIELD_URI,
+ META_QUANTIZER_GPTQMODEL,
+ META_VALUE_URI,
+ MIN_VERSION_WITH_V2,
+)
from ..utils.backend import BACKEND
from ..utils.logger import setup_logger
-from ..utils.model import (convert_gptq_v2_to_v1_format, copy_py_files, find_layers,
- get_model_files_size, get_moe_layer_modules, make_quant)
+from ..utils.model import (
+ convert_gptq_v2_to_v1_format,
+ copy_py_files,
+ find_layers,
+ get_model_files_size,
+ get_moe_layer_modules,
+ make_quant,
+)
from ..utils.torch import torch_empty_cache
from ..version import __version__
from ._const import CPU
+
logger = setup_logger()
QUANT_LOG_LAYER = "layer"
diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py
index 1190fc489..40886483f 100644
--- a/gptqmodel/nn_modules/qlinear/__init__.py
+++ b/gptqmodel/nn_modules/qlinear/__init__.py
@@ -1,8 +1,11 @@
+import sys
from typing import List, Optional, Tuple, Union
+import torch
import torch.nn as nn
-from ...models._const import DEVICE, normalize_device
+from ...models._const import DEVICE, PLATFORM, normalize_device
+
class BaseQuantLinear(nn.Module):
SUPPORTS_BITS: List[int] = None
@@ -16,6 +19,7 @@ class BaseQuantLinear(nn.Module):
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY: List[int] = None
SUPPORTS_DEVICES: List[DEVICE] = None
+ SUPPORTS_PLATFORM: List[PLATFORM] = None
def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, *args,
**kwargs):
@@ -72,7 +76,11 @@ def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynami
outfeatures:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[bool, Optional[Exception]]:
cls.verify_supports_params()
- if device is not None:
+ if PLATFORM.ALL not in cls.SUPPORTS_PLATFORM and sys.platform not in cls.SUPPORTS_PLATFORM:
+ err = f"{cls} does not support platform: {sys.platform}"
+ return False, NotImplementedError(err)
+
+ if DEVICE.ALL not in cls.SUPPORTS_DEVICES and device is not None:
try:
cls.validate_device(device)
except NotImplementedError:
@@ -150,7 +158,7 @@ def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynami
return True, None
@classmethod
- def validate_device(cls, device: DEVICE):
+ def validate_device(cls, device: str|DEVICE|int|torch.device):
dev = normalize_device(device)
if dev not in cls.SUPPORTS_DEVICES:
diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py
index e13763e82..00b89787a 100644
--- a/gptqmodel/nn_modules/qlinear/bitblas.py
+++ b/gptqmodel/nn_modules/qlinear/bitblas.py
@@ -9,11 +9,13 @@
import numpy as np
import torch
import torch.nn as nn
+
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
-from ...models._const import DEVICE
+from ...models._const import DEVICE, PLATFORM
from ...utils.logger import setup_logger
+
logger = setup_logger()
BITBLAS_TARGET = None
@@ -86,6 +88,7 @@ class BitBLASQuantLinear(BaseQuantLinear):
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [16]
SUPPORTS_DEVICES = [DEVICE.CUDA]
+ SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32]
OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512]
zeros_mode = "quantized" # "original" or "rescale" or "quantized"
@@ -136,12 +139,10 @@ def __init__(
self.reset_parameters()
@classmethod
- def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
- outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
- bool, Optional[Exception]]:
+ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if not BITBLAS_AVAILABLE:
return False, ValueError(BITBLAS_INSTALL_HINT)
- return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable)
+ return cls._validate(**args)
def _validate_parameters(
self, group_size: int, infeatures: int, outfeatures: int
diff --git a/gptqmodel/nn_modules/qlinear/bitblas_target_detector.py b/gptqmodel/nn_modules/qlinear/bitblas_target_detector.py
index b98dd29ab..5034e5c22 100644
--- a/gptqmodel/nn_modules/qlinear/bitblas_target_detector.py
+++ b/gptqmodel/nn_modules/qlinear/bitblas_target_detector.py
@@ -8,6 +8,7 @@
from ...utils.logger import setup_logger
+
logger = setup_logger()
TARGET_MISSING_ERROR = (
diff --git a/gptqmodel/nn_modules/qlinear/dynamic_cuda.py b/gptqmodel/nn_modules/qlinear/dynamic_cuda.py
index 28a29ac04..dcbbe9606 100644
--- a/gptqmodel/nn_modules/qlinear/dynamic_cuda.py
+++ b/gptqmodel/nn_modules/qlinear/dynamic_cuda.py
@@ -1,10 +1,12 @@
# License: GPTQModel/licenses/LICENSE.apache
import torch
+
from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear
from gptqmodel.utils.logger import setup_logger
-from ...models._const import DEVICE
+from ...models._const import DEVICE, PLATFORM
+
logger = setup_logger()
@@ -29,6 +31,7 @@ class DynamicCudaQuantLinear(TorchQuantLinear):
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [64]
SUPPORTS_DEVICES = [DEVICE.CUDA]
+ SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32]
# for transformers/optimum tests compat
QUANT_TYPE = "cuda"
diff --git a/gptqmodel/nn_modules/qlinear/exllama.py b/gptqmodel/nn_modules/qlinear/exllama.py
index 569bd9bbd..85669e037 100644
--- a/gptqmodel/nn_modules/qlinear/exllama.py
+++ b/gptqmodel/nn_modules/qlinear/exllama.py
@@ -3,15 +3,18 @@
import math
from logging import getLogger
+from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
+
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
-from ...models._const import DEVICE
+from ...models._const import DEVICE, PLATFORM
+
exllama_import_exception = None
try:
@@ -54,6 +57,7 @@ class ExllamaQuantLinear(BaseQuantLinear):
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_DEVICES = [DEVICE.CUDA]
+ SUPPORTS_PLATFORM = [PLATFORM.LINUX]
# for transformers/optimum tests compat
QUANT_TYPE = "exllama"
@@ -111,6 +115,12 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
else:
self.bias = None
+ @classmethod
+ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
+ if exllama_import_exception is not None:
+ return False, exllama_import_exception
+ return cls._validate(**args)
+
def post_init(self):
self.validate_device(self.qweight.device.type)
assert self.qweight.device.index is not None
diff --git a/gptqmodel/nn_modules/qlinear/exllamav2.py b/gptqmodel/nn_modules/qlinear/exllamav2.py
index 6d977a868..7088c0279 100644
--- a/gptqmodel/nn_modules/qlinear/exllamav2.py
+++ b/gptqmodel/nn_modules/qlinear/exllamav2.py
@@ -2,14 +2,17 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
import math
+from typing import Optional, Tuple
import torch
import torch.nn.functional as F
+
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
-from ...models._const import DEVICE
+from ...models._const import DEVICE, PLATFORM
from ...utils.logger import setup_logger
+
exllama_v2_import_exception = None
try:
from gptqmodel_exllamav2_kernels import gemm_half_q_half, make_q_matrix
@@ -115,6 +118,7 @@ class ExllamaV2QuantLinear(BaseQuantLinear):
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_DEVICES = [DEVICE.CUDA]
+ SUPPORTS_PLATFORM = [PLATFORM.LINUX]
# for transformers/optimum tests compat
QUANT_TYPE = "exllamav2"
@@ -178,6 +182,12 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
else:
self.bias = None
+ @classmethod
+ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
+ if exllama_v2_import_exception is not None:
+ return False, exllama_v2_import_exception
+ return cls._validate(**args)
+
def post_init(self, temp_dq):
self.validate_device(self.qweight.device.type)
assert self.qweight.device.index is not None
diff --git a/gptqmodel/nn_modules/qlinear/ipex.py b/gptqmodel/nn_modules/qlinear/ipex.py
index 37fec5fa8..0e1c27c87 100644
--- a/gptqmodel/nn_modules/qlinear/ipex.py
+++ b/gptqmodel/nn_modules/qlinear/ipex.py
@@ -8,12 +8,14 @@
import torch
import torch.nn as nn
import transformers
-from gptqmodel.models._const import DEVICE
+
+from gptqmodel.models._const import DEVICE, PLATFORM
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from ...utils.logger import setup_logger
from ...utils.torch import HAS_XPU
+
logger = setup_logger()
BITS_DTYPE_MAPPING = {
@@ -64,6 +66,7 @@ class IPEXQuantLinear(BaseQuantLinear):
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1]
SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU]
+ SUPPORTS_PLATFORM = [PLATFORM.LINUX]
# for transformers/optimum tests compat
QUANT_TYPE = "ipex"
@@ -133,23 +136,21 @@ def __init__(
self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0)
@classmethod
- def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
- outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
- bool, Optional[Exception]]:
+ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if sys.platform != "linux":
return False, Exception("IPEX is only available on Linux platform.")
if not HAS_IPEX:
return False, IPEX_ERROR_LOG
- return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable)
+ return cls._validate(**args)
def post_init(self):
self.validate_device(self.qweight.device.type)
def init_ipex_linear(self, x: torch.Tensor):
if not self.training and HAS_IPEX and not x.requires_grad:
- self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, \
- self.infeatures, self.outfeatures, None, self.bias, \
+ self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros,
+ self.infeatures, self.outfeatures, None, self.bias,
self.group_size, self.g_idx, quant_method=0, dtype=0)
def pack(self, linear, scales, zeros, g_idx=None):
diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py
index 9141d3afb..37768edd6 100644
--- a/gptqmodel/nn_modules/qlinear/marlin.py
+++ b/gptqmodel/nn_modules/qlinear/marlin.py
@@ -5,10 +5,12 @@
import numpy as np
import torch
-from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from torch.nn.parameter import Parameter
-from ...models._const import DEVICE
+from gptqmodel.nn_modules.qlinear import BaseQuantLinear
+
+from ...models._const import DEVICE, PLATFORM
+
marlin_import_exception = None
try:
@@ -150,6 +152,7 @@ class MarlinQuantLinear(BaseQuantLinear):
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [64]
SUPPORTS_DEVICES = [DEVICE.CUDA]
+ SUPPORTS_PLATFORM = [PLATFORM.LINUX]
# for transformers/optimum tests compat
QUANT_TYPE = "marlin"
@@ -283,6 +286,12 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
else:
self.bias = None
+ @classmethod
+ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
+ if marlin_import_exception is not None:
+ return False, marlin_import_exception
+ return cls._validate(**args)
+
def post_init(self):
device = self.qweight.device
self.validate_device(device.type)
diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py
index 47ef422c8..4fde4b552 100644
--- a/gptqmodel/nn_modules/qlinear/torch.py
+++ b/gptqmodel/nn_modules/qlinear/torch.py
@@ -7,10 +7,12 @@
import torch.nn as nn
import torch.nn.functional as F
import transformers
+
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel.utils.logger import setup_logger
-from ...models._const import DEVICE
+from ...models._const import DEVICE, PLATFORM
+
logger = setup_logger()
@@ -25,7 +27,8 @@ class TorchQuantLinear(BaseQuantLinear):
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1]
- SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU, DEVICE.CUDA, DEVICE.MPS]
+ SUPPORTS_DEVICES = [DEVICE.ALL]
+ SUPPORTS_PLATFORM = [PLATFORM.ALL]
# for transformers/optimum tests compat
QUANT_TYPE = "torch"
diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py
index 6e617eb59..84cdacbef 100644
--- a/gptqmodel/nn_modules/qlinear/tritonv2.py
+++ b/gptqmodel/nn_modules/qlinear/tritonv2.py
@@ -10,11 +10,12 @@
import transformers
from packaging import version
-from ...models._const import DEVICE
+from ...models._const import DEVICE, PLATFORM
from ...utils.logger import setup_logger
from ..triton_utils.mixin import TritonModuleMixin
from . import BaseQuantLinear
+
try:
from triton import __version__ as triton_version
@@ -42,6 +43,7 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_DEVICES = [DEVICE.CUDA]
+ SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32]
# for transformers/optimum tests compat
QUANT_TYPE = "tritonv2"
@@ -98,12 +100,10 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
self.bias = None
@classmethod
- def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
- outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
- bool, Optional[Exception]]:
+ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if not TRITON_AVAILABLE:
return False, ValueError(TRITON_INSTALL_HINT)
- return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable)
+ return cls._validate(**args)
def post_init(self):
self.validate_device(self.qweight.device.type)
diff --git a/gptqmodel/nn_modules/triton_utils/custom_autotune.py b/gptqmodel/nn_modules/triton_utils/custom_autotune.py
index fde5ca2cc..9356e8b13 100644
--- a/gptqmodel/nn_modules/triton_utils/custom_autotune.py
+++ b/gptqmodel/nn_modules/triton_utils/custom_autotune.py
@@ -5,6 +5,7 @@
import triton
+
# code based https://github.com/fpgaminer/GPTQ-triton
"""
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
diff --git a/gptqmodel/nn_modules/triton_utils/kernels.py b/gptqmodel/nn_modules/triton_utils/kernels.py
index 40a45bee2..7150d34dd 100644
--- a/gptqmodel/nn_modules/triton_utils/kernels.py
+++ b/gptqmodel/nn_modules/triton_utils/kernels.py
@@ -6,6 +6,7 @@
from ...utils.logger import setup_logger
from . import custom_autotune
+
logger = setup_logger()
diff --git a/gptqmodel/quantization/__init__.py b/gptqmodel/quantization/__init__.py
index d97184ac4..a9e03bbcf 100644
--- a/gptqmodel/quantization/__init__.py
+++ b/gptqmodel/quantization/__init__.py
@@ -1,4 +1,13 @@
-from .config import (FORMAT, FORMAT_FIELD_CODE, FORMAT_FIELD_COMPAT_MARLIN, FORMAT_FIELD_JSON,
- QUANT_CONFIG_FILENAME, QUANT_METHOD, QUANT_METHOD_FIELD, BaseQuantizeConfig, QuantizeConfig)
+from .config import (
+ FORMAT,
+ FORMAT_FIELD_CODE,
+ FORMAT_FIELD_COMPAT_MARLIN,
+ FORMAT_FIELD_JSON,
+ QUANT_CONFIG_FILENAME,
+ QUANT_METHOD,
+ QUANT_METHOD_FIELD,
+ BaseQuantizeConfig,
+ QuantizeConfig,
+)
from .gptq import GPTQ
from .quantizer import Quantizer, quantize
diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py
index 02c254a06..fc29cf35b 100644
--- a/gptqmodel/quantization/config.py
+++ b/gptqmodel/quantization/config.py
@@ -98,6 +98,21 @@ def dict_scale_dtype_to_str(d: Dict[str, Any]) -> None:
if isinstance(value, dict):
dict_scale_dtype_to_str(value)
+
+def dynamic_get(dynamic: Dict[str, Dict[str, Union[int, bool]]], layer_name: str, key: str = None,
+ default_value: Union[int, bool] = None) -> Union[Dict, int, bool]:
+ for pattern, pattern_dict in dynamic.items():
+ if pattern.startswith("-:"):
+ if re.match(pattern.removeprefix("-:"), layer_name):
+ return False
+ elif re.match(pattern.removeprefix("+:"), layer_name):
+ if key is None:
+ return pattern_dict
+ else:
+ return pattern_dict.get(key, default_value)
+ return default_value
+
+
@dataclass
class QuantizeConfig():
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
@@ -183,16 +198,7 @@ def meta_get(self, key: str) -> Any:
return self.meta.get(key)
def dynamic_get(self, layer_name: str, key: str = None, default_value: Union[int, bool] = None) -> Union[Dict, int, bool]:
- for pattern, pattern_dict in self.dynamic.items():
- if pattern.startswith("-:"):
- if re.match(pattern.removeprefix("-:"), layer_name):
- return False
- elif re.match(pattern.removeprefix("+:"), layer_name):
- if key is None:
- return pattern_dict
- else:
- return pattern_dict.get(key, default_value)
- return default_value
+ return dynamic_get(self.dynamic, layer_name, key, default_value)
# versionable is a meta.property that pairs value with version i.e "value:1.0.0"
def meta_set_versionable(self, key: str, value: List[str]):
diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py
index c04b445a2..647f02694 100644
--- a/gptqmodel/quantization/gptq.py
+++ b/gptqmodel/quantization/gptq.py
@@ -11,11 +11,13 @@
import transformers
from ..utils.logger import setup_logger
+from ..utils.torch import torch_empty_cache, torch_sync
from .quantizer import Quantizer
-from ..utils.torch import torch_sync, torch_empty_cache
+
logger = setup_logger()
+# TODO do we really need max precision?
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
@@ -23,20 +25,29 @@
class GPTQ:
def __init__(self, layer):
self.layer = layer
- self.dev = self.layer.weight.device
- W = layer.weight.data.clone()
+ self.device = self.layer.weight.device
+
+ self.layer_copy = self._clone_layer()
+
+ self.rows, self.columns = self.layer_copy.shape[0], self.layer_copy.shape[1]
+ self.H = torch.zeros((self.columns, self.columns), device=self.device)
+ self.nsamples = 0
+ self.quantizer = Quantizer()
+
+ def _clone_layer(self):
+ # mps for m1+ is unified memory
+ if self.device.type not in ["mps", "cpu"]:
+ clone = self.layer.weight.data.cpu()
+ else:
+ clone = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
- W = W.flatten(1)
+ clone = clone.flatten(1)
if isinstance(self.layer, transformers.pytorch_utils.Conv1D):
- W = W.t()
+ clone = clone.t()
- self.rows = W.shape[0]
- self.columns = W.shape[1]
- self.H = torch.zeros((self.columns, self.columns), device=self.dev)
- self.nsamples = 0
- self.quantizer = Quantizer()
+ return clone.to(device=self.device, dtype=torch.float)
def add_batch(self, inp, out):
if os.environ.get("DEBUG"):
@@ -70,14 +81,28 @@ def add_batch(self, inp, out):
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())
+ # wrapper for backward compat with optimum
+ # TODO: mark for deprecation
+ def fasterquant(
+ self,
+ blocksize=128,
+ percdamp=0.01,
+ damp_auto_increment=0.0015,
+ group_size=-1,
+ actorder=False,
+ static_groups=False,
+ ):
+ return self.hf_quantize(blocksize, percdamp, damp_auto_increment, group_size, actorder, static_groups)
+
+ # public api exposed to hf
def hf_quantize(
- self,
- blocksize=128,
- percdamp=0.01,
- damp_auto_increment=0.0015,
- group_size=-1,
- actorder=False,
- static_groups=False,
+ self,
+ blocksize=128,
+ percdamp=0.01,
+ damp_auto_increment=0.0015,
+ group_size=-1,
+ actorder=False,
+ static_groups=False,
):
return self.quantize(blocksize, percdamp, damp_auto_increment, group_size, actorder, static_groups)
@@ -91,24 +116,16 @@ def quantize(
actorder=False,
static_groups=False,
):
+ start = time.time()
# TODO: waiting for pytorch implementation of ops for MPS
if sys.platform == "darwin" and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1":
raise RuntimeError("For MacOS you must set env `PYTORCH_ENABLE_MPS_FALLBACK=1` before running quantization.")
- # save mem and temp move to cpu
- self.layer.weight.data = self.layer.weight.data.cpu()
-
- W = self.layer.weight.data.clone()
-
- if isinstance(self.layer, nn.Conv2d):
- W = W.flatten(1)
-
- if isinstance(self.layer, transformers.Conv1D):
- W = W.t()
-
- W = W.to(device=self.dev, dtype=torch.float)
-
- tick = time.time()
+ if self.layer_copy is None:
+ W = self._clone_layer()
+ else:
+ W = self.layer_copy
+ self.layer_copy = None
if not self.quantizer.ready():
self.quantizer.find_params(W, weight=True)
@@ -119,7 +136,7 @@ def quantize(
H[dead, dead] = 1
W[:, dead] = 0
- g_idx = []
+ # g_idx = []
scale = []
zero = []
now_idx = 1
@@ -148,7 +165,7 @@ def quantize(
while 1 > percdamp > 0:
try:
damp = percdamp * torch.mean(torch.diag(H))
- diag = torch.arange(self.columns, device=self.dev)
+ diag = torch.arange(self.columns, device=self.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
@@ -217,9 +234,8 @@ def quantize(
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
logger.debug(torch.sum(Losses))
- torch_sync(self.dev)
+ torch_sync(self.device)
- duration = time.time() - tick
avg_loss = torch.sum(Losses).item() / self.nsamples
if math.isnan(avg_loss):
@@ -248,7 +264,7 @@ def quantize(
self.layer.weight.data = Q.cpu().type_as(self.layer.weight.data)
# move back to self.dev
- self.layer.weight.data = self.layer.weight.data.to(device=self.dev)
+ self.layer.weight.data = self.layer.weight.data.to(device=self.device)
if os.environ.get("DEBUG"):
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
@@ -260,6 +276,7 @@ def quantize(
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
+ duration = time.time() - start
return scale, zero, g_idx, duration, avg_loss, percdamp
def free(self):
@@ -271,7 +288,10 @@ def free(self):
self.Losses = None
self.Trace = None
- torch_empty_cache(self.dev)
+ self.quantizer = None
+ self.layer_copy = None
+
+ torch_empty_cache(self.device)
__all__ = ["GPTQ"]
diff --git a/gptqmodel/quantization/quantizer.py b/gptqmodel/quantization/quantizer.py
index f10ceacef..971777f2a 100644
--- a/gptqmodel/quantization/quantizer.py
+++ b/gptqmodel/quantization/quantizer.py
@@ -6,6 +6,7 @@
from ..utils.logger import setup_logger
+
logger = setup_logger()
diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py
index be9bb9c43..76a6a63f8 100644
--- a/gptqmodel/utils/bitblas.py
+++ b/gptqmodel/utils/bitblas.py
@@ -1,4 +1,3 @@
-import gc
import os
import accelerate
@@ -6,12 +5,13 @@
import torch
from accelerate.utils import find_tied_parameters
-from .torch import torch_empty_cache
from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear
from ..quantization import FORMAT, QuantizeConfig
from ..utils.logger import setup_logger
from .model import recurse_getattr, recurse_setattr
from .progress import ProgressBar
+from .torch import torch_empty_cache
+
logger = setup_logger()
diff --git a/gptqmodel/utils/device.py b/gptqmodel/utils/device.py
index dff9d5cac..4ba3e10c5 100644
--- a/gptqmodel/utils/device.py
+++ b/gptqmodel/utils/device.py
@@ -1,5 +1,6 @@
from device_smi import Device
+
from gptqmodel.models._const import CPU, CUDA_0
diff --git a/gptqmodel/utils/exllama.py b/gptqmodel/utils/exllama.py
index c7d717e80..68a65e49f 100644
--- a/gptqmodel/utils/exllama.py
+++ b/gptqmodel/utils/exllama.py
@@ -1,9 +1,8 @@
-import gc
import torch
-from .torch import torch_empty_cache
from ..nn_modules.qlinear.exllama import ExllamaQuantLinear
+from .torch import torch_empty_cache
def exllama_set_max_input_length(model, max_input_length: int):
diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py
index 8c29bcfa2..093d1a14a 100644
--- a/gptqmodel/utils/importer.py
+++ b/gptqmodel/utils/importer.py
@@ -1,10 +1,6 @@
from collections import OrderedDict
-from typing import Dict, Optional, Type, Union, Tuple
+from typing import Dict, Optional, Tuple, Type, Union
-import torch
-
-from . import BACKEND
-from .torch import HAS_XPU, HAS_CUDA, HAS_MPS
from ..models._const import DEVICE, normalize_device
from ..nn_modules.qlinear import BaseQuantLinear
from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear
@@ -17,6 +13,9 @@
from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear
from ..quantization import FORMAT
from ..utils.logger import setup_logger
+from . import BACKEND
+from .torch import HAS_CUDA, HAS_MPS, HAS_XPU
+
message_logged = False
logger = setup_logger()
diff --git a/gptqmodel/utils/logger.py b/gptqmodel/utils/logger.py
index b62e21fcb..8a4fdbf2d 100644
--- a/gptqmodel/utils/logger.py
+++ b/gptqmodel/utils/logger.py
@@ -1,5 +1,6 @@
import logging
+
logger = None
def setup_logger():
global logger
diff --git a/gptqmodel/utils/marlin.py b/gptqmodel/utils/marlin.py
index a2c1a16c0..b612458fa 100644
--- a/gptqmodel/utils/marlin.py
+++ b/gptqmodel/utils/marlin.py
@@ -1,15 +1,15 @@
-import gc
import accelerate
import torch
from accelerate.utils import find_tied_parameters
-from .torch import torch_empty_cache
from ..nn_modules.qlinear.marlin import MarlinQuantLinear, _get_perms, unpack_qzeros
from ..quantization import FORMAT, QuantizeConfig
from ..utils.logger import setup_logger
from .model import recurse_getattr, recurse_setattr
from .progress import ProgressBar
+from .torch import torch_empty_cache
+
logger = setup_logger()
diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py
index 16e86cdc8..59465e448 100644
--- a/gptqmodel/utils/model.py
+++ b/gptqmodel/utils/model.py
@@ -21,8 +21,7 @@
from transformers import AutoConfig, PretrainedConfig
from transformers.utils.hub import cached_file
-from .torch import torch_empty_cache
-from ..models._const import CPU, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS, DEVICE
+from ..models._const import CPU, DEVICE, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS
from ..nn_modules.qlinear import BaseQuantLinear
from ..nn_modules.qlinear.exllama import ExllamaQuantLinear
from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear
@@ -33,6 +32,8 @@
from .importer import select_quant_linear
from .logger import setup_logger
from .progress import ProgressBar
+from .torch import torch_empty_cache
+from ..quantization.config import dynamic_get
logger = setup_logger()
@@ -190,6 +191,10 @@ def create_quant_layer(QuantLinear, bits, desc_act, dynamic, group_size, module,
d_sym = sym
# dynamic bits, group_size, sym for each layer/module
if dynamic is not None:
+ if dynamic_get(dynamic=dynamic, layer_name=name) == False: # noqa: E712
+ # skip create this quant linear
+ continue
+
for pattern, pattern_dict in dynamic.items():
if re.match(pattern, name):
d_bits = pattern_dict.get("bits", bits)
diff --git a/gptqmodel/utils/perplexity.py b/gptqmodel/utils/perplexity.py
index 72f5bcd16..f2c8183bc 100644
--- a/gptqmodel/utils/perplexity.py
+++ b/gptqmodel/utils/perplexity.py
@@ -3,6 +3,7 @@
import numpy as np
import torch
from datasets import load_dataset
+
from gptqmodel.utils.progress import ProgressBar
diff --git a/gptqmodel/utils/sglang.py b/gptqmodel/utils/sglang.py
index b813bfbd0..2aba002a5 100644
--- a/gptqmodel/utils/sglang.py
+++ b/gptqmodel/utils/sglang.py
@@ -3,6 +3,7 @@
import torch
from transformers import AutoConfig
+
try:
import sglang as sgl
SGLANG_AVAILABLE = True
diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py
index 8194ee34f..38a4be37a 100644
--- a/gptqmodel/utils/torch.py
+++ b/gptqmodel/utils/torch.py
@@ -1,6 +1,8 @@
-import torch
import gc as py_gc
+import torch
+
+
HAS_CUDA = False
HAS_XPU = False
HAS_MPS = False
@@ -53,4 +55,4 @@ def torch_empty_cache(device: torch.device = None, gc: bool = True):
elif device.type == "xpu":
torch.xpu.empty_cache()
elif device.type == "mps":
- torch.mps.empty_cache()
\ No newline at end of file
+ torch.mps.empty_cache()
diff --git a/gptqmodel/utils/vllm.py b/gptqmodel/utils/vllm.py
index 9d0d47d75..d9ff25e68 100644
--- a/gptqmodel/utils/vllm.py
+++ b/gptqmodel/utils/vllm.py
@@ -2,6 +2,7 @@
import torch
+
try:
from vllm import LLM, SamplingParams
diff --git a/gptqmodel/version.py b/gptqmodel/version.py
index 19fec1e6f..81a610005 100644
--- a/gptqmodel/version.py
+++ b/gptqmodel/version.py
@@ -1 +1 @@
-__version__ = "1.4.4-dev"
+__version__ = "1.4.5-dev"
diff --git a/setup.py b/setup.py
index 0a8ab3edd..7fd617cbc 100644
--- a/setup.py
+++ b/setup.py
@@ -9,6 +9,7 @@
from setuptools import find_packages, setup
from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel
+
CUDA_RELEASE = os.environ.get("CUDA_RELEASE", None)
TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST")
@@ -88,6 +89,7 @@ def get_version_tag(is_cuda_release: bool = True) -> str:
import torch # noqa: E402
+
if TORCH_CUDA_ARCH_LIST is None:
got_cuda_v6 = any(torch.cuda.get_device_capability(i)[0] >= 6 for i in range(torch.cuda.device_count()))
got_cuda_between_v6_and_v8 = any(6 <= torch.cuda.get_device_capability(i)[0] < 8 for i in range(torch.cuda.device_count()))
@@ -174,40 +176,47 @@ def get_version_tag(is_cuda_release: bool = True) -> str:
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
),
- cpp_ext.CUDAExtension(
- "gptqmodel_marlin_kernels",
- [
- "gptqmodel_ext/marlin/marlin_cuda.cpp",
- "gptqmodel_ext/marlin/marlin_cuda_kernel.cu",
- "gptqmodel_ext/marlin/marlin_repack.cu",
- ],
- extra_link_args=extra_link_args,
- extra_compile_args=extra_compile_args,
- ),
- cpp_ext.CUDAExtension(
- "gptqmodel_exllama_kernels",
- [
- "gptqmodel_ext/exllama/exllama_ext.cpp",
- "gptqmodel_ext/exllama/cuda_buffers.cu",
- "gptqmodel_ext/exllama/cuda_func/column_remap.cu",
- "gptqmodel_ext/exllama/cuda_func/q4_matmul.cu",
- "gptqmodel_ext/exllama/cuda_func/q4_matrix.cu",
- ],
- extra_link_args=extra_link_args,
- extra_compile_args=extra_compile_args,
- ),
- cpp_ext.CUDAExtension(
- "gptqmodel_exllamav2_kernels",
- [
- "gptqmodel_ext/exllamav2/ext.cpp",
- "gptqmodel_ext/exllamav2/cuda/q_matrix.cu",
- "gptqmodel_ext/exllamav2/cuda/q_gemm.cu",
- ],
- extra_link_args=extra_link_args,
- extra_compile_args=extra_compile_args,
- )
]
+ if sys.platform != "win32":
+ extensions += [
+ # TODO: VC++: fatal error C1061: compiler limit : blocks nested too deeply
+ cpp_ext.CUDAExtension(
+ "gptqmodel_marlin_kernels",
+ [
+ "gptqmodel_ext/marlin/marlin_cuda.cpp",
+ "gptqmodel_ext/marlin/marlin_cuda_kernel.cu",
+ "gptqmodel_ext/marlin/marlin_repack.cu",
+ ],
+ extra_link_args=extra_link_args,
+ extra_compile_args=extra_compile_args,
+ ),
+ # TODO: VC++: error lnk2001 unresolved external symbol cublasHgemm
+ cpp_ext.CUDAExtension(
+ "gptqmodel_exllama_kernels",
+ [
+ "gptqmodel_ext/exllama/exllama_ext.cpp",
+ "gptqmodel_ext/exllama/cuda_buffers.cu",
+ "gptqmodel_ext/exllama/cuda_func/column_remap.cu",
+ "gptqmodel_ext/exllama/cuda_func/q4_matmul.cu",
+ "gptqmodel_ext/exllama/cuda_func/q4_matrix.cu",
+ ],
+ extra_link_args=extra_link_args,
+ extra_compile_args=extra_compile_args,
+ ),
+ # TODO: VC++: error lnk2001 unresolved external symbol cublasHgemm
+ cpp_ext.CUDAExtension(
+ "gptqmodel_exllamav2_kernels",
+ [
+ "gptqmodel_ext/exllamav2/ext.cpp",
+ "gptqmodel_ext/exllamav2/cuda/q_matrix.cu",
+ "gptqmodel_ext/exllamav2/cuda/q_gemm.cu",
+ ],
+ extra_link_args=extra_link_args,
+ extra_compile_args=extra_compile_args,
+ )
+ ]
+
additional_setup_kwargs = {"ext_modules": extensions, "cmdclass": {"build_ext": cpp_ext.BuildExtension}}
diff --git a/tests/models/model_test.py b/tests/models/model_test.py
index 3a9ad0b01..7af9a7bcb 100644
--- a/tests/models/model_test.py
+++ b/tests/models/model_test.py
@@ -2,6 +2,7 @@
import os
import sys
+
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
@@ -13,15 +14,17 @@
import torch.cuda # noqa: E402
from datasets import load_dataset # noqa: E402
-from gptqmodel.utils.torch import torch_empty_cache # noqa: E402
+from lm_eval.utils import make_table # noqa: E402
+from transformers import AutoTokenizer # noqa: E402
+
from gptqmodel import BACKEND, GPTQModel # noqa: E402
from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402
from gptqmodel.quantization import FORMAT # noqa: E402
from gptqmodel.quantization.config import QuantizeConfig # noqa: E402
from gptqmodel.utils.eval import lm_eval # noqa: E402
-from lm_eval.utils import make_table # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
from ovis_calibration_dataset import get_calib_dataset
+from gptqmodel.utils.torch import torch_empty_cache # noqa: E402
+
RAND_SEED = 898
diff --git a/tests/models/test_opt.py b/tests/models/test_opt.py
index e0c8a8ad2..ff782ac21 100644
--- a/tests/models/test_opt.py
+++ b/tests/models/test_opt.py
@@ -1,6 +1,7 @@
+from model_test import ModelTest
+
from gptqmodel import BACKEND
from gptqmodel.utils.importer import backend_dict
-from model_test import ModelTest
class TestOpt(ModelTest):
diff --git a/tests/test_asym_gptq_v1.py b/tests/test_asym_gptq_v1.py
index f2592ec64..a42a1df9d 100644
--- a/tests/test_asym_gptq_v1.py
+++ b/tests/test_asym_gptq_v1.py
@@ -1,11 +1,13 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-from gptqmodel.quantization import FORMAT # noqa: E402
# -- end do not touch
from models.model_test import ModelTest # noqa: E402
+from gptqmodel.quantization import FORMAT # noqa: E402
+
class Test(ModelTest):
NATIVE_MODEL_ID = "ModelCloud/Llama3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct"
diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py
index 323f9cf61..795d7c52c 100644
--- a/tests/test_dynamic.py
+++ b/tests/test_dynamic.py
@@ -1,20 +1,22 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import tempfile # noqa: E402
import unittest # noqa: E402
from datasets import load_dataset # noqa: E402
+from parameterized import parameterized # noqa: E402
+from transformers import AutoTokenizer # noqa: E402
+
from gptqmodel import BACKEND, GPTQModel # noqa: E402
from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402
from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402
from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402
from gptqmodel.quantization import QuantizeConfig # noqa: E402
from gptqmodel.utils import Perplexity # noqa: E402
-from parameterized import parameterized # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
class TestDynamic(unittest.TestCase):
@@ -112,3 +114,17 @@ def test_skip_module(self):
for name, submodule in model.named_modules():
if name == 'model.model.layers.0.self_attn.q_proj' and isinstance(submodule, BaseQuantLinear): # module 0 was skipped
raise ValueError("first layer should be native module")
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.save(tmp_dir)
+ del model
+
+ q_model = GPTQModel.load(tmp_dir)
+ generate_str = self.tokenizer.decode(
+ q_model.generate(
+ **self.tokenizer("The capital of France is is", return_tensors="pt").to(q_model.device),
+ max_new_tokens=2)[0])
+
+ print(f"generate_str: {generate_str}")
+
+ self.assertIn("paris", generate_str.lower())
\ No newline at end of file
diff --git a/tests/test_estimate_vram.py b/tests/test_estimate_vram.py
index 2a1fe5387..edbee17ff 100644
--- a/tests/test_estimate_vram.py
+++ b/tests/test_estimate_vram.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import unittest # noqa: E402
diff --git a/tests/test_eval.py b/tests/test_eval.py
index 62781ce51..2e53b6590 100644
--- a/tests/test_eval.py
+++ b/tests/test_eval.py
@@ -3,9 +3,11 @@
import unittest
from typing import Union
+from parameterized import parameterized
+
from gptqmodel import GPTQModel
from gptqmodel.utils import EVAL
-from parameterized import parameterized
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
diff --git a/tests/test_evalplus.py b/tests/test_evalplus.py
index 4ceda12eb..949e042b8 100644
--- a/tests/test_evalplus.py
+++ b/tests/test_evalplus.py
@@ -1,10 +1,13 @@
+# -- do not touch
import os
-import tempfile
-import unittest
-
-from gptqmodel.utils.eval import evalplus
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+# -- end do not touch
+
+import tempfile # noqa: E402
+import unittest # noqa: E402
+
+from gptqmodel.utils.eval import evalplus # noqa: E402
class TestEvalplus(unittest.TestCase):
diff --git a/tests/test_group_size.py b/tests/test_group_size.py
index 311d54960..91713107b 100644
--- a/tests/test_group_size.py
+++ b/tests/test_group_size.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import logging # noqa: E402
@@ -8,6 +9,9 @@
import traceback # noqa: E402
import unittest # noqa: E402
+from lm_eval.utils import make_table # noqa: E402
+from transformers import AutoTokenizer # noqa: E402
+
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402
from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402
from gptqmodel.nn_modules.qlinear.dynamic_cuda import DynamicCudaQuantLinear # noqa: E402
@@ -18,8 +22,7 @@
from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402
from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402
from gptqmodel.utils.eval import lm_eval # noqa: E402
-from lm_eval.utils import make_table # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
+
logger = logging.getLogger(__name__)
diff --git a/tests/test_ipex_xpu.py b/tests/test_ipex_xpu.py
new file mode 100644
index 000000000..b509977f7
--- /dev/null
+++ b/tests/test_ipex_xpu.py
@@ -0,0 +1,41 @@
+# -- do not touch
+import os
+
+
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+# -- end do not touch
+
+import tempfile # noqa: E402
+
+from models.model_test import ModelTest # noqa: E402
+
+from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402
+from gptqmodel.models._const import DEVICE # noqa: E402
+
+
+class TestsIPEX(ModelTest):
+ NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct"
+
+ def test(self):
+ origin_model = GPTQModel.load(
+ self.NATIVE_MODEL_ID,
+ quantize_config=QuantizeConfig(),
+ backend=BACKEND.IPEX,
+ device=DEVICE.XPU,
+ )
+ tokenizer = self.load_tokenizer(self.NATIVE_MODEL_ID)
+ calibration_dataset = self.load_dataset(tokenizer)
+ origin_model.quantize(calibration_dataset)
+ with tempfile.TemporaryDirectory() as tmpdir:
+ origin_model.save(tmpdir)
+
+ model = GPTQModel.load(
+ tmpdir,
+ backend=BACKEND.IPEX,
+ device=DEVICE.XPU,
+ )
+ generate_str = tokenizer.decode(model.generate(**tokenizer("The capital of France is is", return_tensors="pt").to(model.device), max_new_tokens=2)[0])
+
+ print(f"generate_str: {generate_str}")
+
+ self.assertIn("paris", generate_str.lower())
diff --git a/tests/test_lm_eval.py b/tests/test_lm_eval.py
index afd3f2b99..bd33d7c6a 100644
--- a/tests/test_lm_eval.py
+++ b/tests/test_lm_eval.py
@@ -1,12 +1,15 @@
# -- do not touch
import os
+
# -- end do not touch
import tempfile # noqa: E402
import unittest # noqa: E402
-from gptqmodel.utils.eval import lm_eval # noqa: E402
from lm_eval.utils import make_table
+from gptqmodel.utils.eval import lm_eval # noqa: E402
+
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
diff --git a/tests/test_lm_head.py b/tests/test_lm_head.py
index 151ec33b8..8189aa48b 100644
--- a/tests/test_lm_head.py
+++ b/tests/test_lm_head.py
@@ -1,12 +1,14 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
+from models.model_test import ModelTest # noqa: E402
+
from gptqmodel import GPTQModel # noqa: E402
from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402
-from models.model_test import ModelTest # noqa: E402
class TestLmHead(ModelTest):
diff --git a/tests/test_packing.py b/tests/test_packing.py
index a7e37d6c8..416f1d894 100644
--- a/tests/test_packing.py
+++ b/tests/test_packing.py
@@ -1,11 +1,13 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import unittest # noqa: E402
+
# isort: off
import torch # noqa: E402
import torch.nn as nn # noqa: E402
diff --git a/tests/test_perplexity.py b/tests/test_perplexity.py
index 762b83662..ab2efdc21 100644
--- a/tests/test_perplexity.py
+++ b/tests/test_perplexity.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
@@ -8,11 +9,12 @@
import unittest # noqa: E402
from datasets import load_dataset # noqa: E402
+from parameterized import parameterized # noqa: E402
+from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
+
from gptqmodel import GPTQModel # noqa: E402
from gptqmodel.quantization.config import FORMAT, QUANT_METHOD, AutoRoundQuantizeConfig, QuantizeConfig # noqa: E402
from gptqmodel.utils import Perplexity # noqa: E402
-from parameterized import parameterized # noqa: E402
-from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402
class TestPerplexity(unittest.TestCase):
diff --git a/tests/test_q4_bitblas.py b/tests/test_q4_bitblas.py
index 851222d85..6cc42ff4d 100644
--- a/tests/test_q4_bitblas.py
+++ b/tests/test_q4_bitblas.py
@@ -1,15 +1,17 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import unittest # noqa: E402
import torch # noqa: E402
+from transformers import AutoTokenizer # noqa: E402
+
from gptqmodel import BACKEND, GPTQModel # noqa: E402
from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
class TestQ4BitBLAS(unittest.TestCase):
diff --git a/tests/test_q4_cuda.py b/tests/test_q4_cuda.py
index 29f66542c..8d2927937 100644
--- a/tests/test_q4_cuda.py
+++ b/tests/test_q4_cuda.py
@@ -1,16 +1,19 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import unittest # noqa: E402
import torch # noqa: E402
-from gptqmodel import BACKEND, GPTQModel # noqa: E402
from parameterized import parameterized # noqa: E402
from transformers import AutoTokenizer # noqa: E402
+from gptqmodel import BACKEND, GPTQModel # noqa: E402
+
+
GENERATE_EVAL_SIZE = 100
diff --git a/tests/test_q4_exllama_v1.py b/tests/test_q4_exllama_v1.py
index 8fd3bc6bd..07f6f860e 100644
--- a/tests/test_q4_exllama_v1.py
+++ b/tests/test_q4_exllama_v1.py
@@ -1,20 +1,23 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import torch # noqa: E402
+from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params # noqa: E402
+from models.model_test import ModelTest # noqa: E402
+from transformers import AutoTokenizer # noqa: E402
+
from gptqmodel import BACKEND, GPTQModel, exllama_set_max_input_length # noqa: E402
from gptqmodel.models._const import EXLLAMA_DEFAULT_MAX_INPUT_LENGTH # noqa: E402
from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear # noqa: E402
from gptqmodel.quantization import FORMAT # noqa: E402
from gptqmodel.utils.importer import select_quant_linear # noqa: E402
from gptqmodel.utils.model import gptqmodel_post_init # noqa: E402
-from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params # noqa: E402
-from models.model_test import ModelTest # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
+
REFERENCE = torch.Tensor(
[
diff --git a/tests/test_q4_exllama_v2.py b/tests/test_q4_exllama_v2.py
index c684b6a25..ee136e257 100644
--- a/tests/test_q4_exllama_v2.py
+++ b/tests/test_q4_exllama_v2.py
@@ -1,19 +1,22 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import unittest # noqa: E402
import torch # noqa: E402
+from test_q4_exllama_v1 import REFERENCE, get_diff # noqa: E402
+from transformers import AutoTokenizer # noqa: E402
+
from gptqmodel import BACKEND, GPTQModel # noqa: E402
from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear # noqa: E402
from gptqmodel.quantization import FORMAT # noqa: E402
from gptqmodel.utils.importer import select_quant_linear # noqa: E402
from gptqmodel.utils.model import gptqmodel_post_init # noqa: E402
-from test_q4_exllama_v1 import REFERENCE, get_diff # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
+
GENERATE_EVAL_SIZE = 100
diff --git a/tests/test_q4_ipex.py b/tests/test_q4_ipex.py
index 1d963ac0f..f0518abb4 100644
--- a/tests/test_q4_ipex.py
+++ b/tests/test_q4_ipex.py
@@ -2,13 +2,16 @@
import os
import sys
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import torch # noqa: E402
-from gptqmodel import BACKEND # noqa: E402
from models.model_test import ModelTest # noqa: E402
+from gptqmodel import BACKEND # noqa: E402
+
+
GENERATE_EVAL_SIZE = 100
diff --git a/tests/test_q4_marlin.py b/tests/test_q4_marlin.py
index 52734df23..aa810e366 100644
--- a/tests/test_q4_marlin.py
+++ b/tests/test_q4_marlin.py
@@ -1,16 +1,18 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import unittest # noqa: E402
import torch # noqa: E402
-from gptqmodel import BACKEND, GPTQModel # noqa: E402
-from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402
from parameterized import parameterized # noqa: E402
from transformers import AutoTokenizer # noqa: E402
+from gptqmodel import BACKEND, GPTQModel # noqa: E402
+from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402
+
class TestQ4Marlin(unittest.TestCase):
diff --git a/tests/test_q4_torch.py b/tests/test_q4_torch.py
index cb95861d2..d55964771 100644
--- a/tests/test_q4_torch.py
+++ b/tests/test_q4_torch.py
@@ -1,16 +1,19 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import sys # noqa: E402
import unittest # noqa: E402
import torch # noqa: E402
-from gptqmodel import BACKEND, GPTQModel # noqa: E402
from parameterized import parameterized # noqa: E402
from transformers import AutoTokenizer # noqa: E402
+from gptqmodel import BACKEND, GPTQModel # noqa: E402
+
+
GENERATE_EVAL_SIZE = 100
diff --git a/tests/test_q4_triton.py b/tests/test_q4_triton.py
index 08b0f28fd..8c79121fd 100644
--- a/tests/test_q4_triton.py
+++ b/tests/test_q4_triton.py
@@ -1,15 +1,18 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import unittest # noqa: E402
import torch # noqa: E402
+from transformers import AutoTokenizer # noqa: E402
+
from gptqmodel import BACKEND, GPTQModel # noqa: E402
from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
+
GENERATE_EVAL_SIZE = 100
diff --git a/tests/test_quant_batch.py b/tests/test_quant_batch.py
index 8e44308b8..1cd4e44f3 100644
--- a/tests/test_quant_batch.py
+++ b/tests/test_quant_batch.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
@@ -8,10 +9,11 @@
import unittest # noqa: E402
from datasets import load_dataset # noqa: E402
+from transformers import AutoTokenizer # noqa: E402
+
from gptqmodel import GPTQModel # noqa: E402
from gptqmodel.quantization import QuantizeConfig # noqa: E402
from gptqmodel.utils import Perplexity # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
class TestQuantBatch(unittest.TestCase):
diff --git a/tests/test_quant_formats.py b/tests/test_quant_formats.py
index 24fba00a7..c8f0d87ec 100644
--- a/tests/test_quant_formats.py
+++ b/tests/test_quant_formats.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
@@ -10,14 +11,19 @@
import unittest # noqa: E402
from datasets import load_dataset # noqa: E402
-from gptqmodel.utils.torch import torch_empty_cache # noqa: E402
-from gptqmodel import BACKEND, GPTQModel, __version__, get_best_device # noqa: E402
-from gptqmodel.quantization import FORMAT, QUANT_CONFIG_FILENAME, QUANT_METHOD # noqa: E402
-from gptqmodel.quantization.config import (META_FIELD_QUANTIZER, META_QUANTIZER_GPTQMODEL, # noqa: E402
- AutoRoundQuantizeConfig, QuantizeConfig)
from parameterized import parameterized # noqa: E402
from transformers import AutoTokenizer # noqa: E402
+from gptqmodel import BACKEND, GPTQModel, __version__, get_best_device # noqa: E402
+from gptqmodel.quantization import FORMAT, QUANT_CONFIG_FILENAME, QUANT_METHOD # noqa: E402
+from gptqmodel.quantization.config import ( # noqa: E402
+ META_FIELD_QUANTIZER,
+ META_QUANTIZER_GPTQMODEL,
+ AutoRoundQuantizeConfig,
+ QuantizeConfig,
+)
+from gptqmodel.utils.torch import torch_empty_cache # noqa: E402
+
class TestQuantization(unittest.TestCase):
diff --git a/tests/test_quant_trust_remote.py b/tests/test_quant_trust_remote.py
index aae8da9ab..81127d131 100644
--- a/tests/test_quant_trust_remote.py
+++ b/tests/test_quant_trust_remote.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
@@ -8,9 +9,10 @@
import unittest # noqa: E402
from datasets import load_dataset # noqa: E402
+from transformers import AutoTokenizer # noqa: E402
+
from gptqmodel import GPTQModel # noqa: E402
from gptqmodel.quantization import FORMAT, QuantizeConfig # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
class TestQuantWithTrustRemoteTrue(unittest.TestCase):
diff --git a/tests/test_save_loaded_quantized_model.py b/tests/test_save_loaded_quantized_model.py
index 28730fe70..8a160e025 100644
--- a/tests/test_save_loaded_quantized_model.py
+++ b/tests/test_save_loaded_quantized_model.py
@@ -1,15 +1,18 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import tempfile # noqa: E402
import unittest # noqa: E402
-from gptqmodel import BACKEND, GPTQModel, get_best_device # noqa: E402
from parameterized import parameterized # noqa: E402
from transformers import AutoTokenizer # noqa: E402
+from gptqmodel import BACKEND, GPTQModel, get_best_device # noqa: E402
+
+
MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
class TestSave(unittest.TestCase):
diff --git a/tests/test_serialization.py b/tests/test_serialization.py
index 5aa1d1be7..ffaedca97 100644
--- a/tests/test_serialization.py
+++ b/tests/test_serialization.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
diff --git a/tests/test_sglang.py b/tests/test_sglang.py
index 59fd8320b..374ab8a57 100644
--- a/tests/test_sglang.py
+++ b/tests/test_sglang.py
@@ -1,5 +1,6 @@
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
@@ -9,6 +10,7 @@
import unittest # noqa: E402
import torch # noqa: E402
+
from gptqmodel import BACKEND, GPTQModel # noqa: E402
diff --git a/tests/test_sharded.py b/tests/test_sharded.py
index e8275d8c6..1d2338c13 100644
--- a/tests/test_sharded.py
+++ b/tests/test_sharded.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
@@ -9,9 +10,10 @@
import unittest # noqa: E402
import torch # noqa: E402
-from gptqmodel import GPTQModel # noqa: E402
from transformers import AutoTokenizer # noqa: E402
+from gptqmodel import GPTQModel # noqa: E402
+
class TestSharded(unittest.TestCase):
MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
diff --git a/tests/test_tgi.py b/tests/test_tgi.py
index c3ebf8045..d26c51ecc 100644
--- a/tests/test_tgi.py
+++ b/tests/test_tgi.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
import json # noqa: E402
diff --git a/tests/test_transformers_integration.py b/tests/test_transformers_integration.py
index 6c988e816..2977fb09d 100644
--- a/tests/test_transformers_integration.py
+++ b/tests/test_transformers_integration.py
@@ -1,9 +1,10 @@
import tempfile
import unittest
-from gptqmodel.integration import integration
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
+from gptqmodel.integration import integration
+
class TestTransformersIntegration(unittest.TestCase):
diff --git a/tests/test_triton.py b/tests/test_triton.py
index c9a5bf878..da71ea565 100644
--- a/tests/test_triton.py
+++ b/tests/test_triton.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
@@ -9,9 +10,11 @@
import torch # noqa: E402
import torch.utils.benchmark as benchmark # noqa: E402
-from gptqmodel import BACKEND, GPTQModel # noqa: E402
from transformers import AutoTokenizer # noqa: E402
+from gptqmodel import BACKEND, GPTQModel # noqa: E402
+
+
MODEL_ID = "/monster/data/model/Llama-7B-GPTQ"
DATASET_ID = "timdettmers/openassistant-guanaco"
LEARNING_RATE = 3e-5
diff --git a/tests/test_verify_hash.py b/tests/test_verify_hash.py
index f49591b27..7c0b246da 100644
--- a/tests/test_verify_hash.py
+++ b/tests/test_verify_hash.py
@@ -1,6 +1,7 @@
# -- do not touch
import os
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
diff --git a/tests/test_vllm.py b/tests/test_vllm.py
index 449315ec5..ef4e1b1ae 100644
--- a/tests/test_vllm.py
+++ b/tests/test_vllm.py
@@ -1,22 +1,23 @@
# -- do not touch
import os
-import tempfile
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# -- end do not touch
-import gc # noqa: E402
import importlib.util # noqa: E402
import subprocess # noqa: E402
import sys # noqa: E402
+import tempfile # noqa: E402
import unittest # noqa: E402
-import torch
-from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402
-from transformers import AutoTokenizer # noqa: E402
-from gptqmodel.utils.torch import torch_empty_cache # noqa: E402
+import torch # noqa: E402
from datasets import load_dataset # noqa: E402
+from transformers import AutoTokenizer # noqa: E402
+
from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402
+from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402
+from gptqmodel.utils.torch import torch_empty_cache # noqa: E402
class TestLoadVLLM(unittest.TestCase):