Skip to content

Commit

Permalink
Add table quantization type
Browse files Browse the repository at this point in the history
  • Loading branch information
lucagiac81 committed Dec 16, 2024
1 parent 529837a commit 68c0f59
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 24 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ Options:
--ef-construction INTEGER hnsw ef-construction
--ef-search INTEGER hnsw ef-search
--quantization-type [none|bit|halfvec]
quantization type for vectors
quantization type for vectors (in index)
--table-quantization-type [none|bit|halfvec]
quantization type for vectors (in table). If
equal to bit, the parameter
quantization_type will be set to bit too.
--custom-case-name TEXT Custom case name i.e. PerformanceCase1536D50K
--custom-case-description TEXT Custom name description
--custom-case-load-timeout INTEGER
Expand Down
14 changes: 13 additions & 1 deletion vectordb_bench/backend/clients/pgvector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,17 @@ class PgVectorTypedDict(CommonTypedDict):
click.option(
"--quantization-type",
type=click.Choice(["none", "bit", "halfvec"]),
help="quantization type for vectors",
help="quantization type for vectors (in index)",
required=False,
),
]
table_quantization_type: Annotated[
Optional[str],
click.option(
"--table-quantization-type",
type=click.Choice(["none", "bit", "halfvec"]),
help="quantization type for vectors (in table). "
"If equal to bit, the parameter quantization_type will be set to bit too.",
required=False,
),
]
Expand Down Expand Up @@ -137,6 +147,7 @@ def PgVectorIVFFlat(
lists=parameters["lists"],
probes=parameters["probes"],
quantization_type=parameters["quantization_type"],
table_quantization_type=parameters["table_quantization_type"],
reranking=parameters["reranking"],
reranking_metric=parameters["reranking_metric"],
quantized_fetch_limit=parameters["quantized_fetch_limit"],
Expand Down Expand Up @@ -173,6 +184,7 @@ def PgVectorHNSW(
maintenance_work_mem=parameters["maintenance_work_mem"],
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
table_quantization_type=parameters["table_quantization_type"],
reranking=parameters["reranking"],
reranking_metric=parameters["reranking_metric"],
quantized_fetch_limit=parameters["quantized_fetch_limit"],
Expand Down
20 changes: 16 additions & 4 deletions vectordb_bench/backend/clients/pgvector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None
quantization_type: Optional[str] = None
table_quantization_type: Optional[str] = None
reranking: Optional[bool] = None
quantized_fetch_limit: Optional[int] = None
reranking_metric: Optional[str] = None

def index_param(self) -> PgVectorIndexParam:
index_parameters = {"lists": self.lists}
if self.quantization_type == "none":
self.quantization_type = None
if self.quantization_type == "none" or self.quantization_type == None:
self.quantization_type = "vector"
if self.table_quantization_type == "none" or self.table_quantization_type == None:
self.table_quantization_type = "vector"
if self.table_quantization_type == "bit":
self.quantization_type = "bit"
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
Expand All @@ -185,6 +190,7 @@ def index_param(self) -> PgVectorIndexParam:
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"quantization_type": self.quantization_type,
"table_quantization_type": self.table_quantization_type,
}

def search_param(self) -> PgVectorSearchParam:
Expand Down Expand Up @@ -218,14 +224,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None
quantization_type: Optional[str] = None
table_quantization_type: Optional[str] = None
reranking: Optional[bool] = None
quantized_fetch_limit: Optional[int] = None
reranking_metric: Optional[str] = None

def index_param(self) -> PgVectorIndexParam:
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
if self.quantization_type == "none":
self.quantization_type = None
if self.quantization_type == "none" or self.quantization_type == None:
self.quantization_type = "vector"
if self.table_quantization_type == "none" or self.table_quantization_type == None:
self.table_quantization_type = "vector"
if self.table_quantization_type == "bit":
self.quantization_type = "bit"
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
Expand All @@ -235,6 +246,7 @@ def index_param(self) -> PgVectorIndexParam:
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"quantization_type": self.quantization_type,
"table_quantization_type": self.table_quantization_type,
}

def search_param(self) -> PgVectorSearchParam:
Expand Down
77 changes: 59 additions & 18 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
reranking = self.case_config.search_param()["reranking"]
column_name = (
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
if index_param["quantization_type"] == "bit"
if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit"
else sql.SQL("embedding")
)
search_vector = (
Expand All @@ -103,7 +103,8 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
)

