Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize DPO recipe - precomputing reference model log probabilites #25

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

yash12khandelwal
Copy link

@yash12khandelwal yash12khandelwal commented Dec 30, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?
The primary purpose of this PR is to add support for precomputing reference log probabilities when using DPO. This would make the overall training faster by removing the redundant computation across epochs.

  • CustomPreferenceDataset - This file is a modification of the Preference Dataset that allows the storage of the reference log probabilities along with the data. Every get-item call would return a dictionary of input_ids, labels and the reference model chosen and rejected log probabilities.
  • padded_collate_dpo - Modified this function to return the precomputed log probabilities too along with the inputs and labels.
  • lora_dpo_distributed - Added the support to precompute the reference log probabilities during data setup. For computing losses, the batch item can return the precomputed values saving compute. Implementation is inspired from the Hugging Face trl repository DPO implementation.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings


# Dataset and Sampler
dataset:
_component_: torchtune.datasets.stack_exchange_paired_dataset

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Component needs change

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants