Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose search function with pre-filter for ANN #302

Merged
merged 54 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
0dbe5b2
[WIP] CAGRA - separable compilation for distance computation
achirkin Aug 16, 2024
93b0439
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Aug 16, 2024
ba52b13
Fix style
achirkin Aug 16, 2024
434e50a
Add missing multi-kernel implementation
achirkin Aug 19, 2024
6352550
Move common code out of virtual functions scope (aiming for more inli…
achirkin Aug 19, 2024
d161f79
Make small descriptor functions into fields
achirkin Aug 20, 2024
35c3813
Minor updates to improve reg count
achirkin Aug 20, 2024
4b5dcd3
Refactor distance_core -> compute_distance, and update the instance list
achirkin Aug 21, 2024
e5878db
Merge remote-tracking branch 'rapidsai/branch-24.10' into enh-cagra-s…
achirkin Aug 21, 2024
385a8c4
Make the compute_distance instances controlled from a single place
achirkin Aug 21, 2024
3f77cda
Refactor usage of init_kernel to make sure it instantiated in the sam…
achirkin Aug 22, 2024
ddb0488
Reduce the register usage in distance functions
achirkin Aug 22, 2024
c244ead
Partially implemented manual dispatch
achirkin Aug 23, 2024
7eb6a27
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Aug 23, 2024
ff2fdbe
Finish manual dispatch
achirkin Aug 23, 2024
78a9809
Change instance generator to have blockdim/team_size ratio 16
achirkin Aug 23, 2024
6082bf7
Trying various minor things to reduce register spilling
achirkin Aug 23, 2024
fc7d832
Move the metric parameter to the compute_distance template
achirkin Aug 26, 2024
6763bf7
Expose search() with optional filter for ANN
lowener Aug 26, 2024
118808e
Further reduce register pressure by moving code out of the non-inlina…
achirkin Aug 27, 2024
4e254fc
Merge branch 'branch-24.10' into 24.10-search-filter
lowener Aug 27, 2024
abec125
Manually unroll device::team_sum
achirkin Aug 27, 2024
cf0101c
Remove the test of a compute_distance instance that is not compiled (…
achirkin Aug 28, 2024
b3e6d26
Hide previously not hidden kernels
achirkin Aug 28, 2024
f231828
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Aug 28, 2024
e4cb424
Fix CAGRA filter test
lowener Aug 29, 2024
dc75f7a
Reduce register usage by minimizing the part of descriptor struct pas…
achirkin Sep 2, 2024
6630a99
Further reduce the size size of the dataset descriptor and add explic…
achirkin Sep 2, 2024
790e79c
Cache dataset descriptors to recover small batch performance
achirkin Sep 2, 2024
7599331
Reduce the register usage in compute_distance_standard further
achirkin Sep 3, 2024
4d9241e
Reduce the generated code volume
achirkin Sep 3, 2024
5fdcdd0
More explicit ldg cache behavior and a few smaller things
achirkin Sep 4, 2024
5984596
Simplify vpq indexing arithmetics a bit
achirkin Sep 4, 2024
337d990
Fix style
lowener Sep 4, 2024
6eb34be
Merge branch 'branch-24.10' into 24.10-search-filter
lowener Sep 4, 2024
af0cc12
Bring back the fatbin.ld link option
achirkin Sep 5, 2024
9023e68
relax the config for checking the raft_cutlass symbol exclusion (see …
achirkin Sep 5, 2024
99d2bd3
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 6, 2024
75a2dac
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 9, 2024
6a1b898
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 10, 2024
c1eed0e
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 10, 2024
d4673cf
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 11, 2024
a78797f
Merge branch 'branch-24.10' into 24.10-search-filter
lowener Sep 11, 2024
0046a73
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 11, 2024
267902e
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 16, 2024
4ae3fa5
Merge remote-tracking branch 'achirkin/enh-cagra-separable-compilatio…
lowener Sep 16, 2024
b145d6d
Add base_filter to ANN API
lowener Sep 23, 2024
d357b99
Merge branch 'branch-24.10' into 24.10-search-filter
lowener Sep 25, 2024
66f633c
Fix details, finalize merge
lowener Sep 25, 2024
9e3a4ca
Fix documentation and parameter names
lowener Sep 26, 2024
e6e9f3b
Use references in public API of pre-filtering
lowener Sep 30, 2024
9bdd8f6
Merge branch 'branch-24.10' into 24.10-search-filter
lowener Sep 30, 2024
9386a9e
Unify none_ivf_sample_filter with none_cagra_sample_filter
lowener Oct 1, 2024
e258866
Add Bruteforce Prefilter API
lowener Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ add_library(
src/cluster/kmeans_balanced_predict_int8.cu
src/cluster/kmeans_transform_float.cu
src/cluster/single_linkage_float.cu
src/core/bitset.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_half_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu
Expand Down Expand Up @@ -405,9 +406,6 @@ add_library(
src/neighbors/ivf_pq/detail/ivf_pq_search_float_int64_t.cu
src/neighbors/ivf_pq/detail/ivf_pq_search_int8_t_int64_t.cu
src/neighbors/ivf_pq/detail/ivf_pq_search_uint8_t_int64_t.cu
src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_float_int64_t.cu
src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_int8_t_int64_t.cu
src/neighbors/ivf_pq/detail/ivf_pq_search_with_filter_uint8_t_int64_t.cu
src/neighbors/nn_descent.cu
src/neighbors/nn_descent_float.cu
src/neighbors/nn_descent_int8.cu
Expand Down
6 changes: 6 additions & 0 deletions cpp/include/cuvs/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

#include <raft/core/bitset.hpp>

extern template struct raft::core::bitset<uint8_t, uint32_t>;
extern template struct raft::core::bitset<uint16_t, uint32_t>;
extern template struct raft::core::bitset<uint32_t, uint32_t>;
extern template struct raft::core::bitset<uint32_t, int64_t>;
extern template struct raft::core::bitset<uint64_t, int64_t>;

namespace cuvs::core {
/* To use bitset functions containing CUDA code, include <raft/core/bitset.cuh> */

Expand Down
28 changes: 16 additions & 12 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,9 +951,6 @@ void extend(
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
Expand All @@ -962,23 +959,24 @@ void extend(
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
*/

void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::index<float, uint32_t>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to support both bitset and bitmap using inheritance (and without having to instantiate two explicit templates), right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could work on the host API side if we do dynamic casts on the host side down the line (before passing it to device) and do not add any virtual methods to the filter struct (so that the objects we pass to device don't have virtual tables).
(otherwise it gets very complicated and potentially slow)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, inheritance doesn't work as well on device side so the filter can't dynamically choose the right virtual function. Doing a dynamic cast on the host side can work so I will try that approach

sample_filter = std::nullopt);

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index cagra index
Expand All @@ -987,22 +985,23 @@ void search(raft::resources const& res,
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::index<int8_t, uint32_t>& index,
raft::device_matrix_view<const int8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>>
sample_filter = std::nullopt);

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] index cagra index
Expand All @@ -1011,13 +1010,18 @@ void search(raft::resources const& res,
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
*/
void search(raft::resources const& res,
cuvs::neighbors::cagra::search_params const& params,
const cuvs::neighbors::cagra::index<uint8_t, uint32_t>& index,
raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>>
sample_filter = std::nullopt);

/**
* @}
*/
Expand Down
153 changes: 34 additions & 119 deletions cpp/include/cuvs/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "common.hpp"
#include <cstdint>
#include <cuvs/neighbors/common.hpp>
#include <optional>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>

Expand Down Expand Up @@ -1163,13 +1164,17 @@ void extend(raft::resources const& handle,
* dataset [n_queries, k]
* @param[out] distances raft::device_matrix_view to the distances to the selected neighbors
* [n_queries, k]
* @param[in] sample_filter a device bitset filter function that greenlights samples for a given
* query.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::ivf_flat::search_params& params,
cuvs::neighbors::ivf_flat::index<float, int64_t>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>>
sample_filter = std::nullopt);

/**
* @brief Search ANN using the constructed index.
Expand Down Expand Up @@ -1200,13 +1205,17 @@ void search(raft::resources const& handle,
* dataset [n_queries, k]
* @param[out] distances raft::device_matrix_view to the distances to the selected neighbors
* [n_queries, k]
* @param[in] sample_filter a device bitset filter function that greenlights samples for a given
* query.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::ivf_flat::search_params& params,
cuvs::neighbors::ivf_flat::index<int8_t, int64_t>& index,
raft::device_matrix_view<const int8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>>
sample_filter = std::nullopt);

/**
* @brief Search ANN using the constructed index.
Expand Down Expand Up @@ -1237,112 +1246,18 @@ void search(raft::resources const& handle,
* dataset [n_queries, k]
* @param[out] distances raft::device_matrix_view to the distances to the selected neighbors
* [n_queries, k]
* @param[in] sample_filter a device bitset filter function that greenlights samples for a given
* query.
*/
void search(raft::resources const& handle,
const cuvs::neighbors::ivf_flat::search_params& params,
cuvs::neighbors::ivf_flat::index<uint8_t, int64_t>& index,
raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);

/**
* @brief Search ANN using the constructed index with the given filter.
*
* See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`.
* The exact size of the temporary buffer depends on multiple factors and is an implementation
* detail. However, you can safely specify a small initial size for the memory pool, so that only a
* few allocations happen to grow it during the first invocations of the `search`.
*
* @param[in] handle
* @param[in] params configure the search
* @param[in] idx ivf-flat constructed index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device bitset filter function that greenlights samples for a given
* query.
*/
void search_with_filtering(
raft::resources const& handle,
const search_params& params,
index<float, int64_t>& idx,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t> sample_filter);
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t>>
sample_filter = std::nullopt);

/**
* @brief Search ANN using the constructed index with the given filter.
*
* See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`.
* The exact size of the temporary buffer depends on multiple factors and is an implementation
* detail. However, you can safely specify a small initial size for the memory pool, so that only a
* few allocations happen to grow it during the first invocations of the `search`.
*
* @param[in] handle
* @param[in] params configure the search
* @param[in] idx ivf-flat constructed index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device bitset filter function that greenlights samples for a given
* query.
*/
void search_with_filtering(
raft::resources const& handle,
const search_params& params,
index<int8_t, int64_t>& idx,
raft::device_matrix_view<const int8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t> sample_filter);

/**
* @brief Search ANN using the constructed index with the given filter.
*
* See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`.
* The exact size of the temporary buffer depends on multiple factors and is an implementation
* detail. However, you can safely specify a small initial size for the memory pool, so that only a
* few allocations happen to grow it during the first invocations of the `search`.
*
* @param[in] handle
* @param[in] params configure the search
* @param[in] idx ivf-flat constructed index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] sample_filter a device bitset filter function that greenlights samples for a given
* query.
*/
void search_with_filtering(
raft::resources const& handle,
const search_params& params,
index<uint8_t, int64_t>& idx,
raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
cuvs::neighbors::filtering::bitset_filter<uint32_t, int64_t> sample_filter);
/**
* @}
*/
Expand Down Expand Up @@ -2039,18 +1954,18 @@ void reset_index(const raft::resources& res, index<uint8_t, int64_t>* index);
* using namespace cuvs::neighbors;
* raft::resources res;
* // use default index parameters
* ivf_pq::index_params index_params;
* ivf_flat::index_params index_params;
* // initialize an empty index
* ivf_pq::index<int64_t> index(res, index_params, D);
* ivf_pq::helpers::reset_index(res, &index);
* ivf_flat::index<uint8_t, int64_t> index(res, index_params, D);
* ivf_flat::helpers::reset_index(res, &index);
* // resize the first IVF list to hold 5 records
* auto spec = list_spec<uint32_t, int64_t>{
* index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()};
* auto spec = list_spec<uint32_t, uint8_t, int64_t>{
* index->dim(), index->conservative_memory_allocation()};
* uint32_t new_size = 5;
* ivf::resize_list(res, list, spec, new_size, 0);
* raft::update_device(index.list_sizes(), &new_size, 1, stream);
* // recompute the internal state of the index
* ivf_pq::helpers::recompute_internal_state(res, index);
* ivf_flat::helpers::recompute_internal_state(res, index);
* @endcode
*
* @param[in] res raft resource
Expand All @@ -2067,18 +1982,18 @@ void recompute_internal_state(const raft::resources& res, index<float, int64_t>*
* using namespace cuvs::neighbors;
* raft::resources res;
* // use default index parameters
* ivf_pq::index_params index_params;
* ivf_flat::index_params index_params;
* // initialize an empty index
* ivf_pq::index<int64_t> index(res, index_params, D);
* ivf_pq::helpers::reset_index(res, &index);
* ivf_flat::index<uint8_t, int64_t> index(res, index_params, D);
* ivf_flat::helpers::reset_index(res, &index);
* // resize the first IVF list to hold 5 records
* auto spec = list_spec<uint32_t, int64_t>{
* index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()};
* auto spec = list_spec<uint32_t, uint8_t, int64_t>{
* index->dim(), index->conservative_memory_allocation()};
* uint32_t new_size = 5;
* ivf::resize_list(res, list, spec, new_size, 0);
* raft::update_device(index.list_sizes(), &new_size, 1, stream);
* // recompute the internal state of the index
* ivf_pq::helpers::recompute_internal_state(res, index);
* ivf_flat::helpers::recompute_internal_state(res, index);
* @endcode
*
* @param[in] res raft resource
Expand All @@ -2095,18 +2010,18 @@ void recompute_internal_state(const raft::resources& res, index<int8_t, int64_t>
* using namespace cuvs::neighbors;
* raft::resources res;
* // use default index parameters
* ivf_pq::index_params index_params;
* ivf_flat::index_params index_params;
* // initialize an empty index
* ivf_pq::index<int64_t> index(res, index_params, D);
* ivf_pq::helpers::reset_index(res, &index);
* ivf_flat::index<uint8_t, int64_t> index(res, index_params, D);
* ivf_flat::helpers::reset_index(res, &index);
* // resize the first IVF list to hold 5 records
* auto spec = list_spec<uint32_t, int64_t>{
* index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()};
* auto spec = list_spec<uint32_t, uint8_t, int64_t>{
* index->dim(), index->conservative_memory_allocation()};
* uint32_t new_size = 5;
* ivf::resize_list(res, list, spec, new_size, 0);
* raft::update_device(index.list_sizes(), &new_size, 1, stream);
* // recompute the internal state of the index
* ivf_pq::helpers::recompute_internal_state(res, index);
* ivf_flat::helpers::recompute_internal_state(res, index);
* @endcode
*
* @param[in] res raft resource
Expand Down
Loading
Loading