From d37eebc5d2f5232633c37fc833afb65cecd5d40a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 21 Mar 2024 11:18:44 -0400 Subject: [PATCH 1/3] simplify readme 1/n (#31) * Update README.md * pre-commit: running and fixing... --------- Co-authored-by: github-actions[bot] --- README.md | 76 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 8f461fb518..d3d325a155 100644 --- a/README.md +++ b/README.md @@ -1,57 +1,94 @@ -![](docs/source/_static/images/lightning_thunder_lightmode_nobyline.png) +
+Thunder +
+
+ +**Make PyTorch models Lightning fast.** + +______________________________________________________________________ + +

+ Lightning.ai • + Performance • + Get started • + Install • + Examples • + Features • + Documentation • +

+ +[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning-thunder/blob/main/LICENSE) +[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml) +[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml) +[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest) +[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main) + +
# Welcome to ⚡ Lightning Thunder -Lightning Thunder is a source-to-source compiler for PyTorch. +**Thunder makes PyTorch models Lightning fast.** -It makes PyTorch programs faster both on single accelerators or in distributed settings. +Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (ie: nvFuser, torch.compile, cuDNN, and TransformerEngine FP8). +Works on single accelerators and in multi-GPU settings. Thunder aims to be usable, understandable, and extensible. ## Performance Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best in class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt). -![](docs/source/_static/images/training_throughput_single.png) +
+Thunder +
-We achieve a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8. +Thunder achieves a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8. Thunder supports distributed strategies like DDP and FSDP (ZeRO2 and ZeRO3). Here is the normalized throughput measured for Llama 2 7B (this time without FP8 mixed precision, support for FSDP is underway). -![](docs/source/_static/images/normalized_training_throughput_zero2.png) +
+Thunder +
**NOTE: Lightning Thunder is alpha.** Feel free to get involved, expect a few bumps along the way. -## Start with Thunder +## Get started Try Thunder without installing by using our [Zero to Thunder Tutorial Studio](https://lightning.ai/lightning-ai/studios/zero-to-thunder-tutorial). ## Install Thunder -Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, which will also install the matching PyTorch nightly: +Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, and Thunder together ```bash +# install nvFuser which installs the matching nightly PyTorch pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com -``` -Install Thunder: - -```bash +# install thunder pip install lightning-thunder ``` -It's actually not a bad idea to install directly from `main`: +
+ Advanced install options + + +### Install from main ```bash pip install git+https://github.com/Lightning-AI/lightning-thunder.git ``` -or from the local repo if you want to tinker with the internals: +### Install to tinker and contribute + +Install this way to tinker with the internals and contribute: ```bash pip install -e . ``` +
+ + ## Hello World Here is a simple example of how Thunder lets you compile and run PyTorch code: @@ -82,7 +119,7 @@ print(result) The compiled function `jfoo` takes and returns PyTorch tensors, just like the original function, so modules and functions compiled by Thunder can be used as part of larger PyTorch programs. -## Running training +## Train models Thunder is in its early stages, it should not be used for production runs yet. @@ -102,7 +139,7 @@ python examples/lit-gpt/train_fsdp.py See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with Thunder. -## What's in the box +## Features Given a python callable or PyTorch module, Thunder can generate an optimized program that: @@ -132,7 +169,7 @@ Thunder doesn't generate code for accelerators directly. It acquires and transfo Modules and functions compiled with Thunder fully interoperate with vanilla PyTorch and support PyTorch's autograd. Also, Thunder works alongside torch.compile to leverage its state-of-the-art optimizations. -## Build the documentation +## Documentation Docs are currently not hosted publicly. However you can build them locally really quickly: @@ -168,8 +205,3 @@ Thunder is very thoroughly tested, so expect this to take a while. Lightning Thunder is released under the [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) license. See LICENSE file for details. - -[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml) -[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml) -[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest) -[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main) From f37ac34399c8f039637b527a79064d106abb4bc2 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 21 Mar 2024 17:20:36 +0200 Subject: [PATCH 2/3] Remove deregister_augmented_forward_and_backward (#30) --- thunder/core/transforms.py | 86 ----------------- thunder/executors/apex_entropyex.py | 74 --------------- thunder/executors/cudnnex.py | 108 +++------------------- thunder/executors/transformer_engineex.py | 25 +++-- thunder/tests/test_cudnn_executor.py | 6 +- 5 files changed, 25 insertions(+), 274 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index c24e44a3ce..772e65a84d 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1237,11 +1237,6 @@ def _embedding_prim_grad( def _get_gradfn(bsym: BoundSymbol, *, executors_list: Sequence[Any] = tuple()) -> None | Callable: - # If executor specific `aug_fwd_rule` exists then we will use that, - # so we return `None` here. - if get_executor_specific_aug_fwd_rule(bsym): - return None - cd = get_compile_data() executors_list = cd.executors_list if cd is not None else executors_list # Checks if the executor which has priority for this operation has a specific grad transform for it @@ -2484,15 +2479,6 @@ def zeros_like(x): } -@dataclass(**default_dataclass_params) -class RuleInfo: - checker: Callable - rule: Callable - fw_fallback: Callable - bw_fallback: Callable - executor: Executor - - def register_augmented_forward(op): """Decorator to register an augmented forward implementation for a symbol. @@ -2510,40 +2496,6 @@ def decorator(func): return decorator -def register_augmented_forward_with_checker(executor, op, checker, rule): - """Decorator to register an augmented forward implementation for a symbol. - - Args: - executor (Executor): Executor to which the rule applies. - op (Ops): Symbol for which to register the augmented forward implementation. - checker (Callable): Function that checks if the rule should be applied. - rule (Callable): Function that applies the rule. - """ - fw_fallback = augmented_forward_impls.get(op, None) - bw_fallback = backward_impls.get(op, None) - augmented_forward_impls[executor, op] = RuleInfo(checker, rule, fw_fallback, bw_fallback, executor) - - -def deregister_augmented_forward_and_backward(op): - """Deregisters an augmented forward implementation and a backward - implementation for a symbol. - - Args: - op (Ops): Symbol for which to deregister the augmented forward - implementation and the backward implementation. - - Returns: - None - """ - # Restore the fallback implementation if it exists - if isinstance(augmented_forward_impls[op], RuleInfo): - backward_impls[op] = augmented_forward_impls[op].bw_fallback - augmented_forward_impls[op] = augmented_forward_impls[op].fw_fallback - else: - del augmented_forward_impls[op] - del backward_impls[op] - - def register_backward(op): """Decorator to register a backward implementation for a symbol. @@ -3320,31 +3272,6 @@ def uniform_backward(primal, minval, maxval, g): nondifferentiable_vjp_symbols = (prims.PrimIDs.BITWISE_AND, prims.PrimIDs.SIGNBIT, prims.PrimIDs.FULL) -def get_executor_specific_aug_fwd_rule(symbol: BoundSymbol) -> RuleInfo | None: - """Get executor specific augmented forward rule. - - Args: - symbol (BoundSymbol): BoundSymbol to get the rule for. - - Returns: - RuleInfo: Rule info for the symbol. - """ - cd = get_compile_data() - if cd is None: - return None - - # Search for the executor specific rules. When there are multiple rules - # for the same symbol, we use the left-most executor in the list (i.e. - # the one with the highest priority) and we fallback to the next one if - # the checker returns False. - for executor in cd.executors_list: - candidate = augmented_forward_impls.get((executor, symbol.sym.id)) - if isinstance(candidate, RuleInfo) and candidate.checker(*symbol.args, **symbol.kwargs): - return candidate - - return None - - def is_constant_for_vjp(symbol: prims.Symbol) -> bool: """Check if a symbol is constant for the VJP transform. @@ -3387,19 +3314,10 @@ def vjp_impl_const(symbol, *args, **kwargs): # Normal case, we have a proxy tangent vjp_impl = augmented_forward_impls.get(symbol.sym.id) - vjp_impl = get_executor_specific_aug_fwd_rule(symbol) or vjp_impl if _get_gradfn(symbol) is not None: vjp_impl, backward_fn = make_aug_forward_and_backward(symbol) - if isinstance(vjp_impl, RuleInfo): - # We should use this rule only if checker returns True for the current - # symbol's arguments - if vjp_impl.checker(*symbol.args, **symbol.kwargs): - vjp_impl = vjp_impl.rule - else: - vjp_impl = vjp_impl.fw_fallback - if vjp_impl is None: # We could not find a VJP for this symbol, so we try to decompose it if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs): @@ -3567,14 +3485,10 @@ def put_grad(v: Variable, val: Any) -> None: backward = backward_impls.get(symbol.sym.id) aug_forward = augmented_forward_impls.get(symbol.sym.id) - aug_forward = get_executor_specific_aug_fwd_rule(symbol) or aug_forward if _get_gradfn(symbol) is not None: aug_forward, backward = make_aug_forward_and_backward(symbol) - if isinstance(aug_forward, RuleInfo): - backward = backward_impls[aug_forward.executor, symbol.sym.id] - if backward is None: if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs): # We could not find a backward for this symbol, so we try to decompose it diff --git a/thunder/executors/apex_entropyex.py b/thunder/executors/apex_entropyex.py index 8a82e04e20..818199ad5b 100644 --- a/thunder/executors/apex_entropyex.py +++ b/thunder/executors/apex_entropyex.py @@ -11,10 +11,6 @@ from thunder.core.symbol import Symbol from thunder.core.utils import check, same_shape from thunder.core.transforms import get_grad, put_grad, put_grads, mean_backward, restore_reduced_dims -from thunder.core.transforms import ( - register_augmented_forward_with_checker, - register_backward, -) from thunder.extend import OperatorExecutor, register_executor @@ -197,76 +193,6 @@ def _cross_entropy_checker( return True -# Check out the 'add vjp rule' dev tutorial on how to add a VJP rule for any -# Symbol. We use our new primitives to register a VJP rule for -# torch.nn.functional.cross_entropy. This function is registered as the -# augmented forward rule for torch.nn.functional.cross_entropy below -def apex_cross_entropy_forward_rule( - a, - target, - weight=None, - size_average=None, - ignore_index=-100, - reduce=None, - reduction="mean", - label_smoothing=0.0, -): - loss, max_log_sum_exp = apex_xentropy( - a, - target=target, - reduction=reduction, - label_smoothing=label_smoothing, - ) - primal = loss - saved_for_backward = (a, target, max_log_sum_exp, reduction, label_smoothing) - return primal, saved_for_backward - - -register_augmented_forward_with_checker( - apex_ex, - ltorch.cross_entropy.id, - _cross_entropy_checker, - apex_cross_entropy_forward_rule, -) - - -# This function is the backward rule for torch.nn.functional.cross_entropy. It -# accepts the primal output and saved_for_backward from the forward pass and -# returns the backward output. The backward output is a tuple of the backward -# output for each differentiable Tensor input to the forward pass. In this case, -# the forward pass has 1 such input, so the backward output is a single Tensor. -# This function is registered as the backward rule for -# torch.nn.functional.cross_entropy -@register_backward((apex_ex, ltorch.cross_entropy.id)) -def apex_cross_entropy_backward_rule( - logits, - labels, - max_log_sum_exp, - reduction, - smoothing, - grad, -): - from thunder.core.transforms import mean_backward, sum_backward - - if reduction == "mean": - grad = mean_backward(max_log_sum_exp.ndim, max_log_sum_exp.shape, (0,), grad) - elif reduction == "sum": - grad = sum_backward(max_log_sum_exp.shape, (0,), grad) - elif reduction == "none": - pass - else: - raise ValueError(f"Invalid reduction: {reduction}") - - grad_logits = apex_xentropy_bwd( - grad, - logits, - target=labels, - max_log_sum_exp=max_log_sum_exp, - label_smoothing=smoothing, - ) - return grad_logits - - # Translate calls from torch.nn.functional.cross_entropy to apex_xentropy (when the checker above returns True) def _cross_entropy_transform( a: TensorProxy, diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 9fb6c50a48..75494cff5a 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -35,8 +35,6 @@ def cudnn_available() -> bool: get_grad, put_grad, put_grads, - register_augmented_forward_with_checker, - register_backward, ) from thunder.extend import OperatorExecutor, register_executor import thunder.torch as ltorch @@ -338,7 +336,11 @@ def _cudnn_sdpa_forward_checker( if d % 8 != 0 or d > 128: return False - return True + is_backward_supported = _cudnn_sdpa_backward_checker( + query, key, value, attn_mask, dropout_p, is_causal, scale=scale + ) + + return True and is_backward_supported @langctx("torch") @@ -601,99 +603,6 @@ def cudnn_sdpa_bwd_impl( ) -@langctx("torch") -def cudnn_sdpa_aug_fw_rule_checker( - query: TensorProxy, - key: TensorProxy, - value: TensorProxy, - attn_mask: None | TensorProxy, - dropout_p: float, - is_causal: bool, - *, - scale: None | float, -) -> bool: - from thunder.core.compile_data import get_compile_data - - cd = get_compile_data() - if cudnn_ex in cd.executors_list: - is_forward_supported = _cudnn_sdpa_forward_checker( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - is_backward_supported = _cudnn_sdpa_backward_checker( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - return is_forward_supported and is_backward_supported - return False - - -def cudnn_sdpa_aug_fw_rule( - query, - key, - value, - attn_mask=None, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: float | None = None, -): - output, softmax_stats, seed, offset = cudnn_sdpa_fwd( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - saved_for_backward = ( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - output, - softmax_stats, - seed, - offset, - ) - return output, saved_for_backward - - -register_augmented_forward_with_checker( - cudnn_ex, - "torch.nn.functional.scaled_dot_product_attention", - cudnn_sdpa_aug_fw_rule_checker, - cudnn_sdpa_aug_fw_rule, -) - - -@register_backward((cudnn_ex, "torch.nn.functional.scaled_dot_product_attention")) -def cudnn_sdpa_backward_rule( - query: Proxy, - key: Proxy, - value: Proxy, - attn_mask: None | Proxy, - dropout_p: float, - is_causal: bool, - scale: None | float, - out: Proxy, - softmax_stats: Proxy, - seed: Proxy, - offset: Proxy, - grad_out: Proxy, -): - return cudnn_sdpa_bwd( - grad_out, - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - out, - softmax_stats, - seed, - offset, - scale=scale, - ) - - @langctx("torch") def _cudnn_sdpa_transform( query: TensorProxy, @@ -726,7 +635,7 @@ def _cudnn_sdpa_grad( ) g = get_grad(primal) - grad_query, grad_key, grad_val, grad_attn_mask = cudnn_sdpa_bwd( + grads = cudnn_sdpa_bwd( g, query, key, @@ -740,6 +649,11 @@ def _cudnn_sdpa_grad( offset, scale=scale, ) + if attn_mask is None: + grad_query, grad_key, grad_val = grads + else: + grad_query, grad_key, grad_val, grad_attn_mask = grads + put_grads((query, key, value), (grad_query, grad_key, grad_val)) if attn_mask is not None: put_grad(attn_mask, grad_attn_mask) diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index 91d32e7a88..ced5d8fdb1 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -19,10 +19,6 @@ import thunder.core.prims as prims from thunder.core.proxies import TensorProxy, CollectionProxy from thunder.core.symbol import Symbol -from thunder.core.transforms import ( - register_augmented_forward_with_checker, - register_backward, -) from thunder.extend import OperatorExecutor, register_executor __all__ = [ @@ -411,15 +407,6 @@ def linear_forward_rule_checker(a: TensorProxy, w: TensorProxy, bias: None | Ten return False -register_augmented_forward_with_checker( - transformer_engine_ex, - prims.linear.id, - linear_forward_rule_checker, - linear_forwad_rule, -) - - -@register_backward((transformer_engine_ex, prims.linear.id)) def linear_backward_rule(a_shape, w_shape, b_shape, ctx_idx, grad): return te_functional_linear_backward(grad, a_shape, w_shape, b_shape, ctx_idx) @@ -429,9 +416,21 @@ def _linear_transform(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> torch.T return _create_fp8_linear_bound_symbol(a, w, b, is_grad_enabled=False) +def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy: + out, saved_for_backward = linear_forwad_rule(a, w, b) + g = prims.get_grad(out) + ga, gw, gb = linear_backward_rule(*saved_for_backward, g) + prims.put_grad(a, ga) + prims.put_grad(w, gw) + if b is not None: + prims.put_grad(b, gb) + return out + + # Registers the implementation for torch.nn.functional.linear transformer_engine_ex.register_implementation( prims.linear, checker=_linear_checker, execution_transform=_linear_transform, + grad_transform=_linear_grad, ) diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index c9bd6277ec..4128e02914 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -110,10 +110,8 @@ def test_cudnn_sdpa(): query = 1 * (torch.randn(shape_Q, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5) key = 2 * (torch.randn(shape_K, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5) value = 3 * (torch.randn(shape_V, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5) - is_causal = False - attn_mask = torch.randn( - s_q, s_kv, requires_grad=False, device="cuda", dtype=thunder.torch.to_torch_dtype(dtype) - ) + is_causal = True + attn_mask = None expected = torch.nn.functional.scaled_dot_product_attention( query, key, value, is_causal=is_causal, attn_mask=attn_mask From 735b8758262aba9f8d6c7f9308ca6aae76395d6a Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Thu, 21 Mar 2024 12:03:36 -0500 Subject: [PATCH 3/3] Minor Readme cosmetics (#28) --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index d3d325a155..7d60063864 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Thunder aims to be usable, understandable, and extensible. ## Performance -Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best in class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt). +Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best-in-class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt).
Thunder @@ -121,9 +121,9 @@ The compiled function `jfoo` takes and returns PyTorch tensors, just like the or ## Train models -Thunder is in its early stages, it should not be used for production runs yet. +Thunder is in its early stages and should not be used for production runs yet. -However, it can already deliver outstanding performance on models supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama2, Gemma, Falcon, and derivatives. +However, it can already deliver outstanding performance on LLM model supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama 2, Gemma, Falcon, and others. Run training loop for Llama, single-GPU: @@ -141,7 +141,7 @@ See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with T ## Features -Given a python callable or PyTorch module, Thunder can generate an optimized program that: +Given a Python callable or PyTorch module, Thunder can generate an optimized program that: - Computes its forward and backward passes - Coalesces operations into efficient fusion regions @@ -204,4 +204,4 @@ Thunder is very thoroughly tested, so expect this to take a while. ## License Lightning Thunder is released under the [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) license. -See LICENSE file for details. +See the [LICENSE](LICENSE) file for details.