-
Notifications
You must be signed in to change notification settings - Fork 494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vector Quantized Embeddings #2040
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2040
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1a0841a with merge base 32e265d (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a nice implementation. I left some nit comments and questions. But my larger question is if all of this is needed in torchtune? How much of this code is just needed to support training a codebook? Do we want/need that or should we just house the minimal amount needed for inference?
torchtune/modules/vq_embeddings.py
Outdated
# code_usage and code_avg correspond with N and m, respectively, from Oord et al. | ||
randn_init_embedding = torch.randn(num_embeddings, embedding_dim) | ||
self.register_buffer("embedding", randn_init_embedding.clone()) | ||
if learnable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like this variable because it could disagree with the nn.module.training attribute. Is this only needed for EMA? Why wouldn't the ema code handle this?
torchtune/modules/vq_embeddings.py
Outdated
self.num_embeddings, self.embedding_dim | ||
) | ||
|
||
def lookup(self, token_ids: Tensor) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do we want to use "lookup" here or something more consistent with the rest of the library like deocde?
Open to either, was thinking since I'm putting this in modules it should at least be trainable. But it would simplify the class significantly if we only supported inference for this. |
It's up to you here. It could be nice to already have training logic, but I'm not sure if we know what a training recipe for encoders would look like for us. If we decide to support training an encoder in the future we might make different decisions on how to split the loss logic from the module logic. If we wanted to make this module finetunable with the rest of the LLM, would we use these same methods or allow full override of the codebook? |
Vector Quantized Embeddings (pytorch#2040)
Context
What is the purpose of this PR? Is it to
Vector quantization is an old method introduced with the VQ-VAE (https://arxiv.org/abs/1711.00937) for achieving high fidelity image generation. It still remains a core component in many landmark models:
To continue to support image and more modalities in the future, we need to add the core VQ embedding layer as a module. This module will maintain the codebook embeddings and perform the lookup for discretizing an encoder input.
I've adapted the
Codebook
layer from torchmultimodal (found here, apparently it shares authors with the author of this PR...). I've removed features such as initializing embeddings with encoder input for faster convergence to help simplify the code. I've kept the core embedding lookup functionality and the EMA update so that it could still be trained if users want to experiment.Changelog
What are the changes made in this PR?
VectorQuantizedEmbeddings
and requisite testsTest plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example