diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py index 05ff9541..885995de 100644 --- a/vectordb_bench/backend/clients/milvus/cli.py +++ b/vectordb_bench/backend/clients/milvus/cli.py @@ -1,4 +1,4 @@ -from typing import Annotated, TypedDict, Unpack +from typing import Annotated, TypedDict, Unpack, Optional import click from pydantic import SecretStr @@ -21,6 +21,12 @@ class MilvusTypedDict(TypedDict): uri: Annotated[ str, click.option("--uri", type=str, help="uri connection string", required=True) ] + user_name: Annotated[ + Optional[str], click.option("--user-name", type=str, help="Db username", required=False) + ] + password: Annotated[ + Optional[str], click.option("--password", type=str, help="Db password", required=False) + ] class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict): @@ -37,6 +43,8 @@ def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]): db_config=MilvusConfig( db_label=parameters["db_label"], uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), ), db_case_config=AutoIndexConfig(), **parameters, @@ -53,6 +61,8 @@ def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]): db_config=MilvusConfig( db_label=parameters["db_label"], uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), ), db_case_config=FLATConfig(), **parameters, @@ -73,6 +83,8 @@ def MilvusHNSW(**parameters: Unpack[MilvusHNSWTypedDict]): db_config=MilvusConfig( db_label=parameters["db_label"], uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]) if parameters["password"] else None, ), db_case_config=HNSWConfig( M=parameters["m"], @@ -97,6 +109,8 @@ def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]): db_config=MilvusConfig( db_label=parameters["db_label"], uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), ), db_case_config=IVFFlatConfig( nlist=parameters["nlist"], @@ -116,6 +130,8 @@ def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]): db_config=MilvusConfig( db_label=parameters["db_label"], uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), ), db_case_config=IVFSQ8Config( nlist=parameters["nlist"], @@ -143,6 +159,8 @@ def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]): db_config=MilvusConfig( db_label=parameters["db_label"], uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), ), db_case_config=DISKANNConfig( search_list=parameters["search_list"], @@ -174,6 +192,8 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]): db_config=MilvusConfig( db_label=parameters["db_label"], uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), ), db_case_config=GPUIVFFlatConfig( nlist=parameters["nlist"], @@ -208,6 +228,8 @@ def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]): db_config=MilvusConfig( db_label=parameters["db_label"], uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), ), db_case_config=GPUIVFPQConfig( nlist=parameters["nlist"], @@ -274,6 +296,8 @@ def MilvusGPUCAGRA(**parameters: Unpack[MilvusGPUCAGRATypedDict]): db_config=MilvusConfig( db_label=parameters["db_label"], uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), ), db_case_config=GPUCAGRAConfig( intermediate_graph_degree=parameters["intermediate_graph_degree"], diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index eea7b11f..059ef046 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -1,12 +1,26 @@ -from pydantic import BaseModel, SecretStr +from pydantic import BaseModel, SecretStr, validator from ..api import DBConfig, DBCaseConfig, MetricType, IndexType class MilvusConfig(DBConfig): uri: SecretStr = "http://localhost:19530" + user: str | None = None + password: SecretStr | None = None def to_dict(self) -> dict: - return {"uri": self.uri.get_secret_value()} + return { + "uri": self.uri.get_secret_value(), + "user": self.user if self.user else None, + "password": self.password.get_secret_value() if self.password else None, + } + + @validator("*") + def not_empty_field(cls, v, field): + if field.name in cls.common_short_configs() or field.name in cls.common_long_configs() or field.name in ["user", "password"]: + return v + if isinstance(v, (str, SecretStr)) and len(v) == 0: + raise ValueError("Empty string!") + return v class MilvusIndexConfig(BaseModel):