Skip to content

Commit

Permalink
Merge pull request #1578 from vespa-engine/thomasht86/update-ranking
Browse files Browse the repository at this point in the history
(colpalidemo) Thomasht86/update ranking
  • Loading branch information
thomasht86 authored Nov 15, 2024
2 parents 9ea12e8 + 84650b9 commit 404661a
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 89 deletions.
74 changes: 45 additions & 29 deletions visual-retrieval-colpali/prepare_feed_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def float_to_binary_embedding(float_query_embedding: dict) -> dict:
]

# Define the 'bm25' rank profile
colpali_bm25_profile = RankProfile(
bm25 = RankProfile(
name="bm25",
inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
first_phase="bm25(title) + bm25(text)",
Expand All @@ -743,14 +743,27 @@ def with_quantized_similarity(rank_profile: RankProfile) -> RankProfile:
)


colpali_schema.add_rank_profile(colpali_bm25_profile)
colpali_schema.add_rank_profile(with_quantized_similarity(colpali_bm25_profile))
colpali_schema.add_rank_profile(bm25)
colpali_schema.add_rank_profile(with_quantized_similarity(bm25))

# Update the 'default' rank profile
colpali_profile = RankProfile(
name="default",
inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
first_phase="bm25_score",

# Update the 'colpali' rank profile
input_query_tensors = []
MAX_QUERY_TERMS = 64
for i in range(MAX_QUERY_TERMS):
input_query_tensors.append((f"query(rq{i})", "tensor<int8>(v[16])"))

input_query_tensors.extend(
[
("query(qt)", "tensor<float>(querytoken{}, v[128])"),
("query(qtb)", "tensor<int8>(querytoken{}, v[16])"),
]
)

colpali = RankProfile(
name="colpali",
inputs=input_query_tensors,
first_phase="max_sim_binary",
second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
functions=mapfunctions
+ [
Expand All @@ -768,30 +781,33 @@ def with_quantized_similarity(rank_profile: RankProfile) -> RankProfile:
)
""",
),
Function(name="bm25_score", expression="bm25(title) + bm25(text)"),
Function(
name="max_sim_binary",
expression="""
sum(
reduce(
1 / (1 + sum(
hamming(query(qtb), attribute(embedding)), v)
),
max, patch
),
querytoken
)
""",
),
],
)
colpali_schema.add_rank_profile(colpali_profile)
colpali_schema.add_rank_profile(with_quantized_similarity(colpali_profile))

# Update the 'retrieval-and-rerank' rank profile
input_query_tensors = []
MAX_QUERY_TERMS = 64
for i in range(MAX_QUERY_TERMS):
input_query_tensors.append((f"query(rq{i})", "tensor<int8>(v[16])"))

input_query_tensors.extend(
[
("query(qt)", "tensor<float>(querytoken{}, v[128])"),
("query(qtb)", "tensor<int8>(querytoken{}, v[16])"),
]
)
colpali_schema.add_rank_profile(colpali)
colpali_schema.add_rank_profile(with_quantized_similarity(colpali))

colpali_retrieval_profile = RankProfile(
name="retrieval-and-rerank",
# Update the 'hybrid' rank profile
hybrid = RankProfile(
name="hybrid",
inputs=input_query_tensors,
first_phase="max_sim_binary",
second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
second_phase=SecondPhaseRanking(
expression="max_sim + 2 * (bm25(text) + bm25(title))", rerank_count=10
),
functions=mapfunctions
+ [
Function(
Expand Down Expand Up @@ -824,8 +840,8 @@ def with_quantized_similarity(rank_profile: RankProfile) -> RankProfile:
),
],
)
colpali_schema.add_rank_profile(colpali_retrieval_profile)
colpali_schema.add_rank_profile(with_quantized_similarity(colpali_retrieval_profile))
colpali_schema.add_rank_profile(hybrid)
colpali_schema.add_rank_profile(with_quantized_similarity(hybrid))

# +
from vespa.configuration.services import (
Expand Down
65 changes: 10 additions & 55 deletions visual-retrieval-colpali/src/backend/vespa_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,54 +104,6 @@ def format_query_results(
self.logger.debug(result_text)
return response.json

async def query_vespa_default(
self,
query: str,
q_emb: torch.Tensor,
hits: int = 3,
timeout: str = "10s",
sim_map: bool = False,
**kwargs,
) -> dict:
"""
Query Vespa using the default ranking profile.
This corresponds to the "Hybrid ColPali+BM25" radio button in the UI.
Args:
query (str): The query text.
q_emb (torch.Tensor): Query embeddings.
hits (int, optional): Number of hits to retrieve. Defaults to 3.
timeout (str, optional): Query timeout. Defaults to "10s".
Returns:
dict: The formatted query results.
"""
async with self.app.asyncio(connections=1) as session:
query_embedding = self.format_q_embs(q_emb)

start = time.perf_counter()
response: VespaQueryResponse = await session.query(
body={
"yql": (
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();"
),
"ranking": self.get_rank_profile("default", sim_map),
"query": query,
"timeout": timeout,
"hits": hits,
"input.query(qt)": query_embedding,
"presentation.timing": True,
**kwargs,
},
)
assert response.is_successful(), response.json
stop = time.perf_counter()
self.logger.debug(
f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
f"{response.json.get('timing', {}).get('searchtime', -1)} s"
)
return self.format_query_results(query, response)

async def query_vespa_bm25(
self,
query: str,
Expand Down Expand Up @@ -286,12 +238,14 @@ async def get_result_from_query(

rank_method = ranking.split("_")[0]
sim_map: bool = len(ranking.split("_")) > 1 and ranking.split("_")[1] == "sim"
if rank_method == "nn+colpali":
result = await self.query_vespa_nearest_neighbor(
query, q_embs, sim_map=sim_map
if rank_method == "colpali": # ColPali
result = await self.query_vespa_colpali(
query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map
)
elif rank_method == "hybrid": # Hybrid ColPali+BM25
result = await self.query_vespa_colpali(
query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map
)
elif rank_method == "bm25+colpali":
result = await self.query_vespa_default(query, q_embs, sim_map=sim_map)
elif rank_method == "bm25":
result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map)
else:
Expand Down Expand Up @@ -419,9 +373,10 @@ def get_rank_profile(self, ranking: str, sim_map: bool) -> str:
else:
return ranking

async def query_vespa_nearest_neighbor(
async def query_vespa_colpali(
self,
query: str,
ranking: str,
q_emb: torch.Tensor,
target_hits_per_query_tensor: int = 100,
hnsw_explore_additional_hits: int = 300,
Expand Down Expand Up @@ -467,7 +422,7 @@ async def query_vespa_nearest_neighbor(
f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()"
),
"ranking.profile": self.get_rank_profile(
"retrieval-and-rerank", sim_map
ranking=ranking, sim_map=sim_map
),
"timeout": timeout,
"hits": hits,
Expand Down
8 changes: 4 additions & 4 deletions visual-retrieval-colpali/src/frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def ShareButtons():
)


def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
def SearchBox(with_border=False, query_value="", ranking_value="hybrid"):
grid_cls = "grid gap-2 items-center p-3 bg-muted w-full"

if with_border:
Expand Down Expand Up @@ -203,7 +203,7 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
Span("Ranking by:", cls="text-muted-foreground text-xs font-semibold"),
RadioGroup(
Div(
RadioGroupItem(value="nn+colpali", id="nn+colpali"),
RadioGroupItem(value="colpali", id="colpali"),
Label("ColPali", htmlFor="ColPali"),
cls="flex items-center space-x-2",
),
Expand All @@ -213,7 +213,7 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
cls="flex items-center space-x-2",
),
Div(
RadioGroupItem(value="bm25+colpali", id="bm25+colpali"),
RadioGroupItem(value="hybrid", id="hybrid"),
Label("Hybrid ColPali + BM25", htmlFor="Hybrid ColPali + BM25"),
cls="flex items-center space-x-2",
),
Expand Down Expand Up @@ -349,7 +349,7 @@ def AboutThisDemo():

def Search(request, search_results=[]):
query_value = request.query_params.get("query", "").strip()
ranking_value = request.query_params.get("ranking", "nn+colpali")
ranking_value = request.query_params.get("ranking", "hybrid")
return Div(
Div(
Div(
Expand Down
2 changes: 1 addition & 1 deletion visual-retrieval-colpali/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get():


@rt("/search")
def get(request, query: str = "", ranking: str = "nn+colpali"):
def get(request, query: str = "", ranking: str = "hybrid"):
logger.info(f"/search: Fetching results for query: {query}, ranking: {ranking}")

# Always render the SearchBox first
Expand Down

0 comments on commit 404661a

Please sign in to comment.