Skip to content

Commit

Permalink
Merge branch 'main' into rasbt-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Mar 21, 2024
2 parents 337a8c5 + f37ac34 commit 274151e
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 295 deletions.
74 changes: 53 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,57 +1,94 @@
![](docs/source/_static/images/lightning_thunder_lightmode_nobyline.png)
<div align="center">
<img alt="Thunder" src="docs/source/_static/images/lightning_thunder_lightmode_nobyline.png" width="400px" style="max-width: 100%;">
<br/>
<br/>

**Make PyTorch models Lightning fast.**

______________________________________________________________________

<p align="center">
<a href="https://lightning.ai/">Lightning.ai</a> •
<a href="#performance">Performance</a> •
<a href="#get-started">Get started</a> •
<a href="#install-thunder">Install</a> •
<a href="#hello-world">Examples</a> •
<a href="#features">Features</a> •
<a href="#documentation">Documentation</a> •
</p>

[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning-thunder/blob/main/LICENSE)
[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml)
[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml)
[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main)

</div>

# Welcome to ⚡ Lightning Thunder

Lightning Thunder is a source-to-source compiler for PyTorch.
**Thunder makes PyTorch models Lightning fast.**

It makes PyTorch programs faster both on single accelerators or in distributed settings.
Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (ie: nvFuser, torch.compile, cuDNN, and TransformerEngine FP8).

Works on single accelerators and in multi-GPU settings.
Thunder aims to be usable, understandable, and extensible.

## Performance

Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best-in-class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt).

![](docs/source/_static/images/training_throughput_single.png)
<div align="center">
<img alt="Thunder" src="docs/source/_static/images/training_throughput_single.png" width="800px" style="max-width: 100%;">
</div>

We achieve a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8.
Thunder achieves a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8.

Thunder supports distributed strategies like DDP and FSDP (ZeRO2 and ZeRO3). Here is the normalized throughput measured for Llama 2 7B (this time without FP8 mixed precision, support for FSDP is underway).

![](docs/source/_static/images/normalized_training_throughput_zero2.png)
<div align="center">
<img alt="Thunder" src="docs/source/_static/images/normalized_training_throughput_zero2.png" width="800px" style="max-width: 100%;">
</div>

**NOTE: Lightning Thunder is alpha.** Feel free to get involved, expect a few bumps along the way.

## Start with Thunder
## Get started

