Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RLHF with PPO #1005

Merged
merged 44 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
11d88a2
Refactoring TransformerDecoder and adding value-head transformers
SalmanMohammadi May 9, 2024
2849ec5
adding ppo config and recipe to registry
SalmanMohammadi May 10, 2024
f0c1410
Merge branch 'pytorch:main' into ppo
SalmanMohammadi May 12, 2024
57c67bf
implemented ppo recipe structure, advantage and return estimation, tr…
SalmanMohammadi May 15, 2024
03cba4b
finished first pass implementation of ppo. added tests for ppo loss
SalmanMohammadi May 15, 2024
f50f047
reverting changes
SalmanMohammadi May 15, 2024
b034af7
adding lora to ppo recipe, adding lora value head component and model…
SalmanMohammadi May 16, 2024
466b683
added lora training, added value head checkpointing and recipe resumi…
SalmanMohammadi May 19, 2024
928037d
removing test model builders, adding batched generation to ppo recipe…
SalmanMohammadi May 21, 2024
68b6162
fixing bug in _checkpointer.py
SalmanMohammadi May 21, 2024
65ca12a
Adding support for user-provided masks in attention
SalmanMohammadi May 30, 2024
9d8c5a8
Merge branch 'pytorch:main' into ppo
SalmanMohammadi May 31, 2024
b99102c
merging transformer custom masking, adding support for generation wit…
SalmanMohammadi Jun 4, 2024
a1cde1c
adding functionality for truncation in generation, and further tests …
SalmanMohammadi Jun 4, 2024
b032778
updated lora recipe to use custom generation
SalmanMohammadi Jun 6, 2024
f126e9a
Merge branch 'pytorch:main' into ppo
SalmanMohammadi Jun 6, 2024
04d514a
added support for correct truncation and padding of responses, added …
SalmanMohammadi Jun 7, 2024
4854908
added correct mask and position id trajectory generation, score rejec…
SalmanMohammadi Jun 8, 2024
c885833
bugfixing in ppo recipe. refactoring ppo_utils and tests to individua…
SalmanMohammadi Jun 8, 2024
57d57fa
updating ppo_utils namespace
SalmanMohammadi Jun 8, 2024
cce5548
fixing bug in collation, updating loss tests
SalmanMohammadi Jun 10, 2024
c289566
bugfixes in masking and indexing logprobs and values, added fixed kl …
SalmanMohammadi Jun 12, 2024
a3fa1ea
added loss and value masking
SalmanMohammadi Jun 14, 2024
c3db142
some refactoring, lots of testing and docs
SalmanMohammadi Jun 16, 2024
589bf7d
improved early training stability by adding value head init. from rew…
SalmanMohammadi Jun 16, 2024
346c30b
updating metrics
SalmanMohammadi Jun 18, 2024
2e9d779
reworking causal masking
SalmanMohammadi Jun 18, 2024
46b75be
freeing up memory after steps to avoid mem leaks
SalmanMohammadi Jun 18, 2024
0fd885e
Merge branch 'main' into ppo
SalmanMohammadi Jul 16, 2024
1942b0f
cleaning up; verifying results; switching to full finetune
SalmanMohammadi Jul 16, 2024
58d92ab
tidying up
SalmanMohammadi Jul 16, 2024
1fbb6dc
detaching losses for metric logging
SalmanMohammadi Jul 18, 2024
65ef9dc
removing 1b, merging main
SalmanMohammadi Jul 25, 2024
c7bbff1
merging
SalmanMohammadi Jul 25, 2024
1129f9e
deleting logits in loss
SalmanMohammadi Jul 29, 2024
fe87dfb
Merge branch 'main' into ppo
SalmanMohammadi Aug 2, 2024
662ab2c
cleaning conf
SalmanMohammadi Aug 2, 2024
76b124f
pYdOcLiNt
SalmanMohammadi Aug 2, 2024
dc4887c
downloading weights
SalmanMohammadi Aug 3, 2024
ef85dba
addressing comments
SalmanMohammadi Aug 5, 2024
fd87fe6
updating test
SalmanMohammadi Aug 5, 2024
ba365a8
let's finish this the way we started... together
SalmanMohammadi Aug 5, 2024
e76304c
Merge branch 'main' into ppo
SalmanMohammadi Aug 5, 2024
4e6be43
lInTiNG
SalmanMohammadi Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Loss
:toctree: generated/
:nosignatures:

loss.PPOLoss
loss.DPOLoss
loss.RSOLoss
loss.IPOLoss
Expand All @@ -98,3 +99,17 @@ Functions used for preprocessing images.
transforms.tile_crop
transforms.find_supported_resolutions
transforms.VisionCrossAttentionMask

Reinforcement Learning From Human Feedback (RLHF)
--------------------------------------------------
Components for RLHF algorithms like PPO.

.. autosummary::
:toctree: generated/
:nosignatures:

rlhf.estimate_advantages
rlhf.get_rewards_ppo
rlhf.truncate_sequence_at_first_stop_token
rlhf.left_padded_collate
rlhf.padded_collate_dpo
1 change: 0 additions & 1 deletion docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ Utilities for working with data and datasets.
:nosignatures:

padded_collate
padded_collate_dpo

.. _gen_label:

