From c66dfb52975598e766fcc29ad43c2850433f743b Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Mon, 28 Oct 2024 14:24:52 +0800 Subject: [PATCH] fix pinecone client Signed-off-by: min.tian --- .../backend/clients/pinecone/config.py | 2 - .../backend/clients/pinecone/pinecone.py | 70 +++++++++---------- 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/vectordb_bench/backend/clients/pinecone/config.py b/vectordb_bench/backend/clients/pinecone/config.py index dc1596379..2bbcbb350 100644 --- a/vectordb_bench/backend/clients/pinecone/config.py +++ b/vectordb_bench/backend/clients/pinecone/config.py @@ -4,12 +4,10 @@ class PineconeConfig(DBConfig): api_key: SecretStr - environment: SecretStr index_name: str def to_dict(self) -> dict: return { "api_key": self.api_key.get_secret_value(), - "environment": self.environment.get_secret_value(), "index_name": self.index_name, } diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index c2653ee27..c1351f7a9 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -3,7 +3,7 @@ import logging from contextlib import contextmanager from typing import Type - +import pinecone from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType from .config import PineconeConfig @@ -11,7 +11,8 @@ log = logging.getLogger(__name__) PINECONE_MAX_NUM_PER_BATCH = 1000 -PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB +PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB + class Pinecone(VectorDB): def __init__( @@ -23,30 +24,25 @@ def __init__( **kwargs, ): """Initialize wrapper around the milvus vector database.""" - self.index_name = db_config["index_name"] - self.api_key = db_config["api_key"] - self.environment = db_config["environment"] - self.batch_size = int(min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)) - # Pincone will make connections with server while import - # so place the import here. - import pinecone - pinecone.init( - api_key=self.api_key, environment=self.environment) + self.index_name = db_config.get("index_name", "") + self.api_key = db_config.get("api_key", "") + self.batch_size = int( + min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH) + ) + + pc = pinecone.Pinecone(api_key=self.api_key) + index = pc.Index(self.index_name) + if drop_old: - list_indexes = pinecone.list_indexes() - if self.index_name in list_indexes: - index = pinecone.Index(self.index_name) - index_dim = index.describe_index_stats()["dimension"] - if (index_dim != dim): - raise ValueError( - f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}") - log.info( - f"Pinecone client delete old index: {self.index_name}") - index.delete(delete_all=True) - index.close() - else: + index_stats = index.describe_index_stats() + index_dim = index_stats["dimension"] + if index_dim != dim: raise ValueError( - f"Pinecone index {self.index_name} does not exist") + f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}" + ) + for namespace in index_stats["namespaces"]: + log.info(f"Pinecone index delete namespace: {namespace}") + index.delete(delete_all=True, namespace=namespace) self._metadata_key = "meta" @@ -59,13 +55,10 @@ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConf return EmptyDBCaseConfig @contextmanager - def init(self) -> None: - import pinecone - pinecone.init( - api_key=self.api_key, environment=self.environment) - self.index = pinecone.Index(self.index_name) + def init(self): + pc = pinecone.Pinecone(api_key=self.api_key) + self.index = pc.Index(self.index_name) yield - self.index.close() def ready_to_load(self): pass @@ -83,11 +76,16 @@ def insert_embeddings( insert_count = 0 try: for batch_start_offset in range(0, len(embeddings), self.batch_size): - batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings)) + batch_end_offset = min( + batch_start_offset + self.batch_size, len(embeddings) + ) insert_datas = [] for i in range(batch_start_offset, batch_end_offset): - insert_data = (str(metadata[i]), embeddings[i], { - self._metadata_key: metadata[i]}) + insert_data = ( + str(metadata[i]), + embeddings[i], + {self._metadata_key: metadata[i]}, + ) insert_datas.append(insert_data) self.index.upsert(insert_datas) insert_count += batch_end_offset - batch_start_offset @@ -101,7 +99,7 @@ def search_embedding( k: int = 100, filters: dict | None = None, timeout: int | None = None, - ) -> list[tuple[int, float]]: + ) -> list[int]: if filters is None: pinecone_filters = {} else: @@ -111,9 +109,9 @@ def search_embedding( top_k=k, vector=query, filter=pinecone_filters, - )['matches'] + )["matches"] except Exception as e: print(f"Error querying index: {e}") raise e - id_res = [int(one_res['id']) for one_res in res] + id_res = [int(one_res["id"]) for one_res in res] return id_res