Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Dorbmon committed Sep 17, 2024
1 parent 93366a2 commit 3753c7b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
9 changes: 5 additions & 4 deletions nano_graphrag/storage/asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import nest_asyncio
nest_asyncio.apply()

class AsyncpgVectorStorage(BaseVectorStorage):
class AsyncPGVectorStorage(BaseVectorStorage):
table_name_generator: callable = None
conn_fetcher: callable = None
cosine_better_than_threshold: float = 0.2
Expand All @@ -34,15 +34,16 @@ def __init__(self, dsn: str = None, conn_fetcher: callable = None, table_name_ge
loop = always_get_an_event_loop()
loop.run_until_complete(self._secure_table())
@asynccontextmanager
async def __get_conn(self):
async def __get_conn(self, vector_register=True):
try:
conn: asyncpg.Connection = await asyncpg.connect(self.dsn)
await register_vector(conn)
if vector_register:
await register_vector(conn)
yield conn
finally:
await conn.close()
async def _secure_table(self):
async with self.conn_fetcher() as conn:
async with self.conn_fetcher(vector_register=False) as conn:
conn: asyncpg.Connection
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
result = await conn.fetch(
Expand Down
12 changes: 6 additions & 6 deletions tests/test_asyncpg_vector_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from nano_graphrag import GraphRAG
from nano_graphrag._utils import wrap_embedding_func_with_attrs

from nano_graphrag.storage.asyncpg import AsyncpgVectorStorage
from nano_graphrag.storage.asyncpg import AsyncPGVectorStorage
import asyncpg
from nano_graphrag.graphrag import always_get_an_event_loop
import os
Expand Down Expand Up @@ -36,7 +36,7 @@ async def mock_embedding(texts: list[str]) -> np.ndarray:
@pytest.fixture
def asyncpg_storage(setup_teardown):
rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding)
return AsyncpgVectorStorage(
return AsyncPGVectorStorage(
namespace="test",
global_config=asdict(rag),
embedding_func=mock_embedding,
Expand Down Expand Up @@ -67,7 +67,7 @@ async def test_upsert_and_query(asyncpg_storage):
@pytest.mark.asyncio
async def test_persistence(setup_teardown):
rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding)
initial_storage = AsyncpgVectorStorage(
initial_storage = AsyncPGVectorStorage(
namespace="test",
global_config=asdict(rag),
embedding_func=mock_embedding,
Expand All @@ -82,7 +82,7 @@ async def test_persistence(setup_teardown):
await initial_storage.upsert(test_data)
await initial_storage.index_done_callback()

new_storage = AsyncpgVectorStorage(
new_storage = AsyncPGVectorStorage(
namespace="test",
global_config=asdict(rag),
embedding_func=mock_embedding,
Expand All @@ -100,7 +100,7 @@ async def test_persistence(setup_teardown):
@pytest.mark.asyncio
async def test_persistence_large_dataset(setup_teardown):
rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding)
initial_storage = AsyncpgVectorStorage(
initial_storage = AsyncPGVectorStorage(
namespace="test_large",
global_config=asdict(rag),
embedding_func=mock_embedding,
Expand All @@ -115,7 +115,7 @@ async def test_persistence_large_dataset(setup_teardown):
await initial_storage.upsert(large_data)
await initial_storage.index_done_callback()

new_storage = AsyncpgVectorStorage(
new_storage = AsyncPGVectorStorage(
namespace="test_large",
global_config=asdict(rag),
embedding_func=mock_embedding,
Expand Down

0 comments on commit 3753c7b

Please sign in to comment.