Skip to content

Commit

Permalink
Consolidates the dimension checking logic into a single method.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569865640
  • Loading branch information
tensorflower-gardener committed Oct 1, 2023
1 parent 825dcfb commit ad39c95
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3731,7 +3731,7 @@ void AnnotateShardingWithSimpleHeuristic(
const DotDimensionNumbers& dot_dnums = inst->dot_dimension_numbers();
// const auto& lhs_con_dims = dot_dnums.lhs_contracting_dimensions();
// const auto& rhs_con_dims = dot_dnums.rhs_contracting_dimensions();
std::vector<int64_t> lhs_space_dims, rhs_space_dims;
tsl::protobuf::RepeatedField<int64_t> lhs_space_dims, rhs_space_dims;
std::tie(lhs_space_dims, rhs_space_dims) =
GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ limitations under the License.
namespace xla {
namespace spmd {

using DimMap = StableHashMap</*instr. dim idx*/ int, /* mesh dim idx*/ int>;

void AppendNewStrategy(const HloInstruction* ins, const std::string& name,
const HloSharding& output_spec,
absl::Span<const HloSharding> input_specs,
Expand Down Expand Up @@ -108,22 +110,27 @@ class DotHandler {
CHECK_EQ(lhs_batch_dims_.size(), rhs_batch_dims_.size());
}

bool CheckDims(const HloInstruction* ins,
const tsl::protobuf::RepeatedField<int64_t>& instr_dims,
const DimMap& dim_map) const {
for (const auto& [instr_dim_idx, mesh_dim_idx] : dim_map) {
auto instr_dim = instr_dims.at(instr_dim_idx);
auto shape_dim = ins->shape().dimensions().at(instr_dim);
auto mesh_dim = device_mesh_.dim(mesh_dim_idx);
if (shape_dim < mesh_dim) return false;
if (solver_option_.only_allow_divisible_intermediate &&
!IsDivisible(shape_dim, mesh_dim))
return false;
}
return true;
}

void SplitLhsSpaceRhsSpace(int mesh_dim0, int mesh_dim1) {
for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) {
for (int64_t j = 0; j < rhs_space_dims_.size(); ++j) {
if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) <
device_mesh_.dim(mesh_dim0) ||
rhs_->shape().dimensions().at(rhs_space_dims_.at(j)) <
device_mesh_.dim(mesh_dim1)) {
if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dim0}}) ||
!CheckDims(rhs_, rhs_space_dims_, {{j, mesh_dim1}}))
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
(!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(i)),
device_mesh_.dim(mesh_dim0)) ||
!IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(j)),
device_mesh_.dim(mesh_dim1)))) {
continue;
}
std::string name =
absl::StrFormat("SS = SR x RS @ {%d,%d}", mesh_dim0, mesh_dim1);
HloSharding output_spec =
Expand All @@ -146,19 +153,8 @@ class DotHandler {
void SplitLhsSpaceOnly(int mesh_dim0, int mesh_dim1) {
for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) {
for (int64_t j = i + 1; j < lhs_space_dims_.size(); ++j) {
if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) <
device_mesh_.dim(mesh_dim0) ||
lhs_->shape().dimensions().at(lhs_space_dims_.at(j)) <
device_mesh_.dim(mesh_dim1)) {
if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dim0}, {j, mesh_dim1}}))
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
(!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(i)),
device_mesh_.dim(mesh_dim0)) ||
!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(j)),
device_mesh_.dim(mesh_dim1)))) {
continue;
}
std::string name =
absl::StrFormat("SSR = SSR x RR @ {%d,%d}", mesh_dim0, mesh_dim1);
HloSharding output_spec =
Expand All @@ -178,19 +174,8 @@ class DotHandler {
void SplitRhsSpaceOnly(int mesh_dim0, int mesh_dim1) {
for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) {
for (int64_t j = i + 1; j < rhs_space_dims_.size(); ++j) {
if (rhs_->shape().dimensions().at(rhs_space_dims_.at(i)) <
device_mesh_.dim(mesh_dim0) ||
rhs_->shape().dimensions().at(rhs_space_dims_.at(j)) <
device_mesh_.dim(mesh_dim1)) {
if (!CheckDims(rhs_, rhs_space_dims_, {{i, mesh_dim0}, {j, mesh_dim1}}))
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
(!IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(i)),
device_mesh_.dim(mesh_dim0)) ||
!IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(j)),
device_mesh_.dim(mesh_dim1)))) {
continue;
}
std::string name =
absl::StrFormat("RSS = RR x RSS @ {%d,%d}", mesh_dim0, mesh_dim1);
HloSharding output_spec = Tile(
Expand All @@ -217,20 +202,9 @@ class DotHandler {
mesh_dim1, mesh_dim1);
for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) {
for (int64_t j = 0; j < lhs_con_dims_.size(); ++j) {
if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) <
device_mesh_.dim(mesh_dim0) ||
lhs_->shape().dimensions().at(lhs_con_dims_.at(j)) <
device_mesh_.dim(mesh_dim1)) {
if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dim0}}) ||
!CheckDims(lhs_, lhs_con_dims_, {{j, mesh_dim1}}))
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
(!IsDivisible(
lhs_->shape().dimensions().at(lhs_space_dims_.at(i)),
device_mesh_.dim(mesh_dim0)) ||
!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)),
device_mesh_.dim(mesh_dim1)))) {
continue;
}

