diff --git a/nano_graphrag/storage/asyncpg.py b/nano_graphrag/storage/asyncpg.py index 29e85e6..2b1a0a6 100644 --- a/nano_graphrag/storage/asyncpg.py +++ b/nano_graphrag/storage/asyncpg.py @@ -11,7 +11,7 @@ import nest_asyncio nest_asyncio.apply() -class AsyncpgVectorStorage(BaseVectorStorage): +class AsyncPGVectorStorage(BaseVectorStorage): table_name_generator: callable = None conn_fetcher: callable = None cosine_better_than_threshold: float = 0.2 @@ -34,15 +34,16 @@ def __init__(self, dsn: str = None, conn_fetcher: callable = None, table_name_ge loop = always_get_an_event_loop() loop.run_until_complete(self._secure_table()) @asynccontextmanager - async def __get_conn(self): + async def __get_conn(self, vector_register=True): try: conn: asyncpg.Connection = await asyncpg.connect(self.dsn) - await register_vector(conn) + if vector_register: + await register_vector(conn) yield conn finally: await conn.close() async def _secure_table(self): - async with self.conn_fetcher() as conn: + async with self.conn_fetcher(vector_register=False) as conn: conn: asyncpg.Connection await conn.execute('CREATE EXTENSION IF NOT EXISTS vector') result = await conn.fetch( diff --git a/tests/test_asyncpg_vector_storage.py b/tests/test_asyncpg_vector_storage.py index 932cf9e..db6de25 100644 --- a/tests/test_asyncpg_vector_storage.py +++ b/tests/test_asyncpg_vector_storage.py @@ -4,7 +4,7 @@ from nano_graphrag import GraphRAG from nano_graphrag._utils import wrap_embedding_func_with_attrs -from nano_graphrag.storage.asyncpg import AsyncpgVectorStorage +from nano_graphrag.storage.asyncpg import AsyncPGVectorStorage import asyncpg from nano_graphrag.graphrag import always_get_an_event_loop import os @@ -36,7 +36,7 @@ async def mock_embedding(texts: list[str]) -> np.ndarray: @pytest.fixture def asyncpg_storage(setup_teardown): rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) - return AsyncpgVectorStorage( + return AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, @@ -67,7 +67,7 @@ async def test_upsert_and_query(asyncpg_storage): @pytest.mark.asyncio async def test_persistence(setup_teardown): rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) - initial_storage = AsyncpgVectorStorage( + initial_storage = AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, @@ -82,7 +82,7 @@ async def test_persistence(setup_teardown): await initial_storage.upsert(test_data) await initial_storage.index_done_callback() - new_storage = AsyncpgVectorStorage( + new_storage = AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, @@ -100,7 +100,7 @@ async def test_persistence(setup_teardown): @pytest.mark.asyncio async def test_persistence_large_dataset(setup_teardown): rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) - initial_storage = AsyncpgVectorStorage( + initial_storage = AsyncPGVectorStorage( namespace="test_large", global_config=asdict(rag), embedding_func=mock_embedding, @@ -115,7 +115,7 @@ async def test_persistence_large_dataset(setup_teardown): await initial_storage.upsert(large_data) await initial_storage.index_done_callback() - new_storage = AsyncpgVectorStorage( + new_storage = AsyncPGVectorStorage( namespace="test_large", global_config=asdict(rag), embedding_func=mock_embedding,