From 0b4013d9ccf7c496dfeadd8236748a032c613749 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Tue, 17 Sep 2024 18:40:08 +0800 Subject: [PATCH] change name & fix order --- nano_graphrag/storage/asyncpg.py | 9 +- nano_graphrag/storage/neo4j.py | 129 +++++++++++++++++++++++++++ tests/test_asyncpg_vector_storage.py | 12 +-- 3 files changed, 140 insertions(+), 10 deletions(-) create mode 100644 nano_graphrag/storage/neo4j.py diff --git a/nano_graphrag/storage/asyncpg.py b/nano_graphrag/storage/asyncpg.py index 29e85e6..7d81a6a 100644 --- a/nano_graphrag/storage/asyncpg.py +++ b/nano_graphrag/storage/asyncpg.py @@ -11,7 +11,7 @@ import nest_asyncio nest_asyncio.apply() -class AsyncpgVectorStorage(BaseVectorStorage): +class AsyncPGVectorStorage(BaseVectorStorage): table_name_generator: callable = None conn_fetcher: callable = None cosine_better_than_threshold: float = 0.2 @@ -34,15 +34,16 @@ def __init__(self, dsn: str = None, conn_fetcher: callable = None, table_name_ge loop = always_get_an_event_loop() loop.run_until_complete(self._secure_table()) @asynccontextmanager - async def __get_conn(self): + async def __get_conn(self, vector_register=True): try: conn: asyncpg.Connection = await asyncpg.connect(self.dsn) - await register_vector(conn) + if vector_register: + await register_vector(conn) yield conn finally: await conn.close() async def _secure_table(self): - async with self.conn_fetcher() as conn: + async with self.conn_fetcher(register_vector=False) as conn: conn: asyncpg.Connection await conn.execute('CREATE EXTENSION IF NOT EXISTS vector') result = await conn.fetch( diff --git a/nano_graphrag/storage/neo4j.py b/nano_graphrag/storage/neo4j.py new file mode 100644 index 0000000..5de7973 --- /dev/null +++ b/nano_graphrag/storage/neo4j.py @@ -0,0 +1,129 @@ +from nano_graphrag._storage import BaseGraphStorage +from neo4j import AsyncGraphDatabase +import neo4j +from typing import Union +from nano_graphrag.graphrag import always_get_an_event_loop +from nano_graphrag.base import SingleCommunitySchema +import numpy as np + +import nest_asyncio +nest_asyncio.apply() + +class NetworkXStorage(BaseGraphStorage): + def __init__(self, uri: str, user: str, password: str): + self._driver: neo4j.AsyncDriver = AsyncGraphDatabase(uri, auth=(user, password)) + loop = always_get_an_event_loop() + loop.run_until_complete(self._secure_table()) + async def _secure_table(self): + async with self._driver.session() as session: + await session.run("CREATE CONSTRAINT ON (n:_id) ASSERT n._id IS UNIQUE;") + async def has_node(self, node_id: str) -> bool: + query = "MATCH (n) WHERE n._id = $node_id RETURN n IS NOT NULL AS nodeExists" + + async with self._driver.session() as session: + result = await session.run(query, node_id=node_id) + record = await result.single() + if record: + return record["nodeExists"] + else: + return False + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + query = ( + "MATCH (n1)-[r]-(n2) " + "WHERE n1._id = $node1_id AND n2._id = $node2_id " + "RETURN COUNT(r) > 0 AS relationshipExists" + ) + + async with self._driver.session() as session: + result = await session.run(query, node1_id=source_node_id, node2_id=target_node_id) + record = await result.single() + if record: + return record["relationshipExists"] + else: + return False + + async def node_degree(self, node_id: str) -> int: + query = ( + "MATCH (n)-[r]-() " + "WHERE n._id = $node_id " + "RETURN count(r) AS degree" + ) + + async with self._driver.session() as session: + result = await session.run(query, node_id=node_id) + record = await result.single() + if record: + return record["degree"] + else: + return 0 + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + async with self._driver.session() as session: + src_degree = (await (await session.run("MATCH (n) WHERE n._id = $src_id RETURN size((n)--()) AS degree", src_id=src_id)).single())["degree"] + tgt_degree = (await (await session.run("MATCH (n) WHERE n._id = $tgt_id RETURN size((n)--()) AS degree", tgt_id=tgt_id)).single())["degree"] + + return src_degree + tgt_degree + + async def get_node(self, node_id: str) -> Union[dict, None]: + async with self._driver.session() as session: + result = await session.run("MATCH (n) WHERE n._id = $node_id RETURN n", node_id=node_id) + record = await result.single() + if record: + node = record["n"] + properties = dict(node) + return properties + else: + return None + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> Union[dict, None]: + async with self._driver.session() as session: + result = await session.run("MATCH (start)-[r]-(end) WHERE id(start) = $start_node_id AND id(end) = $end_node_id RETURN r", start_node_id=source_node_id, end_node_id=target_node_id) + record = await result.single() + if not record: + return None + relationship = record["r"] + properties = dict(relationship) + + return properties + + async def get_node_edges( + self, source_node_id: str + ) -> Union[list[tuple[str, str]], None]: + async with self._driver.session() as session: + result = await session.run("MATCH (startNode)-[]->(endNode) " + "WHERE id(startNode) = $start_node_id " + "RETURN endNode", start_node_id=source_node_id) + return [(source_node_id, record["endNode"]) for record in result] + + async def upsert_node(self, node_id: str, node_data: dict[str, str]): + node_data['_id'] = node_id + query = "MERGE (n:Node {id: $_id}) SET n += $props RETURN id(n)" + async with self._driver.session() as session: + async with session.begin_transaction() as tx: + await tx.run(query, _id=node_data['_id'], props=node_data) + + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + async with self._driver.session() as session: + async with session.begin_transaction() as tx: + query = ( + "MATCH (source), (target) " + "WHERE source.id = $source_node_id AND target.id = $target_node_id " + "MERGE (source)-[edge:YOUR_RELATIONSHIP_TYPE]->(target) " + "SET edge += $edge_data" + ) + await tx.run(query, source_node_id=source_node_id, target_node_id=target_node_id, edge_data=edge_data) + + async def clustering(self, algorithm: str): + raise NotImplementedError + + async def community_schema(self) -> dict[str, SingleCommunitySchema]: + """Return the community representation with report and nodes""" + raise NotImplementedError + + async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]: + raise NotImplementedError("Node embedding is not used in nano-graphrag.") diff --git a/tests/test_asyncpg_vector_storage.py b/tests/test_asyncpg_vector_storage.py index 932cf9e..db6de25 100644 --- a/tests/test_asyncpg_vector_storage.py +++ b/tests/test_asyncpg_vector_storage.py @@ -4,7 +4,7 @@ from nano_graphrag import GraphRAG from nano_graphrag._utils import wrap_embedding_func_with_attrs -from nano_graphrag.storage.asyncpg import AsyncpgVectorStorage +from nano_graphrag.storage.asyncpg import AsyncPGVectorStorage import asyncpg from nano_graphrag.graphrag import always_get_an_event_loop import os @@ -36,7 +36,7 @@ async def mock_embedding(texts: list[str]) -> np.ndarray: @pytest.fixture def asyncpg_storage(setup_teardown): rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) - return AsyncpgVectorStorage( + return AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, @@ -67,7 +67,7 @@ async def test_upsert_and_query(asyncpg_storage): @pytest.mark.asyncio async def test_persistence(setup_teardown): rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) - initial_storage = AsyncpgVectorStorage( + initial_storage = AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, @@ -82,7 +82,7 @@ async def test_persistence(setup_teardown): await initial_storage.upsert(test_data) await initial_storage.index_done_callback() - new_storage = AsyncpgVectorStorage( + new_storage = AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, @@ -100,7 +100,7 @@ async def test_persistence(setup_teardown): @pytest.mark.asyncio async def test_persistence_large_dataset(setup_teardown): rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) - initial_storage = AsyncpgVectorStorage( + initial_storage = AsyncPGVectorStorage( namespace="test_large", global_config=asdict(rag), embedding_func=mock_embedding, @@ -115,7 +115,7 @@ async def test_persistence_large_dataset(setup_teardown): await initial_storage.upsert(large_data) await initial_storage.index_done_callback() - new_storage = AsyncpgVectorStorage( + new_storage = AsyncPGVectorStorage( namespace="test_large", global_config=asdict(rag), embedding_func=mock_embedding,