Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cache entries for reshape ops in SPMD. #21579

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

copybara-service[bot]
Copy link

Add cache entries for reshape ops in SPMD.

We may have two compatible sharding pairs when handling reshape. If we have two pairs, we use the first one. We can still use the second one to add as a sharding cache. Given the following reshape,

p0 = bf16[8,8] parameter(0), sharding={replicated}
reshape = bf16[64] reshape(p0), sharding={devices=[4]<=[4]}

there are two compatible sharding pairs

1.1 reshard input to sharding={devices=[4,1]<=[4]}
1.2 reshape

2.1 reshape input to bf16[8,8] with sharding {replicated}
2.2 reshard to final shardingsharding={devices=[4]<=[4]}

Before this change, we only add 1.1 and 1.2. This change also adds 2.1 as a reshard cache, which can be used directly without reshard the result of 1.2. If the cache is not used, it will be removed by DCE pass.

Given the following input

ENTRY %reshape {
  p0 = bf16[8,8] parameter(0), sharding={replicated}
  reshape = bf16[64] reshape(p0), sharding={devices=[4]<=[4]}
  abs = bf16[64] abs(reshape), sharding={replicated}
  ROOT tuple = (bf16[64], bf16[64]) tuple(reshape, abs), sharding={{devices=[4]<=[4]}, {replicated}}
}

Before this change, we have "expensive" all-gather

