Skip to content

Commit

Permalink
Merge branch 'rocm-jaxlib-v0.4.30-qa' into rocm-jaxlib-v0.4.30-qa-cle…
Browse files Browse the repository at this point in the history
…anup
  • Loading branch information
hsharsha authored Aug 29, 2024
2 parents 339dde0 + 8c73dfe commit 8ae1de7
Show file tree
Hide file tree
Showing 15 changed files with 267 additions and 187 deletions.
13 changes: 11 additions & 2 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_llvm_enable_invariant_load_metadata(true);
opts.set_xla_llvm_disable_expensive_passes(false);
opts.set_xla_backend_optimization_level(3);
opts.set_xla_gpu_autotune_level(4);
opts.set_xla_gpu_autotune_level(5);
opts.set_xla_gpu_autotune_max_solutions(0);
opts.set_xla_cpu_multi_thread_eigen(true);
opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
Expand Down Expand Up @@ -270,6 +270,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_shard_autotuning(false);

opts.set_xla_gpu_autotune_gemm_rtol(0.1f);

return opts;
}

Expand Down Expand Up @@ -816,13 +818,20 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level),
debug_options->xla_gpu_autotune_level(),
"Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = "
"on+init; 3 = on+init+reinit; 4 = on+init+reinit+check."));
"on+init; 3 = on+init+reinit; 4 = on+init+reinit+check; "
" 5 = on+init+reinit+check and skip WRONG_RESULT solutions. See also "
" the related flag xla_gpu_autotune_gemm_rtol."));
flag_list->push_back(tsl::Flag(
"xla_gpu_autotune_max_solutions",
int64_setter_for(&DebugOptions::set_xla_gpu_autotune_max_solutions),
debug_options->xla_gpu_autotune_max_solutions(),
"Maximal number of GEMM solutions to consider for autotuning: 0 means "
"consider all solutions returned by the GEMM library."));
flag_list->push_back(tsl::Flag(
"xla_gpu_autotune_gemm_rtol",
float_setter_for(&DebugOptions::set_xla_gpu_autotune_gemm_rtol),
debug_options->xla_gpu_autotune_gemm_rtol(),
"Relative precision for comparing GEMM solutions vs the reference one"));
flag_list->push_back(tsl::Flag(
"xla_force_host_platform_device_count",
int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count),
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,7 @@ xla_test(
":backend_configs_cc",
":gemm_algorithm_picker",
":gemm_rewriter",
":variant_visitor",
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/autotuner_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class AutotuneConfig {
bool should_init_buffers() const { return autotune_level_ >= 2; }
bool should_reinit_output_buffer() const { return autotune_level_ >= 3; }
bool should_check_correctness() const { return autotune_level_ >= 4; }
bool should_skip_wrong_results() const { return autotune_level_ >= 5; }
bool should_crash_on_check_failure() const {
return should_crash_on_check_failure_;
}
Expand Down
127 changes: 66 additions & 61 deletions xla/service/gpu/buffer_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,35 @@ using ComparisonKernelT =
se::TypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
float, uint64_t, se::DeviceMemory<uint64_t>>;

struct ComparisonParams {
double relative_tol = 0.1f;
bool verbose = true;
const Shape *shape = nullptr;
se::Stream* stream = nullptr;
se::DeviceMemoryBase current{};
se::DeviceMemoryBase expected{};
};

// Compares two buffers on the GPU.
//
// Returns `true` if two buffers are equal, `false` otherwise.
template <typename ElementT>
absl::StatusOr<bool> BufferComparator::DeviceCompare(
se::Stream* stream, se::DeviceMemoryBase current,
se::DeviceMemoryBase expected, std::string_view kernel_name,
void* kernel_symbol) const {
se::StreamExecutor* executor = stream->parent();
static absl::StatusOr<bool> DeviceCompare(
std::string_view kernel_name, void* kernel_symbol,
const ComparisonParams& params) {
se::StreamExecutor* executor = params.stream->parent();

se::DeviceMemoryHandle out_param(executor,
se::DeviceMemoryHandle out(executor,
executor->AllocateScalar<uint64_t>());

TF_RETURN_IF_ERROR(stream->MemZero(out_param.memory_ptr(), sizeof(uint64_t)));
if (current.size() != expected.size()) {
TF_RETURN_IF_ERROR(params.stream->MemZero(out.memory_ptr(), sizeof(uint64_t)));
if (params.current.size() != params.expected.size()) {
return Internal("Mismatched buffer size: %d bytes vs. %d bytes",
current.size(), expected.size());
params.current.size(), params.expected.size());
}

se::DeviceMemory<ElementT> current_typed(current);
se::DeviceMemory<ElementT> expected_typed(expected);
se::DeviceMemory<ElementT> current_typed(params.current);
se::DeviceMemory<ElementT> expected_typed(params.expected);
uint64_t buffer_size = current_typed.ElementCount();

TF_ASSIGN_OR_RETURN(
Expand All @@ -80,19 +88,20 @@ absl::StatusOr<bool> BufferComparator::DeviceCompare(
const se::DeviceDescription& gpu_device_info =
executor->GetDeviceDescription();

LaunchDimensions dim = CalculateLaunchDimensions(shape_, gpu_device_info);
LaunchDimensions dim =
CalculateLaunchDimensions(*params.shape, gpu_device_info);

se::DeviceMemory<uint64_t> as_uint64(out_param.memory());
TF_RETURN_IF_ERROR(stream->ThenLaunch(
se::DeviceMemory<uint64_t> as_uint64(out.memory());
TF_RETURN_IF_ERROR(params.stream->ThenLaunch(
dim.thread_counts_per_block(), dim.block_counts(), comparison_kernel,
current_typed, expected_typed, static_cast<float>(tolerance_),
current_typed, expected_typed, static_cast<float>(params.relative_tol),
buffer_size, as_uint64));

uint64_t result = -1;
CHECK_EQ(out_param.memory().size(), sizeof(result));
CHECK_EQ(out.memory().size(), sizeof(result));
TF_RETURN_IF_ERROR(
stream->Memcpy(&result, out_param.memory(), sizeof(result)));
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
params.stream->Memcpy(&result, out.memory(), sizeof(result)));
TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone());
return result == 0;
}

Expand All @@ -101,16 +110,16 @@ absl::StatusOr<bool> BufferComparator::DeviceCompare(
//
// Returns true if no differences were seen, false otherwise.
template <typename ElementType, typename ComparisonType>
absl::StatusOr<bool> BufferComparator::HostCompare(
se::Stream* stream, se::DeviceMemoryBase current,
se::DeviceMemoryBase expected) const {
int64_t n = current.size() / sizeof(ElementType);
static absl::StatusOr<bool> HostCompare(const ComparisonParams& params) {
int64_t n = params.current.size() / sizeof(ElementType);
std::vector<ElementType> host_current(n), host_expected(n);
TF_RETURN_IF_ERROR(
stream->Memcpy(host_current.data(), current, current.size()));
params.stream->Memcpy(host_current.data(), params.current,
params.current.size()));
TF_RETURN_IF_ERROR(
stream->Memcpy(host_expected.data(), expected, expected.size()));
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
params.stream->Memcpy(host_expected.data(), params.expected,
params.expected.size()));
TF_RETURN_IF_ERROR(params.stream->BlockHostUntilDone());

const auto canonicalize = [](ComparisonType a) -> ComparisonType {
if (std::is_same<ElementType, Eigen::half>::value && a) {
Expand All @@ -123,6 +132,7 @@ absl::StatusOr<bool> BufferComparator::HostCompare(
return a;
};
int differences_seen = 0;

for (int64_t i = 0; i < n && differences_seen < 10; ++i) {
auto current_value = static_cast<ComparisonType>(host_current[i]);
auto expected_value = static_cast<ComparisonType>(host_expected[i]);
Expand All @@ -142,32 +152,30 @@ absl::StatusOr<bool> BufferComparator::HostCompare(
!(std::abs(current_value_canonical - expected_value_canonical) /
(std::max(std::abs(current_value_canonical),
std::abs(expected_value_canonical)) +
1) <
tolerance_)) {
1) < params.relative_tol)) {
if (!params.verbose) return false; // Return immediately if not verbose.
++differences_seen;
LOG(ERROR) << "Difference at " << i << ": " << current_value
<< ", expected " << expected_value;
<< ", expected " << expected_value;
}
}
return differences_seen == 0;
}

template <typename ElementT, typename ComparisonT>
absl::StatusOr<bool> BufferComparator::CompareEqualParameterized(
se::Stream* stream, se::DeviceMemoryBase current,
se::DeviceMemoryBase expected, std::string_view kernel_name,
void* kernel_symbol) const {
static absl::StatusOr<bool> CompareEqualParameterized(
std::string_view kernel_name, void* kernel_symbol,
const ComparisonParams& params) {
XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual");
TF_ASSIGN_OR_RETURN(bool result,
DeviceCompare<ElementT>(stream, current, expected,
kernel_name, kernel_symbol));
TF_ASSIGN_OR_RETURN(
bool result, DeviceCompare<ElementT>(kernel_name, kernel_symbol, params));

if (result) {
return true;
}

TF_ASSIGN_OR_RETURN(bool host_return, (HostCompare<ElementT, ComparisonT>(
stream, current, expected)));
TF_ASSIGN_OR_RETURN(bool host_return,
(HostCompare<ElementT, ComparisonT>(params)));
CHECK_EQ(host_return, result)
<< "Host comparison succeeded even though GPU comparison failed.";
return false;
Expand All @@ -176,60 +184,57 @@ absl::StatusOr<bool> BufferComparator::CompareEqualParameterized(
absl::StatusOr<bool> BufferComparator::CompareEqual(
se::Stream* stream, se::DeviceMemoryBase current,
se::DeviceMemoryBase expected) const {

ComparisonParams params{
relative_tol_, verbose_, &shape_, stream, current, expected};

switch (shape_.element_type()) {
#if GOOGLE_CUDA // not available for ROCm yet..
case xla::F8E4M3FN:
return CompareEqualParameterized<tsl::float8_e4m3fn, float>(
stream, current, expected, "fp8_e4m3fn_comparison",
buffer_comparator::fp8_e4m3fn_comparison());
"fp8_e4m3fn_comparison", buffer_comparator::fp8_e4m3fn_comparison(),
params);
case xla::F8E5M2:
return CompareEqualParameterized<tsl::float8_e5m2, float>(
stream, current, expected, "fp8_e5m2_comparison",
buffer_comparator::fp8_e5m2_comparison());
"fp8_e5m2_comparison", buffer_comparator::fp8_e5m2_comparison(),
params);
#endif // GOOGLE_CUDA
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
case xla::F8E4M3FNUZ:
return CompareEqualParameterized<tsl::float8_e4m3fnuz, float>(
stream, current, expected, "fp8_e4m3fnuz_comparison",
buffer_comparator::fp8_e4m3fnuz_comparison());
"fp8_e4m3fnuz_comparison",
buffer_comparator::fp8_e4m3fnuz_comparison(), params);
case xla::F8E5M2FNUZ:
return CompareEqualParameterized<tsl::float8_e5m2fnuz, float>(
stream, current, expected, "fp8_e5m2fnuz_comparison",
buffer_comparator::fp8_e5m2fnuz_comparison());
"fp8_e5m2fnuz_comparison",
buffer_comparator::fp8_e5m2fnuz_comparison(), params);
#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
case xla::F16:
return CompareEqualParameterized<Eigen::half, float>(
stream, current, expected, "fp16_comparison",
buffer_comparator::fp16_comparison());
"fp16_comparison", buffer_comparator::fp16_comparison(), params);
case xla::BF16:
return CompareEqualParameterized<Eigen::bfloat16, float>(
stream, current, expected, "bf16_comparison",
buffer_comparator::bf16_comparison());
"bf16_comparison", buffer_comparator::bf16_comparison(), params);
case xla::F32:
return CompareEqualParameterized<float, float>(
stream, current, expected, "fp32_comparison",
buffer_comparator::fp32_comparison());
"fp32_comparison", buffer_comparator::fp32_comparison(), params);
case xla::F64:
return CompareEqualParameterized<double, double>(
stream, current, expected, "fp64_comparison",
buffer_comparator::fp64_comparison());
"fp64_comparison", buffer_comparator::fp64_comparison(), params);
case xla::S8:
return CompareEqualParameterized<int8_t, float>(
stream, current, expected, "int8_comparison",
buffer_comparator::int8_comparison());
"int8_comparison", buffer_comparator::int8_comparison(), params);
case xla::S32:
return CompareEqualParameterized<int32_t, float>(
stream, current, expected, "int32_comparison",
buffer_comparator::int32_comparison());
"int32_comparison", buffer_comparator::int32_comparison(), params);
default:
return Unimplemented("Unimplemented element type");
}
}

