Skip to content

Commit

Permalink
Merge branch 'branch-24.10' into fea-persistent-cagra
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin authored Aug 19, 2024
2 parents b7f5106 + 9cf4800 commit 05100ce
Show file tree
Hide file tree
Showing 21 changed files with 546 additions and 532 deletions.
1 change: 1 addition & 0 deletions cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ struct params : base_params {
* Simple object to specify hyper-parameters to the balanced k-means algorithm.
*
* The following metrics are currently supported in k-means balanced:
* - CosineExpanded
* - InnerProduct
* - L2Expanded
* - L2SqrtExpanded
Expand Down
72 changes: 72 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ struct index : cuvs::neighbors::index {
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
Expand All @@ -327,6 +333,12 @@ auto build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
Expand All @@ -351,6 +363,12 @@ void build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
Expand All @@ -374,6 +392,12 @@ auto build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
Expand All @@ -398,6 +422,12 @@ void build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
Expand All @@ -421,6 +451,12 @@ auto build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
Expand All @@ -445,6 +481,12 @@ void build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Note, if index_params.add_data_on_build is set to true, the user can set a
* stream pool in the input raft::resource with at least one stream to enable kernel and copy
* overlapping.
Expand Down Expand Up @@ -475,6 +517,12 @@ auto build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Note, if index_params.add_data_on_build is set to true, the user can set a
* stream pool in the input raft::resource with at least one stream to enable kernel and copy
* overlapping.
Expand Down Expand Up @@ -506,6 +554,12 @@ void build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Note, if index_params.add_data_on_build is set to true, the user can set a
* stream pool in the input raft::resource with at least one stream to enable kernel and copy
* overlapping.
Expand Down Expand Up @@ -536,6 +590,12 @@ auto build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Note, if index_params.add_data_on_build is set to true, the user can set a
* stream pool in the input raft::resource with at least one stream to enable kernel and copy
* overlapping.
Expand Down Expand Up @@ -567,6 +627,12 @@ void build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Note, if index_params.add_data_on_build is set to true, the user can set a
* stream pool in the input raft::resource with at least one stream to enable kernel and copy
* overlapping.
Expand Down Expand Up @@ -597,6 +663,12 @@ auto build(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
* - CosineExpanded
*
* Note, if index_params.add_data_on_build is set to true, the user can set a
* stream pool in the input raft::resource with at least one stream to enable kernel and copy
* overlapping.
Expand Down
115 changes: 97 additions & 18 deletions cpp/src/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <raft/linalg/add.cuh>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/matrix_vector.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/normalize.cuh>
Expand Down Expand Up @@ -141,6 +142,53 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
break;
}
case cuvs::distance::DistanceType::CosineExpanded: {
auto workspace = raft::make_device_mdarray<char, IdxT>(
handle, mr, raft::make_extents<IdxT>((sizeof(int)) * n_rows));

auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, raft::make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
thrust::fill(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
initial_value);

auto centroidsNorm =
raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_clusters));
raft::linalg::rowNorm<MathT, IdxT>(centroidsNorm.data_handle(),
centers,
dim,
n_clusters,
raft::linalg::L2Norm,
true,
stream,
raft::sqrt_op{});

cuvs::distance::fusedDistanceNNMinReduce<MathT, raft::KeyValuePair<IdxT, MathT>, IdxT>(
minClusterAndDistance.data_handle(),
dataset,
centers,
dataset_norm,
centroidsNorm.data_handle(),
n_rows,
n_clusters,
dim,
(void*)workspace.data_handle(),
false,
false,
true,
params.metric,
0.0f,
stream);
// Copy keys to output labels
thrust::transform(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + n_rows,
labels,
raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
break;
}
case cuvs::distance::DistanceType::InnerProduct: {
// TODO: pass buffer
rmm::device_uvector<MathT> distances(n_rows * n_clusters, stream, mr);
Expand Down Expand Up @@ -320,13 +368,14 @@ void calc_centers_and_sizes(const raft::resources& handle,
}

/** Computes the L2 norm of the dataset, converting to MathT if necessary */
template <typename T, typename MathT, typename IdxT, typename MappingOpT>
template <typename T, typename MathT, typename IdxT, typename MappingOpT, typename FinOpT>
void compute_norm(const raft::resources& handle,
MathT* dataset_norm,
const T* dataset,
IdxT dim,
IdxT n_rows,
MappingOpT mapping_op,
FinOpT norm_fin_op,
std::optional<rmm::device_async_resource_ref> mr = std::nullopt)
{
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("compute_norm");
Expand All @@ -347,7 +396,7 @@ void compute_norm(const raft::resources& handle,
}

raft::linalg::rowNorm<MathT, IdxT>(
dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream);
dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream, norm_fin_op);
}

/**
Expand Down Expand Up @@ -394,7 +443,8 @@ void predict(const raft::resources& handle,
std::is_same_v<T, MathT> ? 0 : max_minibatch_size * dim, stream, mem_res);
bool need_compute_norm =
dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded);
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded);
rmm::device_uvector<MathT> cur_dataset_norm(
need_compute_norm ? max_minibatch_size : 0, stream, mem_res);
const MathT* dataset_norm_ptr = nullptr;
Expand All @@ -411,8 +461,24 @@ void predict(const raft::resources& handle,

// Compute the norm now if it hasn't been pre-computed.
if (need_compute_norm) {
compute_norm(
handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res);
if (params.metric == cuvs::distance::DistanceType::CosineExpanded)
compute_norm(handle,
cur_dataset_norm.data(),
cur_dataset_ptr,
dim,
minibatch_size,
mapping_op,
raft::sqrt_op{},
mr);
else
compute_norm(handle,
cur_dataset_norm.data(),
cur_dataset_ptr,
dim,
minibatch_size,
mapping_op,
raft::identity_op{},
mr);
dataset_norm_ptr = cur_dataset_norm.data();
} else if (dataset_norm != nullptr) {
dataset_norm_ptr = dataset_norm + offset;
Expand Down Expand Up @@ -904,7 +970,8 @@ auto build_fine_clusters(const raft::resources& handle,
cub::TransformInputIterator<MathT, MappingOpT, const T*> mapping_itr(dataset_mptr, mapping_op);
raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream);
if (params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded) {
thrust::gather(raft::resource::get_thrust_policy(handle),
mc_trainset_ids,
mc_trainset_ids + k,
Expand Down Expand Up @@ -963,7 +1030,8 @@ void build_hierarchical(const raft::resources& handle,
IdxT n_rows,
MathT* cluster_centers,
IdxT n_clusters,
MappingOpT mapping_op)
MappingOpT mapping_op,
const MathT* dataset_norm = nullptr)
{
auto stream = raft::resource::get_cuda_stream(handle);
using LabelT = uint32_t;
Expand All @@ -980,21 +1048,32 @@ void build_hierarchical(const raft::resources& handle,
auto [max_minibatch_size, mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);

// Precompute the L2 norm of the dataset if relevant.
const MathT* dataset_norm = nullptr;
// Precompute the L2 norm of the dataset if relevant and not yet computed.
rmm::device_uvector<MathT> dataset_norm_buf(0, stream, device_memory);
if (params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
if (dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded)) {
dataset_norm_buf.resize(n_rows, stream);
for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) {
IdxT minibatch_size = std::min<IdxT>(max_minibatch_size, n_rows - offset);
compute_norm(handle,
dataset_norm_buf.data() + offset,
dataset + dim * offset,
dim,
minibatch_size,
mapping_op,
device_memory);
if (params.metric == cuvs::distance::DistanceType::CosineExpanded)
compute_norm(handle,
dataset_norm_buf.data() + offset,
dataset + dim * offset,
dim,
minibatch_size,
mapping_op,
raft::sqrt_op{},
device_memory);
else
compute_norm(handle,
dataset_norm_buf.data() + offset,
dataset + dim * offset,
dim,
minibatch_size,
mapping_op,
raft::identity_op{},
device_memory);
}
dataset_norm = (const MathT*)dataset_norm_buf.data();
}
Expand Down
Loading

0 comments on commit 05100ce

Please sign in to comment.