Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quantization option for pgvector with support for halfvec #366

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ Options:
--m INTEGER hnsw m
--ef-construction INTEGER hnsw ef-construction
--ef-search INTEGER hnsw ef-search
--quantization-type [none|halfvec]
quantization type for vectors
--help Show this message and exit.
```
#### Using a configuration file.
Expand Down
16 changes: 14 additions & 2 deletions vectordb_bench/backend/clients/pgvector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ class PgVectorTypedDict(CommonTypedDict):
required=False,
),
]

quantization_type: Annotated[
Optional[str],
click.option(
"--quantization-type",
type=click.Choice(["none", "halfvec"]),
help="quantization type for vectors",
required=False,
),
]

class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
...
Expand All @@ -79,7 +87,10 @@ def PgVectorIVFFlat(
db_name=parameters["db_name"],
),
db_case_config=PgVectorIVFFlatConfig(
metric_type=None, lists=parameters["lists"], probes=parameters["probes"]
metric_type=None,
lists=parameters["lists"],
probes=parameters["probes"],
quantization_type=parameters["quantization_type"],
),
**parameters,
)
Expand Down Expand Up @@ -111,6 +122,7 @@ def PgVectorHNSW(
ef_search=parameters["ef_search"],
maintenance_work_mem=parameters["maintenance_work_mem"],
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
),
**parameters,
)
25 changes: 20 additions & 5 deletions vectordb_bench/backend/clients/pgvector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,18 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
create_index_after_load: bool = True

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
elif self.metric_type == MetricType.IP:
return "vector_ip_ops"
return "vector_cosine_ops"
if self.quantization_type == "halfvec":
if self.metric_type == MetricType.L2:
return "halfvec_l2_ops"
elif self.metric_type == MetricType.IP:
return "halfvec_ip_ops"
return "halfvec_cosine_ops"
else:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
elif self.metric_type == MetricType.IP:
return "vector_ip_ops"
return "vector_cosine_ops"

def parse_metric_fun_op(self) -> LiteralString:
if self.metric_type == MetricType.L2:
Expand Down Expand Up @@ -143,9 +150,12 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
index: IndexType = IndexType.ES_IVFFlat
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None
quantization_type: Optional[str] = None

def index_param(self) -> PgVectorIndexParam:
index_parameters = {"lists": self.lists}
if self.quantization_type == "none":
self.quantization_type = None
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
Expand All @@ -154,6 +164,7 @@ def index_param(self) -> PgVectorIndexParam:
),
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"quantization_type": self.quantization_type,
}

def search_param(self) -> PgVectorSearchParam:
Expand Down Expand Up @@ -183,9 +194,12 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
index: IndexType = IndexType.ES_HNSW
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None
quantization_type: Optional[str] = None

def index_param(self) -> PgVectorIndexParam:
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
if self.quantization_type == "none":
self.quantization_type = None
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
Expand All @@ -194,6 +208,7 @@ def index_param(self) -> PgVectorIndexParam:
),
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"quantization_type": self.quantization_type,
}

def search_param(self) -> PgVectorSearchParam:
Expand Down
113 changes: 84 additions & 29 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,63 @@ def init(self) -> Generator[None, None, None]:
self.cursor.execute(command)
self.conn.commit()

self._filtered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
).format(table_name=sql.Identifier(self.table_name)),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)
index_param = self.case_config.index_param()
# The following sections assume that the quantization_type value matches the quantization function name
if index_param["quantization_type"] != None:
self._filtered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding::{quantization_type}({dim}) "
).format(
table_name=sql.Identifier(self.table_name),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)
else:
self._filtered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
).format(table_name=sql.Identifier(self.table_name)),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)

self._unfiltered_search = sql.Composed(
[
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
sql.Identifier(self.table_name)
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)
if index_param["quantization_type"] != None:
self._unfiltered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} ORDER BY embedding::{quantization_type}({dim}) "
).format(
table_name=sql.Identifier(self.table_name),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)
else:
self._unfiltered_search = sql.Composed(
[
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
sql.Identifier(self.table_name)
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)

try:
yield
Expand Down Expand Up @@ -265,17 +303,34 @@ def _create_index(self):
else:
with_clause = sql.Composed(())

index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (embedding {embedding_metric})
"""
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
index_type=sql.Identifier(index_param["index_type"]),
embedding_metric=sql.Identifier(index_param["metric"]),
)
if index_param["quantization_type"] != None:
index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} ((embedding::{quantization_type}({dim})) {embedding_metric})
"""
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
index_type=sql.Identifier(index_param["index_type"]),
# This assumes that the quantization_type value matches the quantization function name
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=self.dim,
embedding_metric=sql.Identifier(index_param["metric"]),
)
else:
index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (embedding {embedding_metric})
"""
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
index_type=sql.Identifier(index_param["index_type"]),
embedding_metric=sql.Identifier(index_param["metric"]),
)

index_create_sql_with_with_clause = (
index_create_sql + with_clause
).join(" ")
Expand Down
15 changes: 15 additions & 0 deletions vectordb_bench/frontend/config/dbCaseConfigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,19 @@ class CaseConfigInput(BaseModel):
],
)

CaseConfigParamInput_QuantizationType_PgVector = CaseConfigInput(
label=CaseConfigParamType.quantizationType,
inputType=InputType.Option,
inputConfig={
"options": ["none", "halfvec"],
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
IndexType.IVFFlat.value,
],
)

CaseConfigParamInput_QuantizationRatio_PgVectoRS = CaseConfigInput(
label=CaseConfigParamType.quantizationRatio,
inputType=InputType.Option,
Expand Down Expand Up @@ -831,6 +844,7 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_Lists_PgVector,
CaseConfigParamInput_m,
CaseConfigParamInput_EFConstruction_PgVector,
CaseConfigParamInput_QuantizationType_PgVector,
CaseConfigParamInput_maintenance_work_mem_PgVector,
CaseConfigParamInput_max_parallel_workers_PgVector,
]
Expand All @@ -841,6 +855,7 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_EFSearch_PgVector,
CaseConfigParamInput_Lists_PgVector,
CaseConfigParamInput_Probes_PgVector,
CaseConfigParamInput_QuantizationType_PgVector,
CaseConfigParamInput_maintenance_work_mem_PgVector,
CaseConfigParamInput_max_parallel_workers_PgVector,
]
Expand Down