Skip to content

Commit

Permalink
Use MAP macro to shorten code
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 2, 2019
1 parent d952308 commit c003602
Showing 1 changed file with 28 additions and 126 deletions.
154 changes: 28 additions & 126 deletions butterfly/factor_multiply_fast/butterfly_multiply_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <thrust/complex.h>
#include <thrust/pair.h>
#include <thrust/tuple.h>
#include "map.h" // For the MAP macro, i.e. for_each over the arguments

#define thc_cos std::cos
#define thc_sin std::sin
Expand All @@ -23,7 +24,8 @@ static constexpr int MAX_BLOCK_SIZE = 1024;
// static constexpr int MAX_N_FACTORS = 10;
static constexpr int ITEMS_PER_THREAD_FORWARD[14] = {4, 4, 4, 4, 4, 4, 4, 8, 8, 8, 13, 10, 4, 4};
static constexpr int ITEMS_PER_THREAD_BACKWARD[14] = {16, 16, 16, 16, 16, 16, 16, 16, 16, 4, 1, 1, 1, 1};
static constexpr int MIN_BLOCKS_PER_MP[14] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1};
static constexpr int MIN_BLOCKS_PER_MP_FORWARD[14] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1};
static constexpr int MIN_BLOCKS_PER_MP_BACKWARD[14] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};

template <typename T, size_t N>
using CudaAcsr = at::PackedTensorAccessor<T, N, at::RestrictPtrTraits, int32_t>;
Expand Down Expand Up @@ -194,7 +196,7 @@ __device__ __forceinline__ void b_untied_forward(const CudaAcsr<scalar_t, 4> twi

template <int log_n, bool increasing_stride,
int items_per_thread=ITEMS_PER_THREAD_FORWARD[log_n - 1],
int min_blocks_per_mp=MIN_BLOCKS_PER_MP[log_n - 1],
int min_blocks_per_mp=MIN_BLOCKS_PER_MP_FORWARD[log_n - 1],
int max_smem_per_thread=items_per_thread, typename scalar_t>
// C10_LAUNCH_BOUNDS_2 supposedly takes min(1 << log_n, 1024)
// https://github.com/pytorch/pytorch/blob/v1.1.0/c10/macros/Macros.h
Expand Down Expand Up @@ -259,78 +261,17 @@ void butterfly_multiply_untied_forward_fast_cuda(const at::Tensor &twiddle,
auto stream = at::cuda::getCurrentCUDAStream();
switch (log_n)
{
case 1:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<1, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<1, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 2:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<2, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<2, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 3:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<3, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<3, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 4:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<4, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<4, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 5:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<5, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<5, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 6:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<6, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<6, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 7:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<7, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<7, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 8:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<8, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<8, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 9:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<9, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<9, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 10:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<10, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<10, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 11:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<11, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<11, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 12:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<12, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<12, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 13:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<13, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<13, false>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;
case 14:
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<14, true>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size)
: butterfly_multiply_untied_forward_fast_cuda_kernel<14, false>
#define CASE_LOG_N(log_n_val) case log_n_val: \
increasing_stride ? butterfly_multiply_untied_forward_fast_cuda_kernel<log_n_val, true> \
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size) \
: butterfly_multiply_untied_forward_fast_cuda_kernel<log_n_val, false> \
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, output_writer, batch_size); break;

MAP(CASE_LOG_N, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)
}
});
// Have to keep this #undef outside the AT_DISPATCH_FLOATING_TYPES macro for it to work
#undef CASE_LOG_N
AT_CHECK(cudaGetLastError() == cudaSuccess,
"butterfly_multiply_untied_forward_fast_cuda failed with error code ",
cudaGetLastError());
Expand Down Expand Up @@ -398,13 +339,15 @@ __device__ __forceinline__ void b_untied_forward_backward(const CudaAcsr<scalar_
}
}

