Skip to content

Commit

Permalink
[xla][gpu] Implement pipelined-p2p-rewriter.
Browse files Browse the repository at this point in the history
This pass rewrite pipelined point-to-point communication by rotating the
SendDone and RecvDone operations in a while-body to the beginning of the next iteration.
The SendDone and RecvDone operations for the last iteration are moved to the
while-op calling computation, after the while-op.

Add the pass to the GPU post-scheduler pipeline.

This is another approach to achieve the code pattern to pipeline two Send-Recv
chains decomposed from a collective-permute with a source-target pair cycle for
performance. The pipelined Send-Recv pattern puts SendDone and RecvDone before
Send and Recv in the while-body, and if we generate such code pattern too early
in the GPU compilation pipeline, copy-insertion may generate copies of Send
causing Send and SendDone with different buffers and thus correctness problem.

PiperOrigin-RevId: 630121252
  • Loading branch information
bixia1 authored and copybara-github committed May 2, 2024
1 parent adf109d commit a4e712c
Show file tree
Hide file tree
Showing 6 changed files with 1,519 additions and 0 deletions.
41 changes: 41 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3399,6 +3399,7 @@ cc_library(
]),
deps = if_gpu_is_configured([
":gpu_p2p_pipeliner",
":pipelined_p2p_rewriter",
":collective_permute_cycle_decomposer",
":address_computation_fusion_rewriter",
":algorithm_checker",
Expand Down Expand Up @@ -3648,6 +3649,7 @@ xla_test(
"//xla/service:pattern_matcher_gmock",
"//xla/service:xla_debug_info_manager",
"//xla/service/gpu:autotuner_util",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"@com_google_absl//absl/log",
Expand Down Expand Up @@ -6142,3 +6144,42 @@ xla_test(
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "pipelined_p2p_rewriter",
srcs = ["pipelined_p2p_rewriter.cc"],
hdrs = ["pipelined_p2p_rewriter.h"],
deps = [
"//xla:shape_util",
"//xla:status",
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/service:collective_ops_utils",
"//xla/service:hlo_pass",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

xla_cc_test(
name = "pipelined_p2p_rewriter_test",
srcs = ["pipelined_p2p_rewriter_test.cc"],
deps = [
":pipelined_p2p_rewriter",
"//xla/hlo/ir:hlo",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test_main",
],
)
7 changes: 7 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ limitations under the License.
#include "xla/service/gpu/model/gpu_cost_model_stats_collection.h"
#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h"
#include "xla/service/gpu/move_copy_to_users.h"
#include "xla/service/gpu/pipelined_p2p_rewriter.h"
#include "xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h"
#include "xla/service/gpu/reduction_degenerate_dim_remover.h"
#include "xla/service/gpu/reduction_dimension_grouper.h"
Expand Down Expand Up @@ -2199,6 +2200,12 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
{
HloPassPipeline pipeline("post-scheduling-passes");

if (module->config()
.debug_options()
.xla_gpu_enable_pipelined_collectives() ||
module->config().debug_options().xla_gpu_enable_pipelined_p2p()) {
pipeline.AddPass<PipelinedP2PRewriter>();
}
HloPredicate is_nop =
HloPredicateIsOp<HloOpcode::kParameter, HloOpcode::kConstant,
HloOpcode::kBitcast, HloOpcode::kGetTupleElement>;
Expand Down
127 changes: 127 additions & 0 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ limitations under the License.
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/service/xla_debug_info_manager.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
Expand Down Expand Up @@ -395,6 +396,132 @@ ENTRY main {
triton_disabled_module->computation_count());
}

TEST_F(GpuCompilerTest, CollectivePermuteDecompositionAndPipelining) {
const char* kModuleStr = R"(
HloModule cp
cond {
param = (u32[], f32[1, 1024, 1024]) parameter(0)
count = get-tuple-element(%param), index=0
ub = u32[] constant(11)
ROOT result = pred[] compare(count, ub), direction=LT
}
body {
param = (u32[], f32[1, 1024, 1024]) parameter(0)
count = get-tuple-element(%param), index=0
send-data = get-tuple-element(%param), index=1
recv-data = f32[1, 1024, 1024] collective-permute(send-data),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}, channel_id=1
// The computation code that uses the current recv-data and
// produces the send-data for the next iteration.
c1 = u32[] constant(1)
new_count = u32[] add(count, c1)
replica = u32[] replica-id()
c10 = u32[] constant(10)
sum = u32[] add(replica, c10)
sum2 = u32[] add(sum, count)
conv = f32[] convert(sum2)
p = f32[1, 1024, 1024] broadcast(conv), dimensions={}
b = f32[1, 1024, 1024] add(p, recv-data)
c = f32[1, 1024, 1024] multiply(b, b)
d = f32[1, 1024, 1024] tan(c)
s = f32[1, 1024, 1024] dot(c, d), lhs_batch_dims={0},
lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1}
ROOT result = (u32[], f32[1, 1024, 1024]) tuple(new_count, s)
}
ENTRY test_computation {
c0 = u32[] constant(0)
f0 = f32[] constant(0.0)
init = f32[1, 1024, 1024] broadcast(f0), dimensions={}
while_init = (u32[], f32[1, 1024, 1024]) tuple(c0, init)
while_result = (u32[], f32[1, 1024, 1024]) while(while_init), body=body, condition=cond
ROOT result = f32[1, 1024, 1024] get-tuple-element(while_result), index=1
}
)";

// In the expected string, we skip some detail on the while-init tuple due to
// b/333572009.
const char* kExpected = R"(
CHECK: %body.1 (param.2.0: (u32[], f32[1,1024,1024], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), u32[])) -> (u32[], f32[1,1024,1024], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), u32[]) {
CHECK: %param.2.0 = parameter(0)
CHECK: %get-tuple-element.38 = get-tuple-element(%param.2.0), index=2
CHECK: %get-tuple-element.39 = get-tuple-element(%param.2.0), index=3
CHECK-DAG: %get-tuple-element.22 = get-tuple-element(%param.2.0), index=0
CHECK-DAG: %recv-done.3 = recv-done(%get-tuple-element.38), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
CHECK-DAG: %get-tuple-element.25 = get-tuple-element(%recv-done.3), index=0
CHECK: %loop_multiply_tan_fusion = fusion
CHECK: %send-done.3 = send-done(%get-tuple-element.39), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
CHECK: %custom-call.1.0 = custom-call
CHECK: %after-all.3.0 = after-all()
CHECK{LITERAL}: %recv.2.0 = recv(%after-all.3.0), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}, control-predecessors={%custom-call.1.0}
CHECK{LITERAL}: %send.2.0 = send(%bitcast.119.0, %after-all.3.0), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}, control-predecessors={%recv.2.0}
CHECK: %loop_add_fusion = fusion
CHECK: %loop_add_fusion.1 = fusion
CHECK: ROOT %tuple.13 = tuple(%loop_add_fusion.1, %bitcast.119.0, %recv.2.0, %send.2.0, %loop_add_fusion)
CHECK: }
CHECK: %cond.1 (cond_param.1: (u32[], f32[1,1024,1024], (f32[1,1024,1024], u32[], token[]), (f32[1,1024,1024], u32[], token[]), u32[])) -> pred[] {
CHECK: %cond_param.1 = parameter(0)
CHECK: %get-tuple-element.5.0 = get-tuple-element(%cond_param.1), index=0
CHECK: ROOT %loop_compare_fusion = fusion(%get-tuple-element.5.0), kind=kLoop, calls=%fused_compare
CHECK: }
CHECK: ENTRY %test_computation () -> f32[1,1024,1024] {
CHECK: %after-all.1.0 = after-all()
CHECK: %loop_broadcast_fusion = fusion
CHECK{LITERAL}: %recv.1.0 = recv(%after-all.1.0), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}
CHECK{LITERAL}: %send.1.0 = send(%loop_broadcast_fusion, %after-all.1.0), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}, control-predecessors={%recv.1.0}
CHECK: %copy_fusion = fusion
CHECK: %get-tuple-element.36 = get-tuple-element(%copy_fusion), index=0
CHECK: %get-tuple-element.37 = get-tuple-element(%copy_fusion), index=1
CHECK: %bitcast.170 = bitcast(%get-tuple-element.36)
CHECK: %bitcast.171 = bitcast(%get-tuple-element.37)
CHECK: %while-init = tuple
CHECK-SAME: %recv.1.0, %send.1.0
CHECK{LITERAL}: %while-result = while(%while-init), condition=%cond.1, body=%body.1, backend_config={"known_trip_count":{"n":"10"}}
CHECK: %get-tuple-element.40 = get-tuple-element(%while-result), index=2
CHECK: %get-tuple-element.41 = get-tuple-element(%while-result), index=3
CHECK: %recv-done.4 = recv-done(%get-tuple-element.40), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
CHECK: %get-tuple-element.7.0 = get-tuple-element(%recv-done.4), index=0
CHECK: %loop_multiply_tan_fusion.1 = fusion
CHECK: %get-tuple-element.13 = get-tuple-element(%loop_multiply_tan_fusion.1), index=0
CHECK: %get-tuple-element.14 = get-tuple-element(%loop_multiply_tan_fusion.1), index=1
CHECK: %bitcast.150.0 = bitcast(%get-tuple-element.13)
CHECK: %bitcast.155.0 = bitcast(%get-tuple-element.14)
CHECK: %send-done.4 = send-done(%get-tuple-element.41), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"}
CHECK: %custom-call.3.0 = custom-call(%bitcast.150.0, %bitcast.155.0), custom_call_target="__cublas$gemm"
CHECK: %get-tuple-element.10.0 = get-tuple-element(%custom-call.3.0), index=0
CHECK: ROOT %bitcast.5.0 = bitcast(%get-tuple-element.10.0)
CHECK: }
)";

HloModuleConfig config;
DebugOptions debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true);
debug_options.set_xla_gpu_collective_permute_decomposer_threshold(1);
debug_options.set_xla_gpu_enable_pipelined_p2p(true);
debug_options.set_xla_gpu_enable_triton_gemm(false);
config.set_debug_options(debug_options);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
GetOptimizedModule(std::move(module)));
TF_ASSERT_OK(Schedule(optimized_module.get()));

HloPrintOptions options;
options.set_print_operand_shape(false);
options.set_print_result_shape(false);
TF_ASSERT_OK_AND_ASSIGN(
bool filecheck_matched,
RunFileCheck(optimized_module->ToString(options), kExpected));
EXPECT_TRUE(filecheck_matched);
}

} // namespace
} // namespace gpu
} // namespace xla
Loading

0 comments on commit a4e712c

Please sign in to comment.