Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector Quantized Embeddings #2040

Merged
merged 5 commits into from
Dec 3, 2024
Merged

Vector Quantized Embeddings #2040

merged 5 commits into from
Dec 3, 2024

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Nov 21, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Vector quantization is an old method introduced with the VQ-VAE (https://arxiv.org/abs/1711.00937) for achieving high fidelity image generation. It still remains a core component in many landmark models:

  • the DALL-E models use it for image generation by autoregressively predicting image token IDs
  • Latent diffusion uses a VAE (not sure if it was VQ or continuous) to perform the diffusion process in the latent space
  • Muse by Google uses VQVAE similarly for image generation
  • Make-A-Video by Meta uses it for video generation from text

To continue to support image and more modalities in the future, we need to add the core VQ embedding layer as a module. This module will maintain the codebook embeddings and perform the lookup for discretizing an encoder input.

I've adapted the Codebook layer from torchmultimodal (found here, apparently it shares authors with the author of this PR...). I've removed features such as initializing embeddings with encoder input for faster convergence to help simplify the code. I've kept the core embedding lookup functionality and the EMA update so that it could still be trained if users want to experiment.

Changelog

What are the changes made in this PR?

  • Add VectorQuantizedEmbeddings and requisite tests

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Nov 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2040

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1a0841a with merge base 32e265d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 21, 2024
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a nice implementation. I left some nit comments and questions. But my larger question is if all of this is needed in torchtune? How much of this code is just needed to support training a codebook? Do we want/need that or should we just house the minimal amount needed for inference?

torchtune/modules/vq_embeddings.py Show resolved Hide resolved
# code_usage and code_avg correspond with N and m, respectively, from Oord et al.
randn_init_embedding = torch.randn(num_embeddings, embedding_dim)
self.register_buffer("embedding", randn_init_embedding.clone())
if learnable:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this variable because it could disagree with the nn.module.training attribute. Is this only needed for EMA? Why wouldn't the ema code handle this?

self.num_embeddings, self.embedding_dim
)

def lookup(self, token_ids: Tensor) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we want to use "lookup" here or something more consistent with the rest of the library like deocde?

@RdoubleA
Copy link
Contributor Author

RdoubleA commented Dec 2, 2024

Do we want/need that or should we just house the minimal amount needed for inference?

Open to either, was thinking since I'm putting this in modules it should at least be trainable. But it would simplify the class significantly if we only supported inference for this.

@pbontrager
Copy link
Contributor

It's up to you here. It could be nice to already have training logic, but I'm not sure if we know what a training recipe for encoders would look like for us. If we decide to support training an encoder in the future we might make different decisions on how to split the loss logic from the module logic. If we wanted to make this module finetunable with the rest of the LLM, would we use these same methods or allow full override of the codebook?

@RdoubleA RdoubleA merged commit e9b9ea5 into pytorch:main Dec 3, 2024
17 checks passed
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 4, 2024
Vector Quantized Embeddings (pytorch#2040)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants