diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e893b30..cc0f54e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,7 +41,23 @@ jobs: - name: Lint with flake8 run: | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + - name: Setup postgres + uses: ikalnytskyi/action-setup-postgres@v6 + with: + username: ci + password: sw0rdfish + database: test + port: 12345 + postgres-version: "14" + ssl: "on" + id: postgres + - name: Install pgvector + run: | + sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y + sudo apt-get install postgresql-14-pgvector - name: Build and Test + env: + POSTGRES_CONNECTION_STR: ${{ steps.postgres.outputs.connection-uri }} run: | python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=nano_graphrag --cov-report=xml -v ./ - name: Check codecov file diff --git a/examples/using_pgvector_as_vectorDB.py b/examples/using_pgvector_as_vectorDB.py new file mode 100644 index 0000000..c3113da --- /dev/null +++ b/examples/using_pgvector_as_vectorDB.py @@ -0,0 +1,129 @@ +import os +from openai import AsyncOpenAI +from dotenv import load_dotenv +import logging +import numpy as np +from sentence_transformers import SentenceTransformer +from nano_graphrag import GraphRAG, QueryParam +from nano_graphrag._llm import gpt_4o_mini_complete +from nano_graphrag.storage.asyncpg import AsyncPGVectorStorage +from nano_graphrag.base import BaseKVStorage +from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs + +logging.basicConfig(level=logging.WARNING) +logging.getLogger("nano-graphrag").setLevel(logging.DEBUG) + +WORKING_DIR = "nano_graphrag_cache_using_pg_as_vectorDB" +dsn = os.environ.get("POSTGRES_CONNECTION_STR") +load_dotenv() + + +EMBED_MODEL = SentenceTransformer( + "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu" +) + + +@wrap_embedding_func_with_attrs( + embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(), + max_token_size=EMBED_MODEL.max_seq_length, +) +async def local_embedding(texts: list[str]) -> np.ndarray: + return EMBED_MODEL.encode(texts, normalize_embeddings=True) + + +async def deepseepk_model_if_cache( + prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs +) -> str: + openai_async_client = AsyncOpenAI( + api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com" + ) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Get the cached response if having------------------- + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + # ----------------------------------------------------- + + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + # Cache the response if having------------------- + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + # ----------------------------------------------------- + return response.choices[0].message.content + + + +def remove_if_exist(file): + if os.path.exists(file): + os.remove(file) + + +def insert(): + from time import time + + with open("./tests/mock_data.txt", encoding="utf-8-sig") as f: + FAKE_TEXT = f.read() + + remove_if_exist(f"{WORKING_DIR}/vdb_entities.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json") + remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml") + rag = GraphRAG( + working_dir=WORKING_DIR, + enable_llm_cache=True, + vector_db_storage_cls=AsyncPGVectorStorage, + vector_db_storage_cls_kwargs={"dsn": dsn}, + best_model_max_async=10, + cheap_model_max_async=10, + best_model_func=deepseepk_model_if_cache, + cheap_model_func=deepseepk_model_if_cache, + embedding_func=local_embedding + ) + start = time() + rag.insert(FAKE_TEXT) + print("indexing time:", time() - start) + + +def query(): + rag = GraphRAG( + working_dir=WORKING_DIR, + enable_llm_cache=True, + vector_db_storage_cls=AsyncPGVectorStorage, + vector_db_storage_cls_kwargs={"dsn": dsn}, + best_model_max_token_size=8196, + cheap_model_max_token_size=8196, + best_model_max_async=4, + cheap_model_max_async=4, + best_model_func=gpt_4o_mini_complete, + cheap_model_func=gpt_4o_mini_complete, + embedding_func=local_embedding + ) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + +if __name__ == "__main__": + insert() + query() \ No newline at end of file diff --git a/nano_graphrag/storage/asyncpg.py b/nano_graphrag/storage/asyncpg.py new file mode 100644 index 0000000..3fb7816 --- /dev/null +++ b/nano_graphrag/storage/asyncpg.py @@ -0,0 +1,106 @@ +from nano_graphrag._storage import BaseVectorStorage +import asyncpg +import asyncio +from contextlib import asynccontextmanager +from nano_graphrag._utils import logger +from pgvector.asyncpg import register_vector +from nano_graphrag.graphrag import always_get_an_event_loop +import numpy as np +import json +from dataclasses import dataclass + +import nest_asyncio +nest_asyncio.apply() + +@dataclass +class AsyncPGVectorStorage(BaseVectorStorage): + table_name_generator: callable = None + conn_fetcher: callable = None + cosine_better_than_threshold: float = 0.2 + dsn = None + def __post_init__(self): + params = self.global_config.get("vector_db_storage_cls_kwargs", {}) + dsn = params.get("dsn", None) + conn_fetcher = params.get("conn_fetcher", None) + table_name_generator = params.get("table_name_generator", None) + self.dsn = dsn + self.conn_fetcher = conn_fetcher + assert self.dsn != None or self.conn_fetcher != None, "Must provide either dsn or conn_fetcher" + if self.dsn: + self.conn_fetcher = self.__get_conn + if not table_name_generator: + self.table_name_generator = lambda working_dir, namespace: f'{working_dir}_{namespace}_vdb' + self._table_name = self.table_name_generator(self.global_config["working_dir"], self.namespace) + self._max_batch_size = self.global_config["embedding_batch_num"] + + self.cosine_better_than_threshold = self.global_config.get( + "query_better_than_threshold", self.cosine_better_than_threshold + ) + loop = always_get_an_event_loop() + loop.run_until_complete(self._secure_table()) + @asynccontextmanager + async def __get_conn(self, vector_register=True): + try: + conn: asyncpg.Connection = await asyncpg.connect(self.dsn) + if vector_register: + await register_vector(conn) + yield conn + finally: + await conn.close() + async def _secure_table(self): + async with self.conn_fetcher(vector_register=False) as conn: + conn: asyncpg.Connection + await conn.execute('CREATE EXTENSION IF NOT EXISTS vector') + result = await conn.fetch( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = $1)", self._table_name) + table_exists = result[0]['exists'] + if not table_exists: + # create the table + await conn.execute(f'CREATE TABLE {self._table_name} (id text PRIMARY KEY, embedding vector({self.embedding_func.embedding_dim}), data jsonb)') + await conn.execute(f'CREATE INDEX ON {self._table_name} USING hnsw (embedding vector_cosine_ops)') + async def query(self, query: str, top_k: int) -> list[dict]: + embedding = await self.embedding_func([query]) + embedding = embedding[0] + async with self.conn_fetcher() as conn: + + result = await conn.fetch(f'SELECT embedding <=> $1 as similarity, id, embedding, data FROM {self._table_name} WHERE embedding <=> $1 > $3 ORDER BY embedding <=> $1 DESC LIMIT $2', embedding, top_k, self.cosine_better_than_threshold) + + rows = [] + for row in result: + data = json.loads(row['data']) + rows.append({ + **data, + 'id': row['id'], + 'distance': 1 - row['similarity'], + 'similarity': row['similarity'] + }) + return rows + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + list_data = [ + { + "__id__": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings_list = np.concatenate(embeddings_list) + insert_rows = [] + for i, d in enumerate(list_data): + row = [d["__id__"], embeddings_list[i], json.dumps(d)] + insert_rows.append(row) + async with self.conn_fetcher() as conn: + conn: asyncpg.Connection + stmt = f"INSERT INTO {self._table_name} (id, embedding, data) VALUES ($1, $2, $3) ON CONFLICT (id) DO UPDATE SET embedding = $2, data = $3" + return await conn.executemany(stmt, insert_rows) diff --git a/requirements.txt b/requirements.txt index 014ecba..d1a0f0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,6 @@ hnswlib xxhash tenacity dspy-ai +pgvector==0.3.3 +asyncpg==0.29.0 +nest_asyncio==1.6.0 \ No newline at end of file diff --git a/tests/test_asyncpg_vector_storage.py b/tests/test_asyncpg_vector_storage.py new file mode 100644 index 0000000..20bb8e1 --- /dev/null +++ b/tests/test_asyncpg_vector_storage.py @@ -0,0 +1,200 @@ +import numpy as np +import pytest +from dataclasses import asdict +from nano_graphrag import GraphRAG +from nano_graphrag._utils import wrap_embedding_func_with_attrs + +from nano_graphrag.storage.asyncpg import AsyncPGVectorStorage +import asyncpg +from nano_graphrag.graphrag import always_get_an_event_loop +import os +WORKING_DIR = "nano_graphrag_cache_asyncpg_vector_storage_test" +dsn=os.environ['POSTGRES_CONNECTION_STR'] + +@pytest.fixture(scope="function") +def setup_teardown(): + + yield + loop = always_get_an_event_loop() + async def clean_table(): + conn: asyncpg.Connection = await asyncpg.connect(dsn) + async with conn.transaction(): + tables = await conn.fetch( + f"SELECT table_name FROM information_schema.tables WHERE table_name LIKE '{WORKING_DIR}%'" + ) + + for table in tables: + await conn.execute(f"DROP TABLE {table['table_name']} CASCADE") + loop.run_until_complete(clean_table()) + + +@wrap_embedding_func_with_attrs(embedding_dim=384, max_token_size=8192) +async def mock_embedding(texts: list[str]) -> np.ndarray: + return np.random.rand(len(texts), 384) + + +@pytest.fixture +def asyncpg_storage(setup_teardown): + rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding, vector_db_storage_cls_kwargs={"dsn": dsn}) + return AsyncPGVectorStorage( + namespace="test", + global_config=asdict(rag), + embedding_func=mock_embedding, + meta_fields={"entity_name"}, + ) + + +@pytest.mark.asyncio +async def test_upsert_and_query(asyncpg_storage): + test_data = { + "1": {"content": "Test content 1", "entity_name": "Entity 1"}, + "2": {"content": "Test content 2", "entity_name": "Entity 2"}, + } + + await asyncpg_storage.upsert(test_data) + + results = await asyncpg_storage.query("Test query", top_k=2) + + assert len(results) == 2 + assert all(isinstance(result, dict) for result in results) + assert all( + "id" in result and "distance" in result and "similarity" in result + for result in results + ) + + +@pytest.mark.asyncio +async def test_persistence(setup_teardown): + rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding, vector_db_storage_cls_kwargs={"dsn": dsn}) + initial_storage = AsyncPGVectorStorage( + namespace="test", + global_config=asdict(rag), + embedding_func=mock_embedding, + meta_fields={"entity_name"}, + ) + + test_data = { + "1": {"content": "Test content 1", "entity_name": "Entity 1"}, + } + + await initial_storage.upsert(test_data) + await initial_storage.index_done_callback() + + new_storage = AsyncPGVectorStorage( + namespace="test", + global_config=asdict(rag), + embedding_func=mock_embedding, + meta_fields={"entity_name"}, + ) + + results = await new_storage.query("Test query", top_k=1) + + assert len(results) == 1 + assert results[0]["id"] == "1" + assert "entity_name" in results[0] + + +@pytest.mark.asyncio +async def test_persistence_large_dataset(setup_teardown): + rag = GraphRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding, vector_db_storage_cls_kwargs={"dsn": dsn}) + initial_storage = AsyncPGVectorStorage( + namespace="test_large", + global_config=asdict(rag), + embedding_func=mock_embedding, + meta_fields={"entity_name"}, + ) + + large_data = { + str(i): {"content": f"Test content {i}", "entity_name": f"Entity {i}"} + for i in range(1000) + } + await initial_storage.upsert(large_data) + await initial_storage.index_done_callback() + + new_storage = AsyncPGVectorStorage( + namespace="test_large", + global_config=asdict(rag), + embedding_func=mock_embedding, + meta_fields={"entity_name"}, + ) + + results = await new_storage.query("Test query", top_k=500) + assert len(results) == 500 + assert all(result["id"] in large_data for result in results) + + +@pytest.mark.asyncio +async def test_upsert_with_existing_ids(asyncpg_storage): + test_data = { + "1": {"content": "Test content 1", "entity_name": "Entity 1"}, + "2": {"content": "Test content 2", "entity_name": "Entity 2"}, + } + + await asyncpg_storage.upsert(test_data) + + updated_data = { + "1": {"content": "Updated content 1", "entity_name": "Updated Entity 1"}, + "3": {"content": "Test content 3", "entity_name": "Entity 3"}, + } + + await asyncpg_storage.upsert(updated_data) + + results = await asyncpg_storage.query("Updated", top_k=3) + + assert len(results) == 3 + assert any( + result["id"] == "1" and result["entity_name"] == "Updated Entity 1" + for result in results + ) + assert any( + result["id"] == "2" and result["entity_name"] == "Entity 2" + for result in results + ) + assert any( + result["id"] == "3" and result["entity_name"] == "Entity 3" + for result in results + ) + + +@pytest.mark.asyncio +async def test_large_batch_upsert(asyncpg_storage): + batch_size = 30 + large_data = { + str(i): {"content": f"Test content {i}", "entity_name": f"Entity {i}"} + for i in range(batch_size) + } + + await asyncpg_storage.upsert(large_data) + + results = await asyncpg_storage.query("Test query", top_k=batch_size) + assert len(results) == batch_size + assert all(isinstance(result, dict) for result in results) + assert all( + "id" in result and "distance" in result and "similarity" in result + for result in results + ) + + +@pytest.mark.asyncio +async def test_empty_data_insertion(asyncpg_storage): + empty_data = {} + await asyncpg_storage.upsert(empty_data) + + results = await asyncpg_storage.query("Test query", top_k=1) + assert len(results) == 0 + + +@pytest.mark.asyncio +async def test_query_with_no_results(asyncpg_storage): + results = await asyncpg_storage.query("Non-existent query", top_k=5) + assert len(results) == 0 + + test_data = { + "1": {"content": "Test content 1", "entity_name": "Entity 1"}, + } + await asyncpg_storage.upsert(test_data) + + results = await asyncpg_storage.query("Non-existent query", top_k=5) + assert len(results) == 1 + assert all(0 <= result["similarity"] <= 1 for result in results) + assert "entity_name" in results[0]