diff --git a/examples/openai_search.py b/examples/openai_search.py index 83d160d03..f42f76962 100644 --- a/examples/openai_search.py +++ b/examples/openai_search.py @@ -20,7 +20,7 @@ embedder = OpenAIEmbeddings(model="text-embedding-3-large") # Initialize the retriever -retriever = VectorRetriever(driver, embedder) +retriever = VectorRetriever(driver, INDEX_NAME, embedder) # Creating the index create_vector_index( @@ -47,4 +47,4 @@ # Perform the similarity search for a text query query_text = "hello world" -print(retriever.search(INDEX_NAME, query_text=query_text, top_k=5)) +print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/similarity_search_for_text.py b/examples/similarity_search_for_text.py index 5bbd0c9b7..3104bbc6e 100644 --- a/examples/similarity_search_for_text.py +++ b/examples/similarity_search_for_text.py @@ -24,9 +24,6 @@ def embed_query(self, text: str) -> List[float]: embedder = CustomEmbedder() -# Initialize the retriever -retriever = VectorRetriever(driver, embedder) - # Creating the index create_vector_index( driver, @@ -37,6 +34,9 @@ def embed_query(self, text: str) -> List[float]: similarity_fn="euclidean", ) +# Initialize the retriever +retriever = VectorRetriever(driver, INDEX_NAME, embedder) + # Upsert the query vector = [random() for _ in range(DIMENSION)] insert_query = ( @@ -52,4 +52,4 @@ def embed_query(self, text: str) -> List[float]: # Perform the similarity search for a text query query_text = "hello world" -print(retriever.search(INDEX_NAME, query_text=query_text, top_k=5)) +print(retriever.search(query_text=query_text, top_k=5)) diff --git a/examples/similarity_search_for_vector.py b/examples/similarity_search_for_vector.py index 0a741427d..310456760 100644 --- a/examples/similarity_search_for_vector.py +++ b/examples/similarity_search_for_vector.py @@ -14,9 +14,6 @@ # Connect to Neo4j database driver = GraphDatabase.driver(URI, auth=AUTH) -# Initialize the retriever -retriever = VectorRetriever(driver) - # Creating the index create_vector_index( driver, @@ -27,6 +24,9 @@ similarity_fn="euclidean", ) +# Initialize the retriever +retriever = VectorRetriever(driver, INDEX_NAME) + # Upsert the vector vector = [random() for _ in range(DIMENSION)] insert_query = ( @@ -43,4 +43,4 @@ # Perform the similarity search for a vector query query_vector = [random() for _ in range(DIMENSION)] -print(retriever.search(INDEX_NAME, query_vector=query_vector, top_k=5)) +print(retriever.search(query_vector=query_vector, top_k=5)) diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index d1e62ddc0..29301d61b 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -46,7 +46,7 @@ def create_vector_index( raise ValueError(f"Error for inputs to create_vector_index {str(e)}") query = ( - f"CREATE VECTOR INDEX $name IF NOT EXISTS FOR (n:{label}) ON n.{property} OPTIONS " + f"CREATE VECTOR INDEX $name FOR (n:{label}) ON n.{property} OPTIONS " "{ indexConfig: { `vector.dimensions`: toInteger($dimensions), `vector.similarity_function`: $similarity_fn } }" ) driver.execute_query( diff --git a/src/neo4j_genai/retrievers.py b/src/neo4j_genai/retrievers.py index 7226b4d8c..ef1ab84db 100644 --- a/src/neo4j_genai/retrievers.py +++ b/src/neo4j_genai/retrievers.py @@ -13,10 +13,12 @@ class VectorRetriever: def __init__( self, driver: Driver, + index_name: str, embedder: Optional[Embedder] = None, ) -> None: self.driver = driver self._verify_version() + self.index_name = index_name self.embedder = embedder def _verify_version(self) -> None: @@ -48,7 +50,6 @@ def _verify_version(self) -> None: def search( self, - name: str, query_vector: Optional[List[float]] = None, query_text: Optional[str] = None, top_k: int = 5, @@ -74,7 +75,7 @@ def search( """ try: validated_data = SimilaritySearchModel( - index_name=name, + index_name=self.index_name, top_k=top_k, query_vector=query_vector, query_text=query_text, diff --git a/tests/conftest.py b/tests/conftest.py index bc181925c..d05db4bde 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,4 +12,4 @@ def driver(): @pytest.fixture @patch("neo4j_genai.VectorRetriever._verify_version") def retriever(_verify_version_mock, driver): - return VectorRetriever(driver) + return VectorRetriever(driver, "my-index") diff --git a/tests/test_retrievers.py b/tests/test_retrievers.py index a0d8ada71..cb8014e7b 100644 --- a/tests/test_retrievers.py +++ b/tests/test_retrievers.py @@ -7,14 +7,14 @@ def test_vector_retriever_supported_aura_version(driver): driver.execute_query.return_value = [[{"versions": ["5.18-aura"]}], None, None] - VectorRetriever(driver=driver) + VectorRetriever(driver=driver, index_name="my-index") def test_vector_retriever_no_supported_aura_version(driver): driver.execute_query.return_value = [[{"versions": ["5.3-aura"]}], None, None] with pytest.raises(ValueError) as excinfo: - VectorRetriever(driver=driver) + VectorRetriever(driver=driver, index_name="my-index") assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo) @@ -22,14 +22,14 @@ def test_vector_retriever_no_supported_aura_version(driver): def test_vector_retriever_supported_version(driver): driver.execute_query.return_value = [[{"versions": ["5.19.0"]}], None, None] - VectorRetriever(driver=driver) + VectorRetriever(driver=driver, index_name="my-index") def test_vector_retriever_no_supported_version(driver): driver.execute_query.return_value = [[{"versions": ["4.3.5"]}], None, None] with pytest.raises(ValueError) as excinfo: - VectorRetriever(driver=driver) + VectorRetriever(driver=driver, index_name="my-index") assert "This package only supports Neo4j version 5.18.1 or greater" in str(excinfo) @@ -38,13 +38,13 @@ def test_vector_retriever_no_supported_version(driver): def test_similarity_search_vector_happy_path(_verify_version_mock, driver): custom_embeddings = MagicMock() - retriever = VectorRetriever(driver, custom_embeddings) - index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 + retriever = VectorRetriever(driver, index_name, custom_embeddings) + retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], None, @@ -55,7 +55,7 @@ def test_similarity_search_vector_happy_path(_verify_version_mock, driver): YIELD node, score """ - records = retriever.search(name=index_name, query_vector=query_vector, top_k=top_k) + records = retriever.search(query_vector=query_vector, top_k=top_k) custom_embeddings.embed_query.assert_not_called() @@ -77,12 +77,12 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): custom_embeddings = MagicMock() custom_embeddings.embed_query.return_value = embed_query_vector - retriever = VectorRetriever(driver, custom_embeddings) - index_name = "my-index" query_text = "may thy knife chip and shatter" top_k = 5 + retriever = VectorRetriever(driver, index_name, custom_embeddings) + driver.execute_query.return_value = [ [{"node": "dummy-node", "score": 1.0}], None, @@ -94,7 +94,7 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): YIELD node, score """ - records = retriever.search(name=index_name, query_text=query_text, top_k=top_k) + records = retriever.search(query_text=query_text, top_k=top_k) custom_embeddings.embed_query.assert_called_once_with(query_text) @@ -111,16 +111,14 @@ def test_similarity_search_text_happy_path(_verify_version_mock, driver): def test_similarity_search_missing_embedder_for_text(retriever): - index_name = "my-index" query_text = "may thy knife chip and shatter" top_k = 5 with pytest.raises(ValueError, match="Embedding method required for text query"): - retriever.search(name=index_name, query_text=query_text, top_k=top_k) + retriever.search(query_text=query_text, top_k=top_k) def test_similarity_search_both_text_and_vector(retriever): - index_name = "my-index" query_text = "may thy knife chip and shatter" query_vector = [1.1, 2.2, 3.3] top_k = 5 @@ -129,7 +127,6 @@ def test_similarity_search_both_text_and_vector(retriever): ValueError, match="You must provide exactly one of query_vector or query_text." ): retriever.search( - name=index_name, query_text=query_text, query_vector=query_vector, top_k=top_k, @@ -140,13 +137,13 @@ def test_similarity_search_both_text_and_vector(retriever): def test_similarity_search_vector_bad_results(_verify_version_mock, driver): custom_embeddings = MagicMock() - retriever = VectorRetriever(driver, custom_embeddings) - index_name = "my-index" dimensions = 1536 query_vector = [1.0 for _ in range(dimensions)] top_k = 5 + retriever = VectorRetriever(driver, index_name, custom_embeddings) + retriever.driver.execute_query.return_value = [ [{"node": "dummy-node", "score": "adsa"}], None, @@ -158,7 +155,7 @@ def test_similarity_search_vector_bad_results(_verify_version_mock, driver): """ with pytest.raises(ValueError): - retriever.search(name=index_name, query_vector=query_vector, top_k=top_k) + retriever.search(query_vector=query_vector, top_k=top_k) custom_embeddings.embed_query.assert_not_called()