From 86a98b19883ea80b12eeb76e1de4a1e98e515cf4 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Mon, 18 Mar 2024 11:28:42 +0000 Subject: [PATCH] resolve conflicts --- examples/similarity_search_for_text.py | 1 - src/neo4j_genai/client.py | 33 +++++++++++++------------- src/neo4j_genai/types.py | 7 ++---- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 5d278cfba..40c094922 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -26,7 +26,6 @@ def embed_query(self, text: str) -> List[float]: # Initialize the client client = GenAIClient(driver, embedder) -client.drop_index(INDEX_NAME) # Creating the index client.create_index( INDEX_NAME, diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/client.py index d7643ecc6..085bc38a2 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/client.py @@ -1,8 +1,8 @@ -from typing import List, Optional +from typing import List, Optional, Any from pydantic import ValidationError from neo4j import Driver from .embedder import Embedder -from .types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord +from .types import CreateIndexModel, SimilaritySearchModel class GenAIClient: @@ -114,7 +114,8 @@ def similarity_search( query_vector: Optional[List[float]] = None, query_text: Optional[str] = None, top_k: int = 5, - ) -> List[Neo4jRecord]: + custom_retrieval_query: Optional[str] = None, + ) -> Any: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. See the following documentation for more details: @@ -126,13 +127,15 @@ def similarity_search( query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None. top_k (int, optional): The number of neighbors to return. Defaults to 5. + custom_retrieval_query (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None Raises: ValueError: If validation of the input arguments fail. ValueError: If no embedder is provided. Returns: - List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores. + Any: The `top_k` neighbors found in vector search with their nodes and scores. + If custom_retrieval_query is provided, this is changed. """ try: validated_data = SimilaritySearchModel( @@ -140,6 +143,7 @@ def similarity_search( top_k=top_k, query_vector=query_vector, query_text=query_text, + custom_retrieval_query=custom_retrieval_query, ) except ValidationError as e: error_details = e.errors() @@ -154,19 +158,16 @@ def similarity_search( parameters["query_vector"] = query_vector del parameters["query_text"] - db_query_string = """ + query_prefix = """ CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector) YIELD node, score """ - records, _, _ = self.driver.execute_query(db_query_string, parameters) - try: - return [ - Neo4jRecord(node=record["node"], score=record["score"]) - for record in records - ] - except ValidationError as e: - error_details = e.errors() - raise ValueError( - f"Validation failed while constructing output: {error_details}" - ) + if parameters.get("custom_retrieval_query") is not None: + search_query = query_prefix + parameters["custom_retrieval_query"] + del parameters["custom_retrieval_query"] + else: + search_query = query_prefix + + records, _, _ = self.driver.execute_query(search_query, parameters) + return records diff --git a/src/neo4j_genai/types.py b/src/neo4j_genai/types.py index 4fa90594b..6649f782f 100644 --- a/src/neo4j_genai/types.py +++ b/src/neo4j_genai/types.py @@ -1,11 +1,7 @@ -from typing import List, Any, Literal, Optional +from typing import List, Literal, Optional from pydantic import BaseModel, PositiveInt, model_validator -Neo4jRecord = dict[str, Any] -"""Type alias for data items returned from Neo4j queries""" - - class EmbeddingVector(BaseModel): vector: List[float] @@ -23,6 +19,7 @@ class SimilaritySearchModel(BaseModel): top_k: PositiveInt = 5 query_vector: Optional[List[float]] = None query_text: Optional[str] = None + custom_retrieval_query: Optional[str] = None @model_validator(mode="before") def check_query(cls, values):