Skip to content

Commit

Permalink
DPO Activation Offloading (#2087)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Nov 28, 2024
1 parent 67130d9 commit 5e5a349
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 6 deletions.
3 changes: 3 additions & 0 deletions recipes/configs/llama2/7B_lora_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,7 @@ log_peak_memory_stats: True
# Environment
device: cuda
dtype: bf16

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
3 changes: 3 additions & 0 deletions recipes/configs/llama2/7B_lora_dpo_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,7 @@ log_peak_memory_stats: True
# Environment
device: cuda
dtype: bf16

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
92 changes: 92 additions & 0 deletions recipes/configs/llama3_1/8B_lora_dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Config for multi-device LoRA DPO alignment in lora_dpo_distributed.py
# using a Llama2 7B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 2 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
# For single device LoRA DPO alignment please use llama3_1/8B_lora_dpo_single_device

# Model Arguments
model:
_component_: torchtune.models.llama3_1.lora_llama3_1_8b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 8 # higher increases accuracy and memory
lora_alpha: 16 # usually alpha=2*rank
lora_dropout: 0.0

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: null

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 5e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.rlhf.loss.DPOLoss
beta: 0.1
label_smoothing: 0

# Training
epochs: 1
max_steps_per_epoch: 1000
gradient_accumulation_steps: 8 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory

# Logging
output_dir: /tmp/lora_dpo_output/
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
89 changes: 89 additions & 0 deletions recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Config for single device LoRA DPO alignment in lora_dpo_single_device.py
# using a Llama2 7B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on a single device, run the following command from root:
# tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

# Model Arguments
model:
_component_: torchtune.models.llama3_1.lora_llama3_1_8b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 8 # higher increases accuracy and memory
lora_alpha: 16 # usually alpha=2*rank
lora_dropout: 0.0

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: null

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 5e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.rlhf.loss.DPOLoss

# Training
epochs: 1
max_steps_per_epoch: 1000
gradient_accumulation_steps: 8 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory

# Logging
output_dir: /tmp/lora_dpo_output/
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
49 changes: 46 additions & 3 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ class LoRADPORecipeDistributed(FTRecipeInterface):
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.
- Activation Offloading. This can be controlled using the ``enable_activation_offloading``
flag. Activation offloading is a technique similar to activations checkpointing that helps
reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
checkpointing drops the activation in the forward to recompute it later in the backward,
activations offloading will drop the activation in the forward to the CPU and bring it
back during the backward pass. As always, there is a tradeoff--these savings in memory can
come at the cost of training performance and CPU resources. To recover some runtime cost,
we've added an option to enable offloading on a different stream to permit overlapping with
the computation. This option is currently only available on PyTorch 2.5 or later and will
be enabled by default if an acceptable torch version is found. Activation offloading can be
used in conjunction with activation checkpointing.
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
Expand Down Expand Up @@ -109,6 +121,8 @@ class LoRADPORecipeDistributed(FTRecipeInterface):
ValueError: If ``dtype`` is set to fp16.
ValueError: If world_size is 1
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
"""

def __init__(self, cfg: DictConfig) -> None:
Expand All @@ -135,8 +149,28 @@ def __init__(self, cfg: DictConfig) -> None:
)
self._log_peak_memory_stats = False

# training attributes
self._enable_activation_checkpointing = cfg.enable_activation_checkpointing
# activation checkpointing/offloading
self._enable_activation_checkpointing = cfg.get(
"enable_activation_checkpointing", False
)
self._enable_activation_offloading = cfg.get(
"enable_activation_offloading", False
)
if self._enable_activation_offloading:
if self._device.type != "cuda":
raise RuntimeError(
"enable_activation_offloading should only be True when training on CUDA"
)
if not self._enable_activation_checkpointing:
raise RuntimeError(
"enable_activation_offloading should only be True when enable_activation_checkpointing is True"
)
elif self._enable_activation_checkpointing:
utils.log_rank_zero(
log,
"Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
"Enabling activation offloading should reduce memory further.",
)

# These attributes constitute the recipe state and are updated by ``load_checkpoint``
# when ``resume_from_checkpoint`` is ``True``
Expand Down Expand Up @@ -232,6 +266,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
Expand Down Expand Up @@ -294,6 +329,7 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
Expand Down Expand Up @@ -397,6 +433,12 @@ def _setup_model(
lora_unexpected=lora_unexpected,
)
# Ensure no params and buffers are on meta device

# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)

training.validate_no_params_on_meta_device(model)
utils.log_rank_zero(
log,
Expand Down Expand Up @@ -582,7 +624,8 @@ def concatenated_forward(
# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2

all_logits = model(concatenated_input_ids)
with self.activations_handling_ctx:
all_logits = model(concatenated_input_ids)

all_log_probs = rlhf.get_batch_log_probs(
all_logits,
Expand Down
41 changes: 38 additions & 3 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface):
This recipe supports:
- Activation checkpointing. This is enabled by default but is configurable.
- Activation offloading - this is enabled by default and should only be used alongside
activation checkpointing.
- Full bf16 training for supported HW architectures. We currently check bf16 support via
the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via
setting `dtype=bf16` in configuration.
the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via
setting `dtype=bf16` in configuration.
- Checkpointing: of LoRA adapter parameters and their optimizer states. When resuming
from a checkpoint, the adapter parameters are loaded from the checkpoint along
with the base model weights. Note that intra-epoch resumption is not supported.
Expand Down Expand Up @@ -74,6 +76,8 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface):
Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
"""

Expand Down Expand Up @@ -101,6 +105,29 @@ def __init__(self, cfg: DictConfig) -> None:
)
self._log_peak_memory_stats = False

# activation checkpointing/offloading
self._enable_activation_checkpointing = cfg.get(
"enable_activation_checkpointing", False
)
self._enable_activation_offloading = cfg.get(
"enable_activation_offloading", False
)
if self._enable_activation_offloading:
if self._device.type != "cuda":
raise RuntimeError(
"enable_activation_offloading should only be True when training on CUDA"
)
if not self._enable_activation_checkpointing:
raise RuntimeError(
"enable_activation_offloading should only be True when enable_activation_checkpointing is True"
)
elif self._enable_activation_checkpointing:
utils.log_rank_zero(
log,
"Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
"Enabling activation offloading should reduce memory further.",
)

# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = training.set_seed(seed=cfg.seed)
Expand Down Expand Up @@ -190,6 +217,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
compile_model=cfg.compile,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=(
Expand Down Expand Up @@ -251,6 +279,7 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
compile_model: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -289,6 +318,11 @@ def _setup_model(
lora_unexpected=lora_unexpected,
)

# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)

log.info(f"Model is initialized with precision {self._dtype}.")

# Compile model, if enabled.
Expand Down Expand Up @@ -443,7 +477,8 @@ def concatenated_forward(
# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2

all_logits = model(concatenated_input_ids)
with self.activations_handling_ctx:
all_logits = model(concatenated_input_ids)

all_log_probs = rlhf.get_batch_log_probs(
all_logits,
Expand Down
Loading

0 comments on commit 5e5a349

Please sign in to comment.