Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary Quantization Support for pgvector HNSW Algorithm #389

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ Options:
--m INTEGER hnsw m
--ef-construction INTEGER hnsw ef-construction
--ef-search INTEGER hnsw ef-search
--quantization-type [none|halfvec]
--quantization-type [none|bit|halfvec]
quantization type for vectors
--custom-case-name TEXT Custom case name i.e. PerformanceCase1536D50K
--custom-case-description TEXT Custom name description
Expand Down
2 changes: 2 additions & 0 deletions vectordb_bench/backend/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class MetricType(str, Enum):
L2 = "L2"
COSINE = "COSINE"
IP = "IP"
HAMMING = "HAMMING"
JACCARD = "JACCARD"


class IndexType(str, Enum):
Expand Down
48 changes: 47 additions & 1 deletion vectordb_bench/backend/clients/pgvector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
from pydantic import SecretStr

from vectordb_bench.backend.clients.api import MetricType

from ....cli.cli import (
CommonTypedDict,
HNSWFlavor1,
Expand All @@ -16,6 +18,13 @@
from vectordb_bench.backend.clients import DB



def set_default_quantized_fetch_limit(ctx, param, value):
if ctx.params.get("reranking") and value is None:
# ef_search is the default value for quantized_fetch_limit as it's bound by ef_search.
return ctx.params["ef_search"]
return value

class PgVectorTypedDict(CommonTypedDict):
user_name: Annotated[
str, click.option("--user-name", type=str, help="Db username", required=True)
Expand Down Expand Up @@ -61,11 +70,45 @@ class PgVectorTypedDict(CommonTypedDict):
Optional[str],
click.option(
"--quantization-type",
type=click.Choice(["none", "halfvec"]),
type=click.Choice(["none", "bit", "halfvec"]),
help="quantization type for vectors",
required=False,
),
]
reranking: Annotated[
Optional[bool],
click.option(
"--reranking/--skip-reranking",
type=bool,
help="Enable reranking for HNSW search for binary quantization",
default=False,
),
]
reranking_metric: Annotated[
Optional[str],
click.option(
"--reranking-metric",
type=click.Choice(
[metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]]
),
help="Distance metric for reranking",
default="COSINE",
show_default=True,
),
]
quantized_fetch_limit: Annotated[
Optional[int],
click.option(
"--quantized-fetch-limit",
type=int,
help="Limit of fetching quantized vector ranked by distance for reranking \
-- bound by ef_search",
required=False,
callback=set_default_quantized_fetch_limit,
)
]



class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
...
Expand Down Expand Up @@ -126,6 +169,9 @@ def PgVectorHNSW(
maintenance_work_mem=parameters["maintenance_work_mem"],
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
reranking=parameters["reranking"],
reranking_metric=parameters["reranking_metric"],
quantized_fetch_limit=parameters["quantized_fetch_limit"],
),
**parameters,
)
33 changes: 28 additions & 5 deletions vectordb_bench/backend/clients/pgvector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def parse_metric(self) -> str:
elif self.metric_type == MetricType.IP:
return "halfvec_ip_ops"
return "halfvec_cosine_ops"
elif self.quantization_type == "bit":
if self.metric_type == MetricType.JACCARD:
return "bit_jaccard_ops"
return "bit_hamming_ops"
else:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
Expand All @@ -73,18 +77,31 @@ def parse_metric(self) -> str:
return "vector_cosine_ops"

def parse_metric_fun_op(self) -> LiteralString:
if self.metric_type == MetricType.L2:
return "<->"
elif self.metric_type == MetricType.IP:
return "<#>"
return "<=>"
if self.quantization_type == "bit":
if self.metric_type == MetricType.JACCARD:
return "<%>"
return "<~>"
else:
if self.metric_type == MetricType.L2:
return "<->"
elif self.metric_type == MetricType.IP:
return "<#>"
return "<=>"

def parse_metric_fun_str(self) -> str:
if self.metric_type == MetricType.L2:
return "l2_distance"
elif self.metric_type == MetricType.IP:
return "max_inner_product"
return "cosine_distance"

def parse_reranking_metric_fun_op(self) -> LiteralString:
if self.reranking_metric == MetricType.L2:
return "<->"
elif self.reranking_metric == MetricType.IP:
return "<#>"
return "<=>"


@abstractmethod
def index_param(self) -> PgVectorIndexParam:
Expand Down Expand Up @@ -195,6 +212,9 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None
quantization_type: Optional[str] = None
reranking: Optional[bool] = None
quantized_fetch_limit: Optional[int] = None
reranking_metric: Optional[str] = None

def index_param(self) -> PgVectorIndexParam:
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
Expand All @@ -214,6 +234,9 @@ def index_param(self) -> PgVectorIndexParam:
def search_param(self) -> PgVectorSearchParam:
return {
"metric_fun_op": self.parse_metric_fun_op(),
"reranking": self.reranking,
"reranking_metric_fun_op": self.parse_reranking_metric_fun_op(),
"quantized_fetch_limit": self.quantized_fetch_limit,
}

