Skip to content

Commit

Permalink
More explicit ldg cache behavior and a few smaller things
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Sep 4, 2024
1 parent 4d9241e commit 5fdcdd0
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 72 deletions.
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ add_library(
src/neighbors/detail/cagra/search_single_cta_half_uint64.cu
)

file(GLOB_RECURSE compute_distance_sources "src/neighbors/detail/cagra/*.cu")
file(GLOB_RECURSE compute_distance_sources "src/neighbors/detail/cagra/compute_distance_*.cu")
set_source_files_properties(${compute_distance_sources} PROPERTIES COMPILE_FLAGS -maxrregcount=64)

set_target_properties(
Expand Down
26 changes: 12 additions & 14 deletions cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@
#include "compute_distance_standard.hpp"

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/common.hpp>
#include <raft/core/logger-macros.hpp>
#include <raft/core/operators.hpp>
#include <raft/util/device_loads_stores.cuh>
#include <raft/util/pow2_utils.cuh>
#include <raft/util/vectorized.cuh>

#include <type_traits>

Expand Down Expand Up @@ -187,17 +183,19 @@ RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_standard_worker(
constexpr auto kTeamSize = DescriptorT::kTeamSize;
constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim;
constexpr auto vlen = device::get_vlen<LOAD_T, DATA_T>();
constexpr auto reg_nelem = raft::ceildiv<uint32_t>(kDatasetBlockDim, kTeamSize * vlen);
constexpr auto reg_nelem =
raft::div_rounding_up_unsafe<uint32_t>(kDatasetBlockDim, kTeamSize * vlen);

DISTANCE_T r = 0;
for (uint32_t elem_offset = (threadIdx.x % kTeamSize) * vlen; elem_offset < dim;
elem_offset += kDatasetBlockDim) {
raft::TxN_t<DATA_T, vlen> dl_buff[reg_nelem];
DATA_T data[reg_nelem][vlen];
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
const uint32_t k = e * (kTeamSize * vlen) + elem_offset;
if (k >= dim) break;
dl_buff[e].load(dataset_ptr, k);
device::ldg_cg(reinterpret_cast<LOAD_T&>(data[e]),
reinterpret_cast<const LOAD_T*>(dataset_ptr + k));
}
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
Expand All @@ -212,7 +210,7 @@ RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_standard_worker(
DISTANCE_T d;
device::lds(d, query_smem_ptr + sizeof(QUERY_T) * device::swizzling(k + v));
r += dist_op<DISTANCE_T, DescriptorT::kMetric>(
d, cuvs::spatial::knn::detail::utils::mapping<DISTANCE_T>{}(dl_buff[e].val.data[v]));
d, cuvs::spatial::knn::detail::utils::mapping<DISTANCE_T>{}(data[e][v]));
}
}
}
Expand All @@ -236,12 +234,12 @@ template <cuvs::distance::DistanceType Metric,
typename DataT,
typename IndexT,
typename DistanceT>
__launch_bounds__(1, 1) __global__ void standard_dataset_descriptor_init_kernel(
dataset_descriptor_base_t<DataT, IndexT, DistanceT>* out,
const DataT* ptr,
IndexT size,
uint32_t dim,
uint32_t ld)
RAFT_KERNEL __launch_bounds__(1, 1)
standard_dataset_descriptor_init_kernel(dataset_descriptor_base_t<DataT, IndexT, DistanceT>* out,
const DataT* ptr,
IndexT size,
uint32_t dim,
uint32_t ld)
{
using desc_type =
standard_dataset_descriptor_t<Metric, TeamSize, DatasetBlockDim, DataT, IndexT, DistanceT>;
Expand Down
120 changes: 65 additions & 55 deletions cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
#include "compute_distance_vpq.hpp"

#include <cuvs/distance/distance.hpp>
#include <raft/util/device_loads_stores.cuh>
#include <raft/util/integer_utils.hpp>
#include <raft/util/pow2_utils.cuh>

#include <type_traits>

namespace cuvs::neighbors::cagra::detail {

template <cuvs::distance::DistanceType Metric,
Expand Down Expand Up @@ -229,9 +230,12 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq(const DescriptorT* that,
}

template <typename DescriptorT>
_RAFT_DEVICE __noinline__ auto compute_distance_vpq(
const typename DescriptorT::args_t args, const typename DescriptorT::INDEX_T dataset_index) ->
typename DescriptorT::DISTANCE_T
_RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker(
const uint8_t* __restrict__ dataset_ptr,
const typename DescriptorT::CODE_BOOK_T* __restrict__ vq_code_book_ptr,
uint32_t dim,
uint32_t pq_codebook_ptr,
uint32_t n_subspace) -> typename DescriptorT::DISTANCE_T
{
using DISTANCE_T = typename DescriptorT::DISTANCE_T;
using LOAD_T = typename DescriptorT::LOAD_T;
Expand All @@ -242,52 +246,48 @@ _RAFT_DEVICE __noinline__ auto compute_distance_vpq(
constexpr auto PQ_BITS = DescriptorT::kPqBits;
constexpr auto PQ_LEN = DescriptorT::kPqLen;

const uint32_t pq_codebook_ptr = args.smem_ws_ptr;
const uint32_t query_ptr = pq_codebook_ptr + DescriptorT::kSMemCodeBookSizeInBytes;
const auto* __restrict__ node_ptr =
DescriptorT::encoded_dataset_ptr(args) +
(static_cast<std::uint64_t>(DescriptorT::encoded_dataset_dim(args)) * dataset_index);
const unsigned lane_id = threadIdx.x % TeamSize;
// const uint32_t& vq_code = *reinterpret_cast<const std::uint32_t*>(node_ptr);
uint32_t vq_code;
raft::ldg(vq_code, reinterpret_cast<const std::uint32_t*>(node_ptr));
const uint32_t query_ptr = pq_codebook_ptr + DescriptorT::kSMemCodeBookSizeInBytes;
{
uint32_t vq_code;
device::ldg_cg(vq_code, reinterpret_cast<const std::uint32_t*>(dataset_ptr));
vq_code_book_ptr += dim * vq_code;
}
static_assert(PQ_BITS == 8, "Only pq_bits == 8 is supported at the moment.");
constexpr uint32_t vlen = 4; // **** DO NOT CHANGE ****
constexpr uint32_t nelem =
raft::div_rounding_up_unsafe<uint32_t>(DatasetBlockDim / PQ_LEN, TeamSize * vlen);
DISTANCE_T norm = 0;
for (uint32_t elem_offset = 0; elem_offset < args.dim; elem_offset += DatasetBlockDim) {
constexpr unsigned vlen = 4; // **** DO NOT CHANGE ****
constexpr unsigned nelem =
raft::div_rounding_up_unsafe<unsigned>(DatasetBlockDim / PQ_LEN, TeamSize * vlen);
for (uint32_t elem_offset = (threadIdx.x % TeamSize) * (vlen * PQ_LEN); elem_offset < dim;
elem_offset += DatasetBlockDim) {
// Loading PQ codes
uint32_t pq_codes[nelem];
#pragma unroll
for (std::uint32_t e = 0; e < nelem; e++) {
const std::uint32_t k = (lane_id + (TeamSize * e)) * vlen + elem_offset / PQ_LEN;
if (k >= DescriptorT::n_subspace(args)) break;
const std::uint32_t k = e * (TeamSize * vlen) + elem_offset / PQ_LEN;
if (k >= n_subspace) break;
// Loading 4 x 8-bit PQ-codes using 32-bit load ops (from device memory)
raft::ldg(pq_codes[e], reinterpret_cast<const std::uint32_t*>(node_ptr + 4 + k));
device::ldg_cg(pq_codes[e], reinterpret_cast<const std::uint32_t*>(dataset_ptr + 4 + k));
}
//
if constexpr (PQ_LEN % 2 == 0) {
// **** Use half2 for distance computation ****
#pragma unroll 1
#pragma unroll
for (std::uint32_t e = 0; e < nelem; e++) {
const std::uint32_t k = (lane_id + (TeamSize * e)) * vlen + elem_offset / PQ_LEN;
if (k >= DescriptorT::n_subspace(args)) break;
const std::uint32_t k = e * (TeamSize * vlen) + elem_offset / PQ_LEN;
if (k >= n_subspace) break;
// Loading VQ code-book
raft::TxN_t<half2, vlen / 2> vq_vals[PQ_LEN];
half2 vq_vals[PQ_LEN][vlen / 2];
#pragma unroll
for (std::uint32_t m = 0; m < PQ_LEN; m += 1) {
const uint32_t d = (vlen * m) + (PQ_LEN * k);
if (d >= args.dim) break;
vq_vals[m].load(reinterpret_cast<const half2*>(DescriptorT::vq_code_book_ptr(args) + d +
(args.dim * vq_code)),
0);
if (d >= dim) break;
device::ldg_ca(vq_vals[m], vq_code_book_ptr + d);
}
// Compute distance
std::uint32_t pq_code = pq_codes[e];
#pragma unroll
for (std::uint32_t v = 0; v < vlen; v++) {
if (PQ_LEN * (v + k) >= args.dim) break;
if (PQ_LEN * (v + k) >= dim) break;
#pragma unroll
for (std::uint32_t m = 0; m < PQ_LEN; m += 2) {
const std::uint32_t d1 = m + (PQ_LEN * v);
Expand All @@ -300,9 +300,9 @@ _RAFT_DEVICE __noinline__ auto compute_distance_vpq(
// Loading PQ code book from smem
device::lds(c2,
pq_codebook_ptr + sizeof(CODE_BOOK_T) * ((1 << PQ_BITS) * 2 * (m / 2) +
(2 * (pq_code & 0xff))));
(2 * (pq_codes[e] & 0xff))));
// L2 distance
auto dist = q2 - c2 - vq_vals[d1 / vlen].val.data[(d1 % vlen) / 2];
auto dist = q2 - c2 - vq_vals[d1 / vlen][(d1 % vlen) / 2];
dist = dist * dist;
norm += static_cast<DISTANCE_T>(dist.x + dist.y);
}
Expand All @@ -313,37 +313,33 @@ _RAFT_DEVICE __noinline__ auto compute_distance_vpq(
// **** Use float for distance computation ****
#pragma unroll
for (std::uint32_t e = 0; e < nelem; e++) {
const std::uint32_t k = (lane_id + (TeamSize * e)) * vlen + elem_offset / PQ_LEN;
if (k >= DescriptorT::n_subspace(args)) break;
const std::uint32_t k = e * (TeamSize * vlen) + elem_offset / PQ_LEN;
if (k >= n_subspace) break;
// Loading VQ code-book
raft::TxN_t<CODE_BOOK_T, vlen> vq_vals[PQ_LEN];
CODE_BOOK_T vq_vals[PQ_LEN][vlen];
#pragma unroll
for (std::uint32_t m = 0; m < PQ_LEN; m++) {
const std::uint32_t d = (vlen * m) + (PQ_LEN * k);
if (d >= args.dim) break;
// Loading 4 x 8/16-bit VQ-values using 32/64-bit load ops (from L2$ or device
// memory)
vq_vals[m].load(reinterpret_cast<const half2*>(DescriptorT::vq_code_book_ptr(args) + d +
(args.dim * vq_code)),
0);
if (d >= dim) break;
// Loading 4 x 8/16-bit VQ-values using 32/64-bit load ops (from L2$ or device memory)
device::ldg_ca(vq_vals[m], vq_code_book_ptr + d);
}
// Compute distance
std::uint32_t pq_code = pq_codes[e];
#pragma unroll
for (std::uint32_t v = 0; v < vlen; v++) {
if (PQ_LEN * (v + k) >= args.dim) break;
raft::TxN_t<CODE_BOOK_T, PQ_LEN> pq_vals;
device::lds(*pq_vals.vectorized_data(),
pq_codebook_ptr + sizeof(CODE_BOOK_T) * PQ_LEN * (pq_code & 0xff));
if (PQ_LEN * (v + k) >= dim) break;
CODE_BOOK_T pq_vals[PQ_LEN];
device::lds(pq_vals, pq_codebook_ptr + sizeof(CODE_BOOK_T) * PQ_LEN * (pq_code & 0xff));
#pragma unroll
for (std::uint32_t m = 0; m < PQ_LEN; m++) {
const std::uint32_t d1 = m + (PQ_LEN * v);
const std::uint32_t d = d1 + (PQ_LEN * k);
// if (d >= dataset_dim) break;
DISTANCE_T diff;
device::lds(diff, query_ptr + sizeof(QUERY_T) * d);
diff -= static_cast<DISTANCE_T>(pq_vals.data[m]);
diff -= static_cast<DISTANCE_T>(vq_vals[d1 / vlen].val.data[d1 % vlen]);
diff -= static_cast<DISTANCE_T>(pq_vals[m]);
diff -= static_cast<DISTANCE_T>(vq_vals[d1 / vlen][d1 % vlen]);
norm += diff * diff;
}
pq_code >>= 8;
Expand All @@ -354,6 +350,20 @@ _RAFT_DEVICE __noinline__ auto compute_distance_vpq(
return norm;
}

template <typename DescriptorT>
_RAFT_DEVICE __noinline__ auto compute_distance_vpq(
const typename DescriptorT::args_t args, const typename DescriptorT::INDEX_T dataset_index) ->
typename DescriptorT::DISTANCE_T
{
return compute_distance_vpq_worker<DescriptorT>(
DescriptorT::encoded_dataset_ptr(args) +
(static_cast<std::uint64_t>(DescriptorT::encoded_dataset_dim(args)) * dataset_index),
DescriptorT::vq_code_book_ptr(args),
args.dim,
args.smem_ws_ptr,
DescriptorT::n_subspace(args));
}

template <cuvs::distance::DistanceType Metric,
uint32_t TeamSize,
uint32_t DatasetBlockDim,
Expand All @@ -363,15 +373,15 @@ template <cuvs::distance::DistanceType Metric,
typename DataT,
typename IndexT,
typename DistanceT>
__launch_bounds__(1, 1) __global__
void vpq_dataset_descriptor_init_kernel(dataset_descriptor_base_t<DataT, IndexT, DistanceT>* out,
const std::uint8_t* encoded_dataset_ptr,
uint32_t encoded_dataset_dim,
uint32_t n_subspace,
const CodebookT* vq_code_book_ptr,
const CodebookT* pq_code_book_ptr,
IndexT size,
uint32_t dim)
RAFT_KERNEL __launch_bounds__(1, 1)
vpq_dataset_descriptor_init_kernel(dataset_descriptor_base_t<DataT, IndexT, DistanceT>* out,
const std::uint8_t* encoded_dataset_ptr,
uint32_t encoded_dataset_dim,
uint32_t n_subspace,
const CodebookT* vq_code_book_ptr,
const CodebookT* pq_code_book_ptr,
IndexT size,
uint32_t dim)
{
using desc_type = cagra_q_dataset_descriptor_t<Metric,
TeamSize,
Expand Down
94 changes: 94 additions & 0 deletions cpp/src/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,25 @@ RAFT_DEVICE_INLINE_FUNCTION void lds(half2& x, uint32_t addr)
{
asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(reinterpret_cast<uint32_t&>(x)) : "r"(addr));
}
RAFT_DEVICE_INLINE_FUNCTION void lds(half (&x)[1], uint32_t addr)
{
asm volatile("ld.shared.u16 {%0}, [%1];" : "=h"(*reinterpret_cast<uint16_t*>(x)) : "r"(addr));
}
RAFT_DEVICE_INLINE_FUNCTION void lds(half (&x)[2], uint32_t addr)
{
asm volatile("ld.shared.v2.u16 {%0, %1}, [%2];"
: "=h"(*reinterpret_cast<uint16_t*>(x)), "=h"(*reinterpret_cast<uint16_t*>(x + 1))
: "r"(addr));
}
RAFT_DEVICE_INLINE_FUNCTION void lds(half (&x)[4], uint32_t addr)
{
asm volatile("ld.shared.v4.u16 {%0, %1, %2, %3}, [%4];"
: "=h"(*reinterpret_cast<uint16_t*>(x)),
"=h"(*reinterpret_cast<uint16_t*>(x + 1)),
"=h"(*reinterpret_cast<uint16_t*>(x + 2)),
"=h"(*reinterpret_cast<uint16_t*>(x + 3))
: "r"(addr));
}

RAFT_DEVICE_INLINE_FUNCTION void lds(uint4& x, uint32_t addr)
{
Expand All @@ -260,5 +279,80 @@ RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const half2& x)
"h"(reinterpret_cast<const uint16_t&>(x.y)));
}

RAFT_DEVICE_INLINE_FUNCTION void ldg_cg(uint4& x, const uint4* addr)
{
asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(x.x), "=r"(x.y), "=r"(x.z), "=r"(x.w)
: "l"(addr));
}

RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(uint4& x, const uint4* addr)
{
asm volatile("ld.global.ca.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(x.x), "=r"(x.y), "=r"(x.z), "=r"(x.w)
: "l"(addr));
}

RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(uint32_t& x, const uint32_t* addr)
{
asm volatile("ld.global.ca.u32 %0, [%1];" : "=r"(x) : "l"(addr));
}

RAFT_DEVICE_INLINE_FUNCTION void ldg_cg(uint32_t& x, const uint32_t* addr)
{
asm volatile("ld.global.cg.u32 %0, [%1];" : "=r"(x) : "l"(addr));
}

RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half& x, const half* addr)
{
asm volatile("ld.global.ca.u16 {%0}, [%1];"
: "=h"(reinterpret_cast<uint16_t&>(x))
: "l"(reinterpret_cast<const uint16_t*>(addr)));
}
RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half (&x)[1], const half* addr)
{
asm volatile("ld.global.ca.u16 {%0}, [%1];"
: "=h"(*reinterpret_cast<uint16_t*>(x))
: "l"(reinterpret_cast<const uint16_t*>(addr)));
}
RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half (&x)[2], const half* addr)
{
asm volatile("ld.global.ca.v2.u16 {%0, %1}, [%2];"
: "=h"(*reinterpret_cast<uint16_t*>(x)), "=h"(*reinterpret_cast<uint16_t*>(x + 1))
: "l"(reinterpret_cast<const uint16_t*>(addr)));
}
RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half (&x)[4], const half* addr)
{
asm volatile("ld.global.ca.v4.u16 {%0, %1, %2, %3}, [%4];"
: "=h"(*reinterpret_cast<uint16_t*>(x)),
"=h"(*reinterpret_cast<uint16_t*>(x + 1)),
"=h"(*reinterpret_cast<uint16_t*>(x + 2)),
"=h"(*reinterpret_cast<uint16_t*>(x + 3))
: "l"(reinterpret_cast<const uint16_t*>(addr)));
}

RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half2& x, const half* addr)
{
asm volatile("ld.global.ca.v2.u16 {%0, %1}, [%2];"
: "=h"(reinterpret_cast<uint16_t&>(x.x)), "=h"(reinterpret_cast<uint16_t&>(x.y))
: "l"(reinterpret_cast<const uint16_t*>(addr)));
}
RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half2 (&x)[1], const half* addr)
{
asm volatile("ld.global.ca.v2.u16 {%0, %1}, [%2];"
: "=h"(reinterpret_cast<uint16_t&>(x[0].x)),
"=h"(reinterpret_cast<uint16_t&>(x[0].y))
: "l"(reinterpret_cast<const uint16_t*>(addr)));
}
RAFT_DEVICE_INLINE_FUNCTION void ldg_ca(half2 (&x)[2], const half* addr)
{
asm volatile("ld.global.ca.v4.u16 {%0, %1, %2, %3}, [%4];"
: "=h"(reinterpret_cast<uint16_t&>(x[0].x)),
"=h"(reinterpret_cast<uint16_t&>(x[0].y)),
"=h"(reinterpret_cast<uint16_t&>(x[1].x)),
"=h"(reinterpret_cast<uint16_t&>(x[1].y))
: "l"(reinterpret_cast<const uint16_t*>(addr)));
}

} // namespace device
} // namespace cuvs::neighbors::cagra::detail
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ __device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements]
// multiple CTAs per single query
//
template <std::uint32_t MAX_ELEMENTS, class DATASET_DESCRIPTOR_T, class SAMPLE_FILTER_T>
RAFT_KERNEL search_kernel(
RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel(
typename DATASET_DESCRIPTOR_T::INDEX_T* const
result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size]
typename DATASET_DESCRIPTOR_T::DISTANCE_T* const
Expand Down
Loading

0 comments on commit 5fdcdd0

Please sign in to comment.