From 56fdfb8be8e8bab6e3c84067860e1f8de3229622 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Sun, 15 Sep 2024 22:29:28 +0800 Subject: [PATCH 01/10] implement asyncpg support for vector db --- nano_graphrag/storage/asyncpg.py | 100 +++++++++++++ requirements.txt | 2 + tests/test_asyncpg_vector_storage.py | 204 +++++++++++++++++++++++++++ 3 files changed, 306 insertions(+) create mode 100644 nano_graphrag/storage/asyncpg.py create mode 100644 tests/test_asyncpg_vector_storage.py diff --git a/nano_graphrag/storage/asyncpg.py b/nano_graphrag/storage/asyncpg.py new file mode 100644 index 0000000..29e85e6 --- /dev/null +++ b/nano_graphrag/storage/asyncpg.py @@ -0,0 +1,100 @@ +from nano_graphrag._storage import BaseVectorStorage +import asyncpg +import asyncio +from contextlib import asynccontextmanager +from nano_graphrag._utils import logger +from pgvector.asyncpg import register_vector +from nano_graphrag.graphrag import always_get_an_event_loop +import numpy as np +import json + +import nest_asyncio +nest_asyncio.apply() + +class AsyncpgVectorStorage(BaseVectorStorage): + table_name_generator: callable = None + conn_fetcher: callable = None + cosine_better_than_threshold: float = 0.2 + dsn = None + def __init__(self, dsn: str = None, conn_fetcher: callable = None, table_name_generator: callable = None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dsn = dsn + self.conn_fetcher = conn_fetcher + assert self.dsn != None or self.conn_fetcher != None, "Must provide either dsn or conn_fetcher" + if self.dsn: + self.conn_fetcher = self.__get_conn + if not table_name_generator: + self.table_name_generator = lambda working_dir, namespace: f'{working_dir}_{namespace}_vdb' + self._table_name = self.table_name_generator(self.global_config["working_dir"], self.namespace) + self._max_batch_size = self.global_config["embedding_batch_num"] + + self.cosine_better_than_threshold = self.global_config.get( + "query_better_than_threshold", self.cosine_better_than_threshold + ) + loop = always_get_an_event_loop() + loop.run_until_complete(self._secure_table()) + @asynccontextmanager + async def __get_conn(self): + try: + conn: asyncpg.Connection = await asyncpg.connect(self.dsn) + await register_vector(conn) + yield conn + finally: + await conn.close() + async def _secure_table(self): + async with self.conn_fetcher() as conn: + conn: asyncpg.Connection + await conn.execute('CREATE EXTENSION IF NOT EXISTS vector') + result = await conn.fetch( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = $1)", self._table_name) + table_exists = result[0]['exists'] + if not table_exists: + # create the table + await conn.execute(f'CREATE TABLE {self._table_name} (id text PRIMARY KEY, embedding vector({self.embedding_func.embedding_dim}), data jsonb)') + await conn.execute(f'CREATE INDEX ON {self._table_name} USING hnsw (embedding vector_cosine_ops)') + async def query(self, query: str, top_k: int) -> list[dict]: + embedding = await self.embedding_func([query]) + embedding = embedding[0] + async with self.conn_fetcher() as conn: + + result = await conn.fetch(f'SELECT embedding <=> $1 as similarity, id, embedding, data FROM {self._table_name} WHERE embedding <=> $1 > $3 ORDER BY embedding <=> $1 DESC LIMIT $2', embedding, top_k, self.cosine_better_than_threshold) + + rows = [] + for row in result: + data = json.loads(row['data']) + rows.append({ + **data, + 'id': row['id'], + 'distance': 1 - row['similarity'], + 'similarity': row['similarity'] + }) + return rows + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + list_data = [ + { + "__id__": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings_list = np.concatenate(embeddings_list) + insert_rows = [] + for i, d in enumerate(list_data): + row = [d["__id__"], embeddings_list[i], json.dumps(d)] + insert_rows.append(row) + async with self.conn_fetcher() as conn: + conn: asyncpg.Connection + stmt = f"INSERT INTO {self._table_name} (id, embedding, data) VALUES ($1, $2, $3) ON CONFLICT (id) DO UPDATE SET embedding = $2, data = $3" + return await conn.executemany(stmt, insert_rows) diff --git a/requirements.txt b/requirements.txt index 014ecba..5e1e241 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,5 @@ hnswlib xxhash tenacity dspy-ai +pgvector==0.3.3 +asyncpg==0.29.0 \ No newline at end of file diff --git a/tests/test_asyncpg_vector_storage.py b/tests/test_asyncpg_vector_storage.py new file mode 100644 index 0000000..6f0f9e6 --- /dev/null +++ b/tests/test_asyncpg_vector_storage.py @@ -0,0 +1,204 @@ +import numpy as np +import pytest +from dataclasses import asdict +from nano_graphrag import GraphRAG +from nano_graphrag._utils import wrap_embedding_func_with_attrs + +from nano_graphrag.storage.asyncpg import AsyncpgVectorStorage +import asyncpg +from nano_graphrag.graphrag import always_get_an_event_loop +WORKING_DIR = "nano_graphrag_cache_asyncpg_vector_storage_test" +dsn='postgresql://username:password@127.0.0.1:12345/db' + +@pytest.fixture(scope="function") +def setup_teardown(): + + yield + loop = always_get_an_event_loop() + async def clean_table(): + conn: asyncpg.Connection = await asyncpg.connect(dsn) + async with conn.transaction(): + tables = await conn.fetch( + f"SELECT table_name FROM information_schema.tables WHERE table_name LIKE '{WORKING_DIR}%'" + ) + + for table in tables: + await conn.execute(f"DROP TABLE {table['table_name']} CASCADE") + loop.run_until_complete(clean_table()) + + +@wrap_embedding_func_with_attrs(embedding_dim=384, max_token_size=8192) +async def mock_embedding(texts: list[str]) -> np.ndarray: + return np.random.rand(len(texts), 384) + + +@pytest.fixture +def asyncpg_storage(setup_teardown): + rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) + return AsyncpgVectorStorage( + namespace="test", + global_config=asdict(rag), + embedding_func=mock_embedding, + meta_fields={"entity_name"}, + dsn=dsn + ) + + +@pytest.mark.asyncio +async def test_upsert_and_query(asyncpg_storage): + test_data = { + "1": {"content": "Test content 1", "entity_name": "Entity 1"}, + "2": {"content": "Test content 2", "entity_name": "Entity 2"}, + } + + await asyncpg_storage.upsert(test_data) + + results = await asyncpg_storage.query("Test query", top_k=2) + + assert len(results) == 2 + assert all(isinstance(result, dict) for result in results) + assert all( + "id" in result and "distance" in result and "similarity" in result + for result in results + ) + + +@pytest.mark.asyncio +async def test_persistence(setup_teardown): + rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) + initial_storage = AsyncpgVectorStorage( + namespace="test", + global_config=asdict(rag), + embedding_func=mock_embedding, + meta_fields={"entity_name"}, + dsn=dsn + ) + + test_data = { + "1": {"content": "Test content 1", "entity_name": "Entity 1"}, + } + + await initial_storage.upsert(test_data) + await initial_storage.index_done_callback() + + new_storage = AsyncpgVectorStorage( + namespace="test", + global_config=asdict(rag), + embedding_func=mock_embedding, + meta_fields={"entity_name"}, + dsn=dsn + ) + + results = await new_storage.query("Test query", top_k=1) + + assert len(results) == 1 + assert results[0]["id"] == "1" + assert "entity_name" in results[0] + + +@pytest.mark.asyncio +async def test_persistence_large_dataset(setup_teardown): + rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) + initial_storage = AsyncpgVectorStorage( + namespace="test_large", + global_config=asdict(rag), + embedding_func=mock_embedding, + meta_fields={"entity_name"}, + dsn=dsn + ) + + large_data = { + str(i): {"content": f"Test content {i}", "entity_name": f"Entity {i}"} + for i in range(1000) + } + await initial_storage.upsert(large_data) + await initial_storage.index_done_callback() + + new_storage = AsyncpgVectorStorage( + namespace="test_large", + global_config=asdict(rag), + embedding_func=mock_embedding, + meta_fields={"entity_name"}, + dsn=dsn + ) + + results = await new_storage.query("Test query", top_k=500) + assert len(results) == 500 + assert all(result["id"] in large_data for result in results) + + +@pytest.mark.asyncio +async def test_upsert_with_existing_ids(asyncpg_storage): + test_data = { + "1": {"content": "Test content 1", "entity_name": "Entity 1"}, + "2": {"content": "Test content 2", "entity_name": "Entity 2"}, + } + + await asyncpg_storage.upsert(test_data) + + updated_data = { + "1": {"content": "Updated content 1", "entity_name": "Updated Entity 1"}, + "3": {"content": "Test content 3", "entity_name": "Entity 3"}, + } + + await asyncpg_storage.upsert(updated_data) + + results = await asyncpg_storage.query("Updated", top_k=3) + + assert len(results) == 3 + assert any( + result["id"] == "1" and result["entity_name"] == "Updated Entity 1" + for result in results + ) + assert any( + result["id"] == "2" and result["entity_name"] == "Entity 2" + for result in results + ) + assert any( + result["id"] == "3" and result["entity_name"] == "Entity 3" + for result in results + ) + + +@pytest.mark.asyncio +async def test_large_batch_upsert(asyncpg_storage): + batch_size = 30 + large_data = { + str(i): {"content": f"Test content {i}", "entity_name": f"Entity {i}"} + for i in range(batch_size) + } + + await asyncpg_storage.upsert(large_data) + + results = await asyncpg_storage.query("Test query", top_k=batch_size) + assert len(results) == batch_size + assert all(isinstance(result, dict) for result in results) + assert all( + "id" in result and "distance" in result and "similarity" in result + for result in results + ) + + +@pytest.mark.asyncio +async def test_empty_data_insertion(asyncpg_storage): + empty_data = {} + await asyncpg_storage.upsert(empty_data) + + results = await asyncpg_storage.query("Test query", top_k=1) + assert len(results) == 0 + + +@pytest.mark.asyncio +async def test_query_with_no_results(asyncpg_storage): + results = await asyncpg_storage.query("Non-existent query", top_k=5) + assert len(results) == 0 + + test_data = { + "1": {"content": "Test content 1", "entity_name": "Entity 1"}, + } + await asyncpg_storage.upsert(test_data) + + results = await asyncpg_storage.query("Non-existent query", top_k=5) + assert len(results) == 1 + assert all(0 <= result["similarity"] <= 1 for result in results) + assert "entity_name" in results[0] From ad99ba9a4c9c387ff7c6c6b0d8fcdc2a24d19c84 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Tue, 17 Sep 2024 18:16:51 +0800 Subject: [PATCH 02/10] use nest_asyncio --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 5e1e241..d1a0f0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ xxhash tenacity dspy-ai pgvector==0.3.3 -asyncpg==0.29.0 \ No newline at end of file +asyncpg==0.29.0 +nest_asyncio==1.6.0 \ No newline at end of file From 0e9da7937b141a0b1d878332c7c84a8861ac2380 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Tue, 17 Sep 2024 18:23:50 +0800 Subject: [PATCH 03/10] make ci happy --- .github/workflows/test.yml | 9 +++++++++ tests/test_asyncpg_vector_storage.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e893b30..fd1dde2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,6 +41,15 @@ jobs: - name: Lint with flake8 run: | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + - uses: ikalnytskyi/action-setup-postgres@v6 + with: + username: test + password: test + database: test + port: 12345 + postgres-version: "14" + ssl: "on" + id: postgres - name: Build and Test run: | python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=nano_graphrag --cov-report=xml -v ./ diff --git a/tests/test_asyncpg_vector_storage.py b/tests/test_asyncpg_vector_storage.py index 6f0f9e6..f408637 100644 --- a/tests/test_asyncpg_vector_storage.py +++ b/tests/test_asyncpg_vector_storage.py @@ -8,7 +8,7 @@ import asyncpg from nano_graphrag.graphrag import always_get_an_event_loop WORKING_DIR = "nano_graphrag_cache_asyncpg_vector_storage_test" -dsn='postgresql://username:password@127.0.0.1:12345/db' +dsn='postgresql://test:test@127.0.0.1:12345/test' @pytest.fixture(scope="function") def setup_teardown(): From c4a6e935a0d62c7ef3f6d0bbf46c738cbad5b1b5 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Tue, 17 Sep 2024 18:26:35 +0800 Subject: [PATCH 04/10] fix --- .github/workflows/test.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fd1dde2..2fdb0cc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,12 +41,13 @@ jobs: - name: Lint with flake8 run: | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - - uses: ikalnytskyi/action-setup-postgres@v6 + - name: Setup postgres + uses: ikalnytskyi/action-setup-postgres@v6 with: - username: test - password: test + username: ci + password: sw0rdfish database: test - port: 12345 + port: 34837 postgres-version: "14" ssl: "on" id: postgres From 93366a2ccb31af36ecb111c680ee656e21e42977 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Tue, 17 Sep 2024 18:33:41 +0800 Subject: [PATCH 05/10] use env to get connection string --- .github/workflows/test.yml | 4 +++- tests/test_asyncpg_vector_storage.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2fdb0cc..9a8c8e1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,11 +47,13 @@ jobs: username: ci password: sw0rdfish database: test - port: 34837 + port: 12345 postgres-version: "14" ssl: "on" id: postgres - name: Build and Test + env: + POSTGRES_CONNECTION_STR: ${{ steps.postgres.outputs.connection-uri }} run: | python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=nano_graphrag --cov-report=xml -v ./ - name: Check codecov file diff --git a/tests/test_asyncpg_vector_storage.py b/tests/test_asyncpg_vector_storage.py index f408637..932cf9e 100644 --- a/tests/test_asyncpg_vector_storage.py +++ b/tests/test_asyncpg_vector_storage.py @@ -7,8 +7,9 @@ from nano_graphrag.storage.asyncpg import AsyncpgVectorStorage import asyncpg from nano_graphrag.graphrag import always_get_an_event_loop +import os WORKING_DIR = "nano_graphrag_cache_asyncpg_vector_storage_test" -dsn='postgresql://test:test@127.0.0.1:12345/test' +dsn=os.environ['POSTGRES_CONNECTION_STR'] @pytest.fixture(scope="function") def setup_teardown(): From 3753c7b2e595be5d8e23235738407db406347868 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Tue, 17 Sep 2024 18:45:23 +0800 Subject: [PATCH 06/10] fix --- nano_graphrag/storage/asyncpg.py | 9 +++++---- tests/test_asyncpg_vector_storage.py | 12 ++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/nano_graphrag/storage/asyncpg.py b/nano_graphrag/storage/asyncpg.py index 29e85e6..2b1a0a6 100644 --- a/nano_graphrag/storage/asyncpg.py +++ b/nano_graphrag/storage/asyncpg.py @@ -11,7 +11,7 @@ import nest_asyncio nest_asyncio.apply() -class AsyncpgVectorStorage(BaseVectorStorage): +class AsyncPGVectorStorage(BaseVectorStorage): table_name_generator: callable = None conn_fetcher: callable = None cosine_better_than_threshold: float = 0.2 @@ -34,15 +34,16 @@ def __init__(self, dsn: str = None, conn_fetcher: callable = None, table_name_ge loop = always_get_an_event_loop() loop.run_until_complete(self._secure_table()) @asynccontextmanager - async def __get_conn(self): + async def __get_conn(self, vector_register=True): try: conn: asyncpg.Connection = await asyncpg.connect(self.dsn) - await register_vector(conn) + if vector_register: + await register_vector(conn) yield conn finally: await conn.close() async def _secure_table(self): - async with self.conn_fetcher() as conn: + async with self.conn_fetcher(vector_register=False) as conn: conn: asyncpg.Connection await conn.execute('CREATE EXTENSION IF NOT EXISTS vector') result = await conn.fetch( diff --git a/tests/test_asyncpg_vector_storage.py b/tests/test_asyncpg_vector_storage.py index 932cf9e..db6de25 100644 --- a/tests/test_asyncpg_vector_storage.py +++ b/tests/test_asyncpg_vector_storage.py @@ -4,7 +4,7 @@ from nano_graphrag import GraphRAG from nano_graphrag._utils import wrap_embedding_func_with_attrs -from nano_graphrag.storage.asyncpg import AsyncpgVectorStorage +from nano_graphrag.storage.asyncpg import AsyncPGVectorStorage import asyncpg from nano_graphrag.graphrag import always_get_an_event_loop import os @@ -36,7 +36,7 @@ async def mock_embedding(texts: list[str]) -> np.ndarray: @pytest.fixture def asyncpg_storage(setup_teardown): rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) - return AsyncpgVectorStorage( + return AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, @@ -67,7 +67,7 @@ async def test_upsert_and_query(asyncpg_storage): @pytest.mark.asyncio async def test_persistence(setup_teardown): rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) - initial_storage = AsyncpgVectorStorage( + initial_storage = AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, @@ -82,7 +82,7 @@ async def test_persistence(setup_teardown): await initial_storage.upsert(test_data) await initial_storage.index_done_callback() - new_storage = AsyncpgVectorStorage( + new_storage = AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, @@ -100,7 +100,7 @@ async def test_persistence(setup_teardown): @pytest.mark.asyncio async def test_persistence_large_dataset(setup_teardown): rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) - initial_storage = AsyncpgVectorStorage( + initial_storage = AsyncPGVectorStorage( namespace="test_large", global_config=asdict(rag), embedding_func=mock_embedding, @@ -115,7 +115,7 @@ async def test_persistence_large_dataset(setup_teardown): await initial_storage.upsert(large_data) await initial_storage.index_done_callback() - new_storage = AsyncpgVectorStorage( + new_storage = AsyncPGVectorStorage( namespace="test_large", global_config=asdict(rag), embedding_func=mock_embedding, From c1f30db504ec811bef2b134d01f3fdc08506d3b3 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Tue, 17 Sep 2024 19:03:53 +0800 Subject: [PATCH 07/10] fix argument --- .github/workflows/test.yml | 4 + examples/using_pgvector_as_vectorDB.py | 129 +++++++++++++++++++++++++ nano_graphrag/storage/asyncpg.py | 9 +- tests/test_asyncpg_vector_storage.py | 10 +- 4 files changed, 142 insertions(+), 10 deletions(-) create mode 100644 examples/using_pgvector_as_vectorDB.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9a8c8e1..cc0f54e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -51,6 +51,10 @@ jobs: postgres-version: "14" ssl: "on" id: postgres + - name: Install pgvector + run: | + sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y + sudo apt-get install postgresql-14-pgvector - name: Build and Test env: POSTGRES_CONNECTION_STR: ${{ steps.postgres.outputs.connection-uri }} diff --git a/examples/using_pgvector_as_vectorDB.py b/examples/using_pgvector_as_vectorDB.py new file mode 100644 index 0000000..c3113da --- /dev/null +++ b/examples/using_pgvector_as_vectorDB.py @@ -0,0 +1,129 @@ +import os +from openai import AsyncOpenAI +from dotenv import load_dotenv +import logging +import numpy as np +from sentence_transformers import SentenceTransformer +from nano_graphrag import GraphRAG, QueryParam +from nano_graphrag._llm import gpt_4o_mini_complete +from nano_graphrag.storage.asyncpg import AsyncPGVectorStorage +from nano_graphrag.base import BaseKVStorage +from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs + +logging.basicConfig(level=logging.WARNING) +logging.getLogger("nano-graphrag").setLevel(logging.DEBUG) + +WORKING_DIR = "nano_graphrag_cache_using_pg_as_vectorDB" +dsn = os.environ.get("POSTGRES_CONNECTION_STR") +load_dotenv() + + +EMBED_MODEL = SentenceTransformer( + "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu" +) + + +@wrap_embedding_func_with_attrs( + embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(), + max_token_size=EMBED_MODEL.max_seq_length, +) +async def local_embedding(texts: list[str]) -> np.ndarray: + return EMBED_MODEL.encode(texts, normalize_embeddings=True) + + +async def deepseepk_model_if_cache( + prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs +) -> str: + openai_async_client = AsyncOpenAI( + api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com" + ) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Get the cached response if having------------------- + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + # ----------------------------------------------------- + + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + # Cache the response if having------------------- + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + # ----------------------------------------------------- + return response.choices[0].message.content + + + +def remove_if_exist(file): + if os.path.exists(file): + os.remove(file) + + +def insert(): + from time import time + + with open("./tests/mock_data.txt", encoding="utf-8-sig") as f: + FAKE_TEXT = f.read() + + remove_if_exist(f"{WORKING_DIR}/vdb_entities.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json") + remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml") + rag = GraphRAG( + working_dir=WORKING_DIR, + enable_llm_cache=True, + vector_db_storage_cls=AsyncPGVectorStorage, + vector_db_storage_cls_kwargs={"dsn": dsn}, + best_model_max_async=10, + cheap_model_max_async=10, + best_model_func=deepseepk_model_if_cache, + cheap_model_func=deepseepk_model_if_cache, + embedding_func=local_embedding + ) + start = time() + rag.insert(FAKE_TEXT) + print("indexing time:", time() - start) + + +def query(): + rag = GraphRAG( + working_dir=WORKING_DIR, + enable_llm_cache=True, + vector_db_storage_cls=AsyncPGVectorStorage, + vector_db_storage_cls_kwargs={"dsn": dsn}, + best_model_max_token_size=8196, + cheap_model_max_token_size=8196, + best_model_max_async=4, + cheap_model_max_async=4, + best_model_func=gpt_4o_mini_complete, + cheap_model_func=gpt_4o_mini_complete, + embedding_func=local_embedding + ) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + +if __name__ == "__main__": + insert() + query() \ No newline at end of file diff --git a/nano_graphrag/storage/asyncpg.py b/nano_graphrag/storage/asyncpg.py index 2b1a0a6..ea2792b 100644 --- a/nano_graphrag/storage/asyncpg.py +++ b/nano_graphrag/storage/asyncpg.py @@ -7,7 +7,7 @@ from nano_graphrag.graphrag import always_get_an_event_loop import numpy as np import json - +import os import nest_asyncio nest_asyncio.apply() @@ -16,8 +16,11 @@ class AsyncPGVectorStorage(BaseVectorStorage): conn_fetcher: callable = None cosine_better_than_threshold: float = 0.2 dsn = None - def __init__(self, dsn: str = None, conn_fetcher: callable = None, table_name_generator: callable = None, *args, **kwargs): - super().__init__(*args, **kwargs) + def __post_init__(self): + params = self.global_config.get("vector_db_storage_cls_kwargs", {}) + dsn = params.get("dsn", None) + conn_fetcher = params.get("conn_fetcher", None) + table_name_generator = params.get("table_name_generator", None) self.dsn = dsn self.conn_fetcher = conn_fetcher assert self.dsn != None or self.conn_fetcher != None, "Must provide either dsn or conn_fetcher" diff --git a/tests/test_asyncpg_vector_storage.py b/tests/test_asyncpg_vector_storage.py index db6de25..82e9db4 100644 --- a/tests/test_asyncpg_vector_storage.py +++ b/tests/test_asyncpg_vector_storage.py @@ -35,13 +35,12 @@ async def mock_embedding(texts: list[str]) -> np.ndarray: @pytest.fixture def asyncpg_storage(setup_teardown): - rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) + rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding, vector_db_storage_cls_kwargs={"dsn": dsn}) return AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, meta_fields={"entity_name"}, - dsn=dsn ) @@ -66,13 +65,12 @@ async def test_upsert_and_query(asyncpg_storage): @pytest.mark.asyncio async def test_persistence(setup_teardown): - rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) + rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding, vector_db_storage_cls_kwargs={"dsn": dsn}) initial_storage = AsyncPGVectorStorage( namespace="test", global_config=asdict(rag), embedding_func=mock_embedding, meta_fields={"entity_name"}, - dsn=dsn ) test_data = { @@ -99,13 +97,12 @@ async def test_persistence(setup_teardown): @pytest.mark.asyncio async def test_persistence_large_dataset(setup_teardown): - rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding) + rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding, vector_db_storage_cls_kwargs={"dsn": dsn}) initial_storage = AsyncPGVectorStorage( namespace="test_large", global_config=asdict(rag), embedding_func=mock_embedding, meta_fields={"entity_name"}, - dsn=dsn ) large_data = { @@ -120,7 +117,6 @@ async def test_persistence_large_dataset(setup_teardown): global_config=asdict(rag), embedding_func=mock_embedding, meta_fields={"entity_name"}, - dsn=dsn ) results = await new_storage.query("Test query", top_k=500) From b6b79d6b541538043958464aa3494dbe4d7aaaa6 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Tue, 17 Sep 2024 19:08:35 +0800 Subject: [PATCH 08/10] fix --- nano_graphrag/storage/asyncpg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nano_graphrag/storage/asyncpg.py b/nano_graphrag/storage/asyncpg.py index ea2792b..1417b63 100644 --- a/nano_graphrag/storage/asyncpg.py +++ b/nano_graphrag/storage/asyncpg.py @@ -7,7 +7,8 @@ from nano_graphrag.graphrag import always_get_an_event_loop import numpy as np import json -import os +from dataclasses import dataclass + import nest_asyncio nest_asyncio.apply() From aa2e9f022122a7a775971519c794ed9e86b5cad6 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Tue, 17 Sep 2024 19:11:49 +0800 Subject: [PATCH 09/10] fix --- nano_graphrag/storage/asyncpg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nano_graphrag/storage/asyncpg.py b/nano_graphrag/storage/asyncpg.py index 1417b63..3fb7816 100644 --- a/nano_graphrag/storage/asyncpg.py +++ b/nano_graphrag/storage/asyncpg.py @@ -12,6 +12,7 @@ import nest_asyncio nest_asyncio.apply() +@dataclass class AsyncPGVectorStorage(BaseVectorStorage): table_name_generator: callable = None conn_fetcher: callable = None From 770e248b30fda3d325ec5cae40c3662c63bfba06 Mon Sep 17 00:00:00 2001 From: Dorbmon Date: Tue, 17 Sep 2024 19:15:09 +0800 Subject: [PATCH 10/10] fix --- tests/test_asyncpg_vector_storage.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_asyncpg_vector_storage.py b/tests/test_asyncpg_vector_storage.py index 82e9db4..20bb8e1 100644 --- a/tests/test_asyncpg_vector_storage.py +++ b/tests/test_asyncpg_vector_storage.py @@ -85,7 +85,6 @@ async def test_persistence(setup_teardown): global_config=asdict(rag), embedding_func=mock_embedding, meta_fields={"entity_name"}, - dsn=dsn ) results = await new_storage.query("Test query", top_k=1)