Skip to content

Commit

Permalink
Merge pull request #61 from ROCm/rocm-jaxlib-v0.4.31-qa-buffer-comp-fix
Browse files Browse the repository at this point in the history
[ROCm] Fixed issue with Failed to launch ROCm kernel.
  • Loading branch information
Ruturaj4 authored Nov 6, 2024
2 parents 9baff69 + 8d48929 commit 8eb4ca9
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 90 deletions.
6 changes: 6 additions & 0 deletions xla/service/gpu/buffer_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,14 @@ static absl::StatusOr<bool> DeviceCompare(std::string_view kernel_name,
const se::DeviceDescription& gpu_device_info =
executor->GetDeviceDescription();

#ifdef GOOGLE_CUDA
LaunchDimensions dim =
CalculateLaunchDimensions(*params.shape, gpu_device_info);
#else
LaunchDimensions dim =
CalculateLaunchDimensions(*params.shape, gpu_device_info,
{128 / sizeof(ElementT)});
#endif // GOOGLE_CUDA

se::DeviceMemory<uint64_t> as_uint64(out.memory());
TF_RETURN_IF_ERROR(params.stream->ThenLaunch(
Expand Down
236 changes: 146 additions & 90 deletions xla/service/gpu/buffer_comparator.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,141 +108,197 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8;
elem_a_fp8.__x = buffer_a[idx];
elem_b_fp8.__x = buffer_b[idx];
float elem_a = static_cast<float>(elem_a_fp8);
float elem_b = static_cast<float>(elem_b_fp8);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) return;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
int mcount = 0;
uint64_t unroll = 128 / sizeof(*buffer_a);
uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll;
for (unsigned i = 0; i < unroll; ++i) {
if ((idx+i) < buffer_length) {
__hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8;
elem_a_fp8.__x = buffer_a[idx+i];
elem_b_fp8.__x = buffer_b[idx+i];
float elem_a = static_cast<float>(elem_a_fp8);
float elem_b = static_cast<float>(elem_b_fp8);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) continue;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
mcount++;
}
}
if (mcount)
atomicAdd(mismatch_count, mcount);
}

__global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,
__hip_fp8_storage_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8;
elem_a_fp8.__x = buffer_a[idx];
elem_b_fp8.__x = buffer_b[idx];
float elem_a = static_cast<float>(elem_a_fp8);
float elem_b = static_cast<float>(elem_b_fp8);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) return;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
int mcount = 0;
uint64_t unroll = 128 / sizeof(*buffer_a);
uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll;
for (unsigned i = 0; i < unroll; ++i) {
if ((idx+i) < buffer_length) {
__hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8;
elem_a_fp8.__x = buffer_a[idx+i];
elem_b_fp8.__x = buffer_b[idx+i];
float elem_a = static_cast<float>(elem_a_fp8);
float elem_b = static_cast<float>(elem_b_fp8);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) continue;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
mcount++;
}
}
if (mcount)
atomicAdd(mismatch_count, mcount);
}
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200

__global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
float elem_a = __half2float(buffer_a[idx]);
float elem_b = __half2float(buffer_b[idx]);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) return;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
int mcount = 0;
uint64_t unroll = 128 / sizeof(*buffer_a);
uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll;
for (unsigned i = 0; i < unroll; ++i) {
if ((idx+i) < buffer_length) {
float elem_a = __half2float(buffer_a[idx+i]);
float elem_b = __half2float(buffer_b[idx+i]);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) continue;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
mcount++;
}
}
if (mcount)
atomicAdd(mismatch_count, mcount);
}