BufferComparator::BufferComparator(const Shape& shape,
const HloModuleConfig& config,
double tolerance)
: shape_(shape), config_(config), tolerance_(tolerance) {
BufferComparator::BufferComparator(const Shape& shape, double tolerance,
bool verbose) :
shape_(shape), relative_tol_(tolerance), verbose_(verbose) {
// Normalize complex shapes: since we treat the passed array as a contiguous
// storage it does not matter which dimension are we doubling.
auto double_dim_size = [&]() {
Expand Down
28 changes: 4 additions & 24 deletions xla/service/gpu/buffer_comparator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class BufferComparator {
BufferComparator(const BufferComparator&) = delete;
BufferComparator(BufferComparator&&) = default;

BufferComparator(const Shape& shape, const HloModuleConfig& config,
double tolerance = 0.1);
explicit BufferComparator(const Shape& shape, double tolerance = 0.1,
bool verbose = true);

// Returns true if the two buffers compare equal. The definition of "equal"
// is:
Expand All @@ -49,30 +49,10 @@ class BufferComparator {
absl::StatusOr<bool> CompareEqual(se::Stream* stream,
se::DeviceMemoryBase current,
se::DeviceMemoryBase expected) const;

private:
template <typename ElementT, typename ComparisonT>
absl::StatusOr<bool> CompareEqualParameterized(se::Stream* stream,
se::DeviceMemoryBase current,
se::DeviceMemoryBase expected,
std::string_view kernel_name,
void* kernel_symbol) const;

template <typename ElementType, typename ComparisonType>
absl::StatusOr<bool> HostCompare(se::Stream* stream,
se::DeviceMemoryBase current,
se::DeviceMemoryBase expected) const;

template <typename ElementT>
absl::StatusOr<bool> DeviceCompare(se::Stream* stream,
se::DeviceMemoryBase current,
se::DeviceMemoryBase expected,
std::string_view kernel_name,
void* kernel_symbol) const;

Shape shape_;
HloModuleConfig config_;
double tolerance_;
double relative_tol_; // relative tolerance for comparison
bool verbose_; // whether to print out error message on mismatch
};

namespace buffer_comparator {
Expand Down
15 changes: 12 additions & 3 deletions xla/service/gpu/buffer_comparator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class BufferComparatorTest : public testing::Test {
ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<ElementType>(),
{static_cast<int64_t>(current.size())}),
HloModuleConfig(), tolerance);
tolerance);
return comparator
.CompareEqual(stream.get(), current_buffer.memory(),
expected_buffer.memory())
Expand Down Expand Up @@ -261,6 +261,16 @@ TEST_F(BufferComparatorTest, TestNumbers) {
EXPECT_TRUE(CompareEqualFloatBuffers<tsl::float8_e5m2>({11}, {12}));
EXPECT_TRUE(CompareEqualFloatBuffers<tsl::float8_e5m2>({12}, {11}));
#endif // GOOGLE_CUDA

// Rerunning tests with increased relative tolerance
const double tol = 0.001;
EXPECT_FALSE(CompareEqualFloatBuffers<Eigen::half>({0.9}, {1}, tol));
EXPECT_TRUE(CompareEqualFloatBuffers<Eigen::half>({0.9}, {0.901}, tol));
EXPECT_FALSE(CompareEqualFloatBuffers<float>({10}, {10.1}, tol));
EXPECT_TRUE(CompareEqualFloatBuffers<float>({10}, {10.01}, tol));
EXPECT_FALSE(CompareEqualFloatBuffers<int8_t>({100}, {101}, tol));
EXPECT_FALSE(CompareEqualFloatBuffers<double>({20}, {20.1}, tol));
EXPECT_TRUE(CompareEqualFloatBuffers<double>({20}, {20.01}, tol));
}

TEST_F(BufferComparatorTest, TestMultiple) {
Expand Down Expand Up @@ -384,8 +394,7 @@ TEST_F(BufferComparatorTest, BF16) {
stream_exec_->AllocateArray<Eigen::bfloat16>(element_count));
InitializeBuffer(stream.get(), BF16, &rng_state, rhs.memory());

BufferComparator comparator(ShapeUtil::MakeShape(BF16, {element_count}),
HloModuleConfig());
BufferComparator comparator(ShapeUtil::MakeShape(BF16, {element_count}));
EXPECT_FALSE(comparator.CompareEqual(stream.get(), lhs.memory(), rhs.memory())
.value());
}
Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/conv_algorithm_picker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,11 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(

if (reference_result->has_value()) {
XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);

const DebugOptions& debug_options =
runtime_arguments.hlo_module_config.debug_options();
BufferComparator comparator(runtime_arguments.rz_buffers.output_shape(),
runtime_arguments.hlo_module_config);
debug_options.xla_gpu_autotune_gemm_rtol());
for (int i = 0; i < result_buffers.size(); ++i) {
absl::StatusOr<bool> compare_result = comparator.CompareEqual(
stream, (*reference_result)->buffers[i], result_buffers[i]);
Expand All @@ -690,8 +693,6 @@ absl::StatusOr<AutotuneResult> GpuConvAlgorithmPicker::AutotuneOneConvRunner(
// Possibly OOM. Propagate the error.
return compare_result.status();
}
const DebugOptions& debug_options =
runtime_arguments.hlo_module_config.debug_options();
CHECK(!debug_options.xla_gpu_crash_on_verification_failures());
} else if (!compare_result.value()) {
LOG(ERROR)
Expand Down
Loading

0 comments on commit 8ae1de7

Please sign in to comment.