def session_param(self) -> PgVectorSessionCommands:
Expand Down
177 changes: 113 additions & 64 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from psycopg import Connection, Cursor, sql

from ..api import VectorDB
from .config import PgVectorConfigDict, PgVectorIndexConfig
from .config import PgVectorConfigDict, PgVectorIndexConfig, PgVectorHNSWConfig

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,6 +87,92 @@ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
assert cursor is not None, "Cursor is not initialized"

return conn, cursor

def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
index_param = self.case_config.index_param()
reranking = self.case_config.search_param()["reranking"]
column_name = (
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
if index_param["quantization_type"] == "bit"
else sql.SQL("embedding")
)
search_vector = (
sql.SQL("binary_quantize({0})").format(sql.Placeholder())
if index_param["quantization_type"] == "bit"
else sql.Placeholder()
)

# The following sections assume that the quantization_type value matches the quantization function name
if index_param["quantization_type"] != None:
if index_param["quantization_type"] == "bit" and reranking:
# Embeddings needs to be passed to binary_quantize function if quantization_type is bit
search_query = sql.Composed(
[
sql.SQL(
"""
SELECT i.id
FROM (
SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
FROM public.{table_name} {where_clause}
ORDER BY {column_name}::{quantization_type}({dim})
"""
).format(
table_name=sql.Identifier(self.table_name),
column_name=column_name,
reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(
"""
{search_vector}
LIMIT {quantized_fetch_limit}
) i
ORDER BY i.distance
LIMIT %s::int
"""
).format(
search_vector=search_vector,
quantized_fetch_limit=sql.Literal(
self.case_config.search_param()["quantized_fetch_limit"]
),
),
]
)
else:
search_query = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) "
).format(
table_name=sql.Identifier(self.table_name),
column_name=column_name,
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector),
]
)
else:
search_query = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
).format(
table_name=sql.Identifier(self.table_name),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)

return search_query


@contextmanager
def init(self) -> Generator[None, None, None]:
Expand All @@ -112,63 +198,8 @@ def init(self) -> Generator[None, None, None]:
self.cursor.execute(command)
self.conn.commit()

index_param = self.case_config.index_param()
# The following sections assume that the quantization_type value matches the quantization function name
if index_param["quantization_type"] != None:
self._filtered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding::{quantization_type}({dim}) "
).format(
table_name=sql.Identifier(self.table_name),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)
else:
self._filtered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
).format(table_name=sql.Identifier(self.table_name)),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)

if index_param["quantization_type"] != None:
self._unfiltered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} ORDER BY embedding::{quantization_type}({dim}) "
).format(
table_name=sql.Identifier(self.table_name),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)
else:
self._unfiltered_search = sql.Composed(
[
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
sql.Identifier(self.table_name)
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)
self._filtered_search = self._generate_search_query(filtered=True)
self._unfiltered_search = self._generate_search_query()

try:
yield
Expand Down Expand Up @@ -306,12 +337,17 @@ def _create_index(self):
if index_param["quantization_type"] != None:
index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} ((embedding::{quantization_type}({dim})) {embedding_metric})
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (({column_name}::{quantization_type}({dim})) {embedding_metric})
"""
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
column_name=(
sql.SQL("binary_quantize({0})").format(sql.Identifier("embedding"))
if index_param["quantization_type"] == "bit"
else sql.Identifier("embedding")
),
index_type=sql.Identifier(index_param["index_type"]),
# This assumes that the quantization_type value matches the quantization function name
quantization_type=sql.SQL(index_param["quantization_type"]),
Expand Down Expand Up @@ -406,15 +442,28 @@ def search_embedding(
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"

index_param = self.case_config.index_param()
search_param = self.case_config.search_param()
q = np.asarray(query)
if filters:
gt = filters.get("id")
result = self.cursor.execute(
if index_param["quantization_type"] == "bit" and search_param["reranking"]:
result = self.cursor.execute(
self._filtered_search, (q, gt, q, k), prepare=True, binary=True
)
else:
result = self.cursor.execute(
self._filtered_search, (gt, q, k), prepare=True, binary=True
)
)

else:
result = self.cursor.execute(
if index_param["quantization_type"] == "bit" and search_param["reranking"]:
result = self.cursor.execute(
self._unfiltered_search, (q, q, k), prepare=True, binary=True
)
else:
result = self.cursor.execute(
self._unfiltered_search, (q, k), prepare=True, binary=True
)
)

return [int(i[0]) for i in result.fetchall()]
6 changes: 6 additions & 0 deletions vectordb_bench/frontend/components/run_test/caseSelector.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def caseConfigSetting(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, active
value=config.inputConfig["value"],
help=config.inputHelp,
)
elif config.inputType == InputType.Bool:
caseConfig[config.label] = column.checkbox(
config.displayLabel if config.displayLabel else config.label.value,
value=config.inputConfig["value"],
help=config.inputHelp,
)
k += 1
if k == 0:
columns[1].write("Auto")
Loading