diff --git a/visual-retrieval-colpali/prepare_feed_deploy.py b/visual-retrieval-colpali/prepare_feed_deploy.py index 6497b10b3..2213aa379 100644 --- a/visual-retrieval-colpali/prepare_feed_deploy.py +++ b/visual-retrieval-colpali/prepare_feed_deploy.py @@ -725,7 +725,7 @@ def float_to_binary_embedding(float_query_embedding: dict) -> dict: ] # Define the 'bm25' rank profile -colpali_bm25_profile = RankProfile( +bm25 = RankProfile( name="bm25", inputs=[("query(qt)", "tensor(querytoken{}, v[128])")], first_phase="bm25(title) + bm25(text)", @@ -743,14 +743,27 @@ def with_quantized_similarity(rank_profile: RankProfile) -> RankProfile: ) -colpali_schema.add_rank_profile(colpali_bm25_profile) -colpali_schema.add_rank_profile(with_quantized_similarity(colpali_bm25_profile)) +colpali_schema.add_rank_profile(bm25) +colpali_schema.add_rank_profile(with_quantized_similarity(bm25)) -# Update the 'default' rank profile -colpali_profile = RankProfile( - name="default", - inputs=[("query(qt)", "tensor(querytoken{}, v[128])")], - first_phase="bm25_score", + +# Update the 'colpali' rank profile +input_query_tensors = [] +MAX_QUERY_TERMS = 64 +for i in range(MAX_QUERY_TERMS): + input_query_tensors.append((f"query(rq{i})", "tensor(v[16])")) + +input_query_tensors.extend( + [ + ("query(qt)", "tensor(querytoken{}, v[128])"), + ("query(qtb)", "tensor(querytoken{}, v[16])"), + ] +) + +colpali = RankProfile( + name="colpali", + inputs=input_query_tensors, + first_phase="max_sim_binary", second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10), functions=mapfunctions + [ @@ -768,30 +781,33 @@ def with_quantized_similarity(rank_profile: RankProfile) -> RankProfile: ) """, ), - Function(name="bm25_score", expression="bm25(title) + bm25(text)"), + Function( + name="max_sim_binary", + expression=""" + sum( + reduce( + 1 / (1 + sum( + hamming(query(qtb), attribute(embedding)), v) + ), + max, patch + ), + querytoken + ) + """, + ), ], ) -colpali_schema.add_rank_profile(colpali_profile) -colpali_schema.add_rank_profile(with_quantized_similarity(colpali_profile)) - -# Update the 'retrieval-and-rerank' rank profile -input_query_tensors = [] -MAX_QUERY_TERMS = 64 -for i in range(MAX_QUERY_TERMS): - input_query_tensors.append((f"query(rq{i})", "tensor(v[16])")) - -input_query_tensors.extend( - [ - ("query(qt)", "tensor(querytoken{}, v[128])"), - ("query(qtb)", "tensor(querytoken{}, v[16])"), - ] -) +colpali_schema.add_rank_profile(colpali) +colpali_schema.add_rank_profile(with_quantized_similarity(colpali)) -colpali_retrieval_profile = RankProfile( - name="retrieval-and-rerank", +# Update the 'hybrid' rank profile +hybrid = RankProfile( + name="hybrid", inputs=input_query_tensors, first_phase="max_sim_binary", - second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10), + second_phase=SecondPhaseRanking( + expression="max_sim + 2 * (bm25(text) + bm25(title))", rerank_count=10 + ), functions=mapfunctions + [ Function( @@ -824,8 +840,8 @@ def with_quantized_similarity(rank_profile: RankProfile) -> RankProfile: ), ], ) -colpali_schema.add_rank_profile(colpali_retrieval_profile) -colpali_schema.add_rank_profile(with_quantized_similarity(colpali_retrieval_profile)) +colpali_schema.add_rank_profile(hybrid) +colpali_schema.add_rank_profile(with_quantized_similarity(hybrid)) # + from vespa.configuration.services import ( diff --git a/visual-retrieval-colpali/src/backend/vespa_app.py b/visual-retrieval-colpali/src/backend/vespa_app.py index a8197e89b..5b4509435 100644 --- a/visual-retrieval-colpali/src/backend/vespa_app.py +++ b/visual-retrieval-colpali/src/backend/vespa_app.py @@ -104,54 +104,6 @@ def format_query_results( self.logger.debug(result_text) return response.json - async def query_vespa_default( - self, - query: str, - q_emb: torch.Tensor, - hits: int = 3, - timeout: str = "10s", - sim_map: bool = False, - **kwargs, - ) -> dict: - """ - Query Vespa using the default ranking profile. - This corresponds to the "Hybrid ColPali+BM25" radio button in the UI. - - Args: - query (str): The query text. - q_emb (torch.Tensor): Query embeddings. - hits (int, optional): Number of hits to retrieve. Defaults to 3. - timeout (str, optional): Query timeout. Defaults to "10s". - - Returns: - dict: The formatted query results. - """ - async with self.app.asyncio(connections=1) as session: - query_embedding = self.format_q_embs(q_emb) - - start = time.perf_counter() - response: VespaQueryResponse = await session.query( - body={ - "yql": ( - f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();" - ), - "ranking": self.get_rank_profile("default", sim_map), - "query": query, - "timeout": timeout, - "hits": hits, - "input.query(qt)": query_embedding, - "presentation.timing": True, - **kwargs, - }, - ) - assert response.is_successful(), response.json - stop = time.perf_counter() - self.logger.debug( - f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was " - f"{response.json.get('timing', {}).get('searchtime', -1)} s" - ) - return self.format_query_results(query, response) - async def query_vespa_bm25( self, query: str, @@ -286,12 +238,14 @@ async def get_result_from_query( rank_method = ranking.split("_")[0] sim_map: bool = len(ranking.split("_")) > 1 and ranking.split("_")[1] == "sim" - if rank_method == "nn+colpali": - result = await self.query_vespa_nearest_neighbor( - query, q_embs, sim_map=sim_map + if rank_method == "colpali": # ColPali + result = await self.query_vespa_colpali( + query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map + ) + elif rank_method == "hybrid": # Hybrid ColPali+BM25 + result = await self.query_vespa_colpali( + query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map ) - elif rank_method == "bm25+colpali": - result = await self.query_vespa_default(query, q_embs, sim_map=sim_map) elif rank_method == "bm25": result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map) else: @@ -419,9 +373,10 @@ def get_rank_profile(self, ranking: str, sim_map: bool) -> str: else: return ranking - async def query_vespa_nearest_neighbor( + async def query_vespa_colpali( self, query: str, + ranking: str, q_emb: torch.Tensor, target_hits_per_query_tensor: int = 100, hnsw_explore_additional_hits: int = 300, @@ -467,7 +422,7 @@ async def query_vespa_nearest_neighbor( f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()" ), "ranking.profile": self.get_rank_profile( - "retrieval-and-rerank", sim_map + ranking=ranking, sim_map=sim_map ), "timeout": timeout, "hits": hits, diff --git a/visual-retrieval-colpali/src/frontend/app.py b/visual-retrieval-colpali/src/frontend/app.py index 06e01ebf2..f550e839f 100644 --- a/visual-retrieval-colpali/src/frontend/app.py +++ b/visual-retrieval-colpali/src/frontend/app.py @@ -175,7 +175,7 @@ def ShareButtons(): ) -def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"): +def SearchBox(with_border=False, query_value="", ranking_value="hybrid"): grid_cls = "grid gap-2 items-center p-3 bg-muted w-full" if with_border: @@ -203,7 +203,7 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"): Span("Ranking by:", cls="text-muted-foreground text-xs font-semibold"), RadioGroup( Div( - RadioGroupItem(value="nn+colpali", id="nn+colpali"), + RadioGroupItem(value="colpali", id="colpali"), Label("ColPali", htmlFor="ColPali"), cls="flex items-center space-x-2", ), @@ -213,7 +213,7 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"): cls="flex items-center space-x-2", ), Div( - RadioGroupItem(value="bm25+colpali", id="bm25+colpali"), + RadioGroupItem(value="hybrid", id="hybrid"), Label("Hybrid ColPali + BM25", htmlFor="Hybrid ColPali + BM25"), cls="flex items-center space-x-2", ), @@ -349,7 +349,7 @@ def AboutThisDemo(): def Search(request, search_results=[]): query_value = request.query_params.get("query", "").strip() - ranking_value = request.query_params.get("ranking", "nn+colpali") + ranking_value = request.query_params.get("ranking", "hybrid") return Div( Div( Div( diff --git a/visual-retrieval-colpali/src/main.py b/visual-retrieval-colpali/src/main.py index cc09a2dfe..7117c8a57 100644 --- a/visual-retrieval-colpali/src/main.py +++ b/visual-retrieval-colpali/src/main.py @@ -156,7 +156,7 @@ def get(): @rt("/search") -def get(request, query: str = "", ranking: str = "nn+colpali"): +def get(request, query: str = "", ranking: str = "hybrid"): logger.info(f"/search: Fetching results for query: {query}, ranking: {ranking}") # Always render the SearchBox first