From 753c46d4d16b46a9904cb9dd16862a3f96dec345 Mon Sep 17 00:00:00 2001 From: Sheharyar Ahmad Date: Mon, 2 Sep 2024 07:16:38 +0500 Subject: [PATCH] Add support for filtered search in pgvectorscale (#364) * Add support for filtered search in pgvectorscale * Fixed filtered query result not assigned. --- .../clients/pgvectorscale/pgvectorscale.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py index 7c8c314c2..d8f26394c 100644 --- a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +++ b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py @@ -22,6 +22,9 @@ class PgVectorScale(VectorDB): conn: psycopg.Connection[Any] | None = None coursor: psycopg.Cursor[Any] | None = None + _unfiltered_search: sql.Composed + _filtered_search: sql.Composed + def __init__( self, dim: int, @@ -99,6 +102,16 @@ def init(self) -> Generator[None, None, None]: self.cursor.execute(command) self.conn.commit() + self._filtered_search = sql.Composed( + [ + sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format( + sql.Identifier(self.table_name), + ), + sql.SQL(self.case_config.search_param()["metric_fun_op"]), + sql.SQL(" %s::vector LIMIT %s::int") + ] + ) + self._unfiltered_search = sql.Composed( [ sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format( @@ -264,9 +277,14 @@ def search_embedding( assert self.cursor is not None, "Cursor is not initialized" q = np.asarray(query) - # TODO add filters support - result = self.cursor.execute( - self._unfiltered_search, (q, k), prepare=True, binary=True - ) + if filters: + gt = filters.get("id") + result = self.cursor.execute( + self._filtered_search, (gt, q, k), prepare=True, binary=True + ) + else: + result = self.cursor.execute( + self._unfiltered_search, (q, k), prepare=True, binary=True + ) return [int(i[0]) for i in result.fetchall()]