Skip to content

Commit

Permalink
resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Mar 18, 2024
1 parent 8615b53 commit 86a98b1
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 22 deletions.
1 change: 0 additions & 1 deletion examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def embed_query(self, text: str) -> List[float]:
# Initialize the client
client = GenAIClient(driver, embedder)

client.drop_index(INDEX_NAME)
# Creating the index
client.create_index(
INDEX_NAME,
Expand Down
33 changes: 17 additions & 16 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Optional
from typing import List, Optional, Any
from pydantic import ValidationError
from neo4j import Driver
from .embedder import Embedder
from .types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord
from .types import CreateIndexModel, SimilaritySearchModel


class GenAIClient:
Expand Down Expand Up @@ -114,7 +114,8 @@ def similarity_search(
query_vector: Optional[List[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> List[Neo4jRecord]:
custom_retrieval_query: Optional[str] = None,
) -> Any:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -126,20 +127,23 @@ def similarity_search(
query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
custom_retrieval_query (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
Any: The `top_k` neighbors found in vector search with their nodes and scores.
If custom_retrieval_query is provided, this is changed.
"""
try:
validated_data = SimilaritySearchModel(
index_name=name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
custom_retrieval_query=custom_retrieval_query,
)
except ValidationError as e:
error_details = e.errors()
Expand All @@ -154,19 +158,16 @@ def similarity_search(
parameters["query_vector"] = query_vector
del parameters["query_text"]

db_query_string = """
query_prefix = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
records, _, _ = self.driver.execute_query(db_query_string, parameters)

try:
return [
Neo4jRecord(node=record["node"], score=record["score"])
for record in records
]
except ValidationError as e:
error_details = e.errors()
raise ValueError(
f"Validation failed while constructing output: {error_details}"
)
if parameters.get("custom_retrieval_query") is not None:
search_query = query_prefix + parameters["custom_retrieval_query"]
del parameters["custom_retrieval_query"]
else:
search_query = query_prefix

records, _, _ = self.driver.execute_query(search_query, parameters)
return records
7 changes: 2 additions & 5 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import List, Any, Literal, Optional
from typing import List, Literal, Optional
from pydantic import BaseModel, PositiveInt, model_validator


Neo4jRecord = dict[str, Any]
"""Type alias for data items returned from Neo4j queries"""


class EmbeddingVector(BaseModel):
vector: List[float]

Expand All @@ -23,6 +19,7 @@ class SimilaritySearchModel(BaseModel):
top_k: PositiveInt = 5
query_vector: Optional[List[float]] = None
query_text: Optional[str] = None
custom_retrieval_query: Optional[str] = None

@model_validator(mode="before")
def check_query(cls, values):
Expand Down

0 comments on commit 86a98b1

Please sign in to comment.