Skip to content

Commit

Permalink
query: add any/all objectbox#24
Browse files Browse the repository at this point in the history
  • Loading branch information
loryruta committed Apr 10, 2024
1 parent 557738b commit a398911
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 121 deletions.
16 changes: 14 additions & 2 deletions objectbox/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def shlib_name(library: str) -> str:
obx_id = ctypes.c_uint64
obx_qb_cond = ctypes.c_int

obx_qb_cond_p = ctypes.POINTER(obx_qb_cond)

# enums
OBXPropertyType = ctypes.c_int
OBXPropertyFlags = ctypes.c_int
Expand Down Expand Up @@ -327,6 +329,16 @@ def c_voidp_as_bytes(voidp, size):
return memoryview(ctypes.cast(voidp, ctypes.POINTER(ctypes.c_ubyte * size))[0]).tobytes()


def py_list_to_c_array(py_list: List[Any], c_type):
""" Converts the given python list into a C array. """
return (c_type * len(py_list))(*py_list)


def py_list_to_c_pointer(py_list: List[Any], c_type):
""" Converts the given python list into a C array and returns a pointer type. """
return ctypes.cast(py_list_to_c_array(py_list, c_type), ctypes.POINTER(c_type))


# OBX_model* (void);
obx_model = c_fn('obx_model', OBX_model_p, [])

Expand Down Expand Up @@ -656,10 +668,10 @@ def c_voidp_as_bytes(voidp, size):
[OBX_query_builder_p, obx_schema_id, ctypes.c_void_p, ctypes.c_size_t])

# OBX_C_API obx_qb_cond obx_qb_all(OBX_query_builder* builder, const obx_qb_cond conditions[], size_t count);
obx_qb_all = c_fn('obx_qb_all', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond, ctypes.c_size_t])
obx_qb_all = c_fn('obx_qb_all', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond_p, ctypes.c_size_t])

# OBX_C_API obx_qb_cond obx_qb_any(OBX_query_builder* builder, const obx_qb_cond conditions[], size_t count);
obx_qb_any = c_fn('obx_qb_any', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond, ctypes.c_size_t])
obx_qb_any = c_fn('obx_qb_any', obx_qb_cond, [OBX_query_builder_p, obx_qb_cond_p, ctypes.c_size_t])

# OBX_C_API obx_err obx_qb_param_alias(OBX_query_builder* builder, const char* alias);
obx_qb_param_alias = c_fn_rc('obx_qb_param_alias', [OBX_query_builder_p, ctypes.c_char_p])
Expand Down
117 changes: 66 additions & 51 deletions objectbox/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,85 +30,90 @@ def error_code(self) -> int:
def error_message(self) -> str:
return obx_qb_error_message(self._c_builder)

def equals_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
def equals_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return self
cond = obx_qb_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return cond

def not_equals_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
def not_equals_string(self, prop: Union[int, str, Property], value: str,
case_sensitive: bool = True) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_not_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return self
cond = obx_qb_not_equals_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return cond

def contains_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
def contains_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_contains_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return self
cond = obx_qb_contains_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return cond

def starts_with_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
def starts_with_string(self, prop: Union[int, str, Property], value: str,
case_sensitive: bool = True) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_starts_with_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return self
cond = obx_qb_starts_with_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return cond

def ends_with_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
def ends_with_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_ends_with_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return self
cond = obx_qb_ends_with_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return cond

def greater_than_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
def greater_than_string(self, prop: Union[int, str, Property], value: str,
case_sensitive: bool = True) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_greater_than_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return self
cond = obx_qb_greater_than_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return cond

def greater_or_equal_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
def greater_or_equal_string(self, prop: Union[int, str, Property], value: str,
case_sensitive: bool = True) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_greater_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return self
cond = obx_qb_greater_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return cond

def less_than_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
def less_than_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_less_than_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return self
cond = obx_qb_less_than_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return cond

def less_or_equal_string(self, prop: Union[int, str, Property], value: str, case_sensitive: bool = True):
def less_or_equal_string(self, prop: Union[int, str, Property], value: str,
case_sensitive: bool = True) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_less_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return self
cond = obx_qb_less_or_equal_string(self._c_builder, prop_id, c_str(value), case_sensitive)
return cond

def equals_int(self, prop: Union[int, str, Property], value: int):
def equals_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_equals_int(self._c_builder, prop_id, value)
return self
cond = obx_qb_equals_int(self._c_builder, prop_id, value)
return cond

def not_equals_int(self, prop: Union[int, str, Property], value: int):
def not_equals_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_not_equals_int(self._c_builder, prop_id, value)
return self
cond = obx_qb_not_equals_int(self._c_builder, prop_id, value)
return cond

def greater_than_int(self, prop: Union[int, str, Property], value: int):
def greater_than_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_greater_than_int(self._c_builder, prop_id, value)
return self
cond = obx_qb_greater_than_int(self._c_builder, prop_id, value)
return cond

