diff --git a/README.md b/README.md index eeda249ae..e4b9a7c62 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ All the database client supported |pgvector|`pip install vectordb-bench[pgvector]`| |pgvecto.rs|`pip install vectordb-bench[pgvecto_rs]`| |redis|`pip install vectordb-bench[redis]`| +|memorydb| `pip install vectordb-bench[memorydb]`| |chromadb|`pip install vectordb-bench[chromadb]`| ### Run diff --git a/vectordb_bench/backend/clients/memorydb/cli.py b/vectordb_bench/backend/clients/memorydb/cli.py index 29a812e96..c623cd083 100644 --- a/vectordb_bench/backend/clients/memorydb/cli.py +++ b/vectordb_bench/backend/clients/memorydb/cli.py @@ -26,7 +26,7 @@ class MemoryDBTypedDict(TypedDict): is_flag=True, show_default=True, default=True, - help="Enable or disable SSL for Redis", + help="Enable or disable SSL for MemoryDB", ), ] ssl_ca_certs: Annotated[ @@ -44,7 +44,7 @@ class MemoryDBTypedDict(TypedDict): is_flag=True, show_default=True, default=False, - help="Cluster Mode Disabled (CMD) for Redis doesn't use Cluster conn", + help="Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance. In production, MemoryDB only supports CME mode", ), ] diff --git a/vectordb_bench/backend/clients/memorydb/memorydb.py b/vectordb_bench/backend/clients/memorydb/memorydb.py index 4b8166f11..20f00ccba 100644 --- a/vectordb_bench/backend/clients/memorydb/memorydb.py +++ b/vectordb_bench/backend/clients/memorydb/memorydb.py @@ -29,24 +29,24 @@ def __init__( self.case_config = db_case_config self.collection_name = INDEX_NAME self.target_nodes = RedisCluster.RANDOM if not self.db_config["cmd"] else None - self.insert_batch_size = db_case_config.insert_batch_size or 1 + self.insert_batch_size = db_case_config.insert_batch_size or 10 self.dbsize = kwargs.get("num_rows") - # Create a redis connection, if db has password configured, add it to the connection here and in init(): - log.info(f"Redis establishing connection to: {self.db_config}") + # Create a MemoryDB connection, if db has password configured, add it to the connection here and in init(): + log.info(f"Establishing connection to: {self.db_config}") conn = self.get_client(primary=True) log.info(f"Connection established: {conn}") log.info(conn.execute_command("INFO server")) if drop_old: try: - log.info(f"Redis client getting info for: {INDEX_NAME}") + log.info(f"MemoryDB client getting info for: {INDEX_NAME}") info = conn.ft(INDEX_NAME).info() log.info(f"Index info: {info}") except redis.exceptions.ResponseError as e: log.error(e) drop_old = False - log.info(f"Redis client drop_old collection: {self.collection_name}") + log.info(f"MemoryDB client drop_old collection: {self.collection_name}") log.info("Executing FLUSHALL") conn.flushall() @@ -73,11 +73,9 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis): index_param = self.case_config.index_param() search_param = self.case_config.search_param() vector_parameters = { # Vector Index Type: FLAT or HNSW - "TYPE": "FLOAT32", # FLOAT32 or FLOAT64 + "TYPE": "FLOAT32", "DIM": vector_dimensions, # Number of Vector Dimensions - "DISTANCE_METRIC": index_param[ - "metric" - ], # Vector Search Distance Metric + "DISTANCE_METRIC": index_param["metric"], # Vector Search Distance Metric } if index_param["m"]: vector_parameters["M"] = index_param["m"] @@ -89,7 +87,7 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis): schema = ( TagField("id"), NumericField("metadata"), - VectorField("vector", # Vector Field Name + VectorField("vector", # Vector Field Name "HNSW", vector_parameters ), ) @@ -100,8 +98,8 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis): def get_client(self, **kwargs): """ - Gets either cluster connection or normal redis connection based on `cmd` flag. - CMD stands for Cluster Mode Disabled and is a "mode" for Redis. + Gets either cluster connection or normal connection based on `cmd` flag. + CMD stands for Cluster Mode Disabled and is a "mode". """ if not self.db_config["cmd"]: # Cluster mode enabled @@ -228,7 +226,7 @@ def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis): def search_embedding( self, query: list[float], - k: int = 100, + k: int = 10, filters: dict | None = None, timeout: int | None = None, **kwargs: Any, @@ -236,7 +234,7 @@ def search_embedding( assert self.conn is not None query_vector = np.array(query).astype(np.float32).tobytes() - query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k).dialect(2) + query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k) query_params = {"vec": query_vector} if filters: @@ -246,11 +244,11 @@ def search_embedding( # Removing '>=' from the id_value: '>=10000' metadata_value = filters.get("metadata")[2:] if id_value and metadata_value: - query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k).dialect(2) + query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k) elif id_value: #gets exact match for id - query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k).dialect(2) + query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k) else: #metadata only case, greater than or equal to metadata value - query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k).dialect(2) + query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k) res = self.conn.ft(INDEX_NAME).search(query_obj, query_params) return [int(doc["id"]) for doc in res.docs] \ No newline at end of file