Skip to content

Commit

Permalink
fix pinecone client
Browse files Browse the repository at this point in the history
Signed-off-by: min.tian <[email protected]>
  • Loading branch information
alwayslove2013 committed Oct 28, 2024
1 parent 51390b7 commit 0ca1b1f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 38 deletions.
2 changes: 0 additions & 2 deletions vectordb_bench/backend/clients/pinecone/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@

class PineconeConfig(DBConfig):
api_key: SecretStr
environment: SecretStr
index_name: str

def to_dict(self) -> dict:
return {
"api_key": self.api_key.get_secret_value(),
"environment": self.environment.get_secret_value(),
"index_name": self.index_name,
}
70 changes: 34 additions & 36 deletions vectordb_bench/backend/clients/pinecone/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import logging
from contextlib import contextmanager
from typing import Type

import pinecone
from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
from .config import PineconeConfig


log = logging.getLogger(__name__)

PINECONE_MAX_NUM_PER_BATCH = 1000
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB


class Pinecone(VectorDB):
def __init__(
Expand All @@ -23,30 +24,25 @@ def __init__(
**kwargs,
):
"""Initialize wrapper around the milvus vector database."""
self.index_name = db_config["index_name"]
self.api_key = db_config["api_key"]
self.environment = db_config["environment"]
self.batch_size = int(min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH))
# Pincone will make connections with server while import
# so place the import here.
import pinecone
pinecone.init(
api_key=self.api_key, environment=self.environment)
self.index_name = db_config.get("index_name", "")
self.api_key = db_config.get("api_key", "")
self.batch_size = int(
min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)
)

pc = pinecone.Pinecone(api_key=self.api_key)
index = pc.Index(self.index_name)

if drop_old:
list_indexes = pinecone.list_indexes()
if self.index_name in list_indexes:
index = pinecone.Index(self.index_name)
index_dim = index.describe_index_stats()["dimension"]
if (index_dim != dim):
raise ValueError(
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}")
log.info(
f"Pinecone client delete old index: {self.index_name}")
index.delete(delete_all=True)
index.close()
else:
index_stats = index.describe_index_stats()
index_dim = index_stats["dimension"]
if index_dim != dim:
raise ValueError(
f"Pinecone index {self.index_name} does not exist")
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
)
for namespace in index_stats["namespaces"]:
log.info(f"Pinecone index delete namespace: {namespace}")
index.delete(delete_all=True, namespace=namespace)

self._metadata_key = "meta"

Expand All @@ -59,13 +55,10 @@ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConf
return EmptyDBCaseConfig

@contextmanager
def init(self) -> None:
import pinecone
pinecone.init(
api_key=self.api_key, environment=self.environment)
self.index = pinecone.Index(self.index_name)
def init(self):
pc = pinecone.Pinecone(api_key=self.api_key)
self.index = pc.Index(self.index_name)
yield
self.index.close()

def ready_to_load(self):
pass
Expand All @@ -83,11 +76,16 @@ def insert_embeddings(
insert_count = 0
try:
for batch_start_offset in range(0, len(embeddings), self.batch_size):
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
batch_end_offset = min(
batch_start_offset + self.batch_size, len(embeddings)
)
insert_datas = []
for i in range(batch_start_offset, batch_end_offset):
insert_data = (str(metadata[i]), embeddings[i], {
self._metadata_key: metadata[i]})
insert_data = (
str(metadata[i]),
embeddings[i],
{self._metadata_key: metadata[i]},
)
insert_datas.append(insert_data)
self.index.upsert(insert_datas)
insert_count += batch_end_offset - batch_start_offset
Expand All @@ -101,7 +99,7 @@ def search_embedding(
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
) -> list[tuple[int, float]]:
) -> list[int]:
if filters is None:
pinecone_filters = {}
else:
Expand All @@ -111,9 +109,9 @@ def search_embedding(
top_k=k,
vector=query,
filter=pinecone_filters,
)['matches']
)["matches"]
except Exception as e:
print(f"Error querying index: {e}")
raise e
id_res = [int(one_res['id']) for one_res in res]
id_res = [int(one_res["id"]) for one_res in res]
return id_res

0 comments on commit 0ca1b1f

Please sign in to comment.