diff --git a/recipes/configs/llama2/7B_lora_dpo.yaml b/recipes/configs/llama2/7B_lora_dpo.yaml index abf1b43138..0f21b03206 100644 --- a/recipes/configs/llama2/7B_lora_dpo.yaml +++ b/recipes/configs/llama2/7B_lora_dpo.yaml @@ -83,4 +83,7 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 + +# Memory management enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml index 7543cb5d6f..c6d8d4bbba 100644 --- a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml @@ -80,4 +80,7 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 + +# Memory management enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_1/8B_lora_dpo.yaml b/recipes/configs/llama3_1/8B_lora_dpo.yaml new file mode 100644 index 0000000000..6f94b7d09d --- /dev/null +++ b/recipes/configs/llama3_1/8B_lora_dpo.yaml @@ -0,0 +1,92 @@ +# Config for multi-device LoRA DPO alignment in lora_dpo_distributed.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# For single device LoRA DPO alignment please use llama3_1/8B_lora_dpo_single_device + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1_8b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank + lora_dropout: 0.0 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.stack_exchange_paired_dataset +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.05 + lr: 5e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.rlhf.loss.DPOLoss + beta: 0.1 + label_smoothing: 0 + +# Training +epochs: 1 +max_steps_per_epoch: 1000 +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory + +# Logging +output_dir: /tmp/lora_dpo_output/ +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Environment +device: cuda +dtype: bf16 + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory diff --git a/recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml b/recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml new file mode 100644 index 0000000000..638a4efe12 --- /dev/null +++ b/recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml @@ -0,0 +1,89 @@ +# Config for single device LoRA DPO alignment in lora_dpo_single_device.py +# using a Llama2 7B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on a single device, run the following command from root: +# tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1_8b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank + lora_dropout: 0.0 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.stack_exchange_paired_dataset +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.05 + lr: 5e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.rlhf.loss.DPOLoss + +# Training +epochs: 1 +max_steps_per_epoch: 1000 +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory + +# Logging +output_dir: /tmp/lora_dpo_output/ +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Environment +device: cuda +dtype: bf16 + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index d1b54764b2..52498a9625 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -59,6 +59,18 @@ class LoRADPORecipeDistributed(FTRecipeInterface): come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -109,6 +121,8 @@ class LoRADPORecipeDistributed(FTRecipeInterface): ValueError: If ``dtype`` is set to fp16. ValueError: If world_size is 1 RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -135,8 +149,28 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False - # training attributes - self._enable_activation_checkpointing = cfg.enable_activation_checkpointing + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) # These attributes constitute the recipe state and are updated by ``load_checkpoint`` # when ``resume_from_checkpoint`` is ``True`` @@ -232,6 +266,7 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -294,6 +329,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, fsdp_cpu_offload: bool, reshard_after_forward: bool, base_model_state_dict: Dict[str, Any], @@ -397,6 +433,12 @@ def _setup_model( lora_unexpected=lora_unexpected, ) # Ensure no params and buffers are on meta device + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + training.validate_no_params_on_meta_device(model) utils.log_rank_zero( log, @@ -582,7 +624,8 @@ def concatenated_forward( # formed by concatenating an equal number of "chosen" and "rejected". len_chosen = concatenated_input_ids.shape[0] // 2 - all_logits = model(concatenated_input_ids) + with self.activations_handling_ctx: + all_logits = model(concatenated_input_ids) all_log_probs = rlhf.get_batch_log_probs( all_logits, diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index 53b3b67be5..a527150aae 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -44,9 +44,11 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface): This recipe supports: - Activation checkpointing. This is enabled by default but is configurable. + - Activation offloading - this is enabled by default and should only be used alongside + activation checkpointing. - Full bf16 training for supported HW architectures. We currently check bf16 support via - the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via - setting `dtype=bf16` in configuration. + the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via + setting `dtype=bf16` in configuration. - Checkpointing: of LoRA adapter parameters and their optimizer states. When resuming from a checkpoint, the adapter parameters are loaded from the checkpoint along with the base model weights. Note that intra-epoch resumption is not supported. @@ -74,6 +76,8 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ @@ -101,6 +105,29 @@ def __init__(self, cfg: DictConfig) -> None: ) self._log_peak_memory_stats = False + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -190,6 +217,7 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, compile_model=cfg.compile, base_model_state_dict=checkpoint_dict[training.MODEL_KEY], lora_weights_state_dict=( @@ -251,6 +279,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, compile_model: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, @@ -289,6 +318,11 @@ def _setup_model( lora_unexpected=lora_unexpected, ) + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + log.info(f"Model is initialized with precision {self._dtype}.") # Compile model, if enabled. @@ -443,7 +477,8 @@ def concatenated_forward( # formed by concatenating an equal number of "chosen" and "rejected". len_chosen = concatenated_input_ids.shape[0] // 2 - all_logits = model(concatenated_input_ids) + with self.activations_handling_ctx: + all_logits = model(concatenated_input_ids) all_log_probs = rlhf.get_batch_log_probs( all_logits, diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index daf84be1b7..2a4bf25a8b 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -302,6 +302,10 @@ class Recipe: name="llama2/7B_lora_dpo_single_device", file_path="llama2/7B_lora_dpo_single_device.yaml", ), + Config( + name="llama3_1/8B_lora_dpo_single_device", + file_path="llama3_1/8B_lora_dpo_single_device.yaml", + ), ], supports_distributed=False, ), @@ -313,6 +317,10 @@ class Recipe: name="llama2/7B_lora_dpo", file_path="llama2/7B_lora_dpo.yaml", ), + Config( + name="llama3_1/8B_lora_dpo", + file_path="llama3_1/8B_lora_dpo.yaml", + ), ], supports_distributed=True, ),