template <int log_n, bool increasing_stride, int items_per_thread,
template <int log_n, bool increasing_stride,
int items_per_thread=ITEMS_PER_THREAD_BACKWARD[log_n - 1],
int max_reg_storage_per_thread=items_per_thread,
int min_blocks_per_mp=1, int max_smem_per_thread=items_per_thread,
typename scalar_t>
// C10_LAUNCH_BOUNDS_2 already takes min(1 << log_n, 1024)
int min_blocks_per_mp=MIN_BLOCKS_PER_MP_BACKWARD[log_n - 1],
int max_smem_per_thread=items_per_thread, typename scalar_t>
// C10_LAUNCH_BOUNDS_2 supposedly takes min(1 << log_n, 1024)
// https://github.com/pytorch/pytorch/blob/v1.1.0/c10/macros/Macros.h
C10_LAUNCH_BOUNDS_2(1 << log_n, min_blocks_per_mp)
// However, it doesn't seem to work correctly so I have to take min explicitly.
C10_LAUNCH_BOUNDS_2(MIN_MACRO(1 << log_n, MAX_BLOCK_SIZE), min_blocks_per_mp)
__global__ void butterfly_multiply_untied_forward_backward_fast_cuda_kernel(const CudaAcsr<scalar_t, 4> twiddle_a,
InputReader<scalar_t> input_reader,
InputReader<scalar_t> grad_reader,
Expand Down Expand Up @@ -481,59 +424,18 @@ void butterfly_multiply_untied_forward_backward_fast_cuda(const at::Tensor &twid
auto stream = at::cuda::getCurrentCUDAStream();
switch (log_n)
{
case 1:
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<1, true, ITEMS_PER_THREAD_BACKWARD[0]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size)
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<1, false, ITEMS_PER_THREAD_BACKWARD[0]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size); break;
case 2:
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<2, true, ITEMS_PER_THREAD_BACKWARD[1]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size)
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<2, false, ITEMS_PER_THREAD_BACKWARD[1]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size); break;
case 3:
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<3, true, ITEMS_PER_THREAD_BACKWARD[2]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size)
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<3, false, ITEMS_PER_THREAD_BACKWARD[2]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size); break;
case 4:
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<4, true, ITEMS_PER_THREAD_BACKWARD[3]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size)
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<4, false, ITEMS_PER_THREAD_BACKWARD[3]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size); break;
case 5:
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<5, true, ITEMS_PER_THREAD_BACKWARD[4]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size)
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<5, false, ITEMS_PER_THREAD_BACKWARD[4]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size); break;
case 6:
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<6, true, ITEMS_PER_THREAD_BACKWARD[5]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size)
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<6, false, ITEMS_PER_THREAD_BACKWARD[5]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size); break;
case 7:
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<7, true, ITEMS_PER_THREAD_BACKWARD[6]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size)
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<7, false, ITEMS_PER_THREAD_BACKWARD[6]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size); break;
case 8:
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<8, true, ITEMS_PER_THREAD_BACKWARD[7]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size)
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<8, false, ITEMS_PER_THREAD_BACKWARD[7]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size); break;
case 9:
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<9, true, ITEMS_PER_THREAD_BACKWARD[8]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size)
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<9, false, ITEMS_PER_THREAD_BACKWARD[8]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size); break;
case 10:
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<10, true, ITEMS_PER_THREAD_BACKWARD[9]>
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size)
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<10, false, ITEMS_PER_THREAD_BACKWARD[9]>
#define CASE_LOG_N(log_n_val) case log_n_val: \
increasing_stride ? butterfly_multiply_untied_forward_backward_fast_cuda_kernel<log_n_val, true> \
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size) \
: butterfly_multiply_untied_forward_backward_fast_cuda_kernel<log_n_val, false> \
<<<grid, block, 0, stream>>>(twiddle_a, input_reader, grad_reader, d_twiddle_a, d_input_writer, batch_size); break;

MAP(CASE_LOG_N, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
}
});
// Have to keep this #undef outside the AT_DISPATCH_FLOATING_TYPES macro for it to work
#undef CASE_LOG_N
AT_CHECK(cudaGetLastError() == cudaSuccess,
"butterfly_multiply_untied_forward_backward_fast_cuda failed with error code ",
cudaGetLastError());
}
}

0 comments on commit c003602

Please sign in to comment.