From 4b3871237208acb2ffb9db9bef0077b692feb69c Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Fri, 24 May 2024 10:07:23 +0800 Subject: [PATCH] fix bugs: should normalize cosine dataset when test with milvus gpu_index Signed-off-by: min.tian --- vectordb_bench/backend/clients/milvus/config.py | 8 ++++++-- vectordb_bench/backend/clients/milvus/milvus.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index 0e93e82b8..eea7b11f5 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -14,13 +14,17 @@ class MilvusIndexConfig(BaseModel): index: IndexType metric_type: MetricType | None = None + + @property + def is_gpu_index(self) -> bool: + return self.index in [IndexType.GPU_CAGRA, IndexType.GPU_IVF_FLAT, IndexType.GPU_IVF_PQ] def parse_metric(self) -> str: if not self.metric_type: return "" - # if self.metric_type == MetricType.COSINE: - # return MetricType.L2.value + if self.is_gpu_index and self.metric_type == MetricType.COSINE: + return MetricType.L2.value return self.metric_type.value diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py index 58334efe9..2436e2680 100644 --- a/vectordb_bench/backend/clients/milvus/milvus.py +++ b/vectordb_bench/backend/clients/milvus/milvus.py @@ -8,7 +8,7 @@ from pymilvus import Collection, utility from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException -from ..api import VectorDB, IndexType +from ..api import VectorDB from .config import MilvusIndexConfig @@ -119,7 +119,7 @@ def wait_index(): wait_index() # Skip compaction if use GPU indexType - if self.case_config.index in [IndexType.GPU_CAGRA, IndexType.GPU_IVF_FLAT, IndexType.GPU_IVF_PQ]: + if self.case_config.is_gpu_index: log.debug("skip compaction for gpu index type.") else : self.col.compact() @@ -157,6 +157,10 @@ def optimize(self): def need_normalize_cosine(self) -> bool: """Wheather this database need to normalize dataset to support COSINE""" + if self.case_config.is_gpu_index: + log.info(f"current gpu_index only supports IP / L2, cosine dataset need normalize.") + return True + return False def insert_embeddings(