From a39891166b1ab7ee73abaf1c4928f4e8c7478990 Mon Sep 17 00:00:00 2001 From: loryruta Date: Wed, 10 Apr 2024 10:24:15 +0200 Subject: [PATCH] query: add any/all #24 --- objectbox/c.py | 16 +++- objectbox/query_builder.py | 117 ++++++++++++++----------- tests/test_hnsw.py | 30 +++---- tests/test_query.py | 170 +++++++++++++++++++++++++------------ 4 files changed, 212 insertions(+), 121 deletions(-) diff --git a/objectbox/c.py b/objectbox/c.py index 1a65936..63ae72e 100644 --- a/objectbox/c.py +++ b/objectbox/c.py @@ -68,6 +68,8 @@ def shlib_name(library: str) -> str: obx_id = ctypes.c_uint64 obx_qb_cond = ctypes.c_int +obx_qb_cond_p = ctypes.POINTER(obx_qb_cond) + # enums OBXPropertyType = ctypes.c_int OBXPropertyFlags = ctypes.c_int @@ -327,6 +329,16 @@ def c_voidp_as_bytes(voidp, size): return memoryview(ctypes.cast(voidp, ctypes.POINTER(ctypes.c_ubyte * size))[0]).tobytes() +def py_list_to_c_array(py_list: List[Any], c_type): + """ Converts the given python list into a C array. """ + return (c_type * len(py_list))(*py_list) + + +def py_list_to_c_pointer(py_list: List[Any], c_type): + """ Converts the given python list into a C array and returns a pointer type. """ + return ctypes.cast(py_list_to_c_array(py_list, c_type), ctypes.POINTER(c_type)) + + # OBX_model* (void); obx_model = c_fn('obx_model', OBX_model_p, []) @@ -656,10 +668,10 @@ def c_voidp_as_bytes(voidp, size): [OBX_query_builder_p, obx_schema_id, ctypes.c_void_p, ctypes.c_size_t]) # OBX_C_API obx_qb_cond obx_qb_all(OBX_query_builder* builder, const obx_qb_cond conditions[], size_t count); -obx_qb_all = c_fn('obx_qb_all', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond, ctypes.c_size_t]) +obx_qb_all = c_fn('obx_qb_all', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond_p, ctypes.c_size_t]) # OBX_C_API obx_qb_cond obx_qb_any(OBX_query_builder* builder, const obx_qb_cond conditions[], size_t count); -obx_qb_any = c_fn('obx_qb_any', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond, ctypes.c_size_t]) +obx_qb_any = c_fn('obx_qb_any', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond_p, ctypes.c_size_t]) # OBX_C_API obx_err obx_qb_param_alias(OBX_query_builder* builder, const char* alias); obx_qb_param_alias = c_fn_rc('obx_qb_param_alias', [OBX_query_builder_p, ctypes.c_char_p]) diff --git a/objectbox/query_builder.py b/objectbox/query_builder.py index 3687bd8..ab6086e 100644 --- a/objectbox/query_builder.py +++ b/objectbox/query_builder.py @@ -30,85 +30,90 @@ def error_code(self) -> int: def error_message(self) -> str: return obx_qb_error_message(self._c_builder) - def equals_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True): + def equals_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive) - return self + cond = obx_qb_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive) + return cond - def not_equals_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True): + def not_equals_string(self, prop: Union[int, str, Property], value: str, + case_sensitive: bool = True) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_not_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive) - return self + cond = obx_qb_not_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive) + return cond - def contains_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True): + def contains_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_contains_string(self._c_builder, prop_id, c_str(value), case_sensitive) - return self + cond = obx_qb_contains_string(self._c_builder, prop_id, c_str(value), case_sensitive) + return cond - def starts_with_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True): + def starts_with_string(self, prop: Union[int, str, Property], value: str, + case_sensitive: bool = True) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_starts_with_string(self._c_builder, prop_id, c_str(value), case_sensitive) - return self + cond = obx_qb_starts_with_string(self._c_builder, prop_id, c_str(value), case_sensitive) + return cond - def ends_with_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True): + def ends_with_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_ends_with_string(self._c_builder, prop_id, c_str(value), case_sensitive) - return self + cond = obx_qb_ends_with_string(self._c_builder, prop_id, c_str(value), case_sensitive) + return cond - def greater_than_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True): + def greater_than_string(self, prop: Union[int, str, Property], value: str, + case_sensitive: bool = True) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_greater_than_string(self._c_builder, prop_id, c_str(value), case_sensitive) - return self + cond = obx_qb_greater_than_string(self._c_builder, prop_id, c_str(value), case_sensitive) + return cond - def greater_or_equal_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True): + def greater_or_equal_string(self, prop: Union[int, str, Property], value: str, + case_sensitive: bool = True) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_greater_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive) - return self + cond = obx_qb_greater_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive) + return cond - def less_than_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True): + def less_than_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_less_than_string(self._c_builder, prop_id, c_str(value), case_sensitive) - return self + cond = obx_qb_less_than_string(self._c_builder, prop_id, c_str(value), case_sensitive) + return cond - def less_or_equal_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True): + def less_or_equal_string(self, prop: Union[int, str, Property], value: str, + case_sensitive: bool = True) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_less_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive) - return self + cond = obx_qb_less_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive) + return cond - def equals_int(self, prop: Union[int, str, Property], value: int): + def equals_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_equals_int(self._c_builder, prop_id, value) - return self + cond = obx_qb_equals_int(self._c_builder, prop_id, value) + return cond - def not_equals_int(self, prop: Union[int, str, Property], value: int): + def not_equals_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_not_equals_int(self._c_builder, prop_id, value) - return self + cond = obx_qb_not_equals_int(self._c_builder, prop_id, value) + return cond - def greater_than_int(self, prop: Union[int, str, Property], value: int): + def greater_than_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_greater_than_int(self._c_builder, prop_id, value) - return self + cond = obx_qb_greater_than_int(self._c_builder, prop_id, value) + return cond - def greater_or_equal_int(self, prop: Union[int, str, Property], value: int): + def greater_or_equal_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_greater_or_equal_int(self._c_builder, prop_id, value) - return self + cond = obx_qb_greater_or_equal_int(self._c_builder, prop_id, value) + return cond - def less_than_int(self, prop: Union[int, str, Property], value: int): + def less_than_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_less_than_int(self._c_builder, prop_id, value) - return self + cond = obx_qb_less_than_int(self._c_builder, prop_id, value) + return cond - def less_or_equal_int(self, prop: Union[int, str, Property], value: int): + def less_or_equal_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_less_or_equal_int(self._c_builder, prop_id, value) - return self + cond = obx_qb_less_or_equal_int(self._c_builder, prop_id, value) + return cond - def between_2ints(self, prop: Union[int, str, Property], value_a: int, value_b: int): + def between_2ints(self, prop: Union[int, str, Property], value_a: int, value_b: int) -> obx_qb_cond: prop_id = self._get_property_id(prop) - obx_qb_between_2ints(self._c_builder, prop_id, value_a, value_b) - return self + cond = obx_qb_between_2ints(self._c_builder, prop_id, value_a, value_b) + return cond def nearest_neighbors_f32(self, prop: Union[int, str, Property], query_vector: Union[np.ndarray, List[float]], element_count: int): @@ -117,11 +122,21 @@ def nearest_neighbors_f32(self, prop: Union[int, str, Property], query_vector: U raise Exception(f"query_vector dtype must be float32") query_vector_data = query_vector.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) else: # List[float] - query_vector_data = (ctypes.c_float * len(query_vector))(*query_vector) + query_vector_data = py_list_to_c_array(query_vector, ctypes.c_float) prop_id = self._get_property_id(prop) - obx_qb_nearest_neighbors_f32(self._c_builder, prop_id, query_vector_data, element_count) - return self + cond = obx_qb_nearest_neighbors_f32(self._c_builder, prop_id, query_vector_data, element_count) + return cond + + def any(self, conditions: List[obx_qb_cond]) -> obx_qb_cond: + c_conditions = py_list_to_c_pointer(conditions, obx_qb_cond) + cond = obx_qb_any(self._c_builder, c_conditions, len(conditions)) + return cond + + def all(self, conditions: List[obx_qb_cond]) -> obx_qb_cond: + c_conditions = py_list_to_c_pointer(conditions, obx_qb_cond) + cond = obx_qb_all(self._c_builder, c_conditions, len(conditions)) + return cond def build(self) -> Query: c_query = obx_query(self._c_builder) diff --git a/tests/test_hnsw.py b/tests/test_hnsw.py index baf5834..1a5eeae 100644 --- a/tests/test_hnsw.py +++ b/tests/test_hnsw.py @@ -57,9 +57,9 @@ def _test_random_points(num_points: int, num_query_points: int, seed: Optional[i assert len(expected_result) == k # Run ANN with OBX - query_builder = QueryBuilder(db, box) - query_builder.nearest_neighbors_f32("vector", query_point, k) - query = query_builder.build() + qb = box.query() + qb.nearest_neighbors_f32("vector", query_point, k) + query = qb.build() obx_result = [id_ for id_, score in query.find_ids_with_scores()] # Ignore score assert len(obx_result) == k @@ -100,10 +100,10 @@ def test_combined_nn_search(): assert box.count() == 9 # Test condition + NN search - query = box.query() \ - .nearest_neighbors_f32("vector", [4.1, 4.2], 6) \ - .contains_string("name", "red", case_sensitive=False) \ - .build() + qb = box.query() + qb.nearest_neighbors_f32("vector", [4.1, 4.2], 6) + qb.contains_string("name", "red", case_sensitive=False) + query = qb.build() # 4, 5, 3, 6, 2, 7 # Filtered: 3, 6, 7 search_results = query.find_with_scores() @@ -120,20 +120,20 @@ def test_combined_nn_search(): assert search_results[0][0].name == "Red apple" # Regular condition + NN search - query = box.query() \ - .nearest_neighbors_f32("vector", [9.2, 8.9], 7) \ - .starts_with_string("name", "Blue", case_sensitive=True) \ - .build() + qb = box.query() + qb.nearest_neighbors_f32("vector", [9.2, 8.9], 7) + qb.starts_with_string("name", "Blue", case_sensitive=True) + query = qb.build() search_results = query.find_with_scores() assert len(search_results) == 1 assert search_results[0][0].name == "Blue sea" # Regular condition + NN search - query = box.query() \ - .nearest_neighbors_f32("vector", [7.7, 7.7], 8) \ - .contains_string("name", "blue", case_sensitive=False) \ - .build() + qb = box.query() + qb.nearest_neighbors_f32("vector", [7.7, 7.7], 8) + qb.contains_string("name", "blue", case_sensitive=False) + query = qb.build() # 8, 7, 9, 6, 5, 4, 3, 2 # Filtered: 9, 5, 4, 2 search_results = query.find_ids_with_scores() diff --git a/tests/test_query.py b/tests/test_query.py index 588b787..974ad68 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -2,11 +2,11 @@ from objectbox.model import * from objectbox.c import * import pytest -from tests.common import (load_empty_test_objectbox, autocleanup) +from tests.common import (load_empty_test_objectbox, create_test_objectbox, autocleanup) from tests.model import TestEntity -def test_query_basics(): +def test_basics(): ob = load_empty_test_objectbox() box = objectbox.Box(ob, TestEntity) object1 = TestEntity() @@ -21,58 +21,58 @@ def test_query_basics(): # String queries str_prop: Property = TestEntity.get_property("str") - query = box.query() \ - .equals_string(str_prop._id, "bar", True) \ - .build() + qb = box.query() + qb.equals_string(str_prop._id, "bar", True) + query = qb.build() assert query.count() == 1 assert query.find()[0].str == "bar" - query = box.query() \ - .not_equals_string(str_prop._id, "bar", True) \ - .build() + qb = box.query() + qb.not_equals_string(str_prop._id, "bar", True) + query = qb.build() assert query.count() == 1 assert query.find()[0].str == "foo" - query = box.query() \ - .contains_string(str_prop._id, "ba", True) \ - .build() + qb = box.query() + qb.contains_string(str_prop._id, "ba", True) + query = qb.build() assert query.count() == 1 assert query.find()[0].str == "bar" - query = box.query() \ - .starts_with_string(str_prop._id, "f", True) \ - .build() + qb = box.query() + qb.starts_with_string(str_prop._id, "f", True) + query = qb.build() assert query.count() == 1 assert query.find()[0].str == "foo" - query = box.query() \ - .ends_with_string(str_prop._id, "o", True) \ - .build() + qb = box.query() + qb.ends_with_string(str_prop._id, "o", True) + query = qb.build() assert query.count() == 1 assert query.find()[0].str == "foo" - query = box.query() \ - .greater_than_string(str_prop._id, "bar", True) \ - .build() + qb = box.query() + qb.greater_than_string(str_prop._id, "bar", True) + query = qb.build() assert query.count() == 1 assert query.find()[0].str == "foo" - query = box.query() \ - .greater_or_equal_string(str_prop._id, "bar", True) \ - .build() + qb = box.query() + qb.greater_or_equal_string(str_prop._id, "bar", True) + query = qb.build() assert query.count() == 2 assert query.find()[0].str == "foo" assert query.find()[1].str == "bar" - query = box.query() \ - .less_than_string(str_prop._id, "foo", True) \ - .build() + qb = box.query() + qb.less_than_string(str_prop._id, "foo", True) + query = qb.build() assert query.count() == 1 assert query.find()[0].str == "bar" - query = box.query() \ - .less_or_equal_string(str_prop._id, "foo", True) \ - .build() + qb = box.query() + qb.less_or_equal_string(str_prop._id, "foo", True) + query = qb.build() assert query.count() == 2 assert query.find()[0].str == "foo" assert query.find()[1].str == "bar" @@ -80,47 +80,47 @@ def test_query_basics(): # Int queries int_prop: Property = TestEntity.get_property("int64") - query = box.query() \ - .equals_int(int_prop._id, 123) \ - .build() + qb = box.query() + qb.equals_int(int_prop._id, 123) + query = qb.build() assert query.count() == 1 assert query.find()[0].int64 == 123 - query = box.query() \ - .not_equals_int(int_prop._id, 123) \ - .build() + qb = box.query() + qb.not_equals_int(int_prop._id, 123) + query = qb.build() assert query.count() == 1 assert query.find()[0].int64 == 456 - query = box.query() \ - .greater_than_int(int_prop._id, 123) \ - .build() + qb = box.query() + qb.greater_than_int(int_prop._id, 123) + query = qb.build() assert query.count() == 1 assert query.find()[0].int64 == 456 - query = box.query() \ - .greater_or_equal_int(int_prop._id, 123) \ - .build() + qb = box.query() + qb.greater_or_equal_int(int_prop._id, 123) + query = qb.build() assert query.count() == 2 assert query.find()[0].int64 == 123 assert query.find()[1].int64 == 456 - query = box.query() \ - .less_than_int(int_prop._id, 456) \ - .build() + qb = box.query() + qb.less_than_int(int_prop._id, 456) + query = qb.build() assert query.count() == 1 assert query.find()[0].int64 == 123 - query = box.query() \ - .less_or_equal_int(int_prop._id, 456) \ - .build() + qb = box.query() + qb.less_or_equal_int(int_prop._id, 456) + query = qb.build() assert query.count() == 2 assert query.find()[0].int64 == 123 assert query.find()[1].int64 == 456 - query = box.query() \ - .between_2ints(int_prop._id, 100, 200) \ - .build() + qb = box.query() + qb.between_2ints(int_prop._id, 100, 200) + query = qb.build() assert query.count() == 1 assert query.find()[0].int64 == 123 @@ -143,9 +143,9 @@ def test_offset_limit(): int_prop: Property = TestEntity.get_property("int64") - query = box.query() \ - .equals_int(int_prop._id, 0) \ - .build() + qb = box.query() + qb.equals_int(int_prop._id, 0) + query = qb.build() assert query.count() == 4 query.offset(2) @@ -160,3 +160,67 @@ def test_offset_limit(): query.offset(0) query.limit(0) assert len(query.find()) == 4 + + +def test_any_all(): + db = create_test_objectbox() + + box = objectbox.Box(db, TestEntity) + + box.put(TestEntity(str="Foo", int32=10, int8=2, float32=3.14, bool=True)) + box.put(TestEntity(str="FooBar", int32=100, int8=50, float32=2.0, bool=True)) + box.put(TestEntity(str="Bar", int32=99, int8=127, float32=1.0, bool=False)) + box.put(TestEntity(str="Test", int32=1, int8=1, float32=0.0001, bool=True)) + box.put(TestEntity(str="test", int32=3232, int8=88, float32=1.0101, bool=False)) + box.put(TestEntity(str="Foo or BAR?", int32=0, int8=0, float32=0.0, bool=False)) + box.put(TestEntity(str="Just a test", int32=6, int8=6, float32=6.111, bool=False)) + box.put(TestEntity(str="EXAMPLE", int32=37, int8=37, float32=100, bool=True)) + + # Test all + qb = box.query() + qb.all([ + qb.starts_with_string("str", "Foo"), + qb.equals_int("int32", 10) + ]) + query = qb.build() + ids = query.find_ids() + assert ids == [1] + + # Test any + qb = box.query() + qb.any([ + qb.starts_with_string("str", "Test", case_sensitive=False), + qb.ends_with_string("str", "?"), + qb.equals_int("int32", 37) + ]) + query = qb.build() + ids = query.find_ids() + # 4, 5, 6, 8 + assert ids == [4, 5, 6, 8] + + # Test all/any + qb = box.query() + qb.any([ + qb.all([qb.contains_string("str", "Foo"), qb.less_than_int("int32", 100)]), + qb.equals_string("str", "Test", case_sensitive=False) + ]) + query = qb.build() + ids = query.find_ids() + # 1, 4, 5, 6 + assert ids == [1, 4, 5, 6] + + # Test all/any + qb = box.query() + qb.all([ + qb.any([ + qb.contains_string("str", "foo", case_sensitive=False), + qb.contains_string("str", "bar", case_sensitive=False) + ]), + qb.greater_than_int("int8", 30) + ]) + query = qb.build() + ids = query.find_ids() + # 2, 3 + assert ids == [2, 3] + +