From 6e0ad2eda2fc45094ff0a1bb233fde35be44cb2e Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Mon, 22 Apr 2024 06:32:05 +0000 Subject: [PATCH] fix qdrant client: compatibe with open-source qdrant Signed-off-by: min.tian --- .../backend/clients/qdrant_cloud/config.py | 25 ++++++++++++++----- .../clients/qdrant_cloud/qdrant_cloud.py | 18 +++++++------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/vectordb_bench/backend/clients/qdrant_cloud/config.py b/vectordb_bench/backend/clients/qdrant_cloud/config.py index a2c8d1a86..5b1dd7f18 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/config.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/config.py @@ -1,18 +1,31 @@ from pydantic import BaseModel, SecretStr from ..api import DBConfig, DBCaseConfig, MetricType +from pydantic import validator - +# Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant. class QdrantConfig(DBConfig): url: SecretStr api_key: SecretStr def to_dict(self) -> dict: - return { - "url": self.url.get_secret_value(), - "api_key": self.api_key.get_secret_value(), - "prefer_grpc": True, - } + api_key = self.api_key.get_secret_value() + if len(api_key) > 0: + return { + "url": self.url.get_secret_value(), + "api_key": self.api_key.get_secret_value(), + "prefer_grpc": True, + } + else: + return {"url": self.url.get_secret_value(),} + + @validator("*") + def not_empty_field(cls, v, field): + if field.name in ["api_key", "db_label"]: + return v + if isinstance(v, (str, SecretStr)) and len(v) == 0: + raise ValueError("Empty string!") + return v class QdrantIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None diff --git a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py index c4619e657..a51632bc6 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py @@ -43,8 +43,7 @@ def __init__( if drop_old: log.info(f"QdrantCloud client drop_old collection: {self.collection_name}") tmp_client.delete_collection(self.collection_name) - - self._create_collection(dim, tmp_client) + self._create_collection(dim, tmp_client) tmp_client = None @contextmanager @@ -110,13 +109,18 @@ def insert_embeddings( ) -> (int, Exception): """Insert embeddings into Milvus. should call self.init() first""" assert self.qdrant_client is not None + QDRANT_BATCH_SIZE = 500 try: # TODO: counts - _ = self.qdrant_client.upsert( - collection_name=self.collection_name, - wait=True, - points=Batch(ids=metadata, payloads=[{self._primary_field: v} for v in metadata], vectors=embeddings) - ) + for offset in range(0, len(embeddings), QDRANT_BATCH_SIZE): + vectors = embeddings[offset: offset + QDRANT_BATCH_SIZE] + ids = metadata[offset: offset + QDRANT_BATCH_SIZE] + payloads=[{self._primary_field: v} for v in ids] + _ = self.qdrant_client.upsert( + collection_name=self.collection_name, + wait=True, + points=Batch(ids=ids, payloads=payloads, vectors=vectors), + ) except Exception as e: log.info(f"Failed to insert data, {e}") return 0, e