Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar committed Mar 22, 2024
1 parent 907cbd9 commit 857243c
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions thunder/distributed/transforms/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,22 @@ def visit(bsym: BoundSymbol) -> VISIT_TYPE:


class FSDPCommBucketing:
"""Apply communication bucketing to FSDP traces.
"""Apply transformations to FSDP forward/backward traces.
This class is in charge of
- introducing bucketing into the FSDP traces.
- modifying an FSDP backward trace by removing gradient sync's when
:func:`~thunder.distributed.get_skip_data_parallel_grad_sync` is :obj:`True`.
This class is in charge of introducing bucketing into the FSDP traces.
A given forward trace will be updated so that it has fewer ``AllGather``'s by
concatenating sharded parameters beforehand and slicing and reshaping unsharded concatenated parameters afterward.
The backward trace, the counterpart of the forward, will be updated so that it has fewer ``ReduceScatter``'s.
``AllGather``s are also updated if ``sharding_strategy`` is ``FSDPType.ZERO3``.
When :func:`~thunder.ThunderModule.no_sync` is used, this removes :class:`~thunder.core.symbol.BoundSymbol`s of
:func:`~thunder.distributed.prims.reduce_scatter` and :func:`~thunder.distributed.prims.wait`,
and inserts ones to attach or accumulate unsharded, unsynchronized gradients to parameters as ``_thunder_fsdp_unsharded_grad`` attr.
"""

def __init__(
Expand Down

0 comments on commit 857243c

Please sign in to comment.