diff --git a/README.md b/README.md index 7d60063864..f89dce73b0 100644 --- a/README.md +++ b/README.md @@ -125,19 +125,7 @@ Thunder is in its early stages and should not be used for production runs yet. However, it can already deliver outstanding performance on LLM model supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama 2, Gemma, Falcon, and others. -Run training loop for Llama, single-GPU: - -```bash -python examples/lit-gpt/train.py -``` - -Run training loop for Llama, multi-GPU, using FSDP: - -```bash -python examples/lit-gpt/train_fsdp.py -``` - -See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with Thunder. +Check out [the LitGPT integration](https://github.com/Lightning-AI/litgpt/tree/main/extensions/thunder) to learn about running LitGPT and Thunder together. ## Features diff --git a/examples/lit-gpt/.gitignore b/examples/lit-gpt/.gitignore deleted file mode 100644 index c3d41546e1..0000000000 --- a/examples/lit-gpt/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -checkpoints - -download.py -convert_hf_checkpoint.py diff --git a/examples/lit-gpt/1_forward.py b/examples/lit-gpt/1_forward.py deleted file mode 100644 index c0392e3382..0000000000 --- a/examples/lit-gpt/1_forward.py +++ /dev/null @@ -1,57 +0,0 @@ -import time - -import lightning as L -import torch -import torch._dynamo.config -import torch._inductor.config - -from thunder.tests.lit_gpt_model import GPT - - -@torch.inference_mode() -def main(name: str = "open_llama_7b", num_samples: int = 10, compile: str = "eager") -> None: - torch.set_float32_matmul_precision("high") - torch.set_default_dtype(torch.bfloat16) - device = torch.device("cuda") - - with device: - model = GPT.from_name(name) - encoded = torch.randint(0, model.config.padded_vocab_size, (10, model.max_seq_length)) - - model.eval() - - if compile == "inductor": - torch._dynamo.config.automatic_dynamic_shapes = True - torch._inductor.config.triton.unique_kernel_names = True - torch._inductor.config.coordinate_descent_tuning = True - model = torch.compile(model, fullgraph=True) - elif compile == "thunder": - import thunder - from thunder.executors.sdpaex import sdpa_ex - from thunder.executors.torch_compile import torch_compile_executor - - model = thunder.jit( - model, - disable_torch_autograd=True, - executors=[sdpa_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor], - ) - elif compile != "eager": - raise ValueError(compile) - - values = [] - L.seed_everything(1234) - for i in range(num_samples): - t0 = time.perf_counter() - _ = model(encoded) - torch.cuda.synchronize() - t = time.perf_counter() - t0 - values.append(t) - print(f"Time for inference {i + 1}: {t:.05f} sec total") - print(f"Best: {min(values):05f}") - print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") - - -if __name__ == "__main__": - from jsonargparse import CLI - - CLI(main) diff --git a/examples/lit-gpt/README.md b/examples/lit-gpt/README.md deleted file mode 100644 index bf0b46199b..0000000000 --- a/examples/lit-gpt/README.md +++ /dev/null @@ -1,93 +0,0 @@ -# Lit-GPT benchmarks - -## Setup - -```bash -wget -nc https://raw.githubusercontent.com/Lightning-AI/lit-gpt/1a5e7c/scripts/download.py -pip install jsonargparse huggingface_hub sentencepiece tokenizers -pip install git+https://github.com/Lightning-AI/lit-gpt@1a5e7c -``` - -## [1 forward](1_forward.py) - -```bash -python 1_forward.py --compile thunder -``` - -Runs a single forward call with a (B=10 x T=2048) tensor: - -| Method | Time ↓ | Memory ↓ | -| -------- | ------ | -------- | -| Inductor | 1.18 s | 17.38 GB | -| Thunder | 1.27 s | 16.32 GB | -| Eager | 1.48 s | 17.44 GB | - -## [Single-device training](train.py) - -```shell -# setup -python download.py --repo_id openlm-research/open_llama_3b --tokenizer_only true -# run -python train.py --compile thunder --dynamic false -``` - -Static shapes (45 iters) - -| Method | Time ↓ | Memory ↓ | -| -------- | ------ | -------- | -| Inductor | 20.1 s | 20.95 GB | -| Thunder | 21.9 s | 23.75 GB | -| Eager | 24.6 s | 24.28 GB | - -Dynamic shapes (45 iters) - -| Method | Time ↓ | Memory ↓ | -| -------- | ------- | -------- | -| Inductor | 17.0 s | 20.69 GB | -| Eager | 17.6 s | 23.91 GB | -| Thunder | ~5715 s | - | - -## [Multi-device training](train_fsdp.py) - -```shell -# setup -python download.py --repo_id openlm-research/open_llama_3b --tokenizer_only true -# run -python train_fsdp.py --devices 2 --compile thunder --stage 2 --bucketing_strategy BLOCK -``` - -Static shapes (45 iters) - -| Stage | Bucketing | Method | Time ↓ | Memory ↓ | -| ----- | --------- | -------- | ------- | -------- | -| 2 | No | Inductor | Error | Error | -| 2 | No | Thunder | 23.29 s | 26.99 GB | -| 2 | No | Eager | 27.76 s | 27.61 GB | -| | | | | | -| 2 | Block | Inductor | 21.71 s | 24.31 GB | -| 2 | Block | Thunder | 24.30 s | 26.96 GB | -| 2 | Block | Eager | 26.05 s | 27.67 GB | -| | | | | | -| 3 | No | Inductor | Error | Error | -| 3 | No | Thunder | 24.39 s | 20.25 GB | -| 3 | No | Eager | 28.56 s | 20.75 GB | -| | | | | | -| 3 | Block | Inductor | 21.76 s | 17.86 GB | -| 3 | Block | Thunder | 24.11 s | 26.93 GB | -| 3 | Block | Eager | 26.33 s | 21.23 GB | - -## Setup - -```text -Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) -Is debug build: False -CUDA used to build PyTorch: 12.1 -CUDA runtime version: 12.3.107 -GPU 0: NVIDIA A100-SXM4-40GB -Nvidia driver version: 545.23.08 - -pytorch-triton==3.0.0+901819d2b6 -torch==2.3.0.dev20240225+cu121 -lightning-thunder==51993f9a6894f59f3779b30485e72b93d5e7b150 -nvfuser_cu121==0.1.6.dev20240226 -``` diff --git a/examples/lit-gpt/_ddp_thunder.py b/examples/lit-gpt/_ddp_thunder.py deleted file mode 100644 index 1bd07619df..0000000000 --- a/examples/lit-gpt/_ddp_thunder.py +++ /dev/null @@ -1,226 +0,0 @@ -"""Fabric Strategy to support Thunder DDP: To be upstreamed into Fabric eventually.""" - -from contextlib import nullcontext -from datetime import timedelta -from typing import TYPE_CHECKING, Any, ContextManager, Dict, List, Optional, Tuple, Union - -import torch -import torch.distributed -from lightning_utilities.core.imports import RequirementCache -from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only -from torch import Tensor -from torch.nn import Module -from typing_extensions import override - -from lightning.fabric.accelerators.accelerator import Accelerator -from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout -from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO -from lightning.fabric.plugins.precision import Precision -from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher -from lightning.fabric.strategies.parallel import ParallelStrategy -from lightning.fabric.strategies.strategy import TBroadcast, _BackwardSyncControl -from lightning.fabric.utilities.distributed import ( - ReduceOp, - _distributed_is_initialized, - _get_default_process_group_backend_for_device, - _init_dist_connection, - _sync_ddp_if_available, -) -from lightning.fabric.utilities.rank_zero import rank_zero_only - -if TYPE_CHECKING: - from thunder import Executor - - -_THUNDER_AVAILABLE = RequirementCache("lightning-thunder", "thunder") - - -class DDPThunderStrategy(ParallelStrategy): - def __init__( - self, - accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision: Optional[Precision] = None, - executors: Optional[Tuple[Union["Executor", str], ...]] = None, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, - **kwargs: Any, - ): - if not _THUNDER_AVAILABLE: - raise ModuleNotFoundError(str(_THUNDER_AVAILABLE)) - super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision) - self.parallel_devices = parallel_devices - self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment - - self.executors = _validate_executors(executors) - self._num_nodes = 1 - self._process_group_backend: Optional[str] = process_group_backend - self._timeout: Optional[timedelta] = timeout - self._backward_sync_control = _DDPBackwardSyncControl() - self._ddp_kwargs = kwargs - - @property - @override - def root_device(self) -> torch.device: - assert self.parallel_devices is not None - return self.parallel_devices[self.local_rank] - - @property - def num_nodes(self) -> int: - return self._num_nodes - - @num_nodes.setter - def num_nodes(self, num_nodes: int) -> None: - # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks - self._num_nodes = num_nodes - - @property - def num_processes(self) -> int: - return len(self.parallel_devices) if self.parallel_devices is not None else 0 - - @property - @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: - return {"num_replicas": self.num_nodes * self.num_processes, "rank": self.global_rank} - - @override - def _configure_launcher(self) -> None: - assert self.cluster_environment is not None - if not self.cluster_environment.creates_processes_externally: - self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) - - @property - def process_group_backend(self) -> Optional[str]: - return self._process_group_backend - - @override - def _configure_launcher(self) -> None: - assert self.cluster_environment is not None - self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) - - @override - def setup_environment(self) -> None: - super().setup_environment() - self._setup_distributed() - - @override - def setup_module(self, module: Module) -> Module: - import thunder - - module = thunder.distributed.ddp(module, **self._ddp_kwargs) - - return thunder.jit(module, executors=self.executors) - - @override - def module_to_device(self, module: Module) -> None: - module.to(self.root_device) - - @override - def all_reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: - if isinstance(tensor, Tensor): - return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - - @override - def barrier(self, *args: Any, **kwargs: Any) -> None: - if not _distributed_is_initialized(): - return - if torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=[self.root_device.index]) - else: - torch.distributed.barrier() - - @override - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not _distributed_is_initialized(): - return obj - - obj = [obj] - torch.distributed.broadcast_object_list(obj, src) - return obj[0] - - def _setup_distributed(self) -> None: - self._set_world_ranks() - self._process_group_backend = self._get_process_group_backend() - assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) - - def _get_process_group_backend(self) -> str: - return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) - - def _set_world_ranks(self) -> None: - if self.cluster_environment is not None: - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail - # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter - rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank - - -def _validate_executors(executors: Optional[Tuple[Union["Executor", str], ...]]) -> Optional[Tuple["Executor", ...]]: - """Converts string executors into it's respective ``Executor`` object.""" - if executors is None: - return None - from thunder import get_all_executors - - final = [] - issues = [] - all = get_all_executors() - for executor in executors: - if isinstance(executor, str): - for existing in all: - if executor == existing.name: - final.append(existing) - break - else: - issues.append(executor) - else: - final.append(executor) - if issues: - raise ValueError(f"Did not find the executors {issues} in {all}") - return tuple(final) - - -class _DDPBackwardSyncControl(_BackwardSyncControl): - def __init__(self): - self._enabled = False - - @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: - if not getattr(module, "use_ddp", False): - raise TypeError( - "Blocking backward sync is only possible if the module passed to" - f" `{self.__class__.__name__}.no_backward_sync` is applied DDP." - f" Got: {module.__class__.__name__}." - ) - - # issue "Limitations of the current DDP no_sync implementation" has - # details on why we cannot just return `module.no_sync()` - from thunder.distributed import skip_data_parallel_grad_sync - - previous, self._enabled = self._enabled, enabled - if enabled: - return skip_data_parallel_grad_sync() - if not enabled and previous: - return _AllReduceGradsContextManager(module) - return nullcontext() - - -class _AllReduceGradsContextManager: - def __init__(self, module: Module) -> None: - self._module = module - - @override - def __enter__(self) -> None: - from thunder.distributed import _sync_grads - - _sync_grads(self._module) - - @override - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - pass diff --git a/examples/lit-gpt/_fsdp_thunder.py b/examples/lit-gpt/_fsdp_thunder.py deleted file mode 100644 index 133c40b1f2..0000000000 --- a/examples/lit-gpt/_fsdp_thunder.py +++ /dev/null @@ -1,420 +0,0 @@ -"""Fabric Strategy to support Thunder FSDP: To be upstreamed into Fabric eventually.""" - -import shutil -from contextlib import ExitStack, nullcontext -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Tuple, Union - -import torch -from lightning_utilities.core.imports import RequirementCache -from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only -from torch import Tensor -from torch.nn import Module -from torch.optim import Optimizer -from typing_extensions import override - -from lightning.fabric.accelerators.accelerator import Accelerator -from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO -from lightning.fabric.plugins.precision import Precision -from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher -from lightning.fabric.strategies.parallel import ParallelStrategy -from lightning.fabric.strategies.strategy import TBroadcast, _apply_filter, _Sharded, _validate_keys_for_strict_loading -from lightning.fabric.utilities.distributed import ( - ReduceOp, - _distributed_is_initialized, - _get_default_process_group_backend_for_device, - _init_dist_connection, - _sync_ddp_if_available, -) -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 -from lightning.fabric.utilities.load import _METADATA_FILENAME, _move_state_into -from lightning.fabric.utilities.rank_zero import rank_zero_only -from lightning.fabric.utilities.seed import reset_seed -from lightning.fabric.utilities.types import _PATH, _Stateful - -if TYPE_CHECKING: - from thunder import Executor - from thunder.distributed import FSDPBucketingStrategy, FSDPType - from thunder.distributed.checkpoint import StateDictOptions - - _FSDP_TYPE = Union[FSDPType, Literal["ZERO2", "ZERO3"]] - _BUCKETING_STRATEGY = Union[FSDPBucketingStrategy, Literal["NONE", "LAYER", "BLOCK"]] - - -_THUNDER_AVAILABLE = RequirementCache("lightning-thunder", "thunder") - - -class FSDPThunderStrategy(ParallelStrategy, _Sharded): - def __init__( - self, - accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision: Optional[Precision] = None, - sharding_strategy: "_FSDP_TYPE" = "ZERO3", - bucketing_strategy: "_BUCKETING_STRATEGY" = "NONE", - executors: Optional[Tuple[Union["Executor", str], ...]] = None, - state_dict_type: Literal["full", "sharded"] = "sharded", - **kwargs: Any, - ): - if not _TORCH_GREATER_EQUAL_2_2: - raise ImportError("Thunder's FSDP strategy requires PyTorch 2.2 or higher.") - if not _THUNDER_AVAILABLE: - raise ModuleNotFoundError(str(_THUNDER_AVAILABLE)) - super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision) - self.parallel_devices = parallel_devices - self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment - from thunder.distributed import FSDPBucketingStrategy, FSDPType - - self.sharding_strategy = ( - FSDPType[sharding_strategy.upper()] if isinstance(sharding_strategy, str) else sharding_strategy - ) - self.bucketing_strategy = ( - FSDPBucketingStrategy[bucketing_strategy.upper()] - if isinstance(bucketing_strategy, str) - else bucketing_strategy - ) - self.executors = _validate_executors(executors) - self._state_dict_type = state_dict_type - self._fsdp_kwargs = kwargs - - @property - @override - def root_device(self) -> torch.device: - assert self.parallel_devices is not None - return self.parallel_devices[self.local_rank] - - @property - def num_nodes(self) -> int: - return 1 - - @property - def num_processes(self) -> int: - return len(self.parallel_devices) if self.parallel_devices is not None else 0 - - @property - @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: - return {"num_replicas": self.num_nodes * self.num_processes, "rank": self.global_rank} - - @override - def _configure_launcher(self) -> None: - assert self.cluster_environment is not None - if not self.cluster_environment.creates_processes_externally: - self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) - - @override - def setup_environment(self) -> None: - super().setup_environment() - self._setup_distributed() - - @override - def setup_module(self, module: Module) -> Module: - import thunder - - module = thunder.distributed.fsdp( - module, - device=self.root_device, - sharding_strategy=self.sharding_strategy, - bucketing_strategy=self.bucketing_strategy, - **self._fsdp_kwargs, - ) - - # NOTE @IvanYaschuck says that `fsdp(jit(model))` could be supported in the future so that the user owns the `jit` call. - # we would still `jit(fsdp(undo_jit(jit(model))))` internally - return thunder.jit(module, executors=self.executors) - - @override - def module_to_device(self, module: Module) -> None: - pass - - @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: - precision_init_ctx = self.precision.module_init_context() - module_sharded_ctx = self.module_sharded_context() - stack = ExitStack() - if empty_init: - # Materialization happens in `setup`. When modules get wrapped by FSDP - stack.enter_context(torch.device("meta")) - stack.enter_context(precision_init_ctx) - stack.enter_context(module_sharded_ctx) - return stack - - @override - def module_sharded_context(self) -> ContextManager: - return nullcontext() - - @override - def all_reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: - if isinstance(tensor, Tensor): - return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - - @override - def barrier(self, *args: Any, **kwargs: Any) -> None: - if not _distributed_is_initialized(): - return - if torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=[self.root_device.index]) - else: - torch.distributed.barrier() - - @override - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not _distributed_is_initialized(): - return obj - - obj = [obj] - torch.distributed.broadcast_object_list(obj, src) - return obj[0] - - @override - def clip_gradients_norm( - self, - module: Module, - optimizer: Optimizer, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, - error_if_nonfinite: bool = True, - ) -> Tensor: - raise NotImplementedError - - @override - def save_checkpoint( - self, - path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], - storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, - ) -> None: - if storage_options is not None: - raise TypeError( - "`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because" - " `FSDPStrategy` does not use the `CheckpointIO`." - ) - if filter is not None: - raise NotImplementedError("Filtering checkpoint paths is not implemented") - - # broadcast the path from rank 0 to ensure all the states are saved in a common path - path = Path(self.broadcast(path)) - if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path): - raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") - - from thunder.distributed.checkpoint import save, has_fsdp_modules, StateDictOptions - - modules = [module for module in state.values() if has_fsdp_modules(module)] - if len(modules) == 0: - raise ValueError( - "Could not find a FSDP model in the provided checkpoint state. Please provide the model as" - " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure" - " you set up the model (and optimizers if any) through the strategy before saving the checkpoint." - ) - if len(modules) > 1: - raise ValueError( - "Found multiple FSDP models in the given state. Saving checkpoints with FSDP is" - " currently limited to a single model per checkpoint. To save multiple models, call the" - " save method for each model separately with a different path." - ) - - if self._state_dict_type == "sharded": - if _is_full_checkpoint(path): - path.unlink() - path.mkdir(parents=True, exist_ok=True) - - options = StateDictOptions(full_state_dict=False, cpu_offload=True, rank0_only=False) - converted_state, metadata = _get_state_dict(state, filter, options, self.local_rank) - save(converted_state, path) - if self.global_rank == 0: - torch.save(metadata, path / _METADATA_FILENAME) - - elif self._state_dict_type == "full": - if _is_sharded_checkpoint(path): - shutil.rmtree(path) - - options = StateDictOptions(full_state_dict=True, cpu_offload=True, rank0_only=True) - converted_state, metadata = _get_state_dict(state, filter, options, self.local_rank) - converted_state.update(metadata) - if self.global_rank == 0: - torch.save(converted_state, path) - else: - raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") - - @override - def load_checkpoint( - self, - path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, - strict: bool = True, - ) -> Dict[str, Any]: - if not state: - raise ValueError( - f"Got `FSDPStrategy.load_checkpoint(..., state={state!r})` but a state with at least" - " a model instance to reload is required. Pass it in like so:" - " `FSDPStrategy.load_checkpoint(..., state={'model': model, ...})`" - ) - # broadcast the path from rank 0 to ensure all the states are loaded from a common path - path = Path(self.broadcast(path)) - - from thunder.distributed.checkpoint import has_fsdp_modules, StateDictOptions, load_model_state_dict, load - - if isinstance(state, Module): - if not _is_full_checkpoint(path): - raise ValueError( - "Failed to load checkpoint directly into the model. The given path must be a single file" - f" containing the full state dict: {path}" - ) - state_dict = torch.load(str(path), mmap=True, map_location="cpu") - options = StateDictOptions(full_state_dict=True, cpu_offload=True, strict=strict, rank0_only=False) - load_model_state_dict(state_dict, _unwrap_tom(state), options, self.local_rank) - return {} - - if isinstance(state, Optimizer): - raise NotImplementedError( - "Loading a single optimizer object from a checkpoint is not supported yet with the FSDP strategy." - ) - - modules = {key: module for key, module in state.items() if has_fsdp_modules(module)} - if len(modules) == 0: - raise ValueError( - "Could not find a FSDP model in the provided checkpoint state. Please provide the model as" - " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure" - " you set up the model (and optimizers if any) through the strategy before loading the checkpoint." - ) - if len(modules) > 1: - raise ValueError( - "Found multiple FSDP models in the given state. Loading checkpoints with FSDP is" - " currently limited to a single model per checkpoint. To load multiple models, call the" - " load method for each model separately with a different path." - ) - optimizers = {key: optim for key, optim in state.items() if isinstance(optim, Optimizer)} - module_key, module = list(modules.items())[0] - module = _unwrap_tom(module) - - if _is_sharded_checkpoint(path): - options = StateDictOptions(full_state_dict=False, cpu_offload=True, strict=strict, rank0_only=False) - # Load the DCP state dict, which requires a holder state dict - converted_state, _ = _get_state_dict(state, None, options, self.local_rank) - load(converted_state, path) - load_model_state_dict(converted_state[module_key], module, options, self.local_rank) - - # Load metadata (anything not a module or optimizer) - metadata = torch.load(path / _METADATA_FILENAME) - requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() - _validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict) - for key in requested_metadata_keys: - if key not in metadata: - continue - state[key] = metadata.pop(key) - # return the remaining metadata that wasn't requested as part of `state` - return metadata - - if _is_full_checkpoint(path): - options = StateDictOptions(full_state_dict=True, cpu_offload=True, strict=strict, rank0_only=False) - if not options.rank0_only or self.local_rank == 0: - map_location = "cpu" if options.cpu_offload else None - checkpoint = torch.load(str(path), mmap=True, map_location=map_location) - load_model_state_dict(checkpoint[module_key], module, options, self.local_rank) - else: - checkpoint = {} - - requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() - _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict) - # Load metadata (anything not a module or optimizer) - _move_state_into(source=checkpoint, destination=state, keys=requested_metadata_keys) - # return the remaining metadata that wasn't requested as part of `state` - return checkpoint - - raise ValueError( - f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a" - " directory with FSDP checkpoint shards, or a single file with a full checkpoint." - ) - - def _setup_distributed(self) -> None: - reset_seed() - self._set_world_ranks() - process_group_backend = _get_default_process_group_backend_for_device(self.root_device) - assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, process_group_backend) - - def _set_world_ranks(self) -> None: - if self.cluster_environment is not None: - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail - # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter - rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank - - -def _is_sharded_checkpoint(path: Path) -> bool: - """A heuristic check to determine whether the path points to a directory with checkpoint shards.""" - return path.is_dir() and (path / _METADATA_FILENAME).is_file() - - -def _is_full_checkpoint(path: Path) -> bool: - return path.is_file() - - -def _validate_executors(executors: Optional[Tuple[Union["Executor", str], ...]]) -> Optional[Tuple["Executor", ...]]: - """Converts string executors into it's respective ``Executor`` object.""" - if executors is None: - return None - from thunder import get_all_executors - - final = [] - issues = [] - all = get_all_executors() - for executor in executors: - if isinstance(executor, str): - for existing in all: - if executor == existing.name: - final.append(existing) - break - else: - issues.append(executor) - else: - final.append(executor) - if issues: - raise ValueError(f"Did not find the executors {issues} in {all}") - return tuple(final) - - -def _get_state_dict( - state: Dict[str, Any], - filter: Optional[Dict[str, Callable[[str, Any], bool]]], - options: "StateDictOptions", - rank: int, -) -> Tuple[Dict[str, Any], Dict[str, Any]]: - from thunder.distributed.checkpoint import get_model_state_dict - - # replace the modules and optimizer objects in the state with their local state dict - # and separate the user's metadata - converted_state: Dict[str, Any] = {} - metadata: Dict[str, Any] = {} - for key, obj in state.items(): - converted: Any - if isinstance(obj, Module): - converted = get_model_state_dict(_unwrap_tom(obj), options, rank) - target_dict = converted_state - elif isinstance(obj, Optimizer): - # TODO: optimizer support - converted = obj.state_dict() - target_dict = converted_state - else: # everything not a module or optimizer is considered metadata - converted = obj.state_dict() if isinstance(obj, _Stateful) else obj - target_dict = metadata - _apply_filter(key, filter or {}, converted, target_dict) - - return converted_state, metadata - - -def _unwrap_tom(obj: object) -> object: - # TODO: this unwrap won't be required when Fabric's `_unwrap_objects` supports Thunder - from thunder import ThunderModule - - if isinstance(obj, ThunderModule): - return obj._model - return obj diff --git a/examples/lit-gpt/test_ddp_thunder.py b/examples/lit-gpt/test_ddp_thunder.py deleted file mode 100644 index 860d14a8d4..0000000000 --- a/examples/lit-gpt/test_ddp_thunder.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest -import torch -from _ddp_thunder import DDPThunderStrategy - -from lightning import Fabric -# from tests.tests_fabric.helpers.runif import RunIf - - -# @RunIf(min_cuda_gpus=2, thunder=True, standalone=True) -@pytest.mark.parametrize("strategy", ["ddp", DDPThunderStrategy()]) -def test_no_backward_sync(strategy): - fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy) - fabric.launch() - - model = torch.nn.Linear(1, 1, bias=False, device=fabric.device) - x = torch.randn(1, 1, device=fabric.device) - model = fabric.setup(model) - - # 6 iters, 3 grad accumulation iters - for i, enabled in enumerate((True, True, False, True, True, False), 1): - x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32) - - with fabric.no_backward_sync(model, enabled): - y = model(x) - y.backward() - if not enabled: - # Math for the first 3 iters - # - # DistributedDataParallel - # (1*1+2*1+3*1 + 1*2+2*2+3*2) / 2 = 9 - # ^^^^^^^^^^^ ^^^^^^^^^^^ ^^^ - # rank0 rank1 allreduce - # - # thunder.distributed.ddp - # ((1*1+2*1) + (1*2+2*2)) / 2 + (3*1 + 3*2) / 2 = 9 - # ^^^^^^^ ^^^^^^^ ^^^ ^^^ ^^^ ^^^ - # rank0 rank1 allreduce1 rank0 rank1 allreduce2 - assert model.weight.grad.item() == (9.0 if i == 3 else 22.5) - model.weight.grad = None diff --git a/examples/lit-gpt/test_fsdp_thunder.py b/examples/lit-gpt/test_fsdp_thunder.py deleted file mode 100644 index a49b9131f7..0000000000 --- a/examples/lit-gpt/test_fsdp_thunder.py +++ /dev/null @@ -1,294 +0,0 @@ -from _fsdp_thunder import FSDPThunderStrategy, _validate_executors -from lightning.fabric import Fabric -import torch -import pytest -import re -import os -from typing import Optional, Tuple, Union -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 - - -def test_thunder_strategy_input_parsing(): - from thunder.distributed import FSDPBucketingStrategy, FSDPType - from thunder import pythonex - - strategy = FSDPThunderStrategy(bucketing_strategy="BlOcK", executors_list=("python",), sharding_strategy="zero3") - assert strategy.bucketing_strategy is FSDPBucketingStrategy.BLOCK - assert strategy.executors_list == (pythonex,) - assert strategy.sharding_strategy is FSDPType.ZERO3 - - -def test_validate_executors(): - from thunder import pythonex, pytorch_executor - - assert _validate_executors(None) is None - assert _validate_executors((pythonex, pytorch_executor)) == (pythonex, pytorch_executor) - assert _validate_executors(("python", "torch")) == (pythonex, pytorch_executor) - assert _validate_executors(("python", pytorch_executor)) == (pythonex, pytorch_executor) - with pytest.raises(ValueError, match=re.escape("not find the executors ['foo', 'bar'] in")): - assert _validate_executors(("python", "foo", pytorch_executor, "bar")) - - -def test_save_checkpoint_invalid_settings_raise(tmp_path): - strategy = FSDPThunderStrategy(state_dict_type="full") - with pytest.raises(TypeError, match="not supported"): - strategy.save_checkpoint(tmp_path, {}, storage_options=object()) - - with pytest.raises(IsADirectoryError, match="path exists"): - strategy.save_checkpoint(tmp_path, {}) - - model = torch.nn.Linear(1, 1) - with pytest.raises(ValueError, match="Could not find"): - strategy.save_checkpoint(tmp_path / "foo", {}) - - model.use_fsdp = True - with pytest.raises(ValueError, match="Found multiple"): - strategy.save_checkpoint(tmp_path / "foo", {"model1": model, "model2": model}) - - with pytest.raises(ValueError, match="at least a model"): - strategy.load_checkpoint(tmp_path / "foo", {}) - - with pytest.raises(ValueError, match="must be a single file"): - strategy.load_checkpoint(tmp_path, model) - - optimizer = torch.optim.Adam(model.parameters()) - with pytest.raises(NotImplementedError, match="not supported"): - strategy.load_checkpoint(tmp_path, optimizer) - - with pytest.raises(ValueError, match="Found multiple"): - strategy.load_checkpoint(tmp_path / "foo", {"model1": model, "model2": model}) - - with pytest.raises(ValueError, match="Could not find"): - strategy.load_checkpoint(tmp_path / "foo", {"foo": 1}) - - -class Submodule(torch.nn.Module): - def __init__(self, h: int): - super().__init__() - self.l = torch.nn.Linear(4, h * 2, bias=False) - - def forward(self, x): - # defined just because preprocessing fails otherwise - ... - - -class MyModel(torch.nn.Module): - def __init__(self, h: int): - super().__init__() - self.register_buffer("buf", torch.tensor(0)) - self.l = torch.nn.Linear(2, h) - self.inner = Submodule(h) - - def forward(self): - # defined just because preprocessing fails otherwise - ... - - def reset_parameters(self): - self.buf = torch.empty_like(self.buf) - - -def test_materialize_meta_tensors(): - strategy = FSDPThunderStrategy() - fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) - fabric.launch() - - with fabric.init_module(empty_init=True): - model = MyModel(2) - - model = fabric.setup(model) - # all parameters were moved - assert len(list(model.parameters())) == 3 - assert all(p.device.type == "cuda" for p in model.parameters()) - # buffers were moved too - assert model.buf.device.type == "cuda" - - -class StatefulThing: - def state_dict(self): - return {"thing": 1} - - def load_state_dict(self, state_dict): - assert state_dict == self.state_dict() - - -class TensorLike: - def __init__(self, device: Optional[Union[str, torch.device]] = None, shape: Optional[Tuple[int, ...]] = None): - self.device = torch.device(device) if device is not None else None - self.shape = torch.Size(shape) if shape is not None else None - - def __eq__(self, other): - return ( - isinstance(other, torch.Tensor) - and (self.device is None or other.device == self.device) - and (self.shape is None or other.shape == self.shape) - ) - - -def test_save_load_full_checkpoint(tmp_path): - strategy = FSDPThunderStrategy(state_dict_type="full", broadcast_from=0) - fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) - fabric.launch() - - model = MyModel(4) - expected = model.state_dict() - - # save a sharded model - model = fabric.setup(model) - state = {"model": model, "stateful": StatefulThing(), "primitive": 123} - checkpoint_path = tmp_path / "foo" - fabric.save(checkpoint_path, state) - - # assert the file contents - if fabric.global_rank == 0: - checkpoint = torch.load(checkpoint_path) - # cpu_offload is enabled by default - assert checkpoint == { - "model": { - "buf": TensorLike("cpu", tuple()), - "inner.l.weight": TensorLike("cpu", (8, 4)), - "l.bias": TensorLike("cpu", (4,)), - "l.weight": TensorLike("cpu", (4, 2)), - }, - "stateful": {"thing": 1}, - "primitive": 123, - } - torch.testing.assert_close(checkpoint["model"], expected) - - # load its weights into a different sharded model - model = MyModel(4) - model = fabric.setup(model) - state = {"model": model, "stateful": StatefulThing(), "primitive": 321} - fabric.load(checkpoint_path, state) - - from thunder.distributed import _unshard_params - - # unshard this model's parameters to compare with the original state dict before sharding - _unshard_params(model, model.process_group_for_ddp, True) - # we loaded rank 0's weights, so this would fail in the other ranks - if fabric.global_rank == 0: - actual = model.state_dict() - # `_unshard_params` doesnt offload buffers at the moment - assert actual["buf"].device.type == "cuda" - actual["buf"] = actual["buf"].to(device="cpu") - torch.testing.assert_close(actual, expected) - assert state["primitive"] == 123 - - -def test_load_full_checkpoint_only_model(tmp_path): - strategy = FSDPThunderStrategy() - fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) - fabric.launch() - - checkpoint_path = tmp_path / "foo" - checkpoint_path = fabric.broadcast(checkpoint_path) - if fabric.global_rank == 0: - model = MyModel(4) - expected = model.state_dict() - torch.save(expected, checkpoint_path) - fabric.barrier() - expected = torch.load(checkpoint_path) - - # before sharding - model = MyModel(4) - fabric.load_raw(checkpoint_path, model) - torch.testing.assert_close(model.state_dict(), expected) - - # after sharding - model = MyModel(4) - model = fabric.setup(model) - fabric.load_raw(checkpoint_path, model) - from thunder.distributed import _unshard_params - - # unshard this model's parameters to compare with the original state dict before sharding - _unshard_params(model, model.process_group_for_ddp, True) - actual = model.state_dict() - # `_unshard_params` doesnt offload buffers at the moment - assert actual["buf"].device.type == "cuda" - actual["buf"] = actual["buf"].to(device="cpu") - torch.testing.assert_close(actual, expected) - - -def distributed_ckpt_to_regular(path): - """From ``torch.distributed.checkpoint.format_utils.dcp_to_torch_save``.""" - from torch.distributed.checkpoint.state_dict_loader import _load_state_dict - from torch.distributed.checkpoint import FileSystemReader - - if _TORCH_GREATER_EQUAL_2_3: - from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner - else: - from torch.distributed.checkpoint._traverse import set_element - from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner - from torch.distributed.checkpoint.metadata import TensorStorageMetadata - - class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def set_up_planner(self, state_dict, metadata, is_coordinator): - assert not state_dict - # rebuild the state dict from the metadata - for k, v in metadata.state_dict_metadata.items(): - if isinstance(v, TensorStorageMetadata): - v = torch.empty(v.size, dtype=v.properties.dtype) - if k in metadata.planner_data: - set_element(state_dict, metadata.planner_data[k], v) - else: - state_dict[k] = v - super().set_up_planner(state_dict, metadata, is_coordinator) - - state_dict = {} - storage_reader = FileSystemReader(path) - _load_state_dict(state_dict, storage_reader=storage_reader, planner=_EmptyStateDictLoadPlanner(), no_dist=True) - return state_dict - - -def test_save_load_sharded_checkpoint(tmp_path): - strategy = FSDPThunderStrategy(state_dict_type="sharded", broadcast_from=0) - fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) - fabric.launch() - - model = MyModel(4) - expected = model.state_dict() - - # save a sharded model - model = fabric.setup(model) - state = {"model": model, "stateful": StatefulThing(), "primitive": 123} - fabric.save(tmp_path, state) - - # assert the file contents - if fabric.global_rank == 0: - assert set(os.listdir(tmp_path)) == {"meta.pt", "__1_0.distcp", "__0_0.distcp", ".metadata"} - - metadata = torch.load(tmp_path / "meta.pt") - assert metadata == {"stateful": {"thing": 1}, "primitive": 123} - - checkpoint = distributed_ckpt_to_regular(tmp_path) - # cpu_offload is enabled by default - assert checkpoint == { - "model": { - "buf": TensorLike("cpu", tuple()), - "inner.l.weight": TensorLike("cpu", (8, 4)), - "l.bias": TensorLike("cpu", (4,)), - "l.weight": TensorLike("cpu", (4, 2)), - } - } - torch.testing.assert_close(checkpoint["model"], expected) - - # load its weights into a different sharded model - model = MyModel(4) - model = fabric.setup(model) - state = {"model": model, "stateful": StatefulThing(), "primitive": 321} - fabric.load(tmp_path, state) - - from thunder.distributed import _unshard_params - - # unshard this model's parameters to compare with the original state dict before sharding - _unshard_params(model, model.process_group_for_ddp, True) - # we loaded rank 0's weights, so this would fail in the other ranks - if fabric.global_rank == 0: - actual = model.state_dict() - # `_unshard_params` doesnt offload buffers at the moment - assert actual["buf"].device.type == "cuda" - actual["buf"] = actual["buf"].to(device="cpu") - torch.testing.assert_close(actual, expected) - assert state["primitive"] == 123 diff --git a/examples/lit-gpt/train.py b/examples/lit-gpt/train.py deleted file mode 100644 index 412711ce5a..0000000000 --- a/examples/lit-gpt/train.py +++ /dev/null @@ -1,111 +0,0 @@ -import time - -import lightning as L -import torch -from torch.utils.data import DataLoader, IterableDataset - -from thunder.tests.lit_gpt_model import GPT, Config - -model_name = "open_llama_3b" -learning_rate = 6e-4 -micro_batch_size = 2 -max_iters = 50 - - -def main(compile: str = "eager", dynamic: bool = False) -> None: - fabric = L.Fabric(devices=1, precision="bf16-true") - - fabric.seed_everything(42, workers=True) # same seed for every process to init model (FSDP) - - config = Config.from_name(model_name) - print(f"Loading model with {config.__dict__}") - t0 = time.perf_counter() - with fabric.init_module(): - og_model = model = GPT(config) - print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - - if compile == "inductor": - model = torch.compile(model, fullgraph=True, mode="reduce-overhead", dynamic=dynamic) - elif compile == "thunder": - import thunder - from thunder.executors.sdpaex import sdpa_ex - from thunder.executors.torch_compile import torch_compile_executor - - model = thunder.jit( - model, - executors=[sdpa_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor], - # TODO: we'd want to enable CUDAGraphs for parity with `torch.compile` but it goes OOM - ) - model.max_seq_length = og_model.max_seq_length - elif compile != "eager": - raise ValueError(compile) - - model = fabric.setup(model) - optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-1, foreach=False) - optimizer = fabric.setup_optimizers(optimizer) - - train_data = DummyDataset(model.max_seq_length, dynamic) - train_dataloader = DataLoader(train_data, batch_size=micro_batch_size, num_workers=2, collate_fn=pad_collate) - train_dataloader = fabric.setup_dataloaders(train_dataloader) - - train(fabric, model, optimizer, train_dataloader) - print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") - - -def train( - fabric: L.Fabric, model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader -) -> None: - train_iter = iter(train_dataloader) - t0 = None - assert max_iters > 5 - for i in range(max_iters): - iter_t0 = time.perf_counter() - if i == 5: # warmup - t0 = iter_t0 - input_ids, targets = next(train_iter) - - logits = model(input_ids) - logits = logits.reshape(-1, logits.size(-1)) - targets = targets.reshape(-1) - loss = torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) - fabric.backward(loss) - optimizer.step() - optimizer.zero_grad() - - loss_item = loss.item() # synchronization - t1 = time.perf_counter() - print(f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}") - print(f"Total time: {(t1 - t0):.2f}s") - - -class DummyDataset(IterableDataset): - def __init__(self, max_seq_length: int, dynamic: bool): - super().__init__() - self.max_seq_length = max_seq_length - self.dynamic = dynamic - - def __iter__(self): - while True: - if self.dynamic: - t = torch.randint(10, self.max_seq_length + 1, (1,)) - else: - t = self.max_seq_length - data = torch.randint(0, 100, (t + 1,), dtype=torch.int64) - x = data[:t] - y = data[1 : t + 1] - yield x, y - - -def pad_collate(batch): - x, y = zip(*batch) - x_padded = torch.nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=0) - y_padded = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=-1) - return x_padded, y_padded - - -if __name__ == "__main__": - torch.set_float32_matmul_precision("high") - - from jsonargparse import CLI - - CLI(main) diff --git a/examples/lit-gpt/train_fsdp.py b/examples/lit-gpt/train_fsdp.py deleted file mode 100644 index e896d52ef3..0000000000 --- a/examples/lit-gpt/train_fsdp.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging -import re -import time -from typing import Literal - -import lightning as L -import torch -from lightning.fabric.strategies import FSDPStrategy -from torch.distributed.fsdp.wrap import always_wrap_policy -from torch.utils.data import DataLoader, IterableDataset - -from _fsdp_thunder import FSDPThunderStrategy -from thunder.tests.lit_gpt_model import GPT, Block, Config - - -model_name = "open_llama_3b" -learning_rate = 6e-4 -micro_batch_size = 2 -max_iters = 50 - - -def main( - compile: str = "eager", devices: int = 2, stage: str = "2", bucketing_strategy: Literal["NONE", "BLOCK"] = "NONE" -) -> None: - fsdp_type = {"2": "ZERO2", "3": "ZERO3"}[stage] - sharding_strategy = {"2": "SHARD_GRAD_OP", "3": "FULL_SHARD"}[stage] - auto_wrap_policy = always_wrap_policy if bucketing_strategy.lower() == "none" else {Block} - strategy = ( - FSDPThunderStrategy( - sharding_strategy=fsdp_type, - bucketing_strategy=bucketing_strategy, - executors=("sdpa", "torchcompile", "nvfuser", "torch"), - ) - if compile == "thunder" - else FSDPStrategy(auto_wrap_policy=auto_wrap_policy, sharding_strategy=sharding_strategy) - ) - - fabric = L.Fabric(devices=devices, strategy=strategy, precision="bf16-true") - fabric.launch() - - fabric.seed_everything(42, workers=True) # same seed for every process to init model (FSDP) - - config = Config.from_name(model_name) - fabric.print(f"Loading model with {config.__dict__}") - t0 = time.perf_counter() - with fabric.init_module(empty_init=True): - og_model = model = GPT(config) - fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - - if compile == "inductor": - # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632 - pattern = re.compile(".*Profiler function .* will be ignored") - logging.getLogger("torch._dynamo.variables.torch").addFilter( - lambda record: not pattern.search(record.getMessage()) - ) - - model = torch.compile(model) - elif compile == "thunder": - pass # fabric.setup does this - elif compile != "eager": - raise ValueError(compile) - - model = fabric.setup(model) - if compile == "thunder": - model.max_seq_length = og_model.max_seq_length - optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-1, foreach=False) - optimizer = fabric.setup_optimizers(optimizer) - - train_data = DummyDataset(model.max_seq_length) - train_dataloader = DataLoader(train_data, batch_size=micro_batch_size, num_workers=2) - train_dataloader = fabric.setup_dataloaders(train_dataloader) - - train(fabric, model, optimizer, train_dataloader) - fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") - - -def train( - fabric: L.Fabric, model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader -) -> None: - train_iter = iter(train_dataloader) - t0 = None - assert max_iters > 5 - for i in range(max_iters): - iter_t0 = time.perf_counter() - if i == 5: # warmup - t0 = iter_t0 - input_ids, targets = next(train_iter) - - logits = model(input_ids) - logits = logits.reshape(-1, logits.size(-1)) - targets = targets.reshape(-1) - loss = torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) - fabric.backward(loss) - optimizer.step() - optimizer.zero_grad() - - loss_item = loss.item() # synchronization - t1 = time.perf_counter() - fabric.print(f"iter {i}: loss {loss_item :.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms") - fabric.print(f"Total time: {(t1 - t0):.2f}s") - - -class DummyDataset(IterableDataset): - def __init__(self, max_seq_length: int): - super().__init__() - self.max_seq_length = max_seq_length - - def __iter__(self): - t = self.max_seq_length - while True: - data = torch.randint(0, 100, (t + 1,), dtype=torch.int64) - x = data[:t] - y = data[1 : t + 1] - yield x, y - - -if __name__ == "__main__": - torch.set_float32_matmul_precision("high") - - from jsonargparse import CLI - - CLI(main) diff --git a/requirements/test.txt b/requirements/test.txt index a1402bd69b..c882272833 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -12,10 +12,10 @@ hypothesis ==6.99.10 # for test_ddp.py numpy # for test_ops.py einops # for test_einops.py lit_gpt @ git+https://github.com/Lightning-AI/lit-gpt@f241d94df59d82b2017bfdcd3800ac8779eb45f5 -absl-py # for test_parametrized.py in examples/lit-gpt -pandas # for test_parametrized.py in examples/lit-gpt -xlsxwriter # for test_parametrized.py in examples/lit-gpt -jsonargparse # for benchmarking_litgpt.py in thunder/benchmarks +absl-py # thunder/benchmarks/test_benchmark_litgpt.py +pandas # thunder/benchmarks/test_benchmark_litgpt.py +xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py +jsonargparse # thunder/benchmarks/benchmark_litgpt.py # Installs JAX on Linux and MacOS jaxlib; sys_platform == 'linux' or sys_platform == 'darwin' # required for jax, see https://github.com/google/jax#installation diff --git a/examples/lit-gpt/test_parametrized.py b/thunder/benchmarks/test_benchmark_litgpt.py similarity index 66% rename from examples/lit-gpt/test_parametrized.py rename to thunder/benchmarks/test_benchmark_litgpt.py index 5e658b6447..cb76e48221 100644 --- a/examples/lit-gpt/test_parametrized.py +++ b/thunder/benchmarks/test_benchmark_litgpt.py @@ -1,4 +1,4 @@ -''' +""" Script to run all lit-GPT models available as a parametrized test using abseil's unittest framework. Runs a parametrized product over all configs specified, compiler options, distributed modes etc. Uses environment variables to modify default behavior @@ -8,7 +8,7 @@ between each test. BENCHMARK_OUT_FORMAT - use this env variable to control the format in which the results are presented. Uses 'xlsx' by default. Supported: 'none', 'print', 'xlsx'. -''' +""" import torch from absl.testing import parameterized @@ -20,21 +20,18 @@ import pandas as pd from datetime import datetime + class Runner: - ''' + """ Benchmark Runner class to a) Launch the training benchmarking run, b) Store results from all tests, c) Compile results as xlsx file - ''' - - def __init__(self, - benchmark_file, - mid_benchmark_out, - output_format): + """ + def __init__(self, benchmark_file, mid_benchmark_out, output_format): self.dataframe_data = [] - self.json_file_path = '/tmp/benchmark_litgpt_data.json' + self.json_file_path = "/tmp/benchmark_litgpt_data.json" self.benchmark_file = benchmark_file self.mid_benchmark_out = mid_benchmark_out self.output_format = output_format @@ -44,40 +41,64 @@ def __enter__(self): def add_to_dataframe(self): if self.perf_metrics_dict: - if 'tokens_per_sec_per_gpu' not in self.perf_metrics_dict.keys(): #In case of OutofMemory error, this is already marked 'OOM' - self.perf_metrics_dict['tokens_per_sec_per_gpu'] = self.perf_metrics_dict['tokens_per_sec'] / self.perf_metrics_dict['Num GPUS'] + if ( + "tokens_per_sec_per_gpu" not in self.perf_metrics_dict.keys() + ): # In case of OutofMemory error, this is already marked 'OOM' + self.perf_metrics_dict["tokens_per_sec_per_gpu"] = ( + self.perf_metrics_dict["tokens_per_sec"] / self.perf_metrics_dict["Num GPUS"] + ) self.dataframe_data.append(self.perf_metrics_dict) def complete_dataframe(self, is_teardown): if not self.dataframe_data: # The benchmark probably failed return - #Called when tearing down the parametrized test - #This generates a summarized dataframe for each perf metric and saves as a xlsx file + # Called when tearing down the parametrized test + # This generates a summarized dataframe for each perf metric and saves as a xlsx file df = pd.DataFrame(self.dataframe_data) - df['Sharding Size'] = df['Sharding Size'].fillna('none') #Convert None Type to string so that pivot table can group. - index_list = ['model_name', 'Num GPUS', 'Seq Len', 'Micro BS', 'Global BS', 'GA', 'Distributed Mode', 'Sharding Size'] + df["Sharding Size"] = df["Sharding Size"].fillna( + "none" + ) # Convert None Type to string so that pivot table can group. + index_list = [ + "model_name", + "Num GPUS", + "Seq Len", + "Micro BS", + "Global BS", + "GA", + "Distributed Mode", + "Sharding Size", + ] - self.iter_time_df = df.pivot_table(index=index_list, columns='compiler', values='average_iter_time', aggfunc='first').reset_index() - self.tokens_per_sec_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec', aggfunc='first').reset_index() - self.tokens_per_sec_per_gpu_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec_per_gpu', aggfunc='first').reset_index() - self.memory_used_GB_df = df.pivot_table(index=index_list, columns='compiler', values='memory_used_GB', aggfunc='first').reset_index() + self.iter_time_df = df.pivot_table( + index=index_list, columns="compiler", values="average_iter_time", aggfunc="first" + ).reset_index() + self.tokens_per_sec_df = df.pivot_table( + index=index_list, columns="compiler", values="tokens_per_sec", aggfunc="first" + ).reset_index() + self.tokens_per_sec_per_gpu_df = df.pivot_table( + index=index_list, columns="compiler", values="tokens_per_sec_per_gpu", aggfunc="first" + ).reset_index() + self.memory_used_GB_df = df.pivot_table( + index=index_list, columns="compiler", values="memory_used_GB", aggfunc="first" + ).reset_index() if self.output_format == "xlsx": - output_ext = {'xlsx': '.xlsx', }[self.output_format] + output_ext = { + "xlsx": ".xlsx", + }[self.output_format] if not is_teardown: - filename = 'examples/lit-gpt/mid_output_parameterized_results' + str(output_ext) + filename = "mid_output_parameterized_results" + str(output_ext) else: - current_time = datetime.now().strftime('%Y-%m-%d_%H-%M') + current_time = datetime.now().strftime("%Y-%m-%d_%H-%M") filename = f"{current_time}_litgpt_benchmark" + str(output_ext) - filename = 'examples/lit-gpt/' + str(filename) - - with pd.ExcelWriter(filename, engine='xlsxwriter') as writer: - self.iter_time_df.to_excel(writer, sheet_name='Average Iter Time (ms)') - self.tokens_per_sec_df.to_excel(writer, sheet_name='Tokens per sec') - self.tokens_per_sec_per_gpu_df.to_excel(writer, sheet_name='Tokens per sec per GPU') - self.memory_used_GB_df.to_excel(writer, sheet_name='Memory allocated GB') - elif self.output_format == 'print': + + with pd.ExcelWriter(filename, engine="xlsxwriter") as writer: + self.iter_time_df.to_excel(writer, sheet_name="Average Iter Time (ms)") + self.tokens_per_sec_df.to_excel(writer, sheet_name="Tokens per sec") + self.tokens_per_sec_per_gpu_df.to_excel(writer, sheet_name="Tokens per sec per GPU") + self.memory_used_GB_df.to_excel(writer, sheet_name="Memory allocated GB") + elif self.output_format == "print": print("\nAVERAGE ITERATION TIME (ms)") print(self.iter_time_df) print("\nTHROUGHPUT (tokens/s)") @@ -91,12 +112,24 @@ def run_benchmark(self, kwargs): command_list = [] for key, val in kwargs.items(): command_list.append("--" + str(key) + "=" + str(val)) - if kwargs['distributed_mode'] != 'none': + if kwargs["distributed_mode"] != "none": nproc_per_node = torch.cuda.device_count() - subprocess_cmd = ["torchrun", f"--nproc_per_node={nproc_per_node}", "--nnodes=1", "{}".format(self.benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] + subprocess_cmd = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--nnodes=1", + f"{self.benchmark_file}", + "--return_metrics_as_json=True", + f"--json_path={self.json_file_path}", + ] subprocess_cmd.extend(command_list) else: - subprocess_cmd = ["python", "{}".format(self.benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] + subprocess_cmd = [ + "python", + f"{self.benchmark_file}", + "--return_metrics_as_json=True", + f"--json_path={self.json_file_path}", + ] subprocess_cmd.extend(command_list) print(f'Running {" ".join(subprocess_cmd)!r}') @@ -104,13 +137,13 @@ def run_benchmark(self, kwargs): self.perf_metrics_dict = {} if os.path.exists(self.json_file_path): - with open(self.json_file_path, 'r') as file: + with open(self.json_file_path) as file: self.perf_metrics_dict = json.load(file) # Cleanup after the benchmark finishes. It might have failed before creating this os.remove(self.json_file_path) if proc_output.returncode: - if 'CUDA out of memory' in proc_output.stdout or "CUDA error: out of memory" in proc_output.stderr: + if "CUDA out of memory" in proc_output.stdout or "CUDA error: out of memory" in proc_output.stderr: defaultdict_oom = defaultdict(lambda: "OOM") defaultdict_oom.update(self.perf_metrics_dict) self.perf_metrics_dict = defaultdict_oom @@ -124,26 +157,28 @@ def run_benchmark(self, kwargs): class Test(parameterized.TestCase): - @classmethod def setUpClass(cls): - super(Test, cls).setUpClass() + super().setUpClass() def get_installed_thunder_path(): import thunder + thunder_init = thunder.__file__ - thunder_benchmark_file = str(thunder_init).replace('__init__.py', 'benchmarks/benchmark_litgpt.py') + thunder_benchmark_file = str(thunder_init).replace("__init__.py", "benchmarks/benchmark_litgpt.py") return thunder_benchmark_file benchmark_file = os.getenv("BENCHMARK_FILE", get_installed_thunder_path()) mid_benchmark_out = bool(os.getenv("MID_BENCHMARK_OUT", 0)) - output_format = str(os.getenv("BENCHMARK_OUT_FORMAT", "xlsx")) # Can take none, print, xlsx as of 03/12 - cls.runner = Runner(benchmark_file=benchmark_file, mid_benchmark_out=mid_benchmark_out, output_format=output_format) + output_format = str(os.getenv("BENCHMARK_OUT_FORMAT", "xlsx")) # Can take none, print, xlsx as of 03/12 + cls.runner = Runner( + benchmark_file=benchmark_file, mid_benchmark_out=mid_benchmark_out, output_format=output_format + ) @classmethod def tearDownClass(cls): cls.runner.complete_dataframe(is_teardown=True) - super(Test, cls).tearDownClass() + super().tearDownClass() # @parameterized.product( # (dict(distributed_mode = "fsdp", shard_mode = "zero2"), @@ -184,16 +219,23 @@ def tearDownClass(cls): # ) @parameterized.product( - distributed_mode = ("fsdp", ), - shard_mode = ("zero2", ), - model_name = ("Llama-2-7b-hf", ), - micro_batch_size = (1, 4, ), - compile = ("eager", "inductor", "thunder", "thunder_inductor",) + distributed_mode=("fsdp",), + shard_mode=("zero2",), + model_name=("Llama-2-7b-hf",), + micro_batch_size=( + 1, + 4, + ), + compile=( + "eager", + "inductor", + "thunder", + "thunder_inductor", + ), ) - def test(self, **kwargs): - kwargs['nsys_enabled'] = False - kwargs['dynamic'] = False + kwargs["nsys_enabled"] = False + kwargs["dynamic"] = False self.__file__ = __file__ try: @@ -210,5 +252,6 @@ def test(self, **kwargs): else: self.fail(run_msg) -if __name__ == '__main__': + +if __name__ == "__main__": absltest.main()