HloSharding output_spec = Tile(ins_->shape(), {space_base_dim_ + i},
{mesh_dim0}, device_mesh_);
Expand Down Expand Up @@ -258,20 +232,9 @@ class DotHandler {
mesh_dim1, mesh_dim0);
for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) {
for (int64_t j = 0; j < lhs_con_dims_.size(); ++j) {
if (rhs_->shape().dimensions().at(rhs_space_dims_.at(i)) <
device_mesh_.dim(mesh_dim1) ||
lhs_->shape().dimensions().at(lhs_con_dims_.at(j)) <
device_mesh_.dim(mesh_dim0)) {
if (!CheckDims(rhs_, rhs_space_dims_, {{i, mesh_dim1}}) ||
!CheckDims(lhs_, lhs_con_dims_, {{j, mesh_dim0}}))
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
(!IsDivisible(
rhs_->shape().dimensions().at(rhs_space_dims_.at(i)),
device_mesh_.dim(mesh_dim1)) ||
!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)),
device_mesh_.dim(mesh_dim0)))) {
continue;
}
HloSharding output_spec =
Tile(ins_->shape(),
{space_base_dim_ +
Expand Down Expand Up @@ -299,15 +262,7 @@ class DotHandler {
[](int64_t size) { return size > 1; }) == 1) {
for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) {
for (int64_t j = 0; j < device_mesh_.num_dimensions(); ++j) {
if (lhs_->shape().dimensions().at(lhs_batch_dims_.at(i)) <
device_mesh_.dim(j)) {
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
!IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(i)),
device_mesh_.dim(j))) {
continue;
}
if (!CheckDims(lhs_, lhs_batch_dims_, {{i, j}})) continue;
std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d}", i, j);
HloSharding output_spec = Tile(ins_->shape(), {i}, {j}, device_mesh_);
HloSharding lhs_spec =
Expand All @@ -325,19 +280,8 @@ class DotHandler {
void SplitTwoBatchDims(int mesh_dim0, int mesh_dim1) {
if (lhs_batch_dims_.size() == 2 && device_mesh_.dim(mesh_dim0) > 1 &&
device_mesh_.dim(mesh_dim1) > 1) {
if (lhs_->shape().dimensions().at(lhs_batch_dims_.at(0)) <
device_mesh_.dim(mesh_dim0) ||
lhs_->shape().dimensions().at(lhs_batch_dims_.at(1)) <
device_mesh_.dim(mesh_dim1)) {
return;
}
if (solver_option_.only_allow_divisible_intermediate &&
(!IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(0)),
device_mesh_.dim(mesh_dim0)) ||
!IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(1)),
device_mesh_.dim(mesh_dim1)))) {
if (!CheckDims(lhs_, lhs_batch_dims_, {{0, mesh_dim0}, {1, mesh_dim1}}))
return;
}
std::string name =
absl::StrFormat("Sb = Sb x Sb @ {%d,%d}", mesh_dim0, mesh_dim1);
HloSharding output_spec =
Expand All @@ -360,21 +304,9 @@ class DotHandler {
absl::StrFormat("SbSi = SbSi x SbR @ {%d,%d}", mesh_dim0, mesh_dim1);
for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) {
for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) {
if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) <
device_mesh_.dim(mesh_dim0) ||
lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)) <
device_mesh_.dim(mesh_dim1)) {
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
(!IsDivisible(
lhs_->shape().dimensions().at(lhs_space_dims_.at(i)),
device_mesh_.dim(mesh_dim0)) ||
!IsDivisible(
lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)),
device_mesh_.dim(mesh_dim1)))) {
if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dim0}}) ||
!CheckDims(lhs_, lhs_batch_dims_, {{j, mesh_dim1}}))
continue;
}
HloSharding output_spec =
Tile(ins_->shape(), {j, space_base_dim_ + i},
{mesh_dim0, mesh_dim1}, device_mesh_);
Expand All @@ -398,21 +330,9 @@ class DotHandler {
absl::StrFormat("SbSj = SbR x SbSj @ {%d,%d}", mesh_dim0, mesh_dim1);
for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) {
for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) {
if (rhs_->shape().dimensions().at(rhs_space_dims_.at(i)) <
device_mesh_.dim(mesh_dim1) ||
lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)) <
device_mesh_.dim(mesh_dim0)) {
if (!CheckDims(rhs_, rhs_space_dims_, {{i, mesh_dim1}}) ||
!CheckDims(lhs_, lhs_batch_dims_, {{j, mesh_dim0}}))
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
(!IsDivisible(
rhs_->shape().dimensions().at(rhs_space_dims_.at(i)),
device_mesh_.dim(mesh_dim1)) ||
!IsDivisible(
lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)),
device_mesh_.dim(mesh_dim0)))) {
continue;
}
HloSharding output_spec =
Tile(ins_->shape(),
{j, space_base_dim_ +
Expand All @@ -439,20 +359,9 @@ class DotHandler {
mesh_dim0, mesh_dim1, mesh_dim1);
for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) {
for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) {
if (lhs_->shape().dimensions().at(lhs_con_dims_.at(i)) <
device_mesh_.dim(mesh_dim1) ||
lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)) <
device_mesh_.dim(mesh_dim0)) {
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
(!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)),
device_mesh_.dim(mesh_dim1)) ||
!IsDivisible(
lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)),
device_mesh_.dim(mesh_dim0)))) {
if (!CheckDims(lhs_, lhs_con_dims_, {{i, mesh_dim1}}) ||
!CheckDims(lhs_, lhs_batch_dims_, {{j, mesh_dim0}}))
continue;
}
HloSharding output_spec =
Tile(ins_->shape(), {j}, {mesh_dim0}, device_mesh_);
HloSharding lhs_spec =
Expand Down Expand Up @@ -482,27 +391,10 @@ class DotHandler {
mesh_dim0, mesh_dim1, mesh_dim0, mesh_dim1);
for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) {
for (int64_t j = i + 1; j < lhs_con_dims_.size(); ++j) {
if (lhs_->shape().dimensions().at(lhs_con_dims_.at(i)) <
device_mesh_.dim(mesh_dim0) ||
lhs_->shape().dimensions().at(lhs_con_dims_.at(j)) <
device_mesh_.dim(mesh_dim1) ||
rhs_->shape().dimensions().at(rhs_con_dims_.at(i)) <
device_mesh_.dim(mesh_dim0) ||
rhs_->shape().dimensions().at(rhs_con_dims_.at(j)) <
device_mesh_.dim(mesh_dim1)) {
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
(!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)),
device_mesh_.dim(mesh_dim0)) ||
!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)),
device_mesh_.dim(mesh_dim1)) ||
!IsDivisible(rhs_->shape().dimensions().at(rhs_con_dims_.at(i)),
device_mesh_.dim(mesh_dim0)) ||
!IsDivisible(rhs_->shape().dimensions().at(rhs_con_dims_.at(j)),
device_mesh_.dim(mesh_dim1)))) {
if (!CheckDims(lhs_, lhs_con_dims_,
{{i, mesh_dim0}, {j, mesh_dim1}}) ||
!CheckDims(rhs_, rhs_con_dims_, {{i, mesh_dim0}, {j, mesh_dim1}}))
continue;
}
HloSharding output_spec = HloSharding::Replicate();
HloSharding lhs_spec =
Tile(lhs_->shape(), {lhs_con_dims_[i], lhs_con_dims_[j]},
Expand All @@ -526,15 +418,7 @@ class DotHandler {
std::string name = absl::StrFormat("RR = RS x SR @ {%d} (allreduce @ %d)",
mesh_dim0, mesh_dim0);
for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) {
if (lhs_->shape().dimensions().at(lhs_con_dims_.at(i)) <
device_mesh_.dim(mesh_dim0)) {
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)),
device_mesh_.dim(mesh_dim0))) {
continue;
}
if (!CheckDims(lhs_, lhs_con_dims_, {{i, mesh_dim0}})) continue;
HloSharding output_spec = HloSharding::Replicate();
HloSharding lhs_spec =
Tile(lhs_->shape(), {lhs_con_dims_[i]}, {mesh_dim0}, device_mesh_);
Expand Down Expand Up @@ -614,25 +498,19 @@ class DotHandler {
[](int64_t size) { return size > 1; }) > 1) {
int mesh_dim = 0;
for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) {
if (rhs_->shape().dimensions().at(lhs_batch_dims_.at(i)) <
device_mesh_.dim(mesh_dim)) {
continue;
}
if (solver_option_.only_allow_divisible_intermediate &&
!IsDivisible(rhs_->shape().dimensions().at(lhs_batch_dims_.at(i)),
device_mesh_.dim(mesh_dim))) {
if (!CheckDims(lhs_, lhs_batch_dims_, {{i, mesh_dim}}) ||
!CheckDims(rhs_, rhs_batch_dims_, {{i, mesh_dim}}))
continue;
}
std::string name =
absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim);
HloSharding output_spec =
Tile(ins_->shape(), {i}, {mesh_dim}, device_mesh_1d_);
HloSharding lhs_spec = Tile(lhs_->shape(), {lhs_batch_dims_[i]},
{mesh_dim}, device_mesh_1d_);
HloSharding rhs_spec = Tile(rhs_->shape(), {rhs_batch_dims_[i]},
{mesh_dim}, device_mesh_1d_);
AppendNewStrategy(ins_, name, output_spec, {lhs_spec, rhs_spec}, 0, 0,
cluster_env_, strategy_map_, strategies_);
std::string name =
absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim);
HloSharding output_spec =
Tile(ins_->shape(), {i}, {mesh_dim}, device_mesh_1d_);
HloSharding lhs_spec = Tile(lhs_->shape(), {lhs_batch_dims_[i]},
{mesh_dim}, device_mesh_1d_);
HloSharding rhs_spec = Tile(rhs_->shape(), {rhs_batch_dims_[i]},
{mesh_dim}, device_mesh_1d_);
AppendNewStrategy(ins_, name, output_spec, {lhs_spec, rhs_spec}, 0, 0,
cluster_env_, strategy_map_, strategies_);
}
}
}
Expand Down Expand Up @@ -748,7 +626,7 @@ class DotHandler {
// Dimension information
const DotDimensionNumbers& dot_dnums_;
int64_t space_base_dim_;
std::vector<int64_t> lhs_space_dims_, rhs_space_dims_;
tsl::protobuf::RepeatedField<int64_t> lhs_space_dims_, rhs_space_dims_;
const tsl::protobuf::RepeatedField<int64_t>& lhs_con_dims_;
const tsl::protobuf::RepeatedField<int64_t>& rhs_con_dims_;
const tsl::protobuf::RepeatedField<int64_t>& lhs_batch_dims_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ void BatchDimMapForward(const std::vector<HloInstruction*>& instructions,
ins->dot_dimension_numbers().lhs_batch_dimensions();
const auto& rhs_batch_dims =
ins->dot_dimension_numbers().rhs_batch_dimensions();
std::vector<int64_t> lhs_space_dims, rhs_space_dims;
tsl::protobuf::RepeatedField<int64_t> lhs_space_dims, rhs_space_dims;
std::tie(lhs_space_dims, rhs_space_dims) =
GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums);
// This part assumes that the dot has been through the dot decomposer,
Expand Down Expand Up @@ -759,7 +759,7 @@ void BatchDimMapBackward(const std::vector<HloInstruction*>& instructions,
ins->dot_dimension_numbers().lhs_batch_dimensions();
const auto& rhs_batch_dims =
ins->dot_dimension_numbers().rhs_batch_dimensions();
std::vector<int64_t> lhs_space_dims, rhs_space_dims;
tsl::protobuf::RepeatedField<int64_t> lhs_space_dims, rhs_space_dims;
std::tie(lhs_space_dims, rhs_space_dims) =
GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums);

Expand Down
Loading

0 comments on commit ad39c95

Please sign in to comment.