Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add table quantization option for pgvector #427

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ Options:
--ef-construction INTEGER hnsw ef-construction
--ef-search INTEGER hnsw ef-search
--quantization-type [none|bit|halfvec]
quantization type for vectors
quantization type for vectors (in index)
--table-quantization-type [none|bit|halfvec]
quantization type for vectors (in table). If
equal to bit, the parameter
quantization_type will be set to bit too.
--custom-case-name TEXT Custom case name i.e. PerformanceCase1536D50K
--custom-case-description TEXT Custom name description
--custom-case-load-timeout INTEGER
Expand Down
14 changes: 13 additions & 1 deletion vectordb_bench/backend/clients/pgvector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,17 @@ class PgVectorTypedDict(CommonTypedDict):
click.option(
"--quantization-type",
type=click.Choice(["none", "bit", "halfvec"]),
help="quantization type for vectors",
help="quantization type for vectors (in index)",
required=False,
),
]
table_quantization_type: Annotated[
Optional[str],
click.option(
"--table-quantization-type",
type=click.Choice(["none", "bit", "halfvec"]),
help="quantization type for vectors (in table). "
"If equal to bit, the parameter quantization_type will be set to bit too.",
required=False,
),
]
Expand Down Expand Up @@ -137,6 +147,7 @@ def PgVectorIVFFlat(
lists=parameters["lists"],
probes=parameters["probes"],
quantization_type=parameters["quantization_type"],
table_quantization_type=parameters["table_quantization_type"],
reranking=parameters["reranking"],
reranking_metric=parameters["reranking_metric"],
quantized_fetch_limit=parameters["quantized_fetch_limit"],
Expand Down Expand Up @@ -173,6 +184,7 @@ def PgVectorHNSW(
maintenance_work_mem=parameters["maintenance_work_mem"],
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
table_quantization_type=parameters["table_quantization_type"],
reranking=parameters["reranking"],
reranking_metric=parameters["reranking_metric"],
quantized_fetch_limit=parameters["quantized_fetch_limit"],
Expand Down
20 changes: 16 additions & 4 deletions vectordb_bench/backend/clients/pgvector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,19 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None
quantization_type: Optional[str] = None
table_quantization_type: Optional[str] = None
reranking: Optional[bool] = None
quantized_fetch_limit: Optional[int] = None
reranking_metric: Optional[str] = None

def index_param(self) -> PgVectorIndexParam:
index_parameters = {"lists": self.lists}
if self.quantization_type == "none":
self.quantization_type = None
if self.quantization_type == "none" or self.quantization_type == None:
self.quantization_type = "vector"
if self.table_quantization_type == "none" or self.table_quantization_type == None:
self.table_quantization_type = "vector"
if self.table_quantization_type == "bit":
self.quantization_type = "bit"
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
Expand All @@ -185,6 +190,7 @@ def index_param(self) -> PgVectorIndexParam:
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"quantization_type": self.quantization_type,
"table_quantization_type": self.table_quantization_type,
}

def search_param(self) -> PgVectorSearchParam:
Expand Down Expand Up @@ -218,14 +224,19 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None
quantization_type: Optional[str] = None
table_quantization_type: Optional[str] = None
reranking: Optional[bool] = None
quantized_fetch_limit: Optional[int] = None
reranking_metric: Optional[str] = None

def index_param(self) -> PgVectorIndexParam:
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
if self.quantization_type == "none":
self.quantization_type = None
if self.quantization_type == "none" or self.quantization_type == None:
self.quantization_type = "vector"
if self.table_quantization_type == "none" or self.table_quantization_type == None:
self.table_quantization_type = "vector"
if self.table_quantization_type == "bit":
self.quantization_type = "bit"
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
Expand All @@ -235,6 +246,7 @@ def index_param(self) -> PgVectorIndexParam:
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"quantization_type": self.quantization_type,
"table_quantization_type": self.table_quantization_type,
}

def search_param(self) -> PgVectorSearchParam:
Expand Down
77 changes: 59 additions & 18 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
reranking = self.case_config.search_param()["reranking"]
column_name = (
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
if index_param["quantization_type"] == "bit"
if index_param["quantization_type"] == "bit" and index_param["table_quantization_type"] != "bit"
else sql.SQL("embedding")
)
search_vector = (
Expand All @@ -103,7 +103,8 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
)

