-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support CUDA Graph for MoE models #1233
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this seems mostly reasonable, although I have questions and stylistic suggestions. Have you tested that it works with Mcore?
@ptrendx @ksivaman @sbhavani What is our priority for this feature? The custom Mcore logic in make_graphed_callables
is already messy and fragile, and this PR does exacerbate those problems.
for m_chunk in range(num_model_chunks): | ||
for _ in range(num_microbatches): | ||
for l_no in range(num_layers): | ||
per_callable_module_params.append( | ||
tuple(callables[m_chunk * num_layers + l_no].parameters()) | ||
if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module) | ||
else () | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change seems correct to me, but it's odd if the Mcore integration was working before. @ksivaman Have we run this with Mcore, or did we run with num_microbatches=1
?
This changes the interpretation of per_callable_module_params
from (num_chunks, layers_per_chunk, num_microbatches)
to (num_chunks, num_microbatches, layers_per_chunk)
. This matches the interpretation of per_callable_*
lists when capturing graphs:
TransformerEngine/transformer_engine/pytorch/graph.py
Lines 237 to 239 in 3b89c36
per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + ( | |
fwd_idx[m_chunk] * num_layers + l_no | |
) |
for module in func.modules(): | ||
if hasattr(module, "is_first_microbatch"): | ||
module.is_first_microbatch = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TE modules don't set or read the is_first_microbatch
attr. It's a kwarg in the forward
function. Also, this assumes callables
contains torch.nn.Module
s.
for module in func.modules(): | |
if hasattr(module, "is_first_microbatch"): | |
module.is_first_microbatch = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the reminder. I also believe that modifications are needed here. The change we originally intended to address is the issue where the fp8 weight caching behavior in MoE leads to different behaviors for the first microbatch compared to other microbatches. If we do not include this piece of code, the warmup process will update is_first_microbatch
to False, causing all captured graphs to exhibit non-first microbatch behavior, which does not align with our requirements. Therefore, we chose to reset this parameter after warmup.
In summary, our requirement is either to prevent is_first_microbatch
from being updated during warmup or to reset is_first_microbatch
after warmup. Choosing the former may require adding a flag to inform MoE that it is currently in the warmup phase, while choosing the latter might necessitate making this code a bit more general. Do you have any input on the modification plan that could serve as a reference for us?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, so this is Mcore-specific logic. It's uncomfortable that it's made its way into TE, but it's a tricky problem and I can't think of a better solution either, at least without significant changes in Mcore.
We should document what this is doing, especially for TE developers with no knowledge of Mcore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I try to explain the reason for this MoE specific logic.
transformer_engine/pytorch/graph.py
Outdated
if ( | ||
not fp8_recipe.fp8_mha | ||
and not fp8_recipe.fp8_dpa | ||
and hasattr(m, "attention_dropout") | ||
and m.deterministic | ||
): | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we skipping the FP8 scale update logic for this case?
if ( | |
not fp8_recipe.fp8_mha | |
and not fp8_recipe.fp8_dpa | |
and hasattr(m, "attention_dropout") | |
and m.deterministic | |
): | |
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for a deterministic test with FP8. Even without CUDA graphs, we find that FP8_DPA leads to random output, even when set to deterministic mode. With CUDA graphs, we not only need to set DPA to BF16, but also skip the FP8_meta update for DPA. After that, we will finally obtain a stable output under the deterministic test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like we're covering up a correctness bug. This may also affect convergence since we are no longer doing amax reductions within MHA (e.g. for the FP8 cast before qkv
and proj
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It makes sense, and I think we should 'continue
' without fp8_mha and fp8_dpa, regardless of whether the test is deterministic or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, this change will cause correctness issues. We do a max all-reduce on the amaxes so that FP8 scaling factors are synchronized over the TP group. If we skip the amax reduction, then the MHA's TP communication will be wrong. In particular, we'll all-gather FP8 inputs to the qkv
GEMM, but the scaling factors will be different for each TP rank. In pseudocode:
def mha_qkv(x_local, w_local, x_fp8_scale):
x_local_fp8, x_amax = cast_to_fp8(x, x_fp8_scale)
x_fp8 = all_gather(x_local_fp8)
y_local = gemm(x_fp8, w_local, x_fp8_scale)
max_all_reduce(x_amax) # Without this, x_fp8_scale is different between ranks
update_fp8_scale(x_fp8_scale, x_amax)
return y_local
Turns out this isn't relevant, since MultiheadAttention
is not a TransformerEngineBaseModule
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell, this logic is specific to DotProductAttention
. We could make intent much more obvious with:
if ( | |
not fp8_recipe.fp8_mha | |
and not fp8_recipe.fp8_dpa | |
and hasattr(m, "attention_dropout") | |
and m.deterministic | |
): | |
continue | |
if ( | |
isinstance(m, DotProductAttention) | |
and not fp8_recipe.fp8_mha | |
and not fp8_recipe.fp8_dpa | |
): | |
# Don't need to update FP8 meta for non-FP8 DPA | |
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I have modified this.
/te-ci pytorch |
Signed-off-by: Robin Zhang <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
Signed-off-by: Robin Zhang <[email protected]> Co-authored-by: Yifei Song <[email protected]>
Signed-off-by: Robin Zhang <[email protected]>
for more information, see https://pre-commit.ci
/te-ci pytorch |
Yes, we also made some changes in Mcore, together with TE changes in this PR, to enable MoE cudagraph. You can refer to issue 193 in our Megatron-LM repo. |
Signed-off-by: Xin Yao <[email protected]>
/te-ci pytorch |
Signed-off-by: Robin Zhang <[email protected]>
Signed-off-by: Yifei Song <[email protected]>
Description
Different from non-MoE models like llama2, MoE models have dynamic shaped activations in FFN layers, so one cudagraph can only capture a part of one transformer layer, instead of covering the whole layer. We call this a "breaking-layer" cudagraph mode. This PR adds breaking-layer cudagraph supports for MoE models on the TE side, and fixes several related bugs in TE.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
is_initialized()
method in CudaRNGStatesTracker to align with what is already done in MCore.per_callable_module_params
order bug in_make_graphed_callables
when_order
is given._make_graphed_callables
when_order
is given.fp8_group
argument tomake_graphed_callables()
and modifingis_first_microbatch
,skip_fp8_weight_update
andfp8_meta
code.Checklist: