From 805f3a33885d8c66904c79190aeaa443ae02a31c Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 18 Sep 2024 10:33:18 +0200 Subject: [PATCH 01/39] add poc draft --- tests/unit/test_q.py | 484 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 484 insertions(+) create mode 100644 tests/unit/test_q.py diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py new file mode 100644 index 00000000..004cd4af --- /dev/null +++ b/tests/unit/test_q.py @@ -0,0 +1,484 @@ +# test_querybuilder.py + +import unittest +from dataclasses import dataclass +from typing import Any, List, Union, Optional, Dict +import json + + +@dataclass +class Field: + name: str + + def __eq__(self, other: Any) -> "Condition": + return Condition(f"{self.name} = {self._format_value(other)}") + + def __ne__(self, other: Any) -> "Condition": + return Condition(f"{self.name} != {self._format_value(other)}") + + def __lt__(self, other: Any) -> "Condition": + return Condition(f"{self.name} < {self._format_value(other)}") + + def __le__(self, other: Any) -> "Condition": + return Condition(f"{self.name} <= {self._format_value(other)}") + + def __gt__(self, other: Any) -> "Condition": + return Condition(f"{self.name} > {self._format_value(other)}") + + def __ge__(self, other: Any) -> "Condition": + return Condition(f"{self.name} >= {self._format_value(other)}") + + def contains(self, value: Any) -> "Condition": + return Condition(f"{self.name} contains {self._format_value(value)}") + + def matches(self, value: Any) -> "Condition": + return Condition(f"{self.name} matches {self._format_value(value)}") + + def in_range(self, start: Any, end: Any) -> "Condition": + return Condition(f"range({self.name}, {start}, {end})") + + def le(self, value: Any) -> "Condition": + return self.__le__(value) + + def lt(self, value: Any) -> "Condition": + return self.__lt__(value) + + def ge(self, value: Any) -> "Condition": + return self.__ge__(value) + + def gt(self, value: Any) -> "Condition": + return self.__gt__(value) + + def eq(self, value: Any) -> "Condition": + return self.__eq__(value) + + def __str__(self) -> str: + return self.name + + @staticmethod + def _format_value(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + else: + return str(value) + + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f'"{k}":{self._format_annotation_value(v)}' for k, v in annotations.items() + ) + return Condition(f"({{{{{annotations_str}}}}})({self.name})") + + @staticmethod + def _format_annotation_value(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, dict): + return ( + "{" + + ",".join( + f'"{k}":{Field._format_annotation_value(v)}' + for k, v in value.items() + ) + + "}" + ) + elif isinstance(value, list): + return ( + "[" + + ",".join(f"{Field._format_annotation_value(v)}" for v in value) + + "]" + ) + else: + return str(value) + + +@dataclass +class Condition: + expression: str + + def __and__(self, other: "Condition") -> "Condition": + left = self.expression + right = other.expression + + if " and " in left or " or " in left: + left = f"({left})" + if " and " in right or " or " in right: + right = f"({right})" + + return Condition(f"{left} and {right}") + + def __or__(self, other: "Condition") -> "Condition": + left = self.expression + right = other.expression + + if " and " in left or " or " in left: + left = f"({left})" + if " and " in right or " or " in right: + right = f"({right})" + + return Condition(f"{left} or {right}") + + def __invert__(self) -> "Condition": + return Condition(f"!({self.expression})") + + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f'"{k}":{Field._format_annotation_value(v)}' for k, v in annotations.items() + ) + return Condition(f"([{annotations_str}]({self.expression}))") + + def build(self) -> str: + return self.expression + + +class Query: + def __init__(self, select_fields: Union[str, List[str], List[Field]]): + self.select_fields = ( + ", ".join(select_fields) + if isinstance(select_fields, List) + and all(isinstance(f, str) for f in select_fields) + else ", ".join(str(f) for f in select_fields) + ) + self.sources = "*" + self.condition = None + self.order_by_clauses = [] + self.limit_value = None + self.offset_value = None + self.timeout_value = None + self.parameters = {} + self.grouping = None + + def from_(self, *sources: str) -> "Query": + self.sources = ", ".join(sources) + return self + + def where(self, condition: Union[Condition, Field]) -> "Query": + if isinstance(condition, Field): + self.condition = condition + else: + self.condition = condition + return self + + def order_by_field( + self, + field: str, + ascending: bool = True, + annotations: Optional[Dict[str, Any]] = None, + ) -> "Query": + direction = "asc" if ascending else "desc" + if annotations: + annotations_str = ",".join( + f'"{k}":{Field._format_annotation_value(v)}' + for k, v in annotations.items() + ) + self.order_by_clauses.append(f"{{{annotations_str}}}{field} {direction}") + else: + self.order_by_clauses.append(f"{field} {direction}") + return self + + def orderByAsc( + self, field: str, annotations: Optional[Dict[str, Any]] = None + ) -> "Query": + return self.order_by_field(field, True, annotations) + + def orderByDesc( + self, field: str, annotations: Optional[Dict[str, Any]] = None + ) -> "Query": + return self.order_by_field(field, False, annotations) + + def set_limit(self, limit: int) -> "Query": + self.limit_value = limit + return self + + def set_offset(self, offset: int) -> "Query": + self.offset_value = offset + return self + + def set_timeout(self, timeout: int) -> "Query": + self.timeout_value = timeout + return self + + def add_parameter(self, key: str, value: Any) -> "Query": + self.parameters[key] = value + return self + + def param(self, key: str, value: Any) -> "Query": + return self.add_parameter(key, value) + + def group(self, group_expression: str) -> "Query": + self.grouping = group_expression + return self + + def build(self) -> str: + query = f"yql=select {self.select_fields} from {self.sources}" + if self.condition: + query += f" where {self.condition.build()}" + if self.grouping: + query += f" | {self.grouping}" + if self.order_by_clauses: + query += " order by " + ", ".join(self.order_by_clauses) + if self.limit_value is not None: + query += f" limit {self.limit_value}" + if self.offset_value is not None: + query += f" offset {self.offset_value}" + if self.timeout_value is not None: + query += f" timeout {self.timeout_value}" + if self.parameters: + params = "&" + "&".join(f"{k}={v}" for k, v in self.parameters.items()) + query += params + return query + + +class Q: + @staticmethod + def select(*fields): + return Query(select_fields=list(fields)) + + @staticmethod + def p(*args): + if not args: + return Condition("") + else: + condition = args[0] + for arg in args[1:]: + condition = condition & arg + return condition + + @staticmethod + def ui(value: str = "", index: Optional[str] = None) -> Condition: + if index is None: + # Only value provided + return ( + Condition(f'userQuery("{value}")') + if value + else Condition("userQuery()") + ) + else: + # Both index and value provided + default_index_json = json.dumps({"defaultIndex": index}) + return Condition(f'({default_index_json})userQuery("{value}")') + + @staticmethod + def dotPdt(field: str, vector: Dict[str, int]) -> Condition: + vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" + return Condition(f"dotProduct({field}, {vector_str})") + + @staticmethod + def wtdSet(field: str, vector: Dict[str, int]) -> Condition: + vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" + return Condition(f"weightedSet({field}, {vector_str})") + + @staticmethod + def nonEmpty(condition: Condition) -> Condition: + return Condition(f"nonEmpty({condition.build()})") + + @staticmethod + def wand(field: str, weights, annotations: Dict[str, Any] = None) -> Condition: + if isinstance(weights, list): + weights_str = "[" + ",".join(str(item) for item in weights) + "]" + elif isinstance(weights, dict): + weights_str = "{" + ",".join(f'"{k}":{v}' for k, v in weights.items()) + "}" + else: + raise ValueError("Invalid weights for wand") + expr = f"wand({field}, {weights_str})" + if annotations: + annotations_str = ",".join( + f'"{k}":{Field._format_annotation_value(v)}' + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def weakand(*conditions, annotations: Dict[str, Any] = None) -> Condition: + conditions_str = ", ".join(cond.build() for cond in conditions) + expr = f"weakAnd({conditions_str})" + if annotations: + annotations_str = ",".join( + f'"{k}":{Field._format_annotation_value(v)}' + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def geoLocation(field: str, lat: float, lng: float, radius: str) -> Condition: + return Condition(f'geoLocation({field}, {lat}, {lng}, "{radius}")') + + @staticmethod + def nearestNeighbor( + field: str, query_vector: str, annotations: Dict[str, Any] = None + ) -> Condition: + if annotations: + if "targetHits" not in annotations: + raise ValueError("targetHits annotation is required") + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}nearestNeighbor({field}, {query_vector}))" + ) + else: + raise ValueError("Annotations are required for nearestNeighbor") + + @staticmethod + def rank(*queries) -> Condition: + queries_str = ", ".join(query.build() for query in queries) + return Condition(f"rank({queries_str})") + + +class G: + @staticmethod + def all(*args) -> str: + return "all(" + " ".join(args) + ")" + + @staticmethod + def group(field: str) -> str: + return f"group({field})" + + @staticmethod + def maxRtn(value: int) -> str: + return f"max({value})" + + @staticmethod + def each(*args) -> str: + return "each(" + " ".join(args) + ")" + + @staticmethod + def output(output_func: str) -> str: + return f"output({output_func})" + + @staticmethod + def count() -> str: + return "count()" + + @staticmethod + def summary() -> str: + return "summary()" + + +class A: + @staticmethod + def a(*args, **kwargs) -> Dict[str, Any]: + if args and isinstance(args[0], dict): + return args[0] + else: + annotations = {} + for i in range(0, len(args), 2): + annotations[args[i]] = args[i + 1] + annotations.update(kwargs) + return annotations + + @staticmethod + def filter() -> Dict[str, Any]: + return {"filter": True} + + @staticmethod + def defaultIndex(index: str) -> Dict[str, Any]: + return {"defaultIndex": index} + + @staticmethod + def append(annotations: Dict[str, Any], other: Dict[str, Any]) -> Dict[str, Any]: + annotations.update(other) + return annotations + + +class QTest(unittest.TestCase): + def test_select_specific_fields(self): + f1 = Field("f1") + condition = f1.contains("v1") + q = Query(select_fields=["f1", "f2"]).from_("sd1").where(condition).build() + + self.assertEqual(q, 'yql=select f1, f2 from sd1 where f1 contains "v1"') + + def test_select_from_specific_sources(self): + f1 = Field("f1") + condition = f1.contains("v1") + q = Query(select_fields="*").from_("sd1").where(condition).build() + + self.assertEqual(q, 'yql=select * from sd1 where f1 contains "v1"') + + def test_select_from_multiples_sources(self): + f1 = Field("f1") + condition = f1.contains("v1") + q = Query(select_fields="*").from_("sd1", "sd2").where(condition).build() + + self.assertEqual(q, 'yql=select * from sd1, sd2 where f1 contains "v1"') + + def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): + f1 = Field("f1") + f2 = Field("f2") + f3 = Field("f3") + f4 = Field("f4") + condition = ((f1.contains("v1") & f2.contains("v2")) | f3.contains("v3")) & ( + ~f4.contains("v4") + ) + q = ( + Query(select_fields="*") + .from_("sd1") + .where(condition) + .set_offset(1) + .set_limit(2) + .set_timeout(3) + .orderByDesc("f1") + .orderByAsc("f2") + .param("paramk1", "paramv1") + .build() + ) + + expected = 'yql=select * from sd1 where ((f1 contains "v1" and f2 contains "v2") or f3 contains "v3") and !(f4 contains "v4") order by f1 desc, f2 asc limit 2 offset 1 timeout 3¶mk1=paramv1' + self.assertEqual(q, expected) + + def test_matches(self): + condition = ( + (Field("f1").matches("v1") & Field("f2").matches("v2")) + | Field("f3").matches("v3") + ) & ~Field("f4").matches("v4") + q = Query(select_fields="*").from_("sd1").where(condition).build() + expected = 'yql=select * from sd1 where ((f1 matches "v1" and f2 matches "v2") or f3 matches "v3") and !(f4 matches "v4")' + self.assertEqual(q, expected) + + def test_nested_queries(self): + nested_query = (Field("f2").contains("2") & Field("f3").contains("3")) | ( + Field("f2").contains("4") & ~Field("f3").contains("5") + ) + condition = Field("f1").contains("1") & ~nested_query + q = Query(select_fields="*").from_("sd1").where(condition).build() + expected = 'yql=select * from sd1 where f1 contains "1" and (!((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5"))))' + self.assertEqual(q, expected) + + def test_userInput_with_and_without_defaultIndex(self): + condition = Q.ui(value="value1") & Q.ui(index="index", value="value2") + q = Query(select_fields="*").from_("sd1").where(condition).build() + expected = 'yql=select * from sd1 where userQuery("value1") and ({"defaultIndex": "index"})userQuery("value2")' + self.assertEqual(q, expected) + + def test_fields_duration(self): + f1 = Field("subject") + f2 = Field("display_date") + f3 = Field("duration") + condition = ( + Query(select_fields=[f1, f2]).from_("calendar").where(f3 > 0).build() + ) + expected = "yql=select subject, display_date from calendar where duration > 0" + self.assertEqual(condition, expected) + + def test_nearest_neighbor(self): + condition_uq = Q.ui() + condition_nn = Q.nearestNeighbor( + field="dense_rep", query_vector="q_dense", annotations={"targetHits": 10} + ) + q = ( + Query(select_fields=["id, text"]) + .from_("m") + .where(condition_uq | condition_nn) + .build() + ) + expected = "yql=select id, text from m where userQuery() or ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" + self.assertEqual(q, expected) + + +if __name__ == "__main__": + unittest.main() From 5bbd8e7ce85ebea825fd506f99ecfa9c37040a31 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Tue, 24 Sep 2024 08:48:15 +0200 Subject: [PATCH 02/39] rename ui to userQuery --- tests/unit/test_q.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 004cd4af..7083e793 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -246,7 +246,7 @@ def p(*args): return condition @staticmethod - def ui(value: str = "", index: Optional[str] = None) -> Condition: + def userQuery(value: str = "", index: Optional[str] = None) -> Condition: if index is None: # Only value provided return ( @@ -450,7 +450,9 @@ def test_nested_queries(self): self.assertEqual(q, expected) def test_userInput_with_and_without_defaultIndex(self): - condition = Q.ui(value="value1") & Q.ui(index="index", value="value2") + condition = Q.userQuery(value="value1") & Q.userQuery( + index="index", value="value2" + ) q = Query(select_fields="*").from_("sd1").where(condition).build() expected = 'yql=select * from sd1 where userQuery("value1") and ({"defaultIndex": "index"})userQuery("value2")' self.assertEqual(q, expected) @@ -466,7 +468,7 @@ def test_fields_duration(self): self.assertEqual(condition, expected) def test_nearest_neighbor(self): - condition_uq = Q.ui() + condition_uq = Q.userQuery() condition_nn = Q.nearestNeighbor( field="dense_rep", query_vector="q_dense", annotations={"targetHits": 10} ) From 8cd21f3492dc37fdff709a6d2a3c33f804a90d93 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Tue, 24 Sep 2024 09:51:44 +0200 Subject: [PATCH 03/39] nn operators - not passing --- tests/unit/test_q.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 7083e793..4e1440db 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -481,6 +481,24 @@ def test_nearest_neighbor(self): expected = "yql=select id, text from m where userQuery() or ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) + def test_build_many_nn_operators(self): + conditions = [ + Q.nearestNeighbor( + field="colbert", + query_vector=f"binary_vector_{i}", + annotations={"targetHits": 100}, + ) + for i in range(32) + ] + q = ( + Query(select_fields="*") + .from_("doc") + .where(condition=Q.p(*conditions)) + .build() + ) + expected = "yql=select * from doc where ({targetHits:100}nearestNeighbor(colbert, binary_vector_0)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_1)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_2)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_3)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_4)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_5)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_6)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_7)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_8)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_9)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_10)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_11)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_12)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_13)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_14)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_15)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_16)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_17)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_18)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_19)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_20)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_21)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_22)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_23)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_24)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_25)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_26)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_27)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_28)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_29)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_30)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_31))" + self.assertEqual(q, expected) + if __name__ == "__main__": unittest.main() From 8b8297f3ff9a5f747e3a70e64bc9d65d54282736 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 25 Sep 2024 09:00:43 +0200 Subject: [PATCH 04/39] start with annotations --- tests/unit/test_q.py | 300 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 282 insertions(+), 18 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 4e1440db..0ebc8647 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -64,9 +64,9 @@ def _format_value(value: Any) -> str: def annotate(self, annotations: Dict[str, Any]) -> "Condition": annotations_str = ",".join( - f'"{k}":{self._format_annotation_value(v)}' for k, v in annotations.items() + f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() ) - return Condition(f"({{{{{annotations_str}}}}})({self.name})") + return Condition(f"({{{annotations_str}}}){self.name}") @staticmethod def _format_annotation_value(value: Any) -> str: @@ -101,10 +101,9 @@ def __and__(self, other: "Condition") -> "Condition": left = self.expression right = other.expression - if " and " in left or " or " in left: - left = f"({left})" - if " and " in right or " or " in right: - right = f"({right})" + # Adjust parentheses based on operator precedence + left = f"({left})" if " or " in left else left + right = f"({right})" if " or " in right else right return Condition(f"{left} and {right}") @@ -112,10 +111,9 @@ def __or__(self, other: "Condition") -> "Condition": left = self.expression right = other.expression - if " and " in left or " or " in left: - left = f"({left})" - if " and " in right or " or " in right: - right = f"({right})" + # Always add parentheses if 'and' or 'or' is in the expressions + left = f"({left})" if " and " in left or " or " in left else left + right = f"({right})" if " and " in right or " or " in right else right return Condition(f"{left} or {right}") @@ -124,13 +122,39 @@ def __invert__(self) -> "Condition": def annotate(self, annotations: Dict[str, Any]) -> "Condition": annotations_str = ",".join( - f'"{k}":{Field._format_annotation_value(v)}' for k, v in annotations.items() + f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() ) - return Condition(f"([{annotations_str}]({self.expression}))") + return Condition(f"({{{annotations_str}}}){self.expression}") def build(self) -> str: return self.expression + @classmethod + def all(cls, *conditions: "Condition") -> "Condition": + """Combine multiple conditions using logical AND.""" + expressions = [] + for cond in conditions: + expr = cond.expression + # Wrap expressions with 'or' in parentheses + if " or " in expr: + expr = f"({expr})" + expressions.append(expr) + combined_expression = " and ".join(expressions) + return Condition(combined_expression) + + @classmethod + def any(cls, *conditions: "Condition") -> "Condition": + """Combine multiple conditions using logical OR.""" + expressions = [] + for cond in conditions: + expr = cond.expression + # Wrap expressions with 'and' or 'or' in parentheses + if " and " in expr or " or " in expr: + expr = f"({expr})" + expressions.append(expr) + combined_expression = " or ".join(expressions) + return Condition(combined_expression) + class Query: def __init__(self, select_fields: Union[str, List[str], List[Field]]): @@ -256,7 +280,9 @@ def userQuery(value: str = "", index: Optional[str] = None) -> Condition: ) else: # Both index and value provided - default_index_json = json.dumps({"defaultIndex": index}) + default_index_json = json.dumps( + {"defaultIndex": index}, separators=(",", ":") + ) return Condition(f'({default_index_json})userQuery("{value}")') @staticmethod @@ -270,8 +296,12 @@ def wtdSet(field: str, vector: Dict[str, int]) -> Condition: return Condition(f"weightedSet({field}, {vector_str})") @staticmethod - def nonEmpty(condition: Condition) -> Condition: - return Condition(f"nonEmpty({condition.build()})") + def nonEmpty(condition: Union[Condition, Field]) -> Condition: + if isinstance(condition, Field): + expr = str(condition) + else: + expr = condition.build() + return Condition(f"nonEmpty({expr})") @staticmethod def wand(field: str, weights, annotations: Dict[str, Any] = None) -> Condition: @@ -454,7 +484,7 @@ def test_userInput_with_and_without_defaultIndex(self): index="index", value="value2" ) q = Query(select_fields="*").from_("sd1").where(condition).build() - expected = 'yql=select * from sd1 where userQuery("value1") and ({"defaultIndex": "index"})userQuery("value2")' + expected = 'yql=select * from sd1 where userQuery("value1") and ({"defaultIndex":"index"})userQuery("value2")' self.assertEqual(q, expected) def test_fields_duration(self): @@ -482,6 +512,7 @@ def test_nearest_neighbor(self): self.assertEqual(q, expected) def test_build_many_nn_operators(self): + self.maxDiff = None conditions = [ Q.nearestNeighbor( field="colbert", @@ -490,13 +521,246 @@ def test_build_many_nn_operators(self): ) for i in range(32) ] + # Use Condition.any to combine conditions with OR q = ( Query(select_fields="*") .from_("doc") - .where(condition=Q.p(*conditions)) + .where(condition=Condition.any(*conditions)) + .build() + ) + expected = "yql=select * from doc where " + " or ".join( + [ + f"({{targetHits:100}}nearestNeighbor(colbert, binary_vector_{i}))" + for i in range(32) + ] + ) + self.assertEqual(q, expected) + + def test_field_comparison_operators(self): + f1 = Field("age") + condition = (f1 > 30) & (f1 <= 50) + q = Query(select_fields="*").from_("people").where(condition).build() + expected = "yql=select * from people where age > 30 and age <= 50" + self.assertEqual(q, expected) + + def test_field_in_range(self): + f1 = Field("age") + condition = f1.in_range(18, 65) + q = Query(select_fields="*").from_("people").where(condition).build() + expected = "yql=select * from people where range(age, 18, 65)" + self.assertEqual(q, expected) + + def test_field_annotation(self): + f1 = Field("title") + annotations = {"highlight": True} + annotated_field = f1.annotate(annotations) + q = Query(select_fields="*").from_("articles").where(annotated_field).build() + expected = "yql=select * from articles where ({highlight:true})title" + self.assertEqual(q, expected) + + def test_condition_annotation(self): + f1 = Field("title") + condition = f1.contains("Python") + annotated_condition = condition.annotate({"filter": True}) + q = ( + Query(select_fields="*") + .from_("articles") + .where(annotated_condition) + .build() + ) + expected = ( + 'yql=select * from articles where ({filter:true})title contains "Python"' + ) + self.assertEqual(q, expected) + + def test_grouping_aggregation(self): + grouping = G.all(G.group("category"), G.output(G.count())) + q = Query(select_fields="*").from_("products").group(grouping).build() + expected = "yql=select * from products | all(group(category) output(count()))" + self.assertEqual(q, expected) + + def test_add_parameter(self): + f1 = Field("title") + condition = f1.contains("Python") + q = ( + Query(select_fields="*") + .from_("articles") + .where(condition) + .add_parameter("tracelevel", 1) + .build() + ) + expected = ( + 'yql=select * from articles where title contains "Python"&tracelevel=1' + ) + self.assertEqual(q, expected) + + def test_custom_ranking_expression(self): + condition = Q.rank( + Q.userQuery(), Q.dotPdt("embedding", {"feature1": 1, "feature2": 2}) + ) + q = Query(select_fields="*").from_("documents").where(condition).build() + expected = 'yql=select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1,"feature2":2}))' + self.assertEqual(q, expected) + + def test_wand(self): + condition = Q.wand("keywords", {"apple": 10, "banana": 20}) + q = Query(select_fields="*").from_("fruits").where(condition).build() + expected = ( + 'yql=select * from fruits where wand(keywords, {"apple":10,"banana":20})' + ) + self.assertEqual(q, expected) + + def test_weakand(self): + condition1 = Field("title").contains("Python") + condition2 = Field("description").contains("Programming") + condition = Q.weakand( + condition1, condition2, annotations={"targetNumHits": 100} + ) + q = Query(select_fields="*").from_("articles").where(condition).build() + expected = 'yql=select * from articles where ({"targetNumHits":100}weakAnd(title contains "Python", description contains "Programming"))' + self.assertEqual(q, expected) + + def test_geoLocation(self): + condition = Q.geoLocation("location_field", 37.7749, -122.4194, "10km") + q = Query(select_fields="*").from_("places").where(condition).build() + expected = 'yql=select * from places where geoLocation(location_field, 37.7749, -122.4194, "10km")' + self.assertEqual(q, expected) + + def test_condition_all_any(self): + c1 = Field("f1").contains("v1") + c2 = Field("f2").contains("v2") + c3 = Field("f3").contains("v3") + condition = Condition.all(c1, c2, Condition.any(c3, ~c1)) + q = Query(select_fields="*").from_("sd1").where(condition).build() + expected = 'yql=select * from sd1 where f1 contains "v1" and f2 contains "v2" and (f3 contains "v3" or !(f1 contains "v1"))' + self.assertEqual(q, expected) + + def test_order_by_with_annotations(self): + f1 = "relevance" + f2 = "price" + annotations = A.a("strength", 0.5) + q = ( + Query(select_fields="*") + .from_("products") + .orderByDesc(f1, annotations) + .orderByAsc(f2) + .build() + ) + expected = 'yql=select * from products order by {"strength":0.5}relevance desc, price asc' + self.assertEqual(q, expected) + + def test_field_comparison_methods(self): + f1 = Field("age") + condition = f1.ge(18) & f1.lt(30) + q = Query(select_fields="*").from_("users").where(condition).build() + expected = "yql=select * from users where age >= 18 and age < 30" + self.assertEqual(q, expected) + + def test_filter_annotation(self): + f1 = Field("title") + condition = f1.contains("Python").annotate({"filter": True}) + q = Query(select_fields="*").from_("articles").where(condition).build() + expected = ( + 'yql=select * from articles where ({filter:true})title contains "Python"' + ) + self.assertEqual(q, expected) + + def test_nonEmpty(self): + condition = Q.nonEmpty(Field("comments").eq("any_value")) + q = Query(select_fields="*").from_("posts").where(condition).build() + expected = 'yql=select * from posts where nonEmpty(comments = "any_value")' + self.assertEqual(q, expected) + + def test_dotProduct(self): + condition = Q.dotPdt("vector_field", {"feature1": 1, "feature2": 2}) + q = Query(select_fields="*").from_("vectors").where(condition).build() + expected = 'yql=select * from vectors where dotProduct(vector_field, {"feature1":1,"feature2":2})' + self.assertEqual(q, expected) + + def test_in_range_string_values(self): + f1 = Field("date") + condition = f1.in_range("2021-01-01", "2021-12-31") + q = Query(select_fields="*").from_("events").where(condition).build() + expected = "yql=select * from events where range(date, 2021-01-01, 2021-12-31)" + self.assertEqual(q, expected) + + def test_condition_inversion(self): + f1 = Field("status") + condition = ~f1.eq("inactive") + q = Query(select_fields="*").from_("users").where(condition).build() + expected = 'yql=select * from users where !(status = "inactive")' + self.assertEqual(q, expected) + + def test_multiple_parameters(self): + f1 = Field("title") + condition = f1.contains("Python") + q = ( + Query(select_fields="*") + .from_("articles") + .where(condition) + .add_parameter("tracelevel", 1) + .add_parameter("language", "en") .build() ) - expected = "yql=select * from doc where ({targetHits:100}nearestNeighbor(colbert, binary_vector_0)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_1)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_2)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_3)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_4)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_5)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_6)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_7)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_8)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_9)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_10)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_11)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_12)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_13)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_14)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_15)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_16)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_17)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_18)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_19)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_20)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_21)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_22)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_23)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_24)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_25)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_26)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_27)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_28)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_29)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_30)) OR ({targetHits:100}nearestNeighbor(colbert, binary_vector_31))" + expected = 'yql=select * from articles where title contains "Python"&tracelevel=1&language=en' + self.assertEqual(q, expected) + + def test_multiple_groupings(self): + grouping = G.all( + G.group("category"), + G.maxRtn(10), + G.output(G.count()), + G.each(G.group("subcategory"), G.output(G.summary())), + ) + q = Query(select_fields="*").from_("products").group(grouping).build() + expected = "yql=select * from products | all(group(category) max(10) output(count()) each(group(subcategory) output(summary())))" + self.assertEqual(q, expected) + + def test_default_index_annotation(self): + condition = Q.userQuery("search terms", index="default_field") + q = Query(select_fields="*").from_("documents").where(condition).build() + expected = 'yql=select * from documents where ({"defaultIndex":"default_field"})userQuery("search terms")' + self.assertEqual(q, expected) + + def test_Q_p_function(self): + condition = Q.p( + Field("f1").contains("v1"), + Field("f2").contains("v2"), + Field("f3").contains("v3"), + ) + q = Query(select_fields="*").from_("sd1").where(condition).build() + expected = 'yql=select * from sd1 where f1 contains "v1" and f2 contains "v2" and f3 contains "v3"' + self.assertEqual(q, expected) + + def test_rank_multiple_conditions(self): + condition = Q.rank( + Q.userQuery(), + Q.dotPdt("embedding", {"feature1": 1}), + Q.wtdSet("tags", {"tag1": 2}), + ) + q = Query(select_fields="*").from_("documents").where(condition).build() + expected = 'yql=select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1}), weightedSet(tags, {"tag1":2}))' + self.assertEqual(q, expected) + + def test_nonEmpty_with_annotations(self): + annotated_field = Field("comments").annotate(A.filter()) + condition = Q.nonEmpty(annotated_field) + q = Query(select_fields="*").from_("posts").where(condition).build() + expected = "yql=select * from posts where nonEmpty(({filter:true})comments)" + self.assertEqual(q, expected) + + def test_weight_annotation(self): + condition = Field("title").contains("heads", annotations={"weight": 200}) + q = Query(select_fields="*").from_("s1").where(condition).build() + expected = "yql=select * from s1 where title contains({weight:200}'heads')" + self.assertEqual(q, expected) + + def test_nearest_neighbor_annotations(self): + condition = Q.nearestNeighbor( + field="dense_rep", query_vector="q_dense", annotations={"targetHits": 10} + ) + q = Query(select_fields=["id, text"]).from_("m").where(condition).build() + expected = "yql=select id, text from m where ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) From 6fce39db07c0ad1271e5f994ad880957bb13951b Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 25 Sep 2024 09:32:18 +0200 Subject: [PATCH 05/39] more tests --- tests/unit/test_q.py | 190 ++++++++++++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 67 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 0ebc8647..c1261087 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -1,5 +1,3 @@ -# test_querybuilder.py - import unittest from dataclasses import dataclass from typing import Any, List, Union, Optional, Dict @@ -28,14 +26,45 @@ def __gt__(self, other: Any) -> "Condition": def __ge__(self, other: Any) -> "Condition": return Condition(f"{self.name} >= {self._format_value(other)}") - def contains(self, value: Any) -> "Condition": - return Condition(f"{self.name} contains {self._format_value(value)}") + def contains( + self, value: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + value_str = self._format_value(value) + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{self.name} contains({{{annotations_str}}}{value_str})") + else: + return Condition(f"{self.name} contains {value_str}") - def matches(self, value: Any) -> "Condition": - return Condition(f"{self.name} matches {self._format_value(value)}") + def matches( + self, value: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + value_str = self._format_value(value) + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{self.name} matches({{{annotations_str}}}{value_str})") + else: + return Condition(f"{self.name} matches {value_str}") - def in_range(self, start: Any, end: Any) -> "Condition": - return Condition(f"range({self.name}, {start}, {end})") + def in_range( + self, start: Any, end: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}range({self.name}, {start}, {end}))" + ) + else: + return Condition(f"range({self.name}, {start}, {end})") def le(self, value: Any) -> "Condition": return self.__le__(value) @@ -59,15 +88,11 @@ def __str__(self) -> str: def _format_value(value: Any) -> str: if isinstance(value, str): return f'"{value}"' + elif isinstance(value, Condition): + return value.build() else: return str(value) - def annotate(self, annotations: Dict[str, Any]) -> "Condition": - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() - ) - return Condition(f"({{{annotations_str}}}){self.name}") - @staticmethod def _format_annotation_value(value: Any) -> str: if isinstance(value, str): @@ -92,6 +117,12 @@ def _format_annotation_value(value: Any) -> str: else: return str(value) + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" for k, v in annotations.items() + ) + return Condition(f"({{{annotations_str}}}){self.name}") + @dataclass class Condition: @@ -286,14 +317,32 @@ def userQuery(value: str = "", index: Optional[str] = None) -> Condition: return Condition(f'({default_index_json})userQuery("{value}")') @staticmethod - def dotPdt(field: str, vector: Dict[str, int]) -> Condition: + def dotProduct( + field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None + ) -> Condition: vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" - return Condition(f"dotProduct({field}, {vector_str})") + expr = f"dotProduct({field}, {vector_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) @staticmethod - def wtdSet(field: str, vector: Dict[str, int]) -> Condition: + def weightedSet( + field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None + ) -> Condition: vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" - return Condition(f"weightedSet({field}, {vector_str})") + expr = f"weightedSet({field}, {vector_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) @staticmethod def nonEmpty(condition: Union[Condition, Field]) -> Condition: @@ -304,7 +353,9 @@ def nonEmpty(condition: Union[Condition, Field]) -> Condition: return Condition(f"nonEmpty({expr})") @staticmethod - def wand(field: str, weights, annotations: Dict[str, Any] = None) -> Condition: + def wand( + field: str, weights, annotations: Optional[Dict[str, Any]] = None + ) -> Condition: if isinstance(weights, list): weights_str = "[" + ",".join(str(item) for item in weights) + "]" elif isinstance(weights, dict): @@ -314,14 +365,14 @@ def wand(field: str, weights, annotations: Dict[str, Any] = None) -> Condition: expr = f"wand({field}, {weights_str})" if annotations: annotations_str = ",".join( - f'"{k}":{Field._format_annotation_value(v)}' + f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" return Condition(expr) @staticmethod - def weakand(*conditions, annotations: Dict[str, Any] = None) -> Condition: + def weakAnd(*conditions, annotations: Dict[str, Any] = None) -> Condition: conditions_str = ", ".join(cond.build() for cond in conditions) expr = f"weakAnd({conditions_str})" if annotations: @@ -333,25 +384,34 @@ def weakand(*conditions, annotations: Dict[str, Any] = None) -> Condition: return Condition(expr) @staticmethod - def geoLocation(field: str, lat: float, lng: float, radius: str) -> Condition: - return Condition(f'geoLocation({field}, {lat}, {lng}, "{radius}")') - - @staticmethod - def nearestNeighbor( - field: str, query_vector: str, annotations: Dict[str, Any] = None + def geoLocation( + field: str, + lat: float, + lng: float, + radius: str, + annotations: Optional[Dict[str, Any]] = None, ) -> Condition: + expr = f'geoLocation({field}, {lat}, {lng}, "{radius}")' if annotations: - if "targetHits" not in annotations: - raise ValueError("targetHits annotation is required") annotations_str = ",".join( f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() ) - return Condition( - f"({{{annotations_str}}}nearestNeighbor({field}, {query_vector}))" - ) - else: - raise ValueError("Annotations are required for nearestNeighbor") + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def nearestNeighbor( + field: str, query_vector: str, annotations: Dict[str, Any] + ) -> Condition: + if "targetHits" not in annotations: + raise ValueError("targetHits annotation is required") + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}nearestNeighbor({field}, {query_vector}))" + ) @staticmethod def rank(*queries) -> Condition: @@ -389,33 +449,29 @@ def summary() -> str: return "summary()" -class A: - @staticmethod - def a(*args, **kwargs) -> Dict[str, Any]: - if args and isinstance(args[0], dict): - return args[0] - else: - annotations = {} - for i in range(0, len(args), 2): - annotations[args[i]] = args[i + 1] - annotations.update(kwargs) - return annotations - - @staticmethod - def filter() -> Dict[str, Any]: - return {"filter": True} - - @staticmethod - def defaultIndex(index: str) -> Dict[str, Any]: - return {"defaultIndex": index} - - @staticmethod - def append(annotations: Dict[str, Any], other: Dict[str, Any]) -> Dict[str, Any]: - annotations.update(other) - return annotations +class QTest(unittest.TestCase): + def test_dotProduct_with_annotations(self): + condition = Q.dotProduct( + "vector_field", + {"feature1": 1, "feature2": 2}, + annotations={"label": "myDotProduct"}, + ) + q = Query(select_fields="*").from_("vectors").where(condition).build() + expected = 'yql=select * from vectors where ({label:"myDotProduct"}dotProduct(vector_field, {"feature1":1,"feature2":2}))' + self.assertEqual(q, expected) + def test_geoLocation_with_annotations(self): + condition = Q.geoLocation( + "location_field", + 37.7749, + -122.4194, + "10km", + annotations={"targetHits": 100}, + ) + q = Query(select_fields="*").from_("places").where(condition).build() + expected = 'yql=select * from places where ({targetHits:100}geoLocation(location_field, 37.7749, -122.4194, "10km"))' + self.assertEqual(q, expected) -class QTest(unittest.TestCase): def test_select_specific_fields(self): f1 = Field("f1") condition = f1.contains("v1") @@ -596,7 +652,7 @@ def test_add_parameter(self): def test_custom_ranking_expression(self): condition = Q.rank( - Q.userQuery(), Q.dotPdt("embedding", {"feature1": 1, "feature2": 2}) + Q.userQuery(), Q.dotProduct("embedding", {"feature1": 1, "feature2": 2}) ) q = Query(select_fields="*").from_("documents").where(condition).build() expected = 'yql=select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1,"feature2":2}))' @@ -613,7 +669,7 @@ def test_wand(self): def test_weakand(self): condition1 = Field("title").contains("Python") condition2 = Field("description").contains("Programming") - condition = Q.weakand( + condition = Q.weakAnd( condition1, condition2, annotations={"targetNumHits": 100} ) q = Query(select_fields="*").from_("articles").where(condition).build() @@ -638,7 +694,7 @@ def test_condition_all_any(self): def test_order_by_with_annotations(self): f1 = "relevance" f2 = "price" - annotations = A.a("strength", 0.5) + annotations = {"strength": 0.5} q = ( Query(select_fields="*") .from_("products") @@ -672,7 +728,7 @@ def test_nonEmpty(self): self.assertEqual(q, expected) def test_dotProduct(self): - condition = Q.dotPdt("vector_field", {"feature1": 1, "feature2": 2}) + condition = Q.dotProduct("vector_field", {"feature1": 1, "feature2": 2}) q = Query(select_fields="*").from_("vectors").where(condition).build() expected = 'yql=select * from vectors where dotProduct(vector_field, {"feature1":1,"feature2":2})' self.assertEqual(q, expected) @@ -735,15 +791,15 @@ def test_Q_p_function(self): def test_rank_multiple_conditions(self): condition = Q.rank( Q.userQuery(), - Q.dotPdt("embedding", {"feature1": 1}), - Q.wtdSet("tags", {"tag1": 2}), + Q.dotProduct("embedding", {"feature1": 1}), + Q.weightedSet("tags", {"tag1": 2}), ) q = Query(select_fields="*").from_("documents").where(condition).build() expected = 'yql=select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1}), weightedSet(tags, {"tag1":2}))' self.assertEqual(q, expected) def test_nonEmpty_with_annotations(self): - annotated_field = Field("comments").annotate(A.filter()) + annotated_field = Field("comments").annotate({"filter": True}) condition = Q.nonEmpty(annotated_field) q = Query(select_fields="*").from_("posts").where(condition).build() expected = "yql=select * from posts where nonEmpty(({filter:true})comments)" @@ -752,7 +808,7 @@ def test_nonEmpty_with_annotations(self): def test_weight_annotation(self): condition = Field("title").contains("heads", annotations={"weight": 200}) q = Query(select_fields="*").from_("s1").where(condition).build() - expected = "yql=select * from s1 where title contains({weight:200}'heads')" + expected = 'yql=select * from s1 where title contains({weight:200}"heads")' self.assertEqual(q, expected) def test_nearest_neighbor_annotations(self): From c58a8a492a7593e7e75064df28c416838b7ad50f Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 25 Sep 2024 09:38:39 +0200 Subject: [PATCH 06/39] =?UTF-8?q?rest=20of=20methods=20=F0=9F=9A=80=20pass?= =?UTF-8?q?ing=20tests=20=E2=9C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_q.py | 223 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index c1261087..47e191a5 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -52,6 +52,12 @@ def matches( else: return Condition(f"{self.name} matches {value_str}") + def in_(self, *values) -> "Condition": + values_str = ", ".join( + f'"{v}"' if isinstance(v, str) else str(v) for v in values + ) + return Condition(f"{self.name} in ({values_str})") + def in_range( self, start: Any, end: Any, annotations: Optional[Dict[str, Any]] = None ) -> "Condition": @@ -418,6 +424,123 @@ def rank(*queries) -> Condition: queries_str = ", ".join(query.build() for query in queries) return Condition(f"rank({queries_str})") + @staticmethod + def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"phrase({terms_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def near(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"near({terms_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def onear(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"onear({terms_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def sameElement(*conditions) -> Condition: + conditions_str = ", ".join(cond.build() for cond in conditions) + expr = f"sameElement({conditions_str})" + return Condition(expr) + + @staticmethod + def equiv(*terms) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"equiv({terms_str})" + return Condition(expr) + + @staticmethod + def uri(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: + expr = f'uri("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def fuzzy(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: + expr = f'fuzzy("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def userInput( + value: Optional[str] = None, annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + if value is None: + expr = "userInput()" + elif value.startswith("@"): + expr = f"userInput({value})" + else: + expr = f'userInput("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def predicate( + field: str, + attributes: Optional[Dict[str, Any]] = None, + range_attributes: Optional[Dict[str, Any]] = None, + ) -> Condition: + if attributes is None: + attributes_str = "0" + else: + attributes_str = ( + "{" + ",".join(f'"{k}":"{v}"' for k, v in attributes.items()) + "}" + ) + if range_attributes is None: + range_attributes_str = "0" + else: + range_attributes_str = ( + "{" + ",".join(f'"{k}":{v}' for k, v in range_attributes.items()) + "}" + ) + expr = f"predicate({field},{attributes_str},{range_attributes_str})" + return Condition(expr) + + @staticmethod + def true() -> Condition: + return Condition("true") + + @staticmethod + def false() -> Condition: + return Condition("false") + class G: @staticmethod @@ -820,5 +943,105 @@ def test_nearest_neighbor_annotations(self): self.assertEqual(q, expected) +class TestQueryBuilder(unittest.TestCase): + def test_phrase(self): + text = Field("text") + condition = text.contains(Q.phrase("st", "louis", "blues")) + query = Q.select("*").where(condition).build() + expected = ( + 'yql=select * from * where text contains phrase("st", "louis", "blues")' + ) + self.assertEqual(query, expected) + + def test_near(self): + title = Field("title") + condition = title.contains(Q.near("madonna", "saint")) + query = Q.select("*").where(condition).build() + expected = 'yql=select * from * where title contains near("madonna", "saint")' + self.assertEqual(query, expected) + + def test_onear(self): + title = Field("title") + condition = title.contains(Q.onear("madonna", "saint")) + query = Q.select("*").where(condition).build() + expected = 'yql=select * from * where title contains onear("madonna", "saint")' + self.assertEqual(query, expected) + + def test_sameElement(self): + persons = Field("persons") + first_name = Field("first_name") + last_name = Field("last_name") + year_of_birth = Field("year_of_birth") + condition = persons.contains( + Q.sameElement( + first_name.contains("Joe"), + last_name.contains("Smith"), + year_of_birth < 1940, + ) + ) + query = Q.select("*").where(condition).build() + expected = 'yql=select * from * where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' + self.assertEqual(query, expected) + + def test_equiv(self): + fieldName = Field("fieldName") + condition = fieldName.contains(Q.equiv("A", "B")) + query = Q.select("*").where(condition).build() + expected = 'yql=select * from * where fieldName contains equiv("A", "B")' + self.assertEqual(query, expected) + + def test_uri(self): + myUrlField = Field("myUrlField") + condition = myUrlField.contains(Q.uri("vespa.ai/foo")) + query = Q.select("*").where(condition).build() + expected = 'yql=select * from * where myUrlField contains uri("vespa.ai/foo")' + self.assertEqual(query, expected) + + def test_fuzzy(self): + myStringAttribute = Field("myStringAttribute") + annotations = {"prefixLength": 1, "maxEditDistance": 2} + condition = myStringAttribute.contains( + Q.fuzzy("parantesis", annotations=annotations) + ) + query = Q.select("*").where(condition).build() + expected = 'yql=select * from * where myStringAttribute contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' + self.assertEqual(query, expected) + + def test_userInput(self): + condition = Q.userInput("@animal") + query = Q.select("*").where(condition).param("animal", "panda").build() + expected = "yql=select * from * where userInput(@animal)&animal=panda" + self.assertEqual(query, expected) + + def test_in_operator(self): + integer_field = Field("integer_field") + condition = integer_field.in_(10, 20, 30) + query = Q.select("*").where(condition).build() + expected = "yql=select * from * where integer_field in (10, 20, 30)" + self.assertEqual(query, expected) + + def test_predicate(self): + condition = Q.predicate( + "predicate_field", + attributes={"gender": "Female"}, + range_attributes={"age": "20L"}, + ) + query = Q.select("*").where(condition).build() + expected = 'yql=select * from * where predicate(predicate_field,{"gender":"Female"},{"age":20L})' + self.assertEqual(query, expected) + + def test_true(self): + condition = Q.true() + query = Q.select("*").where(condition).build() + expected = "yql=select * from * where true" + self.assertEqual(query, expected) + + def test_false(self): + condition = Q.false() + query = Q.select("*").where(condition).build() + expected = "yql=select * from * where false" + self.assertEqual(query, expected) + + if __name__ == "__main__": unittest.main() From d749f54e3a7362eac27084b6f3d8ae9c9aadab73 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 25 Sep 2024 09:53:13 +0200 Subject: [PATCH 07/39] unify test classes --- tests/unit/test_q.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 47e191a5..98efd600 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -572,7 +572,7 @@ def summary() -> str: return "summary()" -class QTest(unittest.TestCase): +class TestQueryBuilder(unittest.TestCase): def test_dotProduct_with_annotations(self): condition = Q.dotProduct( "vector_field", @@ -942,8 +942,6 @@ def test_nearest_neighbor_annotations(self): expected = "yql=select id, text from m where ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) - -class TestQueryBuilder(unittest.TestCase): def test_phrase(self): text = Field("text") condition = text.contains(Q.phrase("st", "louis", "blues")) From b4bf206e02135217f864a902fcb00ef74a40ad10 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Fri, 27 Sep 2024 14:29:39 +0200 Subject: [PATCH 08/39] restructure --- tests/unit/test_q.py | 576 +-------------------------------- vespa/querybuilder/__init__.py | 2 + vespa/querybuilder/main.py | 573 ++++++++++++++++++++++++++++++++ 3 files changed, 578 insertions(+), 573 deletions(-) create mode 100644 vespa/querybuilder/__init__.py create mode 100644 vespa/querybuilder/main.py diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 98efd600..daeb0201 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -1,586 +1,16 @@ import unittest -from dataclasses import dataclass -from typing import Any, List, Union, Optional, Dict -import json - - -@dataclass -class Field: - name: str - - def __eq__(self, other: Any) -> "Condition": - return Condition(f"{self.name} = {self._format_value(other)}") - - def __ne__(self, other: Any) -> "Condition": - return Condition(f"{self.name} != {self._format_value(other)}") - - def __lt__(self, other: Any) -> "Condition": - return Condition(f"{self.name} < {self._format_value(other)}") - - def __le__(self, other: Any) -> "Condition": - return Condition(f"{self.name} <= {self._format_value(other)}") - - def __gt__(self, other: Any) -> "Condition": - return Condition(f"{self.name} > {self._format_value(other)}") - - def __ge__(self, other: Any) -> "Condition": - return Condition(f"{self.name} >= {self._format_value(other)}") - - def contains( - self, value: Any, annotations: Optional[Dict[str, Any]] = None - ) -> "Condition": - value_str = self._format_value(value) - if annotations: - annotations_str = ",".join( - f"{k}:{self._format_annotation_value(v)}" - for k, v in annotations.items() - ) - return Condition(f"{self.name} contains({{{annotations_str}}}{value_str})") - else: - return Condition(f"{self.name} contains {value_str}") - - def matches( - self, value: Any, annotations: Optional[Dict[str, Any]] = None - ) -> "Condition": - value_str = self._format_value(value) - if annotations: - annotations_str = ",".join( - f"{k}:{self._format_annotation_value(v)}" - for k, v in annotations.items() - ) - return Condition(f"{self.name} matches({{{annotations_str}}}{value_str})") - else: - return Condition(f"{self.name} matches {value_str}") - - def in_(self, *values) -> "Condition": - values_str = ", ".join( - f'"{v}"' if isinstance(v, str) else str(v) for v in values - ) - return Condition(f"{self.name} in ({values_str})") - - def in_range( - self, start: Any, end: Any, annotations: Optional[Dict[str, Any]] = None - ) -> "Condition": - if annotations: - annotations_str = ",".join( - f"{k}:{self._format_annotation_value(v)}" - for k, v in annotations.items() - ) - return Condition( - f"({{{annotations_str}}}range({self.name}, {start}, {end}))" - ) - else: - return Condition(f"range({self.name}, {start}, {end})") - - def le(self, value: Any) -> "Condition": - return self.__le__(value) - - def lt(self, value: Any) -> "Condition": - return self.__lt__(value) - - def ge(self, value: Any) -> "Condition": - return self.__ge__(value) - - def gt(self, value: Any) -> "Condition": - return self.__gt__(value) - - def eq(self, value: Any) -> "Condition": - return self.__eq__(value) - - def __str__(self) -> str: - return self.name - - @staticmethod - def _format_value(value: Any) -> str: - if isinstance(value, str): - return f'"{value}"' - elif isinstance(value, Condition): - return value.build() - else: - return str(value) - - @staticmethod - def _format_annotation_value(value: Any) -> str: - if isinstance(value, str): - return f'"{value}"' - elif isinstance(value, bool): - return str(value).lower() - elif isinstance(value, dict): - return ( - "{" - + ",".join( - f'"{k}":{Field._format_annotation_value(v)}' - for k, v in value.items() - ) - + "}" - ) - elif isinstance(value, list): - return ( - "[" - + ",".join(f"{Field._format_annotation_value(v)}" for v in value) - + "]" - ) - else: - return str(value) - - def annotate(self, annotations: Dict[str, Any]) -> "Condition": - annotations_str = ",".join( - f"{k}:{self._format_annotation_value(v)}" for k, v in annotations.items() - ) - return Condition(f"({{{annotations_str}}}){self.name}") - - -@dataclass -class Condition: - expression: str - - def __and__(self, other: "Condition") -> "Condition": - left = self.expression - right = other.expression - - # Adjust parentheses based on operator precedence - left = f"({left})" if " or " in left else left - right = f"({right})" if " or " in right else right - - return Condition(f"{left} and {right}") - - def __or__(self, other: "Condition") -> "Condition": - left = self.expression - right = other.expression - - # Always add parentheses if 'and' or 'or' is in the expressions - left = f"({left})" if " and " in left or " or " in left else left - right = f"({right})" if " and " in right or " or " in right else right - - return Condition(f"{left} or {right}") - - def __invert__(self) -> "Condition": - return Condition(f"!({self.expression})") - - def annotate(self, annotations: Dict[str, Any]) -> "Condition": - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() - ) - return Condition(f"({{{annotations_str}}}){self.expression}") - - def build(self) -> str: - return self.expression - - @classmethod - def all(cls, *conditions: "Condition") -> "Condition": - """Combine multiple conditions using logical AND.""" - expressions = [] - for cond in conditions: - expr = cond.expression - # Wrap expressions with 'or' in parentheses - if " or " in expr: - expr = f"({expr})" - expressions.append(expr) - combined_expression = " and ".join(expressions) - return Condition(combined_expression) - - @classmethod - def any(cls, *conditions: "Condition") -> "Condition": - """Combine multiple conditions using logical OR.""" - expressions = [] - for cond in conditions: - expr = cond.expression - # Wrap expressions with 'and' or 'or' in parentheses - if " and " in expr or " or " in expr: - expr = f"({expr})" - expressions.append(expr) - combined_expression = " or ".join(expressions) - return Condition(combined_expression) - - -class Query: - def __init__(self, select_fields: Union[str, List[str], List[Field]]): - self.select_fields = ( - ", ".join(select_fields) - if isinstance(select_fields, List) - and all(isinstance(f, str) for f in select_fields) - else ", ".join(str(f) for f in select_fields) - ) - self.sources = "*" - self.condition = None - self.order_by_clauses = [] - self.limit_value = None - self.offset_value = None - self.timeout_value = None - self.parameters = {} - self.grouping = None - - def from_(self, *sources: str) -> "Query": - self.sources = ", ".join(sources) - return self - - def where(self, condition: Union[Condition, Field]) -> "Query": - if isinstance(condition, Field): - self.condition = condition - else: - self.condition = condition - return self - - def order_by_field( - self, - field: str, - ascending: bool = True, - annotations: Optional[Dict[str, Any]] = None, - ) -> "Query": - direction = "asc" if ascending else "desc" - if annotations: - annotations_str = ",".join( - f'"{k}":{Field._format_annotation_value(v)}' - for k, v in annotations.items() - ) - self.order_by_clauses.append(f"{{{annotations_str}}}{field} {direction}") - else: - self.order_by_clauses.append(f"{field} {direction}") - return self - - def orderByAsc( - self, field: str, annotations: Optional[Dict[str, Any]] = None - ) -> "Query": - return self.order_by_field(field, True, annotations) - - def orderByDesc( - self, field: str, annotations: Optional[Dict[str, Any]] = None - ) -> "Query": - return self.order_by_field(field, False, annotations) - - def set_limit(self, limit: int) -> "Query": - self.limit_value = limit - return self - - def set_offset(self, offset: int) -> "Query": - self.offset_value = offset - return self - - def set_timeout(self, timeout: int) -> "Query": - self.timeout_value = timeout - return self - - def add_parameter(self, key: str, value: Any) -> "Query": - self.parameters[key] = value - return self - - def param(self, key: str, value: Any) -> "Query": - return self.add_parameter(key, value) - - def group(self, group_expression: str) -> "Query": - self.grouping = group_expression - return self - - def build(self) -> str: - query = f"yql=select {self.select_fields} from {self.sources}" - if self.condition: - query += f" where {self.condition.build()}" - if self.grouping: - query += f" | {self.grouping}" - if self.order_by_clauses: - query += " order by " + ", ".join(self.order_by_clauses) - if self.limit_value is not None: - query += f" limit {self.limit_value}" - if self.offset_value is not None: - query += f" offset {self.offset_value}" - if self.timeout_value is not None: - query += f" timeout {self.timeout_value}" - if self.parameters: - params = "&" + "&".join(f"{k}={v}" for k, v in self.parameters.items()) - query += params - return query - - -class Q: - @staticmethod - def select(*fields): - return Query(select_fields=list(fields)) - - @staticmethod - def p(*args): - if not args: - return Condition("") - else: - condition = args[0] - for arg in args[1:]: - condition = condition & arg - return condition - - @staticmethod - def userQuery(value: str = "", index: Optional[str] = None) -> Condition: - if index is None: - # Only value provided - return ( - Condition(f'userQuery("{value}")') - if value - else Condition("userQuery()") - ) - else: - # Both index and value provided - default_index_json = json.dumps( - {"defaultIndex": index}, separators=(",", ":") - ) - return Condition(f'({default_index_json})userQuery("{value}")') - - @staticmethod - def dotProduct( - field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None - ) -> Condition: - vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" - expr = f"dotProduct({field}, {vector_str})" - if annotations: - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def weightedSet( - field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None - ) -> Condition: - vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" - expr = f"weightedSet({field}, {vector_str})" - if annotations: - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def nonEmpty(condition: Union[Condition, Field]) -> Condition: - if isinstance(condition, Field): - expr = str(condition) - else: - expr = condition.build() - return Condition(f"nonEmpty({expr})") - - @staticmethod - def wand( - field: str, weights, annotations: Optional[Dict[str, Any]] = None - ) -> Condition: - if isinstance(weights, list): - weights_str = "[" + ",".join(str(item) for item in weights) + "]" - elif isinstance(weights, dict): - weights_str = "{" + ",".join(f'"{k}":{v}' for k, v in weights.items()) + "}" - else: - raise ValueError("Invalid weights for wand") - expr = f"wand({field}, {weights_str})" - if annotations: - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def weakAnd(*conditions, annotations: Dict[str, Any] = None) -> Condition: - conditions_str = ", ".join(cond.build() for cond in conditions) - expr = f"weakAnd({conditions_str})" - if annotations: - annotations_str = ",".join( - f'"{k}":{Field._format_annotation_value(v)}' - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def geoLocation( - field: str, - lat: float, - lng: float, - radius: str, - annotations: Optional[Dict[str, Any]] = None, - ) -> Condition: - expr = f'geoLocation({field}, {lat}, {lng}, "{radius}")' - if annotations: - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def nearestNeighbor( - field: str, query_vector: str, annotations: Dict[str, Any] - ) -> Condition: - if "targetHits" not in annotations: - raise ValueError("targetHits annotation is required") - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() - ) - return Condition( - f"({{{annotations_str}}}nearestNeighbor({field}, {query_vector}))" - ) - - @staticmethod - def rank(*queries) -> Condition: - queries_str = ", ".join(query.build() for query in queries) - return Condition(f"rank({queries_str})") - - @staticmethod - def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: - terms_str = ", ".join(f'"{term}"' for term in terms) - expr = f"phrase({terms_str})" - if annotations: - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def near(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: - terms_str = ", ".join(f'"{term}"' for term in terms) - expr = f"near({terms_str})" - if annotations: - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def onear(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: - terms_str = ", ".join(f'"{term}"' for term in terms) - expr = f"onear({terms_str})" - if annotations: - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def sameElement(*conditions) -> Condition: - conditions_str = ", ".join(cond.build() for cond in conditions) - expr = f"sameElement({conditions_str})" - return Condition(expr) - - @staticmethod - def equiv(*terms) -> Condition: - terms_str = ", ".join(f'"{term}"' for term in terms) - expr = f"equiv({terms_str})" - return Condition(expr) - - @staticmethod - def uri(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: - expr = f'uri("{value}")' - if annotations: - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def fuzzy(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: - expr = f'fuzzy("{value}")' - if annotations: - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def userInput( - value: Optional[str] = None, annotations: Optional[Dict[str, Any]] = None - ) -> Condition: - if value is None: - expr = "userInput()" - elif value.startswith("@"): - expr = f"userInput({value})" - else: - expr = f'userInput("{value}")' - if annotations: - annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" - for k, v in annotations.items() - ) - expr = f"({{{annotations_str}}}{expr})" - return Condition(expr) - - @staticmethod - def predicate( - field: str, - attributes: Optional[Dict[str, Any]] = None, - range_attributes: Optional[Dict[str, Any]] = None, - ) -> Condition: - if attributes is None: - attributes_str = "0" - else: - attributes_str = ( - "{" + ",".join(f'"{k}":"{v}"' for k, v in attributes.items()) + "}" - ) - if range_attributes is None: - range_attributes_str = "0" - else: - range_attributes_str = ( - "{" + ",".join(f'"{k}":{v}' for k, v in range_attributes.items()) + "}" - ) - expr = f"predicate({field},{attributes_str},{range_attributes_str})" - return Condition(expr) - - @staticmethod - def true() -> Condition: - return Condition("true") - - @staticmethod - def false() -> Condition: - return Condition("false") - - -class G: - @staticmethod - def all(*args) -> str: - return "all(" + " ".join(args) + ")" - - @staticmethod - def group(field: str) -> str: - return f"group({field})" - - @staticmethod - def maxRtn(value: int) -> str: - return f"max({value})" - - @staticmethod - def each(*args) -> str: - return "each(" + " ".join(args) + ")" - - @staticmethod - def output(output_func: str) -> str: - return f"output({output_func})" - - @staticmethod - def count() -> str: - return "count()" - - @staticmethod - def summary() -> str: - return "summary()" +from vespa.querybuilder import Query, Q, Field, G, Condition class TestQueryBuilder(unittest.TestCase): def test_dotProduct_with_annotations(self): condition = Q.dotProduct( - "vector_field", + "weightedset_field", {"feature1": 1, "feature2": 2}, annotations={"label": "myDotProduct"}, ) q = Query(select_fields="*").from_("vectors").where(condition).build() - expected = 'yql=select * from vectors where ({label:"myDotProduct"}dotProduct(vector_field, {"feature1":1,"feature2":2}))' + expected = 'yql=select * from vectors where ({label:"myDotProduct"}dotProduct(weightedset_field, {"feature1":1,"feature2":2}))' self.assertEqual(q, expected) def test_geoLocation_with_annotations(self): diff --git a/vespa/querybuilder/__init__.py b/vespa/querybuilder/__init__.py new file mode 100644 index 00000000..6ccdd50e --- /dev/null +++ b/vespa/querybuilder/__init__.py @@ -0,0 +1,2 @@ +# Export all from main +from .main import * diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py new file mode 100644 index 00000000..b7606086 --- /dev/null +++ b/vespa/querybuilder/main.py @@ -0,0 +1,573 @@ +from dataclasses import dataclass +from typing import Any, List, Union, Optional, Dict +import json + + +@dataclass +class Field: + name: str + + def __eq__(self, other: Any) -> "Condition": + return Condition(f"{self.name} = {self._format_value(other)}") + + def __ne__(self, other: Any) -> "Condition": + return Condition(f"{self.name} != {self._format_value(other)}") + + def __lt__(self, other: Any) -> "Condition": + return Condition(f"{self.name} < {self._format_value(other)}") + + def __le__(self, other: Any) -> "Condition": + return Condition(f"{self.name} <= {self._format_value(other)}") + + def __gt__(self, other: Any) -> "Condition": + return Condition(f"{self.name} > {self._format_value(other)}") + + def __ge__(self, other: Any) -> "Condition": + return Condition(f"{self.name} >= {self._format_value(other)}") + + def contains( + self, value: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + value_str = self._format_value(value) + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{self.name} contains({{{annotations_str}}}{value_str})") + else: + return Condition(f"{self.name} contains {value_str}") + + def matches( + self, value: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + value_str = self._format_value(value) + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{self.name} matches({{{annotations_str}}}{value_str})") + else: + return Condition(f"{self.name} matches {value_str}") + + def in_(self, *values) -> "Condition": + values_str = ", ".join( + f'"{v}"' if isinstance(v, str) else str(v) for v in values + ) + return Condition(f"{self.name} in ({values_str})") + + def in_range( + self, start: Any, end: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}range({self.name}, {start}, {end}))" + ) + else: + return Condition(f"range({self.name}, {start}, {end})") + + def le(self, value: Any) -> "Condition": + return self.__le__(value) + + def lt(self, value: Any) -> "Condition": + return self.__lt__(value) + + def ge(self, value: Any) -> "Condition": + return self.__ge__(value) + + def gt(self, value: Any) -> "Condition": + return self.__gt__(value) + + def eq(self, value: Any) -> "Condition": + return self.__eq__(value) + + def __str__(self) -> str: + return self.name + + @staticmethod + def _format_value(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + elif isinstance(value, Condition): + return value.build() + else: + return str(value) + + @staticmethod + def _format_annotation_value(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, dict): + return ( + "{" + + ",".join( + f'"{k}":{Field._format_annotation_value(v)}' + for k, v in value.items() + ) + + "}" + ) + elif isinstance(value, list): + return ( + "[" + + ",".join(f"{Field._format_annotation_value(v)}" for v in value) + + "]" + ) + else: + return str(value) + + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" for k, v in annotations.items() + ) + return Condition(f"({{{annotations_str}}}){self.name}") + + +@dataclass +class Condition: + expression: str + + def __and__(self, other: "Condition") -> "Condition": + left = self.expression + right = other.expression + + # Adjust parentheses based on operator precedence + left = f"({left})" if " or " in left else left + right = f"({right})" if " or " in right else right + + return Condition(f"{left} and {right}") + + def __or__(self, other: "Condition") -> "Condition": + left = self.expression + right = other.expression + + # Always add parentheses if 'and' or 'or' is in the expressions + left = f"({left})" if " and " in left or " or " in left else left + right = f"({right})" if " and " in right or " or " in right else right + + return Condition(f"{left} or {right}") + + def __invert__(self) -> "Condition": + return Condition(f"!({self.expression})") + + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() + ) + return Condition(f"({{{annotations_str}}}){self.expression}") + + def build(self) -> str: + return self.expression + + @classmethod + def all(cls, *conditions: "Condition") -> "Condition": + """Combine multiple conditions using logical AND.""" + expressions = [] + for cond in conditions: + expr = cond.expression + # Wrap expressions with 'or' in parentheses + if " or " in expr: + expr = f"({expr})" + expressions.append(expr) + combined_expression = " and ".join(expressions) + return Condition(combined_expression) + + @classmethod + def any(cls, *conditions: "Condition") -> "Condition": + """Combine multiple conditions using logical OR.""" + expressions = [] + for cond in conditions: + expr = cond.expression + # Wrap expressions with 'and' or 'or' in parentheses + if " and " in expr or " or " in expr: + expr = f"({expr})" + expressions.append(expr) + combined_expression = " or ".join(expressions) + return Condition(combined_expression) + + +class Query: + def __init__(self, select_fields: Union[str, List[str], List[Field]]): + self.select_fields = ( + ", ".join(select_fields) + if isinstance(select_fields, List) + and all(isinstance(f, str) for f in select_fields) + else ", ".join(str(f) for f in select_fields) + ) + self.sources = "*" + self.condition = None + self.order_by_clauses = [] + self.limit_value = None + self.offset_value = None + self.timeout_value = None + self.parameters = {} + self.grouping = None + + def from_(self, *sources: str) -> "Query": + self.sources = ", ".join(sources) + return self + + def where(self, condition: Union[Condition, Field]) -> "Query": + if isinstance(condition, Field): + self.condition = condition + else: + self.condition = condition + return self + + def order_by_field( + self, + field: str, + ascending: bool = True, + annotations: Optional[Dict[str, Any]] = None, + ) -> "Query": + direction = "asc" if ascending else "desc" + if annotations: + annotations_str = ",".join( + f'"{k}":{Field._format_annotation_value(v)}' + for k, v in annotations.items() + ) + self.order_by_clauses.append(f"{{{annotations_str}}}{field} {direction}") + else: + self.order_by_clauses.append(f"{field} {direction}") + return self + + def orderByAsc( + self, field: str, annotations: Optional[Dict[str, Any]] = None + ) -> "Query": + return self.order_by_field(field, True, annotations) + + def orderByDesc( + self, field: str, annotations: Optional[Dict[str, Any]] = None + ) -> "Query": + return self.order_by_field(field, False, annotations) + + def set_limit(self, limit: int) -> "Query": + self.limit_value = limit + return self + + def set_offset(self, offset: int) -> "Query": + self.offset_value = offset + return self + + def set_timeout(self, timeout: int) -> "Query": + self.timeout_value = timeout + return self + + def add_parameter(self, key: str, value: Any) -> "Query": + self.parameters[key] = value + return self + + def param(self, key: str, value: Any) -> "Query": + return self.add_parameter(key, value) + + def group(self, group_expression: str) -> "Query": + self.grouping = group_expression + return self + + def build(self, prepend_yql=True) -> str: + query = f"select {self.select_fields} from {self.sources}" + if prepend_yql: + query = f"yql={query}" + if self.condition: + query += f" where {self.condition.build()}" + if self.grouping: + query += f" | {self.grouping}" + if self.order_by_clauses: + query += " order by " + ", ".join(self.order_by_clauses) + if self.limit_value is not None: + query += f" limit {self.limit_value}" + if self.offset_value is not None: + query += f" offset {self.offset_value}" + if self.timeout_value is not None: + query += f" timeout {self.timeout_value}" + if self.parameters: + params = "&" + "&".join(f"{k}={v}" for k, v in self.parameters.items()) + query += params + return query + + +class Q: + @staticmethod + def select(*fields): + return Query(select_fields=list(fields)) + + @staticmethod + def p(*args): + if not args: + return Condition("") + else: + condition = args[0] + for arg in args[1:]: + condition = condition & arg + return condition + + @staticmethod + def userQuery(value: str = "", index: Optional[str] = None) -> Condition: + if index is None: + # Only value provided + return ( + Condition(f'userQuery("{value}")') + if value + else Condition("userQuery()") + ) + else: + # Both index and value provided + default_index_json = json.dumps( + {"defaultIndex": index}, separators=(",", ":") + ) + return Condition(f'({default_index_json})userQuery("{value}")') + + @staticmethod + def dotProduct( + field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" + expr = f"dotProduct({field}, {vector_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def weightedSet( + field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" + expr = f"weightedSet({field}, {vector_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def nonEmpty(condition: Union[Condition, Field]) -> Condition: + if isinstance(condition, Field): + expr = str(condition) + else: + expr = condition.build() + return Condition(f"nonEmpty({expr})") + + @staticmethod + def wand( + field: str, weights, annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + if isinstance(weights, list): + weights_str = "[" + ",".join(str(item) for item in weights) + "]" + elif isinstance(weights, dict): + weights_str = "{" + ",".join(f'"{k}":{v}' for k, v in weights.items()) + "}" + else: + raise ValueError("Invalid weights for wand") + expr = f"wand({field}, {weights_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def weakAnd(*conditions, annotations: Dict[str, Any] = None) -> Condition: + conditions_str = ", ".join(cond.build() for cond in conditions) + expr = f"weakAnd({conditions_str})" + if annotations: + annotations_str = ",".join( + f'"{k}":{Field._format_annotation_value(v)}' + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def geoLocation( + field: str, + lat: float, + lng: float, + radius: str, + annotations: Optional[Dict[str, Any]] = None, + ) -> Condition: + expr = f'geoLocation({field}, {lat}, {lng}, "{radius}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def nearestNeighbor( + field: str, query_vector: str, annotations: Dict[str, Any] + ) -> Condition: + if "targetHits" not in annotations: + raise ValueError("targetHits annotation is required") + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}nearestNeighbor({field}, {query_vector}))" + ) + + @staticmethod + def rank(*queries) -> Condition: + queries_str = ", ".join(query.build() for query in queries) + return Condition(f"rank({queries_str})") + + @staticmethod + def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"phrase({terms_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def near(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"near({terms_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def onear(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"onear({terms_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def sameElement(*conditions) -> Condition: + conditions_str = ", ".join(cond.build() for cond in conditions) + expr = f"sameElement({conditions_str})" + return Condition(expr) + + @staticmethod + def equiv(*terms) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"equiv({terms_str})" + return Condition(expr) + + @staticmethod + def uri(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: + expr = f'uri("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def fuzzy(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: + expr = f'fuzzy("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def userInput( + value: Optional[str] = None, annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + if value is None: + expr = "userInput()" + elif value.startswith("@"): + expr = f"userInput({value})" + else: + expr = f'userInput("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Field._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def predicate( + field: str, + attributes: Optional[Dict[str, Any]] = None, + range_attributes: Optional[Dict[str, Any]] = None, + ) -> Condition: + if attributes is None: + attributes_str = "0" + else: + attributes_str = ( + "{" + ",".join(f'"{k}":"{v}"' for k, v in attributes.items()) + "}" + ) + if range_attributes is None: + range_attributes_str = "0" + else: + range_attributes_str = ( + "{" + ",".join(f'"{k}":{v}' for k, v in range_attributes.items()) + "}" + ) + expr = f"predicate({field},{attributes_str},{range_attributes_str})" + return Condition(expr) + + @staticmethod + def true() -> Condition: + return Condition("true") + + @staticmethod + def false() -> Condition: + return Condition("false") + + +class G: + @staticmethod + def all(*args) -> str: + return "all(" + " ".join(args) + ")" + + @staticmethod + def group(field: str) -> str: + return f"group({field})" + + @staticmethod + def maxRtn(value: int) -> str: + return f"max({value})" + + @staticmethod + def each(*args) -> str: + return "each(" + " ".join(args) + ")" + + @staticmethod + def output(output_func: str) -> str: + return f"output({output_func})" + + @staticmethod + def count() -> str: + return "count()" + + @staticmethod + def summary() -> str: + return "summary()" From 511bb71b9adf45fa9ce130bf862457d364148fed Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Fri, 27 Sep 2024 14:29:47 +0200 Subject: [PATCH 09/39] add integration test --- tests/integration/test_integration_queries.py | 167 ++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 tests/integration/test_integration_queries.py diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py new file mode 100644 index 00000000..7149e844 --- /dev/null +++ b/tests/integration/test_integration_queries.py @@ -0,0 +1,167 @@ +import unittest +from vespa.deployment import VespaDocker +from vespa.package import ( + ApplicationPackage, + Schema, + Document, + Field, + StructField, + Struct, +) +from vespa.querybuilder import Query, Q + + +class TestQueriesIntegration(unittest.TestCase): + @classmethod + def setUpClass(cls): + application_name = "querybuilder" + cls.application_name = application_name + + # Define all fields used in the unit tests + fields = [ + Field( + name="weightedset_field", + type="weightedset", + indexing=["attribute"], + ), + Field(name="location_field", type="position", indexing=["attribute"]), + Field(name="f1", type="string", indexing=["index", "summary"]), + Field(name="f2", type="string", indexing=["index", "summary"]), + Field(name="f3", type="string", indexing=["index", "summary"]), + Field(name="f4", type="string", indexing=["index", "summary"]), + Field(name="age", type="int", indexing=["attribute", "summary"]), + Field(name="duration", type="int", indexing=["attribute", "summary"]), + Field(name="id", type="string", indexing=["attribute", "summary"]), + Field(name="text", type="string", indexing=["index", "summary"]), + Field(name="title", type="string", indexing=["index", "summary"]), + Field(name="description", type="string", indexing=["index", "summary"]), + Field(name="date", type="string", indexing=["attribute", "summary"]), + Field(name="status", type="string", indexing=["attribute", "summary"]), + Field(name="comments", type="string", indexing=["attribute", "summary"]), + Field( + name="embedding", + type="tensor(x[128])", + indexing=["attribute"], + ), + Field(name="tags", type="array", indexing=["attribute", "summary"]), + Field( + name="timestamp", + type="long", + indexing=["attribute", "summary"], + ), + Field(name="integer_field", type="int", indexing=["attribute", "summary"]), + Field( + name="predicate_field", + type="predicate", + indexing=["attribute", "summary"], + ), + Field( + name="myStringAttribute", type="string", indexing=["index", "summary"] + ), + Field(name="myUrlField", type="string", indexing=["index", "summary"]), + Field(name="fieldName", type="string", indexing=["index", "summary"]), + Field( + name="dense_rep", + type="tensor(x[128])", + indexing=["attribute"], + ), + Field(name="artist", type="string", indexing=["attribute", "summary"]), + Field(name="subject", type="string", indexing=["attribute", "summary"]), + Field( + name="display_date", type="string", indexing=["attribute", "summary"] + ), + Field(name="price", type="double", indexing=["attribute", "summary"]), + Field(name="keywords", type="string", indexing=["index", "summary"]), + ] + email_struct = Struct( + name="email", + fields=[ + Field(name="sender", type="string"), + Field(name="recipient", type="string"), + Field(name="subject", type="string"), + Field(name="content", type="string"), + ], + ) + emails_field = Field( + name="emails", + type="array", + indexing=["summary"], + struct_fields=[ + StructField( + name="content", indexing=["attribute"], attribute=["fast-search"] + ) + ], + ) + document = Document(fields=fields, structs=[email_struct]) + schema = Schema(name=application_name, document=document) + schema.add_fields(emails_field) + application_package = ApplicationPackage(name=application_name, schema=[schema]) + print(application_package.schema.schema_to_text) + # Deploy the application + cls.vespa_docker = VespaDocker(port=8089) + cls.app = cls.vespa_docker.deploy(application_package=application_package) + + @classmethod + def tearDown(cls): + cls.vespa_docker.container.stop(timeout=5) + cls.vespa_docker.container.remove() + + def test_dotProduct_with_annotations(self): + # Feed a document with 'weightedset_field' + field = "weightedset_field" + doc = { + "id": f"id:{self.application_name}:{self.application_name}::1", + "fields": {field: {"feature1": 0.5, "feature2": 1.0}}, + } + self.app.feed_data_point( + schema=self.application_name, data_id=doc["id"], fields=doc["fields"] + ) + + # Build and send the query + condition = Q.dotProduct( + field, + {"feature1": 1, "feature2": 2}, + annotations={"label": "myDotProduct"}, + ) + q = ( + Query(select_fields=[field]) + .from_(self.application_name) + .where(condition) + .build(prepend_yql=False) + ) + print(q) + result = self.app.query(yql=q) + + # Check the result + self.assertEqual(len(result.hits), 1) + self.assertEqual(result.hits[0]["id"], doc["id"]) + + def test_geoLocation_with_annotations(self): + # Feed a document with 'location_field' + doc = { + "id": f"id:{self.application_name}:{self.application_name}::2", + "fields": {"location_field": "37.7749, -122.4194"}, + } + self.app.feed_data_point( + schema=self.application_name, data_id=doc["id"], fields=doc["fields"] + ) + + # Build and send the query + condition = Q.geoLocation( + "location_field", + 37.7749, + -122.4194, + "10km", + annotations={"targetHits": 100}, + ) + q = ( + Query(select_fields="") + .from_(self.application_name) + .where(condition) + .build(prepend_yql=False) + ) + result = self.app.query(yql=q) + + # Check the result + self.assertEqual(len(result.hits), 1) + self.assertEqual(result.hits[0]["id"], doc["id"]) From bc49b31cc553da1b4c56311a2849d4998f58529a Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 30 Sep 2024 13:07:48 +0200 Subject: [PATCH 10/39] rename Field to Queryfield --- vespa/querybuilder/main.py | 46 ++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py index b7606086..346c27ba 100644 --- a/vespa/querybuilder/main.py +++ b/vespa/querybuilder/main.py @@ -4,7 +4,7 @@ @dataclass -class Field: +class Queryfield: name: str def __eq__(self, other: Any) -> "Condition": @@ -108,7 +108,7 @@ def _format_annotation_value(value: Any) -> str: return ( "{" + ",".join( - f'"{k}":{Field._format_annotation_value(v)}' + f'"{k}":{Queryfield._format_annotation_value(v)}' for k, v in value.items() ) + "}" @@ -116,7 +116,7 @@ def _format_annotation_value(value: Any) -> str: elif isinstance(value, list): return ( "[" - + ",".join(f"{Field._format_annotation_value(v)}" for v in value) + + ",".join(f"{Queryfield._format_annotation_value(v)}" for v in value) + "]" ) else: @@ -158,7 +158,8 @@ def __invert__(self) -> "Condition": def annotate(self, annotations: Dict[str, Any]) -> "Condition": annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() ) return Condition(f"({{{annotations_str}}}){self.expression}") @@ -193,7 +194,7 @@ def any(cls, *conditions: "Condition") -> "Condition": class Query: - def __init__(self, select_fields: Union[str, List[str], List[Field]]): + def __init__(self, select_fields: Union[str, List[str], List[Queryfield]]): self.select_fields = ( ", ".join(select_fields) if isinstance(select_fields, List) @@ -213,8 +214,8 @@ def from_(self, *sources: str) -> "Query": self.sources = ", ".join(sources) return self - def where(self, condition: Union[Condition, Field]) -> "Query": - if isinstance(condition, Field): + def where(self, condition: Union[Condition, Queryfield]) -> "Query": + if isinstance(condition, Queryfield): self.condition = condition else: self.condition = condition @@ -229,7 +230,7 @@ def order_by_field( direction = "asc" if ascending else "desc" if annotations: annotations_str = ",".join( - f'"{k}":{Field._format_annotation_value(v)}' + f'"{k}":{Queryfield._format_annotation_value(v)}' for k, v in annotations.items() ) self.order_by_clauses.append(f"{{{annotations_str}}}{field} {direction}") @@ -331,7 +332,7 @@ def dotProduct( expr = f"dotProduct({field}, {vector_str})" if annotations: annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" + f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -345,15 +346,15 @@ def weightedSet( expr = f"weightedSet({field}, {vector_str})" if annotations: annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" + f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" return Condition(expr) @staticmethod - def nonEmpty(condition: Union[Condition, Field]) -> Condition: - if isinstance(condition, Field): + def nonEmpty(condition: Union[Condition, Queryfield]) -> Condition: + if isinstance(condition, Queryfield): expr = str(condition) else: expr = condition.build() @@ -372,7 +373,7 @@ def wand( expr = f"wand({field}, {weights_str})" if annotations: annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" + f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -384,7 +385,7 @@ def weakAnd(*conditions, annotations: Dict[str, Any] = None) -> Condition: expr = f"weakAnd({conditions_str})" if annotations: annotations_str = ",".join( - f'"{k}":{Field._format_annotation_value(v)}' + f'"{k}":{Queryfield._format_annotation_value(v)}' for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -401,7 +402,7 @@ def geoLocation( expr = f'geoLocation({field}, {lat}, {lng}, "{radius}")' if annotations: annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" + f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -414,7 +415,8 @@ def nearestNeighbor( if "targetHits" not in annotations: raise ValueError("targetHits annotation is required") annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" for k, v in annotations.items() + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() ) return Condition( f"({{{annotations_str}}}nearestNeighbor({field}, {query_vector}))" @@ -431,7 +433,7 @@ def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: expr = f"phrase({terms_str})" if annotations: annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" + f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -443,7 +445,7 @@ def near(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: expr = f"near({terms_str})" if annotations: annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" + f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -455,7 +457,7 @@ def onear(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: expr = f"onear({terms_str})" if annotations: annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" + f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -478,7 +480,7 @@ def uri(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: expr = f'uri("{value}")' if annotations: annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" + f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -489,7 +491,7 @@ def fuzzy(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition expr = f'fuzzy("{value}")' if annotations: annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" + f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -507,7 +509,7 @@ def userInput( expr = f'userInput("{value}")' if annotations: annotations_str = ",".join( - f"{k}:{Field._format_annotation_value(v)}" + f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" From 26e8c1d3bacd7d64a16a61ec41af020d7d8f6793 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 30 Sep 2024 13:08:11 +0200 Subject: [PATCH 11/39] use unit tests for integration tests --- tests/integration/test_integration_queries.py | 155 +++++++++++++----- 1 file changed, 112 insertions(+), 43 deletions(-) diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py index 7149e844..8ce3150a 100644 --- a/tests/integration/test_integration_queries.py +++ b/tests/integration/test_integration_queries.py @@ -7,8 +7,11 @@ Field, StructField, Struct, + RankProfile, ) -from vespa.querybuilder import Query, Q +from tests.unit.test_q import TestQueryBuilder + +qb = TestQueryBuilder() class TestQueriesIntegration(unittest.TestCase): @@ -92,76 +95,142 @@ def setUpClass(cls): ) ], ) + rank_profiles = [ + RankProfile( + name="dotproduct", + first_phase="rawScore(weightedset_field)", + summary_features=["rawScore(weightedset_field)"], + ), + RankProfile( + name="geolocation", + first_phase="distance(location_field)", + summary_features=["distance(location_field).km"], + ), + ] document = Document(fields=fields, structs=[email_struct]) - schema = Schema(name=application_name, document=document) + schema = Schema( + name=application_name, document=document, rank_profiles=rank_profiles + ) schema.add_fields(emails_field) application_package = ApplicationPackage(name=application_name, schema=[schema]) print(application_package.schema.schema_to_text) # Deploy the application cls.vespa_docker = VespaDocker(port=8089) cls.app = cls.vespa_docker.deploy(application_package=application_package) + cls.app.wait_for_application_up() @classmethod - def tearDown(cls): + def tearDownClass(cls): cls.vespa_docker.container.stop(timeout=5) cls.vespa_docker.container.remove() + # @unittest.skip("Skip until we have a better way to test this") def test_dotProduct_with_annotations(self): # Feed a document with 'weightedset_field' field = "weightedset_field" - doc = { - "id": f"id:{self.application_name}:{self.application_name}::1", - "fields": {field: {"feature1": 0.5, "feature2": 1.0}}, - } + fields = {field: {"feature1": 2, "feature2": 4}} + data_id = 1 self.app.feed_data_point( - schema=self.application_name, data_id=doc["id"], fields=doc["fields"] + schema=self.application_name, data_id=data_id, fields=fields ) - - # Build and send the query - condition = Q.dotProduct( - field, - {"feature1": 1, "feature2": 2}, - annotations={"label": "myDotProduct"}, + q = qb.test_dotProduct_with_annotations() + with self.app.syncio() as sess: + result = sess.query(yql=q, ranking="dotproduct") + print(result.json) + self.assertEqual(len(result.hits), 1) + self.assertEqual( + result.hits[0]["id"], + f"id:{self.application_name}:{self.application_name}::{data_id}", ) - q = ( - Query(select_fields=[field]) - .from_(self.application_name) - .where(condition) - .build(prepend_yql=False) + self.assertEqual( + result.hits[0]["fields"]["summaryfeatures"]["rawScore(weightedset_field)"], + 10, ) - print(q) - result = self.app.query(yql=q) - - # Check the result - self.assertEqual(len(result.hits), 1) - self.assertEqual(result.hits[0]["id"], doc["id"]) def test_geoLocation_with_annotations(self): # Feed a document with 'location_field' - doc = { - "id": f"id:{self.application_name}:{self.application_name}::2", - "fields": {"location_field": "37.7749, -122.4194"}, + field_name = "location_field" + fields = { + field_name: { + "lat": 37.77491, + "lng": -122.41941, + }, # 0.00001 degrees more than the query } + data_id = 2 self.app.feed_data_point( - schema=self.application_name, data_id=doc["id"], fields=doc["fields"] + schema=self.application_name, data_id=data_id, fields=fields ) - # Build and send the query - condition = Q.geoLocation( - "location_field", - 37.7749, - -122.4194, - "10km", - annotations={"targetHits": 100}, + q = qb.test_geoLocation_with_annotations() + with self.app.syncio() as sess: + result = sess.query(yql=q, ranking="geolocation") + # Check the result + self.assertEqual(len(result.hits), 1) + self.assertEqual( + result.hits[0]["id"], + f"id:{self.application_name}:{self.application_name}::{data_id}", ) - q = ( - Query(select_fields="") - .from_(self.application_name) - .where(condition) - .build(prepend_yql=False) + self.assertAlmostEqual( + result.hits[0]["fields"]["summaryfeatures"]["distance(location_field).km"], + 0.001417364012462494, ) - result = self.app.query(yql=q) + print(result.json) + def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): + docs = [ + { # Should not match + "f1": "v1", + "f2": "v2", + "f3": "asdf", + "f4": "d", + "age": 10, + "duration": 100, + }, + { # Should match + "f1": "v1", + "f2": "v2", + "f3": "v3", + "f4": "d", + "age": 20, + "duration": 200, + }, + { # Should match + "f1": "v1", + "f2": "v2", + "f3": "v3", + "f4": "d", + "age": 30, + "duration": 300, + }, + { # Should not match + "f1": "v1", + "f2": "v2", + "f3": "v3", + "f4": "v4", + "age": 30, + "duration": 300, + }, + ] + id_to_match = 2 + docs = [ + { + "fields": doc, + "id": data_id, + } + for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.application_name) + # Build and send the query + q = qb.test_basic_and_andnot_or_offset_limit_param_order_by_and_contains() + print(q) + with self.app.syncio() as sess: + result = sess.query( + yql=q, + ) # Check the result self.assertEqual(len(result.hits), 1) - self.assertEqual(result.hits[0]["id"], doc["id"]) + self.assertEqual( + result.hits[0]["id"], + f"id:{self.application_name}:{self.application_name}::{id_to_match}", + ) + print(result.json) From 1f9c40bf21e86ba24eddaebc3b73549673d63021 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 30 Sep 2024 13:08:27 +0200 Subject: [PATCH 12/39] refacotring of unit tests --- tests/unit/test_q.py | 476 ++++++++++++++++++++++++++++--------------- 1 file changed, 314 insertions(+), 162 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index daeb0201..7bbff88a 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -1,5 +1,5 @@ import unittest -from vespa.querybuilder import Query, Q, Field, G, Condition +from vespa.querybuilder import Query, Q, Queryfield, G, Condition class TestQueryBuilder(unittest.TestCase): @@ -9,9 +9,15 @@ def test_dotProduct_with_annotations(self): {"feature1": 1, "feature2": 2}, annotations={"label": "myDotProduct"}, ) - q = Query(select_fields="*").from_("vectors").where(condition).build() - expected = 'yql=select * from vectors where ({label:"myDotProduct"}dotProduct(weightedset_field, {"feature1":1,"feature2":2}))' + q = ( + Query(select_fields="*") + .from_("querybuilder") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from querybuilder where ({label:"myDotProduct"}dotProduct(weightedset_field, {"feature1":1,"feature2":2}))' self.assertEqual(q, expected) + return q def test_geoLocation_with_annotations(self): condition = Q.geoLocation( @@ -21,89 +27,128 @@ def test_geoLocation_with_annotations(self): "10km", annotations={"targetHits": 100}, ) - q = Query(select_fields="*").from_("places").where(condition).build() - expected = 'yql=select * from places where ({targetHits:100}geoLocation(location_field, 37.7749, -122.4194, "10km"))' + q = ( + Query(select_fields="*") + .from_("querybuilder") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from querybuilder where ({targetHits:100}geoLocation(location_field, 37.7749, -122.4194, "10km"))' self.assertEqual(q, expected) + return q def test_select_specific_fields(self): - f1 = Field("f1") + f1 = Queryfield("f1") condition = f1.contains("v1") - q = Query(select_fields=["f1", "f2"]).from_("sd1").where(condition).build() + q = ( + Query(select_fields=["f1", "f2"]) + .from_("sd1") + .where(condition) + .build(prepend_yql=False) + ) - self.assertEqual(q, 'yql=select f1, f2 from sd1 where f1 contains "v1"') + self.assertEqual(q, 'select f1, f2 from sd1 where f1 contains "v1"') def test_select_from_specific_sources(self): - f1 = Field("f1") + f1 = Queryfield("f1") condition = f1.contains("v1") - q = Query(select_fields="*").from_("sd1").where(condition).build() + q = ( + Query(select_fields="*") + .from_("sd1") + .where(condition) + .build(prepend_yql=False) + ) - self.assertEqual(q, 'yql=select * from sd1 where f1 contains "v1"') + self.assertEqual(q, 'select * from sd1 where f1 contains "v1"') def test_select_from_multiples_sources(self): - f1 = Field("f1") + f1 = Queryfield("f1") condition = f1.contains("v1") - q = Query(select_fields="*").from_("sd1", "sd2").where(condition).build() + q = ( + Query(select_fields="*") + .from_("sd1", "sd2") + .where(condition) + .build(prepend_yql=False) + ) - self.assertEqual(q, 'yql=select * from sd1, sd2 where f1 contains "v1"') + self.assertEqual(q, 'select * from sd1, sd2 where f1 contains "v1"') def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): - f1 = Field("f1") - f2 = Field("f2") - f3 = Field("f3") - f4 = Field("f4") + f1 = Queryfield("f1") + f2 = Queryfield("f2") + f3 = Queryfield("f3") + f4 = Queryfield("f4") condition = ((f1.contains("v1") & f2.contains("v2")) | f3.contains("v3")) & ( ~f4.contains("v4") ) q = ( Query(select_fields="*") - .from_("sd1") + .from_("querybuilder") .where(condition) .set_offset(1) .set_limit(2) - .set_timeout(3) - .orderByDesc("f1") - .orderByAsc("f2") - .param("paramk1", "paramv1") - .build() + .set_timeout(3000) + .orderByDesc("age") + .orderByAsc("duration") + .build(prepend_yql=False) ) - expected = 'yql=select * from sd1 where ((f1 contains "v1" and f2 contains "v2") or f3 contains "v3") and !(f4 contains "v4") order by f1 desc, f2 asc limit 2 offset 1 timeout 3¶mk1=paramv1' + expected = 'select * from querybuilder where ((f1 contains "v1" and f2 contains "v2") or f3 contains "v3") and !(f4 contains "v4") order by age desc, duration asc limit 2 offset 1 timeout 3000' self.assertEqual(q, expected) + return q def test_matches(self): condition = ( - (Field("f1").matches("v1") & Field("f2").matches("v2")) - | Field("f3").matches("v3") - ) & ~Field("f4").matches("v4") - q = Query(select_fields="*").from_("sd1").where(condition).build() - expected = 'yql=select * from sd1 where ((f1 matches "v1" and f2 matches "v2") or f3 matches "v3") and !(f4 matches "v4")' + (Queryfield("f1").matches("v1") & Queryfield("f2").matches("v2")) + | Queryfield("f3").matches("v3") + ) & ~Queryfield("f4").matches("v4") + q = ( + Query(select_fields="*") + .from_("sd1") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from sd1 where ((f1 matches "v1" and f2 matches "v2") or f3 matches "v3") and !(f4 matches "v4")' self.assertEqual(q, expected) def test_nested_queries(self): - nested_query = (Field("f2").contains("2") & Field("f3").contains("3")) | ( - Field("f2").contains("4") & ~Field("f3").contains("5") + nested_query = ( + Queryfield("f2").contains("2") & Queryfield("f3").contains("3") + ) | (Queryfield("f2").contains("4") & ~Queryfield("f3").contains("5")) + condition = Queryfield("f1").contains("1") & ~nested_query + q = ( + Query(select_fields="*") + .from_("sd1") + .where(condition) + .build(prepend_yql=False) ) - condition = Field("f1").contains("1") & ~nested_query - q = Query(select_fields="*").from_("sd1").where(condition).build() - expected = 'yql=select * from sd1 where f1 contains "1" and (!((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5"))))' + expected = 'select * from sd1 where f1 contains "1" and (!((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5"))))' self.assertEqual(q, expected) def test_userInput_with_and_without_defaultIndex(self): condition = Q.userQuery(value="value1") & Q.userQuery( index="index", value="value2" ) - q = Query(select_fields="*").from_("sd1").where(condition).build() - expected = 'yql=select * from sd1 where userQuery("value1") and ({"defaultIndex":"index"})userQuery("value2")' + q = ( + Query(select_fields="*") + .from_("sd1") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from sd1 where userQuery("value1") and ({"defaultIndex":"index"})userQuery("value2")' self.assertEqual(q, expected) def test_fields_duration(self): - f1 = Field("subject") - f2 = Field("display_date") - f3 = Field("duration") + f1 = Queryfield("subject") + f2 = Queryfield("display_date") + f3 = Queryfield("duration") condition = ( - Query(select_fields=[f1, f2]).from_("calendar").where(f3 > 0).build() + Query(select_fields=[f1, f2]) + .from_("calendar") + .where(f3 > 0) + .build(prepend_yql=False) ) - expected = "yql=select subject, display_date from calendar where duration > 0" + expected = "select subject, display_date from calendar where duration > 0" self.assertEqual(condition, expected) def test_nearest_neighbor(self): @@ -115,9 +160,9 @@ def test_nearest_neighbor(self): Query(select_fields=["id, text"]) .from_("m") .where(condition_uq | condition_nn) - .build() + .build(prepend_yql=False) ) - expected = "yql=select id, text from m where userQuery() or ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" + expected = "select id, text from m where userQuery() or ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) def test_build_many_nn_operators(self): @@ -135,9 +180,9 @@ def test_build_many_nn_operators(self): Query(select_fields="*") .from_("doc") .where(condition=Condition.any(*conditions)) - .build() + .build(prepend_yql=False) ) - expected = "yql=select * from doc where " + " or ".join( + expected = "select * from doc where " + " or ".join( [ f"({{targetHits:100}}nearestNeighbor(colbert, binary_vector_{i}))" for i in range(32) @@ -146,102 +191,141 @@ def test_build_many_nn_operators(self): self.assertEqual(q, expected) def test_field_comparison_operators(self): - f1 = Field("age") + f1 = Queryfield("age") condition = (f1 > 30) & (f1 <= 50) - q = Query(select_fields="*").from_("people").where(condition).build() - expected = "yql=select * from people where age > 30 and age <= 50" + q = ( + Query(select_fields="*") + .from_("people") + .where(condition) + .build(prepend_yql=False) + ) + expected = "select * from people where age > 30 and age <= 50" self.assertEqual(q, expected) def test_field_in_range(self): - f1 = Field("age") + f1 = Queryfield("age") condition = f1.in_range(18, 65) - q = Query(select_fields="*").from_("people").where(condition).build() - expected = "yql=select * from people where range(age, 18, 65)" + q = ( + Query(select_fields="*") + .from_("people") + .where(condition) + .build(prepend_yql=False) + ) + expected = "select * from people where range(age, 18, 65)" self.assertEqual(q, expected) def test_field_annotation(self): - f1 = Field("title") + f1 = Queryfield("title") annotations = {"highlight": True} annotated_field = f1.annotate(annotations) - q = Query(select_fields="*").from_("articles").where(annotated_field).build() - expected = "yql=select * from articles where ({highlight:true})title" + q = ( + Query(select_fields="*") + .from_("articles") + .where(annotated_field) + .build(prepend_yql=False) + ) + expected = "select * from articles where ({highlight:true})title" self.assertEqual(q, expected) def test_condition_annotation(self): - f1 = Field("title") + f1 = Queryfield("title") condition = f1.contains("Python") annotated_condition = condition.annotate({"filter": True}) q = ( Query(select_fields="*") .from_("articles") .where(annotated_condition) - .build() - ) - expected = ( - 'yql=select * from articles where ({filter:true})title contains "Python"' + .build(prepend_yql=False) ) + expected = 'select * from articles where ({filter:true})title contains "Python"' self.assertEqual(q, expected) def test_grouping_aggregation(self): grouping = G.all(G.group("category"), G.output(G.count())) - q = Query(select_fields="*").from_("products").group(grouping).build() - expected = "yql=select * from products | all(group(category) output(count()))" + q = ( + Query(select_fields="*") + .from_("products") + .group(grouping) + .build(prepend_yql=False) + ) + expected = "select * from products | all(group(category) output(count()))" self.assertEqual(q, expected) def test_add_parameter(self): - f1 = Field("title") + f1 = Queryfield("title") condition = f1.contains("Python") q = ( Query(select_fields="*") .from_("articles") .where(condition) .add_parameter("tracelevel", 1) - .build() - ) - expected = ( - 'yql=select * from articles where title contains "Python"&tracelevel=1' + .build(prepend_yql=False) ) + expected = 'select * from articles where title contains "Python"&tracelevel=1' self.assertEqual(q, expected) def test_custom_ranking_expression(self): condition = Q.rank( Q.userQuery(), Q.dotProduct("embedding", {"feature1": 1, "feature2": 2}) ) - q = Query(select_fields="*").from_("documents").where(condition).build() - expected = 'yql=select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1,"feature2":2}))' + q = ( + Query(select_fields="*") + .from_("documents") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1,"feature2":2}))' self.assertEqual(q, expected) def test_wand(self): condition = Q.wand("keywords", {"apple": 10, "banana": 20}) - q = Query(select_fields="*").from_("fruits").where(condition).build() - expected = ( - 'yql=select * from fruits where wand(keywords, {"apple":10,"banana":20})' + q = ( + Query(select_fields="*") + .from_("fruits") + .where(condition) + .build(prepend_yql=False) ) + expected = 'select * from fruits where wand(keywords, {"apple":10,"banana":20})' self.assertEqual(q, expected) def test_weakand(self): - condition1 = Field("title").contains("Python") - condition2 = Field("description").contains("Programming") + condition1 = Queryfield("title").contains("Python") + condition2 = Queryfield("description").contains("Programming") condition = Q.weakAnd( condition1, condition2, annotations={"targetNumHits": 100} ) - q = Query(select_fields="*").from_("articles").where(condition).build() - expected = 'yql=select * from articles where ({"targetNumHits":100}weakAnd(title contains "Python", description contains "Programming"))' + q = ( + Query(select_fields="*") + .from_("articles") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from articles where ({"targetNumHits":100}weakAnd(title contains "Python", description contains "Programming"))' self.assertEqual(q, expected) def test_geoLocation(self): condition = Q.geoLocation("location_field", 37.7749, -122.4194, "10km") - q = Query(select_fields="*").from_("places").where(condition).build() - expected = 'yql=select * from places where geoLocation(location_field, 37.7749, -122.4194, "10km")' + q = ( + Query(select_fields="*") + .from_("places") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from places where geoLocation(location_field, 37.7749, -122.4194, "10km")' self.assertEqual(q, expected) def test_condition_all_any(self): - c1 = Field("f1").contains("v1") - c2 = Field("f2").contains("v2") - c3 = Field("f3").contains("v3") + c1 = Queryfield("f1").contains("v1") + c2 = Queryfield("f2").contains("v2") + c3 = Queryfield("f3").contains("v3") condition = Condition.all(c1, c2, Condition.any(c3, ~c1)) - q = Query(select_fields="*").from_("sd1").where(condition).build() - expected = 'yql=select * from sd1 where f1 contains "v1" and f2 contains "v2" and (f3 contains "v3" or !(f1 contains "v1"))' + q = ( + Query(select_fields="*") + .from_("sd1") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from sd1 where f1 contains "v1" and f2 contains "v2" and (f3 contains "v3" or !(f1 contains "v1"))' self.assertEqual(q, expected) def test_order_by_with_annotations(self): @@ -253,55 +337,85 @@ def test_order_by_with_annotations(self): .from_("products") .orderByDesc(f1, annotations) .orderByAsc(f2) - .build() + .build(prepend_yql=False) + ) + expected = ( + 'select * from products order by {"strength":0.5}relevance desc, price asc' ) - expected = 'yql=select * from products order by {"strength":0.5}relevance desc, price asc' self.assertEqual(q, expected) def test_field_comparison_methods(self): - f1 = Field("age") + f1 = Queryfield("age") condition = f1.ge(18) & f1.lt(30) - q = Query(select_fields="*").from_("users").where(condition).build() - expected = "yql=select * from users where age >= 18 and age < 30" + q = ( + Query(select_fields="*") + .from_("users") + .where(condition) + .build(prepend_yql=False) + ) + expected = "select * from users where age >= 18 and age < 30" self.assertEqual(q, expected) def test_filter_annotation(self): - f1 = Field("title") + f1 = Queryfield("title") condition = f1.contains("Python").annotate({"filter": True}) - q = Query(select_fields="*").from_("articles").where(condition).build() - expected = ( - 'yql=select * from articles where ({filter:true})title contains "Python"' + q = ( + Query(select_fields="*") + .from_("articles") + .where(condition) + .build(prepend_yql=False) ) + expected = 'select * from articles where ({filter:true})title contains "Python"' self.assertEqual(q, expected) def test_nonEmpty(self): - condition = Q.nonEmpty(Field("comments").eq("any_value")) - q = Query(select_fields="*").from_("posts").where(condition).build() - expected = 'yql=select * from posts where nonEmpty(comments = "any_value")' + condition = Q.nonEmpty(Queryfield("comments").eq("any_value")) + q = ( + Query(select_fields="*") + .from_("posts") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from posts where nonEmpty(comments = "any_value")' self.assertEqual(q, expected) def test_dotProduct(self): condition = Q.dotProduct("vector_field", {"feature1": 1, "feature2": 2}) - q = Query(select_fields="*").from_("vectors").where(condition).build() - expected = 'yql=select * from vectors where dotProduct(vector_field, {"feature1":1,"feature2":2})' + q = ( + Query(select_fields="*") + .from_("vectors") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from vectors where dotProduct(vector_field, {"feature1":1,"feature2":2})' self.assertEqual(q, expected) def test_in_range_string_values(self): - f1 = Field("date") + f1 = Queryfield("date") condition = f1.in_range("2021-01-01", "2021-12-31") - q = Query(select_fields="*").from_("events").where(condition).build() - expected = "yql=select * from events where range(date, 2021-01-01, 2021-12-31)" + q = ( + Query(select_fields="*") + .from_("events") + .where(condition) + .build(prepend_yql=False) + ) + expected = "select * from events where range(date, 2021-01-01, 2021-12-31)" self.assertEqual(q, expected) def test_condition_inversion(self): - f1 = Field("status") + f1 = Queryfield("status") condition = ~f1.eq("inactive") - q = Query(select_fields="*").from_("users").where(condition).build() - expected = 'yql=select * from users where !(status = "inactive")' + q = ( + Query(select_fields="*") + .from_("users") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from users where !(status = "inactive")' self.assertEqual(q, expected) def test_multiple_parameters(self): - f1 = Field("title") + f1 = Queryfield("title") condition = f1.contains("Python") q = ( Query(select_fields="*") @@ -309,9 +423,9 @@ def test_multiple_parameters(self): .where(condition) .add_parameter("tracelevel", 1) .add_parameter("language", "en") - .build() + .build(prepend_yql=False) ) - expected = 'yql=select * from articles where title contains "Python"&tracelevel=1&language=en' + expected = 'select * from articles where title contains "Python"&tracelevel=1&language=en' self.assertEqual(q, expected) def test_multiple_groupings(self): @@ -321,24 +435,39 @@ def test_multiple_groupings(self): G.output(G.count()), G.each(G.group("subcategory"), G.output(G.summary())), ) - q = Query(select_fields="*").from_("products").group(grouping).build() - expected = "yql=select * from products | all(group(category) max(10) output(count()) each(group(subcategory) output(summary())))" + q = ( + Query(select_fields="*") + .from_("products") + .group(grouping) + .build(prepend_yql=False) + ) + expected = "select * from products | all(group(category) max(10) output(count()) each(group(subcategory) output(summary())))" self.assertEqual(q, expected) def test_default_index_annotation(self): condition = Q.userQuery("search terms", index="default_field") - q = Query(select_fields="*").from_("documents").where(condition).build() - expected = 'yql=select * from documents where ({"defaultIndex":"default_field"})userQuery("search terms")' + q = ( + Query(select_fields="*") + .from_("documents") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from documents where ({"defaultIndex":"default_field"})userQuery("search terms")' self.assertEqual(q, expected) def test_Q_p_function(self): condition = Q.p( - Field("f1").contains("v1"), - Field("f2").contains("v2"), - Field("f3").contains("v3"), + Queryfield("f1").contains("v1"), + Queryfield("f2").contains("v2"), + Queryfield("f3").contains("v3"), ) - q = Query(select_fields="*").from_("sd1").where(condition).build() - expected = 'yql=select * from sd1 where f1 contains "v1" and f2 contains "v2" and f3 contains "v3"' + q = ( + Query(select_fields="*") + .from_("sd1") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from sd1 where f1 contains "v1" and f2 contains "v2" and f3 contains "v3"' self.assertEqual(q, expected) def test_rank_multiple_conditions(self): @@ -347,59 +476,77 @@ def test_rank_multiple_conditions(self): Q.dotProduct("embedding", {"feature1": 1}), Q.weightedSet("tags", {"tag1": 2}), ) - q = Query(select_fields="*").from_("documents").where(condition).build() - expected = 'yql=select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1}), weightedSet(tags, {"tag1":2}))' + q = ( + Query(select_fields="*") + .from_("documents") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1}), weightedSet(tags, {"tag1":2}))' self.assertEqual(q, expected) def test_nonEmpty_with_annotations(self): - annotated_field = Field("comments").annotate({"filter": True}) + annotated_field = Queryfield("comments").annotate({"filter": True}) condition = Q.nonEmpty(annotated_field) - q = Query(select_fields="*").from_("posts").where(condition).build() - expected = "yql=select * from posts where nonEmpty(({filter:true})comments)" + q = ( + Query(select_fields="*") + .from_("posts") + .where(condition) + .build(prepend_yql=False) + ) + expected = "select * from posts where nonEmpty(({filter:true})comments)" self.assertEqual(q, expected) def test_weight_annotation(self): - condition = Field("title").contains("heads", annotations={"weight": 200}) - q = Query(select_fields="*").from_("s1").where(condition).build() - expected = 'yql=select * from s1 where title contains({weight:200}"heads")' + condition = Queryfield("title").contains("heads", annotations={"weight": 200}) + q = ( + Query(select_fields="*") + .from_("s1") + .where(condition) + .build(prepend_yql=False) + ) + expected = 'select * from s1 where title contains({weight:200}"heads")' self.assertEqual(q, expected) def test_nearest_neighbor_annotations(self): condition = Q.nearestNeighbor( field="dense_rep", query_vector="q_dense", annotations={"targetHits": 10} ) - q = Query(select_fields=["id, text"]).from_("m").where(condition).build() - expected = "yql=select id, text from m where ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" + q = ( + Query(select_fields=["id, text"]) + .from_("m") + .where(condition) + .build(prepend_yql=False) + ) + expected = "select id, text from m where ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) def test_phrase(self): - text = Field("text") + text = Queryfield("text") condition = text.contains(Q.phrase("st", "louis", "blues")) - query = Q.select("*").where(condition).build() - expected = ( - 'yql=select * from * where text contains phrase("st", "louis", "blues")' - ) + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = 'select * from * where text contains phrase("st", "louis", "blues")' self.assertEqual(query, expected) def test_near(self): - title = Field("title") + title = Queryfield("title") condition = title.contains(Q.near("madonna", "saint")) - query = Q.select("*").where(condition).build() - expected = 'yql=select * from * where title contains near("madonna", "saint")' + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = 'select * from * where title contains near("madonna", "saint")' self.assertEqual(query, expected) def test_onear(self): - title = Field("title") + title = Queryfield("title") condition = title.contains(Q.onear("madonna", "saint")) - query = Q.select("*").where(condition).build() - expected = 'yql=select * from * where title contains onear("madonna", "saint")' + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = 'select * from * where title contains onear("madonna", "saint")' self.assertEqual(query, expected) def test_sameElement(self): - persons = Field("persons") - first_name = Field("first_name") - last_name = Field("last_name") - year_of_birth = Field("year_of_birth") + persons = Queryfield("persons") + first_name = Queryfield("first_name") + last_name = Queryfield("last_name") + year_of_birth = Queryfield("year_of_birth") condition = persons.contains( Q.sameElement( first_name.contains("Joe"), @@ -407,45 +554,50 @@ def test_sameElement(self): year_of_birth < 1940, ) ) - query = Q.select("*").where(condition).build() - expected = 'yql=select * from * where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = 'select * from * where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' self.assertEqual(query, expected) def test_equiv(self): - fieldName = Field("fieldName") + fieldName = Queryfield("fieldName") condition = fieldName.contains(Q.equiv("A", "B")) - query = Q.select("*").where(condition).build() - expected = 'yql=select * from * where fieldName contains equiv("A", "B")' + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = 'select * from * where fieldName contains equiv("A", "B")' self.assertEqual(query, expected) def test_uri(self): - myUrlField = Field("myUrlField") + myUrlField = Queryfield("myUrlField") condition = myUrlField.contains(Q.uri("vespa.ai/foo")) - query = Q.select("*").where(condition).build() - expected = 'yql=select * from * where myUrlField contains uri("vespa.ai/foo")' + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = 'select * from * where myUrlField contains uri("vespa.ai/foo")' self.assertEqual(query, expected) def test_fuzzy(self): - myStringAttribute = Field("myStringAttribute") + myStringAttribute = Queryfield("myStringAttribute") annotations = {"prefixLength": 1, "maxEditDistance": 2} condition = myStringAttribute.contains( Q.fuzzy("parantesis", annotations=annotations) ) - query = Q.select("*").where(condition).build() - expected = 'yql=select * from * where myStringAttribute contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = 'select * from * where myStringAttribute contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' self.assertEqual(query, expected) def test_userInput(self): condition = Q.userInput("@animal") - query = Q.select("*").where(condition).param("animal", "panda").build() - expected = "yql=select * from * where userInput(@animal)&animal=panda" + query = ( + Q.select("*") + .where(condition) + .param("animal", "panda") + .build(prepend_yql=False) + ) + expected = "select * from * where userInput(@animal)&animal=panda" self.assertEqual(query, expected) def test_in_operator(self): - integer_field = Field("integer_field") + integer_field = Queryfield("integer_field") condition = integer_field.in_(10, 20, 30) - query = Q.select("*").where(condition).build() - expected = "yql=select * from * where integer_field in (10, 20, 30)" + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = "select * from * where integer_field in (10, 20, 30)" self.assertEqual(query, expected) def test_predicate(self): @@ -454,20 +606,20 @@ def test_predicate(self): attributes={"gender": "Female"}, range_attributes={"age": "20L"}, ) - query = Q.select("*").where(condition).build() - expected = 'yql=select * from * where predicate(predicate_field,{"gender":"Female"},{"age":20L})' + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = 'select * from * where predicate(predicate_field,{"gender":"Female"},{"age":20L})' self.assertEqual(query, expected) def test_true(self): condition = Q.true() - query = Q.select("*").where(condition).build() - expected = "yql=select * from * where true" + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = "select * from * where true" self.assertEqual(query, expected) def test_false(self): condition = Q.false() - query = Q.select("*").where(condition).build() - expected = "yql=select * from * where false" + query = Q.select("*").where(condition).build(prepend_yql=False) + expected = "select * from * where false" self.assertEqual(query, expected) From d0371e03793fc9dcfd05ae396272be741ae68c00 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 23 Oct 2024 04:42:58 +0200 Subject: [PATCH 13/39] simplify --- tests/unit/test_q.py | 259 ++++++++----------------------------------- 1 file changed, 44 insertions(+), 215 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 7bbff88a..53c6e629 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -9,12 +9,7 @@ def test_dotProduct_with_annotations(self): {"feature1": 1, "feature2": 2}, annotations={"label": "myDotProduct"}, ) - q = ( - Query(select_fields="*") - .from_("querybuilder") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("querybuilder").where(condition) expected = 'select * from querybuilder where ({label:"myDotProduct"}dotProduct(weightedset_field, {"feature1":1,"feature2":2}))' self.assertEqual(q, expected) return q @@ -27,12 +22,7 @@ def test_geoLocation_with_annotations(self): "10km", annotations={"targetHits": 100}, ) - q = ( - Query(select_fields="*") - .from_("querybuilder") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("querybuilder").where(condition) expected = 'select * from querybuilder where ({targetHits:100}geoLocation(location_field, 37.7749, -122.4194, "10km"))' self.assertEqual(q, expected) return q @@ -40,36 +30,21 @@ def test_geoLocation_with_annotations(self): def test_select_specific_fields(self): f1 = Queryfield("f1") condition = f1.contains("v1") - q = ( - Query(select_fields=["f1", "f2"]) - .from_("sd1") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields=["f1", "f2"]).from_("sd1").where(condition) self.assertEqual(q, 'select f1, f2 from sd1 where f1 contains "v1"') def test_select_from_specific_sources(self): f1 = Queryfield("f1") condition = f1.contains("v1") - q = ( - Query(select_fields="*") - .from_("sd1") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("sd1").where(condition) self.assertEqual(q, 'select * from sd1 where f1 contains "v1"') def test_select_from_multiples_sources(self): f1 = Queryfield("f1") condition = f1.contains("v1") - q = ( - Query(select_fields="*") - .from_("sd1", "sd2") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("sd1", "sd2").where(condition) self.assertEqual(q, 'select * from sd1, sd2 where f1 contains "v1"') @@ -90,7 +65,6 @@ def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): .set_timeout(3000) .orderByDesc("age") .orderByAsc("duration") - .build(prepend_yql=False) ) expected = 'select * from querybuilder where ((f1 contains "v1" and f2 contains "v2") or f3 contains "v3") and !(f4 contains "v4") order by age desc, duration asc limit 2 offset 1 timeout 3000' @@ -102,12 +76,7 @@ def test_matches(self): (Queryfield("f1").matches("v1") & Queryfield("f2").matches("v2")) | Queryfield("f3").matches("v3") ) & ~Queryfield("f4").matches("v4") - q = ( - Query(select_fields="*") - .from_("sd1") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("sd1").where(condition) expected = 'select * from sd1 where ((f1 matches "v1" and f2 matches "v2") or f3 matches "v3") and !(f4 matches "v4")' self.assertEqual(q, expected) @@ -116,12 +85,7 @@ def test_nested_queries(self): Queryfield("f2").contains("2") & Queryfield("f3").contains("3") ) | (Queryfield("f2").contains("4") & ~Queryfield("f3").contains("5")) condition = Queryfield("f1").contains("1") & ~nested_query - q = ( - Query(select_fields="*") - .from_("sd1") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("sd1").where(condition) expected = 'select * from sd1 where f1 contains "1" and (!((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5"))))' self.assertEqual(q, expected) @@ -129,12 +93,7 @@ def test_userInput_with_and_without_defaultIndex(self): condition = Q.userQuery(value="value1") & Q.userQuery( index="index", value="value2" ) - q = ( - Query(select_fields="*") - .from_("sd1") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("sd1").where(condition) expected = 'select * from sd1 where userQuery("value1") and ({"defaultIndex":"index"})userQuery("value2")' self.assertEqual(q, expected) @@ -142,12 +101,7 @@ def test_fields_duration(self): f1 = Queryfield("subject") f2 = Queryfield("display_date") f3 = Queryfield("duration") - condition = ( - Query(select_fields=[f1, f2]) - .from_("calendar") - .where(f3 > 0) - .build(prepend_yql=False) - ) + condition = Query(select_fields=[f1, f2]).from_("calendar").where(f3 > 0) expected = "select subject, display_date from calendar where duration > 0" self.assertEqual(condition, expected) @@ -160,7 +114,6 @@ def test_nearest_neighbor(self): Query(select_fields=["id, text"]) .from_("m") .where(condition_uq | condition_nn) - .build(prepend_yql=False) ) expected = "select id, text from m where userQuery() or ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) @@ -180,7 +133,6 @@ def test_build_many_nn_operators(self): Query(select_fields="*") .from_("doc") .where(condition=Condition.any(*conditions)) - .build(prepend_yql=False) ) expected = "select * from doc where " + " or ".join( [ @@ -193,24 +145,14 @@ def test_build_many_nn_operators(self): def test_field_comparison_operators(self): f1 = Queryfield("age") condition = (f1 > 30) & (f1 <= 50) - q = ( - Query(select_fields="*") - .from_("people") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("people").where(condition) expected = "select * from people where age > 30 and age <= 50" self.assertEqual(q, expected) def test_field_in_range(self): f1 = Queryfield("age") condition = f1.in_range(18, 65) - q = ( - Query(select_fields="*") - .from_("people") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("people").where(condition) expected = "select * from people where range(age, 18, 65)" self.assertEqual(q, expected) @@ -218,12 +160,7 @@ def test_field_annotation(self): f1 = Queryfield("title") annotations = {"highlight": True} annotated_field = f1.annotate(annotations) - q = ( - Query(select_fields="*") - .from_("articles") - .where(annotated_field) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("articles").where(annotated_field) expected = "select * from articles where ({highlight:true})title" self.assertEqual(q, expected) @@ -231,23 +168,13 @@ def test_condition_annotation(self): f1 = Queryfield("title") condition = f1.contains("Python") annotated_condition = condition.annotate({"filter": True}) - q = ( - Query(select_fields="*") - .from_("articles") - .where(annotated_condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("articles").where(annotated_condition) expected = 'select * from articles where ({filter:true})title contains "Python"' self.assertEqual(q, expected) def test_grouping_aggregation(self): grouping = G.all(G.group("category"), G.output(G.count())) - q = ( - Query(select_fields="*") - .from_("products") - .group(grouping) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("products").group(grouping) expected = "select * from products | all(group(category) output(count()))" self.assertEqual(q, expected) @@ -259,7 +186,6 @@ def test_add_parameter(self): .from_("articles") .where(condition) .add_parameter("tracelevel", 1) - .build(prepend_yql=False) ) expected = 'select * from articles where title contains "Python"&tracelevel=1' self.assertEqual(q, expected) @@ -268,23 +194,13 @@ def test_custom_ranking_expression(self): condition = Q.rank( Q.userQuery(), Q.dotProduct("embedding", {"feature1": 1, "feature2": 2}) ) - q = ( - Query(select_fields="*") - .from_("documents") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("documents").where(condition) expected = 'select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1,"feature2":2}))' self.assertEqual(q, expected) def test_wand(self): condition = Q.wand("keywords", {"apple": 10, "banana": 20}) - q = ( - Query(select_fields="*") - .from_("fruits") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("fruits").where(condition) expected = 'select * from fruits where wand(keywords, {"apple":10,"banana":20})' self.assertEqual(q, expected) @@ -294,23 +210,13 @@ def test_weakand(self): condition = Q.weakAnd( condition1, condition2, annotations={"targetNumHits": 100} ) - q = ( - Query(select_fields="*") - .from_("articles") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("articles").where(condition) expected = 'select * from articles where ({"targetNumHits":100}weakAnd(title contains "Python", description contains "Programming"))' self.assertEqual(q, expected) def test_geoLocation(self): condition = Q.geoLocation("location_field", 37.7749, -122.4194, "10km") - q = ( - Query(select_fields="*") - .from_("places") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("places").where(condition) expected = 'select * from places where geoLocation(location_field, 37.7749, -122.4194, "10km")' self.assertEqual(q, expected) @@ -319,12 +225,7 @@ def test_condition_all_any(self): c2 = Queryfield("f2").contains("v2") c3 = Queryfield("f3").contains("v3") condition = Condition.all(c1, c2, Condition.any(c3, ~c1)) - q = ( - Query(select_fields="*") - .from_("sd1") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("sd1").where(condition) expected = 'select * from sd1 where f1 contains "v1" and f2 contains "v2" and (f3 contains "v3" or !(f1 contains "v1"))' self.assertEqual(q, expected) @@ -337,7 +238,6 @@ def test_order_by_with_annotations(self): .from_("products") .orderByDesc(f1, annotations) .orderByAsc(f2) - .build(prepend_yql=False) ) expected = ( 'select * from products order by {"strength":0.5}relevance desc, price asc' @@ -347,70 +247,40 @@ def test_order_by_with_annotations(self): def test_field_comparison_methods(self): f1 = Queryfield("age") condition = f1.ge(18) & f1.lt(30) - q = ( - Query(select_fields="*") - .from_("users") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("users").where(condition) expected = "select * from users where age >= 18 and age < 30" self.assertEqual(q, expected) def test_filter_annotation(self): f1 = Queryfield("title") condition = f1.contains("Python").annotate({"filter": True}) - q = ( - Query(select_fields="*") - .from_("articles") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("articles").where(condition) expected = 'select * from articles where ({filter:true})title contains "Python"' self.assertEqual(q, expected) def test_nonEmpty(self): condition = Q.nonEmpty(Queryfield("comments").eq("any_value")) - q = ( - Query(select_fields="*") - .from_("posts") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("posts").where(condition) expected = 'select * from posts where nonEmpty(comments = "any_value")' self.assertEqual(q, expected) def test_dotProduct(self): condition = Q.dotProduct("vector_field", {"feature1": 1, "feature2": 2}) - q = ( - Query(select_fields="*") - .from_("vectors") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("vectors").where(condition) expected = 'select * from vectors where dotProduct(vector_field, {"feature1":1,"feature2":2})' self.assertEqual(q, expected) def test_in_range_string_values(self): f1 = Queryfield("date") condition = f1.in_range("2021-01-01", "2021-12-31") - q = ( - Query(select_fields="*") - .from_("events") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("events").where(condition) expected = "select * from events where range(date, 2021-01-01, 2021-12-31)" self.assertEqual(q, expected) def test_condition_inversion(self): f1 = Queryfield("status") condition = ~f1.eq("inactive") - q = ( - Query(select_fields="*") - .from_("users") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("users").where(condition) expected = 'select * from users where !(status = "inactive")' self.assertEqual(q, expected) @@ -423,7 +293,6 @@ def test_multiple_parameters(self): .where(condition) .add_parameter("tracelevel", 1) .add_parameter("language", "en") - .build(prepend_yql=False) ) expected = 'select * from articles where title contains "Python"&tracelevel=1&language=en' self.assertEqual(q, expected) @@ -435,23 +304,13 @@ def test_multiple_groupings(self): G.output(G.count()), G.each(G.group("subcategory"), G.output(G.summary())), ) - q = ( - Query(select_fields="*") - .from_("products") - .group(grouping) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("products").group(grouping) expected = "select * from products | all(group(category) max(10) output(count()) each(group(subcategory) output(summary())))" self.assertEqual(q, expected) def test_default_index_annotation(self): condition = Q.userQuery("search terms", index="default_field") - q = ( - Query(select_fields="*") - .from_("documents") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("documents").where(condition) expected = 'select * from documents where ({"defaultIndex":"default_field"})userQuery("search terms")' self.assertEqual(q, expected) @@ -461,12 +320,7 @@ def test_Q_p_function(self): Queryfield("f2").contains("v2"), Queryfield("f3").contains("v3"), ) - q = ( - Query(select_fields="*") - .from_("sd1") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("sd1").where(condition) expected = 'select * from sd1 where f1 contains "v1" and f2 contains "v2" and f3 contains "v3"' self.assertEqual(q, expected) @@ -476,35 +330,20 @@ def test_rank_multiple_conditions(self): Q.dotProduct("embedding", {"feature1": 1}), Q.weightedSet("tags", {"tag1": 2}), ) - q = ( - Query(select_fields="*") - .from_("documents") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("documents").where(condition) expected = 'select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1}), weightedSet(tags, {"tag1":2}))' self.assertEqual(q, expected) def test_nonEmpty_with_annotations(self): annotated_field = Queryfield("comments").annotate({"filter": True}) condition = Q.nonEmpty(annotated_field) - q = ( - Query(select_fields="*") - .from_("posts") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("posts").where(condition) expected = "select * from posts where nonEmpty(({filter:true})comments)" self.assertEqual(q, expected) def test_weight_annotation(self): condition = Queryfield("title").contains("heads", annotations={"weight": 200}) - q = ( - Query(select_fields="*") - .from_("s1") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields="*").from_("s1").where(condition) expected = 'select * from s1 where title contains({weight:200}"heads")' self.assertEqual(q, expected) @@ -512,33 +351,28 @@ def test_nearest_neighbor_annotations(self): condition = Q.nearestNeighbor( field="dense_rep", query_vector="q_dense", annotations={"targetHits": 10} ) - q = ( - Query(select_fields=["id, text"]) - .from_("m") - .where(condition) - .build(prepend_yql=False) - ) + q = Query(select_fields=["id, text"]).from_("m").where(condition) expected = "select id, text from m where ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) def test_phrase(self): text = Queryfield("text") condition = text.contains(Q.phrase("st", "louis", "blues")) - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = 'select * from * where text contains phrase("st", "louis", "blues")' self.assertEqual(query, expected) def test_near(self): title = Queryfield("title") condition = title.contains(Q.near("madonna", "saint")) - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = 'select * from * where title contains near("madonna", "saint")' self.assertEqual(query, expected) def test_onear(self): title = Queryfield("title") condition = title.contains(Q.onear("madonna", "saint")) - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = 'select * from * where title contains onear("madonna", "saint")' self.assertEqual(query, expected) @@ -554,21 +388,21 @@ def test_sameElement(self): year_of_birth < 1940, ) ) - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = 'select * from * where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' self.assertEqual(query, expected) def test_equiv(self): fieldName = Queryfield("fieldName") condition = fieldName.contains(Q.equiv("A", "B")) - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = 'select * from * where fieldName contains equiv("A", "B")' self.assertEqual(query, expected) def test_uri(self): myUrlField = Queryfield("myUrlField") condition = myUrlField.contains(Q.uri("vespa.ai/foo")) - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = 'select * from * where myUrlField contains uri("vespa.ai/foo")' self.assertEqual(query, expected) @@ -578,25 +412,20 @@ def test_fuzzy(self): condition = myStringAttribute.contains( Q.fuzzy("parantesis", annotations=annotations) ) - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = 'select * from * where myStringAttribute contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' self.assertEqual(query, expected) def test_userInput(self): condition = Q.userInput("@animal") - query = ( - Q.select("*") - .where(condition) - .param("animal", "panda") - .build(prepend_yql=False) - ) + query = Q.select("*").where(condition).param("animal", "panda") expected = "select * from * where userInput(@animal)&animal=panda" self.assertEqual(query, expected) def test_in_operator(self): integer_field = Queryfield("integer_field") condition = integer_field.in_(10, 20, 30) - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = "select * from * where integer_field in (10, 20, 30)" self.assertEqual(query, expected) @@ -606,19 +435,19 @@ def test_predicate(self): attributes={"gender": "Female"}, range_attributes={"age": "20L"}, ) - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = 'select * from * where predicate(predicate_field,{"gender":"Female"},{"age":20L})' self.assertEqual(query, expected) def test_true(self): condition = Q.true() - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = "select * from * where true" self.assertEqual(query, expected) def test_false(self): condition = Q.false() - query = Q.select("*").where(condition).build(prepend_yql=False) + query = Q.select("*").where(condition) expected = "select * from * where false" self.assertEqual(query, expected) From 2eb7f63f6cd63ccb0780f58a56c1958d84350952 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 23 Oct 2024 04:43:15 +0200 Subject: [PATCH 14/39] add dunder methods to remove build --- vespa/querybuilder/main.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py index 346c27ba..066d0d80 100644 --- a/vespa/querybuilder/main.py +++ b/vespa/querybuilder/main.py @@ -194,7 +194,9 @@ def any(cls, *conditions: "Condition") -> "Condition": class Query: - def __init__(self, select_fields: Union[str, List[str], List[Queryfield]]): + def __init__( + self, select_fields: Union[str, List[str], List[Queryfield]], prepend_yql=False + ): self.select_fields = ( ", ".join(select_fields) if isinstance(select_fields, List) @@ -209,6 +211,19 @@ def __init__(self, select_fields: Union[str, List[str], List[Queryfield]]): self.timeout_value = None self.parameters = {} self.grouping = None + self.prepend_yql = prepend_yql + + def __str__(self) -> str: + return self.build(self.prepend_yql) + + def __eq__(self, other: Any) -> bool: + return self.build() == other + + def __ne__(self, other: Any) -> bool: + return self.build() != other + + def __repr__(self) -> str: + return str(self) def from_(self, *sources: str) -> "Query": self.sources = ", ".join(sources) @@ -271,7 +286,7 @@ def group(self, group_expression: str) -> "Query": self.grouping = group_expression return self - def build(self, prepend_yql=True) -> str: + def build(self, prepend_yql=False) -> str: query = f"select {self.select_fields} from {self.sources}" if prepend_yql: query = f"yql={query}" From b60c6cf16a0173f0de19e75da14add7c7a880a84 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 09:44:24 +0100 Subject: [PATCH 15/39] no index annotation on userQuery --- vespa/querybuilder/main.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py index 066d0d80..6feea77c 100644 --- a/vespa/querybuilder/main.py +++ b/vespa/querybuilder/main.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from typing import Any, List, Union, Optional, Dict -import json @dataclass @@ -324,20 +323,18 @@ def p(*args): return condition @staticmethod - def userQuery(value: str = "", index: Optional[str] = None) -> Condition: - if index is None: - # Only value provided - return ( - Condition(f'userQuery("{value}")') - if value - else Condition("userQuery()") - ) - else: - # Both index and value provided - default_index_json = json.dumps( - {"defaultIndex": index}, separators=(",", ":") - ) - return Condition(f'({default_index_json})userQuery("{value}")') + def userQuery(value: str = "") -> Condition: + return Condition(f'userQuery("{value}")') if value else Condition("userQuery()") + # else: + # # Both index and value provided + # default_index_json = json.dumps( + # {"defaultIndex": index}, separators=(",", ":") + # ) + # return ( + # Condition(f'({default_index_json}userQuery("{value}"))') + # if value + # else Condition(f"({default_index_json}userQuery())") + # ) @staticmethod def dotProduct( From 8538efdc4f11ff382dca10b055f7830956d776d1 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 09:44:41 +0100 Subject: [PATCH 16/39] more integration tests --- tests/integration/test_integration_queries.py | 289 ++++++++++++++++-- 1 file changed, 256 insertions(+), 33 deletions(-) diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py index 8ce3150a..3dfd1e19 100644 --- a/tests/integration/test_integration_queries.py +++ b/tests/integration/test_integration_queries.py @@ -5,6 +5,7 @@ Schema, Document, Field, + FieldSet, StructField, Struct, RankProfile, @@ -19,8 +20,12 @@ class TestQueriesIntegration(unittest.TestCase): def setUpClass(cls): application_name = "querybuilder" cls.application_name = application_name - + schema_name1 = "sd1" + schema_name2 = "sd2" + cls.schema_name1 = schema_name1 + cls.schema_name2 = schema_name2 # Define all fields used in the unit tests + # Schema 1 fields = [ Field( name="weightedset_field", @@ -28,10 +33,10 @@ def setUpClass(cls): indexing=["attribute"], ), Field(name="location_field", type="position", indexing=["attribute"]), - Field(name="f1", type="string", indexing=["index", "summary"]), - Field(name="f2", type="string", indexing=["index", "summary"]), - Field(name="f3", type="string", indexing=["index", "summary"]), - Field(name="f4", type="string", indexing=["index", "summary"]), + Field(name="f1", type="string", indexing=["attribute", "summary"]), + Field(name="f2", type="string", indexing=["attribute", "summary"]), + Field(name="f3", type="string", indexing=["attribute", "summary"]), + Field(name="f4", type="string", indexing=["attribute", "summary"]), Field(name="age", type="int", indexing=["attribute", "summary"]), Field(name="duration", type="int", indexing=["attribute", "summary"]), Field(name="id", type="string", indexing=["attribute", "summary"]), @@ -106,14 +111,29 @@ def setUpClass(cls): first_phase="distance(location_field)", summary_features=["distance(location_field).km"], ), + RankProfile( + name="bm25", first_phase="bm25(text)", summary_features=["bm25(text)"] + ), ] + fieldset = FieldSet(name="default", fields=["text", "title", "description"]) document = Document(fields=fields, structs=[email_struct]) - schema = Schema( - name=application_name, document=document, rank_profiles=rank_profiles + schema1 = Schema( + name=schema_name1, + document=document, + rank_profiles=rank_profiles, + fieldsets=[fieldset], ) - schema.add_fields(emails_field) - application_package = ApplicationPackage(name=application_name, schema=[schema]) - print(application_package.schema.schema_to_text) + schema1.add_fields(emails_field) + ## Schema 2 + schema2 = Schema( + name=schema_name2, document=document, rank_profiles=rank_profiles + ) + # Create the application package + application_package = ApplicationPackage( + name=application_name, schema=[schema1, schema2] + ) + print(application_package.get_schema(schema_name1).schema_to_text) + print(application_package.get_schema(schema_name2).schema_to_text) # Deploy the application cls.vespa_docker = VespaDocker(port=8089) cls.app = cls.vespa_docker.deploy(application_package=application_package) @@ -124,14 +144,13 @@ def tearDownClass(cls): cls.vespa_docker.container.stop(timeout=5) cls.vespa_docker.container.remove() - # @unittest.skip("Skip until we have a better way to test this") def test_dotProduct_with_annotations(self): # Feed a document with 'weightedset_field' field = "weightedset_field" fields = {field: {"feature1": 2, "feature2": 4}} data_id = 1 self.app.feed_data_point( - schema=self.application_name, data_id=data_id, fields=fields + schema=self.schema_name1, data_id=data_id, fields=fields ) q = qb.test_dotProduct_with_annotations() with self.app.syncio() as sess: @@ -140,14 +159,14 @@ def test_dotProduct_with_annotations(self): self.assertEqual(len(result.hits), 1) self.assertEqual( result.hits[0]["id"], - f"id:{self.application_name}:{self.application_name}::{data_id}", + f"id:{self.schema_name1}:{self.schema_name1}::{data_id}", ) self.assertEqual( result.hits[0]["fields"]["summaryfeatures"]["rawScore(weightedset_field)"], 10, ) - def test_geoLocation_with_annotations(self): + def test_geolocation_with_annotations(self): # Feed a document with 'location_field' field_name = "location_field" fields = { @@ -158,17 +177,17 @@ def test_geoLocation_with_annotations(self): } data_id = 2 self.app.feed_data_point( - schema=self.application_name, data_id=data_id, fields=fields + schema=self.schema_name1, data_id=data_id, fields=fields ) # Build and send the query - q = qb.test_geoLocation_with_annotations() + q = qb.test_geolocation_with_annotations() with self.app.syncio() as sess: result = sess.query(yql=q, ranking="geolocation") # Check the result self.assertEqual(len(result.hits), 1) self.assertEqual( result.hits[0]["id"], - f"id:{self.application_name}:{self.application_name}::{data_id}", + f"id:{self.schema_name1}:{self.schema_name1}::{data_id}", ) self.assertAlmostEqual( result.hits[0]["fields"]["summaryfeatures"]["distance(location_field).km"], @@ -178,7 +197,7 @@ def test_geoLocation_with_annotations(self): def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): docs = [ - { # Should not match + { # Should not match - f3 doesn't contain "v3" "f1": "v1", "f2": "v2", "f3": "asdf", @@ -202,35 +221,239 @@ def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): "age": 30, "duration": 300, }, - { # Should not match + { # Should not match - contains f4="v4" "f1": "v1", "f2": "v2", "f3": "v3", "f4": "v4", - "age": 30, - "duration": 300, + "age": 40, + "duration": 400, }, ] - id_to_match = 2 + + # Feed documents + docs = [{"id": data_id, "fields": doc} for data_id, doc in enumerate(docs, 1)] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + + # Build and send query + q = qb.test_basic_and_andnot_or_offset_limit_param_order_by_and_contains() + print(f"Executing query: {q}") + + with self.app.syncio() as sess: + result = sess.query(yql=q) + + # Verify results + self.assertEqual( + len(result.hits), 1 + ) # Should get 1 hit due to offset=1, limit=2 + + # The query orders by age desc, duration asc with offset 1 + # So we should get doc ID 2 (since doc ID 3 is skipped due to offset) + hit = result.hits[0] + self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::2") + + # Verify the matching document has expected field values + self.assertEqual(hit["fields"]["age"], 20) + self.assertEqual(hit["fields"]["duration"], 200) + self.assertEqual(hit["fields"]["f1"], "v1") + self.assertEqual(hit["fields"]["f2"], "v2") + self.assertEqual(hit["fields"]["f3"], "v3") + self.assertEqual(hit["fields"]["f4"], "d") + + print(result.json) + + def test_matches(self): + # Matches is a regex (or substring) match + # Feed test documents + docs = [ + { # Doc 1: Should match - satisfies (f1="v1" AND f2="v2") and f4!="v4" + "f1": "v1", + "f2": "v2", + "f3": "other", + "f4": "nothing", + }, + { # Doc 2: Should not match - fails f4!="v4" condition + "f1": "v1", + "f2": "v2", + "f3": "v3", + "f4": "v4", + }, + { # Doc 3: Should match - satisfies f3="v3" and f4!="v4" + "f1": "other", + "f2": "other", + "f3": "v3", + "f4": "nothing", + }, + { # Doc 4: Should not match - fails all conditions + "f1": "other", + "f2": "other", + "f3": "other", + "f4": "v4", + }, + ] + + # Ensure fields are properly indexed for matching docs = [ { "fields": doc, - "id": data_id, + "id": str(data_id), } for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.application_name) - # Build and send the query - q = qb.test_basic_and_andnot_or_offset_limit_param_order_by_and_contains() - print(q) + + # Feed documents + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + + # Build and send query + q = qb.test_matches() + # select * from sd1 where ((f1 matches "v1" and f2 matches "v2") or f3 matches "v3") and !(f4 matches "v4") + print(f"Executing query: {q}") + with self.app.syncio() as sess: - result = sess.query( - yql=q, - ) - # Check the result + result = sess.query(yql=q) + + # Check result count + self.assertEqual(len(result.hits), 2) + + # Verify specific matches + ids = sorted([hit["id"] for hit in result.hits]) + expected_ids = sorted( + [ + f"id:{self.schema_name1}:{self.schema_name1}::1", + f"id:{self.schema_name1}:{self.schema_name1}::3", + ] + ) + + self.assertEqual(ids, expected_ids) + print(result.json) + + def test_nested_queries(self): + # Contains is an exact match + # q = 'select * from sd1 where f1 contains "1" and (!((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5"))))' + # Feed test documents + docs = [ + { # Doc 1: Should not match - satisfies f1 contains "1" but fails inner query + "f1": "1", + "f2": "2", + "f3": "3", + }, + { # Doc 2: Should match + "f1": "1", + "f2": "4", + "f3": "5", + }, + { # Doc 3: Should not match - fails f1 contains "1" + "f1": "other", + "f2": "2", + "f3": "3", + }, + { # Doc 4: Should not match + "f1": "1", + "f2": "4", + "f3": "other", + }, + ] + docs = [ + { + "fields": doc, + "id": str(data_id), + } + for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + q = qb.test_nested_queries() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + print(result.json) self.assertEqual(len(result.hits), 1) self.assertEqual( result.hits[0]["id"], - f"id:{self.application_name}:{self.application_name}::{id_to_match}", + f"id:{self.schema_name1}:{self.schema_name1}::2", ) - print(result.json) + + def test_userquery_defaultindex(self): + # 'select * from sd1 where ({"defaultIndex":"text"}userQuery())' + # Feed test documents + docs = [ + { # Doc 1: Should match + "description": "foo", + "text": "foo", + }, + { # Doc 2: Should match + "description": "foo", + "text": "bar", + }, + { # Doc 3: Should not match + "description": "bar", + "text": "baz", + }, + ] + + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + + # Execute query + q = qb.test_userquery() + query = "foo" + print(f"Executing query: {q}") + body = { + "yql": str(q), + "query": query, + } + with self.app.syncio() as sess: + result = sess.query(body=body) + self.assertEqual(len(result.hits), 2) + ids = sorted([hit["id"] for hit in result.hits]) + self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) + self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) + + def test_userquery_customindex(self): + # 'select * from sd1 where userQuery())' + # Feed test documents + docs = [ + { # Doc 1: Should match + "description": "foo", + "text": "foo", + }, + { # Doc 2: Should not match + "description": "foo", + "text": "bar", + }, + { # Doc 3: Should not match + "description": "bar", + "text": "baz", + }, + ] + + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + + # Execute query + q = qb.test_userquery() + query = "foo" + print(f"Executing query: {q}") + body = { + "yql": str(q), + "query": query, + "ranking": "bm25", + "model.defaultIndex": "text", # userQuery() needs this to set index, see https://docs.vespa.ai/en/query-api.html#using-a-fieldset + } + with self.app.syncio() as sess: + result = sess.query(body=body) + # Verify only one document matches both conditions + self.assertEqual(len(result.hits), 1) + self.assertEqual( + result.hits[0]["id"], f"id:{self.schema_name1}:{self.schema_name1}::1" + ) + + # Verify matching document has expected values + hit = result.hits[0] + self.assertEqual(hit["fields"]["description"], "foo") + self.assertEqual(hit["fields"]["text"], "foo") From 41f70ecdc879085870a99560809d37d57838483b Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 09:44:51 +0100 Subject: [PATCH 17/39] more unit tests --- tests/unit/test_q.py | 95 +++++++++++++++++++++++++++++++------------- 1 file changed, 68 insertions(+), 27 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 53c6e629..218751b7 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -9,12 +9,12 @@ def test_dotProduct_with_annotations(self): {"feature1": 1, "feature2": 2}, annotations={"label": "myDotProduct"}, ) - q = Query(select_fields="*").from_("querybuilder").where(condition) - expected = 'select * from querybuilder where ({label:"myDotProduct"}dotProduct(weightedset_field, {"feature1":1,"feature2":2}))' + q = Query(select_fields="*").from_("sd1").where(condition) + expected = 'select * from sd1 where ({label:"myDotProduct"}dotProduct(weightedset_field, {"feature1":1,"feature2":2}))' self.assertEqual(q, expected) return q - def test_geoLocation_with_annotations(self): + def test_geolocation_with_annotations(self): condition = Q.geoLocation( "location_field", 37.7749, @@ -22,8 +22,8 @@ def test_geoLocation_with_annotations(self): "10km", annotations={"targetHits": 100}, ) - q = Query(select_fields="*").from_("querybuilder").where(condition) - expected = 'select * from querybuilder where ({targetHits:100}geoLocation(location_field, 37.7749, -122.4194, "10km"))' + q = Query(select_fields="*").from_("sd1").where(condition) + expected = 'select * from sd1 where ({targetHits:100}geoLocation(location_field, 37.7749, -122.4194, "10km"))' self.assertEqual(q, expected) return q @@ -31,22 +31,22 @@ def test_select_specific_fields(self): f1 = Queryfield("f1") condition = f1.contains("v1") q = Query(select_fields=["f1", "f2"]).from_("sd1").where(condition) - self.assertEqual(q, 'select f1, f2 from sd1 where f1 contains "v1"') + return q def test_select_from_specific_sources(self): f1 = Queryfield("f1") condition = f1.contains("v1") q = Query(select_fields="*").from_("sd1").where(condition) - self.assertEqual(q, 'select * from sd1 where f1 contains "v1"') + return q def test_select_from_multiples_sources(self): f1 = Queryfield("f1") condition = f1.contains("v1") q = Query(select_fields="*").from_("sd1", "sd2").where(condition) - self.assertEqual(q, 'select * from sd1, sd2 where f1 contains "v1"') + return q def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): f1 = Queryfield("f1") @@ -58,7 +58,7 @@ def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): ) q = ( Query(select_fields="*") - .from_("querybuilder") + .from_("sd1") .where(condition) .set_offset(1) .set_limit(2) @@ -67,7 +67,7 @@ def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): .orderByAsc("duration") ) - expected = 'select * from querybuilder where ((f1 contains "v1" and f2 contains "v2") or f3 contains "v3") and !(f4 contains "v4") order by age desc, duration asc limit 2 offset 1 timeout 3000' + expected = 'select * from sd1 where ((f1 contains "v1" and f2 contains "v2") or f3 contains "v3") and !(f4 contains "v4") order by age desc, duration asc limit 2 offset 1 timeout 3000' self.assertEqual(q, expected) return q @@ -79,6 +79,7 @@ def test_matches(self): q = Query(select_fields="*").from_("sd1").where(condition) expected = 'select * from sd1 where ((f1 matches "v1" and f2 matches "v2") or f3 matches "v3") and !(f4 matches "v4")' self.assertEqual(q, expected) + return q def test_nested_queries(self): nested_query = ( @@ -88,22 +89,23 @@ def test_nested_queries(self): q = Query(select_fields="*").from_("sd1").where(condition) expected = 'select * from sd1 where f1 contains "1" and (!((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5"))))' self.assertEqual(q, expected) + return q - def test_userInput_with_and_without_defaultIndex(self): - condition = Q.userQuery(value="value1") & Q.userQuery( - index="index", value="value2" - ) + def test_userquery(self): + condition = Q.userQuery() q = Query(select_fields="*").from_("sd1").where(condition) - expected = 'select * from sd1 where userQuery("value1") and ({"defaultIndex":"index"})userQuery("value2")' + expected = "select * from sd1 where userQuery()" self.assertEqual(q, expected) + return q def test_fields_duration(self): f1 = Queryfield("subject") f2 = Queryfield("display_date") f3 = Queryfield("duration") - condition = Query(select_fields=[f1, f2]).from_("calendar").where(f3 > 0) + q = Query(select_fields=[f1, f2]).from_("calendar").where(f3 > 0) expected = "select subject, display_date from calendar where duration > 0" - self.assertEqual(condition, expected) + self.assertEqual(q, expected) + return q def test_nearest_neighbor(self): condition_uq = Q.userQuery() @@ -117,6 +119,7 @@ def test_nearest_neighbor(self): ) expected = "select id, text from m where userQuery() or ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) + return q def test_build_many_nn_operators(self): self.maxDiff = None @@ -141,6 +144,7 @@ def test_build_many_nn_operators(self): ] ) self.assertEqual(q, expected) + return q def test_field_comparison_operators(self): f1 = Queryfield("age") @@ -148,6 +152,7 @@ def test_field_comparison_operators(self): q = Query(select_fields="*").from_("people").where(condition) expected = "select * from people where age > 30 and age <= 50" self.assertEqual(q, expected) + return q def test_field_in_range(self): f1 = Queryfield("age") @@ -155,6 +160,7 @@ def test_field_in_range(self): q = Query(select_fields="*").from_("people").where(condition) expected = "select * from people where range(age, 18, 65)" self.assertEqual(q, expected) + return q def test_field_annotation(self): f1 = Queryfield("title") @@ -163,6 +169,7 @@ def test_field_annotation(self): q = Query(select_fields="*").from_("articles").where(annotated_field) expected = "select * from articles where ({highlight:true})title" self.assertEqual(q, expected) + return q def test_condition_annotation(self): f1 = Queryfield("title") @@ -171,12 +178,14 @@ def test_condition_annotation(self): q = Query(select_fields="*").from_("articles").where(annotated_condition) expected = 'select * from articles where ({filter:true})title contains "Python"' self.assertEqual(q, expected) + return q def test_grouping_aggregation(self): grouping = G.all(G.group("category"), G.output(G.count())) q = Query(select_fields="*").from_("products").group(grouping) expected = "select * from products | all(group(category) output(count()))" self.assertEqual(q, expected) + return q def test_add_parameter(self): f1 = Queryfield("title") @@ -189,6 +198,7 @@ def test_add_parameter(self): ) expected = 'select * from articles where title contains "Python"&tracelevel=1' self.assertEqual(q, expected) + return q def test_custom_ranking_expression(self): condition = Q.rank( @@ -197,12 +207,14 @@ def test_custom_ranking_expression(self): q = Query(select_fields="*").from_("documents").where(condition) expected = 'select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1,"feature2":2}))' self.assertEqual(q, expected) + return q def test_wand(self): condition = Q.wand("keywords", {"apple": 10, "banana": 20}) q = Query(select_fields="*").from_("fruits").where(condition) expected = 'select * from fruits where wand(keywords, {"apple":10,"banana":20})' self.assertEqual(q, expected) + return q def test_weakand(self): condition1 = Queryfield("title").contains("Python") @@ -213,12 +225,14 @@ def test_weakand(self): q = Query(select_fields="*").from_("articles").where(condition) expected = 'select * from articles where ({"targetNumHits":100}weakAnd(title contains "Python", description contains "Programming"))' self.assertEqual(q, expected) + return q - def test_geoLocation(self): + def test_geolocation(self): condition = Q.geoLocation("location_field", 37.7749, -122.4194, "10km") q = Query(select_fields="*").from_("places").where(condition) expected = 'select * from places where geoLocation(location_field, 37.7749, -122.4194, "10km")' self.assertEqual(q, expected) + return q def test_condition_all_any(self): c1 = Queryfield("f1").contains("v1") @@ -228,6 +242,7 @@ def test_condition_all_any(self): q = Query(select_fields="*").from_("sd1").where(condition) expected = 'select * from sd1 where f1 contains "v1" and f2 contains "v2" and (f3 contains "v3" or !(f1 contains "v1"))' self.assertEqual(q, expected) + return q def test_order_by_with_annotations(self): f1 = "relevance" @@ -243,6 +258,7 @@ def test_order_by_with_annotations(self): 'select * from products order by {"strength":0.5}relevance desc, price asc' ) self.assertEqual(q, expected) + return q def test_field_comparison_methods(self): f1 = Queryfield("age") @@ -250,6 +266,7 @@ def test_field_comparison_methods(self): q = Query(select_fields="*").from_("users").where(condition) expected = "select * from users where age >= 18 and age < 30" self.assertEqual(q, expected) + return q def test_filter_annotation(self): f1 = Queryfield("title") @@ -257,18 +274,21 @@ def test_filter_annotation(self): q = Query(select_fields="*").from_("articles").where(condition) expected = 'select * from articles where ({filter:true})title contains "Python"' self.assertEqual(q, expected) + return q - def test_nonEmpty(self): + def test_non_empty(self): condition = Q.nonEmpty(Queryfield("comments").eq("any_value")) q = Query(select_fields="*").from_("posts").where(condition) expected = 'select * from posts where nonEmpty(comments = "any_value")' self.assertEqual(q, expected) + return q - def test_dotProduct(self): + def test_dotproduct(self): condition = Q.dotProduct("vector_field", {"feature1": 1, "feature2": 2}) q = Query(select_fields="*").from_("vectors").where(condition) expected = 'select * from vectors where dotProduct(vector_field, {"feature1":1,"feature2":2})' self.assertEqual(q, expected) + return q def test_in_range_string_values(self): f1 = Queryfield("date") @@ -276,6 +296,7 @@ def test_in_range_string_values(self): q = Query(select_fields="*").from_("events").where(condition) expected = "select * from events where range(date, 2021-01-01, 2021-12-31)" self.assertEqual(q, expected) + return q def test_condition_inversion(self): f1 = Queryfield("status") @@ -283,6 +304,7 @@ def test_condition_inversion(self): q = Query(select_fields="*").from_("users").where(condition) expected = 'select * from users where !(status = "inactive")' self.assertEqual(q, expected) + return q def test_multiple_parameters(self): f1 = Queryfield("title") @@ -296,6 +318,7 @@ def test_multiple_parameters(self): ) expected = 'select * from articles where title contains "Python"&tracelevel=1&language=en' self.assertEqual(q, expected) + return q def test_multiple_groupings(self): grouping = G.all( @@ -307,14 +330,16 @@ def test_multiple_groupings(self): q = Query(select_fields="*").from_("products").group(grouping) expected = "select * from products | all(group(category) max(10) output(count()) each(group(subcategory) output(summary())))" self.assertEqual(q, expected) + return q - def test_default_index_annotation(self): - condition = Q.userQuery("search terms", index="default_field") + def test_userquery_basic(self): + condition = Q.userQuery("search terms") q = Query(select_fields="*").from_("documents").where(condition) - expected = 'select * from documents where ({"defaultIndex":"default_field"})userQuery("search terms")' + expected = 'select * from documents where userQuery("search terms")' self.assertEqual(q, expected) + return q - def test_Q_p_function(self): + def test_q_p_function(self): condition = Q.p( Queryfield("f1").contains("v1"), Queryfield("f2").contains("v2"), @@ -333,19 +358,22 @@ def test_rank_multiple_conditions(self): q = Query(select_fields="*").from_("documents").where(condition) expected = 'select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1}), weightedSet(tags, {"tag1":2}))' self.assertEqual(q, expected) + return q - def test_nonEmpty_with_annotations(self): + def test_non_empty_with_annotations(self): annotated_field = Queryfield("comments").annotate({"filter": True}) condition = Q.nonEmpty(annotated_field) q = Query(select_fields="*").from_("posts").where(condition) expected = "select * from posts where nonEmpty(({filter:true})comments)" self.assertEqual(q, expected) + return q def test_weight_annotation(self): condition = Queryfield("title").contains("heads", annotations={"weight": 200}) q = Query(select_fields="*").from_("s1").where(condition) expected = 'select * from s1 where title contains({weight:200}"heads")' self.assertEqual(q, expected) + return q def test_nearest_neighbor_annotations(self): condition = Q.nearestNeighbor( @@ -354,6 +382,7 @@ def test_nearest_neighbor_annotations(self): q = Query(select_fields=["id, text"]).from_("m").where(condition) expected = "select id, text from m where ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) + return q def test_phrase(self): text = Queryfield("text") @@ -361,6 +390,7 @@ def test_phrase(self): query = Q.select("*").where(condition) expected = 'select * from * where text contains phrase("st", "louis", "blues")' self.assertEqual(query, expected) + return query def test_near(self): title = Queryfield("title") @@ -368,6 +398,7 @@ def test_near(self): query = Q.select("*").where(condition) expected = 'select * from * where title contains near("madonna", "saint")' self.assertEqual(query, expected) + return query def test_onear(self): title = Queryfield("title") @@ -375,8 +406,9 @@ def test_onear(self): query = Q.select("*").where(condition) expected = 'select * from * where title contains onear("madonna", "saint")' self.assertEqual(query, expected) + return query - def test_sameElement(self): + def test_same_element(self): persons = Queryfield("persons") first_name = Queryfield("first_name") last_name = Queryfield("last_name") @@ -391,6 +423,7 @@ def test_sameElement(self): query = Q.select("*").where(condition) expected = 'select * from * where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' self.assertEqual(query, expected) + return query def test_equiv(self): fieldName = Queryfield("fieldName") @@ -398,6 +431,7 @@ def test_equiv(self): query = Q.select("*").where(condition) expected = 'select * from * where fieldName contains equiv("A", "B")' self.assertEqual(query, expected) + return query def test_uri(self): myUrlField = Queryfield("myUrlField") @@ -405,6 +439,7 @@ def test_uri(self): query = Q.select("*").where(condition) expected = 'select * from * where myUrlField contains uri("vespa.ai/foo")' self.assertEqual(query, expected) + return query def test_fuzzy(self): myStringAttribute = Queryfield("myStringAttribute") @@ -415,12 +450,14 @@ def test_fuzzy(self): query = Q.select("*").where(condition) expected = 'select * from * where myStringAttribute contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' self.assertEqual(query, expected) + return query - def test_userInput(self): + def test_userinput(self): condition = Q.userInput("@animal") query = Q.select("*").where(condition).param("animal", "panda") expected = "select * from * where userInput(@animal)&animal=panda" self.assertEqual(query, expected) + return query def test_in_operator(self): integer_field = Queryfield("integer_field") @@ -428,6 +465,7 @@ def test_in_operator(self): query = Q.select("*").where(condition) expected = "select * from * where integer_field in (10, 20, 30)" self.assertEqual(query, expected) + return query def test_predicate(self): condition = Q.predicate( @@ -438,18 +476,21 @@ def test_predicate(self): query = Q.select("*").where(condition) expected = 'select * from * where predicate(predicate_field,{"gender":"Female"},{"age":20L})' self.assertEqual(query, expected) + return query def test_true(self): condition = Q.true() query = Q.select("*").where(condition) expected = "select * from * where true" self.assertEqual(query, expected) + return query def test_false(self): condition = Q.false() query = Q.select("*").where(condition) expected = "select * from * where false" self.assertEqual(query, expected) + return query if __name__ == "__main__": From cb30290bf448040dab11f64253a7356029506495 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 09:45:25 +0100 Subject: [PATCH 18/39] clean userquery --- vespa/querybuilder/main.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py index 6feea77c..793d2e34 100644 --- a/vespa/querybuilder/main.py +++ b/vespa/querybuilder/main.py @@ -325,16 +325,6 @@ def p(*args): @staticmethod def userQuery(value: str = "") -> Condition: return Condition(f'userQuery("{value}")') if value else Condition("userQuery()") - # else: - # # Both index and value provided - # default_index_json = json.dumps( - # {"defaultIndex": index}, separators=(",", ":") - # ) - # return ( - # Condition(f'({default_index_json}userQuery("{value}"))') - # if value - # else Condition(f"({default_index_json}userQuery())") - # ) @staticmethod def dotProduct( From a336bee6f5f2570842ea215ab0bcd7b2dd9d19b6 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 10:11:50 +0100 Subject: [PATCH 19/39] add userinput --- tests/unit/test_q.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 218751b7..75219198 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -453,9 +453,23 @@ def test_fuzzy(self): return query def test_userinput(self): + condition = Q.userInput("@myvar") + query = Q.select("*").from_("sd1").where(condition) + expected = "select * from sd1 where userInput(@myvar)" + self.assertEqual(query, expected) + return query + + def test_userinput_param(self): condition = Q.userInput("@animal") - query = Q.select("*").where(condition).param("animal", "panda") - expected = "select * from * where userInput(@animal)&animal=panda" + query = Q.select("*").from_("sd1").where(condition).param("animal", "panda") + expected = "select * from sd1 where userInput(@animal)&animal=panda" + self.assertEqual(query, expected) + return query + + def test_userinput_with_defaultindex(self): + condition = Q.userInput("@myvar").annotate({"defaultindex": "text"}) + query = Q.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where ({defaultindex:"text"})userInput(@myvar)' self.assertEqual(query, expected) return query From 10cc0b39c0602db299eb26efe9e17e2cbcaa37a2 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 10:12:04 +0100 Subject: [PATCH 20/39] userinput integration --- tests/integration/test_integration_queries.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py index 3dfd1e19..bdc2b16f 100644 --- a/tests/integration/test_integration_queries.py +++ b/tests/integration/test_integration_queries.py @@ -457,3 +457,43 @@ def test_userquery_customindex(self): hit = result.hits[0] self.assertEqual(hit["fields"]["description"], "foo") self.assertEqual(hit["fields"]["text"], "foo") + + def test_userinput(self): + # 'select * from sd1 where userInput(@myvar)' + # Feed test documents + myvar = "panda" + docs = [ + { # Doc 1: Should match + "description": "a panda is a cute", + "text": "foo", + }, + { # Doc 2: Should match + "description": "foo", + "text": "you are a cool panda", + }, + { # Doc 3: Should not match + "description": "bar", + "text": "baz", + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + # Execute query + q = qb.test_userinput() + print(f"Executing query: {q}") + body = { + "yql": str(q), + "ranking": "bm25", + "myvar": myvar, + } + with self.app.syncio() as sess: + result = sess.query(body=body) + # Verify only two documents match + self.assertEqual(len(result.hits), 2) + # Verify matching documents have expected values + ids = sorted([hit["id"] for hit in result.hits]) + self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) + self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) From 52b609ec7f027dce0608043c7250727d2b82e6c2 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 11:25:53 +0100 Subject: [PATCH 21/39] no parenthesis annotations --- vespa/querybuilder/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py index 793d2e34..7b57c280 100644 --- a/vespa/querybuilder/main.py +++ b/vespa/querybuilder/main.py @@ -160,7 +160,7 @@ def annotate(self, annotations: Dict[str, Any]) -> "Condition": f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) - return Condition(f"({{{annotations_str}}}){self.expression}") + return Condition(f"{{{annotations_str}}}{self.expression}") def build(self) -> str: return self.expression From e1b3cb78df5b9a6df1c4acba15b621316d980854 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 11:26:02 +0100 Subject: [PATCH 22/39] predicate --- tests/integration/test_integration_queries.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py index bdc2b16f..2ae81109 100644 --- a/tests/integration/test_integration_queries.py +++ b/tests/integration/test_integration_queries.py @@ -62,6 +62,7 @@ def setUpClass(cls): name="predicate_field", type="predicate", indexing=["attribute", "summary"], + index="arity: 2", # This is required for predicate fields ), Field( name="myStringAttribute", type="string", indexing=["index", "summary"] @@ -497,3 +498,148 @@ def test_userinput(self): ids = sorted([hit["id"] for hit in result.hits]) self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) + + def test_userinput_with_defaultindex(self): + # 'select * from sd1 where {defaultIndex:"text"}userInput(@myvar)' + # Feed test documents + myvar = "panda" + docs = [ + { # Doc 1: Should not match + "description": "a panda is a cute", + "text": "foo", + }, + { # Doc 2: Should match + "description": "foo", + "text": "you are a cool panda", + }, + { # Doc 3: Should not match + "description": "bar", + "text": "baz", + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + # Execute query + q = qb.test_userinput_with_defaultindex() + print(f"Executing query: {q}") + body = { + "yql": str(q), + "ranking": "bm25", + "myvar": myvar, + } + with self.app.syncio() as sess: + result = sess.query(body=body) + print(result.json) + # Verify only one document matches + self.assertEqual(len(result.hits), 1) + # Verify matching document has expected values + hit = result.hits[0] + self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::2") + + def test_in_operator_intfield(self): + # 'select * from * where integer_field in (10, 20, 30)' + # We use age field for this test + # Feed test documents + docs = [ + { # Doc 1: Should match + "age": 10, + }, + { # Doc 2: Should match + "age": 20, + }, + { # Doc 3: Should not match + "age": 31, + }, + { # Doc 4: Should not match + "age": 40, + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + # Execute query + q = qb.test_in_operator_intfield() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + print(result.json) + # Verify only two documents match + self.assertEqual(len(result.hits), 2) + # Verify matching documents have expected values + ids = sorted([hit["id"] for hit in result.hits]) + self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) + self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) + + def test_in_operator_stringfield(self): + # 'select * from sd1 where status in ("active", "inactive")' + # Feed test documents + docs = [ + { # Doc 1: Should match + "status": "active", + }, + { # Doc 2: Should match + "status": "inactive", + }, + { # Doc 3: Should not match + "status": "foo", + }, + { # Doc 4: Should not match + "status": "bar", + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + # Execute query + q = qb.test_in_operator_stringfield() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + # Verify only two documents match + self.assertEqual(len(result.hits), 2) + # Verify matching documents have expected values + ids = sorted([hit["id"] for hit in result.hits]) + self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) + self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) + + def test_predicate(self): + # 'select * from sd1 where predicate(predicate_field,{"gender":"Female"},{"age":25L})' + # Feed test documents with predicate_field + docs = [ + { # Doc 1: Should match - satisfies both predicates + "predicate_field": 'gender in ["Female"] and age in [20..30]', + }, + { # Doc 2: Should not match - wrong gender + "predicate_field": 'gender in ["Male"] and age in [20..30]', + }, + { # Doc 3: Should not match - too young + "predicate_field": 'gender in ["Female"] and age in [30..40]', + }, + ] + + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + + # Execute query using predicate search + q = qb.test_predicate() + print(f"Executing query: {q}") + + with self.app.syncio() as sess: + result = sess.query(yql=q) + + # Verify only one document matches both predicates + self.assertEqual(len(result.hits), 1) + + # Verify matching document has expected id + hit = result.hits[0] + self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::1") From 21c89546d526e9f5dbb9adad16420138981a8118 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 11:26:13 +0100 Subject: [PATCH 23/39] annotation formatting and predicate --- tests/unit/test_q.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 75219198..a16e7ea9 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -176,7 +176,7 @@ def test_condition_annotation(self): condition = f1.contains("Python") annotated_condition = condition.annotate({"filter": True}) q = Query(select_fields="*").from_("articles").where(annotated_condition) - expected = 'select * from articles where ({filter:true})title contains "Python"' + expected = 'select * from articles where {filter:true}title contains "Python"' self.assertEqual(q, expected) return q @@ -272,7 +272,7 @@ def test_filter_annotation(self): f1 = Queryfield("title") condition = f1.contains("Python").annotate({"filter": True}) q = Query(select_fields="*").from_("articles").where(condition) - expected = 'select * from articles where ({filter:true})title contains "Python"' + expected = 'select * from articles where {filter:true}title contains "Python"' self.assertEqual(q, expected) return q @@ -467,17 +467,25 @@ def test_userinput_param(self): return query def test_userinput_with_defaultindex(self): - condition = Q.userInput("@myvar").annotate({"defaultindex": "text"}) + condition = Q.userInput("@myvar").annotate({"defaultIndex": "text"}) query = Q.select("*").from_("sd1").where(condition) - expected = 'select * from sd1 where ({defaultindex:"text"})userInput(@myvar)' + expected = 'select * from sd1 where {defaultIndex:"text"}userInput(@myvar)' self.assertEqual(query, expected) return query - def test_in_operator(self): - integer_field = Queryfield("integer_field") + def test_in_operator_intfield(self): + integer_field = Queryfield("age") condition = integer_field.in_(10, 20, 30) - query = Q.select("*").where(condition) - expected = "select * from * where integer_field in (10, 20, 30)" + query = Q.select("*").from_("sd1").where(condition) + expected = "select * from sd1 where age in (10, 20, 30)" + self.assertEqual(query, expected) + return query + + def test_in_operator_stringfield(self): + string_field = Queryfield("status") + condition = string_field.in_("active", "inactive") + query = Q.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where status in ("active", "inactive")' self.assertEqual(query, expected) return query @@ -487,22 +495,22 @@ def test_predicate(self): attributes={"gender": "Female"}, range_attributes={"age": "20L"}, ) - query = Q.select("*").where(condition) - expected = 'select * from * where predicate(predicate_field,{"gender":"Female"},{"age":20L})' + query = Q.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where predicate(predicate_field,{"gender":"Female"},{"age":20L})' self.assertEqual(query, expected) return query def test_true(self): condition = Q.true() - query = Q.select("*").where(condition) - expected = "select * from * where true" + query = Q.select("*").from_("sd1").where(condition) + expected = "select * from sd1 where true" self.assertEqual(query, expected) return query def test_false(self): condition = Q.false() - query = Q.select("*").where(condition) - expected = "select * from * where false" + query = Q.select("*").from_("sd1").where(condition) + expected = "select * from sd1 where false" self.assertEqual(query, expected) return query From 634513c501ca6e63fc01427c6266cce741e61488 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 13:53:02 +0100 Subject: [PATCH 24/39] wand --- tests/unit/test_q.py | 39 ++++++++++++++++++++++++++++++++++---- vespa/querybuilder/main.py | 18 +++++++++++++----- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index a16e7ea9..511e10dc 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -212,7 +212,30 @@ def test_custom_ranking_expression(self): def test_wand(self): condition = Q.wand("keywords", {"apple": 10, "banana": 20}) q = Query(select_fields="*").from_("fruits").where(condition) - expected = 'select * from fruits where wand(keywords, {"apple":10,"banana":20})' + expected = ( + 'select * from fruits where wand(keywords, {"apple":10, "banana":20})' + ) + self.assertEqual(q, expected) + return q + + def test_wand_numeric(self): + condition = Q.wand("description", [[11, 1], [37, 2]]) + q = Query(select_fields="*").from_("fruits").where(condition) + expected = "select * from fruits where wand(description, [[11, 1], [37, 2]])" + self.assertEqual(q, expected) + return q + + def test_wand_annotations(self): + self.maxDiff = None + condition = Q.wand( + "description", + weights={"a": 1, "b": 2}, + annotations={"scoreThreshold": 0.13, "targetHits": 7}, + ) + q = Query(select_fields="*").from_("fruits").where(condition) + expected = 'select * from fruits where ({scoreThreshold: 0.13, targetHits: 7}wand(description, {"a":1, "b":2}))' + print(q) + print(expected) self.assertEqual(q, expected) return q @@ -223,7 +246,7 @@ def test_weakand(self): condition1, condition2, annotations={"targetNumHits": 100} ) q = Query(select_fields="*").from_("articles").where(condition) - expected = 'select * from articles where ({"targetNumHits":100}weakAnd(title contains "Python", description contains "Programming"))' + expected = 'select * from articles where ({"targetNumHits": 100}weakAnd(title contains "Python", description contains "Programming"))' self.assertEqual(q, expected) return q @@ -260,14 +283,22 @@ def test_order_by_with_annotations(self): self.assertEqual(q, expected) return q - def test_field_comparison_methods(self): + def test_field_comparison_methods_builtins(self): f1 = Queryfield("age") - condition = f1.ge(18) & f1.lt(30) + condition = (f1 >= 18) & (f1 < 30) q = Query(select_fields="*").from_("users").where(condition) expected = "select * from users where age >= 18 and age < 30" self.assertEqual(q, expected) return q + def test_field_comparison_methods(self): + f1 = Queryfield("age") + condition = (f1.ge(18) & f1.lt(30)) | f1.eq(40) + q = Query(select_fields="*").from_("users").where(condition) + expected = "select * from users where (age >= 18 and age < 30) or age = 40" + self.assertEqual(q, expected) + return q + def test_filter_annotation(self): f1 = Queryfield("title") condition = f1.contains("Python").annotate({"filter": True}) diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py index 7b57c280..1b1d8335 100644 --- a/vespa/querybuilder/main.py +++ b/vespa/querybuilder/main.py @@ -24,6 +24,12 @@ def __gt__(self, other: Any) -> "Condition": def __ge__(self, other: Any) -> "Condition": return Condition(f"{self.name} >= {self._format_value(other)}") + def __and__(self, other: Any) -> "Condition": + return Condition(f"{self.name} and {self._format_value(other)}") + + def __or__(self, other: Any) -> "Condition": + return Condition(f"{self.name} or {self._format_value(other)}") + def contains( self, value: Any, annotations: Optional[Dict[str, Any]] = None ) -> "Condition": @@ -367,15 +373,17 @@ def wand( field: str, weights, annotations: Optional[Dict[str, Any]] = None ) -> Condition: if isinstance(weights, list): - weights_str = "[" + ",".join(str(item) for item in weights) + "]" + weights_str = "[" + ", ".join(str(item) for item in weights) + "]" elif isinstance(weights, dict): - weights_str = "{" + ",".join(f'"{k}":{v}' for k, v in weights.items()) + "}" + weights_str = ( + "{" + ", ".join(f'"{k}":{v}' for k, v in weights.items()) + "}" + ) else: raise ValueError("Invalid weights for wand") expr = f"wand({field}, {weights_str})" if annotations: - annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + annotations_str = ", ".join( + f"{k}: {Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -387,7 +395,7 @@ def weakAnd(*conditions, annotations: Dict[str, Any] = None) -> Condition: expr = f"weakAnd({conditions_str})" if annotations: annotations_str = ",".join( - f'"{k}":{Queryfield._format_annotation_value(v)}' + f'"{k}": {Queryfield._format_annotation_value(v)}' for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" From f791d4932ae064bed0f07ab73afac8c76755fe3e Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 13:56:53 +0100 Subject: [PATCH 25/39] timeout --- tests/unit/test_q.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 511e10dc..63b02cb8 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -71,6 +71,14 @@ def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): self.assertEqual(q, expected) return q + def test_timeout(self): + f1 = Queryfield("title") + condition = f1.contains("madonna") + q = Query(select_fields="*").from_("sd1").where(condition).set_timeout(70) + expected = 'select * from sd1 where title contains "madonna" timeout 70' + self.assertEqual(q, expected) + return q + def test_matches(self): condition = ( (Queryfield("f1").matches("v1") & Queryfield("f2").matches("v2")) From a191796e8ead8ba323665ca08513c0fd32319c4b Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 2 Dec 2024 14:17:56 +0100 Subject: [PATCH 26/39] exclude markdown files in resources --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 67503190..ccbcdab1 100644 --- a/.gitignore +++ b/.gitignore @@ -150,4 +150,7 @@ markdown_before/ # All .pem and .crt files *.pem -*.crt \ No newline at end of file +*.crt + +# Exclude markdown files in vespa/resources - directory +vespa/resources/*.md \ No newline at end of file From 86c85afbc3e95617bc07487efb6fc12276f49cea Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 4 Dec 2024 14:57:34 +0100 Subject: [PATCH 27/39] add fuzzy integration test --- tests/integration/test_integration_queries.py | 47 ++++++++++++++----- tests/unit/test_q.py | 6 +-- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py index 2ae81109..bba8e435 100644 --- a/tests/integration/test_integration_queries.py +++ b/tests/integration/test_integration_queries.py @@ -21,9 +21,7 @@ def setUpClass(cls): application_name = "querybuilder" cls.application_name = application_name schema_name1 = "sd1" - schema_name2 = "sd2" cls.schema_name1 = schema_name1 - cls.schema_name2 = schema_name2 # Define all fields used in the unit tests # Schema 1 fields = [ @@ -118,23 +116,17 @@ def setUpClass(cls): ] fieldset = FieldSet(name="default", fields=["text", "title", "description"]) document = Document(fields=fields, structs=[email_struct]) - schema1 = Schema( + schema = Schema( name=schema_name1, document=document, rank_profiles=rank_profiles, fieldsets=[fieldset], ) - schema1.add_fields(emails_field) - ## Schema 2 - schema2 = Schema( - name=schema_name2, document=document, rank_profiles=rank_profiles - ) + schema.add_fields(emails_field) + # Create the application package - application_package = ApplicationPackage( - name=application_name, schema=[schema1, schema2] - ) + application_package = ApplicationPackage(name=application_name, schema=[schema]) print(application_package.get_schema(schema_name1).schema_to_text) - print(application_package.get_schema(schema_name2).schema_to_text) # Deploy the application cls.vespa_docker = VespaDocker(port=8089) cls.app = cls.vespa_docker.deploy(application_package=application_package) @@ -643,3 +635,34 @@ def test_predicate(self): # Verify matching document has expected id hit = result.hits[0] self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::1") + + def test_fuzzy(self): + # 'select * from sd1 where f1 contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' + # Feed test documents + docs = [ + { # Doc 1: Should match + "f1": "parantesis", + }, + { # Doc 2: Should match - edit distance 1 + "f1": "paranthesis", + }, + { # Doc 3: Should not match - edit distance 3 + "f1": "parrenthesis", + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + # Execute query + q = qb.test_fuzzy() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + # Verify only two documents match + self.assertEqual(len(result.hits), 2) + # Verify matching documents have expected values + ids = sorted([hit["id"] for hit in result.hits]) + self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) + self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 63b02cb8..b20833a5 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -481,13 +481,13 @@ def test_uri(self): return query def test_fuzzy(self): - myStringAttribute = Queryfield("myStringAttribute") + myStringAttribute = Queryfield("f1") annotations = {"prefixLength": 1, "maxEditDistance": 2} condition = myStringAttribute.contains( Q.fuzzy("parantesis", annotations=annotations) ) - query = Q.select("*").where(condition) - expected = 'select * from * where myStringAttribute contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' + query = Q.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where f1 contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' self.assertEqual(query, expected) return query From 0d896a18aa7823ed7f6a1b5559e0e10e723ccc64 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 4 Dec 2024 15:02:06 +0100 Subject: [PATCH 28/39] add uri integration test --- tests/integration/test_integration_queries.py | 30 +++++++++++++++++++ tests/unit/test_q.py | 4 +-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py index bba8e435..4dd642e7 100644 --- a/tests/integration/test_integration_queries.py +++ b/tests/integration/test_integration_queries.py @@ -666,3 +666,33 @@ def test_fuzzy(self): ids = sorted([hit["id"] for hit in result.hits]) self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) + + def test_uri(self): + # 'select * from sd1 where myUrlField contains uri("vespa.ai/foo")' + # Feed test documents + docs = [ + { # Doc 1: Should match + "myUrlField": "https://vespa.ai/foo", + }, + { # Doc 2: Should not match - wrong path + "myUrlField": "https://vespa.ai/bar", + }, + { # Doc 3: Should not match - wrong domain + "myUrlField": "https://google.com/foo", + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + # Execute query + q = qb.test_uri() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + # Verify only one document matches + self.assertEqual(len(result.hits), 1) + # Verify matching document has expected values + hit = result.hits[0] + self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::1") diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index b20833a5..59c46ccb 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -475,8 +475,8 @@ def test_equiv(self): def test_uri(self): myUrlField = Queryfield("myUrlField") condition = myUrlField.contains(Q.uri("vespa.ai/foo")) - query = Q.select("*").where(condition) - expected = 'select * from * where myUrlField contains uri("vespa.ai/foo")' + query = Q.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where myUrlField contains uri("vespa.ai/foo")' self.assertEqual(query, expected) return query From 0a53f301c392553db0cefc5663b159f7f43b652c Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 4 Dec 2024 15:09:36 +0100 Subject: [PATCH 29/39] sameElement integration test --- tests/integration/test_integration_queries.py | 58 ++++++++++++++++++- tests/unit/test_q.py | 4 +- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py index 4dd642e7..e247df48 100644 --- a/tests/integration/test_integration_queries.py +++ b/tests/integration/test_integration_queries.py @@ -99,6 +99,24 @@ def setUpClass(cls): ) ], ) + person_struct = Struct( + name="person", + fields=[ + Field(name="first_name", type="string"), + Field(name="last_name", type="string"), + Field(name="year_of_birth", type="int"), + ], + ) + persons_field = Field( + name="persons", + type="array", + indexing=["summary"], + struct_fields=[ + StructField(name="first_name", indexing=["attribute"]), + StructField(name="last_name", indexing=["attribute"]), + StructField(name="year_of_birth", indexing=["attribute"]), + ], + ) rank_profiles = [ RankProfile( name="dotproduct", @@ -115,14 +133,14 @@ def setUpClass(cls): ), ] fieldset = FieldSet(name="default", fields=["text", "title", "description"]) - document = Document(fields=fields, structs=[email_struct]) + document = Document(fields=fields, structs=[email_struct, person_struct]) schema = Schema( name=schema_name1, document=document, rank_profiles=rank_profiles, fieldsets=[fieldset], ) - schema.add_fields(emails_field) + schema.add_fields(emails_field, persons_field) # Create the application package application_package = ApplicationPackage(name=application_name, schema=[schema]) @@ -696,3 +714,39 @@ def test_uri(self): # Verify matching document has expected values hit = result.hits[0] self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::1") + + def test_same_element(self): + # 'select * from sd1 where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' + # Feed test documents + docs = [ + { # Doc 1: Should match + "persons": [ + {"first_name": "Joe", "last_name": "Smith", "year_of_birth": 1930} + ], + }, + { # Doc 2: Should not match - wrong last name + "persons": [ + {"first_name": "Joe", "last_name": "Johnson", "year_of_birth": 1930} + ], + }, + { # Doc 3: Should not match - wrong year of birth + "persons": [ + {"first_name": "Joe", "last_name": "Smith", "year_of_birth": 1940} + ], + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name1) + # Execute query + q = qb.test_same_element() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + # Verify only one document matches + self.assertEqual(len(result.hits), 1) + # Verify matching document has expected values + hit = result.hits[0] + self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::1") diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 59c46ccb..ff95fe83 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -459,8 +459,8 @@ def test_same_element(self): year_of_birth < 1940, ) ) - query = Q.select("*").where(condition) - expected = 'select * from * where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' + query = Q.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' self.assertEqual(query, expected) return query From 7d5b5df0bc0c6cbe33c5087808be9145c09b99e8 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Wed, 4 Dec 2024 15:25:51 +0100 Subject: [PATCH 30/39] near and onear with distance --- tests/unit/test_q.py | 16 ++++++++++++++++ vespa/querybuilder/main.py | 20 +++++++++++++++++--- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index ff95fe83..5277cf95 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -439,6 +439,14 @@ def test_near(self): self.assertEqual(query, expected) return query + def test_near_with_distance(self): + title = Queryfield("title") + condition = title.contains(Q.near("madonna", "saint", distance=10)) + query = Q.select("*").where(condition) + expected = 'select * from * where title contains ({distance:10}near("madonna", "saint"))' + self.assertEqual(query, expected) + return query + def test_onear(self): title = Queryfield("title") condition = title.contains(Q.onear("madonna", "saint")) @@ -447,6 +455,14 @@ def test_onear(self): self.assertEqual(query, expected) return query + def test_onear_with_distance(self): + title = Queryfield("title") + condition = title.contains(Q.onear("madonna", "saint", distance=5)) + query = Q.select("*").where(condition) + expected = 'select * from * where title contains ({distance:5}onear("madonna", "saint"))' + self.assertEqual(query, expected) + return query + def test_same_element(self): persons = Queryfield("persons") first_name = Queryfield("first_name") diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py index 1b1d8335..de94084f 100644 --- a/vespa/querybuilder/main.py +++ b/vespa/querybuilder/main.py @@ -450,11 +450,18 @@ def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: return Condition(expr) @staticmethod - def near(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + def near( + *terms, annotations: Optional[Dict[str, Any]] = None, **kwargs + ) -> Condition: terms_str = ", ".join(f'"{term}"' for term in terms) expr = f"near({terms_str})" + # if kwargs - add to annotations + if kwargs: + if not annotations: + annotations = {} + annotations.update(kwargs) if annotations: - annotations_str = ",".join( + annotations_str = ", ".join( f"{k}:{Queryfield._format_annotation_value(v)}" for k, v in annotations.items() ) @@ -462,9 +469,16 @@ def near(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: return Condition(expr) @staticmethod - def onear(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + def onear( + *terms, annotations: Optional[Dict[str, Any]] = None, **kwargs + ) -> Condition: terms_str = ", ".join(f'"{term}"' for term in terms) expr = f"onear({terms_str})" + # if kwargs - add to annotations + if kwargs: + if not annotations: + annotations = {} + annotations.update(kwargs) if annotations: annotations_str = ",".join( f"{k}:{Queryfield._format_annotation_value(v)}" From 2ac91cc70a441c01ad60dec9a4e77c5caf0cb361 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Thu, 5 Dec 2024 09:37:31 +0100 Subject: [PATCH 31/39] passing grouping --- tests/integration/test_integration_queries.py | 357 ++++++++++++++++-- tests/unit/test_q.py | 25 +- vespa/querybuilder/main.py | 22 +- 3 files changed, 348 insertions(+), 56 deletions(-) diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py index e247df48..3d105361 100644 --- a/tests/integration/test_integration_queries.py +++ b/tests/integration/test_integration_queries.py @@ -1,4 +1,5 @@ import unittest +import requests from vespa.deployment import VespaDocker from vespa.package import ( ApplicationPackage, @@ -20,8 +21,8 @@ class TestQueriesIntegration(unittest.TestCase): def setUpClass(cls): application_name = "querybuilder" cls.application_name = application_name - schema_name1 = "sd1" - cls.schema_name1 = schema_name1 + schema_name = "sd1" + cls.schema_name = schema_name # Define all fields used in the unit tests # Schema 1 fields = [ @@ -135,16 +136,62 @@ def setUpClass(cls): fieldset = FieldSet(name="default", fields=["text", "title", "description"]) document = Document(fields=fields, structs=[email_struct, person_struct]) schema = Schema( - name=schema_name1, + name=schema_name, document=document, rank_profiles=rank_profiles, fieldsets=[fieldset], ) schema.add_fields(emails_field, persons_field) + # Add purchase schema for grouping test + # schema purchase { + # document purchase { + + # field date type long { + # indexing: summary | attribute + # } + + # field price type int { + # indexing: summary | attribute + # } + + # field tax type double { + # indexing: summary | attribute + # } + + # field item type string { + # indexing: summary | attribute + # } + + # field customer type string { + # indexing: summary | attribute + # } + + # } + purchase_schema = Schema( + name="purchase", + document=Document( + fields=[ + Field(name="date", type="long", indexing=["summary", "attribute"]), + Field(name="price", type="int", indexing=["summary", "attribute"]), + Field(name="tax", type="double", indexing=["summary", "attribute"]), + Field( + name="item", type="string", indexing=["summary", "attribute"] + ), + Field( + name="customer", + type="string", + indexing=["summary", "attribute"], + ), + ] + ), + ) # Create the application package - application_package = ApplicationPackage(name=application_name, schema=[schema]) - print(application_package.get_schema(schema_name1).schema_to_text) + application_package = ApplicationPackage( + name=application_name, schema=[schema, purchase_schema] + ) + print(application_package.get_schema(schema_name).schema_to_text) + print(application_package.get_schema("purchase").schema_to_text) # Deploy the application cls.vespa_docker = VespaDocker(port=8089) cls.app = cls.vespa_docker.deploy(application_package=application_package) @@ -161,7 +208,7 @@ def test_dotProduct_with_annotations(self): fields = {field: {"feature1": 2, "feature2": 4}} data_id = 1 self.app.feed_data_point( - schema=self.schema_name1, data_id=data_id, fields=fields + schema=self.schema_name, data_id=data_id, fields=fields ) q = qb.test_dotProduct_with_annotations() with self.app.syncio() as sess: @@ -170,7 +217,7 @@ def test_dotProduct_with_annotations(self): self.assertEqual(len(result.hits), 1) self.assertEqual( result.hits[0]["id"], - f"id:{self.schema_name1}:{self.schema_name1}::{data_id}", + f"id:{self.schema_name}:{self.schema_name}::{data_id}", ) self.assertEqual( result.hits[0]["fields"]["summaryfeatures"]["rawScore(weightedset_field)"], @@ -188,7 +235,7 @@ def test_geolocation_with_annotations(self): } data_id = 2 self.app.feed_data_point( - schema=self.schema_name1, data_id=data_id, fields=fields + schema=self.schema_name, data_id=data_id, fields=fields ) # Build and send the query q = qb.test_geolocation_with_annotations() @@ -198,7 +245,7 @@ def test_geolocation_with_annotations(self): self.assertEqual(len(result.hits), 1) self.assertEqual( result.hits[0]["id"], - f"id:{self.schema_name1}:{self.schema_name1}::{data_id}", + f"id:{self.schema_name}:{self.schema_name}::{data_id}", ) self.assertAlmostEqual( result.hits[0]["fields"]["summaryfeatures"]["distance(location_field).km"], @@ -244,7 +291,7 @@ def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): # Feed documents docs = [{"id": data_id, "fields": doc} for data_id, doc in enumerate(docs, 1)] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Build and send query q = qb.test_basic_and_andnot_or_offset_limit_param_order_by_and_contains() @@ -261,7 +308,7 @@ def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): # The query orders by age desc, duration asc with offset 1 # So we should get doc ID 2 (since doc ID 3 is skipped due to offset) hit = result.hits[0] - self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::2") + self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::2") # Verify the matching document has expected field values self.assertEqual(hit["fields"]["age"], 20) @@ -313,7 +360,7 @@ def test_matches(self): ] # Feed documents - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Build and send query q = qb.test_matches() @@ -330,8 +377,8 @@ def test_matches(self): ids = sorted([hit["id"] for hit in result.hits]) expected_ids = sorted( [ - f"id:{self.schema_name1}:{self.schema_name1}::1", - f"id:{self.schema_name1}:{self.schema_name1}::3", + f"id:{self.schema_name}:{self.schema_name}::1", + f"id:{self.schema_name}:{self.schema_name}::3", ] ) @@ -371,7 +418,7 @@ def test_nested_queries(self): } for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) q = qb.test_nested_queries() print(f"Executing query: {q}") with self.app.syncio() as sess: @@ -380,7 +427,7 @@ def test_nested_queries(self): self.assertEqual(len(result.hits), 1) self.assertEqual( result.hits[0]["id"], - f"id:{self.schema_name1}:{self.schema_name1}::2", + f"id:{self.schema_name}:{self.schema_name}::2", ) def test_userquery_defaultindex(self): @@ -405,7 +452,7 @@ def test_userquery_defaultindex(self): docs = [ {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Execute query q = qb.test_userquery() @@ -419,8 +466,8 @@ def test_userquery_defaultindex(self): result = sess.query(body=body) self.assertEqual(len(result.hits), 2) ids = sorted([hit["id"] for hit in result.hits]) - self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) - self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::1", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::2", ids) def test_userquery_customindex(self): # 'select * from sd1 where userQuery())' @@ -444,7 +491,7 @@ def test_userquery_customindex(self): docs = [ {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Execute query q = qb.test_userquery() @@ -461,7 +508,7 @@ def test_userquery_customindex(self): # Verify only one document matches both conditions self.assertEqual(len(result.hits), 1) self.assertEqual( - result.hits[0]["id"], f"id:{self.schema_name1}:{self.schema_name1}::1" + result.hits[0]["id"], f"id:{self.schema_name}:{self.schema_name}::1" ) # Verify matching document has expected values @@ -491,7 +538,7 @@ def test_userinput(self): docs = [ {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Execute query q = qb.test_userinput() print(f"Executing query: {q}") @@ -506,8 +553,8 @@ def test_userinput(self): self.assertEqual(len(result.hits), 2) # Verify matching documents have expected values ids = sorted([hit["id"] for hit in result.hits]) - self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) - self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::1", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::2", ids) def test_userinput_with_defaultindex(self): # 'select * from sd1 where {defaultIndex:"text"}userInput(@myvar)' @@ -531,7 +578,7 @@ def test_userinput_with_defaultindex(self): docs = [ {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Execute query q = qb.test_userinput_with_defaultindex() print(f"Executing query: {q}") @@ -547,7 +594,7 @@ def test_userinput_with_defaultindex(self): self.assertEqual(len(result.hits), 1) # Verify matching document has expected values hit = result.hits[0] - self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::2") + self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::2") def test_in_operator_intfield(self): # 'select * from * where integer_field in (10, 20, 30)' @@ -571,7 +618,7 @@ def test_in_operator_intfield(self): docs = [ {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Execute query q = qb.test_in_operator_intfield() print(f"Executing query: {q}") @@ -582,8 +629,8 @@ def test_in_operator_intfield(self): self.assertEqual(len(result.hits), 2) # Verify matching documents have expected values ids = sorted([hit["id"] for hit in result.hits]) - self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) - self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::1", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::2", ids) def test_in_operator_stringfield(self): # 'select * from sd1 where status in ("active", "inactive")' @@ -606,7 +653,7 @@ def test_in_operator_stringfield(self): docs = [ {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Execute query q = qb.test_in_operator_stringfield() print(f"Executing query: {q}") @@ -616,8 +663,8 @@ def test_in_operator_stringfield(self): self.assertEqual(len(result.hits), 2) # Verify matching documents have expected values ids = sorted([hit["id"] for hit in result.hits]) - self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) - self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::1", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::2", ids) def test_predicate(self): # 'select * from sd1 where predicate(predicate_field,{"gender":"Female"},{"age":25L})' @@ -638,7 +685,7 @@ def test_predicate(self): docs = [ {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Execute query using predicate search q = qb.test_predicate() @@ -652,7 +699,7 @@ def test_predicate(self): # Verify matching document has expected id hit = result.hits[0] - self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::1") + self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::1") def test_fuzzy(self): # 'select * from sd1 where f1 contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' @@ -672,7 +719,7 @@ def test_fuzzy(self): docs = [ {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Execute query q = qb.test_fuzzy() print(f"Executing query: {q}") @@ -682,8 +729,8 @@ def test_fuzzy(self): self.assertEqual(len(result.hits), 2) # Verify matching documents have expected values ids = sorted([hit["id"] for hit in result.hits]) - self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::1", ids) - self.assertIn(f"id:{self.schema_name1}:{self.schema_name1}::2", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::1", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::2", ids) def test_uri(self): # 'select * from sd1 where myUrlField contains uri("vespa.ai/foo")' @@ -703,7 +750,7 @@ def test_uri(self): docs = [ {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Execute query q = qb.test_uri() print(f"Executing query: {q}") @@ -713,7 +760,7 @@ def test_uri(self): self.assertEqual(len(result.hits), 1) # Verify matching document has expected values hit = result.hits[0] - self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::1") + self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::1") def test_same_element(self): # 'select * from sd1 where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' @@ -739,7 +786,7 @@ def test_same_element(self): docs = [ {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) ] - self.app.feed_iterable(iter=docs, schema=self.schema_name1) + self.app.feed_iterable(iter=docs, schema=self.schema_name) # Execute query q = qb.test_same_element() print(f"Executing query: {q}") @@ -749,4 +796,232 @@ def test_same_element(self): self.assertEqual(len(result.hits), 1) # Verify matching document has expected values hit = result.hits[0] - self.assertEqual(hit["id"], f"id:{self.schema_name1}:{self.schema_name1}::1") + self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::1") + + def test_grouping(self): + # "select * from purchase | all(group(customer) each(output(sum(price))))" + # sample data from https://github.com/vespa-cloud/vespa-documentation-search#feed-grouping-examples + sample_data = [ + { + "fields": { + "customer": "Smith", + "date": 1157526000, + "item": "Intake valve", + "price": "1000", + "tax": "0.24", + }, + "put": "id:purchase:purchase::0", + }, + { + "fields": { + "customer": "Smith", + "date": 1157616000, + "item": "Rocker arm", + "price": "1000", + "tax": "0.12", + }, + "put": "id:purchase:purchase::1", + }, + { + "fields": { + "customer": "Smith", + "date": 1157619600, + "item": "Spring", + "price": "2000", + "tax": "0.24", + }, + "put": "id:purchase:purchase::2", + }, + { + "fields": { + "customer": "Jones", + "date": 1157709600, + "item": "Valve cover", + "price": "3000", + "tax": "0.12", + }, + "put": "id:purchase:purchase::3", + }, + { + "fields": { + "customer": "Jones", + "date": 1157702400, + "item": "Intake port", + "price": "5000", + "tax": "0.24", + }, + "put": "id:purchase:purchase::4", + }, + { + "fields": { + "customer": "Brown", + "date": 1157706000, + "item": "Head", + "price": "8000", + "tax": "0.12", + }, + "put": "id:purchase:purchase::5", + }, + { + "fields": { + "customer": "Smith", + "date": 1157796000, + "item": "Coolant", + "price": "1300", + "tax": "0.24", + }, + "put": "id:purchase:purchase::6", + }, + { + "fields": { + "customer": "Jones", + "date": 1157788800, + "item": "Engine block", + "price": "2100", + "tax": "0.12", + }, + "put": "id:purchase:purchase::7", + }, + { + "fields": { + "customer": "Brown", + "date": 1157792400, + "item": "Oil pan", + "price": "3400", + "tax": "0.24", + }, + "put": "id:purchase:purchase::8", + }, + { + "fields": { + "customer": "Smith", + "date": 1157796000, + "item": "Oil sump", + "price": "5500", + "tax": "0.12", + }, + "put": "id:purchase:purchase::9", + }, + { + "fields": { + "customer": "Jones", + "date": 1157875200, + "item": "Camshaft", + "price": "8900", + "tax": "0.24", + }, + "put": "id:purchase:purchase::10", + }, + { + "fields": { + "customer": "Brown", + "date": 1157878800, + "item": "Exhaust valve", + "price": "1440", + "tax": "0.12", + }, + "put": "id:purchase:purchase::11", + }, + { + "fields": { + "customer": "Brown", + "date": 1157882400, + "item": "Rocker arm", + "price": "2330", + "tax": "0.24", + }, + "put": "id:purchase:purchase::12", + }, + { + "fields": { + "customer": "Brown", + "date": 1157875200, + "item": "Spring", + "price": "3770", + "tax": "0.12", + }, + "put": "id:purchase:purchase::13", + }, + { + "fields": { + "customer": "Smith", + "date": 1157878800, + "item": "Spark plug", + "price": "6100", + "tax": "0.24", + }, + "put": "id:purchase:purchase::14", + }, + { + "fields": { + "customer": "Jones", + "date": 1157968800, + "item": "Exhaust port", + "price": "9870", + "tax": "0.12", + }, + "put": "id:purchase:purchase::15", + }, + { + "fields": { + "customer": "Brown", + "date": 1157961600, + "item": "Piston", + "price": "1597", + "tax": "0.24", + }, + "put": "id:purchase:purchase::16", + }, + { + "fields": { + "customer": "Smith", + "date": 1157965200, + "item": "Connection rod", + "price": "2584", + "tax": "0.12", + }, + "put": "id:purchase:purchase::17", + }, + { + "fields": { + "customer": "Jones", + "date": 1157968800, + "item": "Rod bearing", + "price": "4181", + "tax": "0.24", + }, + "put": "id:purchase:purchase::18", + }, + { + "fields": { + "customer": "Jones", + "date": 1157972400, + "item": "Crankshaft", + "price": "6765", + "tax": "0.12", + }, + "put": "id:purchase:purchase::19", + }, + ] + # map data to correct format + sample_data = [ + {"fields": doc["fields"], "id": doc["put"].split("::")[-1]} + for doc in sample_data + ] + # Feed documents + self.app.feed_iterable(iter=sample_data, schema="purchase") + # Execute query + q = qb.test_grouping_with_condition() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + result_children = result.json["root"]["children"][0]["children"] + # also get result from https://api.search.vespa.ai/search/?yql=select%20*%20from%20purchase%20where%20true%20%7C%20all(%20group(customer)%20each(output(sum(price)))%20) + # to compare + api_resp = requests.get( + "https://api.search.vespa.ai/search/?yql=select%20*%20from%20purchase%20where%20true%20%7C%20all(%20group(customer)%20each(output(sum(price)))%20)", + ) + api_resp = api_resp.json() + api_children = api_resp["root"]["children"][0]["children"] + self.maxDiff = None + self.assertEqual(result_children, api_children) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 5277cf95..a70ec8a8 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -188,10 +188,17 @@ def test_condition_annotation(self): self.assertEqual(q, expected) return q - def test_grouping_aggregation(self): - grouping = G.all(G.group("category"), G.output(G.count())) - q = Query(select_fields="*").from_("products").group(grouping) - expected = "select * from products | all(group(category) output(count()))" + def test_grouping_with_condition(self): + self.maxDiff = None + grouping = G.all(G.group("customer"), G.each(G.output(G.sum("price")))) + q = ( + Query(select_fields="*") + .from_("purchase") + .where(True) + .set_limit(0) + .groupby(grouping) + ) + expected = "select * from purchase where true limit 0 | all(group(customer) each(output(sum(price))))" self.assertEqual(q, expected) return q @@ -362,11 +369,11 @@ def test_multiple_parameters(self): def test_multiple_groupings(self): grouping = G.all( G.group("category"), - G.maxRtn(10), + G.max(10), G.output(G.count()), G.each(G.group("subcategory"), G.output(G.summary())), ) - q = Query(select_fields="*").from_("products").group(grouping) + q = Query(select_fields="*").from_("products").groupby(grouping) expected = "select * from products | all(group(category) max(10) output(count()) each(group(subcategory) output(summary())))" self.assertEqual(q, expected) return q @@ -556,15 +563,13 @@ def test_predicate(self): return query def test_true(self): - condition = Q.true() - query = Q.select("*").from_("sd1").where(condition) + query = Q.select("*").from_("sd1").where(True) expected = "select * from sd1 where true" self.assertEqual(query, expected) return query def test_false(self): - condition = Q.false() - query = Q.select("*").from_("sd1").where(condition) + query = Q.select("*").from_("sd1").where(False) expected = "select * from sd1 where false" self.assertEqual(query, expected) return query diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py index de94084f..14208e5c 100644 --- a/vespa/querybuilder/main.py +++ b/vespa/querybuilder/main.py @@ -234,9 +234,11 @@ def from_(self, *sources: str) -> "Query": self.sources = ", ".join(sources) return self - def where(self, condition: Union[Condition, Queryfield]) -> "Query": + def where(self, condition: Union[Condition, Queryfield, bool]) -> "Query": if isinstance(condition, Queryfield): self.condition = condition + elif isinstance(condition, bool): + self.condition = Condition("true") if condition else Condition("false") else: self.condition = condition return self @@ -287,7 +289,7 @@ def add_parameter(self, key: str, value: Any) -> "Query": def param(self, key: str, value: Any) -> "Query": return self.add_parameter(key, value) - def group(self, group_expression: str) -> "Query": + def groupby(self, group_expression: str) -> "Query": self.grouping = group_expression return self @@ -297,8 +299,6 @@ def build(self, prepend_yql=False) -> str: query = f"yql={query}" if self.condition: query += f" where {self.condition.build()}" - if self.grouping: - query += f" | {self.grouping}" if self.order_by_clauses: query += " order by " + ", ".join(self.order_by_clauses) if self.limit_value is not None: @@ -307,6 +307,8 @@ def build(self, prepend_yql=False) -> str: query += f" offset {self.offset_value}" if self.timeout_value is not None: query += f" timeout {self.timeout_value}" + if self.grouping: + query += f" | {self.grouping}" if self.parameters: params = "&" + "&".join(f"{k}={v}" for k, v in self.parameters.items()) query += params @@ -579,9 +581,19 @@ def group(field: str) -> str: return f"group({field})" @staticmethod - def maxRtn(value: int) -> str: + def max(value: int) -> str: return f"max({value})" + # min + @staticmethod + def min(value: int) -> str: + return f"min({value})" + + # sum + @staticmethod + def sum(value: int) -> str: + return f"sum({value})" + @staticmethod def each(*args) -> str: return "each(" + " ".join(args) + ")" From 69693663bb0939881a05e026cf360f130e1b3b27 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Thu, 5 Dec 2024 09:49:09 +0100 Subject: [PATCH 32/39] verify results of grouping --- tests/integration/test_integration_queries.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py index 3d105361..6993d5d8 100644 --- a/tests/integration/test_integration_queries.py +++ b/tests/integration/test_integration_queries.py @@ -1023,5 +1023,15 @@ def test_grouping(self): ) api_resp = api_resp.json() api_children = api_resp["root"]["children"][0]["children"] - self.maxDiff = None self.assertEqual(result_children, api_children) + # Verify the result + group_results = result_children[0]["children"] + self.assertEqual(group_results[0]["id"], "group:string:Brown") + self.assertEqual(group_results[0]["value"], "Brown") + self.assertEqual(group_results[0]["fields"]["sum(price)"], 20537) + self.assertEqual(group_results[1]["id"], "group:string:Jones") + self.assertEqual(group_results[1]["value"], "Jones") + self.assertEqual(group_results[1]["fields"]["sum(price)"], 39816) + self.assertEqual(group_results[2]["id"], "group:string:Smith") + self.assertEqual(group_results[2]["value"], "Smith") + self.assertEqual(group_results[2]["fields"]["sum(price)"], 19484) From a82b411c781e653d98a48417da421b75bd3a1c79 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Thu, 5 Dec 2024 13:43:43 +0100 Subject: [PATCH 33/39] more grouping methods --- vespa/querybuilder/main.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py index 14208e5c..46bf35d3 100644 --- a/vespa/querybuilder/main.py +++ b/vespa/querybuilder/main.py @@ -581,19 +581,33 @@ def group(field: str) -> str: return f"group({field})" @staticmethod - def max(value: int) -> str: + def max(value: Union[int, float]) -> str: return f"max({value})" - # min @staticmethod - def min(value: int) -> str: + def precision(value: int) -> str: + return f"precision({value})" + + @staticmethod + def min(value: Union[int, float]) -> str: return f"min({value})" - # sum @staticmethod - def sum(value: int) -> str: + def sum(value: Union[int, float]) -> str: return f"sum({value})" + @staticmethod + def avg(value: Union[int, float]) -> str: + return f"avg({value})" + + @staticmethod + def stddev(value: Union[int, float]) -> str: + return f"stddev({value})" + + @staticmethod + def xor(value: str) -> str: + return f"xor({value})" + @staticmethod def each(*args) -> str: return "each(" + " ".join(args) + ")" @@ -604,7 +618,19 @@ def output(output_func: str) -> str: @staticmethod def count() -> str: - return "count()" + # Also need to handle negative count + class MaybenegativeCount(str): + def __new__(cls, value): + return super().__new__(cls, value) + + def __neg__(self): + return f"-{self}" + + return MaybenegativeCount("count()") + + @staticmethod + def order(value: str) -> str: + return f"order({value})" @staticmethod def summary() -> str: From d4997ee218d8c5f0207c7c5dcb6ab3508eb400df Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Thu, 5 Dec 2024 13:43:58 +0100 Subject: [PATCH 34/39] more grouping unit tests --- tests/unit/test_q.py | 64 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index a70ec8a8..843b18ba 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -189,7 +189,6 @@ def test_condition_annotation(self): return q def test_grouping_with_condition(self): - self.maxDiff = None grouping = G.all(G.group("customer"), G.each(G.output(G.sum("price")))) q = ( Query(select_fields="*") @@ -202,6 +201,69 @@ def test_grouping_with_condition(self): self.assertEqual(q, expected) return q + def test_grouping_with_ordering_and_limiting(self): + self.maxDiff = None + grouping = G.all( + G.group("customer"), + G.max(2), + G.precision(12), + G.order(-G.count()), + G.each(G.output(G.sum("price"))), + ) + q = Query(select_fields="*").from_("purchase").where(True).groupby(grouping) + expected = "select * from purchase where true | all(group(customer) max(2) precision(12) order(-count()) each(output(sum(price))))" + self.assertEqual(q, expected) + return q + + def test_grouping_with_map_keys(self): + grouping = G.all( + G.group("mymap.key"), + G.each(G.group("mymap.value"), G.each(G.output(G.count()))), + ) + q = Query(select_fields="*").from_("purchase").where(True).groupby(grouping) + expected = "select * from purchase where true | all(group(mymap.key) each(group(mymap.value) each(output(count()))))" + self.assertEqual(q, expected) + return q + + def test_group_by_year(self): + grouping = G.all(G.group("time.year(a)"), G.each(G.output(G.count()))) + q = Query(select_fields="*").from_("purchase").where(True).groupby(grouping) + expected = "select * from purchase where true | all(group(time.year(a)) each(output(count())))" + self.assertEqual(q, expected) + return q + + def test_grouping_with_date_agg(self): + grouping = G.all( + G.group("time.year(a)"), + G.each( + G.output(G.count()), + G.all( + G.group("time.monthofyear(a)"), + G.each( + G.output(G.count()), + G.all( + G.group("time.dayofmonth(a)"), + G.each( + G.output(G.count()), + G.all( + G.group("time.hourofday(a)"), + G.each(G.output(G.count())), + ), + ), + ), + ), + ), + ), + ) + q = Query(select_fields="*").from_("purchase").where(True).groupby(grouping) + expected = "select * from purchase where true | all(group(time.year(a)) each(output(count()) all(group(time.monthofyear(a)) each(output(count()) all(group(time.dayofmonth(a)) each(output(count()) all(group(time.hourofday(a)) each(output(count())))))))))" + # q select * from purchase where true | all(group(time.year(a)) each(output(count()) all(group(time.monthofyear(a)) each(output(count()) all(group(time.dayofmonth(a)) each(output(count()) all(group(time.hourofday(a)) each(output(count()))))))))) + # e select * from purchase where true | all(group(time.year(a)) each(output(count() all(group(time.monthofyear(a)) each(output(count()) all(group(time.dayofmonth(a)) each(output(count()) all(group(time.hourofday(a)) each(output(count()))))))))) + print(q) + print(expected) + self.assertEqual(q, expected) + return q + def test_add_parameter(self): f1 = Queryfield("title") condition = f1.contains("Python") From 940268bd02a56be4146c9fd0e20d6db22621d2f4 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Thu, 5 Dec 2024 13:44:16 +0100 Subject: [PATCH 35/39] more grouping integration tests --- tests/integration/test_integration_queries.py | 476 +++++++++--------- 1 file changed, 240 insertions(+), 236 deletions(-) diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py index 6993d5d8..80107f1c 100644 --- a/tests/integration/test_integration_queries.py +++ b/tests/integration/test_integration_queries.py @@ -143,31 +143,6 @@ def setUpClass(cls): ) schema.add_fields(emails_field, persons_field) # Add purchase schema for grouping test - # schema purchase { - - # document purchase { - - # field date type long { - # indexing: summary | attribute - # } - - # field price type int { - # indexing: summary | attribute - # } - - # field tax type double { - # indexing: summary | attribute - # } - - # field item type string { - # indexing: summary | attribute - # } - - # field customer type string { - # indexing: summary | attribute - # } - - # } purchase_schema = Schema( name="purchase", document=Document( @@ -202,6 +177,221 @@ def tearDownClass(cls): cls.vespa_docker.container.stop(timeout=5) cls.vespa_docker.container.remove() + @property + def sample_grouping_data(self): + sample_data = [ + { + "fields": { + "customer": "Smith", + "date": 1157526000, + "item": "Intake valve", + "price": "1000", + "tax": "0.24", + }, + "put": "id:purchase:purchase::0", + }, + { + "fields": { + "customer": "Smith", + "date": 1157616000, + "item": "Rocker arm", + "price": "1000", + "tax": "0.12", + }, + "put": "id:purchase:purchase::1", + }, + { + "fields": { + "customer": "Smith", + "date": 1157619600, + "item": "Spring", + "price": "2000", + "tax": "0.24", + }, + "put": "id:purchase:purchase::2", + }, + { + "fields": { + "customer": "Jones", + "date": 1157709600, + "item": "Valve cover", + "price": "3000", + "tax": "0.12", + }, + "put": "id:purchase:purchase::3", + }, + { + "fields": { + "customer": "Jones", + "date": 1157702400, + "item": "Intake port", + "price": "5000", + "tax": "0.24", + }, + "put": "id:purchase:purchase::4", + }, + { + "fields": { + "customer": "Brown", + "date": 1157706000, + "item": "Head", + "price": "8000", + "tax": "0.12", + }, + "put": "id:purchase:purchase::5", + }, + { + "fields": { + "customer": "Smith", + "date": 1157796000, + "item": "Coolant", + "price": "1300", + "tax": "0.24", + }, + "put": "id:purchase:purchase::6", + }, + { + "fields": { + "customer": "Jones", + "date": 1157788800, + "item": "Engine block", + "price": "2100", + "tax": "0.12", + }, + "put": "id:purchase:purchase::7", + }, + { + "fields": { + "customer": "Brown", + "date": 1157792400, + "item": "Oil pan", + "price": "3400", + "tax": "0.24", + }, + "put": "id:purchase:purchase::8", + }, + { + "fields": { + "customer": "Smith", + "date": 1157796000, + "item": "Oil sump", + "price": "5500", + "tax": "0.12", + }, + "put": "id:purchase:purchase::9", + }, + { + "fields": { + "customer": "Jones", + "date": 1157875200, + "item": "Camshaft", + "price": "8900", + "tax": "0.24", + }, + "put": "id:purchase:purchase::10", + }, + { + "fields": { + "customer": "Brown", + "date": 1157878800, + "item": "Exhaust valve", + "price": "1440", + "tax": "0.12", + }, + "put": "id:purchase:purchase::11", + }, + { + "fields": { + "customer": "Brown", + "date": 1157882400, + "item": "Rocker arm", + "price": "2330", + "tax": "0.24", + }, + "put": "id:purchase:purchase::12", + }, + { + "fields": { + "customer": "Brown", + "date": 1157875200, + "item": "Spring", + "price": "3770", + "tax": "0.12", + }, + "put": "id:purchase:purchase::13", + }, + { + "fields": { + "customer": "Smith", + "date": 1157878800, + "item": "Spark plug", + "price": "6100", + "tax": "0.24", + }, + "put": "id:purchase:purchase::14", + }, + { + "fields": { + "customer": "Jones", + "date": 1157968800, + "item": "Exhaust port", + "price": "9870", + "tax": "0.12", + }, + "put": "id:purchase:purchase::15", + }, + { + "fields": { + "customer": "Brown", + "date": 1157961600, + "item": "Piston", + "price": "1597", + "tax": "0.24", + }, + "put": "id:purchase:purchase::16", + }, + { + "fields": { + "customer": "Smith", + "date": 1157965200, + "item": "Connection rod", + "price": "2584", + "tax": "0.12", + }, + "put": "id:purchase:purchase::17", + }, + { + "fields": { + "customer": "Jones", + "date": 1157968800, + "item": "Rod bearing", + "price": "4181", + "tax": "0.24", + }, + "put": "id:purchase:purchase::18", + }, + { + "fields": { + "customer": "Jones", + "date": 1157972400, + "item": "Crankshaft", + "price": "6765", + "tax": "0.12", + }, + "put": "id:purchase:purchase::19", + }, + ] + docs = [ + {"fields": doc["fields"], "id": doc["put"].split("::")[-1]} + for doc in sample_data + ] + return docs + + def feed_grouping_data(self): + # Feed documents + self.app.feed_iterable(iter=self.sample_grouping_data, schema="purchase") + return + def test_dotProduct_with_annotations(self): # Feed a document with 'weightedset_field' field = "weightedset_field" @@ -798,218 +988,10 @@ def test_same_element(self): hit = result.hits[0] self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::1") - def test_grouping(self): + def test_grouping_with_condition(self): # "select * from purchase | all(group(customer) each(output(sum(price))))" - # sample data from https://github.com/vespa-cloud/vespa-documentation-search#feed-grouping-examples - sample_data = [ - { - "fields": { - "customer": "Smith", - "date": 1157526000, - "item": "Intake valve", - "price": "1000", - "tax": "0.24", - }, - "put": "id:purchase:purchase::0", - }, - { - "fields": { - "customer": "Smith", - "date": 1157616000, - "item": "Rocker arm", - "price": "1000", - "tax": "0.12", - }, - "put": "id:purchase:purchase::1", - }, - { - "fields": { - "customer": "Smith", - "date": 1157619600, - "item": "Spring", - "price": "2000", - "tax": "0.24", - }, - "put": "id:purchase:purchase::2", - }, - { - "fields": { - "customer": "Jones", - "date": 1157709600, - "item": "Valve cover", - "price": "3000", - "tax": "0.12", - }, - "put": "id:purchase:purchase::3", - }, - { - "fields": { - "customer": "Jones", - "date": 1157702400, - "item": "Intake port", - "price": "5000", - "tax": "0.24", - }, - "put": "id:purchase:purchase::4", - }, - { - "fields": { - "customer": "Brown", - "date": 1157706000, - "item": "Head", - "price": "8000", - "tax": "0.12", - }, - "put": "id:purchase:purchase::5", - }, - { - "fields": { - "customer": "Smith", - "date": 1157796000, - "item": "Coolant", - "price": "1300", - "tax": "0.24", - }, - "put": "id:purchase:purchase::6", - }, - { - "fields": { - "customer": "Jones", - "date": 1157788800, - "item": "Engine block", - "price": "2100", - "tax": "0.12", - }, - "put": "id:purchase:purchase::7", - }, - { - "fields": { - "customer": "Brown", - "date": 1157792400, - "item": "Oil pan", - "price": "3400", - "tax": "0.24", - }, - "put": "id:purchase:purchase::8", - }, - { - "fields": { - "customer": "Smith", - "date": 1157796000, - "item": "Oil sump", - "price": "5500", - "tax": "0.12", - }, - "put": "id:purchase:purchase::9", - }, - { - "fields": { - "customer": "Jones", - "date": 1157875200, - "item": "Camshaft", - "price": "8900", - "tax": "0.24", - }, - "put": "id:purchase:purchase::10", - }, - { - "fields": { - "customer": "Brown", - "date": 1157878800, - "item": "Exhaust valve", - "price": "1440", - "tax": "0.12", - }, - "put": "id:purchase:purchase::11", - }, - { - "fields": { - "customer": "Brown", - "date": 1157882400, - "item": "Rocker arm", - "price": "2330", - "tax": "0.24", - }, - "put": "id:purchase:purchase::12", - }, - { - "fields": { - "customer": "Brown", - "date": 1157875200, - "item": "Spring", - "price": "3770", - "tax": "0.12", - }, - "put": "id:purchase:purchase::13", - }, - { - "fields": { - "customer": "Smith", - "date": 1157878800, - "item": "Spark plug", - "price": "6100", - "tax": "0.24", - }, - "put": "id:purchase:purchase::14", - }, - { - "fields": { - "customer": "Jones", - "date": 1157968800, - "item": "Exhaust port", - "price": "9870", - "tax": "0.12", - }, - "put": "id:purchase:purchase::15", - }, - { - "fields": { - "customer": "Brown", - "date": 1157961600, - "item": "Piston", - "price": "1597", - "tax": "0.24", - }, - "put": "id:purchase:purchase::16", - }, - { - "fields": { - "customer": "Smith", - "date": 1157965200, - "item": "Connection rod", - "price": "2584", - "tax": "0.12", - }, - "put": "id:purchase:purchase::17", - }, - { - "fields": { - "customer": "Jones", - "date": 1157968800, - "item": "Rod bearing", - "price": "4181", - "tax": "0.24", - }, - "put": "id:purchase:purchase::18", - }, - { - "fields": { - "customer": "Jones", - "date": 1157972400, - "item": "Crankshaft", - "price": "6765", - "tax": "0.12", - }, - "put": "id:purchase:purchase::19", - }, - ] - # map data to correct format - sample_data = [ - {"fields": doc["fields"], "id": doc["put"].split("::")[-1]} - for doc in sample_data - ] - # Feed documents - self.app.feed_iterable(iter=sample_data, schema="purchase") + # Feed test documents + self.feed_grouping_data() # Execute query q = qb.test_grouping_with_condition() print(f"Executing query: {q}") @@ -1035,3 +1017,25 @@ def test_grouping(self): self.assertEqual(group_results[2]["id"], "group:string:Smith") self.assertEqual(group_results[2]["value"], "Smith") self.assertEqual(group_results[2]["fields"]["sum(price)"], 19484) + + def test_grouping_with_ordering_and_limiting(self): + # "select * from purchase where true | all(group(customer) max(2) precision(12) order(-count()) each(output(sum(price))))" + # Feed test documents + self.feed_grouping_data() + # Execute query + q = qb.test_grouping_with_ordering_and_limiting() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + result_children = result.json["root"]["children"][0]["children"][0]["children"] + print(result_children) + # assert 2 groups + self.assertEqual(len(result_children), 2) + # assert the first group is Jones + self.assertEqual(result_children[0]["id"], "group:string:Jones") + self.assertEqual(result_children[0]["value"], "Jones") + self.assertEqual(result_children[0]["fields"]["sum(price)"], 39816) + # assert the second group is Brown + self.assertEqual(result_children[1]["id"], "group:string:Smith") + self.assertEqual(result_children[1]["value"], "Smith") + self.assertEqual(result_children[1]["fields"]["sum(price)"], 19484) From 44e8fb10219d66c13a0e7dd67e65351c481be0ea Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Thu, 5 Dec 2024 14:21:29 +0100 Subject: [PATCH 36/39] module level working --- tests/unit/test_q.py | 5 +- vespa/querybuilder/__init__.py | 36 +- vespa/querybuilder/builder/__init__.py | 0 vespa/querybuilder/builder/builder.py | 571 ++++++++++++++++++++++++ vespa/querybuilder/grouping/__init__.py | 0 vespa/querybuilder/grouping/grouping.py | 67 +++ vespa/querybuilder/main.py | 66 --- 7 files changed, 675 insertions(+), 70 deletions(-) create mode 100644 vespa/querybuilder/builder/__init__.py create mode 100644 vespa/querybuilder/builder/builder.py create mode 100644 vespa/querybuilder/grouping/__init__.py create mode 100644 vespa/querybuilder/grouping/grouping.py diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 843b18ba..4122c441 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -1,15 +1,16 @@ import unittest from vespa.querybuilder import Query, Q, Queryfield, G, Condition +import vespa.querybuilder as qb class TestQueryBuilder(unittest.TestCase): def test_dotProduct_with_annotations(self): - condition = Q.dotProduct( + condition = qb.dotProduct( "weightedset_field", {"feature1": 1, "feature2": 2}, annotations={"label": "myDotProduct"}, ) - q = Query(select_fields="*").from_("sd1").where(condition) + q = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where ({label:"myDotProduct"}dotProduct(weightedset_field, {"feature1":1,"feature2":2}))' self.assertEqual(q, expected) return q diff --git a/vespa/querybuilder/__init__.py b/vespa/querybuilder/__init__.py index 6ccdd50e..7681bfae 100644 --- a/vespa/querybuilder/__init__.py +++ b/vespa/querybuilder/__init__.py @@ -1,2 +1,34 @@ -# Export all from main -from .main import * +from .builder.builder import Query, Q, Queryfield, Condition +from .grouping.grouping import G +import inspect + +# Import original classes +# ...existing code... + +# Automatically expose all static methods from Q and G classes +for cls in [Q, G]: + for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): + if not name.startswith("_"): + # Create function with same name and signature as the static method + globals()[name] = method + +# Create __all__ list dynamically +__all__ = [ + # Classes + "Query", + "Q", + "Queryfield", + "G", + "Condition", + # Add all exposed functions + *( + name + for name, method in inspect.getmembers(Q, predicate=inspect.isfunction) + if not name.startswith("_") + ), + *( + name + for name, method in inspect.getmembers(G, predicate=inspect.isfunction) + if not name.startswith("_") + ), +] diff --git a/vespa/querybuilder/builder/__init__.py b/vespa/querybuilder/builder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vespa/querybuilder/builder/builder.py b/vespa/querybuilder/builder/builder.py new file mode 100644 index 00000000..de92f382 --- /dev/null +++ b/vespa/querybuilder/builder/builder.py @@ -0,0 +1,571 @@ +from dataclasses import dataclass +from typing import Any, List, Union, Optional, Dict + + +@dataclass +class Queryfield: + name: str + + def __eq__(self, other: Any) -> "Condition": + return Condition(f"{self.name} = {self._format_value(other)}") + + def __ne__(self, other: Any) -> "Condition": + return Condition(f"{self.name} != {self._format_value(other)}") + + def __lt__(self, other: Any) -> "Condition": + return Condition(f"{self.name} < {self._format_value(other)}") + + def __le__(self, other: Any) -> "Condition": + return Condition(f"{self.name} <= {self._format_value(other)}") + + def __gt__(self, other: Any) -> "Condition": + return Condition(f"{self.name} > {self._format_value(other)}") + + def __ge__(self, other: Any) -> "Condition": + return Condition(f"{self.name} >= {self._format_value(other)}") + + def __and__(self, other: Any) -> "Condition": + return Condition(f"{self.name} and {self._format_value(other)}") + + def __or__(self, other: Any) -> "Condition": + return Condition(f"{self.name} or {self._format_value(other)}") + + def contains( + self, value: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + value_str = self._format_value(value) + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{self.name} contains({{{annotations_str}}}{value_str})") + else: + return Condition(f"{self.name} contains {value_str}") + + def matches( + self, value: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + value_str = self._format_value(value) + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{self.name} matches({{{annotations_str}}}{value_str})") + else: + return Condition(f"{self.name} matches {value_str}") + + def in_(self, *values) -> "Condition": + values_str = ", ".join( + f'"{v}"' if isinstance(v, str) else str(v) for v in values + ) + return Condition(f"{self.name} in ({values_str})") + + def in_range( + self, start: Any, end: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}range({self.name}, {start}, {end}))" + ) + else: + return Condition(f"range({self.name}, {start}, {end})") + + def le(self, value: Any) -> "Condition": + return self.__le__(value) + + def lt(self, value: Any) -> "Condition": + return self.__lt__(value) + + def ge(self, value: Any) -> "Condition": + return self.__ge__(value) + + def gt(self, value: Any) -> "Condition": + return self.__gt__(value) + + def eq(self, value: Any) -> "Condition": + return self.__eq__(value) + + def __str__(self) -> str: + return self.name + + @staticmethod + def _format_value(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + elif isinstance(value, Condition): + return value.build() + else: + return str(value) + + @staticmethod + def _format_annotation_value(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, dict): + return ( + "{" + + ",".join( + f'"{k}":{Queryfield._format_annotation_value(v)}' + for k, v in value.items() + ) + + "}" + ) + elif isinstance(value, list): + return ( + "[" + + ",".join(f"{Queryfield._format_annotation_value(v)}" for v in value) + + "]" + ) + else: + return str(value) + + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" for k, v in annotations.items() + ) + return Condition(f"({{{annotations_str}}}){self.name}") + + +@dataclass +class Condition: + expression: str + + def __and__(self, other: "Condition") -> "Condition": + left = self.expression + right = other.expression + + # Adjust parentheses based on operator precedence + left = f"({left})" if " or " in left else left + right = f"({right})" if " or " in right else right + + return Condition(f"{left} and {right}") + + def __or__(self, other: "Condition") -> "Condition": + left = self.expression + right = other.expression + + # Always add parentheses if 'and' or 'or' is in the expressions + left = f"({left})" if " and " in left or " or " in left else left + right = f"({right})" if " and " in right or " or " in right else right + + return Condition(f"{left} or {right}") + + def __invert__(self) -> "Condition": + return Condition(f"!({self.expression})") + + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{{{annotations_str}}}{self.expression}") + + def build(self) -> str: + return self.expression + + @classmethod + def all(cls, *conditions: "Condition") -> "Condition": + """Combine multiple conditions using logical AND.""" + expressions = [] + for cond in conditions: + expr = cond.expression + # Wrap expressions with 'or' in parentheses + if " or " in expr: + expr = f"({expr})" + expressions.append(expr) + combined_expression = " and ".join(expressions) + return Condition(combined_expression) + + @classmethod + def any(cls, *conditions: "Condition") -> "Condition": + """Combine multiple conditions using logical OR.""" + expressions = [] + for cond in conditions: + expr = cond.expression + # Wrap expressions with 'and' or 'or' in parentheses + if " and " in expr or " or " in expr: + expr = f"({expr})" + expressions.append(expr) + combined_expression = " or ".join(expressions) + return Condition(combined_expression) + + +class Query: + def __init__( + self, select_fields: Union[str, List[str], List[Queryfield]], prepend_yql=False + ): + self.select_fields = ( + ", ".join(select_fields) + if isinstance(select_fields, List) + and all(isinstance(f, str) for f in select_fields) + else ", ".join(str(f) for f in select_fields) + ) + self.sources = "*" + self.condition = None + self.order_by_clauses = [] + self.limit_value = None + self.offset_value = None + self.timeout_value = None + self.parameters = {} + self.grouping = None + self.prepend_yql = prepend_yql + + def __str__(self) -> str: + return self.build(self.prepend_yql) + + def __eq__(self, other: Any) -> bool: + return self.build() == other + + def __ne__(self, other: Any) -> bool: + return self.build() != other + + def __repr__(self) -> str: + return str(self) + + def from_(self, *sources: str) -> "Query": + self.sources = ", ".join(sources) + return self + + def where(self, condition: Union[Condition, Queryfield, bool]) -> "Query": + if isinstance(condition, Queryfield): + self.condition = condition + elif isinstance(condition, bool): + self.condition = Condition("true") if condition else Condition("false") + else: + self.condition = condition + return self + + def order_by_field( + self, + field: str, + ascending: bool = True, + annotations: Optional[Dict[str, Any]] = None, + ) -> "Query": + direction = "asc" if ascending else "desc" + if annotations: + annotations_str = ",".join( + f'"{k}":{Queryfield._format_annotation_value(v)}' + for k, v in annotations.items() + ) + self.order_by_clauses.append(f"{{{annotations_str}}}{field} {direction}") + else: + self.order_by_clauses.append(f"{field} {direction}") + return self + + def orderByAsc( + self, field: str, annotations: Optional[Dict[str, Any]] = None + ) -> "Query": + return self.order_by_field(field, True, annotations) + + def orderByDesc( + self, field: str, annotations: Optional[Dict[str, Any]] = None + ) -> "Query": + return self.order_by_field(field, False, annotations) + + def set_limit(self, limit: int) -> "Query": + self.limit_value = limit + return self + + def set_offset(self, offset: int) -> "Query": + self.offset_value = offset + return self + + def set_timeout(self, timeout: int) -> "Query": + self.timeout_value = timeout + return self + + def add_parameter(self, key: str, value: Any) -> "Query": + self.parameters[key] = value + return self + + def param(self, key: str, value: Any) -> "Query": + return self.add_parameter(key, value) + + def groupby(self, group_expression: str) -> "Query": + self.grouping = group_expression + return self + + def build(self, prepend_yql=False) -> str: + query = f"select {self.select_fields} from {self.sources}" + if prepend_yql: + query = f"yql={query}" + if self.condition: + query += f" where {self.condition.build()}" + if self.order_by_clauses: + query += " order by " + ", ".join(self.order_by_clauses) + if self.limit_value is not None: + query += f" limit {self.limit_value}" + if self.offset_value is not None: + query += f" offset {self.offset_value}" + if self.timeout_value is not None: + query += f" timeout {self.timeout_value}" + if self.grouping: + query += f" | {self.grouping}" + if self.parameters: + params = "&" + "&".join(f"{k}={v}" for k, v in self.parameters.items()) + query += params + return query + + +class Q: + @staticmethod + def select(*fields): + return Query(select_fields=list(fields)) + + @staticmethod + def p(*args): + if not args: + return Condition("") + else: + condition = args[0] + for arg in args[1:]: + condition = condition & arg + return condition + + @staticmethod + def userQuery(value: str = "") -> Condition: + return Condition(f'userQuery("{value}")') if value else Condition("userQuery()") + + @staticmethod + def dotProduct( + field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" + expr = f"dotProduct({field}, {vector_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def weightedSet( + field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" + expr = f"weightedSet({field}, {vector_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def nonEmpty(condition: Union[Condition, Queryfield]) -> Condition: + if isinstance(condition, Queryfield): + expr = str(condition) + else: + expr = condition.build() + return Condition(f"nonEmpty({expr})") + + @staticmethod + def wand( + field: str, weights, annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + if isinstance(weights, list): + weights_str = "[" + ", ".join(str(item) for item in weights) + "]" + elif isinstance(weights, dict): + weights_str = ( + "{" + ", ".join(f'"{k}":{v}' for k, v in weights.items()) + "}" + ) + else: + raise ValueError("Invalid weights for wand") + expr = f"wand({field}, {weights_str})" + if annotations: + annotations_str = ", ".join( + f"{k}: {Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def weakAnd(*conditions, annotations: Dict[str, Any] = None) -> Condition: + conditions_str = ", ".join(cond.build() for cond in conditions) + expr = f"weakAnd({conditions_str})" + if annotations: + annotations_str = ",".join( + f'"{k}": {Queryfield._format_annotation_value(v)}' + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def geoLocation( + field: str, + lat: float, + lng: float, + radius: str, + annotations: Optional[Dict[str, Any]] = None, + ) -> Condition: + expr = f'geoLocation({field}, {lat}, {lng}, "{radius}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def nearestNeighbor( + field: str, query_vector: str, annotations: Dict[str, Any] + ) -> Condition: + if "targetHits" not in annotations: + raise ValueError("targetHits annotation is required") + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}nearestNeighbor({field}, {query_vector}))" + ) + + @staticmethod + def rank(*queries) -> Condition: + queries_str = ", ".join(query.build() for query in queries) + return Condition(f"rank({queries_str})") + + @staticmethod + def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"phrase({terms_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def near( + *terms, annotations: Optional[Dict[str, Any]] = None, **kwargs + ) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"near({terms_str})" + # if kwargs - add to annotations + if kwargs: + if not annotations: + annotations = {} + annotations.update(kwargs) + if annotations: + annotations_str = ", ".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def onear( + *terms, annotations: Optional[Dict[str, Any]] = None, **kwargs + ) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"onear({terms_str})" + # if kwargs - add to annotations + if kwargs: + if not annotations: + annotations = {} + annotations.update(kwargs) + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def sameElement(*conditions) -> Condition: + conditions_str = ", ".join(cond.build() for cond in conditions) + expr = f"sameElement({conditions_str})" + return Condition(expr) + + @staticmethod + def equiv(*terms) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"equiv({terms_str})" + return Condition(expr) + + @staticmethod + def uri(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: + expr = f'uri("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def fuzzy(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: + expr = f'fuzzy("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def userInput( + value: Optional[str] = None, annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + if value is None: + expr = "userInput()" + elif value.startswith("@"): + expr = f"userInput({value})" + else: + expr = f'userInput("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def predicate( + field: str, + attributes: Optional[Dict[str, Any]] = None, + range_attributes: Optional[Dict[str, Any]] = None, + ) -> Condition: + if attributes is None: + attributes_str = "0" + else: + attributes_str = ( + "{" + ",".join(f'"{k}":"{v}"' for k, v in attributes.items()) + "}" + ) + if range_attributes is None: + range_attributes_str = "0" + else: + range_attributes_str = ( + "{" + ",".join(f'"{k}":{v}' for k, v in range_attributes.items()) + "}" + ) + expr = f"predicate({field},{attributes_str},{range_attributes_str})" + return Condition(expr) + + @staticmethod + def true() -> Condition: + return Condition("true") + + @staticmethod + def false() -> Condition: + return Condition("false") diff --git a/vespa/querybuilder/grouping/__init__.py b/vespa/querybuilder/grouping/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vespa/querybuilder/grouping/grouping.py b/vespa/querybuilder/grouping/grouping.py new file mode 100644 index 00000000..f9494d84 --- /dev/null +++ b/vespa/querybuilder/grouping/grouping.py @@ -0,0 +1,67 @@ +from typing import Union + + +class G: + @staticmethod + def all(*args) -> str: + return "all(" + " ".join(args) + ")" + + @staticmethod + def group(field: str) -> str: + return f"group({field})" + + @staticmethod + def max(value: Union[int, float]) -> str: + return f"max({value})" + + @staticmethod + def precision(value: int) -> str: + return f"precision({value})" + + @staticmethod + def min(value: Union[int, float]) -> str: + return f"min({value})" + + @staticmethod + def sum(value: Union[int, float]) -> str: + return f"sum({value})" + + @staticmethod + def avg(value: Union[int, float]) -> str: + return f"avg({value})" + + @staticmethod + def stddev(value: Union[int, float]) -> str: + return f"stddev({value})" + + @staticmethod + def xor(value: str) -> str: + return f"xor({value})" + + @staticmethod + def each(*args) -> str: + return "each(" + " ".join(args) + ")" + + @staticmethod + def output(output_func: str) -> str: + return f"output({output_func})" + + @staticmethod + def count() -> str: + # Also need to handle negative count + class MaybenegativeCount(str): + def __new__(cls, value): + return super().__new__(cls, value) + + def __neg__(self): + return f"-{self}" + + return MaybenegativeCount("count()") + + @staticmethod + def order(value: str) -> str: + return f"order({value})" + + @staticmethod + def summary() -> str: + return "summary()" diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py index 46bf35d3..de92f382 100644 --- a/vespa/querybuilder/main.py +++ b/vespa/querybuilder/main.py @@ -569,69 +569,3 @@ def true() -> Condition: @staticmethod def false() -> Condition: return Condition("false") - - -class G: - @staticmethod - def all(*args) -> str: - return "all(" + " ".join(args) + ")" - - @staticmethod - def group(field: str) -> str: - return f"group({field})" - - @staticmethod - def max(value: Union[int, float]) -> str: - return f"max({value})" - - @staticmethod - def precision(value: int) -> str: - return f"precision({value})" - - @staticmethod - def min(value: Union[int, float]) -> str: - return f"min({value})" - - @staticmethod - def sum(value: Union[int, float]) -> str: - return f"sum({value})" - - @staticmethod - def avg(value: Union[int, float]) -> str: - return f"avg({value})" - - @staticmethod - def stddev(value: Union[int, float]) -> str: - return f"stddev({value})" - - @staticmethod - def xor(value: str) -> str: - return f"xor({value})" - - @staticmethod - def each(*args) -> str: - return "each(" + " ".join(args) + ")" - - @staticmethod - def output(output_func: str) -> str: - return f"output({output_func})" - - @staticmethod - def count() -> str: - # Also need to handle negative count - class MaybenegativeCount(str): - def __new__(cls, value): - return super().__new__(cls, value) - - def __neg__(self): - return f"-{self}" - - return MaybenegativeCount("count()") - - @staticmethod - def order(value: str) -> str: - return f"order({value})" - - @staticmethod - def summary() -> str: - return "summary()" From 58f19dd458b0f00e407ef009c47836a5a686f590 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Thu, 5 Dec 2024 14:52:01 +0100 Subject: [PATCH 37/39] next - restructure and add docstrings to public methods --- tests/unit/test_q.py | 323 ++++++++++++-------------- vespa/querybuilder/__init__.py | 33 +-- vespa/querybuilder/builder/builder.py | 22 +- 3 files changed, 177 insertions(+), 201 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 4122c441..4ae60351 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -1,5 +1,5 @@ import unittest -from vespa.querybuilder import Query, Q, Queryfield, G, Condition +from vespa.querybuilder import G import vespa.querybuilder as qb @@ -16,49 +16,49 @@ def test_dotProduct_with_annotations(self): return q def test_geolocation_with_annotations(self): - condition = Q.geoLocation( + condition = qb.geoLocation( "location_field", 37.7749, -122.4194, "10km", annotations={"targetHits": 100}, ) - q = Query(select_fields="*").from_("sd1").where(condition) + q = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where ({targetHits:100}geoLocation(location_field, 37.7749, -122.4194, "10km"))' self.assertEqual(q, expected) return q def test_select_specific_fields(self): - f1 = Queryfield("f1") + f1 = qb.Queryfield("f1") condition = f1.contains("v1") - q = Query(select_fields=["f1", "f2"]).from_("sd1").where(condition) + q = qb.select(["f1", "f2"]).from_("sd1").where(condition) self.assertEqual(q, 'select f1, f2 from sd1 where f1 contains "v1"') return q def test_select_from_specific_sources(self): - f1 = Queryfield("f1") + f1 = qb.Queryfield("f1") condition = f1.contains("v1") - q = Query(select_fields="*").from_("sd1").where(condition) + q = qb.select("*").from_("sd1").where(condition) self.assertEqual(q, 'select * from sd1 where f1 contains "v1"') return q def test_select_from_multiples_sources(self): - f1 = Queryfield("f1") + f1 = qb.Queryfield("f1") condition = f1.contains("v1") - q = Query(select_fields="*").from_("sd1", "sd2").where(condition) + q = qb.select("*").from_("sd1", "sd2").where(condition) self.assertEqual(q, 'select * from sd1, sd2 where f1 contains "v1"') return q def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): - f1 = Queryfield("f1") - f2 = Queryfield("f2") - f3 = Queryfield("f3") - f4 = Queryfield("f4") + f1 = qb.Queryfield("f1") + f2 = qb.Queryfield("f2") + f3 = qb.Queryfield("f3") + f4 = qb.Queryfield("f4") condition = ((f1.contains("v1") & f2.contains("v2")) | f3.contains("v3")) & ( ~f4.contains("v4") ) q = ( - Query(select_fields="*") + qb.select("*") .from_("sd1") .where(condition) .set_offset(1) @@ -73,59 +73,55 @@ def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): return q def test_timeout(self): - f1 = Queryfield("title") + f1 = qb.Queryfield("title") condition = f1.contains("madonna") - q = Query(select_fields="*").from_("sd1").where(condition).set_timeout(70) + q = qb.select("*").from_("sd1").where(condition).set_timeout(70) expected = 'select * from sd1 where title contains "madonna" timeout 70' self.assertEqual(q, expected) return q def test_matches(self): condition = ( - (Queryfield("f1").matches("v1") & Queryfield("f2").matches("v2")) - | Queryfield("f3").matches("v3") - ) & ~Queryfield("f4").matches("v4") - q = Query(select_fields="*").from_("sd1").where(condition) + (qb.Queryfield("f1").matches("v1") & qb.Queryfield("f2").matches("v2")) + | qb.Queryfield("f3").matches("v3") + ) & ~qb.Queryfield("f4").matches("v4") + q = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where ((f1 matches "v1" and f2 matches "v2") or f3 matches "v3") and !(f4 matches "v4")' self.assertEqual(q, expected) return q def test_nested_queries(self): nested_query = ( - Queryfield("f2").contains("2") & Queryfield("f3").contains("3") - ) | (Queryfield("f2").contains("4") & ~Queryfield("f3").contains("5")) - condition = Queryfield("f1").contains("1") & ~nested_query - q = Query(select_fields="*").from_("sd1").where(condition) + qb.Queryfield("f2").contains("2") & qb.Queryfield("f3").contains("3") + ) | (qb.Queryfield("f2").contains("4") & ~qb.Queryfield("f3").contains("5")) + condition = qb.Queryfield("f1").contains("1") & ~nested_query + q = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where f1 contains "1" and (!((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5"))))' self.assertEqual(q, expected) return q def test_userquery(self): - condition = Q.userQuery() - q = Query(select_fields="*").from_("sd1").where(condition) + condition = qb.userQuery() + q = qb.select("*").from_("sd1").where(condition) expected = "select * from sd1 where userQuery()" self.assertEqual(q, expected) return q def test_fields_duration(self): - f1 = Queryfield("subject") - f2 = Queryfield("display_date") - f3 = Queryfield("duration") - q = Query(select_fields=[f1, f2]).from_("calendar").where(f3 > 0) + f1 = qb.Queryfield("subject") + f2 = qb.Queryfield("display_date") + f3 = qb.Queryfield("duration") + q = qb.select([f1, f2]).from_("calendar").where(f3 > 0) expected = "select subject, display_date from calendar where duration > 0" self.assertEqual(q, expected) return q def test_nearest_neighbor(self): - condition_uq = Q.userQuery() - condition_nn = Q.nearestNeighbor( + condition_uq = qb.userQuery() + condition_nn = qb.nearestNeighbor( field="dense_rep", query_vector="q_dense", annotations={"targetHits": 10} ) - q = ( - Query(select_fields=["id, text"]) - .from_("m") - .where(condition_uq | condition_nn) - ) + q = qb.select(["id, text"]).from_("m").where(condition_uq | condition_nn) expected = "select id, text from m where userQuery() or ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) return q @@ -133,7 +129,7 @@ def test_nearest_neighbor(self): def test_build_many_nn_operators(self): self.maxDiff = None conditions = [ - Q.nearestNeighbor( + qb.nearestNeighbor( field="colbert", query_vector=f"binary_vector_{i}", annotations={"targetHits": 100}, @@ -141,11 +137,7 @@ def test_build_many_nn_operators(self): for i in range(32) ] # Use Condition.any to combine conditions with OR - q = ( - Query(select_fields="*") - .from_("doc") - .where(condition=Condition.any(*conditions)) - ) + q = qb.select("*").from_("doc").where(condition=qb.any(*conditions)) expected = "select * from doc where " + " or ".join( [ f"({{targetHits:100}}nearestNeighbor(colbert, binary_vector_{i}))" @@ -156,48 +148,42 @@ def test_build_many_nn_operators(self): return q def test_field_comparison_operators(self): - f1 = Queryfield("age") + f1 = qb.Queryfield("age") condition = (f1 > 30) & (f1 <= 50) - q = Query(select_fields="*").from_("people").where(condition) + q = qb.select("*").from_("people").where(condition) expected = "select * from people where age > 30 and age <= 50" self.assertEqual(q, expected) return q def test_field_in_range(self): - f1 = Queryfield("age") + f1 = qb.Queryfield("age") condition = f1.in_range(18, 65) - q = Query(select_fields="*").from_("people").where(condition) + q = qb.select("*").from_("people").where(condition) expected = "select * from people where range(age, 18, 65)" self.assertEqual(q, expected) return q def test_field_annotation(self): - f1 = Queryfield("title") + f1 = qb.Queryfield("title") annotations = {"highlight": True} annotated_field = f1.annotate(annotations) - q = Query(select_fields="*").from_("articles").where(annotated_field) + q = qb.select("*").from_("articles").where(annotated_field) expected = "select * from articles where ({highlight:true})title" self.assertEqual(q, expected) return q def test_condition_annotation(self): - f1 = Queryfield("title") + f1 = qb.Queryfield("title") condition = f1.contains("Python") annotated_condition = condition.annotate({"filter": True}) - q = Query(select_fields="*").from_("articles").where(annotated_condition) + q = qb.select("*").from_("articles").where(annotated_condition) expected = 'select * from articles where {filter:true}title contains "Python"' self.assertEqual(q, expected) return q def test_grouping_with_condition(self): grouping = G.all(G.group("customer"), G.each(G.output(G.sum("price")))) - q = ( - Query(select_fields="*") - .from_("purchase") - .where(True) - .set_limit(0) - .groupby(grouping) - ) + q = qb.select("*").from_("purchase").where(True).set_limit(0).groupby(grouping) expected = "select * from purchase where true limit 0 | all(group(customer) each(output(sum(price))))" self.assertEqual(q, expected) return q @@ -211,7 +197,7 @@ def test_grouping_with_ordering_and_limiting(self): G.order(-G.count()), G.each(G.output(G.sum("price"))), ) - q = Query(select_fields="*").from_("purchase").where(True).groupby(grouping) + q = qb.select("*").from_("purchase").where(True).groupby(grouping) expected = "select * from purchase where true | all(group(customer) max(2) precision(12) order(-count()) each(output(sum(price))))" self.assertEqual(q, expected) return q @@ -221,14 +207,14 @@ def test_grouping_with_map_keys(self): G.group("mymap.key"), G.each(G.group("mymap.value"), G.each(G.output(G.count()))), ) - q = Query(select_fields="*").from_("purchase").where(True).groupby(grouping) + q = qb.select("*").from_("purchase").where(True).groupby(grouping) expected = "select * from purchase where true | all(group(mymap.key) each(group(mymap.value) each(output(count()))))" self.assertEqual(q, expected) return q def test_group_by_year(self): grouping = G.all(G.group("time.year(a)"), G.each(G.output(G.count()))) - q = Query(select_fields="*").from_("purchase").where(True).groupby(grouping) + q = qb.select("*").from_("purchase").where(True).groupby(grouping) expected = "select * from purchase where true | all(group(time.year(a)) each(output(count())))" self.assertEqual(q, expected) return q @@ -256,7 +242,7 @@ def test_grouping_with_date_agg(self): ), ), ) - q = Query(select_fields="*").from_("purchase").where(True).groupby(grouping) + q = qb.select("*").from_("purchase").where(True).groupby(grouping) expected = "select * from purchase where true | all(group(time.year(a)) each(output(count()) all(group(time.monthofyear(a)) each(output(count()) all(group(time.dayofmonth(a)) each(output(count()) all(group(time.hourofday(a)) each(output(count())))))))))" # q select * from purchase where true | all(group(time.year(a)) each(output(count()) all(group(time.monthofyear(a)) each(output(count()) all(group(time.dayofmonth(a)) each(output(count()) all(group(time.hourofday(a)) each(output(count()))))))))) # e select * from purchase where true | all(group(time.year(a)) each(output(count() all(group(time.monthofyear(a)) each(output(count()) all(group(time.dayofmonth(a)) each(output(count()) all(group(time.hourofday(a)) each(output(count()))))))))) @@ -266,10 +252,10 @@ def test_grouping_with_date_agg(self): return q def test_add_parameter(self): - f1 = Queryfield("title") + f1 = qb.Queryfield("title") condition = f1.contains("Python") q = ( - Query(select_fields="*") + qb.select("*") .from_("articles") .where(condition) .add_parameter("tracelevel", 1) @@ -279,17 +265,17 @@ def test_add_parameter(self): return q def test_custom_ranking_expression(self): - condition = Q.rank( - Q.userQuery(), Q.dotProduct("embedding", {"feature1": 1, "feature2": 2}) + condition = qb.rank( + qb.userQuery(), qb.dotProduct("embedding", {"feature1": 1, "feature2": 2}) ) - q = Query(select_fields="*").from_("documents").where(condition) + q = qb.select("*").from_("documents").where(condition) expected = 'select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1,"feature2":2}))' self.assertEqual(q, expected) return q def test_wand(self): - condition = Q.wand("keywords", {"apple": 10, "banana": 20}) - q = Query(select_fields="*").from_("fruits").where(condition) + condition = qb.wand("keywords", {"apple": 10, "banana": 20}) + q = qb.select("*").from_("fruits").where(condition) expected = ( 'select * from fruits where wand(keywords, {"apple":10, "banana":20})' ) @@ -297,20 +283,20 @@ def test_wand(self): return q def test_wand_numeric(self): - condition = Q.wand("description", [[11, 1], [37, 2]]) - q = Query(select_fields="*").from_("fruits").where(condition) + condition = qb.wand("description", [[11, 1], [37, 2]]) + q = qb.select("*").from_("fruits").where(condition) expected = "select * from fruits where wand(description, [[11, 1], [37, 2]])" self.assertEqual(q, expected) return q def test_wand_annotations(self): self.maxDiff = None - condition = Q.wand( + condition = qb.wand( "description", weights={"a": 1, "b": 2}, annotations={"scoreThreshold": 0.13, "targetHits": 7}, ) - q = Query(select_fields="*").from_("fruits").where(condition) + q = qb.select("*").from_("fruits").where(condition) expected = 'select * from fruits where ({scoreThreshold: 0.13, targetHits: 7}wand(description, {"a":1, "b":2}))' print(q) print(expected) @@ -318,29 +304,29 @@ def test_wand_annotations(self): return q def test_weakand(self): - condition1 = Queryfield("title").contains("Python") - condition2 = Queryfield("description").contains("Programming") - condition = Q.weakAnd( + condition1 = qb.Queryfield("title").contains("Python") + condition2 = qb.Queryfield("description").contains("Programming") + condition = qb.weakAnd( condition1, condition2, annotations={"targetNumHits": 100} ) - q = Query(select_fields="*").from_("articles").where(condition) + q = qb.select("*").from_("articles").where(condition) expected = 'select * from articles where ({"targetNumHits": 100}weakAnd(title contains "Python", description contains "Programming"))' self.assertEqual(q, expected) return q def test_geolocation(self): - condition = Q.geoLocation("location_field", 37.7749, -122.4194, "10km") - q = Query(select_fields="*").from_("places").where(condition) + condition = qb.geoLocation("location_field", 37.7749, -122.4194, "10km") + q = qb.select("*").from_("places").where(condition) expected = 'select * from places where geoLocation(location_field, 37.7749, -122.4194, "10km")' self.assertEqual(q, expected) return q def test_condition_all_any(self): - c1 = Queryfield("f1").contains("v1") - c2 = Queryfield("f2").contains("v2") - c3 = Queryfield("f3").contains("v3") - condition = Condition.all(c1, c2, Condition.any(c3, ~c1)) - q = Query(select_fields="*").from_("sd1").where(condition) + c1 = qb.Queryfield("f1").contains("v1") + c2 = qb.Queryfield("f2").contains("v2") + c3 = qb.Queryfield("f3").contains("v3") + condition = qb.all(c1, c2, qb.any(c3, ~c1)) + q = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where f1 contains "v1" and f2 contains "v2" and (f3 contains "v3" or !(f1 contains "v1"))' self.assertEqual(q, expected) return q @@ -349,12 +335,7 @@ def test_order_by_with_annotations(self): f1 = "relevance" f2 = "price" annotations = {"strength": 0.5} - q = ( - Query(select_fields="*") - .from_("products") - .orderByDesc(f1, annotations) - .orderByAsc(f2) - ) + q = qb.select("*").from_("products").orderByDesc(f1, annotations).orderByAsc(f2) expected = ( 'select * from products order by {"strength":0.5}relevance desc, price asc' ) @@ -362,64 +343,64 @@ def test_order_by_with_annotations(self): return q def test_field_comparison_methods_builtins(self): - f1 = Queryfield("age") + f1 = qb.Queryfield("age") condition = (f1 >= 18) & (f1 < 30) - q = Query(select_fields="*").from_("users").where(condition) + q = qb.select("*").from_("users").where(condition) expected = "select * from users where age >= 18 and age < 30" self.assertEqual(q, expected) return q def test_field_comparison_methods(self): - f1 = Queryfield("age") + f1 = qb.Queryfield("age") condition = (f1.ge(18) & f1.lt(30)) | f1.eq(40) - q = Query(select_fields="*").from_("users").where(condition) + q = qb.select("*").from_("users").where(condition) expected = "select * from users where (age >= 18 and age < 30) or age = 40" self.assertEqual(q, expected) return q def test_filter_annotation(self): - f1 = Queryfield("title") + f1 = qb.Queryfield("title") condition = f1.contains("Python").annotate({"filter": True}) - q = Query(select_fields="*").from_("articles").where(condition) + q = qb.select("*").from_("articles").where(condition) expected = 'select * from articles where {filter:true}title contains "Python"' self.assertEqual(q, expected) return q def test_non_empty(self): - condition = Q.nonEmpty(Queryfield("comments").eq("any_value")) - q = Query(select_fields="*").from_("posts").where(condition) + condition = qb.nonEmpty(qb.Queryfield("comments").eq("any_value")) + q = qb.select("*").from_("posts").where(condition) expected = 'select * from posts where nonEmpty(comments = "any_value")' self.assertEqual(q, expected) return q def test_dotproduct(self): - condition = Q.dotProduct("vector_field", {"feature1": 1, "feature2": 2}) - q = Query(select_fields="*").from_("vectors").where(condition) + condition = qb.dotProduct("vector_field", {"feature1": 1, "feature2": 2}) + q = qb.select("*").from_("vectors").where(condition) expected = 'select * from vectors where dotProduct(vector_field, {"feature1":1,"feature2":2})' self.assertEqual(q, expected) return q def test_in_range_string_values(self): - f1 = Queryfield("date") + f1 = qb.Queryfield("date") condition = f1.in_range("2021-01-01", "2021-12-31") - q = Query(select_fields="*").from_("events").where(condition) + q = qb.select("*").from_("events").where(condition) expected = "select * from events where range(date, 2021-01-01, 2021-12-31)" self.assertEqual(q, expected) return q def test_condition_inversion(self): - f1 = Queryfield("status") + f1 = qb.Queryfield("status") condition = ~f1.eq("inactive") - q = Query(select_fields="*").from_("users").where(condition) + q = qb.select("*").from_("users").where(condition) expected = 'select * from users where !(status = "inactive")' self.assertEqual(q, expected) return q def test_multiple_parameters(self): - f1 = Queryfield("title") + f1 = qb.Queryfield("title") condition = f1.contains("Python") q = ( - Query(select_fields="*") + qb.select("*") .from_("articles") .where(condition) .add_parameter("tracelevel", 1) @@ -436,203 +417,195 @@ def test_multiple_groupings(self): G.output(G.count()), G.each(G.group("subcategory"), G.output(G.summary())), ) - q = Query(select_fields="*").from_("products").groupby(grouping) + q = qb.select("*").from_("products").groupby(grouping) expected = "select * from products | all(group(category) max(10) output(count()) each(group(subcategory) output(summary())))" self.assertEqual(q, expected) return q def test_userquery_basic(self): - condition = Q.userQuery("search terms") - q = Query(select_fields="*").from_("documents").where(condition) + condition = qb.userQuery("search terms") + q = qb.select("*").from_("documents").where(condition) expected = 'select * from documents where userQuery("search terms")' self.assertEqual(q, expected) return q - def test_q_p_function(self): - condition = Q.p( - Queryfield("f1").contains("v1"), - Queryfield("f2").contains("v2"), - Queryfield("f3").contains("v3"), - ) - q = Query(select_fields="*").from_("sd1").where(condition) - expected = 'select * from sd1 where f1 contains "v1" and f2 contains "v2" and f3 contains "v3"' - self.assertEqual(q, expected) - def test_rank_multiple_conditions(self): - condition = Q.rank( - Q.userQuery(), - Q.dotProduct("embedding", {"feature1": 1}), - Q.weightedSet("tags", {"tag1": 2}), + condition = qb.rank( + qb.userQuery(), + qb.dotProduct("embedding", {"feature1": 1}), + qb.weightedSet("tags", {"tag1": 2}), ) - q = Query(select_fields="*").from_("documents").where(condition) + q = qb.select("*").from_("documents").where(condition) expected = 'select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1}), weightedSet(tags, {"tag1":2}))' self.assertEqual(q, expected) return q def test_non_empty_with_annotations(self): - annotated_field = Queryfield("comments").annotate({"filter": True}) - condition = Q.nonEmpty(annotated_field) - q = Query(select_fields="*").from_("posts").where(condition) + annotated_field = qb.Queryfield("comments").annotate({"filter": True}) + condition = qb.nonEmpty(annotated_field) + q = qb.select("*").from_("posts").where(condition) expected = "select * from posts where nonEmpty(({filter:true})comments)" self.assertEqual(q, expected) return q def test_weight_annotation(self): - condition = Queryfield("title").contains("heads", annotations={"weight": 200}) - q = Query(select_fields="*").from_("s1").where(condition) + condition = qb.Queryfield("title").contains( + "heads", annotations={"weight": 200} + ) + q = qb.select("*").from_("s1").where(condition) expected = 'select * from s1 where title contains({weight:200}"heads")' self.assertEqual(q, expected) return q def test_nearest_neighbor_annotations(self): - condition = Q.nearestNeighbor( + condition = qb.nearestNeighbor( field="dense_rep", query_vector="q_dense", annotations={"targetHits": 10} ) - q = Query(select_fields=["id, text"]).from_("m").where(condition) + q = qb.select(["id, text"]).from_("m").where(condition) expected = "select id, text from m where ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" self.assertEqual(q, expected) return q def test_phrase(self): - text = Queryfield("text") - condition = text.contains(Q.phrase("st", "louis", "blues")) - query = Q.select("*").where(condition) + text = qb.Queryfield("text") + condition = text.contains(qb.phrase("st", "louis", "blues")) + query = qb.select("*").where(condition) expected = 'select * from * where text contains phrase("st", "louis", "blues")' self.assertEqual(query, expected) return query def test_near(self): - title = Queryfield("title") - condition = title.contains(Q.near("madonna", "saint")) - query = Q.select("*").where(condition) + title = qb.Queryfield("title") + condition = title.contains(qb.near("madonna", "saint")) + query = qb.select("*").where(condition) expected = 'select * from * where title contains near("madonna", "saint")' self.assertEqual(query, expected) return query def test_near_with_distance(self): - title = Queryfield("title") - condition = title.contains(Q.near("madonna", "saint", distance=10)) - query = Q.select("*").where(condition) + title = qb.Queryfield("title") + condition = title.contains(qb.near("madonna", "saint", distance=10)) + query = qb.select("*").where(condition) expected = 'select * from * where title contains ({distance:10}near("madonna", "saint"))' self.assertEqual(query, expected) return query def test_onear(self): - title = Queryfield("title") - condition = title.contains(Q.onear("madonna", "saint")) - query = Q.select("*").where(condition) + title = qb.Queryfield("title") + condition = title.contains(qb.onear("madonna", "saint")) + query = qb.select("*").where(condition) expected = 'select * from * where title contains onear("madonna", "saint")' self.assertEqual(query, expected) return query def test_onear_with_distance(self): - title = Queryfield("title") - condition = title.contains(Q.onear("madonna", "saint", distance=5)) - query = Q.select("*").where(condition) + title = qb.Queryfield("title") + condition = title.contains(qb.onear("madonna", "saint", distance=5)) + query = qb.select("*").where(condition) expected = 'select * from * where title contains ({distance:5}onear("madonna", "saint"))' self.assertEqual(query, expected) return query def test_same_element(self): - persons = Queryfield("persons") - first_name = Queryfield("first_name") - last_name = Queryfield("last_name") - year_of_birth = Queryfield("year_of_birth") + persons = qb.Queryfield("persons") + first_name = qb.Queryfield("first_name") + last_name = qb.Queryfield("last_name") + year_of_birth = qb.Queryfield("year_of_birth") condition = persons.contains( - Q.sameElement( + qb.sameElement( first_name.contains("Joe"), last_name.contains("Smith"), year_of_birth < 1940, ) ) - query = Q.select("*").from_("sd1").where(condition) + query = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' self.assertEqual(query, expected) return query def test_equiv(self): - fieldName = Queryfield("fieldName") - condition = fieldName.contains(Q.equiv("A", "B")) - query = Q.select("*").where(condition) + fieldName = qb.Queryfield("fieldName") + condition = fieldName.contains(qb.equiv("A", "B")) + query = qb.select("*").where(condition) expected = 'select * from * where fieldName contains equiv("A", "B")' self.assertEqual(query, expected) return query def test_uri(self): - myUrlField = Queryfield("myUrlField") - condition = myUrlField.contains(Q.uri("vespa.ai/foo")) - query = Q.select("*").from_("sd1").where(condition) + myUrlField = qb.Queryfield("myUrlField") + condition = myUrlField.contains(qb.uri("vespa.ai/foo")) + query = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where myUrlField contains uri("vespa.ai/foo")' self.assertEqual(query, expected) return query def test_fuzzy(self): - myStringAttribute = Queryfield("f1") + myStringAttribute = qb.Queryfield("f1") annotations = {"prefixLength": 1, "maxEditDistance": 2} condition = myStringAttribute.contains( - Q.fuzzy("parantesis", annotations=annotations) + qb.fuzzy("parantesis", annotations=annotations) ) - query = Q.select("*").from_("sd1").where(condition) + query = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where f1 contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' self.assertEqual(query, expected) return query def test_userinput(self): - condition = Q.userInput("@myvar") - query = Q.select("*").from_("sd1").where(condition) + condition = qb.userInput("@myvar") + query = qb.select("*").from_("sd1").where(condition) expected = "select * from sd1 where userInput(@myvar)" self.assertEqual(query, expected) return query def test_userinput_param(self): - condition = Q.userInput("@animal") - query = Q.select("*").from_("sd1").where(condition).param("animal", "panda") + condition = qb.userInput("@animal") + query = qb.select("*").from_("sd1").where(condition).param("animal", "panda") expected = "select * from sd1 where userInput(@animal)&animal=panda" self.assertEqual(query, expected) return query def test_userinput_with_defaultindex(self): - condition = Q.userInput("@myvar").annotate({"defaultIndex": "text"}) - query = Q.select("*").from_("sd1").where(condition) + condition = qb.userInput("@myvar").annotate({"defaultIndex": "text"}) + query = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where {defaultIndex:"text"}userInput(@myvar)' self.assertEqual(query, expected) return query def test_in_operator_intfield(self): - integer_field = Queryfield("age") + integer_field = qb.Queryfield("age") condition = integer_field.in_(10, 20, 30) - query = Q.select("*").from_("sd1").where(condition) + query = qb.select("*").from_("sd1").where(condition) expected = "select * from sd1 where age in (10, 20, 30)" self.assertEqual(query, expected) return query def test_in_operator_stringfield(self): - string_field = Queryfield("status") + string_field = qb.Queryfield("status") condition = string_field.in_("active", "inactive") - query = Q.select("*").from_("sd1").where(condition) + query = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where status in ("active", "inactive")' self.assertEqual(query, expected) return query def test_predicate(self): - condition = Q.predicate( + condition = qb.predicate( "predicate_field", attributes={"gender": "Female"}, range_attributes={"age": "20L"}, ) - query = Q.select("*").from_("sd1").where(condition) + query = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where predicate(predicate_field,{"gender":"Female"},{"age":20L})' self.assertEqual(query, expected) return query def test_true(self): - query = Q.select("*").from_("sd1").where(True) + query = qb.select("*").from_("sd1").where(True) expected = "select * from sd1 where true" self.assertEqual(query, expected) return query def test_false(self): - query = Q.select("*").from_("sd1").where(False) + query = qb.select("*").from_("sd1").where(False) expected = "select * from sd1 where false" self.assertEqual(query, expected) return query diff --git a/vespa/querybuilder/__init__.py b/vespa/querybuilder/__init__.py index 7681bfae..271d5511 100644 --- a/vespa/querybuilder/__init__.py +++ b/vespa/querybuilder/__init__.py @@ -1,34 +1,35 @@ -from .builder.builder import Query, Q, Queryfield, Condition +from .builder.builder import Q from .grouping.grouping import G import inspect # Import original classes # ...existing code... -# Automatically expose all static methods from Q and G classes -for cls in [Q, G]: +# Automatically expose all static methods from Q +for cls in [Q]: # do not expose G for now for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): if not name.startswith("_"): # Create function with same name and signature as the static method globals()[name] = method + +def get_function_members(cls): + return [ + name + for name, method in inspect.getmembers(cls, predicate=inspect.isfunction) + if not name.startswith("_") + ] + + # Create __all__ list dynamically __all__ = [ # Classes - "Query", + # "Query", "Q", - "Queryfield", + # "Queryfield", "G", - "Condition", + # "Condition", # Add all exposed functions - *( - name - for name, method in inspect.getmembers(Q, predicate=inspect.isfunction) - if not name.startswith("_") - ), - *( - name - for name, method in inspect.getmembers(G, predicate=inspect.isfunction) - if not name.startswith("_") - ), + *get_function_members(Q), + *get_function_members(G), ] diff --git a/vespa/querybuilder/builder/builder.py b/vespa/querybuilder/builder/builder.py index de92f382..527603fa 100644 --- a/vespa/querybuilder/builder/builder.py +++ b/vespa/querybuilder/builder/builder.py @@ -30,6 +30,10 @@ def __and__(self, other: Any) -> "Condition": def __or__(self, other: Any) -> "Condition": return Condition(f"{self.name} or {self._format_value(other)}") + # repr as str + def __repr__(self) -> str: + return self.name + def contains( self, value: Any, annotations: Optional[Dict[str, Any]] = None ) -> "Condition": @@ -317,18 +321,16 @@ def build(self, prepend_yql=False) -> str: class Q: @staticmethod - def select(*fields): - return Query(select_fields=list(fields)) + def select(fields): + return Query(select_fields=fields) @staticmethod - def p(*args): - if not args: - return Condition("") - else: - condition = args[0] - for arg in args[1:]: - condition = condition & arg - return condition + def any(*conditions): + return Condition.any(*conditions) + + @staticmethod + def all(*conditions): + return Condition.all(*conditions) @staticmethod def userQuery(value: str = "") -> Condition: From 1310616b916fc4222d70b7348c7ee410c7dfc4e9 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Fri, 6 Dec 2024 08:30:17 +0100 Subject: [PATCH 38/39] fix module export --- vespa/querybuilder/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vespa/querybuilder/__init__.py b/vespa/querybuilder/__init__.py index 271d5511..da996f31 100644 --- a/vespa/querybuilder/__init__.py +++ b/vespa/querybuilder/__init__.py @@ -1,4 +1,4 @@ -from .builder.builder import Q +from .builder.builder import Q, Queryfield from .grouping.grouping import G import inspect @@ -25,8 +25,8 @@ def get_function_members(cls): __all__ = [ # Classes # "Query", - "Q", - # "Queryfield", + # "Q", + "Queryfield", "G", # "Condition", # Add all exposed functions From 0b7222bb6146951f5d2d58f162b0bcaa42a39f8a Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Fri, 6 Dec 2024 13:19:38 +0100 Subject: [PATCH 39/39] renames --- tests/unit/test_q.py | 175 +++++++++++++----------- vespa/querybuilder/__init__.py | 9 +- vespa/querybuilder/builder/builder.py | 44 +++--- vespa/querybuilder/grouping/grouping.py | 2 +- 4 files changed, 120 insertions(+), 110 deletions(-) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py index 4ae60351..be475a41 100644 --- a/tests/unit/test_q.py +++ b/tests/unit/test_q.py @@ -1,5 +1,5 @@ import unittest -from vespa.querybuilder import G +from vespa.querybuilder import Grouping import vespa.querybuilder as qb @@ -29,31 +29,31 @@ def test_geolocation_with_annotations(self): return q def test_select_specific_fields(self): - f1 = qb.Queryfield("f1") + f1 = qb.QueryField("f1") condition = f1.contains("v1") q = qb.select(["f1", "f2"]).from_("sd1").where(condition) self.assertEqual(q, 'select f1, f2 from sd1 where f1 contains "v1"') return q def test_select_from_specific_sources(self): - f1 = qb.Queryfield("f1") + f1 = qb.QueryField("f1") condition = f1.contains("v1") q = qb.select("*").from_("sd1").where(condition) self.assertEqual(q, 'select * from sd1 where f1 contains "v1"') return q def test_select_from_multiples_sources(self): - f1 = qb.Queryfield("f1") + f1 = qb.QueryField("f1") condition = f1.contains("v1") q = qb.select("*").from_("sd1", "sd2").where(condition) self.assertEqual(q, 'select * from sd1, sd2 where f1 contains "v1"') return q def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): - f1 = qb.Queryfield("f1") - f2 = qb.Queryfield("f2") - f3 = qb.Queryfield("f3") - f4 = qb.Queryfield("f4") + f1 = qb.QueryField("f1") + f2 = qb.QueryField("f2") + f3 = qb.QueryField("f3") + f4 = qb.QueryField("f4") condition = ((f1.contains("v1") & f2.contains("v2")) | f3.contains("v3")) & ( ~f4.contains("v4") ) @@ -73,7 +73,7 @@ def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): return q def test_timeout(self): - f1 = qb.Queryfield("title") + f1 = qb.QueryField("title") condition = f1.contains("madonna") q = qb.select("*").from_("sd1").where(condition).set_timeout(70) expected = 'select * from sd1 where title contains "madonna" timeout 70' @@ -82,9 +82,9 @@ def test_timeout(self): def test_matches(self): condition = ( - (qb.Queryfield("f1").matches("v1") & qb.Queryfield("f2").matches("v2")) - | qb.Queryfield("f3").matches("v3") - ) & ~qb.Queryfield("f4").matches("v4") + (qb.QueryField("f1").matches("v1") & qb.QueryField("f2").matches("v2")) + | qb.QueryField("f3").matches("v3") + ) & ~qb.QueryField("f4").matches("v4") q = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where ((f1 matches "v1" and f2 matches "v2") or f3 matches "v3") and !(f4 matches "v4")' self.assertEqual(q, expected) @@ -92,9 +92,9 @@ def test_matches(self): def test_nested_queries(self): nested_query = ( - qb.Queryfield("f2").contains("2") & qb.Queryfield("f3").contains("3") - ) | (qb.Queryfield("f2").contains("4") & ~qb.Queryfield("f3").contains("5")) - condition = qb.Queryfield("f1").contains("1") & ~nested_query + qb.QueryField("f2").contains("2") & qb.QueryField("f3").contains("3") + ) | (qb.QueryField("f2").contains("4") & ~qb.QueryField("f3").contains("5")) + condition = qb.QueryField("f1").contains("1") & ~nested_query q = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where f1 contains "1" and (!((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5"))))' self.assertEqual(q, expected) @@ -108,9 +108,9 @@ def test_userquery(self): return q def test_fields_duration(self): - f1 = qb.Queryfield("subject") - f2 = qb.Queryfield("display_date") - f3 = qb.Queryfield("duration") + f1 = qb.QueryField("subject") + f2 = qb.QueryField("display_date") + f3 = qb.QueryField("duration") q = qb.select([f1, f2]).from_("calendar").where(f3 > 0) expected = "select subject, display_date from calendar where duration > 0" self.assertEqual(q, expected) @@ -148,7 +148,7 @@ def test_build_many_nn_operators(self): return q def test_field_comparison_operators(self): - f1 = qb.Queryfield("age") + f1 = qb.QueryField("age") condition = (f1 > 30) & (f1 <= 50) q = qb.select("*").from_("people").where(condition) expected = "select * from people where age > 30 and age <= 50" @@ -156,7 +156,7 @@ def test_field_comparison_operators(self): return q def test_field_in_range(self): - f1 = qb.Queryfield("age") + f1 = qb.QueryField("age") condition = f1.in_range(18, 65) q = qb.select("*").from_("people").where(condition) expected = "select * from people where range(age, 18, 65)" @@ -164,7 +164,7 @@ def test_field_in_range(self): return q def test_field_annotation(self): - f1 = qb.Queryfield("title") + f1 = qb.QueryField("title") annotations = {"highlight": True} annotated_field = f1.annotate(annotations) q = qb.select("*").from_("articles").where(annotated_field) @@ -173,7 +173,7 @@ def test_field_annotation(self): return q def test_condition_annotation(self): - f1 = qb.Queryfield("title") + f1 = qb.QueryField("title") condition = f1.contains("Python") annotated_condition = condition.annotate({"filter": True}) q = qb.select("*").from_("articles").where(annotated_condition) @@ -182,7 +182,10 @@ def test_condition_annotation(self): return q def test_grouping_with_condition(self): - grouping = G.all(G.group("customer"), G.each(G.output(G.sum("price")))) + grouping = Grouping.all( + Grouping.group("customer"), + Grouping.each(Grouping.output(Grouping.sum("price"))), + ) q = qb.select("*").from_("purchase").where(True).set_limit(0).groupby(grouping) expected = "select * from purchase where true limit 0 | all(group(customer) each(output(sum(price))))" self.assertEqual(q, expected) @@ -190,12 +193,12 @@ def test_grouping_with_condition(self): def test_grouping_with_ordering_and_limiting(self): self.maxDiff = None - grouping = G.all( - G.group("customer"), - G.max(2), - G.precision(12), - G.order(-G.count()), - G.each(G.output(G.sum("price"))), + grouping = Grouping.all( + Grouping.group("customer"), + Grouping.max(2), + Grouping.precision(12), + Grouping.order(-Grouping.count()), + Grouping.each(Grouping.output(Grouping.sum("price"))), ) q = qb.select("*").from_("purchase").where(True).groupby(grouping) expected = "select * from purchase where true | all(group(customer) max(2) precision(12) order(-count()) each(output(sum(price))))" @@ -203,9 +206,12 @@ def test_grouping_with_ordering_and_limiting(self): return q def test_grouping_with_map_keys(self): - grouping = G.all( - G.group("mymap.key"), - G.each(G.group("mymap.value"), G.each(G.output(G.count()))), + grouping = Grouping.all( + Grouping.group("mymap.key"), + Grouping.each( + Grouping.group("mymap.value"), + Grouping.each(Grouping.output(Grouping.count())), + ), ) q = qb.select("*").from_("purchase").where(True).groupby(grouping) expected = "select * from purchase where true | all(group(mymap.key) each(group(mymap.value) each(output(count()))))" @@ -213,28 +219,31 @@ def test_grouping_with_map_keys(self): return q def test_group_by_year(self): - grouping = G.all(G.group("time.year(a)"), G.each(G.output(G.count()))) + grouping = Grouping.all( + Grouping.group("time.year(a)"), + Grouping.each(Grouping.output(Grouping.count())), + ) q = qb.select("*").from_("purchase").where(True).groupby(grouping) expected = "select * from purchase where true | all(group(time.year(a)) each(output(count())))" self.assertEqual(q, expected) return q def test_grouping_with_date_agg(self): - grouping = G.all( - G.group("time.year(a)"), - G.each( - G.output(G.count()), - G.all( - G.group("time.monthofyear(a)"), - G.each( - G.output(G.count()), - G.all( - G.group("time.dayofmonth(a)"), - G.each( - G.output(G.count()), - G.all( - G.group("time.hourofday(a)"), - G.each(G.output(G.count())), + grouping = Grouping.all( + Grouping.group("time.year(a)"), + Grouping.each( + Grouping.output(Grouping.count()), + Grouping.all( + Grouping.group("time.monthofyear(a)"), + Grouping.each( + Grouping.output(Grouping.count()), + Grouping.all( + Grouping.group("time.dayofmonth(a)"), + Grouping.each( + Grouping.output(Grouping.count()), + Grouping.all( + Grouping.group("time.hourofday(a)"), + Grouping.each(Grouping.output(Grouping.count())), ), ), ), @@ -252,7 +261,7 @@ def test_grouping_with_date_agg(self): return q def test_add_parameter(self): - f1 = qb.Queryfield("title") + f1 = qb.QueryField("title") condition = f1.contains("Python") q = ( qb.select("*") @@ -304,8 +313,8 @@ def test_wand_annotations(self): return q def test_weakand(self): - condition1 = qb.Queryfield("title").contains("Python") - condition2 = qb.Queryfield("description").contains("Programming") + condition1 = qb.QueryField("title").contains("Python") + condition2 = qb.QueryField("description").contains("Programming") condition = qb.weakAnd( condition1, condition2, annotations={"targetNumHits": 100} ) @@ -322,9 +331,9 @@ def test_geolocation(self): return q def test_condition_all_any(self): - c1 = qb.Queryfield("f1").contains("v1") - c2 = qb.Queryfield("f2").contains("v2") - c3 = qb.Queryfield("f3").contains("v3") + c1 = qb.QueryField("f1").contains("v1") + c2 = qb.QueryField("f2").contains("v2") + c3 = qb.QueryField("f3").contains("v3") condition = qb.all(c1, c2, qb.any(c3, ~c1)) q = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where f1 contains "v1" and f2 contains "v2" and (f3 contains "v3" or !(f1 contains "v1"))' @@ -343,7 +352,7 @@ def test_order_by_with_annotations(self): return q def test_field_comparison_methods_builtins(self): - f1 = qb.Queryfield("age") + f1 = qb.QueryField("age") condition = (f1 >= 18) & (f1 < 30) q = qb.select("*").from_("users").where(condition) expected = "select * from users where age >= 18 and age < 30" @@ -351,7 +360,7 @@ def test_field_comparison_methods_builtins(self): return q def test_field_comparison_methods(self): - f1 = qb.Queryfield("age") + f1 = qb.QueryField("age") condition = (f1.ge(18) & f1.lt(30)) | f1.eq(40) q = qb.select("*").from_("users").where(condition) expected = "select * from users where (age >= 18 and age < 30) or age = 40" @@ -359,7 +368,7 @@ def test_field_comparison_methods(self): return q def test_filter_annotation(self): - f1 = qb.Queryfield("title") + f1 = qb.QueryField("title") condition = f1.contains("Python").annotate({"filter": True}) q = qb.select("*").from_("articles").where(condition) expected = 'select * from articles where {filter:true}title contains "Python"' @@ -367,7 +376,7 @@ def test_filter_annotation(self): return q def test_non_empty(self): - condition = qb.nonEmpty(qb.Queryfield("comments").eq("any_value")) + condition = qb.nonEmpty(qb.QueryField("comments").eq("any_value")) q = qb.select("*").from_("posts").where(condition) expected = 'select * from posts where nonEmpty(comments = "any_value")' self.assertEqual(q, expected) @@ -381,7 +390,7 @@ def test_dotproduct(self): return q def test_in_range_string_values(self): - f1 = qb.Queryfield("date") + f1 = qb.QueryField("date") condition = f1.in_range("2021-01-01", "2021-12-31") q = qb.select("*").from_("events").where(condition) expected = "select * from events where range(date, 2021-01-01, 2021-12-31)" @@ -389,7 +398,7 @@ def test_in_range_string_values(self): return q def test_condition_inversion(self): - f1 = qb.Queryfield("status") + f1 = qb.QueryField("status") condition = ~f1.eq("inactive") q = qb.select("*").from_("users").where(condition) expected = 'select * from users where !(status = "inactive")' @@ -397,7 +406,7 @@ def test_condition_inversion(self): return q def test_multiple_parameters(self): - f1 = qb.Queryfield("title") + f1 = qb.QueryField("title") condition = f1.contains("Python") q = ( qb.select("*") @@ -411,11 +420,13 @@ def test_multiple_parameters(self): return q def test_multiple_groupings(self): - grouping = G.all( - G.group("category"), - G.max(10), - G.output(G.count()), - G.each(G.group("subcategory"), G.output(G.summary())), + grouping = Grouping.all( + Grouping.group("category"), + Grouping.max(10), + Grouping.output(Grouping.count()), + Grouping.each( + Grouping.group("subcategory"), Grouping.output(Grouping.summary()) + ), ) q = qb.select("*").from_("products").groupby(grouping) expected = "select * from products | all(group(category) max(10) output(count()) each(group(subcategory) output(summary())))" @@ -441,7 +452,7 @@ def test_rank_multiple_conditions(self): return q def test_non_empty_with_annotations(self): - annotated_field = qb.Queryfield("comments").annotate({"filter": True}) + annotated_field = qb.QueryField("comments").annotate({"filter": True}) condition = qb.nonEmpty(annotated_field) q = qb.select("*").from_("posts").where(condition) expected = "select * from posts where nonEmpty(({filter:true})comments)" @@ -449,7 +460,7 @@ def test_non_empty_with_annotations(self): return q def test_weight_annotation(self): - condition = qb.Queryfield("title").contains( + condition = qb.QueryField("title").contains( "heads", annotations={"weight": 200} ) q = qb.select("*").from_("s1").where(condition) @@ -467,7 +478,7 @@ def test_nearest_neighbor_annotations(self): return q def test_phrase(self): - text = qb.Queryfield("text") + text = qb.QueryField("text") condition = text.contains(qb.phrase("st", "louis", "blues")) query = qb.select("*").where(condition) expected = 'select * from * where text contains phrase("st", "louis", "blues")' @@ -475,7 +486,7 @@ def test_phrase(self): return query def test_near(self): - title = qb.Queryfield("title") + title = qb.QueryField("title") condition = title.contains(qb.near("madonna", "saint")) query = qb.select("*").where(condition) expected = 'select * from * where title contains near("madonna", "saint")' @@ -483,7 +494,7 @@ def test_near(self): return query def test_near_with_distance(self): - title = qb.Queryfield("title") + title = qb.QueryField("title") condition = title.contains(qb.near("madonna", "saint", distance=10)) query = qb.select("*").where(condition) expected = 'select * from * where title contains ({distance:10}near("madonna", "saint"))' @@ -491,7 +502,7 @@ def test_near_with_distance(self): return query def test_onear(self): - title = qb.Queryfield("title") + title = qb.QueryField("title") condition = title.contains(qb.onear("madonna", "saint")) query = qb.select("*").where(condition) expected = 'select * from * where title contains onear("madonna", "saint")' @@ -499,7 +510,7 @@ def test_onear(self): return query def test_onear_with_distance(self): - title = qb.Queryfield("title") + title = qb.QueryField("title") condition = title.contains(qb.onear("madonna", "saint", distance=5)) query = qb.select("*").where(condition) expected = 'select * from * where title contains ({distance:5}onear("madonna", "saint"))' @@ -507,10 +518,10 @@ def test_onear_with_distance(self): return query def test_same_element(self): - persons = qb.Queryfield("persons") - first_name = qb.Queryfield("first_name") - last_name = qb.Queryfield("last_name") - year_of_birth = qb.Queryfield("year_of_birth") + persons = qb.QueryField("persons") + first_name = qb.QueryField("first_name") + last_name = qb.QueryField("last_name") + year_of_birth = qb.QueryField("year_of_birth") condition = persons.contains( qb.sameElement( first_name.contains("Joe"), @@ -524,7 +535,7 @@ def test_same_element(self): return query def test_equiv(self): - fieldName = qb.Queryfield("fieldName") + fieldName = qb.QueryField("fieldName") condition = fieldName.contains(qb.equiv("A", "B")) query = qb.select("*").where(condition) expected = 'select * from * where fieldName contains equiv("A", "B")' @@ -532,7 +543,7 @@ def test_equiv(self): return query def test_uri(self): - myUrlField = qb.Queryfield("myUrlField") + myUrlField = qb.QueryField("myUrlField") condition = myUrlField.contains(qb.uri("vespa.ai/foo")) query = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where myUrlField contains uri("vespa.ai/foo")' @@ -540,7 +551,7 @@ def test_uri(self): return query def test_fuzzy(self): - myStringAttribute = qb.Queryfield("f1") + myStringAttribute = qb.QueryField("f1") annotations = {"prefixLength": 1, "maxEditDistance": 2} condition = myStringAttribute.contains( qb.fuzzy("parantesis", annotations=annotations) @@ -572,7 +583,7 @@ def test_userinput_with_defaultindex(self): return query def test_in_operator_intfield(self): - integer_field = qb.Queryfield("age") + integer_field = qb.QueryField("age") condition = integer_field.in_(10, 20, 30) query = qb.select("*").from_("sd1").where(condition) expected = "select * from sd1 where age in (10, 20, 30)" @@ -580,7 +591,7 @@ def test_in_operator_intfield(self): return query def test_in_operator_stringfield(self): - string_field = qb.Queryfield("status") + string_field = qb.QueryField("status") condition = string_field.in_("active", "inactive") query = qb.select("*").from_("sd1").where(condition) expected = 'select * from sd1 where status in ("active", "inactive")' diff --git a/vespa/querybuilder/__init__.py b/vespa/querybuilder/__init__.py index da996f31..ecdad8bb 100644 --- a/vespa/querybuilder/__init__.py +++ b/vespa/querybuilder/__init__.py @@ -1,5 +1,5 @@ -from .builder.builder import Q, Queryfield -from .grouping.grouping import G +from .builder.builder import Q, QueryField +from .grouping.grouping import Grouping import inspect # Import original classes @@ -26,10 +26,9 @@ def get_function_members(cls): # Classes # "Query", # "Q", - "Queryfield", - "G", + "QueryField", + "Grouping", # "Condition", # Add all exposed functions *get_function_members(Q), - *get_function_members(G), ] diff --git a/vespa/querybuilder/builder/builder.py b/vespa/querybuilder/builder/builder.py index 527603fa..bad58251 100644 --- a/vespa/querybuilder/builder/builder.py +++ b/vespa/querybuilder/builder/builder.py @@ -3,7 +3,7 @@ @dataclass -class Queryfield: +class QueryField: name: str def __eq__(self, other: Any) -> "Condition": @@ -117,7 +117,7 @@ def _format_annotation_value(value: Any) -> str: return ( "{" + ",".join( - f'"{k}":{Queryfield._format_annotation_value(v)}' + f'"{k}":{QueryField._format_annotation_value(v)}' for k, v in value.items() ) + "}" @@ -125,7 +125,7 @@ def _format_annotation_value(value: Any) -> str: elif isinstance(value, list): return ( "[" - + ",".join(f"{Queryfield._format_annotation_value(v)}" for v in value) + + ",".join(f"{QueryField._format_annotation_value(v)}" for v in value) + "]" ) else: @@ -167,7 +167,7 @@ def __invert__(self) -> "Condition": def annotate(self, annotations: Dict[str, Any]) -> "Condition": annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) return Condition(f"{{{annotations_str}}}{self.expression}") @@ -204,7 +204,7 @@ def any(cls, *conditions: "Condition") -> "Condition": class Query: def __init__( - self, select_fields: Union[str, List[str], List[Queryfield]], prepend_yql=False + self, select_fields: Union[str, List[str], List[QueryField]], prepend_yql=False ): self.select_fields = ( ", ".join(select_fields) @@ -238,8 +238,8 @@ def from_(self, *sources: str) -> "Query": self.sources = ", ".join(sources) return self - def where(self, condition: Union[Condition, Queryfield, bool]) -> "Query": - if isinstance(condition, Queryfield): + def where(self, condition: Union[Condition, QueryField, bool]) -> "Query": + if isinstance(condition, QueryField): self.condition = condition elif isinstance(condition, bool): self.condition = Condition("true") if condition else Condition("false") @@ -256,7 +256,7 @@ def order_by_field( direction = "asc" if ascending else "desc" if annotations: annotations_str = ",".join( - f'"{k}":{Queryfield._format_annotation_value(v)}' + f'"{k}":{QueryField._format_annotation_value(v)}' for k, v in annotations.items() ) self.order_by_clauses.append(f"{{{annotations_str}}}{field} {direction}") @@ -344,7 +344,7 @@ def dotProduct( expr = f"dotProduct({field}, {vector_str})" if annotations: annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -358,15 +358,15 @@ def weightedSet( expr = f"weightedSet({field}, {vector_str})" if annotations: annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" return Condition(expr) @staticmethod - def nonEmpty(condition: Union[Condition, Queryfield]) -> Condition: - if isinstance(condition, Queryfield): + def nonEmpty(condition: Union[Condition, QueryField]) -> Condition: + if isinstance(condition, QueryField): expr = str(condition) else: expr = condition.build() @@ -387,7 +387,7 @@ def wand( expr = f"wand({field}, {weights_str})" if annotations: annotations_str = ", ".join( - f"{k}: {Queryfield._format_annotation_value(v)}" + f"{k}: {QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -399,7 +399,7 @@ def weakAnd(*conditions, annotations: Dict[str, Any] = None) -> Condition: expr = f"weakAnd({conditions_str})" if annotations: annotations_str = ",".join( - f'"{k}": {Queryfield._format_annotation_value(v)}' + f'"{k}": {QueryField._format_annotation_value(v)}' for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -416,7 +416,7 @@ def geoLocation( expr = f'geoLocation({field}, {lat}, {lng}, "{radius}")' if annotations: annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -429,7 +429,7 @@ def nearestNeighbor( if "targetHits" not in annotations: raise ValueError("targetHits annotation is required") annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) return Condition( @@ -447,7 +447,7 @@ def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: expr = f"phrase({terms_str})" if annotations: annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -466,7 +466,7 @@ def near( annotations.update(kwargs) if annotations: annotations_str = ", ".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -485,7 +485,7 @@ def onear( annotations.update(kwargs) if annotations: annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -508,7 +508,7 @@ def uri(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: expr = f'uri("{value}")' if annotations: annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -519,7 +519,7 @@ def fuzzy(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition expr = f'fuzzy("{value}")' if annotations: annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" @@ -537,7 +537,7 @@ def userInput( expr = f'userInput("{value}")' if annotations: annotations_str = ",".join( - f"{k}:{Queryfield._format_annotation_value(v)}" + f"{k}:{QueryField._format_annotation_value(v)}" for k, v in annotations.items() ) expr = f"({{{annotations_str}}}{expr})" diff --git a/vespa/querybuilder/grouping/grouping.py b/vespa/querybuilder/grouping/grouping.py index f9494d84..ea75570f 100644 --- a/vespa/querybuilder/grouping/grouping.py +++ b/vespa/querybuilder/grouping/grouping.py @@ -1,7 +1,7 @@ from typing import Union -class G: +class Grouping: @staticmethod def all(*args) -> str: return "all(" + " ".join(args) + ")"