Skip to content

Commit

Permalink
Support Early Exit Loss and/or Layer Dropout (#1076)
Browse files Browse the repository at this point in the history
Co-authored-by: ebsmothers <[email protected]>
  • Loading branch information
mostafaelhoushi and ebsmothers authored Dec 6, 2024
1 parent f799211 commit f8563dd
Show file tree
Hide file tree
Showing 12 changed files with 2,395 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Modeling Components and Building Blocks
TransformerCrossAttentionLayer
TransformerDecoder
VisionTransformer
LayerDropout
prepare_layer_dropout

Losses
------
Expand Down
137 changes: 137 additions & 0 deletions recipes/dev/7B_full_early_exit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Config for multi-device full finetuning with early exit loss and/or layer dropout
# in dev/early_exit_finetune_distributed.py using a Llama2 7B model on a small TOPv2
# instruction set.
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-2-7b-hf --output-dir /tmp/Llama-2-7b-hf --hf-token <HF_TOKEN>
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# To reproduce experiments of various papers that use early exit loss and/or layer dropout:
# - LayerSkip (https://arxiv.org/abs/2404.16710) on TOPv2:
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss.scale=1.0 early_exit_loss.curriculum=torchtune.modules.early_exit_loss.GradualEarlyExitCurriculum early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.linear_l_loss_scale layer_dropout.prob=0.2 layer_dropout.scale=exp
#
# - LITE (https://arxiv.org/abs/2310.18581):
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml layer_dropout=null early_exit_loss.layers=8,12,16,20,24,28 early_exit_loss.scale_fn=torchtune.modules.early_exit_loss.uniform_loss_scale early_exit_loss.curriculum=null epochs=5
#
# - LayerDrop (https://arxiv.org/abs/1909.11556):
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.2 layer_dropout.layers=1::2
#
# - Progressive Layer Dropping (https://arxiv.org/abs/2010.13369) (The paper also implements a curriculum for layer drop probability which is not yet implemented.):
# tune run --nnodes 1 --nproc_per_node 4 dev/early_exit_finetune_distributed --config recipes/dev/7B_full_early_exit.yaml early_exit_loss=null layer_dropout.prob=0.5 layer_dropout.scale=exp
#
# This config works best for distributed training, hence when the model is being fine-tuned on 2+ GPUs.
#


# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model
max_seq_len: null

# Dataset
dataset:
_component_: torchtune.datasets.instruct_dataset
source: WillHeld/top_v2
split: train
column_map:
input: utterance
output: semantic_parse

seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama2.llama2_7b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-2-7b-hf
checkpoint_files: [
pytorch_model-00001-of-00002.bin,
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 8
epochs: 1
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/topv2-llama2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1

# Early Exit Loss
early_exit_loss:
layers: "0::4"
curriculum: torchtune.modules.early_exit_loss.RotationalEarlyExitCurriculum
scale_fn: torchtune.modules.early_exit_loss.sum_l_loss_scale
scale: 1.0

# Layer Dropout
layer_dropout:
prob: 0.2
layers: ":"
layers_scale: "exp"
disable_on_eval: True
Loading

0 comments on commit f8563dd

Please sign in to comment.