From 32b967d1d1bf2f26e3330fa54abc3c224148fd82 Mon Sep 17 00:00:00 2001 From: Will Tai Date: Thu, 4 Apr 2024 12:00:14 +0100 Subject: [PATCH] Added VectorRetriever class and remove GenAIClient --- examples/similarity_search_for_text.py | 12 ++-- examples/similarity_search_for_vector.py | 13 ++-- src/neo4j_genai/__init__.py | 4 +- src/neo4j_genai/{client.py => retrievers.py} | 70 ++------------------ tests/conftest.py | 8 +-- tests/test_indexes.py | 4 +- tests/{test_client.py => test_retrievers.py} | 60 ++++++++--------- 7 files changed, 54 insertions(+), 117 deletions(-) rename src/neo4j_genai/{client.py => retrievers.py} (61%) rename tests/{test_client.py => test_retrievers.py} (71%) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 40c094922..5bbd0c9b7 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -1,9 +1,10 @@ from typing import List from neo4j import GraphDatabase -from neo4j_genai.client import GenAIClient +from neo4j_genai import VectorRetriever from random import random from neo4j_genai.embedder import Embedder +from neo4j_genai.indexes import create_vector_index URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") @@ -23,11 +24,12 @@ def embed_query(self, text: str) -> List[float]: embedder = CustomEmbedder() -# Initialize the client -client = GenAIClient(driver, embedder) +# Initialize the retriever +retriever = VectorRetriever(driver, embedder) # Creating the index -client.create_index( +create_vector_index( + driver, INDEX_NAME, label="Document", property="propertyKey", @@ -50,4 +52,4 @@ def embed_query(self, text: str) -> List[float]: # Perform the similarity search for a text query query_text = "hello world" -print(client.similarity_search(INDEX_NAME, query_text=query_text, top_k=5)) +print(retriever.search(INDEX_NAME, query_text=query_text, top_k=5)) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 6c4c9c3aa..0a741427d 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -1,8 +1,10 @@ from neo4j import GraphDatabase -from neo4j_genai.client import GenAIClient +from neo4j_genai import VectorRetriever from random import random +from neo4j_genai.indexes import create_vector_index + URI = "neo4j://localhost:7687" AUTH = ("neo4j", "password") @@ -12,11 +14,12 @@ # Connect to Neo4j database driver = GraphDatabase.driver(URI, auth=AUTH) -# Initialize the client -client = GenAIClient(driver) +# Initialize the retriever +retriever = VectorRetriever(driver) # Creating the index -client.create_index( +create_vector_index( + driver, INDEX_NAME, label="Document", property="propertyKey", @@ -40,4 +43,4 @@ # Perform the similarity search for a vector query query_vector = [random() for _ in range(DIMENSION)] -print(client.similarity_search(INDEX_NAME, query_vector=query_vector, top_k=5)) +print(retriever.search(INDEX_NAME, query_vector=query_vector, top_k=5)) diff --git a/src/neo4j_genai/__init__.py b/src/neo4j_genai/__init__.py index c7748c58c..de6038a8e 100644 --- a/src/neo4j_genai/__init__.py +++ b/src/neo4j_genai/__init__.py @@ -1,4 +1,4 @@ -from .client import GenAIClient +from .retrievers import VectorRetriever -__all__ = ["GenAIClient"] +__all__ = ["VectorRetriever"] diff --git a/src/neo4j_genai/client.py b/src/neo4j_genai/retrievers.py similarity index 61% rename from src/neo4j_genai/client.py rename to src/neo4j_genai/retrievers.py index 82580a4d7..7226b4d8c 100644 --- a/src/neo4j_genai/client.py +++ b/src/neo4j_genai/retrievers.py @@ -2,12 +2,12 @@ from pydantic import ValidationError from neo4j import Driver from .embedder import Embedder -from .types import VectorIndexModel, SimilaritySearchModel, Neo4jRecord +from .types import SimilaritySearchModel, Neo4jRecord -class GenAIClient: +class VectorRetriever: """ - Provides functionality to use Neo4j's GenAI features + Provides retrieval methods using vector search over embeddings """ def __init__( @@ -46,69 +46,7 @@ def _verify_version(self) -> None: "This package only supports Neo4j version 5.18.1 or greater" ) - def create_index( - self, - name: str, - label: str, - property: str, - dimensions: int, - similarity_fn: str, - ) -> None: - """ - This method constructs a Cypher query and executes it - to create a new vector index in Neo4j. - - See Cypher manual on [Create node index](https://neo4j.com/docs/operations-manual/5/reference/procedures/#procedure_db_index_vector_createNodeIndex) - - Args: - name (str): The unique name of the index. - label (str): The node label to be indexed. - property (str): The property key of a node which contains embedding values. - dimensions (int): Vector embedding dimension - similarity_fn (str): case-insensitive values for the vector similarity function: - ``euclidean`` or ``cosine``. - - Raises: - ValueError: If validation of the input arguments fail. - """ - try: - VectorIndexModel( - **{ - "name": name, - "label": label, - "property": property, - "dimensions": dimensions, - "similarity_fn": similarity_fn, - } - ) - except ValidationError as e: - raise ValueError(f"Error for inputs to create_index {str(e)}") - - query = ( - f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS " - "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" - ) - self.driver.execute_query( - query, - {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn}, - ) - - def drop_index(self, name: str) -> None: - """ - This method constructs a Cypher query and executes it - to drop a vector index in Neo4j. - See Cypher manual on [Drop vector indexes](https://neo4j.com/docs/cypher-manual/current/indexes-for-vector-search/#indexes-vector-drop) - - Args: - name (str): The name of the index to delete. - """ - query = "DROP INDEX $name" - parameters = { - "name": name, - } - self.driver.execute_query(query, parameters) - - def similarity_search( + def search( self, name: str, query_vector: Optional[List[float]] = None, diff --git a/tests/conftest.py b/tests/conftest.py index f481cdb43..bc181925c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import pytest -from neo4j_genai import GenAIClient +from neo4j_genai import VectorRetriever from neo4j import Driver from unittest.mock import MagicMock, patch @@ -10,6 +10,6 @@ def driver(): @pytest.fixture -@patch("neo4j_genai.GenAIClient._verify_version") -def client(_verify_version_mock, driver): - return GenAIClient(driver) +@patch("neo4j_genai.VectorRetriever._verify_version") +def retriever(_verify_version_mock, driver): + return VectorRetriever(driver) diff --git a/tests/test_indexes.py b/tests/test_indexes.py index 6b0ae6223..20383a65c 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -21,7 +21,7 @@ def test_create_vector_index_happy_path(driver): ) -def test_create_vector_index_ensure_escaping(driver, client): +def test_create_vector_index_ensure_escaping(driver): create_query = ( "CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:People) ON n.name OPTIONS " "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" @@ -97,7 +97,7 @@ def test_create_fulltext_index_empty_node_properties(driver): assert "Error for inputs to create_fulltext_index" in str(excinfo) -def test_create_fulltext_index_ensure_escaping(driver, client): +def test_create_fulltext_index_ensure_escaping(driver): label = "node-label" text_node_properties = ["property-1", "property-2"] create_query = ( diff --git a/tests/test_client.py b/tests/test_retrievers.py similarity index 71% rename from tests/test_client.py rename to tests/test_retrievers.py index 91c433959..a0d8ada71 100644 --- a/tests/test_client.py +++ b/tests/test_retrievers.py @@ -1,51 +1,51 @@ import pytest -from neo4j_genai import GenAIClient -from neo4j_genai.types import Neo4jRecord from unittest.mock import patch, MagicMock +from neo4j_genai import VectorRetriever +from neo4j_genai.types import Neo4jRecord -def test_genai_client_supported_aura_version(driver): +def test_vector_retriever_supported_aura_version(driver): driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None] - GenAIClient(driver=driver) + VectorRetriever(driver=driver) -def test_genai_client_no_supported_aura_version(driver): +def test_vector_retriever_no_supported_aura_version(driver): driver.execute_query.return_value = [[{"versions": ["5.3-aura"]}], None, None] with pytest.raises(ValueError) as excinfo: - GenAIClient(driver=driver) + VectorRetriever(driver=driver) assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo) -def test_genai_client_supported_version(driver): +def test_vector_retriever_supported_version(driver): driver.execute_query.return_value = [[{"versions": ["5.19.0"]}], None, None] - GenAIClient(driver=driver) + VectorRetriever(driver=driver) -def test_genai_client_no_supported_version(driver): +def test_vector_retriever_no_supported_version(driver): driver.execute_query.return_value = [[{"versions": ["4.3.5"]}], None, None] with pytest.raises(ValueError) as excinfo: - GenAIClient(driver=driver) + VectorRetriever(driver=driver) assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo) -@patch("neo4j_genai.GenAIClient._verify_version") +@patch("neo4j_genai.VectorRetriever._verify_version") def test_similarity_search_vector_happy_path(_verify_version_mock, driver): custom_embeddings = MagicMock() - client = GenAIClient(driver, custom_embeddings) + retriever = VectorRetriever(driver, custom_embeddings) index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - client.driver.execute_query.return_value = [ + retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], None, None, @@ -55,13 +55,11 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): YIELD node, score """ - records = client.similarity_search( - name=index_name, query_vector=query_vector, top_k=top_k - ) + records = retriever.search(name=index_name, query_vector=query_vector, top_k=top_k) custom_embeddings.embed_query.assert_not_called() - client.driver.execute_query.assert_called_once_with( + retriever.driver.execute_query.assert_called_once_with( search_query, { "index_name": index_name, @@ -73,13 +71,13 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): assert records == [Neo4jRecord(node="dummy-node", score=1.0)] -@patch("neo4j_genai.GenAIClient._verify_version") +@patch("neo4j_genai.VectorRetriever._verify_version") def test_similarity_search_text_happy_path(_verify_version_mock, driver): embed_query_vector = [1.0 for _ in range(1536)] custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector - client = GenAIClient(driver, custom_embeddings) + retriever = VectorRetriever(driver, custom_embeddings) index_name = "my-index" query_text = "may thy knife chip and shatter" @@ -96,9 +94,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): YIELD node, score """ - records = client.similarity_search( - name=index_name, query_text=query_text, top_k=top_k - ) + records = retriever.search(name=index_name, query_text=query_text, top_k=top_k) custom_embeddings.embed_query.assert_called_once_with(query_text) @@ -114,16 +110,16 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): assert records == [Neo4jRecord(node="dummy-node", score=1.0)] -def test_similarity_search_missing_embedder_for_text(client): +def test_similarity_search_missing_embedder_for_text(retriever): index_name = "my-index" query_text = "may thy knife chip and shatter" top_k = 5 with pytest.raises(ValueError, match="Embedding method required for text query"): - client.similarity_search(name=index_name, query_text=query_text, top_k=top_k) + retriever.search(name=index_name, query_text=query_text, top_k=top_k) -def test_similarity_search_both_text_and_vector(client): +def test_similarity_search_both_text_and_vector(retriever): index_name = "my-index" query_text = "may thy knife chip and shatter" query_vector = [1.1, 2.2, 3.3] @@ -132,7 +128,7 @@ def test_similarity_search_both_text_and_vector(client): with pytest.raises( ValueError, match="You must provide exactly one of query_vector or query_text." ): - client.similarity_search( + retriever.search( name=index_name, query_text=query_text, query_vector=query_vector, @@ -140,18 +136,18 @@ def test_similarity_search_both_text_and_vector(client): ) -@patch("neo4j_genai.GenAIClient._verify_version") +@patch("neo4j_genai.VectorRetriever._verify_version") def test_similarity_search_vector_bad_results(_verify_version_mock, driver): custom_embeddings = MagicMock() - client = GenAIClient(driver, custom_embeddings) + retriever = VectorRetriever(driver, custom_embeddings) index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 - client.driver.execute_query.return_value = [ + retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": "adsa"}], None, None, @@ -162,13 +158,11 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): """ with pytest.raises(ValueError): - client.similarity_search( - name=index_name, query_vector=query_vector, top_k=top_k - ) + retriever.search(name=index_name, query_vector=query_vector, top_k=top_k) custom_embeddings.embed_query.assert_not_called() - client.driver.execute_query.assert_called_once_with( + retriever.driver.execute_query.assert_called_once_with( search_query, { "index_name": index_name,