From d0528f4a773dc625f6afebdaa39d73dc3593c7fa Mon Sep 17 00:00:00 2001 From: willtai Date: Mon, 21 Oct 2024 09:48:33 +0100 Subject: [PATCH] Add check to not use deprecated Cypher syntax when Neo4j version is >= 5.23.0 (#183) * Add check to not use deprecated Cypher syntax when Neo4j version is >= 5.23.0 * Update CHANGELOG * Add variable scope query in Hybrid Retriever based on neo4j version * Include E2E test to test for deprecation warning from deprecated Cypher subquery syntax * Resolve mypy errors * Add neo4j:latest to pr and scheduled E2E tests --- .github/workflows/pr-e2e-tests.yaml | 10 +- .github/workflows/scheduled-e2e-tests.yaml | 13 +- CHANGELOG.md | 1 + .../experimental/components/kg_writer.py | 69 ++++-- src/neo4j_graphrag/neo4j_queries.py | 72 ++++-- src/neo4j_graphrag/retrievers/base.py | 11 + .../retrievers/external/pinecone/types.py | 3 - .../retrievers/external/weaviate/weaviate.py | 2 +- src/neo4j_graphrag/retrievers/hybrid.py | 10 +- tests/e2e/docker-compose.yml | 2 +- tests/e2e/test_hybrid_e2e.py | 21 ++ tests/e2e/test_kg_writer_component_e2e.py | 40 +++ .../experimental/components/test_kg_writer.py | 229 +++++++++++++++++- .../experimental/pipeline/test_kg_builder.py | 57 ++++- tests/unit/retrievers/test_hybrid.py | 26 +- 15 files changed, 500 insertions(+), 66 deletions(-) diff --git a/.github/workflows/pr-e2e-tests.yaml b/.github/workflows/pr-e2e-tests.yaml index 79550dd6..16eecc2e 100644 --- a/.github/workflows/pr-e2e-tests.yaml +++ b/.github/workflows/pr-e2e-tests.yaml @@ -16,10 +16,8 @@ jobs: strategy: matrix: python-version: ['3.9', '3.12'] - neo4j-version: - - 5 - neo4j-edition: - - enterprise + neo4j-tag: + - 'latest' services: t2v-transformers: image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-all-MiniLM-L6-v2-onnx @@ -37,7 +35,7 @@ jobs: - 8080:8080 - 50051:50051 neo4j: - image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }} + image: neo4j:${{ matrix.neo4j-tag }} env: NEO4J_AUTH: neo4j/password NEO4J_ACCEPT_LICENSE_AGREEMENT: 'eval' @@ -93,7 +91,7 @@ jobs: - name: Run tests shell: bash run: | - if [[ "${{ matrix.neo4j-edition }}" == "community" ]]; then + if [[ "${{ matrix.neo4j-tag }}" == "latest" || "${{ matrix.neo4j-tag }}" == *-community ]]; then poetry run pytest -m 'not enterprise_only' ./tests/e2e else poetry run pytest ./tests/e2e diff --git a/.github/workflows/scheduled-e2e-tests.yaml b/.github/workflows/scheduled-e2e-tests.yaml index 77e495bc..094be3cb 100644 --- a/.github/workflows/scheduled-e2e-tests.yaml +++ b/.github/workflows/scheduled-e2e-tests.yaml @@ -13,11 +13,10 @@ jobs: strategy: matrix: python-version: ['3.9', '3.10', '3.11', '3.12'] - neo4j-version: - - 5 - neo4j-edition: - - community - - enterprise + neo4j-tag: + - '5-community' + - '5-enterprise' + - 'latest' services: t2v-transformers: image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-all-MiniLM-L6-v2-onnx @@ -41,7 +40,7 @@ jobs: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} neo4j: - image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }} + image: neo4j:${{ matrix.neo4j-tag }} env: NEO4J_AUTH: neo4j/password NEO4J_ACCEPT_LICENSE_AGREEMENT: 'eval' @@ -100,7 +99,7 @@ jobs: - name: Run tests shell: bash run: | - if [[ "${{ matrix.neo4j-edition }}" == "community" ]]; then + if [[ "${{ matrix.neo4j-tag }}" == "latest" || "${{ matrix.neo4j-tag }}" == *-community ]]; then poetry run pytest -m 'not enterprise_only' ./tests/e2e else poetry run pytest ./tests/e2e diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bc00e9a..1ba50441 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added - Made `relations` and `potential_schema` optional in `SchemaBuilder`. +- Added a check to prevent the use of deprecated Cypher syntax for Neo4j versions 5.23.0 and above. ## 1.1.0 diff --git a/src/neo4j_graphrag/experimental/components/kg_writer.py b/src/neo4j_graphrag/experimental/components/kg_writer.py index 19ebcdd7..091ca14c 100644 --- a/src/neo4j_graphrag/experimental/components/kg_writer.py +++ b/src/neo4j_graphrag/experimental/components/kg_writer.py @@ -33,7 +33,12 @@ Neo4jRelationship, ) from neo4j_graphrag.experimental.pipeline.component import Component, DataModel -from neo4j_graphrag.neo4j_queries import UPSERT_NODE_QUERY, UPSERT_RELATIONSHIP_QUERY +from neo4j_graphrag.neo4j_queries import ( + UPSERT_NODE_QUERY, + UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, + UPSERT_RELATIONSHIP_QUERY, + UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, +) logger = logging.getLogger(__name__) @@ -113,6 +118,7 @@ def __init__( self.neo4j_database = neo4j_database self.batch_size = batch_size self.max_concurrency = max_concurrency + self.is_version_5_23_or_above = self._check_if_version_5_23_or_above() def _db_setup(self) -> None: # create index on __Entity__.id @@ -147,7 +153,12 @@ def _upsert_nodes(self, nodes: list[Neo4jNode]) -> None: nodes (list[Neo4jNode]): The nodes batch to upsert into the database. """ parameters = {"rows": self._nodes_to_rows(nodes)} - self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters) + if self.is_version_5_23_or_above: + self.driver.execute_query( + UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters + ) + else: + self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters) async def _async_upsert_nodes( self, @@ -161,7 +172,32 @@ async def _async_upsert_nodes( """ async with sem: parameters = {"rows": self._nodes_to_rows(nodes)} - await self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters) + await self.driver.execute_query( + UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters + ) + + def _get_version(self) -> tuple[int, ...]: + records, _, _ = self.driver.execute_query( + "CALL dbms.components()", database_=self.neo4j_database + ) + version = records[0]["versions"][0] + # Drop everything after the '-' first + version_main, *_ = version.split("-") + # Convert each number between '.' into int + version_tuple = tuple(map(int, version_main.split("."))) + # If no patch version, consider it's 0 + if len(version_tuple) < 3: + version_tuple = (*version_tuple, 0) + return version_tuple + + def _check_if_version_5_23_or_above(self) -> bool: + """ + Check if the connected Neo4j database version supports the required features. + + Sets a flag if the connected Neo4j version is 5.23 or above. + """ + version_tuple = self._get_version() + return version_tuple >= (5, 23, 0) def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None: """Upserts a single relationship into the Neo4j database. @@ -170,7 +206,12 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None: rels (list[Neo4jRelationship]): The relationships batch to upsert into the database. """ parameters = {"rows": [rel.model_dump() for rel in rels]} - self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters) + if self.is_version_5_23_or_above: + self.driver.execute_query( + UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters + ) + else: + self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters) async def _async_upsert_relationships( self, rels: list[Neo4jRelationship], sem: asyncio.Semaphore @@ -182,9 +223,15 @@ async def _async_upsert_relationships( """ async with sem: parameters = {"rows": [rel.model_dump() for rel in rels]} - await self.driver.execute_query( - UPSERT_RELATIONSHIP_QUERY, parameters_=parameters - ) + if self.is_version_5_23_or_above: + await self.driver.execute_query( + UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, + parameters_=parameters, + ) + else: + await self.driver.execute_query( + UPSERT_RELATIONSHIP_QUERY, parameters_=parameters + ) @validate_call async def run(self, graph: Neo4jGraph) -> KGWriterModel: @@ -193,12 +240,6 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel: Args: graph (Neo4jGraph): The knowledge graph to upsert into the database. """ - # we disable the notification logger to get rid of the deprecation - # warning about Cypher subqueries. Once the queries are updated - # for Neo4j 5.23, we can remove this line and the 'finally' block - notification_logger = logging.getLogger("neo4j.notifications") - notification_level = notification_logger.level - notification_logger.setLevel(logging.ERROR) try: if inspect.iscoroutinefunction(self.driver.execute_query): await self._async_db_setup() @@ -233,5 +274,3 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel: except neo4j.exceptions.ClientError as e: logger.exception(e) return KGWriterModel(status="FAILURE", metadata={"error": str(e)}) - finally: - notification_logger.setLevel(notification_level) diff --git a/src/neo4j_graphrag/neo4j_queries.py b/src/neo4j_graphrag/neo4j_queries.py index fd819e5c..243f74e1 100644 --- a/src/neo4j_graphrag/neo4j_queries.py +++ b/src/neo4j_graphrag/neo4j_queries.py @@ -55,6 +55,20 @@ "RETURN elementId(n)" ) +UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE = ( + "UNWIND $rows AS row " + "CREATE (n:__KGBuilder__ {id: row.id}) " + "SET n += row.properties " + "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " + "WITH node as n, row CALL (n, row) { " + "WITH n, row WITH n, row WHERE row.embedding_properties IS NOT NULL " + "UNWIND keys(row.embedding_properties) as emb " + "CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) " + "RETURN count(*) as nbEmb " + "} " + "RETURN elementId(n)" +) + UPSERT_RELATIONSHIP_QUERY = ( "UNWIND $rows as row " "MATCH (start:__KGBuilder__ {id: row.start_node_id}) " @@ -69,6 +83,21 @@ "RETURN elementId(rel)" ) + +UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE = ( + "UNWIND $rows as row " + "MATCH (start:__KGBuilder__ {id: row.start_node_id}) " + "MATCH (end:__KGBuilder__ {id: row.end_node_id}) " + "WITH start, end, row " + "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " + "WITH rel, row CALL (rel, row) { " + "WITH rel, row WITH rel, row WHERE row.embedding_properties IS NOT NULL " + "UNWIND keys(row.embedding_properties) as emb " + "CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) " + "} " + "RETURN elementId(rel)" +) + UPSERT_VECTOR_ON_NODE_QUERY = ( "MATCH (n) " "WHERE elementId(n) = $id " @@ -86,19 +115,33 @@ ) -def _get_hybrid_query() -> str: - return ( - f"CALL {{ {VECTOR_INDEX_QUERY} " - f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score " - f"UNWIND nodes AS n " - f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score " - f"UNION " - f"{FULL_TEXT_SEARCH_QUERY} " - f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score " - f"UNWIND nodes AS n " - f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} " - f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k" - ) +def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str: + if neo4j_version_is_5_23_or_above: + return ( + f"CALL () {{ {VECTOR_INDEX_QUERY} " + f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score " + f"UNWIND nodes AS n " + f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score " + f"UNION " + f"{FULL_TEXT_SEARCH_QUERY} " + f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score " + f"UNWIND nodes AS n " + f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} " + f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k" + ) + else: + return ( + f"CALL {{ {VECTOR_INDEX_QUERY} " + f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score " + f"UNWIND nodes AS n " + f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score " + f"UNION " + f"{FULL_TEXT_SEARCH_QUERY} " + f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score " + f"UNWIND nodes AS n " + f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} " + f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k" + ) def _get_filtered_vector_query( @@ -139,6 +182,7 @@ def get_search_query( embedding_node_property: Optional[str] = None, embedding_dimension: Optional[int] = None, filters: Optional[dict[str, Any]] = None, + neo4j_version_is_5_23_or_above: bool = False, ) -> tuple[str, dict[str, Any]]: """Build the search query, including pre-filtering if needed, and return clause. @@ -160,7 +204,7 @@ def get_search_query( if search_type == SearchType.HYBRID: if filters: raise Exception("Filters are not supported with Hybrid Search") - query = _get_hybrid_query() + query = _get_hybrid_query(neo4j_version_is_5_23_or_above) params: dict[str, Any] = {} elif search_type == SearchType.VECTOR: if filters: diff --git a/src/neo4j_graphrag/retrievers/base.py b/src/neo4j_graphrag/retrievers/base.py index 17078b54..e4d2a9e4 100644 --- a/src/neo4j_graphrag/retrievers/base.py +++ b/src/neo4j_graphrag/retrievers/base.py @@ -101,6 +101,14 @@ def _get_version(self) -> tuple[tuple[int, ...], bool]: version_tuple = (*version_tuple, 0) return version_tuple, "aura" in version + def _check_if_version_5_23_or_above(self, version_tuple: tuple[int, ...]) -> bool: + """ + Check if the connected Neo4j database version supports the required features. + + Sets a flag if the connected Neo4j version is 5.23 or above. + """ + return version_tuple >= (5, 23, 0) + def _verify_version(self) -> None: """ Check if the connected Neo4j database version supports vector indexing. @@ -111,6 +119,9 @@ def _verify_version(self) -> None: not supported. """ version_tuple, is_aura = self._get_version() + self.neo4j_version_is_5_23_or_above = self._check_if_version_5_23_or_above( + version_tuple + ) if is_aura: target_version = (5, 18, 0) diff --git a/src/neo4j_graphrag/retrievers/external/pinecone/types.py b/src/neo4j_graphrag/retrievers/external/pinecone/types.py index 3c458296..397894fd 100644 --- a/src/neo4j_graphrag/retrievers/external/pinecone/types.py +++ b/src/neo4j_graphrag/retrievers/external/pinecone/types.py @@ -17,10 +17,7 @@ from typing import Any, Callable, Optional, Union import neo4j - - from pinecone import Pinecone - from pydantic import ( BaseModel, ConfigDict, diff --git a/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py b/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py index 85377d82..4a777a33 100644 --- a/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py +++ b/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Optional import neo4j import weaviate.classes as wvc diff --git a/src/neo4j_graphrag/retrievers/hybrid.py b/src/neo4j_graphrag/retrievers/hybrid.py index 5c14f277..54bec9ce 100644 --- a/src/neo4j_graphrag/retrievers/hybrid.py +++ b/src/neo4j_graphrag/retrievers/hybrid.py @@ -184,7 +184,11 @@ def get_search_results( query_vector = self.embedder.embed_query(query_text) parameters["query_vector"] = query_vector - search_query, _ = get_search_query(SearchType.HYBRID, self.return_properties) + search_query, _ = get_search_query( + SearchType.HYBRID, + self.return_properties, + neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above, + ) logger.debug("HybridRetriever Cypher parameters: %s", parameters) logger.debug("HybridRetriever Cypher query: %s", search_query) @@ -336,7 +340,9 @@ def get_search_results( del parameters["query_params"] search_query, _ = get_search_query( - SearchType.HYBRID, retrieval_query=self.retrieval_query + SearchType.HYBRID, + retrieval_query=self.retrieval_query, + neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above, ) logger.debug("HybridCypherRetriever Cypher parameters: %s", parameters) diff --git a/tests/e2e/docker-compose.yml b/tests/e2e/docker-compose.yml index 1ffdb449..bf4a683a 100644 --- a/tests/e2e/docker-compose.yml +++ b/tests/e2e/docker-compose.yml @@ -26,7 +26,7 @@ services: environment: ENABLE_CUDA: "0" neo4j: - image: neo4j:5-enterprise + image: neo4j:5.24-enterprise ports: - 7687:7687 - 7474:7474 diff --git a/tests/e2e/test_hybrid_e2e.py b/tests/e2e/test_hybrid_e2e.py index 4abefecc..36b32c92 100644 --- a/tests/e2e/test_hybrid_e2e.py +++ b/tests/e2e/test_hybrid_e2e.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import pytest from neo4j import Driver @@ -40,6 +41,26 @@ def test_hybrid_retriever_search_text( assert isinstance(result, RetrieverResultItem) +@pytest.mark.usefixtures("setup_neo4j_for_retrieval") +def test_hybrid_retriever_no_neo4j_deprecation_warning( + driver: Driver, random_embedder: Embedder, caplog: pytest.LogCaptureFixture +) -> None: + retriever = HybridRetriever( + driver, "vector-index-name", "fulltext-index-name", random_embedder + ) + + top_k = 5 + with caplog.at_level(logging.WARNING): + retriever.search(query_text="Find me a book about Fremen", top_k=top_k) + + for record in caplog.records: + if ( + "Neo.ClientNotification.Statement.FeatureDeprecationWarning" + in record.message + ): + assert False, f"Deprecation warning found in logs: {record.message}" + + @pytest.mark.usefixtures("setup_neo4j_for_retrieval") def test_hybrid_cypher_retriever_search_text( driver: Driver, random_embedder: Embedder diff --git a/tests/e2e/test_kg_writer_component_e2e.py b/tests/e2e/test_kg_writer_component_e2e.py index 6388dccc..2fc0ab90 100644 --- a/tests/e2e/test_kg_writer_component_e2e.py +++ b/tests/e2e/test_kg_writer_component_e2e.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import neo4j import pytest @@ -100,3 +101,42 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: for key, val in node_with_two_embeddings.embedding_properties.items(): assert key in node_c.keys() assert val == node_c.get(key) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_neo4j_for_kg_construction") +async def test_kg_writer_no_neo4j_deprecation_warning( + driver: neo4j.Driver, caplog: pytest.LogCaptureFixture +) -> None: + start_node = Neo4jNode( + id="1", + label="MyLabel", + properties={"chunk": 1}, + embedding_properties={"vectorProperty": [1.0, 2.0, 3.0]}, + ) + end_node = Neo4jNode( + id="2", + label="MyLabel", + properties={}, + embedding_properties=None, + ) + relationship = Neo4jRelationship( + start_node_id="1", end_node_id="2", type="MY_RELATIONSHIP" + ) + graph = Neo4jGraph( + nodes=[start_node, end_node], + relationships=[relationship], + ) + + neo4j_writer = Neo4jWriter(driver=driver) + with caplog.at_level(logging.WARNING): + res = await neo4j_writer.run(graph=graph) + + for record in caplog.records: + if ( + "Neo.ClientNotification.Statement.FeatureDeprecationWarning" + in record.message + ): + assert False, f"Deprecation warning found in logs: {record.message}" + + assert res.status == "SUCCESS" diff --git a/tests/unit/experimental/components/test_kg_writer.py b/tests/unit/experimental/components/test_kg_writer.py index 94271d5a..ed1eb0e5 100644 --- a/tests/unit/experimental/components/test_kg_writer.py +++ b/tests/unit/experimental/components/test_kg_writer.py @@ -24,7 +24,12 @@ Neo4jNode, Neo4jRelationship, ) -from neo4j_graphrag.neo4j_queries import UPSERT_NODE_QUERY, UPSERT_RELATIONSHIP_QUERY +from neo4j_graphrag.neo4j_queries import ( + UPSERT_NODE_QUERY, + UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, + UPSERT_RELATIONSHIP_QUERY, + UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, +) def test_batched() -> None: @@ -41,11 +46,15 @@ def test_batched() -> None: ] +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 22, 0), +) @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._db_setup", return_value=None, ) -def test_upsert_nodes(driver: MagicMock) -> None: +def test_upsert_nodes(_: Mock, driver: MagicMock) -> None: neo4j_writer = Neo4jWriter(driver=driver) node = Neo4jNode(id="1", label="Label", properties={"key": "value"}) neo4j_writer._upsert_nodes(nodes=[node]) @@ -65,11 +74,16 @@ def test_upsert_nodes(driver: MagicMock) -> None: ) +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 22, 0), +) @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._db_setup", return_value=None, ) def test_upsert_nodes_with_embedding( + _: Mock, driver: MagicMock, ) -> None: neo4j_writer = Neo4jWriter(driver=driver) @@ -97,11 +111,15 @@ def test_upsert_nodes_with_embedding( ) +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 22, 0), +) @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._db_setup", return_value=None, ) -def test_upsert_relationship(driver: MagicMock) -> None: +def test_upsert_relationship(_: Mock, driver: MagicMock) -> None: neo4j_writer = Neo4jWriter(driver=driver) rel = Neo4jRelationship( start_node_id="1", @@ -127,6 +145,10 @@ def test_upsert_relationship(driver: MagicMock) -> None: ) +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 22, 0), +) @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._db_setup", return_value=None, @@ -159,6 +181,10 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: ) +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 22, 0), +) @pytest.mark.asyncio @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._db_setup", @@ -201,6 +227,10 @@ async def test_run(_: Mock, driver: MagicMock) -> None: ) +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 22, 0), +) @pytest.mark.asyncio @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._async_db_setup", @@ -241,3 +271,196 @@ async def test_run_async_driver(_: Mock, async_driver: MagicMock) -> None: UPSERT_RELATIONSHIP_QUERY, parameters_=parameters_, ) + + +@pytest.mark.asyncio +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._db_setup", + return_value=None, +) +async def test_run_is_version_below_5_23(_: Mock) -> None: + driver = MagicMock() + driver.execute_query = Mock(return_value=([{"versions": ["5.22.0"]}], None, None)) + + neo4j_writer = Neo4jWriter(driver=driver) + + node = Neo4jNode(id="1", label="Label") + rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") + graph = Neo4jGraph(nodes=[node], relationships=[rel]) + await neo4j_writer.run(graph=graph) + + driver.execute_query.assert_any_call( + UPSERT_NODE_QUERY, + parameters_={ + "rows": [ + { + "label": "Label", + "labels": ["Label", "__Entity__"], + "id": "1", + "properties": {}, + "embedding_properties": None, + } + ] + }, + ) + parameters_ = { + "rows": [ + { + "type": "RELATIONSHIP", + "start_node_id": "1", + "end_node_id": "2", + "properties": {}, + "embedding_properties": None, + } + ] + } + driver.execute_query.assert_any_call( + UPSERT_RELATIONSHIP_QUERY, + parameters_=parameters_, + ) + + +@pytest.mark.asyncio +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._db_setup", + return_value=None, +) +async def test_run_is_version_5_23_or_above(_: Mock) -> None: + driver = MagicMock() + driver.execute_query = Mock(return_value=([{"versions": ["5.23.0"]}], None, None)) + + neo4j_writer = Neo4jWriter(driver=driver) + neo4j_writer.is_version_5_23_or_above = True + + node = Neo4jNode(id="1", label="Label") + rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") + graph = Neo4jGraph(nodes=[node], relationships=[rel]) + await neo4j_writer.run(graph=graph) + + driver.execute_query.assert_any_call( + UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, + parameters_={ + "rows": [ + { + "label": "Label", + "labels": ["Label", "__Entity__"], + "id": "1", + "properties": {}, + "embedding_properties": None, + } + ] + }, + ) + parameters_ = { + "rows": [ + { + "type": "RELATIONSHIP", + "start_node_id": "1", + "end_node_id": "2", + "properties": {}, + "embedding_properties": None, + } + ] + } + driver.execute_query.assert_any_call( + UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, + parameters_=parameters_, + ) + + +@pytest.mark.asyncio +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._async_db_setup", + return_value=None, +) +async def test_run_async_driver_is_version_below_5_23(_: Mock) -> None: + async_driver = MagicMock() + async_driver.execute_query = Mock( + return_value=([{"versions": ["5.22.0"]}], None, None) + ) + + neo4j_writer = Neo4jWriter(driver=async_driver) + + node = Neo4jNode(id="1", label="Label") + rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") + graph = Neo4jGraph(nodes=[node], relationships=[rel]) + await neo4j_writer.run(graph=graph) + + async_driver.execute_query.assert_any_call( + UPSERT_NODE_QUERY, + parameters_={ + "rows": [ + { + "label": "Label", + "labels": ["Label", "__Entity__"], + "id": "1", + "properties": {}, + "embedding_properties": None, + } + ] + }, + ) + parameters_ = { + "rows": [ + { + "type": "RELATIONSHIP", + "start_node_id": "1", + "end_node_id": "2", + "properties": {}, + "embedding_properties": None, + } + ] + } + async_driver.execute_query.assert_any_call( + UPSERT_RELATIONSHIP_QUERY, + parameters_=parameters_, + ) + + +@pytest.mark.asyncio +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._async_db_setup", + return_value=None, +) +async def test_run_async_driver_is_version_5_23_or_above(_: Mock) -> None: + async_driver = MagicMock() + async_driver.execute_query = Mock( + return_value=([{"versions": ["5.23.0"]}], None, None) + ) + + neo4j_writer = Neo4jWriter(driver=async_driver) + + node = Neo4jNode(id="1", label="Label") + rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") + graph = Neo4jGraph(nodes=[node], relationships=[rel]) + await neo4j_writer.run(graph=graph) + + async_driver.execute_query.assert_any_call( + UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, + parameters_={ + "rows": [ + { + "label": "Label", + "labels": ["Label", "__Entity__"], + "id": "1", + "properties": {}, + "embedding_properties": None, + } + ] + }, + ) + parameters_ = { + "rows": [ + { + "type": "RELATIONSHIP", + "start_node_id": "1", + "end_node_id": "2", + "properties": {}, + "embedding_properties": None, + } + ] + } + async_driver.execute_query.assert_any_call( + UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, + parameters_=parameters_, + ) diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index 188180f2..64da47a0 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -12,7 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock, patch +from unittest import mock +from unittest.mock import MagicMock, Mock, patch import neo4j import pytest @@ -25,8 +26,12 @@ from neo4j_graphrag.llm.base import LLMInterface +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) @pytest.mark.asyncio -async def test_knowledge_graph_builder_init_with_text() -> None: +async def test_knowledge_graph_builder_init_with_text(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -60,8 +65,12 @@ async def test_knowledge_graph_builder_init_with_text() -> None: assert pipe_inputs["splitter"]["text"] == text_input +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) @pytest.mark.asyncio -async def test_knowledge_graph_builder_init_with_file_path() -> None: +async def test_knowledge_graph_builder_init_with_file_path(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -94,8 +103,12 @@ async def test_knowledge_graph_builder_init_with_file_path() -> None: assert pipe_inputs["pdf_loader"]["filepath"] == file_path +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) @pytest.mark.asyncio -async def test_knowledge_graph_builder_run_with_both_inputs() -> None: +async def test_knowledge_graph_builder_run_with_both_inputs(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -118,8 +131,12 @@ async def test_knowledge_graph_builder_run_with_both_inputs() -> None: ) or "Expected 'text' argument when 'from_pdf' is False." in str(exc_info.value) +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) @pytest.mark.asyncio -async def test_knowledge_graph_builder_run_with_no_inputs() -> None: +async def test_knowledge_graph_builder_run_with_no_inputs(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -139,8 +156,12 @@ async def test_knowledge_graph_builder_run_with_no_inputs() -> None: ) or "Expected 'text' argument when 'from_pdf' is False." in str(exc_info.value) +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) @pytest.mark.asyncio -async def test_knowledge_graph_builder_document_info_with_file() -> None: +async def test_knowledge_graph_builder_document_info_with_file(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -167,8 +188,12 @@ async def test_knowledge_graph_builder_document_info_with_file() -> None: assert "extractor" not in pipe_inputs +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) @pytest.mark.asyncio -async def test_knowledge_graph_builder_document_info_with_text() -> None: +async def test_knowledge_graph_builder_document_info_with_text(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -194,8 +219,12 @@ async def test_knowledge_graph_builder_document_info_with_text() -> None: assert pipe_inputs["splitter"] == {"text": text_input} +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) @pytest.mark.asyncio -async def test_knowledge_graph_builder_with_entities_and_file() -> None: +async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -234,7 +263,11 @@ async def test_knowledge_graph_builder_with_entities_and_file() -> None: assert pipe_inputs["schema"]["potential_schema"] == potential_schema -def test_simple_kg_pipeline_on_error_conversion() -> None: +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) +def test_simple_kg_pipeline_on_error_conversion(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -265,7 +298,11 @@ def test_simple_kg_pipeline_on_error_invalid_value() -> None: assert "Expected one of ['RAISE', 'IGNORE']" in str(exc_info.value) -def test_simple_kg_pipeline_no_entity_resolution() -> None: +@mock.patch( + "neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter._get_version", + return_value=(5, 23, 0), +) +def test_simple_kg_pipeline_no_entity_resolution(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index c9f2f9b0..0ea0c104 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -87,6 +87,7 @@ def test_hybrid_retriever_with_result_format_function( embedder, result_formatter=result_formatter, ) + retriever.neo4j_version_is_5_23_or_above = True retriever.driver.execute_query.return_value = [ # type: ignore [neo4j_record], None, @@ -174,12 +175,16 @@ def test_hybrid_search_text_happy_path( retriever = HybridRetriever( driver, vector_index_name, fulltext_index_name, embedder ) + retriever.neo4j_version_is_5_23_or_above = True retriever.driver.execute_query.return_value = [ # type: ignore [neo4j_record], None, None, ] - search_query, _ = get_search_query(SearchType.HYBRID) + search_query, _ = get_search_query( + SearchType.HYBRID, + neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above, + ) records = retriever.search(query_text=query_text, top_k=top_k) @@ -226,12 +231,16 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector( embedder, neo4j_database=database, ) + retriever.neo4j_version_is_5_23_or_above = True retriever.driver.execute_query.return_value = [ # type: ignore [neo4j_record], None, None, ] - search_query, _ = get_search_query(SearchType.HYBRID) + search_query, _ = get_search_query( + SearchType.HYBRID, + neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above, + ) retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k) @@ -300,12 +309,17 @@ def test_hybrid_retriever_return_properties( embedder, return_properties, ) + retriever.neo4j_version_is_5_23_or_above = True driver.execute_query.return_value = [ [neo4j_record], None, None, ] - search_query, _ = get_search_query(SearchType.HYBRID, return_properties) + search_query, _ = get_search_query( + SearchType.HYBRID, + return_properties, + neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above, + ) records = retriever.search(query_text=query_text, top_k=top_k) @@ -355,13 +369,16 @@ def test_hybrid_cypher_retrieval_query_with_params( retrieval_query, embedder, ) + retriever.neo4j_version_is_5_23_or_above = True driver.execute_query.return_value = [ [neo4j_record], None, None, ] search_query, _ = get_search_query( - SearchType.HYBRID, retrieval_query=retrieval_query + SearchType.HYBRID, + retrieval_query=retrieval_query, + neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above, ) records = retriever.search( @@ -419,6 +436,7 @@ def test_hybrid_cypher_retriever_with_result_format_function( embedder, result_formatter=result_formatter, ) + retriever.neo4j_version_is_5_23_or_above = True retriever.driver.execute_query.return_value = [ # type: ignore [neo4j_record], None,