Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bugs: should normalize cosine dataset when test with milvus gpu_index #326

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vectordb_bench/backend/clients/milvus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ class MilvusIndexConfig(BaseModel):

index: IndexType
metric_type: MetricType | None = None

@property
def is_gpu_index(self) -> bool:
return self.index in [IndexType.GPU_CAGRA, IndexType.GPU_IVF_FLAT, IndexType.GPU_IVF_PQ]

def parse_metric(self) -> str:
if not self.metric_type:
return ""

# if self.metric_type == MetricType.COSINE:
# return MetricType.L2.value
if self.is_gpu_index and self.metric_type == MetricType.COSINE:
return MetricType.L2.value
return self.metric_type.value


Expand Down
8 changes: 6 additions & 2 deletions vectordb_bench/backend/clients/milvus/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pymilvus import Collection, utility
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException

from ..api import VectorDB, IndexType
from ..api import VectorDB
from .config import MilvusIndexConfig


Expand Down Expand Up @@ -119,7 +119,7 @@ def wait_index():
wait_index()

# Skip compaction if use GPU indexType
if self.case_config.index in [IndexType.GPU_CAGRA, IndexType.GPU_IVF_FLAT, IndexType.GPU_IVF_PQ]:
if self.case_config.is_gpu_index:
log.debug("skip compaction for gpu index type.")
else :
self.col.compact()
Expand Down Expand Up @@ -157,6 +157,10 @@ def optimize(self):

def need_normalize_cosine(self) -> bool:
"""Wheather this database need to normalize dataset to support COSINE"""
if self.case_config.is_gpu_index:
log.info(f"current gpu_index only supports IP / L2, cosine dataset need normalize.")
return True

return False

def insert_embeddings(
Expand Down