ENTRY %reshape_spmd (param: bf16[8,8]) -> (bf16[16], bf16[64]) {
  %param = bf16[8,8]{1,0} parameter(0), sharding={replicated}
  %constant = s32[4]{0} constant({0, 2, 4, 6})
  %partition-id = u32[] partition-id()
  %dynamic-slice = s32[1]{0} dynamic-slice(s32[4]{0} %constant, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape.1 = s32[] reshape(s32[1]{0} %dynamic-slice)
  %constant.1 = s32[] constant(0)
  %dynamic-slice.1 = bf16[2,8]{1,0} dynamic-slice(bf16[8,8]{1,0} %param, s32[] %reshape.1, s32[] %constant.1), dynamic_slice_sizes={2,8}
  %reshape.2 = bf16[16]{0} reshape(bf16[2,8]{1,0} %dynamic-slice.1)
  %all-gather = bf16[64]{0} all-gather(bf16[16]{0} %reshape.2), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={0}, use_global_device_ids=true
  %abs.1 = bf16[64]{0} abs(bf16[64]{0} %all-gather)
  ROOT %tuple.1 = (bf16[16]{0}, bf16[64]{0}) tuple(bf16[16]{0} %reshape.2, bf16[64]{0} %abs.1)
}

With this change, we replace reshard with reshape

ENTRY %reshape_spmd (param: bf16[8,8]) -> (bf16[16], bf16[64]) {
  %param = bf16[8,8]{1,0} parameter(0), sharding={replicated}
  %constant = s32[4]{0} constant({0, 2, 4, 6})
  %partition-id = u32[] partition-id()
  %dynamic-slice = s32[1]{0} dynamic-slice(s32[4]{0} %constant, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape.1 = s32[] reshape(s32[1]{0} %dynamic-slice)
  %constant.1 = s32[] constant(0)
  %dynamic-slice.1 = bf16[2,8]{1,0} dynamic-slice(bf16[8,8]{1,0} %param, s32[] %reshape.1, s32[] %constant.1), dynamic_slice_sizes={2,8}
  %reshape.2 = bf16[16]{0} reshape(bf16[2,8]{1,0} %dynamic-slice.1)
  %reshape.3 = bf16[64]{0} reshape(bf16[8,8]{1,0} %param)
  %abs.1 = bf16[64]{0} abs(bf16[64]{0} %reshape.3)
  ROOT %tuple.1 = (bf16[16]{0}, bf16[64]{0}) tuple(bf16[16]{0} %reshape.2, bf16[64]{0} %abs.1)
}

We may have two compatible sharding pairs when handling reshape. If we have two pairs, we use the first one. We can still use the second one to add as a sharding cache. Given the following reshape,
```
p0 = bf16[8,8] parameter(0), sharding={replicated}
reshape = bf16[64] reshape(p0), sharding={devices=[4]<=[4]}
```
there are two compatible sharding pairs
```
1.1 reshard input to sharding={devices=[4,1]<=[4]}
1.2 reshape

2.1 reshape input to bf16[8,8] with sharding {replicated}
2.2 reshard to final shardingsharding={devices=[4]<=[4]}
```

Before this change, we only add 1.1 and 1.2. This change also adds 2.1 as a reshard cache, which can be used directly without reshard the result of 1.2. If the cache is not used, it will be removed by DCE pass.

Given the following input
```
ENTRY %reshape {
  p0 = bf16[8,8] parameter(0), sharding={replicated}
  reshape = bf16[64] reshape(p0), sharding={devices=[4]<=[4]}
  abs = bf16[64] abs(reshape), sharding={replicated}
  ROOT tuple = (bf16[64], bf16[64]) tuple(reshape, abs), sharding={{devices=[4]<=[4]}, {replicated}}
}
```

Before this change, we have "expensive" all-gather
```
ENTRY %reshape_spmd (param: bf16[8,8]) -> (bf16[16], bf16[64]) {
  %param = bf16[8,8]{1,0} parameter(0), sharding={replicated}
  %constant = s32[4]{0} constant({0, 2, 4, 6})
  %partition-id = u32[] partition-id()
  %dynamic-slice = s32[1]{0} dynamic-slice(s32[4]{0} %constant, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape.1 = s32[] reshape(s32[1]{0} %dynamic-slice)
  %constant.1 = s32[] constant(0)
  %dynamic-slice.1 = bf16[2,8]{1,0} dynamic-slice(bf16[8,8]{1,0} %param, s32[] %reshape.1, s32[] %constant.1), dynamic_slice_sizes={2,8}
  %reshape.2 = bf16[16]{0} reshape(bf16[2,8]{1,0} %dynamic-slice.1)
  %all-gather = bf16[64]{0} all-gather(bf16[16]{0} %reshape.2), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={0}, use_global_device_ids=true
  %abs.1 = bf16[64]{0} abs(bf16[64]{0} %all-gather)
  ROOT %tuple.1 = (bf16[16]{0}, bf16[64]{0}) tuple(bf16[16]{0} %reshape.2, bf16[64]{0} %abs.1)
}
```

With this change, we replace reshard with reshape
```
ENTRY %reshape_spmd (param: bf16[8,8]) -> (bf16[16], bf16[64]) {
  %param = bf16[8,8]{1,0} parameter(0), sharding={replicated}
  %constant = s32[4]{0} constant({0, 2, 4, 6})
  %partition-id = u32[] partition-id()
  %dynamic-slice = s32[1]{0} dynamic-slice(s32[4]{0} %constant, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape.1 = s32[] reshape(s32[1]{0} %dynamic-slice)
  %constant.1 = s32[] constant(0)
  %dynamic-slice.1 = bf16[2,8]{1,0} dynamic-slice(bf16[8,8]{1,0} %param, s32[] %reshape.1, s32[] %constant.1), dynamic_slice_sizes={2,8}
  %reshape.2 = bf16[16]{0} reshape(bf16[2,8]{1,0} %dynamic-slice.1)
  %reshape.3 = bf16[64]{0} reshape(bf16[8,8]{1,0} %param)
  %abs.1 = bf16[64]{0} abs(bf16[64]{0} %reshape.3)
  ROOT %tuple.1 = (bf16[16]{0}, bf16[64]{0}) tuple(bf16[16]{0} %reshape.2, bf16[64]{0} %abs.1)
}
```

PiperOrigin-RevId: 716846842
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant