diff --git a/recipes/full_dpo_distributed.py b/recipes/full_dpo_distributed.py new file mode 100644 index 0000000000..0d645d27f9 --- /dev/null +++ b/recipes/full_dpo_distributed.py @@ -0,0 +1,970 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time +import os + +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group + +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, rlhf, training, utils +from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo +from torchtune.datasets import ConcatDataset +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training.activations import apply_selective_activation_checkpointing +from torchtune.rlhf.loss import SimPOLoss + +from tqdm import tqdm + + +log = utils.get_logger("DEBUG") + + +class FullDPORecipeDistributed(FTRecipeInterface): + """ + Full DPO finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. + + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + The following losses are supported in this recipe: + - :class:`~torchtune.modules.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO). + - :class:`~torchtune.modules.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO). + - :class:`~torchtune.modules.rlhf.loss.IPO`: Identity Preference Optimization (IPO). + - :class:`~torchtune.modules.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO). + + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. + """ + + def __init__(self, cfg: DictConfig) -> None: + + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + if ( + cfg.get("fsdp_cpu_offload", False) + and cfg.optimizer.get("fused", False) + and not utils.torch_version_ge("2.4.0") + ): + raise RuntimeError( + "Using fused optimizer on CPU is only supported in PyTorch nightly." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + _, rank = training.get_world_size_and_rank() + self._is_rank_zero = rank == 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + raise RuntimeError( + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." + ) + if self._gradient_accumulation_steps > 1: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + + # These attributes constitute the recipe state and are updated by ``load_checkpoint`` + # when ``resume_from_checkpoint`` is ``True`` + self.seed = training.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def load_ref_states(self, cfg_ref_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + _ref_checkpointer = config.instantiate(cfg_ref_checkpointer) + checkpoint_dict = _ref_checkpointer.load_checkpoint() + return checkpoint_dict[training.MODEL_KEY] + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + log.info("_metric_logger is initialized.") + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + self._compile = cfg.get("compile", False) + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + model_state_dict=checkpoint_dict[training.MODEL_KEY], + ac_mode=cfg.get("ac_mode", None), + ac_option=cfg.get("ac_option", None), + ) + log.info("Loading reference model") + self._ref_model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=False, + model_state_dict=self.load_ref_states(cfg.ref_checkpointer), + ) + log.info("Done with loading reference model") + self._ref_model.eval() + for p in self._ref_model.parameters(): + p.requires_grad = False + + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=self._optimizer_in_bwd, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + + if self._is_rank_zero: + log.info("Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after all of these are initialized + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, + ac_mode: Optional[str] = None, + ac_option: Optional[int] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + """ + + if self._is_rank_zero: + log.info( + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + model.load_state_dict(model_state_dict) + if self._dtype == torch.bfloat16: + model = model.to(torch.bfloat16) + + # We currently have two versions of activation checkpointing in this recipe + # for testing and BC purposes. ``enable_activation_checkpointing`` controls + # the older version of AC and this behavior is unchanged + # ac_mode and ac_option together control selective AC. This is only enabled + # when these are set AND ``enable_activation_checkpointing`` is set to False + # We'll clean this up as soon as testing of AC is complete + if (not enable_activation_checkpointing) and (ac_mode is not None): + apply_selective_activation_checkpointing( + model, + ac_mode, + ac_option, + ) + + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, + model_state_dict, + self._device, + self._is_rank_zero, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + if self._is_rank_zero: + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + param: config.instantiate(cfg_optimizer, [param]) + for param in self._model.parameters() + } + + # Register optimizer step hooks on the model to run optimizer in backward. + training.register_optim_in_bwd_hooks( + model=self._model, optim_dict=optim_dict + ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + model=self._model, optim_dict=optim_dict + ) + # Load optimizer states for each param. If optimizer states are being restored in an optimizer in + # backward run, these need to have been saved with the same setting. Cannot restore from runs that + # did not use optimizer in backward. + if opt_state_dict is not None: + for param in opt_state_dict.keys(): + try: + training.load_from_full_optimizer_state_dict( + self._optim_ckpt_wrapper.state_dict()[param], + opt_state_dict[param], + self._device, + ) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + if self._is_rank_zero: + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) + + if self._is_rank_zero: + log.info("Optimizer and loss are initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + if self._is_rank_zero: + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = training.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ) + + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + collate_fn=partial( + padded_collate_dpo, + padding_idx=self._tokenizer.pad_id, + ignore_idx=CROSS_ENTROPY_IGNORE_IDX, + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the model weights and recipe state in + different checkpoint files. To correctly resume training from an intermediate checkpoint, + the model weights and recipe state must be provided. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + + if self._is_rank_zero: + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + start = time.perf_counter() + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + cpu_state_dict = training.get_full_model_state_dict( + self._model, + self._is_rank_zero, + device=self._device, + ) + + if self._is_rank_zero: + log.info( + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" + ) + + if intermediate_checkpoint: + start = time.perf_counter() + if self._is_rank_zero: + log.info("Getting optimizer state dict...") + if not self._optimizer_in_bwd: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = {} + for param, opt in self._optim_ckpt_wrapper.optim_map.items(): + opt_state_dict[param] = training.get_full_optimizer_state_dict( + opt, self._is_rank_zero, device=self._device + ) + if self._is_rank_zero: + log.info( + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + + if self._is_rank_zero: + start = time.perf_counter() + checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() + + def concatenated_forward( + self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Run forward pass of the model with chosen and rejected samples concatenated. + + Args: + model (nn.Module): The model to be used for the forward pass. + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels. + + Returns: + Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits. + """ + concatenated_input_ids, concatenated_labels = batch + concatenated_input_ids = concatenated_input_ids.to(self._device) + concatenated_labels = concatenated_labels.to(self._device) + + # formed by concatenating an equal number of "chosen" and "rejected". + len_chosen = concatenated_input_ids.shape[0] // 2 + + all_logits = model(concatenated_input_ids) + + chosen_log_probs = rlhf.get_batch_log_probs( + all_logits[:len_chosen], + concatenated_labels[:len_chosen], + # see :class:`~torchtune.modules.rlhf.loss.dpo.SimPOLoss` + return_average_logprobs=isinstance(self._loss_fn, SimPOLoss), + ) + + rejected_log_probs = rlhf.get_batch_log_probs( + all_logits[len_chosen:], + concatenated_labels[len_chosen:], + # see :class:`~torchtune.modules.rlhf.loss.dpo.SimPOLoss` + return_average_logprobs=isinstance(self._loss_fn, SimPOLoss), + ) + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits) + + def train(self) -> None: + """ + The core training loop. Supports training on subsets of the dataset using the + ``max_steps_per_epoch``. + """ + # clean up before training begins + training.cleanup_before_training() + + _, rank = training.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # batch is input_ids, labels + num_tokens += batch[0].numel() + ( + policy_chosen_log_probs, + policy_rejected_log_probs, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(self._model, batch) + + policy_chosen_logits_mean = policy_chosen_logits.detach().mean() + policy_rejected_logits_mean = policy_rejected_logits.detach().mean() + + # deleting logits here helps reduce (peak) memory usage - we only need them for metric logging + del policy_chosen_logits, policy_rejected_logits + + if isinstance(self._loss_fn, SimPOLoss): + loss, chosen_rewards, rejected_rewards = self._loss_fn( + policy_chosen_log_probs, policy_rejected_log_probs + ) + else: + # reference based losses (e.g. DPO) explicitly regularize the objective fn based on + # the reference model's output - reference-free losses (such as SimPO) don't require this. + with torch.no_grad(): + ( + reference_chosen_log_probs, + reference_rejected_log_probs, + _, + _, + ) = self.concatenated_forward(self._ref_model, batch) + assert not reference_chosen_log_probs.requires_grad + assert not reference_rejected_log_probs.requires_grad + loss, chosen_rewards, rejected_rewards = self._loss_fn( + policy_chosen_log_probs, + policy_rejected_log_probs, + reference_chosen_log_probs, + reference_rejected_log_probs, + ) + loss = loss.mean() + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + "rewards/chosen": chosen_rewards.mean().cpu(), + "rewards/rejected": rejected_rewards.mean().cpu(), + "rewards/accuracies": reward_accuracies.mean().cpu(), + "rewards/margins": (chosen_rewards - rejected_rewards) + .mean() + .cpu(), + "log_probs/rejected": policy_rejected_log_probs.detach() + .mean() + .cpu(), + "log_probs/chosen": policy_chosen_log_probs.detach() + .mean() + .cpu(), + "logits/rejected": policy_rejected_logits_mean.cpu(), + "logits/chosen": policy_chosen_logits_mean.cpu(), + } + if self._log_peak_memory_stats: + log_dict.update( + training.get_memory_stats(device=self._device) + ) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not training.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + + config.log_config(recipe_name="FullDPORecipeDistributed", cfg=cfg) + + recipe = FullDPORecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index cdb1d45f01..25a3e1e728 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -128,6 +128,15 @@ class Recipe: ], supports_distributed=True, ), + Recipe( + name="full_dpo_distributed", + file_path="full_dpo_distributed.py", + configs=[ + Config(name="llama3/8B_full_dpo", file_path="llama3/8B_full_dpo.yaml"), + Config(name="llama3_1/8B_full_dpo", file_path="llama3_1/8B_full_dpo.yaml"), + ], + supports_distributed=True, + ), Recipe( name="lora_finetune_single_device", file_path="lora_finetune_single_device.py",