diff --git a/tests/test_vllm.py b/tests/test_vllm.py index 420525307..0e46ee5bb 100644 --- a/tests/test_vllm.py +++ b/tests/test_vllm.py @@ -19,6 +19,7 @@ def setUpClass(self): subprocess.check_call([sys.executable, "-m", "pip", "install", "vllm>=0.5.1"]) from vllm import SamplingParams # noqa: E402 self.MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + self.SHARDED_MODEL_ID = "ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-sharded" self.prompts = [ "The capital of France is", ] @@ -29,6 +30,7 @@ def test_load_vllm(self): self.MODEL_ID, device="cuda:0", backend=BACKEND.VLLM, + gpu_memory_utilization=0.2, ) outputs = model.generate( prompts=self.prompts, @@ -50,3 +52,21 @@ def test_load_vllm(self): print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") self.assertEquals(generated_text, " Paris. 2. Name the capital of the United States. 3.") + def test_load_shared_vllm(self): + model = GPTQModel.from_quantized( + self.SHARDED_MODEL_ID, + device="cuda:0", + backend=BACKEND.VLLM, + gpu_memory_utilization=0.2, + ) + outputs = model.generate( + prompts=self.prompts, + temperature=0.8, + top_p=0.95, + ) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + self.assertEquals(generated_text, + " Paris.\n2. Who has a national flag with a white field surrounded by") \ No newline at end of file