Skip to content

Commit

Permalink
clang-format
Browse files Browse the repository at this point in the history
  • Loading branch information
bkarsin committed Jan 15, 2025
1 parent 8a11ea0 commit e4b557d
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 83 deletions.
7 changes: 3 additions & 4 deletions cpp/src/neighbors/detail/vamana/greedy_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ __global__ void GreedySearchKernel(

static __shared__ Point<T, accT> s_query;


union ShmemLayout {
// All blocksort sizes have same alignment (16)
typename cub::BlockMergeSort<DistPair<IdxT, accT>, 32, 1>::TempStorage sort_mem;
Expand All @@ -113,15 +112,15 @@ __global__ void GreedySearchKernel(
DistPair<IdxT, accT> candidate_queue;
};

int align_padding = (((dim-1)/alignof(ShmemLayout))+1)*alignof(ShmemLayout) - dim;
int align_padding = (((dim - 1) / alignof(ShmemLayout)) + 1) * alignof(ShmemLayout) - dim;

// Dynamic shared memory used for blocksort, temp vector storage, and neighborhood list
extern __shared__ __align__(alignof(ShmemLayout)) char smem[];

size_t smem_offset = sort_smem_size; // temp sorting memory takes first chunk

T* s_coords = reinterpret_cast<T*>(&smem[smem_offset]);
smem_offset += (dim+align_padding) * sizeof(T);
smem_offset += (dim + align_padding) * sizeof(T);

Node<accT>* topk_pq = reinterpret_cast<Node<accT>*>(&smem[smem_offset]);
smem_offset += topk * sizeof(Node<accT>);
Expand Down Expand Up @@ -173,7 +172,7 @@ __global__ void GreedySearchKernel(

if (threadIdx.x == 0) { heap_queue.insert_back(medoid_dist, medoid_id); }
__syncthreads();

while (cand_q_size != 0) {
__syncthreads();

Expand Down
8 changes: 4 additions & 4 deletions cpp/src/neighbors/detail/vamana/robust_prune.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ __global__ void RobustPruneKernel(
// Dynamic shared memory used for blocksort, temp vector storage, and neighborhood list
extern __shared__ __align__(alignof(ShmemLayout)) char smem[];

int align_padding = (((dim-1)/alignof(ShmemLayout))+1)*alignof(ShmemLayout) - dim;
int align_padding = (((dim - 1) / alignof(ShmemLayout)) + 1) * alignof(ShmemLayout) - dim;

T* s_coords = reinterpret_cast<T*>(&smem[sort_smem_size]);
DistPair<IdxT, accT>* new_nbh_list =
reinterpret_cast<DistPair<IdxT, accT>*>(&smem[(dim+align_padding) * sizeof(T) + sort_smem_size]);
T* s_coords = reinterpret_cast<T*>(&smem[sort_smem_size]);
DistPair<IdxT, accT>* new_nbh_list = reinterpret_cast<DistPair<IdxT, accT>*>(
&smem[(dim + align_padding) * sizeof(T) + sort_smem_size]);

static __shared__ Point<T, accT> s_query;
s_query.coords = s_coords;
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/neighbors/vamana.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ index<T, IdxT> build(
const index_params& params,
raft::mdspan<const T, raft::matrix_extent<int64_t>, raft::row_major, Accessor> dataset)
{
return cuvs::neighbors::vamana::detail::build<T, IdxT, Accessor>(
res, params, dataset);
return cuvs::neighbors::vamana::detail::build<T, IdxT, Accessor>(res, params, dataset);
}

template <typename T, typename IdxT>
Expand Down
30 changes: 15 additions & 15 deletions cpp/src/neighbors/vamana_build_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@

namespace cuvs::neighbors::vamana {

#define RAFT_INST_VAMANA_BUILD(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
} \
\
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
#define RAFT_INST_VAMANA_BUILD(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
} \
\
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
}

RAFT_INST_VAMANA_BUILD(float, uint32_t);
Expand Down
30 changes: 15 additions & 15 deletions cpp/src/neighbors/vamana_build_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@

namespace cuvs::neighbors::vamana {

#define RAFT_INST_VAMANA_BUILD(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
} \
\
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
#define RAFT_INST_VAMANA_BUILD(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
} \
\
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
}

RAFT_INST_VAMANA_BUILD(int8_t, uint32_t);
Expand Down
30 changes: 15 additions & 15 deletions cpp/src/neighbors/vamana_build_uint8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@

namespace cuvs::neighbors::vamana {

#define RAFT_INST_VAMANA_BUILD(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
} \
\
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
#define RAFT_INST_VAMANA_BUILD(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
} \
\
auto build(raft::resources const& handle, \
const cuvs::neighbors::vamana::index_params& params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::vamana::index<T, IdxT> \
{ \
return cuvs::neighbors::vamana::build<T, IdxT>(handle, params, dataset); \
}

RAFT_INST_VAMANA_BUILD(uint8_t, uint32_t);
Expand Down
5 changes: 2 additions & 3 deletions cpp/src/neighbors/vamana_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ namespace cuvs::neighbors::vamana {
#define CUVS_INST_VAMANA_SERIALIZE(DTYPE) \
void serialize(raft::resources const& handle, \
const std::string& file_prefix, \
const cuvs::neighbors::vamana::index<DTYPE, uint32_t>& index_) \
const cuvs::neighbors::vamana::index<DTYPE, uint32_t>& index_) \
{ \
cuvs::neighbors::vamana::detail::serialize<DTYPE, uint32_t>( \
handle, file_prefix, index_); \
cuvs::neighbors::vamana::detail::serialize<DTYPE, uint32_t>(handle, file_prefix, index_); \
};

/** @} */ // end group vamana
Expand Down
58 changes: 33 additions & 25 deletions examples/cpp/src/vamana_example.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@
#include "common.cuh"

template <typename T>
void vamana_build_and_write(raft::device_resources const &dev_resources,
void vamana_build_and_write(raft::device_resources const& dev_resources,
raft::device_matrix_view<const T, int64_t> dataset,
std::string out_fname, int degree, int visited_size,
float max_fraction, int iters) {
std::string out_fname,
int degree,
int visited_size,
float max_fraction,
int iters)
{
using namespace cuvs::neighbors;

// use default index parameters
Expand All @@ -46,23 +50,24 @@ void vamana_build_and_write(raft::device_resources const &dev_resources,

auto start = std::chrono::system_clock::now();
auto index = vamana::build(dev_resources, index_params, dataset);
auto end = std::chrono::system_clock::now();
auto end = std::chrono::system_clock::now();
std::chrono::duration<double> elapsed_seconds = end - start;

std::cout << "Vamana index has " << index.size() << " vectors" << std::endl;
std::cout << "Vamana graph has degree " << index.graph_degree()
<< ", graph size [" << index.graph().extent(0) << ", "
<< index.graph().extent(1) << "]" << std::endl;
std::cout << "Vamana graph has degree " << index.graph_degree() << ", graph size ["
<< index.graph().extent(0) << ", " << index.graph().extent(1) << "]" << std::endl;

std::cout << "Time to build index: " << elapsed_seconds.count() << "s\n";

// Output index to file
serialize(dev_resources, out_fname, index);
}

void usage() {
printf("Usage: ./vamana_example <data filename> <output filename> <graph "
"degree> <visited_size> <max_fraction> <iterations> \n");
void usage()
{
printf(
"Usage: ./vamana_example <data filename> <output filename> <graph "
"degree> <visited_size> <max_fraction> <iterations> \n");
printf("Input file expected to be binary file of fp32 vectors.\n");
printf("Graph degree sizes supported: 32, 64, 128, 256\n");
printf("Visited_size must be > degree and a power of 2.\n");
Expand All @@ -71,13 +76,14 @@ void usage() {
exit(1);
}

int main(int argc, char *argv[]) {
int main(int argc, char* argv[])
{
raft::device_resources dev_resources;

// Set pool memory resource with 1 GiB initial pool size. All allocations use
// the same pool.
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr(
rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull);
rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull);
rmm::mr::set_current_device_resource(&pool_mr);

// Alternatively, one could define a pool allocator for temporary arrays (used
Expand All @@ -87,22 +93,24 @@ int main(int argc, char *argv[]) {
// limit. raft::resource::set_workspace_to_pool_resource(dev_resources, 2 *
// 1024 * 1024 * 1024ull);

if (argc != 7)
usage();
if (argc != 7) usage();

std::string data_fname = (std::string)(argv[1]); // Input filename
std::string out_fname = (std::string)argv[2]; // Output index filename
int degree = atoi(argv[3]);
int max_visited = atoi(argv[4]);
float max_fraction = atof(argv[5]);
int iters = atoi(argv[6]);
std::string data_fname = (std::string)(argv[1]); // Input filename
std::string out_fname = (std::string)argv[2]; // Output index filename
int degree = atoi(argv[3]);
int max_visited = atoi(argv[4]);
float max_fraction = atof(argv[5]);
int iters = atoi(argv[6]);

// Read in binary dataset file
auto dataset =
read_bin_dataset<uint8_t, int64_t>(dev_resources, data_fname, INT_MAX);
auto dataset = read_bin_dataset<uint8_t, int64_t>(dev_resources, data_fname, INT_MAX);

// Simple build example to create graph and write to a file
vamana_build_and_write<uint8_t>(
dev_resources, raft::make_const_mdspan(dataset.view()), out_fname, degree,
max_visited, max_fraction, iters);
vamana_build_and_write<uint8_t>(dev_resources,
raft::make_const_mdspan(dataset.view()),
out_fname,
degree,
max_visited,
max_fraction,
iters);
}

0 comments on commit e4b557d

Please sign in to comment.