Skip to content

Commit

Permalink
split custom and vector search methods
Browse files Browse the repository at this point in the history
  • Loading branch information
willtai committed Mar 28, 2024
1 parent dd6ae19 commit 779960f
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 18 deletions.
15 changes: 9 additions & 6 deletions src/neo4j_genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def search_similar_vectors(
query_vector: Optional[List[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> Any:
) -> List[Neo4jRecord]:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
See the following documentation for more details:
Expand All @@ -137,8 +137,7 @@ def search_similar_vectors(
ValueError: If no embedder is provided.
Returns:
Any: The `top_k` neighbors found in vector search with their nodes and scores.
If custom_retrieval_query is provided, this is changed.
List[Neo4jRecord]: The `top_k` neighbors found in vector search with their nodes and scores.
"""
try:
validated_data = SimilaritySearchModel(
Expand Down Expand Up @@ -196,15 +195,14 @@ def custom_search_similar_vectors(
query_text (Optional[str], optional): The text to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
custom_retrieval_query (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None
custom_params (Optional[str], optional: Custom query to use as suffix for retrieval query. Defaults to None
custom_params (Optional[Dict[str, Any]], optional: Query parameters to provide for the custom query. Defaults to None
Raises:
ValueError: If validation of the input arguments fail.
ValueError: If no embedder is provided.
Returns:
Any: The `top_k` neighbors found in vector search with their nodes and scores.
If custom_retrieval_query is provided, this is changed.
"""
try:
validated_data = CustomSimilaritySearchModel(
Expand All @@ -226,12 +224,17 @@ def custom_search_similar_vectors(
parameters["query_vector"] = self.embedder.embed_query(query_text)
del parameters["query_text"]

if custom_params:
for key, value in custom_params.items():
if key not in parameters:
parameters[key] = value
del parameters["custom_params"]

query_prefix = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
search_query = query_prefix + parameters["custom_retrieval_query"]
del parameters["custom_retrieval_query"]

records, _, _ = self.driver.execute_query(search_query, parameters)
return records
11 changes: 0 additions & 11 deletions src/neo4j_genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,3 @@ def check_only_either_vector_or_text(cls, values):
class CustomSimilaritySearchModel(SimilaritySearchModel):
custom_retrieval_query: str
custom_params: Optional[Dict[str, Any]] = None

@model_validator(mode="before")
def combine_custom_params(cls, values):
"""
Combine custom_params dict into the main model's fields.
"""
custom_params = values.pop("custom_params", None) or {}
for key, value in custom_params.items():
if key not in values:
values[key] = value
return values
54 changes: 53 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_custom_retrieval_query_happy_path(_verify_version_mock, driver):
YIELD node, score
"""
custom_retrieval_query = """
RETURN node.id as node_id, node.text as text, score
RETURN node.id AS node_id, node.text AS text, score
"""

records = client.custom_search_similar_vectors(
Expand All @@ -289,6 +289,58 @@ def test_custom_retrieval_query_happy_path(_verify_version_mock, driver):
assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}]


@patch("neo4j_genai.GenAIClient._verify_version")
def test_custom_retrieval_query_with_params(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
custom_embeddings = MagicMock()
custom_embeddings.embed_query.return_value = embed_query_vector

client = GenAIClient(driver, custom_embeddings)

index_name = "my-index"
query_text = "may thy knife chip and shatter"
top_k = 5

driver.execute_query.return_value = [
[{"node_id": 123, "text": "dummy-text", "score": 1.0}],
None,
None,
]

search_query = """
CALL db.index.vector.queryNodes($index_name, $top_k, $query_vector)
YIELD node, score
"""
custom_retrieval_query = """
RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata
"""
custom_params = {
"param": "dummy-param",
}

records = client.custom_search_similar_vectors(
name=index_name,
query_text=query_text,
top_k=top_k,
custom_retrieval_query=custom_retrieval_query,
custom_params=custom_params,
)

custom_embeddings.embed_query.assert_called_once_with(query_text)

driver.execute_query.assert_called_once_with(
search_query + custom_retrieval_query,
{
"index_name": index_name,
"top_k": top_k,
"query_vector": embed_query_vector,
"param": "dummy-param",
},
)

assert records == [{"node_id": 123, "text": "dummy-text", "score": 1.0}]


@patch("neo4j_genai.GenAIClient._verify_version")
def test_custom_retrieval_query_cypher_error(_verify_version_mock, driver):
embed_query_vector = [1.0 for _ in range(1536)]
Expand Down

0 comments on commit 779960f

Please sign in to comment.