Skip to content

Commit

Permalink
Merge branch 'branch-25.02' into extend-c-api
Browse files Browse the repository at this point in the history
  • Loading branch information
ajit283 authored Jan 14, 2025
2 parents 12ac03d + 28d9990 commit 1fc6795
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 65 deletions.
28 changes: 28 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,20 @@ auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<float, float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* @param[in] handle
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a host pointer to a row-major matrix [n_rows, dim]
*
* @return the constructed brute-force index
*/
auto build(raft::resources const& handle,
const cuvs::neighbors::brute_force::index_params& index_params,
raft::host_matrix_view<const float, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<float, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
Expand Down Expand Up @@ -231,6 +245,20 @@ auto build(raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<half, float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* @param[in] handle
* @param[in] index_params parameters such as the distance metric to use
* @param[in] dataset a host pointer to a row-major matrix [n_rows, dim]
*
* @return the constructed brute-force index
*/
auto build(raft::resources const& handle,
const cuvs::neighbors::brute_force::index_params& index_params,
raft::host_matrix_view<const half, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::brute_force::index<half, float>;

[[deprecated]] auto build(
raft::resources const& handle,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset,
Expand Down
45 changes: 45 additions & 0 deletions cpp/include/cuvs/neighbors/refine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,51 @@ void refine(raft::resources const& handle,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);

/**
* @brief Refine nearest neighbor search.
*
* Refinement is an operation that follows an approximate NN search. The approximate search has
* already selected n_candidates neighbor candidates for each query. We narrow it down to k
* neighbors. For each query, we calculate the exact distance between the query and its
* n_candidates neighbor candidate, and select the k nearest ones.
*
* The k nearest neighbors and distances are returned.
*
* Example usage
* @code{.cpp}
* using namespace cuvs::neighbors;
* // use default index parameters
* ivf_pq::index_params index_params;
* // create and fill the index from a [N, D] dataset
* auto index = ivf_pq::build(handle, index_params, dataset);
* // use default search parameters
* ivf_pq::search_params search_params;
* // search m = 4 * k nearest neighbours for each of the N queries
* ivf_pq::search(handle, search_params, index, queries, neighbor_candidates,
* out_dists_tmp);
* // refine it to the k nearest one
* refine(handle, dataset, queries, neighbor_candidates, out_indices, out_dists,
* index.metric());
* @endcode
*
*
* @param[in] handle the raft handle
* @param[in] dataset device matrix that stores the dataset [n_rows, dims]
* @param[in] queries device matrix of the queries [n_queris, dims]
* @param[in] neighbor_candidates indices of candidate vectors [n_queries, n_candidates], where
* n_candidates >= k
* @param[out] indices device matrix that stores the refined indices [n_queries, k]
* @param[out] distances device matrix that stores the refined distances [n_queries, k]
* @param[in] metric distance metric to use. Euclidean (L2) is used by default
*/
void refine(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<const uint32_t, int64_t, raft::row_major> neighbor_candidates,
raft::device_matrix_view<uint32_t, int64_t, raft::row_major> indices,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded);

/**
* @brief Refine nearest neighbor search.
*
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ void index<T, DistT>::update_dataset(
{ \
return detail::build<T, DistT>(res, dataset, index_params.metric, index_params.metric_arg); \
} \
auto build(raft::resources const& res, \
const cuvs::neighbors::brute_force::index_params& index_params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::brute_force::index<T, DistT> \
{ \
return detail::build<T, DistT>(res, dataset, index_params.metric, index_params.metric_arg); \
} \
auto build(raft::resources const& res, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset, \
cuvs::distance::DistanceType metric, \
Expand Down
22 changes: 18 additions & 4 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "./knn_utils.cuh"

#include <raft/core/bitmap.cuh>
#include <raft/core/copy.cuh>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
Expand Down Expand Up @@ -750,10 +751,10 @@ void search(raft::resources const& res,
}
}

template <typename T, typename DistT, typename LayoutT = raft::row_major>
template <typename T, typename DistT, typename AccessorT, typename LayoutT = raft::row_major>
cuvs::neighbors::brute_force::index<T, DistT> build(
raft::resources const& res,
raft::device_matrix_view<const T, int64_t, LayoutT> dataset,
mdspan<const T, matrix_extent<int64_t>, LayoutT, AccessorT> dataset,
cuvs::distance::DistanceType metric,
DistT metric_arg)
{
Expand All @@ -764,18 +765,31 @@ cuvs::neighbors::brute_force::index<T, DistT> build(
if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
metric == cuvs::distance::DistanceType::CosineExpanded) {
auto dataset_storage = std::optional<device_matrix<T, int64_t, LayoutT>>{};
auto dataset_view = [&res, &dataset_storage, dataset]() {
if constexpr (std::is_same_v<decltype(dataset),
raft::device_matrix_view<const T, int64_t, row_major>>) {
return dataset;
} else {
dataset_storage =
make_device_matrix<T, int64_t, LayoutT>(res, dataset.extent(0), dataset.extent(1));
raft::copy(res, dataset_storage->view(), dataset);
return raft::make_const_mdspan(dataset_storage->view());
}
}();

norms = raft::make_device_vector<DistT, int64_t>(res, dataset.extent(0));
// cosine needs the l2norm, where as l2 distances needs the squared norm
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::norm(res,
dataset,
dataset_view,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS,
raft::sqrt_op{});
} else {
raft::linalg::norm(res,
dataset,
dataset_view,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS);
Expand Down
1 change: 1 addition & 0 deletions cpp/src/neighbors/ivf_flat_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ void index<T, IdxT>::check_consistency()
"inconsistent number of lists (clusters)");
}

template struct index<float, uint32_t>; // Used for refine function
template struct index<float, int64_t>;
template struct index<half, int64_t>;
template struct index<int8_t, int64_t>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,6 @@
}

instantiate_cuvs_neighbors_refine_d(int64_t, float, float, int64_t);
instantiate_cuvs_neighbors_refine_d(uint32_t, float, float, int64_t);

#undef instantiate_cuvs_neighbors_refine_d
13 changes: 7 additions & 6 deletions cpp/src/neighbors/refine/refine_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,13 @@ void refine_device(
cuvs::neighbors::ivf_flat::index<data_t, idx_t> refinement_index(
handle, cuvs::distance::DistanceType(metric), n_queries, false, true, dim);

cuvs::neighbors::ivf_flat::detail::fill_refinement_index(handle,
&refinement_index,
dataset.data_handle(),
neighbor_candidates.data_handle(),
n_queries,
n_candidates);
cuvs::neighbors::ivf_flat::detail::fill_refinement_index<data_t, idx_t>(
handle,
&refinement_index,
dataset.data_handle(),
neighbor_candidates.data_handle(),
static_cast<idx_t>(n_queries),
static_cast<uint32_t>(n_candidates));
uint32_t grid_dim_x = 1;

// the neighbor ids will be computed in uint32_t as offset
Expand Down
142 changes: 87 additions & 55 deletions cpp/test/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cuvs/selection/select_k.hpp>

#include <cuvs/neighbors/brute_force.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/init.cuh>
Expand Down Expand Up @@ -210,14 +211,15 @@ struct RandomKNNInputs {
int k;
cuvs::distance::DistanceType metric;
bool row_major;
bool host_dataset;
};

std::ostream& operator<<(std::ostream& os, const RandomKNNInputs& input)
{
return os << "num_queries:" << input.num_queries << " num_vecs:" << input.num_db_vecs
<< " dim:" << input.dim << " k:" << input.k
<< " metric:" << cuvs::neighbors::print_metric{input.metric}
<< " row_major:" << input.row_major;
<< " row_major:" << input.row_major << " host_dataset:" << input.host_dataset;
}

template <typename T, typename DistT = T>
Expand Down Expand Up @@ -399,12 +401,15 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>

cuvs::neighbors::brute_force::search_params search_params;

if (params_.row_major) {
auto idx =
cuvs::neighbors::brute_force::build(handle_,
index_params,
raft::make_device_matrix_view<const T, int64_t>(
database.data(), params_.num_db_vecs, params_.dim));
if (params_.host_dataset) {
// test building from a dataset in host memory
auto host_database =
raft::make_host_matrix<T, int64_t, raft::row_major>(params_.num_db_vecs, params_.dim);
raft::copy(
host_database.data_handle(), database.data(), params_.num_db_vecs * params_.dim, stream_);

auto idx = cuvs::neighbors::brute_force::build(
handle_, index_params, raft::make_const_mdspan(host_database.view()));

cuvs::neighbors::brute_force::search(
handle_,
Expand All @@ -416,21 +421,39 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>
distances,
cuvs::neighbors::filtering::none_sample_filter{});
} else {
auto idx = cuvs::neighbors::brute_force::build(
handle_,
index_params,
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
database.data(), params_.num_db_vecs, params_.dim));
if (params_.row_major) {
auto idx =
cuvs::neighbors::brute_force::build(handle_,
index_params,
raft::make_device_matrix_view<const T, int64_t>(
database.data(), params_.num_db_vecs, params_.dim));

cuvs::neighbors::brute_force::search(
handle_,
search_params,
idx,
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
search_queries.data(), params_.num_queries, params_.dim),
indices,
distances,
cuvs::neighbors::filtering::none_sample_filter{});
cuvs::neighbors::brute_force::search(
handle_,
search_params,
idx,
raft::make_device_matrix_view<const T, int64_t>(
search_queries.data(), params_.num_queries, params_.dim),
indices,
distances,
cuvs::neighbors::filtering::none_sample_filter{});
} else {
auto idx = cuvs::neighbors::brute_force::build(
handle_,
index_params,
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
database.data(), params_.num_db_vecs, params_.dim));

cuvs::neighbors::brute_force::search(
handle_,
search_params,
idx,
raft::make_device_matrix_view<const T, int64_t, raft::col_major>(
search_queries.data(), params_.num_queries, params_.dim),
indices,
distances,
cuvs::neighbors::filtering::none_sample_filter{});
}
}

ASSERT_TRUE(cuvs::neighbors::devArrMatchKnnPair(ref_indices_.data(),
Expand Down Expand Up @@ -480,42 +503,51 @@ class RandomBruteForceKNNTest : public ::testing::TestWithParam<RandomKNNInputs>

const std::vector<RandomKNNInputs> random_inputs = {
// test each distance metric on a small-ish input, with row-major inputs
{100, 256, 2, 65, cuvs::distance::DistanceType::L2Expanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, true},
{100, 256, 2, 65, cuvs::distance::DistanceType::L2Expanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, true, false},
// test each distance metric with col-major inputs
{256, 512, 16, 7, cuvs::distance::DistanceType::L2Expanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, false},
{256, 512, 16, 7, cuvs::distance::DistanceType::L2Expanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Unexpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CorrelationExpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::CosineExpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::LpUnexpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::JensenShannon, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
{256, 512, 16, 8, cuvs::distance::DistanceType::Canberra, false, false},
// larger tests on different sized data / k values
{10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Expanded, false},
{345, 1023, 16, 128, cuvs::distance::DistanceType::CosineExpanded, true},
{789, 20516, 64, 256, cuvs::distance::DistanceType::L2SqrtExpanded, false},
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, true},
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, false},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::LpUnexpanded, true},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::L2SqrtExpanded, false},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::InnerProduct, false}};
{10000, 40000, 32, 30, cuvs::distance::DistanceType::L2Expanded, false, false},
{345, 1023, 16, 128, cuvs::distance::DistanceType::CosineExpanded, true, false},
{789, 20516, 64, 256, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, true, false},
{1000, 200000, 128, 128, cuvs::distance::DistanceType::L2Expanded, false, false},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::LpUnexpanded, true, false},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::L2SqrtExpanded, false, false},
{1000, 5000, 128, 128, cuvs::distance::DistanceType::InnerProduct, false, false},
// test with datasets on host memory
{256, 512, 16, 8, cuvs::distance::DistanceType::L2Expanded, true, true},
{256, 512, 32, 16, cuvs::distance::DistanceType::L2Unexpanded, true, true},
{256, 512, 8, 8, cuvs::distance::DistanceType::L2SqrtExpanded, true, true},
{256, 128, 32, 8, cuvs::distance::DistanceType::L2SqrtUnexpanded, true, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::L1, true, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::Linf, true, true},
{256, 512, 16, 8, cuvs::distance::DistanceType::InnerProduct, true, true},
{256, 512, 16, 7, cuvs::distance::DistanceType::L2Expanded, true, true}};

typedef RandomBruteForceKNNTest<float, float> RandomBruteForceKNNTestF;
TEST_P(RandomBruteForceKNNTestF, BruteForce) { this->testBruteForce(); }
Expand Down

0 comments on commit 1fc6795

Please sign in to comment.