Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
bclavie authored Sep 23, 2024
2 parents dfb069c + fd60959 commit 0f6fb60
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 66 deletions.
79 changes: 30 additions & 49 deletions byaldi/colpali.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,22 @@
import os
import shutil
import tempfile

# Import version directly from the package metadata
from importlib.metadata import version
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

import srsly
import torch
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
process_images,
process_queries,
)
from colpali_engine.models import ColPali, ColPaliProcessor
from pdf2image import convert_from_path
from PIL import Image
from transformers import AutoProcessor

from byaldi.objects import Result

from .utils import capture_print

# Import version directly from the package metadata
VERSION = version("Byaldi")


MOCK_IMAGE = Image.new("RGB", (448, 448), (255, 255, 255))


class ColPaliModel:
def __init__(
self,
Expand All @@ -41,6 +29,9 @@ def __init__(
device: Optional[Union[str, torch.device]] = None,
**kwargs,
):
if isinstance(pretrained_model_name_or_path, Path):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)

if "colpali" not in pretrained_model_name_or_path.lower():
raise ValueError(
"This pre-release version of Byaldi only supports ColPali for now. Incorrect model name specified."
Expand Down Expand Up @@ -73,35 +64,27 @@ def __init__(
self.doc_ids_to_file_names = {}
self.doc_ids = set()

# self.model = ColPali.from_pretrained(
# "vidore/colpaligemma-3b-pt-448-base",
# torch_dtype=torch.bfloat16,
# device_map="cuda"
# if device == "cuda"
# or (isinstance(device, torch.device) and device.type == "cuda")
# else None,
# token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
# )

# if verbose > 0:
# print("Loading adapter...")
# print("Adapter name: ", self.pretrained_model_name_or_path)
# self.model.load_adapter(self.pretrained_model_name_or_path)

self.model = ColPali.from_pretrained(
self.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None,
device_map=(
"cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
)
self.model = self.model.eval()
self.processor = AutoProcessor.from_pretrained(
self.pretrained_model_name_or_path,
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),

self.processor = cast(
ColPaliProcessor,
ColPaliProcessor.from_pretrained(
self.pretrained_model_name_or_path,
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
),
)

self.device = device
if device != "cuda" and not (
isinstance(device, torch.device) and device.type == "cuda"
Expand All @@ -112,14 +95,18 @@ def __init__(
self.full_document_collection = False
self.highest_doc_id = -1
else:
index_path = Path(index_root) / Path(index_name)
if self.index_name is None:
raise ValueError("No index name specified. Cannot load from index.")

index_path = Path(index_root) / Path(self.index_name)
index_config = srsly.read_gzip_json(index_path / "index_config.json.gz")
self.full_document_collection = index_config.get(
"full_document_collection", False
)
self.resize_stored_images = index_config.get("resize_stored_images", False)
self.max_image_width = index_config.get("max_image_width", None)
self.max_image_height = index_config.get("max_image_height", None)

if self.full_document_collection:
collection_path = index_path / "collection"
json_files = sorted(
Expand Down Expand Up @@ -524,7 +511,7 @@ def _add_to_index(
f"Document ID {doc_id} with page ID {page_id} already exists in the index"
)

processed_image = process_images(self.processor, [image])
processed_image = self.processor.process_images([image])

# Generate embedding
with torch.no_grad():
Expand Down Expand Up @@ -583,12 +570,6 @@ def _add_to_index(
def remove_from_index(self):
raise NotImplementedError("This method is not implemented yet.")

@capture_print
def _score(self, qs: torch.Tensor):
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, self.indexed_embeddings)
return scores

def search(
self,
query: Union[str, List[str]],
Expand All @@ -612,13 +593,13 @@ def search(
for q in queries:
# Process query
with torch.no_grad():
batch_query = process_queries(self.processor, [q], MOCK_IMAGE)
batch_query = self.processor.process_queries([q])
batch_query = {k: v.to(self.device) for k, v in batch_query.items()}
embeddings_query = self.model(**batch_query)
qs = list(torch.unbind(embeddings_query.to("cpu")))

# Compute scores
scores = self._score(qs)
scores = self.processor.score(qs, self.indexed_embeddings).cpu().numpy()

# Get top k relevant pages
top_pages = scores.argsort(axis=1)[0][-k:][::-1].tolist()
Expand Down Expand Up @@ -690,7 +671,7 @@ def encode_image(
raise ValueError(f"Unsupported input type: {type(item)}")

with torch.no_grad():
batch = process_images(self.processor, images)
batch = self.processor.process_images(images)
batch = {k: v.to(self.device) for k, v in batch.items()}
embeddings = self.model(**batch)

Expand All @@ -711,7 +692,7 @@ def encode_query(self, query: Union[str, List[str]]) -> torch.Tensor:
query = [query]

with torch.no_grad():
batch = process_queries(self.processor, query, MOCK_IMAGE)
batch = self.processor.process_queries(query)
batch = {k: v.to(self.device) for k, v in batch.items()}
embeddings = self.model(**batch)

Expand Down
15 changes: 0 additions & 15 deletions byaldi/utils.py

This file was deleted.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ maintainers = [
]

dependencies = [
"colpali-engine==0.2.2",
"colpali-engine>=0.3.0,<0.4.0",
"ml-dtypes",
"mteb==1.6.35",
"ninja",
Expand All @@ -47,6 +47,8 @@ server = ["uvicorn", "fastapi"]

[tool.pytest.ini_options]
filterwarnings = ["ignore::Warning"]
markers = ["slow: marks test as slow"]
testpaths = ["tests"]

[tool.ruff]
# Exclude a variety of commonly ignored directories.
Expand Down
35 changes: 34 additions & 1 deletion tests/all.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from pathlib import Path

from colpali_engine.utils.torch_utils import get_torch_device

from byaldi import RAGMultiModalModel

device = get_torch_device("auto")
print(f"Using device: {device}")

path_document_1 = Path("docs/attention.pdf")
path_document_2 = Path("docs/attention_copy.pdf")


def test_single_pdf():
print("Testing single PDF indexing and retrieval...")

# Initialize the model
model = RAGMultiModalModel.from_pretrained("vidore/colpali")
model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", device=device)

if not Path("docs/attention.pdf").is_file():
raise FileNotFoundError(
f"Please download the PDF file from https://arxiv.org/pdf/1706.03762 and move it to {path_document_1}."
)

# Index a single PDF
model.index(
Expand Down Expand Up @@ -51,6 +66,15 @@ def test_multi_document():
# Initialize the model
model = RAGMultiModalModel.from_pretrained("vidore/colpali")

if not Path("docs/attention.pdf").is_file():
raise FileNotFoundError(
f"Please download the PDF file from https://arxiv.org/pdf/1706.03762 and move it to {path_document_1}."
)
if not Path("docs/attention_copy.pdf").is_file():
raise FileNotFoundError(
f"Please download the PDF file from https://arxiv.org/pdf/1706.03762 and move it to {path_document_2}."
)

# Index a directory of documents
model.index(
input_path="docs/",
Expand Down Expand Up @@ -132,6 +156,15 @@ def test_add_to_index():


if __name__ == "__main__":
print("Starting tests...")

print("/n/n----------------- Single PDF test -----------------n")
test_single_pdf()

print("/n/n----------------- Multi document test -----------------n")
test_multi_document()

print("/n/n----------------- Add to index test -----------------n")
test_add_to_index()

print("\nAll tests completed.")
23 changes: 23 additions & 0 deletions tests/test_colpali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Generator

import pytest
from colpali_engine.models import ColPali
from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch

from byaldi import RAGMultiModalModel
from byaldi.colpali import ColPaliModel


@pytest.fixture(scope="module")
def colpali_rag_model() -> Generator[RAGMultiModalModel, None, None]:
device = get_torch_device("auto")
print(f"Using device: {device}")
yield RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", device=device)
tear_down_torch()


@pytest.mark.slow
def test_load_colpali_from_pretrained(colpali_rag_model: RAGMultiModalModel):
assert isinstance(colpali_rag_model, RAGMultiModalModel)
assert isinstance(colpali_rag_model.model, ColPaliModel)
assert isinstance(colpali_rag_model.model.model, ColPali)
Loading

0 comments on commit 0f6fb60

Please sign in to comment.