Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector Quantized Embeddings #2040

Merged
merged 5 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions tests/torchtune/modules/test_vq_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from tests.test_utils import assert_expected
from torch import tensor
from torchtune.modules.vq_embeddings import VectorQuantizedEmbeddings


@pytest.fixture(autouse=True)
def random_seed():
torch.manual_seed(4)


class TestVectorQuantizedEmbeddings:
@pytest.fixture
def num_embeddings(self):
return 4

@pytest.fixture
def embedding_dim(self):
return 5

@pytest.fixture
def embedding_weights(self):
# This is 4x5
return tensor(
[
[1.0, 0.0, -1.0, -1.0, 2.0],
[2.0, -2.0, 0.0, 0.0, 1.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[-1.0, -2.0, 0.0, 2.0, 0.0],
]
)

@pytest.fixture
def codebook(self, num_embeddings, embedding_dim, embedding_weights):
vq = VectorQuantizedEmbeddings(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
)
vq.embedding.data = embedding_weights
return vq

@pytest.fixture
def encoded(self):
# This is 2x3x5
encoded = tensor(
[
[
[-1.0, 2.0, 0.0, 0.0, -2.0],
[0.0, 1.0, -1.0, 2.0, -1.0],
[1.0, 0.0, -1.0, -1.0, 1.0],
],
[
[2.0, 1.0, 0.0, 1.0, 1.0],
[2.0, -1.0, 0.0, 2.0, 0.0],
[-1.0, -2.0, 0.0, 1.0, 0.0],
],
]
)
encoded.requires_grad_()

return encoded

def test_quantized_output(self, codebook, encoded):
actual = codebook(encoded)

expected_quantized = tensor(
[
[
[2.0, 1.0, 0.0, 1.0, 1.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[1.0, 0.0, -1.0, -1.0, 2.0],
],
[
[2.0, 1.0, 0.0, 1.0, 1.0],
[2.0, -2.0, 0.0, 0.0, 1.0],
[-1.0, -2.0, 0.0, 2.0, 0.0],
],
]
)
expected_token_ids = tensor([[2.0, 2.0, 0.0], [2.0, 1.0, 3.0]]).type(
torch.LongTensor
)

assert_expected(actual[0], expected_quantized)
assert_expected(actual[1], expected_token_ids)

def test_decode(self, codebook):
indices_flat = tensor([[0, 1]]) # (b, seq_len)
indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape)
actual_quantized_flat = codebook.decode(indices_flat)
actual_quantized = codebook.decode(indices_shaped)
expected_quantized_flat = tensor(
[[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]]
)
expected_quantized = tensor(
[
[
[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]],
[[2.0, 1.0, 0.0, 1.0, 1.0], [-1.0, -2.0, 0.0, 2.0, 0.0]],
]
]
)
assert_expected(
actual_quantized_flat, expected_quantized_flat, rtol=0.0, atol=1e-4
)
assert_expected(actual_quantized, expected_quantized, rtol=0.0, atol=1e-4)
2 changes: 2 additions & 0 deletions torchtune/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TransformerSelfAttentionLayer,
)
from .vision_transformer import VisionTransformer
from .vq_embeddings import VectorQuantizedEmbeddings

__all__ = [
"MultiHeadAttention",
Expand All @@ -38,6 +39,7 @@
"KVCache",
"RotaryPositionalEmbeddings",
"VisionRotaryPositionalEmbeddings",
"VectorQuantizedEmbeddings",
"RMSNorm",
"TiedLinear",
"Fp32LayerNorm",
Expand Down
86 changes: 86 additions & 0 deletions torchtune/modules/vq_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import functional as F


class VectorQuantizedEmbeddings(nn.Module):
"""
Vector quantized embedding layer that takes in the output of an encoder
and performs a nearest-neighbor lookup in the embedding space.
Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf)
to generate high-fidelity images, videos, and audio data.

This module currently does not support pre-training of the embeddings via EMA.

Code was adapted from torchmultimodal's `Codebook module
RdoubleA marked this conversation as resolved.
Show resolved Hide resolved
<https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/modules/layers/codebook.py>`_.

Args:
num_embeddings (int): Number of vectors in the embedding space.
embedding_dim (int): Dimensionality of the embedding vectors.
"""

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
) -> None:
super().__init__()
self.embedding = nn.Parameter(torch.empty(num_embeddings, embedding_dim))
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim

def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
z (Tensor): Tensor containing a batch of encoder outputs of shape ``(b, s, d)``, where
b is batch size, s is sequence length or time, and d is ``embedding_dim``.

Returns:
Tuple[Tensor, Tensor]: The quantized input and the embedding vector ids that were used.

Raises:
ValueError: if input embedding dimension does not match embedding dimension of module
"""
bsz, seq_len, z_embed_dim = z.shape
if z_embed_dim != self.embedding_dim:
raise ValueError(
f"Expected last dimension of input tensor ({z_embed_dim}) to be embedding size of {self.embedding_dim}"
)

# Flatten into batch dimension
z_flat = z.view(-1, z_embed_dim)
# Calculate distances from each encoder, E(x), output vector to each embedding vector, e, ||E(x) - e||^2
distances = torch.cdist(z_flat, self.embedding, p=2.0) ** 2

# Encoding - select closest embedding vectors, shape [b * s, ]
token_ids_flat = torch.argmin(distances, dim=1)

# Quantize - shape [b * s, d]
quantized_flat = self.decode(token_ids_flat)

# Straight through estimator
quantized_flat = z_flat + (quantized_flat - z_flat).detach()

# Reshape to original - [b, s, d] and [b, s]
quantized = quantized_flat.view(bsz, seq_len, z_embed_dim)
token_ids = token_ids_flat.view(bsz, seq_len)

return quantized, token_ids

def extra_repr(self) -> str:
return "num_embeddings={}, embedding_dim={}".format(
self.num_embeddings, self.embedding_dim
)

def decode(self, token_ids: Tensor) -> Tensor:
# Returns the embeddings of shape [b, s, d]
return F.embedding(token_ids, self.embedding)
Loading