Try Thunder without installing by using our [Zero to Thunder Tutorial Studio](https://lightning.ai/lightning-ai/studios/zero-to-thunder-tutorial).

## Install Thunder

Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, which will also install the matching PyTorch nightly:
Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, and Thunder together

```bash
# install nvFuser which installs the matching nightly PyTorch
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
```

Install Thunder:

```bash
# install thunder
pip install lightning-thunder
```

It's actually not a bad idea to install directly from `main`:
<details>
<summary>Advanced install options</summary>
<!-- following section will be skipped from PyPI description -->

### Install from main

```bash
pip install git+https://github.com/Lightning-AI/lightning-thunder.git
```

or from the local repo if you want to tinker with the internals:
### Install to tinker and contribute

Install this way to tinker with the internals and contribute:

```bash
pip install -e .
```

</details>
<!-- end skipping PyPI description -->

## Hello World

Here is a simple example of how Thunder lets you compile and run PyTorch code:
Expand Down Expand Up @@ -102,7 +139,7 @@ python examples/lit-gpt/train_fsdp.py

See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with Thunder.

## What's in the box
## Features

Given a Python callable or PyTorch module, Thunder can generate an optimized program that:

Expand Down Expand Up @@ -132,7 +169,7 @@ Thunder doesn't generate code for accelerators directly. It acquires and transfo

Modules and functions compiled with Thunder fully interoperate with vanilla PyTorch and support PyTorch's autograd. Also, Thunder works alongside torch.compile to leverage its state-of-the-art optimizations.

## Build the documentation
## Documentation

Docs are currently not hosted publicly. However you can build them locally really quickly:

Expand Down Expand Up @@ -168,8 +205,3 @@ Thunder is very thoroughly tested, so expect this to take a while.

Lightning Thunder is released under the [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) license.
See the [LICENSE](LICENSE) file for details.

[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml)
[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml)
[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main)
86 changes: 0 additions & 86 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,11 +1237,6 @@ def _embedding_prim_grad(


def _get_gradfn(bsym: BoundSymbol, *, executors_list: Sequence[Any] = tuple()) -> None | Callable:
# If executor specific `aug_fwd_rule` exists then we will use that,
# so we return `None` here.
if get_executor_specific_aug_fwd_rule(bsym):
return None

cd = get_compile_data()
executors_list = cd.executors_list if cd is not None else executors_list
# Checks if the executor which has priority for this operation has a specific grad transform for it
Expand Down Expand Up @@ -2484,15 +2479,6 @@ def zeros_like(x):
}


@dataclass(**default_dataclass_params)
class RuleInfo:
checker: Callable
rule: Callable
fw_fallback: Callable
bw_fallback: Callable
executor: Executor


def register_augmented_forward(op):
"""Decorator to register an augmented forward implementation for a symbol.
Expand All @@ -2510,40 +2496,6 @@ def decorator(func):
return decorator


def register_augmented_forward_with_checker(executor, op, checker, rule):
"""Decorator to register an augmented forward implementation for a symbol.
Args:
executor (Executor): Executor to which the rule applies.
op (Ops): Symbol for which to register the augmented forward implementation.
checker (Callable): Function that checks if the rule should be applied.
rule (Callable): Function that applies the rule.
"""
fw_fallback = augmented_forward_impls.get(op, None)
bw_fallback = backward_impls.get(op, None)
augmented_forward_impls[executor, op] = RuleInfo(checker, rule, fw_fallback, bw_fallback, executor)


def deregister_augmented_forward_and_backward(op):
"""Deregisters an augmented forward implementation and a backward
implementation for a symbol.
Args:
op (Ops): Symbol for which to deregister the augmented forward
implementation and the backward implementation.
Returns:
None
"""
# Restore the fallback implementation if it exists
if isinstance(augmented_forward_impls[op], RuleInfo):
backward_impls[op] = augmented_forward_impls[op].bw_fallback
augmented_forward_impls[op] = augmented_forward_impls[op].fw_fallback
else:
del augmented_forward_impls[op]
del backward_impls[op]


def register_backward(op):
"""Decorator to register a backward implementation for a symbol.
Expand Down Expand Up @@ -3320,31 +3272,6 @@ def uniform_backward(primal, minval, maxval, g):
nondifferentiable_vjp_symbols = (prims.PrimIDs.BITWISE_AND, prims.PrimIDs.SIGNBIT, prims.PrimIDs.FULL)


def get_executor_specific_aug_fwd_rule(symbol: BoundSymbol) -> RuleInfo | None:
"""Get executor specific augmented forward rule.
Args:
symbol (BoundSymbol): BoundSymbol to get the rule for.
Returns:
RuleInfo: Rule info for the symbol.
"""
cd = get_compile_data()
if cd is None:
return None

# Search for the executor specific rules. When there are multiple rules
# for the same symbol, we use the left-most executor in the list (i.e.
# the one with the highest priority) and we fallback to the next one if
# the checker returns False.
for executor in cd.executors_list:
candidate = augmented_forward_impls.get((executor, symbol.sym.id))
if isinstance(candidate, RuleInfo) and candidate.checker(*symbol.args, **symbol.kwargs):
return candidate

return None


def is_constant_for_vjp(symbol: prims.Symbol) -> bool:
"""Check if a symbol is constant for the VJP transform.
Expand Down Expand Up @@ -3387,19 +3314,10 @@ def vjp_impl_const(symbol, *args, **kwargs):

# Normal case, we have a proxy tangent
vjp_impl = augmented_forward_impls.get(symbol.sym.id)
vjp_impl = get_executor_specific_aug_fwd_rule(symbol) or vjp_impl

if _get_gradfn(symbol) is not None:
vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)

if isinstance(vjp_impl, RuleInfo):
# We should use this rule only if checker returns True for the current
# symbol's arguments
if vjp_impl.checker(*symbol.args, **symbol.kwargs):
vjp_impl = vjp_impl.rule
else:
vjp_impl = vjp_impl.fw_fallback

if vjp_impl is None:
# We could not find a VJP for this symbol, so we try to decompose it
if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs):
Expand Down Expand Up @@ -3567,14 +3485,10 @@ def put_grad(v: Variable, val: Any) -> None:

backward = backward_impls.get(symbol.sym.id)
aug_forward = augmented_forward_impls.get(symbol.sym.id)
aug_forward = get_executor_specific_aug_fwd_rule(symbol) or aug_forward

if _get_gradfn(symbol) is not None:
aug_forward, backward = make_aug_forward_and_backward(symbol)

if isinstance(aug_forward, RuleInfo):
backward = backward_impls[aug_forward.executor, symbol.sym.id]

if backward is None:
if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs):
# We could not find a backward for this symbol, so we try to decompose it
Expand Down
74 changes: 0 additions & 74 deletions thunder/executors/apex_entropyex.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
from thunder.core.symbol import Symbol
from thunder.core.utils import check, same_shape
from thunder.core.transforms import get_grad, put_grad, put_grads, mean_backward, restore_reduced_dims
from thunder.core.transforms import (
register_augmented_forward_with_checker,
register_backward,
)

from thunder.extend import OperatorExecutor, register_executor

Expand Down Expand Up @@ -197,76 +193,6 @@ def _cross_entropy_checker(
return True


# Check out the 'add vjp rule' dev tutorial on how to add a VJP rule for any
# Symbol. We use our new primitives to register a VJP rule for
# torch.nn.functional.cross_entropy. This function is registered as the
# augmented forward rule for torch.nn.functional.cross_entropy below
def apex_cross_entropy_forward_rule(
a,
target,
weight=None,
size_average=None,
ignore_index=-100,
reduce=None,
reduction="mean",
label_smoothing=0.0,
):
loss, max_log_sum_exp = apex_xentropy(
a,
target=target,
reduction=reduction,
label_smoothing=label_smoothing,
)
primal = loss
saved_for_backward = (a, target, max_log_sum_exp, reduction, label_smoothing)
return primal, saved_for_backward


register_augmented_forward_with_checker(
apex_ex,
ltorch.cross_entropy.id,
_cross_entropy_checker,
apex_cross_entropy_forward_rule,
)


# This function is the backward rule for torch.nn.functional.cross_entropy. It
# accepts the primal output and saved_for_backward from the forward pass and
# returns the backward output. The backward output is a tuple of the backward
# output for each differentiable Tensor input to the forward pass. In this case,
# the forward pass has 1 such input, so the backward output is a single Tensor.
# This function is registered as the backward rule for
# torch.nn.functional.cross_entropy
@register_backward((apex_ex, ltorch.cross_entropy.id))
def apex_cross_entropy_backward_rule(
logits,
labels,
max_log_sum_exp,
reduction,
smoothing,
grad,
):
from thunder.core.transforms import mean_backward, sum_backward

if reduction == "mean":
grad = mean_backward(max_log_sum_exp.ndim, max_log_sum_exp.shape, (0,), grad)
elif reduction == "sum":
grad = sum_backward(max_log_sum_exp.shape, (0,), grad)
elif reduction == "none":
pass
else:
raise ValueError(f"Invalid reduction: {reduction}")

grad_logits = apex_xentropy_bwd(
grad,
logits,
target=labels,
max_log_sum_exp=max_log_sum_exp,
label_smoothing=smoothing,
)
return grad_logits


# Translate calls from torch.nn.functional.cross_entropy to apex_xentropy (when the checker above returns True)
def _cross_entropy_transform(
a: TensorProxy,
Expand Down
Loading

0 comments on commit 274151e

Please sign in to comment.