def greater_or_equal_int(self, prop: Union[int, str, Property], value: int):
def greater_or_equal_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_greater_or_equal_int(self._c_builder, prop_id, value)
return self
cond = obx_qb_greater_or_equal_int(self._c_builder, prop_id, value)
return cond

def less_than_int(self, prop: Union[int, str, Property], value: int):
def less_than_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_less_than_int(self._c_builder, prop_id, value)
return self
cond = obx_qb_less_than_int(self._c_builder, prop_id, value)
return cond

def less_or_equal_int(self, prop: Union[int, str, Property], value: int):
def less_or_equal_int(self, prop: Union[int, str, Property], value: int) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_less_or_equal_int(self._c_builder, prop_id, value)
return self
cond = obx_qb_less_or_equal_int(self._c_builder, prop_id, value)
return cond

def between_2ints(self, prop: Union[int, str, Property], value_a: int, value_b: int):
def between_2ints(self, prop: Union[int, str, Property], value_a: int, value_b: int) -> obx_qb_cond:
prop_id = self._get_property_id(prop)
obx_qb_between_2ints(self._c_builder, prop_id, value_a, value_b)
return self
cond = obx_qb_between_2ints(self._c_builder, prop_id, value_a, value_b)
return cond

def nearest_neighbors_f32(self, prop: Union[int, str, Property], query_vector: Union[np.ndarray, List[float]],
element_count: int):
Expand All @@ -117,11 +122,21 @@ def nearest_neighbors_f32(self, prop: Union[int, str, Property], query_vector: U
raise Exception(f"query_vector dtype must be float32")
query_vector_data = query_vector.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
else: # List[float]
query_vector_data = (ctypes.c_float * len(query_vector))(*query_vector)
query_vector_data = py_list_to_c_array(query_vector, ctypes.c_float)

prop_id = self._get_property_id(prop)
obx_qb_nearest_neighbors_f32(self._c_builder, prop_id, query_vector_data, element_count)
return self
cond = obx_qb_nearest_neighbors_f32(self._c_builder, prop_id, query_vector_data, element_count)
return cond

def any(self, conditions: List[obx_qb_cond]) -> obx_qb_cond:
c_conditions = py_list_to_c_pointer(conditions, obx_qb_cond)
cond = obx_qb_any(self._c_builder, c_conditions, len(conditions))
return cond

def all(self, conditions: List[obx_qb_cond]) -> obx_qb_cond:
c_conditions = py_list_to_c_pointer(conditions, obx_qb_cond)
cond = obx_qb_all(self._c_builder, c_conditions, len(conditions))
return cond

def build(self) -> Query:
c_query = obx_query(self._c_builder)
Expand Down
30 changes: 15 additions & 15 deletions tests/test_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def _test_random_points(num_points: int, num_query_points: int, seed: Optional[i
assert len(expected_result) == k

# Run ANN with OBX
query_builder = QueryBuilder(db, box)
query_builder.nearest_neighbors_f32("vector", query_point, k)
query = query_builder.build()
qb = box.query()
qb.nearest_neighbors_f32("vector", query_point, k)
query = qb.build()
obx_result = [id_ for id_, score in query.find_ids_with_scores()] # Ignore score
assert len(obx_result) == k

Expand Down Expand Up @@ -100,10 +100,10 @@ def test_combined_nn_search():
assert box.count() == 9

# Test condition + NN search
query = box.query() \
.nearest_neighbors_f32("vector", [4.1, 4.2], 6) \
.contains_string("name", "red", case_sensitive=False) \
.build()
qb = box.query()
qb.nearest_neighbors_f32("vector", [4.1, 4.2], 6)
qb.contains_string("name", "red", case_sensitive=False)
query = qb.build()
# 4, 5, 3, 6, 2, 7
# Filtered: 3, 6, 7
search_results = query.find_with_scores()
Expand All @@ -120,20 +120,20 @@ def test_combined_nn_search():
assert search_results[0][0].name == "Red apple"

# Regular condition + NN search
query = box.query() \
.nearest_neighbors_f32("vector", [9.2, 8.9], 7) \
.starts_with_string("name", "Blue", case_sensitive=True) \
.build()
qb = box.query()
qb.nearest_neighbors_f32("vector", [9.2, 8.9], 7)
qb.starts_with_string("name", "Blue", case_sensitive=True)
query = qb.build()

search_results = query.find_with_scores()
assert len(search_results) == 1
assert search_results[0][0].name == "Blue sea"

# Regular condition + NN search
query = box.query() \
.nearest_neighbors_f32("vector", [7.7, 7.7], 8) \
.contains_string("name", "blue", case_sensitive=False) \
.build()
qb = box.query()
qb.nearest_neighbors_f32("vector", [7.7, 7.7], 8)
qb.contains_string("name", "blue", case_sensitive=False)
query = qb.build()
# 8, 7, 9, 6, 5, 4, 3, 2
# Filtered: 9, 5, 4, 2
search_results = query.find_ids_with_scores()
Expand Down
Loading

0 comments on commit a398911

Please sign in to comment.