Skip to content

Commit

Permalink
Merge pull request #1566 from vespa-engine/thomasht86/fix-to-bfloat16
Browse files Browse the repository at this point in the history
(colpalidemo) fix to bfloat16
  • Loading branch information
thomasht86 authored Nov 6, 2024
2 parents 0f7da4e + 64f1dce commit ad2f0a6
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions visual-retrieval-colpali/src/backend/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class SimMapGenerator:
Generates similarity maps based on query embeddings and image patches using the ColPali model.
"""

COLPALI_GEMMA_MODEL_NAME = "vidore/colpaligemma-3b-pt-448-base"
colormap = cm.get_cmap("viridis") # Preload colormap for efficiency

def __init__(self, model_name: str = "vidore/colpali-v1.2", n_patch: int = 32):
Expand All @@ -47,7 +46,7 @@ def load_model(self) -> Tuple[ColPali, ColPaliProcessor]:
"""
model = ColPali.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
torch_dtype=torch.bfloat16, # Note that the embeddings created during feed were float32 -> binarized, yet setting this seem to produce the most similar results both locally (mps) and HF (Cuda)
device_map=self.device,
).eval()

Expand Down

0 comments on commit ad2f0a6

Please sign in to comment.