Skip to content

Commit

Permalink
add vllm sharded test (ModelCloud#267)
Browse files Browse the repository at this point in the history
* add vllm sharded test

* merge vllm sharded into test_vllm
  • Loading branch information
PZS-ModelCloud authored Jul 23, 2024
1 parent f3d400f commit 5f5eae6
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def setUpClass(self):
subprocess.check_call([sys.executable, "-m", "pip", "install", "vllm>=0.5.1"])
from vllm import SamplingParams # noqa: E402
self.MODEL_ID = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
self.SHARDED_MODEL_ID = "ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-sharded"
self.prompts = [
"The capital of France is",
]
Expand All @@ -29,6 +30,7 @@ def test_load_vllm(self):
self.MODEL_ID,
device="cuda:0",
backend=BACKEND.VLLM,
gpu_memory_utilization=0.2,
)
outputs = model.generate(
prompts=self.prompts,
Expand All @@ -50,3 +52,21 @@ def test_load_vllm(self):
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text, " Paris. 2. Name the capital of the United States. 3.")

def test_load_shared_vllm(self):
model = GPTQModel.from_quantized(
self.SHARDED_MODEL_ID,
device="cuda:0",
backend=BACKEND.VLLM,
gpu_memory_utilization=0.2,
)
outputs = model.generate(
prompts=self.prompts,
temperature=0.8,
top_p=0.95,
)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text,
" Paris.\n2. Who has a national flag with a white field surrounded by")

0 comments on commit 5f5eae6

Please sign in to comment.