Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA codes [butterfly_multiply_untied_forward_max5_fast*] #30

Open
lloo099 opened this issue Apr 15, 2022 · 0 comments
Open

CUDA codes [butterfly_multiply_untied_forward_max5_fast*] #30

lloo099 opened this issue Apr 15, 2022 · 0 comments

Comments

@lloo099
Copy link

lloo099 commented Apr 15, 2022

Hi, I am researching on ur CUDA of butterfly. Sorry about my limited CUDA experiences. I read ur load_max5 function which includes remaining_input_idx, low_bits, high_bits and etc. It's not clear why U write like this to read data, and would you mind explain it? Thanks much.

` const int t_idx = threadIdx.x;
  const int batch_idx = (threadIdx.y + (blockIdx.x >> (log_n - nsteps)) * blockDim.y) * items_per_thread;
  const int remaining_input_idx = blockIdx.x & ((1 << (log_n - nsteps)) - 1);
  const int low_bits = remaining_input_idx & ((1 << input_idx_start_bit) - 1);
  const int high_bits = (remaining_input_idx >> input_idx_start_bit) << (input_idx_start_bit + nsteps);
  // All threads with the same t_idx should have the same input_idx
  const int input_idx = high_bits | (t_idx << input_idx_start_bit) | low_bits;
  const int input_idx_stride = (1 << input_idx_start_bit) * warpSize;
  twiddle_reader.load_max5<nsteps, increasing_stride>(s_twiddle, input_idx_start_bit, low_`


Could you explain the details about the load_max5 function? such as why these parameters(low_order_bits,remainder,s_idx) are used for index searching.

  template<int nsteps, bool increasing_stride>
  __device__ __forceinline__ void load_max5(scalar_t s_twiddle[nsteps][2][1 << nsteps],
                                            int input_idx_start_bit, int low_bits, int high_bits) {
    constexpr int span = 1 << nsteps;
    const int s = blockIdx.y + gridDim.y * blockIdx.z;  // For conv2d butterfly as well
    for (int t = threadIdx.x + threadIdx.y * blockDim.x; t < nsteps * (span / 2); t += blockDim.x * blockDim.y) {
      const int step = t / (span / 2);
      const int s_twiddle_stride = 1 << (increasing_stride ? step : nsteps - 1 - step);
      const int remainder = t % (span / 2);
      const int low_order_bits = remainder & (s_twiddle_stride - 1);
      const int s_idx = 2 * (remainder - low_order_bits) + low_order_bits;
      const int idx = (high_bits >> 1) | (remainder << input_idx_start_bit) | low_bits;
      s_twiddle[step][0][s_idx] = twiddle_a[s][step][idx][0][0];
      s_twiddle[step][1][s_idx] = twiddle_a[s][step][idx][0][1];
      s_twiddle[step][1][s_idx + s_twiddle_stride] = twiddle_a[s][step][idx][1][0];
      s_twiddle[step][0][s_idx + s_twiddle_stride] = twiddle_a[s][step][idx][1][1];
    }
  }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant