Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(colpalidemo) Thomasht86/update ranking #1578

Merged
merged 5 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 45 additions & 29 deletions visual-retrieval-colpali/prepare_feed_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def float_to_binary_embedding(float_query_embedding: dict) -> dict:
]

# Define the 'bm25' rank profile
colpali_bm25_profile = RankProfile(
bm25 = RankProfile(
name="bm25",
inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
first_phase="bm25(title) + bm25(text)",
Expand All @@ -743,14 +743,27 @@ def with_quantized_similarity(rank_profile: RankProfile) -> RankProfile:
)


colpali_schema.add_rank_profile(colpali_bm25_profile)
colpali_schema.add_rank_profile(with_quantized_similarity(colpali_bm25_profile))
colpali_schema.add_rank_profile(bm25)
colpali_schema.add_rank_profile(with_quantized_similarity(bm25))

# Update the 'default' rank profile
colpali_profile = RankProfile(
name="default",
inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
first_phase="bm25_score",

# Update the 'colpali' rank profile
input_query_tensors = []
MAX_QUERY_TERMS = 64
for i in range(MAX_QUERY_TERMS):
input_query_tensors.append((f"query(rq{i})", "tensor<int8>(v[16])"))

input_query_tensors.extend(
[
("query(qt)", "tensor<float>(querytoken{}, v[128])"),
("query(qtb)", "tensor<int8>(querytoken{}, v[16])"),
]
)

colpali = RankProfile(
name="colpali",
inputs=input_query_tensors,
first_phase="max_sim_binary",
second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
functions=mapfunctions
+ [
Expand All @@ -768,30 +781,33 @@ def with_quantized_similarity(rank_profile: RankProfile) -> RankProfile:
)
""",
),
Function(name="bm25_score", expression="bm25(title) + bm25(text)"),
Function(
name="max_sim_binary",
expression="""
sum(
reduce(
1 / (1 + sum(
hamming(query(qtb), attribute(embedding)), v)
),
max, patch
),
querytoken
)
""",
),
],
)
colpali_schema.add_rank_profile(colpali_profile)
colpali_schema.add_rank_profile(with_quantized_similarity(colpali_profile))

# Update the 'retrieval-and-rerank' rank profile
input_query_tensors = []
MAX_QUERY_TERMS = 64
for i in range(MAX_QUERY_TERMS):
input_query_tensors.append((f"query(rq{i})", "tensor<int8>(v[16])"))

input_query_tensors.extend(
[
("query(qt)", "tensor<float>(querytoken{}, v[128])"),
("query(qtb)", "tensor<int8>(querytoken{}, v[16])"),
]
)
colpali_schema.add_rank_profile(colpali)
colpali_schema.add_rank_profile(with_quantized_similarity(colpali))

colpali_retrieval_profile = RankProfile(
name="retrieval-and-rerank",
# Update the 'hybrid' rank profile
hybrid = RankProfile(
name="hybrid",
inputs=input_query_tensors,
first_phase="max_sim_binary",
second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
second_phase=SecondPhaseRanking(
expression="max_sim + 2 * (bm25(text) + bm25(title))", rerank_count=10
),
functions=mapfunctions
+ [
Function(
Expand Down Expand Up @@ -824,8 +840,8 @@ def with_quantized_similarity(rank_profile: RankProfile) -> RankProfile:
),
],
)
colpali_schema.add_rank_profile(colpali_retrieval_profile)
colpali_schema.add_rank_profile(with_quantized_similarity(colpali_retrieval_profile))
colpali_schema.add_rank_profile(hybrid)
colpali_schema.add_rank_profile(with_quantized_similarity(hybrid))

# +
from vespa.configuration.services import (
Expand Down
65 changes: 10 additions & 55 deletions visual-retrieval-colpali/src/backend/vespa_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,54 +104,6 @@ def format_query_results(
self.logger.debug(result_text)
return response.json

async def query_vespa_default(
self,
query: str,
q_emb: torch.Tensor,
hits: int = 3,
timeout: str = "10s",
sim_map: bool = False,
**kwargs,
) -> dict:
"""
Query Vespa using the default ranking profile.
This corresponds to the "Hybrid ColPali+BM25" radio button in the UI.

Args:
query (str): The query text.
q_emb (torch.Tensor): Query embeddings.
hits (int, optional): Number of hits to retrieve. Defaults to 3.
timeout (str, optional): Query timeout. Defaults to "10s".

Returns:
dict: The formatted query results.
"""
async with self.app.asyncio(connections=1) as session:
query_embedding = self.format_q_embs(q_emb)

start = time.perf_counter()
response: VespaQueryResponse = await session.query(
body={
"yql": (
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();"
),
"ranking": self.get_rank_profile("default", sim_map),
"query": query,
"timeout": timeout,
"hits": hits,
"input.query(qt)": query_embedding,
"presentation.timing": True,
**kwargs,
},
)
assert response.is_successful(), response.json
stop = time.perf_counter()
self.logger.debug(
f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
f"{response.json.get('timing', {}).get('searchtime', -1)} s"
)
return self.format_query_results(query, response)

async def query_vespa_bm25(
self,
query: str,
Expand Down Expand Up @@ -286,12 +238,14 @@ async def get_result_from_query(

rank_method = ranking.split("_")[0]
sim_map: bool = len(ranking.split("_")) > 1 and ranking.split("_")[1] == "sim"
if rank_method == "nn+colpali":
result = await self.query_vespa_nearest_neighbor(
query, q_embs, sim_map=sim_map
if rank_method == "colpali": # ColPali
result = await self.query_vespa_colpali(
query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map
)
elif rank_method == "hybrid": # Hybrid ColPali+BM25
result = await self.query_vespa_colpali(
query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map
)
elif rank_method == "bm25+colpali":
result = await self.query_vespa_default(query, q_embs, sim_map=sim_map)
elif rank_method == "bm25":
result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map)
else:
Expand Down Expand Up @@ -419,9 +373,10 @@ def get_rank_profile(self, ranking: str, sim_map: bool) -> str:
else:
return ranking

async def query_vespa_nearest_neighbor(
async def query_vespa_colpali(
self,
query: str,
ranking: str,
q_emb: torch.Tensor,
target_hits_per_query_tensor: int = 100,
hnsw_explore_additional_hits: int = 300,
Expand Down Expand Up @@ -467,7 +422,7 @@ async def query_vespa_nearest_neighbor(
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()"
),
"ranking.profile": self.get_rank_profile(
"retrieval-and-rerank", sim_map
ranking=ranking, sim_map=sim_map
),
"timeout": timeout,
"hits": hits,
Expand Down
8 changes: 4 additions & 4 deletions visual-retrieval-colpali/src/frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def ShareButtons():
)


def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
def SearchBox(with_border=False, query_value="", ranking_value="hybrid"):
grid_cls = "grid gap-2 items-center p-3 bg-muted w-full"

if with_border:
Expand Down Expand Up @@ -203,7 +203,7 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
Span("Ranking by:", cls="text-muted-foreground text-xs font-semibold"),
RadioGroup(
Div(
RadioGroupItem(value="nn+colpali", id="nn+colpali"),
RadioGroupItem(value="colpali", id="colpali"),
Label("ColPali", htmlFor="ColPali"),
cls="flex items-center space-x-2",
),
Expand All @@ -213,7 +213,7 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
cls="flex items-center space-x-2",
),
Div(
RadioGroupItem(value="bm25+colpali", id="bm25+colpali"),
RadioGroupItem(value="hybrid", id="hybrid"),
Label("Hybrid ColPali + BM25", htmlFor="Hybrid ColPali + BM25"),
cls="flex items-center space-x-2",
),
Expand Down Expand Up @@ -349,7 +349,7 @@ def AboutThisDemo():

def Search(request, search_results=[]):
query_value = request.query_params.get("query", "").strip()
ranking_value = request.query_params.get("ranking", "nn+colpali")
ranking_value = request.query_params.get("ranking", "hybrid")
return Div(
Div(
Div(
Expand Down
2 changes: 1 addition & 1 deletion visual-retrieval-colpali/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get():


@rt("/search")
def get(request, query: str = "", ranking: str = "nn+colpali"):
def get(request, query: str = "", ranking: str = "hybrid"):
logger.info(f"/search: Fetching results for query: {query}, ranking: {ranking}")

# Always render the SearchBox first
Expand Down