diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index d2b4dfcbb3f615..a35840689a8592 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -3731,7 +3731,7 @@ void AnnotateShardingWithSimpleHeuristic( const DotDimensionNumbers& dot_dnums = inst->dot_dimension_numbers(); // const auto& lhs_con_dims = dot_dnums.lhs_contracting_dimensions(); // const auto& rhs_con_dims = dot_dnums.rhs_contracting_dimensions(); - std::vector lhs_space_dims, rhs_space_dims; + tsl::protobuf::RepeatedField lhs_space_dims, rhs_space_dims; std::tie(lhs_space_dims, rhs_space_dims) = GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 7090a33dc24fda..522848a3416798 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -39,6 +39,8 @@ limitations under the License. namespace xla { namespace spmd { +using DimMap = StableHashMap; + void AppendNewStrategy(const HloInstruction* ins, const std::string& name, const HloSharding& output_spec, absl::Span input_specs, @@ -108,22 +110,27 @@ class DotHandler { CHECK_EQ(lhs_batch_dims_.size(), rhs_batch_dims_.size()); } + bool CheckDims(const HloInstruction* ins, + const tsl::protobuf::RepeatedField& instr_dims, + const DimMap& dim_map) const { + for (const auto& [instr_dim_idx, mesh_dim_idx] : dim_map) { + auto instr_dim = instr_dims.at(instr_dim_idx); + auto shape_dim = ins->shape().dimensions().at(instr_dim); + auto mesh_dim = device_mesh_.dim(mesh_dim_idx); + if (shape_dim < mesh_dim) return false; + if (solver_option_.only_allow_divisible_intermediate && + !IsDivisible(shape_dim, mesh_dim)) + return false; + } + return true; + } + void SplitLhsSpaceRhsSpace(int mesh_dim0, int mesh_dim1) { for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { for (int64_t j = 0; j < rhs_space_dims_.size(); ++j) { - if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) < - device_mesh_.dim(mesh_dim0) || - rhs_->shape().dimensions().at(rhs_space_dims_.at(j)) < - device_mesh_.dim(mesh_dim1)) { + if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dim0}}) || + !CheckDims(rhs_, rhs_space_dims_, {{j, mesh_dim1}})) continue; - } - if (solver_option_.only_allow_divisible_intermediate && - (!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(j)), - device_mesh_.dim(mesh_dim1)))) { - continue; - } std::string name = absl::StrFormat("SS = SR x RS @ {%d,%d}", mesh_dim0, mesh_dim1); HloSharding output_spec = @@ -146,19 +153,8 @@ class DotHandler { void SplitLhsSpaceOnly(int mesh_dim0, int mesh_dim1) { for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { for (int64_t j = i + 1; j < lhs_space_dims_.size(); ++j) { - if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) < - device_mesh_.dim(mesh_dim0) || - lhs_->shape().dimensions().at(lhs_space_dims_.at(j)) < - device_mesh_.dim(mesh_dim1)) { + if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dim0}, {j, mesh_dim1}})) continue; - } - if (solver_option_.only_allow_divisible_intermediate && - (!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(j)), - device_mesh_.dim(mesh_dim1)))) { - continue; - } std::string name = absl::StrFormat("SSR = SSR x RR @ {%d,%d}", mesh_dim0, mesh_dim1); HloSharding output_spec = @@ -178,19 +174,8 @@ class DotHandler { void SplitRhsSpaceOnly(int mesh_dim0, int mesh_dim1) { for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) { for (int64_t j = i + 1; j < rhs_space_dims_.size(); ++j) { - if (rhs_->shape().dimensions().at(rhs_space_dims_.at(i)) < - device_mesh_.dim(mesh_dim0) || - rhs_->shape().dimensions().at(rhs_space_dims_.at(j)) < - device_mesh_.dim(mesh_dim1)) { + if (!CheckDims(rhs_, rhs_space_dims_, {{i, mesh_dim0}, {j, mesh_dim1}})) continue; - } - if (solver_option_.only_allow_divisible_intermediate && - (!IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(j)), - device_mesh_.dim(mesh_dim1)))) { - continue; - } std::string name = absl::StrFormat("RSS = RR x RSS @ {%d,%d}", mesh_dim0, mesh_dim1); HloSharding output_spec = Tile( @@ -217,20 +202,9 @@ class DotHandler { mesh_dim1, mesh_dim1); for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { for (int64_t j = 0; j < lhs_con_dims_.size(); ++j) { - if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) < - device_mesh_.dim(mesh_dim0) || - lhs_->shape().dimensions().at(lhs_con_dims_.at(j)) < - device_mesh_.dim(mesh_dim1)) { + if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dim0}}) || + !CheckDims(lhs_, lhs_con_dims_, {{j, mesh_dim1}})) continue; - } - if (solver_option_.only_allow_divisible_intermediate && - (!IsDivisible( - lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)), - device_mesh_.dim(mesh_dim1)))) { - continue; - } HloSharding output_spec = Tile(ins_->shape(), {space_base_dim_ + i}, {mesh_dim0}, device_mesh_); @@ -258,20 +232,9 @@ class DotHandler { mesh_dim1, mesh_dim0); for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) { for (int64_t j = 0; j < lhs_con_dims_.size(); ++j) { - if (rhs_->shape().dimensions().at(rhs_space_dims_.at(i)) < - device_mesh_.dim(mesh_dim1) || - lhs_->shape().dimensions().at(lhs_con_dims_.at(j)) < - device_mesh_.dim(mesh_dim0)) { + if (!CheckDims(rhs_, rhs_space_dims_, {{i, mesh_dim1}}) || + !CheckDims(lhs_, lhs_con_dims_, {{j, mesh_dim0}})) continue; - } - if (solver_option_.only_allow_divisible_intermediate && - (!IsDivisible( - rhs_->shape().dimensions().at(rhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim1)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)), - device_mesh_.dim(mesh_dim0)))) { - continue; - } HloSharding output_spec = Tile(ins_->shape(), {space_base_dim_ + @@ -299,15 +262,7 @@ class DotHandler { [](int64_t size) { return size > 1; }) == 1) { for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) { for (int64_t j = 0; j < device_mesh_.num_dimensions(); ++j) { - if (lhs_->shape().dimensions().at(lhs_batch_dims_.at(i)) < - device_mesh_.dim(j)) { - continue; - } - if (solver_option_.only_allow_divisible_intermediate && - !IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(i)), - device_mesh_.dim(j))) { - continue; - } + if (!CheckDims(lhs_, lhs_batch_dims_, {{i, j}})) continue; std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d}", i, j); HloSharding output_spec = Tile(ins_->shape(), {i}, {j}, device_mesh_); HloSharding lhs_spec = @@ -325,19 +280,8 @@ class DotHandler { void SplitTwoBatchDims(int mesh_dim0, int mesh_dim1) { if (lhs_batch_dims_.size() == 2 && device_mesh_.dim(mesh_dim0) > 1 && device_mesh_.dim(mesh_dim1) > 1) { - if (lhs_->shape().dimensions().at(lhs_batch_dims_.at(0)) < - device_mesh_.dim(mesh_dim0) || - lhs_->shape().dimensions().at(lhs_batch_dims_.at(1)) < - device_mesh_.dim(mesh_dim1)) { - return; - } - if (solver_option_.only_allow_divisible_intermediate && - (!IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(0)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(1)), - device_mesh_.dim(mesh_dim1)))) { + if (!CheckDims(lhs_, lhs_batch_dims_, {{0, mesh_dim0}, {1, mesh_dim1}})) return; - } std::string name = absl::StrFormat("Sb = Sb x Sb @ {%d,%d}", mesh_dim0, mesh_dim1); HloSharding output_spec = @@ -360,21 +304,9 @@ class DotHandler { absl::StrFormat("SbSi = SbSi x SbR @ {%d,%d}", mesh_dim0, mesh_dim1); for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) { - if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) < - device_mesh_.dim(mesh_dim0) || - lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)) < - device_mesh_.dim(mesh_dim1)) { - continue; - } - if (solver_option_.only_allow_divisible_intermediate && - (!IsDivisible( - lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible( - lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)), - device_mesh_.dim(mesh_dim1)))) { + if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dim0}}) || + !CheckDims(lhs_, lhs_batch_dims_, {{j, mesh_dim1}})) continue; - } HloSharding output_spec = Tile(ins_->shape(), {j, space_base_dim_ + i}, {mesh_dim0, mesh_dim1}, device_mesh_); @@ -398,21 +330,9 @@ class DotHandler { absl::StrFormat("SbSj = SbR x SbSj @ {%d,%d}", mesh_dim0, mesh_dim1); for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) { for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) { - if (rhs_->shape().dimensions().at(rhs_space_dims_.at(i)) < - device_mesh_.dim(mesh_dim1) || - lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)) < - device_mesh_.dim(mesh_dim0)) { + if (!CheckDims(rhs_, rhs_space_dims_, {{i, mesh_dim1}}) || + !CheckDims(lhs_, lhs_batch_dims_, {{j, mesh_dim0}})) continue; - } - if (solver_option_.only_allow_divisible_intermediate && - (!IsDivisible( - rhs_->shape().dimensions().at(rhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim1)) || - !IsDivisible( - lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)), - device_mesh_.dim(mesh_dim0)))) { - continue; - } HloSharding output_spec = Tile(ins_->shape(), {j, space_base_dim_ + @@ -439,20 +359,9 @@ class DotHandler { mesh_dim0, mesh_dim1, mesh_dim1); for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) { for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) { - if (lhs_->shape().dimensions().at(lhs_con_dims_.at(i)) < - device_mesh_.dim(mesh_dim1) || - lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)) < - device_mesh_.dim(mesh_dim0)) { - continue; - } - if (solver_option_.only_allow_divisible_intermediate && - (!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)), - device_mesh_.dim(mesh_dim1)) || - !IsDivisible( - lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)), - device_mesh_.dim(mesh_dim0)))) { + if (!CheckDims(lhs_, lhs_con_dims_, {{i, mesh_dim1}}) || + !CheckDims(lhs_, lhs_batch_dims_, {{j, mesh_dim0}})) continue; - } HloSharding output_spec = Tile(ins_->shape(), {j}, {mesh_dim0}, device_mesh_); HloSharding lhs_spec = @@ -482,27 +391,10 @@ class DotHandler { mesh_dim0, mesh_dim1, mesh_dim0, mesh_dim1); for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) { for (int64_t j = i + 1; j < lhs_con_dims_.size(); ++j) { - if (lhs_->shape().dimensions().at(lhs_con_dims_.at(i)) < - device_mesh_.dim(mesh_dim0) || - lhs_->shape().dimensions().at(lhs_con_dims_.at(j)) < - device_mesh_.dim(mesh_dim1) || - rhs_->shape().dimensions().at(rhs_con_dims_.at(i)) < - device_mesh_.dim(mesh_dim0) || - rhs_->shape().dimensions().at(rhs_con_dims_.at(j)) < - device_mesh_.dim(mesh_dim1)) { - continue; - } - if (solver_option_.only_allow_divisible_intermediate && - (!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)), - device_mesh_.dim(mesh_dim1)) || - !IsDivisible(rhs_->shape().dimensions().at(rhs_con_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(rhs_->shape().dimensions().at(rhs_con_dims_.at(j)), - device_mesh_.dim(mesh_dim1)))) { + if (!CheckDims(lhs_, lhs_con_dims_, + {{i, mesh_dim0}, {j, mesh_dim1}}) || + !CheckDims(rhs_, rhs_con_dims_, {{i, mesh_dim0}, {j, mesh_dim1}})) continue; - } HloSharding output_spec = HloSharding::Replicate(); HloSharding lhs_spec = Tile(lhs_->shape(), {lhs_con_dims_[i], lhs_con_dims_[j]}, @@ -526,15 +418,7 @@ class DotHandler { std::string name = absl::StrFormat("RR = RS x SR @ {%d} (allreduce @ %d)", mesh_dim0, mesh_dim0); for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) { - if (lhs_->shape().dimensions().at(lhs_con_dims_.at(i)) < - device_mesh_.dim(mesh_dim0)) { - continue; - } - if (solver_option_.only_allow_divisible_intermediate && - !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)), - device_mesh_.dim(mesh_dim0))) { - continue; - } + if (!CheckDims(lhs_, lhs_con_dims_, {{i, mesh_dim0}})) continue; HloSharding output_spec = HloSharding::Replicate(); HloSharding lhs_spec = Tile(lhs_->shape(), {lhs_con_dims_[i]}, {mesh_dim0}, device_mesh_); @@ -614,25 +498,19 @@ class DotHandler { [](int64_t size) { return size > 1; }) > 1) { int mesh_dim = 0; for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) { - if (rhs_->shape().dimensions().at(lhs_batch_dims_.at(i)) < - device_mesh_.dim(mesh_dim)) { - continue; - } - if (solver_option_.only_allow_divisible_intermediate && - !IsDivisible(rhs_->shape().dimensions().at(lhs_batch_dims_.at(i)), - device_mesh_.dim(mesh_dim))) { + if (!CheckDims(lhs_, lhs_batch_dims_, {{i, mesh_dim}}) || + !CheckDims(rhs_, rhs_batch_dims_, {{i, mesh_dim}})) continue; - } - std::string name = - absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim); - HloSharding output_spec = - Tile(ins_->shape(), {i}, {mesh_dim}, device_mesh_1d_); - HloSharding lhs_spec = Tile(lhs_->shape(), {lhs_batch_dims_[i]}, - {mesh_dim}, device_mesh_1d_); - HloSharding rhs_spec = Tile(rhs_->shape(), {rhs_batch_dims_[i]}, - {mesh_dim}, device_mesh_1d_); - AppendNewStrategy(ins_, name, output_spec, {lhs_spec, rhs_spec}, 0, 0, - cluster_env_, strategy_map_, strategies_); + std::string name = + absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim); + HloSharding output_spec = + Tile(ins_->shape(), {i}, {mesh_dim}, device_mesh_1d_); + HloSharding lhs_spec = Tile(lhs_->shape(), {lhs_batch_dims_[i]}, + {mesh_dim}, device_mesh_1d_); + HloSharding rhs_spec = Tile(rhs_->shape(), {rhs_batch_dims_[i]}, + {mesh_dim}, device_mesh_1d_); + AppendNewStrategy(ins_, name, output_spec, {lhs_spec, rhs_spec}, 0, 0, + cluster_env_, strategy_map_, strategies_); } } } @@ -748,7 +626,7 @@ class DotHandler { // Dimension information const DotDimensionNumbers& dot_dnums_; int64_t space_base_dim_; - std::vector lhs_space_dims_, rhs_space_dims_; + tsl::protobuf::RepeatedField lhs_space_dims_, rhs_space_dims_; const tsl::protobuf::RepeatedField& lhs_con_dims_; const tsl::protobuf::RepeatedField& rhs_con_dims_; const tsl::protobuf::RepeatedField& lhs_batch_dims_; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 05cc51657dcc96..25692d4564e8cb 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -498,7 +498,7 @@ void BatchDimMapForward(const std::vector& instructions, ins->dot_dimension_numbers().lhs_batch_dimensions(); const auto& rhs_batch_dims = ins->dot_dimension_numbers().rhs_batch_dimensions(); - std::vector lhs_space_dims, rhs_space_dims; + tsl::protobuf::RepeatedField lhs_space_dims, rhs_space_dims; std::tie(lhs_space_dims, rhs_space_dims) = GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums); // This part assumes that the dot has been through the dot decomposer, @@ -759,7 +759,7 @@ void BatchDimMapBackward(const std::vector& instructions, ins->dot_dimension_numbers().lhs_batch_dimensions(); const auto& rhs_batch_dims = ins->dot_dimension_numbers().rhs_batch_dimensions(); - std::vector lhs_space_dims, rhs_space_dims; + tsl::protobuf::RepeatedField lhs_space_dims, rhs_space_dims; std::tie(lhs_space_dims, rhs_space_dims) = GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 38926d9fa35afa..c6aac8d53abcea 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -185,18 +185,18 @@ inline bool DimensionsEqual(const Shape& a, const Shape& b) { * HloInstruction Utility */ // Get the space dimensions of a dot instruction. -inline std::pair, std::vector> GetSpaceDims( - const Shape& lhs_shape, const Shape& rhs_shape, - const DotDimensionNumbers& dnums) { - std::vector lhs_space_dims; - std::vector rhs_space_dims; +inline std::pair, + tsl::protobuf::RepeatedField> +GetSpaceDims(const Shape& lhs_shape, const Shape& rhs_shape, + const DotDimensionNumbers& dnums) { + tsl::protobuf::RepeatedField lhs_space_dims, rhs_space_dims; for (int64_t i = 0; i < lhs_shape.rank(); ++i) { if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) || absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) { continue; } - lhs_space_dims.push_back(i); + lhs_space_dims.Add(i); } for (int64_t i = 0; i < rhs_shape.rank(); ++i) { @@ -204,7 +204,7 @@ inline std::pair, std::vector> GetSpaceDims( absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) { continue; } - rhs_space_dims.push_back(i); + rhs_space_dims.Add(i); } return std::make_pair(std::move(lhs_space_dims), std::move(rhs_space_dims)); }