# The following sections assume that the quantization_type value matches the quantization function name
if index_param["quantization_type"] != None:
if index_param["quantization_type"] != index_param["table_quantization_type"]:
# Reranking makes sense only if table quantization is not "bit"
if index_param["quantization_type"] == "bit" and reranking:
# Embeddings needs to be passed to binary_quantize function if quantization_type is bit
search_query = sql.Composed(
Expand All @@ -112,29 +113,33 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
"""
SELECT i.id
FROM (
SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance
FROM public.{table_name} {where_clause}
ORDER BY {column_name}::{quantization_type}({dim})
"""
).format(
table_name=sql.Identifier(self.table_name),
column_name=column_name,
reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]),
search_vector=search_vector,
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(
"""
{search_vector}
{search_vector}::{quantization_type}({dim})
LIMIT {quantized_fetch_limit}
) i
ORDER BY i.distance
LIMIT %s::int
"""
).format(
search_vector=search_vector,
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
quantized_fetch_limit=sql.Literal(
self.case_config.search_param()["quantized_fetch_limit"]
),
Expand All @@ -154,7 +159,11 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector),
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
search_vector=search_vector,
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)
else:
Expand All @@ -167,7 +176,11 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
search_vector=search_vector,
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)

Expand Down Expand Up @@ -334,7 +347,7 @@ def _create_index(self):
else:
with_clause = sql.Composed(())

if index_param["quantization_type"] != None:
if index_param["quantization_type"] != index_param["table_quantization_type"]:
index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
Expand Down Expand Up @@ -377,15 +390,20 @@ def _create_index(self):
def _create_table(self, dim: int):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()

try:
log.info(f"{self.name} client create table : {self.table_name}")

# create table
self.cursor.execute(
sql.SQL(
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
).format(table_name=sql.Identifier(self.table_name), dim=dim)
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));"
).format(
table_name=sql.Identifier(self.table_name),
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
dim=dim)
)
self.cursor.execute(
sql.SQL(
Expand All @@ -407,19 +425,42 @@ def insert_embeddings(
) -> Tuple[int, Optional[Exception]]:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()

try:
metadata_arr = np.array(metadata)
embeddings_arr = np.array(embeddings)

with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
table_name=sql.Identifier(self.table_name)
)
) as copy:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
copy.write_row((row, embeddings_arr[i]))

if index_param["table_quantization_type"] == "bit":
with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format(
table_name=sql.Identifier(self.table_name)
)
) as copy:
# Same logic as pgvector binary_quantize
for i, row in enumerate(metadata_arr):
embeddings_bit = ''
for embedding in embeddings_arr[i]:
if embedding > 0:
embeddings_bit += '1'
else:
embeddings_bit += '0'
copy.write_row((str(row), embeddings_bit))
else:
with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
table_name=sql.Identifier(self.table_name)
)
) as copy:
if index_param["table_quantization_type"] == "halfvec":
copy.set_types(["bigint", "halfvec"])
for i, row in enumerate(metadata_arr):
copy.write_row((row, np.float16(embeddings_arr[i])))
else:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
copy.write_row((row, embeddings_arr[i]))
self.conn.commit()

if kwargs.get("last_batch"):
Expand Down
15 changes: 15 additions & 0 deletions vectordb_bench/frontend/config/dbCaseConfigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,19 @@ class CaseConfigInput(BaseModel):
],
)

CaseConfigParamInput_TableQuantizationType_PgVector = CaseConfigInput(
label=CaseConfigParamType.tableQuantizationType,
inputType=InputType.Option,
inputConfig={
"options": ["none", "bit", "halfvec"],
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
IndexType.IVFFlat.value,
],
)

CaseConfigParamInput_max_parallel_workers_PgVectorRS = CaseConfigInput(
label=CaseConfigParamType.max_parallel_workers,
displayLabel="Max parallel workers",
Expand Down Expand Up @@ -1149,6 +1162,7 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_m,
CaseConfigParamInput_EFConstruction_PgVector,
CaseConfigParamInput_QuantizationType_PgVector,
CaseConfigParamInput_TableQuantizationType_PgVector,
CaseConfigParamInput_maintenance_work_mem_PgVector,
CaseConfigParamInput_max_parallel_workers_PgVector,
]
Expand All @@ -1160,6 +1174,7 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_Lists_PgVector,
CaseConfigParamInput_Probes_PgVector,
CaseConfigParamInput_QuantizationType_PgVector,
CaseConfigParamInput_TableQuantizationType_PgVector,
CaseConfigParamInput_maintenance_work_mem_PgVector,
CaseConfigParamInput_max_parallel_workers_PgVector,
CaseConfigParamInput_reranking_PgVector,
Expand Down
1 change: 1 addition & 0 deletions vectordb_bench/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class CaseConfigParamType(Enum):
probes = "probes"
quantizationType = "quantization_type"
quantizationRatio = "quantization_ratio"
tableQuantizationType = "table_quantization_type"
reranking = "reranking"
rerankingMetric = "reranking_metric"
quantizedFetchLimit = "quantized_fetch_limit"
Expand Down

0 comments on commit 68c0f59

Please sign in to comment.