Skip to content

Commit

Permalink
Add custom retrieval query option in similarity search
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Mar 26, 2024
1 parent e2d9546 commit 259c2d6
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 32 deletions.
1 change: 0 additions & 1 deletion examples/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def embed_query(self, text: str) -> List[float]:
# Initialize the client
client = GenAIClient(driver, embedder)

client.drop_index(INDEX_NAME)
# Creating the index
client.create_index(
INDEX_NAME,
Expand Down
35 changes: 18 additions & 17 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Optional
from typing import List, Optional, Any
from pydantic import ValidationError
from neo4j import Driver
from .embedder import Embedder
from neo4j_genai.types import CreateIndexModel, SimilaritySearchModel, Neo4jRecord
from .types import CreateIndexModel, SimilaritySearchModel


class GenAIClient:
Expand Down Expand Up @@ -91,7 +91,7 @@ def create_index(
"toInteger($dimensions),"
"$similarity_fn )"
)
self.database_query(query, params=index_data.dict())
self.driver.execute_query(query, index_data.model_dump())

def drop_index(self, name: str) -> None:
"""
Expand All @@ -114,7 +114,8 @@ def similarity_search(
query_vector: Optional[List[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> List[Neo4jRecord]:
custom_retrieval_query: Optional[str] = None,
) -> Any:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -126,20 +127,23 @@ def similarity_search(
query_vector (Optional[List[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
custom_retrieval_query (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
Any: The `top_k` neighbors found in vector search with their nodes and scores.
If custom_retrieval_query is provided, this is changed.
"""
try:
validated_data = SimilaritySearchModel(
index_name=name,
top_k=top_k,
query_vector=query_vector,
query_text=query_text,
custom_retrieval_query=custom_retrieval_query,
)
except ValidationError as e:
error_details = e.errors()
Expand All @@ -154,19 +158,16 @@ def similarity_search(
parameters["query_vector"] = query_vector
del parameters["query_text"]

db_query_string = """
query_prefix = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
records, _, _ = self.driver.execute_query(db_query_string, parameters)

try:
return [
Neo4jRecord(node=record["node"], score=record["score"])
for record in records
]
except ValidationError as e:
error_details = e.errors()
raise ValueError(
f"Validation failed while constructing output: {error_details}"
)
if parameters.get("custom_retrieval_query") is not None:
search_query = query_prefix + parameters["custom_retrieval_query"]
del parameters["custom_retrieval_query"]
else:
search_query = query_prefix

records, _, _ = self.driver.execute_query(search_query, parameters)
return records
8 changes: 2 additions & 6 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from typing import List, Any, Literal, Optional
from typing import List, Literal, Optional
from pydantic import BaseModel, PositiveInt, model_validator


class Neo4jRecord(BaseModel):
node: Any
score: float


class EmbeddingVector(BaseModel):
vector: List[float]

Expand All @@ -24,6 +19,7 @@ class SimilaritySearchModel(BaseModel):
top_k: PositiveInt = 5
query_vector: Optional[List[float]] = None
query_text: Optional[str] = None
custom_retrieval_query: Optional[str] = None

@model_validator(mode="before")
def check_query(cls, values):
Expand Down
86 changes: 78 additions & 8 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from neo4j_genai import GenAIClient
from neo4j_genai.types import Neo4jRecord
from unittest.mock import patch, MagicMock
from neo4j.exceptions import CypherSyntaxError


def test_genai_client_supported_aura_version(driver):
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [{"node": "dummy-node", "score": 1.0}]


@patch("neo4j_genai.GenAIClient._verify_version")
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver):
},
)

assert records == [Neo4jRecord(node="dummy-node", score=1.0)]
assert records == [{"node": "dummy-node", "score": 1.0}]


def test_similarity_search_missing_embedder_for_text(client):
Expand Down Expand Up @@ -204,11 +204,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
query_vector = [1.0 for _ in range(dimensions)]
top_k = 5

client.driver.execute_query.return_value = [
[{"node": "dummy-node", "score": "adsa"}],
None,
None,
]
client.driver.execute_query.side_effect = ValueError
search_query = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
Expand All @@ -229,3 +225,77 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver):
"query_vector": query_vector,
},
)


@patch("neo4j_genai.GenAIClient._verify_version")
def test_custom_retrieval_query_happy_path(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector

client = GenAIClient(driver, custom_embeddings)

index_name = "my-index"
query_text = "may thy knife chip and shatter"
top_k = 5

driver.execute_query.return_value = [
[{"node_id": 123, "text": "dummy-text", "score": 1.0}],
None,
None,
]

search_query = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
custom_retrieval_query = """
RETURN node.id as node_id, node.text as text, score
"""

records = client.similarity_search(
name=index_name,
query_text=query_text,
top_k=top_k,
custom_retrieval_query=custom_retrieval_query,
)

custom_embeddings.embed_query.assert_called_once_with(query_text)

driver.execute_query.assert_called_once_with(
search_query + custom_retrieval_query,
{
"index_name": index_name,
"top_k": top_k,
"query_vector": embed_query_vector,
},
)

assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}]


@patch("neo4j_genai.GenAIClient._verify_version")
def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector

client = GenAIClient(driver, custom_embeddings)

index_name = "my-index"
query_text = "may thy knife chip and shatter"
top_k = 5

driver.execute_query.side_effect = CypherSyntaxError

custom_retrieval_query = """
this is not a cypher query
"""

with pytest.raises(CypherSyntaxError):
client.similarity_search(
name=index_name,
query_text=query_text,
top_k=top_k,
custom_retrieval_query=custom_retrieval_query,
)

0 comments on commit 259c2d6

Please sign in to comment.