__global__ void xla_fp32_comparison(float* buffer_a, float* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
float elem_a = buffer_a[idx];
float elem_b = buffer_b[idx];
if (isnan(elem_a) && isnan(elem_b)) return;
if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
return;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);
if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
int mcount = 0;
uint64_t unroll = 128 / sizeof(*buffer_a);
uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll;
for (unsigned i = 0; i < unroll; ++i) {
if ((idx+i) < buffer_length) {
float elem_a = buffer_a[idx+i];
float elem_b = buffer_b[idx+i];
if (isnan(elem_a) && isnan(elem_b)) continue;
if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
continue;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);
if (rel_error > rel_error_threshold || isnan(rel_error))
mcount++;
}
}
if (mcount)
atomicAdd(mismatch_count, mcount);
}

__global__ void xla_fp64_comparison(double* buffer_a, double* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;

double elem_a = buffer_a[idx];
double elem_b = buffer_b[idx];
if (isnan(elem_a) && isnan(elem_b)) return;
if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
return;
double rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);
if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
int mcount = 0;
uint64_t unroll = 128 / sizeof(*buffer_a);
uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll;
for (unsigned i = 0; i < unroll; ++i) {
if ((idx+i) < buffer_length) {
double elem_a = buffer_a[idx+i];
double elem_b = buffer_b[idx+i];
if (isnan(elem_a) && isnan(elem_b)) continue;
if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
continue;
double rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);
if (rel_error > rel_error_threshold || isnan(rel_error))
mcount++;
}
}
if (mcount)
atomicAdd(mismatch_count, mcount);
}

__global__ void xla_bf16_comparison(bfloat16* buffer_a, bfloat16* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
float elem_a = BF16_TO_F32(buffer_a[idx]);
float elem_b = BF16_TO_F32(buffer_b[idx]);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) return;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
int mcount = 0;
uint64_t unroll = 128 / sizeof(*buffer_a);
uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll;

for (unsigned i = 0; i < unroll; ++i) {
if ((idx+i) < buffer_length) {
float elem_a = BF16_TO_F32(buffer_a[idx+i]);
float elem_b = BF16_TO_F32(buffer_b[idx+i]);
elem_a = Canonicalize(elem_a);
elem_b = Canonicalize(elem_b);
if (isnan(elem_a) && isnan(elem_b)) continue;

float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);

if (rel_error > rel_error_threshold || isnan(rel_error))
mcount++;
}
}
if (mcount)
atomicAdd(mismatch_count, mcount);
}

// TODO(b/191520348): The comparison below requires exact equality.
__global__ void xla_int8_comparison(int8_t* buffer_a, int8_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
float a = buffer_a[idx];
float b = buffer_b[idx];
float rel_error = abs(a - b) / (max(abs(a), abs(b)) + 1);
if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
int mcount = 0;
uint64_t unroll = 128;
uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll;
for (unsigned i = 0; i < unroll; ++i) {
if ((idx+i) < buffer_length) {
float a = buffer_a[idx+i];
float b = buffer_b[idx+i];
float rel_error = abs(a - b) / (max(abs(a), abs(b)) + 1);
if (rel_error > rel_error_threshold || isnan(rel_error))
mcount++;
}
}
if (mcount)
atomicAdd(mismatch_count, mcount);
}

__global__ void xla_int32_comparison(int* buffer_a, int* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
float elem_a = static_cast<float>(buffer_a[idx]);
float elem_b = static_cast<float>(buffer_b[idx]);
float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);
if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
int mcount = 0;
uint64_t unroll = 128 / sizeof(*buffer_a);
uint64_t idx = (threadIdx.x + blockIdx.x * blockDim.x) * unroll;
for (unsigned i = 0; i < unroll; ++i) {
if ((idx+i) < buffer_length) {
float elem_a = static_cast<float>(buffer_a[idx+i]);
float elem_b = static_cast<float>(buffer_b[idx+i]);
float rel_error = abs(elem_a - elem_b) / (max(abs(elem_a), abs(elem_b)) + 1);
if (rel_error > rel_error_threshold || isnan(rel_error))
mcount++;
}
}
if (mcount)
atomicAdd(mismatch_count, mcount);
}

} // namespace
Expand Down

0 comments on commit 8eb4ca9

Please sign in to comment.