# The following sections assume that the quantization_type value matches the quantization function name
if index_param["quantization_type"] != None:
if index_param["quantization_type"] != index_param["table_quantization_type"]:
# Reranking makes sense only if table quantization is not "bit"
if index_param["quantization_type"] == "bit" and reranking:
# Embeddings needs to be passed to binary_quantize function if quantization_type is bit
search_query = sql.Composed(
Expand All @@ -112,29 +113,33 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
"""
SELECT i.id
FROM (
SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
SELECT id, embedding {reranking_metric_fun_op} %s::{table_quantization_type} AS distance
FROM public.{table_name} {where_clause}
ORDER BY {column_name}::{quantization_type}({dim})
"""
).format(
table_name=sql.Identifier(self.table_name),
column_name=column_name,
reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]),
search_vector=search_vector,
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(
"""
{search_vector}
{search_vector}::{quantization_type}({dim})
LIMIT {quantized_fetch_limit}
) i
ORDER BY i.distance
LIMIT %s::int
"""
).format(
search_vector=search_vector,
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
quantized_fetch_limit=sql.Literal(
self.case_config.search_param()["quantized_fetch_limit"]
),
Expand All @@ -154,7 +159,11 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector),
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
search_vector=search_vector,
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)
else:
Expand All @@ -167,7 +176,11 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
sql.SQL(" {search_vector}::{quantization_type}({dim}) LIMIT %s::int").format(
search_vector=search_vector,
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)

Expand Down Expand Up @@ -334,7 +347,7 @@ def _create_index(self):
else:
with_clause = sql.Composed(())

if index_param["quantization_type"] != None:
if index_param["quantization_type"] != index_param["table_quantization_type"]:
index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
Expand Down Expand Up @@ -377,15 +390,20 @@ def _create_index(self):
def _create_table(self, dim: int):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()

try:
log.info(f"{self.name} client create table : {self.table_name}")

# create table
self.cursor.execute(
sql.SQL(
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
).format(table_name=sql.Identifier(self.table_name), dim=dim)
"CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding {table_quantization_type}({dim}));"
).format(
table_name=sql.Identifier(self.table_name),
table_quantization_type=sql.SQL(index_param["table_quantization_type"]),
dim=dim)
)
self.cursor.execute(
sql.SQL(
Expand All @@ -407,19 +425,42 @@ def insert_embeddings(
) -> Tuple[int, Optional[Exception]]:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()

try:
metadata_arr = np.array(metadata)
embeddings_arr = np.array(embeddings)

with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
table_name=sql.Identifier(self.table_name)
)
) as copy:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
copy.write_row((row, embeddings_arr[i]))

if index_param["table_quantization_type"] == "bit":
with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT TEXT)").format(
table_name=sql.Identifier(self.table_name)
)
) as copy:
# Same logic as pgvector binary_quantize
for i, row in enumerate(metadata_arr):
embeddings_bit = ''
for embedding in embeddings_arr[i]:
if embedding > 0:
embeddings_bit += '1'
else:
embeddings_bit += '0'
copy.write_row((str(row), embeddings_bit))
else:
with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
table_name=sql.Identifier(self.table_name)
)
) as copy:
if index_param["table_quantization_type"] == "halfvec":
copy.set_types(["bigint", "halfvec"])
for i, row in enumerate(metadata_arr):
copy.write_row((row, np.float16(embeddings_arr[i])))
else:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
copy.write_row((row, embeddings_arr[i]))
self.conn.commit()

if kwargs.get("last_batch"):
Expand Down
15 changes: 15 additions & 0 deletions vectordb_bench/frontend/config/dbCaseConfigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,19 @@ class CaseConfigInput(BaseModel):
],
)

CaseConfigParamInput_TableQuantizationType_PgVector = CaseConfigInput(
label=CaseConfigParamType.tableQuantizationType,
inputType=InputType.Option,
inputConfig={
"options": ["none", "bit", "halfvec"],
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
IndexType.IVFFlat.value,
],
)

CaseConfigParamInput_max_parallel_workers_PgVectorRS = CaseConfigInput(
label=CaseConfigParamType.max_parallel_workers,
displayLabel="Max parallel workers",
Expand Down Expand Up @@ -1149,6 +1162,7 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_m,
CaseConfigParamInput_EFConstruction_PgVector,
CaseConfigParamInput_QuantizationType_PgVector,
CaseConfigParamInput_TableQuantizationType_PgVector,
CaseConfigParamInput_maintenance_work_mem_PgVector,
CaseConfigParamInput_max_parallel_workers_PgVector,
]
Expand All @@ -1160,6 +1174,7 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_Lists_PgVector,
CaseConfigParamInput_Probes_PgVector,
CaseConfigParamInput_QuantizationType_PgVector,
CaseConfigParamInput_TableQuantizationType_PgVector,
CaseConfigParamInput_maintenance_work_mem_PgVector,
CaseConfigParamInput_max_parallel_workers_PgVector,
CaseConfigParamInput_reranking_PgVector,
Expand Down
1 change: 1 addition & 0 deletions vectordb_bench/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class CaseConfigParamType(Enum):
probes = "probes"
quantizationType = "quantization_type"
quantizationRatio = "quantization_ratio"
tableQuantizationType = "table_quantization_type"
reranking = "reranking"
rerankingMetric = "reranking_metric"
quantizedFetchLimit = "quantized_fetch_limit"
Expand Down
Loading