Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix pinecone client #387

Merged
merged 1 commit into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions vectordb_bench/backend/clients/pinecone/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@

class PineconeConfig(DBConfig):
api_key: SecretStr
environment: SecretStr
index_name: str

def to_dict(self) -> dict:
return {
"api_key": self.api_key.get_secret_value(),
"environment": self.environment.get_secret_value(),
"index_name": self.index_name,
}
70 changes: 34 additions & 36 deletions vectordb_bench/backend/clients/pinecone/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import logging
from contextlib import contextmanager
from typing import Type

import pinecone
from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
from .config import PineconeConfig


log = logging.getLogger(__name__)

PINECONE_MAX_NUM_PER_BATCH = 1000
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB


class Pinecone(VectorDB):
def __init__(
Expand All @@ -23,30 +24,25 @@ def __init__(
**kwargs,
):
"""Initialize wrapper around the milvus vector database."""
self.index_name = db_config["index_name"]
self.api_key = db_config["api_key"]
self.environment = db_config["environment"]
self.batch_size = int(min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH))
# Pincone will make connections with server while import
# so place the import here.
import pinecone
pinecone.init(
api_key=self.api_key, environment=self.environment)
self.index_name = db_config.get("index_name", "")
self.api_key = db_config.get("api_key", "")
self.batch_size = int(
min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)
)

pc = pinecone.Pinecone(api_key=self.api_key)
index = pc.Index(self.index_name)

if drop_old:
list_indexes = pinecone.list_indexes()
if self.index_name in list_indexes:
index = pinecone.Index(self.index_name)
index_dim = index.describe_index_stats()["dimension"]
if (index_dim != dim):
raise ValueError(
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}")
log.info(
f"Pinecone client delete old index: {self.index_name}")
index.delete(delete_all=True)
index.close()
else:
index_stats = index.describe_index_stats()
index_dim = index_stats["dimension"]
if index_dim != dim:
raise ValueError(
f"Pinecone index {self.index_name} does not exist")
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
)
for namespace in index_stats["namespaces"]:
log.info(f"Pinecone index delete namespace: {namespace}")
index.delete(delete_all=True, namespace=namespace)

self._metadata_key = "meta"

Expand All @@ -59,13 +55,10 @@ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConf
return EmptyDBCaseConfig

@contextmanager
def init(self) -> None:
import pinecone
pinecone.init(
api_key=self.api_key, environment=self.environment)
self.index = pinecone.Index(self.index_name)
def init(self):
pc = pinecone.Pinecone(api_key=self.api_key)
self.index = pc.Index(self.index_name)
yield
self.index.close()

def ready_to_load(self):
pass
Expand All @@ -83,11 +76,16 @@ def insert_embeddings(
insert_count = 0
try:
for batch_start_offset in range(0, len(embeddings), self.batch_size):
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
batch_end_offset = min(
batch_start_offset + self.batch_size, len(embeddings)
)
insert_datas = []
for i in range(batch_start_offset, batch_end_offset):
insert_data = (str(metadata[i]), embeddings[i], {
self._metadata_key: metadata[i]})
insert_data = (
str(metadata[i]),
embeddings[i],
{self._metadata_key: metadata[i]},
)
insert_datas.append(insert_data)
self.index.upsert(insert_datas)
insert_count += batch_end_offset - batch_start_offset
Expand All @@ -101,7 +99,7 @@ def search_embedding(
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
) -> list[tuple[int, float]]:
) -> list[int]:
if filters is None:
pinecone_filters = {}
else:
Expand All @@ -111,9 +109,9 @@ def search_embedding(
top_k=k,
vector=query,
filter=pinecone_filters,
)['matches']
)["matches"]
except Exception as e:
print(f"Error querying index: {e}")
raise e
id_res = [int(one_res['id']) for one_res in res]
id_res = [int(one_res["id"]) for one_res in res]
return id_res