Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO Performance Improvements #2066

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented Nov 25, 2024

Closes #1425

This PR provides various performance improvements to our PPO single device recipe.

Branch Total training time (hours)* Peak memory allocated (GB)
Main 13.1 69.6
This branch 5.4 69.5
This branch + compile 4.6 68.6

*The models were trained over approx. 37M tokens (~65k samples w/max_seq_len=512) on a single A100 GPU.

image Due to the non-determinism of the training process curves may look slightly different.

Changelog:

  • KV-cacheing is now supported during trajectory generation - this significantly speeds up training.
  • generation.generate now only returns logits over the generated tokens rather than the whole sequence - significantly reduces peak memory usage. Tests have been updated.
  • Added profiler support to the recipe.
  • Various changes in trajectory estimation/reward estimation which improve performance.
  • Added parents=True to output_dir.mkdir in our checkpointers. We use nested checkpoint folders for PPO e..g output_dir/policy/, output_dir/value/.
  • The addition of various performance improvements in main since the original baseline means we can bump the default batch size in the configs.
  • Compile support. We have two options here:
    1. Compile the trajectory estimation functions separately - minimizes recompiles but results in a small warmup overhead.
    2. Compile each model using training.compile_model - this results in ~10 recompile warnings, which means we need to increase the compile cache size limit - I've added torch._dynamo.config.cache_size_limit = 16 at the top of the recipe.
Pasted image 20241125141622

I landed on option 2 - it's similar to how we integrate compile with the rest of our recipes, and it eliminates the small warmup overhead. To fully realize compile speedups it's recommended to do a small warm-run of the recipe with compile enabled.

Copy link

pytorch-bot bot commented Nov 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2066

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 92927c4 with merge base f2bd4bc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 25, 2024
@felipemello1
Copy link
Contributor

Geez! >3x improvement is no joke. I don't think i will have time to review it this week. But I am very curious to see the changes.

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YOU SHALL NOT PASS

recipes/configs/mistral/7B_full_ppo_low_memory.yaml Outdated Show resolved Hide resolved
@@ -94,15 +94,15 @@ def generate_next_token(
- tokens (torch.Tensor): tensor with the generated tokens,
with shape [bsz x 1].
- logits (torch.Tensor): tensor with the logits associated with the generated tokens,
with shape [bsz x seq_length x vocab_size].
with shape [bsz x 1 x vocab_size].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, this is a BC breaking change for a public API which means we need to deprecate accordingly. Can you make this a flag that is enabled for the PPO use case, then add a deprecation warning?

Copy link
Collaborator Author

@SalmanMohammadi SalmanMohammadi Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid this is going to be challenging to do without introducing graph breaks during compilation. I generally agree with you, though in this case I'm not 100% sure who would be using logits returned from this function outside of PPO

@@ -355,8 +355,8 @@ def generate(
# if incremental decoding is enabled, we can use the current position
# otherwise, we take the whole sequence up to the current position
if incremental_decoding:
curr_input_pos = input_pos[:, curr_pos]
curr_masks = masks[:, curr_pos, None, :]
curr_input_pos = input_pos[:, curr_pos].contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is making a copy of the tensor? So is it slower when not compiling?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this was faster in compile as it avoids recompiles on the mask and input_pos strides.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I definitely believe that, but how does it compare when not compiling?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still pretty minimal but YMMV depending on bsz and sequence length. When I last profiled it:

image

The little black sliver is the .contiguous call, which takes around 50us every step (some napkin math means this overhead is roughly 50us * max_generated_tokens - 1) compared to the 4.562s for generating the entire sequence, so a minimal portion of time.

@@ -189,7 +189,7 @@ def get_position_ids_from_padding_mask(
return ((padding_mask.cumsum(-1) - 1) * padding_mask).to(torch.int)


@torch.inference_mode()
@torch.no_grad()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Truly don't fully understand the difference here lol.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inference_mode changes the attributes of the tensors which will trigger unnecessary recompiles without really being that useful

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to define "without really being that useful"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll need to double check where I found this - I think it was in a PyTorch dev podcast, but when it was released PyTorch folks mentioned up to ~5% improvement gains on deployed models internally at FB. HF PRs for including inference_mode in generation didn't really find speedups to the same degree, so they still use no_grad for generation. To expand on my point above, under compile inference_mode tensors have different metadata properties and we trigger recompiles when guards are created on these properties which results in increased warmup time.

torchtune/modules/transformer.py Show resolved Hide resolved
torchtune/rlhf/loss/ppo.py Outdated Show resolved Hide resolved
torchtune/rlhf/rewards.py Outdated Show resolved Hide resolved
torchtune/rlhf/rewards.py Show resolved Hide resolved
)
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
mask_sum = mask.sum() + 1e-8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point maybe we make the added value configurable rather than 1e-8?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure when someone would want to consider configuring this value

@SalmanMohammadi SalmanMohammadi mentioned this pull request Dec 10, 2024
44 tasks
@SalmanMohammadi
Copy link
Collaborator Author

Is there anything else blocking this from landing? @joecummings

@joecummings joecummings added the rlhf Anything related to reinforcement learning w/ human feedback label Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. rlhf Anything related to reinforcement learning w/ human feedback
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] PPO Performance Optimizations (or: PPOPO)
4 participants