Expand Down
180 changes: 180 additions & 0 deletions recipes/configs/mistral/7B_full_ppo_low_memory.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Config for single device RLHF full finetuning using PPO in ppo_full_finetune_single_device.py
# using a Mistral 7B model.
#
# This config has been tested on an A100 80GB.
# This config uses hyperparameters based on small set of experiments and information
# available from existing implementations.
#
# This config assumes that you've run the following command before launching
# this run:
# tune download weqweasdas/RM-Mistral-7B --output-dir /tmp/RM-Mistral-7B/ --ignore-patterns=""
# tune download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/Mistral-7B-Instruct-v0.2/ --hf-token HF_TOKEN
#
# You'll also need to ensure that {output_dir} exists beforehand, as checkpoints for policy and value models are saved in sub-folders.
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run ppo_full_finetune_single_device --config mistral/7B_full_ppo_low_memory
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run ppo_full_finetune_single_device --config mistral/7B_full_low_memory checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#

# Tokenizer
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
path: /tmp/Mistral-7B-Instruct-v0.2/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.text_completion_dataset
source: trl-internal-testing/sentiment-trl-style
max_seq_len: null
split: train
column: prompt
add_eos: False

policy_model:
_component_: torchtune.models.mistral.mistral_7b

# we need to manually build the mistral classifier model
# because our reward model checkpoint has a larger vocabulary size (due to an added padding token)
reward_and_value_model:
_component_: torchtune.models.mistral._component_builders.mistral_classifier
attn_dropout: 0.0
embed_dim: 4096
intermediate_dim: 14336
max_seq_len: 32768
norm_eps: 1.0e-05
num_classes: 1
num_heads: 32
num_kv_heads: 8
num_layers: 32
vocab_size: 32001

# checkpointer for the policy model - update this if resuming from checkpoint
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/
checkpoint_files: [
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin"
]
# this is the only place where you should update `recipe_checkpoint` if resuming training
recipe_checkpoint: null
output_dir: ${output_dir}/policy
model_type: MISTRAL

# this should be setup identically to the policy model checkpointer at the start of training
# ensure `checkpoint_files` always points to the original policy weights, even if resuming training
ref_policy_checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/
checkpoint_files: [
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin"
]
output_dir: ${output_dir}/policy
model_type: MISTRAL

# checkpointer for the value model - update `checkpoint_files` if resuming from checkpoint
# since this model will be identical to the reward model it's helpful to initialise this
# from the trained reward model weights
value_checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/RM-Mistral-7B/
checkpoint_files: [
"model-00001-of-00003.safetensors",
"model-00002-of-00003.safetensors",
"model-00003-of-00003.safetensors"
]
output_dir: ${output_dir}/value
model_type: REWARD

# checkpointer for the reward model, ensure `checkpoint_files`
# always points to the original reward model weights, even if resuming training
reward_checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/RM-Mistral-7B/
checkpoint_files: [
"model-00001-of-00003.safetensors",
"model-00002-of-00003.safetensors",
"model-00003-of-00003.safetensors"
]
output_dir: ${output_dir}/value
model_type: REWARD


resume_from_checkpoint: False
output_dir: /tmp/mistral7b-ppo-finetune
seed: null
shuffle: True

# Training env
device: cuda

# Training arguments
batch_size: 64
num_steps: 10000
ppo_epochs: 2
ppo_batch_size: 32
gradient_accumulation_steps: 1

# Memory management and performance
compile: True
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 3e-6
optimizer_in_bwd: True
log_peak_memory_stats: False
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16


# batch size for forward pass during generation
forward_batch_size: 16
max_generated_tokens: 58
temperature: 0.7
top_k: null

# parameter for penalising generations shorter than `min_response_length`
min_response_length: 18
# parameter for penalising generations without a stop token
penalise_no_eos: True
# scalar penalty to apply when penalising
reward_penalty: -3

# tokens to consider as "end of sequence" tokens
stop_token_ids: [
2, # eos_id
28723 # mistral "." token
]
whiten_rewards: False

# GAE hyperparameters
gamma: 1
lmbda: 0.95

# PPO hyperparameters
loss:
_component_: torchtune.modules.loss.PPOLoss
epsilon: 0.2
value_coeff: 0.1
value_clip_range: 0.2
kl_coeff: 0.01


# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}

log_every_n_steps: 1
3 changes: 2 additions & 1 deletion recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torchtune import config, modules, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.peft.peft_utils import (
disable_adapter,
get_adapter_params,
Expand Down Expand Up @@ -449,7 +450,7 @@ def _setup_data(
batch_size=batch_size,
sampler=sampler,
collate_fn=partial(
utils.padded_collate_dpo,
rlhf.padded_collate_dpo,
padding_idx=self._tokenizer.pad_id,
ignore_idx=CROSS_ENTROPY_IGNORE_IDX,
),
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchtune import config, modules, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.peft.peft_utils import (
disable_adapter,
get_adapter_params,
Expand Down Expand Up @@ -345,7 +346,7 @@ def _setup_data(
sampler=sampler,
batch_size=batch_size,
collate_fn=partial(
utils.padded_collate_dpo,
rlhf.padded_collate_dpo,
padding_idx=self._tokenizer.pad_id,
ignore_idx=CROSS_ENTROPY_IGNORE_IDX,
),
Expand Down
Loading
Loading