Skip to content

Commit

Permalink
Update device to cuda
Browse files Browse the repository at this point in the history
Signed-off-by: Angel Luu <[email protected]>
  • Loading branch information
aluu317 committed Sep 11, 2024
1 parent cc9f2a1 commit 6e9e7f2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def load(
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
device_map="auto",
device_map="cuda",
torch_dtype=torch.float16
if use_flash_attn
else None, # since we are using exllama kernel, we have to use float16
Expand Down

0 comments on commit 6e9e7f2

Please sign in to comment.