Skip to content

Commit

Permalink
Merge pull request #2 from pytorch/main
Browse files Browse the repository at this point in the history
Vector Quantized Embeddings (pytorch#2040)
  • Loading branch information
rahul-sarvam authored Dec 4, 2024
2 parents efa91bf + e9b9ea5 commit 3098447
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 0 deletions.
114 changes: 114 additions & 0 deletions tests/torchtune/modules/test_vq_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from tests.test_utils import assert_expected
from torch import tensor
from torchtune.modules.vq_embeddings import VectorQuantizedEmbeddings


@pytest.fixture(autouse=True)
def random_seed():
torch.manual_seed(4)


class TestVectorQuantizedEmbeddings:
@pytest.fixture
def num_embeddings(self):
return 4

@pytest.fixture
def embedding_dim(self):
return 5

@pytest.fixture
def embedding_weights(self):
# This is 4x5
return tensor(
[
[1.0, 0.0, -1.0, -1.0, 2.0],
[2.0, -2.0, 0.0, 0.0, 1.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[-1.0, -2.0, 0.0, 2.0, 0.0],
]
)

@pytest.fixture
def codebook(self, num_embeddings, embedding_dim, embedding_weights):
vq = VectorQuantizedEmbeddings(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
)
vq.embedding.data = embedding_weights
return vq

@pytest.fixture
def encoded(self):
# This is 2x3x5
encoded = tensor(
[
[
[-1.0, 2.0, 0.0, 0.0, -2.0],
[0.0, 1.0, -1.0, 2.0, -1.0],
[1.0, 0.0, -1.0, -1.0, 1.0],
],
[
[2.0, 1.0, 0.0, 1.0, 1.0],
[2.0, -1.0, 0.0, 2.0, 0.0],
[-1.0, -2.0, 0.0, 1.0, 0.0],
],
]
)
encoded.requires_grad_()

return encoded

def test_quantized_output(self, codebook, encoded):
actual = codebook(encoded)

expected_quantized = tensor(
[
[
[2.0, 1.0, 0.0, 1.0, 1.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[1.0, 0.0, -1.0, -1.0, 2.0],
],
[
[2.0, 1.0, 0.0, 1.0, 1.0],
[2.0, -2.0, 0.0, 0.0, 1.0],
[-1.0, -2.0, 0.0, 2.0, 0.0],
],
]
)
expected_token_ids = tensor([[2.0, 2.0, 0.0], [2.0, 1.0, 3.0]]).type(
torch.LongTensor
)

assert_expected(actual[0], expected_quantized)
assert_expected(actual[1], expected_token_ids)

def test_decode(self, codebook):
indices_flat = tensor([[0, 1]]) # (b, seq_len)
indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape)
actual_quantized_flat = codebook.decode(indices_flat)
actual_quantized = codebook.decode(indices_shaped)
expected_quantized_flat = tensor(
[[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]]
)
expected_quantized = tensor(
[
[
[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]],
[[2.0, 1.0, 0.0, 1.0, 1.0], [-1.0, -2.0, 0.0, 2.0, 0.0]],
]
]
)
assert_expected(
actual_quantized_flat, expected_quantized_flat, rtol=0.0, atol=1e-4
)
assert_expected(actual_quantized, expected_quantized, rtol=0.0, atol=1e-4)
2 changes: 2 additions & 0 deletions torchtune/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TransformerSelfAttentionLayer,
)
from .vision_transformer import VisionTransformer
from .vq_embeddings import VectorQuantizedEmbeddings

__all__ = [
"MultiHeadAttention",
Expand All @@ -38,6 +39,7 @@
"KVCache",
"RotaryPositionalEmbeddings",
"VisionRotaryPositionalEmbeddings",
"VectorQuantizedEmbeddings",
"RMSNorm",
"TiedLinear",
"Fp32LayerNorm",
Expand Down
86 changes: 86 additions & 0 deletions torchtune/modules/vq_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import functional as F


class VectorQuantizedEmbeddings(nn.Module):
"""
Vector quantized embedding layer that takes in the output of an encoder
and performs a nearest-neighbor lookup in the embedding space.
Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf)
to generate high-fidelity images, videos, and audio data.
This module currently does not support pre-training of the embeddings via EMA.
Code was adapted from torchmultimodal's `Codebook module
<https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/modules/layers/codebook.py>`_.
Args:
num_embeddings (int): Number of vectors in the embedding space.
embedding_dim (int): Dimensionality of the embedding vectors.
"""

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
) -> None:
super().__init__()
self.embedding = nn.Parameter(torch.empty(num_embeddings, embedding_dim))
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim

def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
z (Tensor): Tensor containing a batch of encoder outputs of shape ``(b, s, d)``, where
b is batch size, s is sequence length or time, and d is ``embedding_dim``.
Returns:
Tuple[Tensor, Tensor]: The quantized input and the embedding vector ids that were used.
Raises:
ValueError: if input embedding dimension does not match embedding dimension of module
"""
bsz, seq_len, z_embed_dim = z.shape
if z_embed_dim != self.embedding_dim:
raise ValueError(
f"Expected last dimension of input tensor ({z_embed_dim}) to be embedding size of {self.embedding_dim}"
)

# Flatten into batch dimension
z_flat = z.view(-1, z_embed_dim)
# Calculate distances from each encoder, E(x), output vector to each embedding vector, e, ||E(x) - e||^2
distances = torch.cdist(z_flat, self.embedding, p=2.0) ** 2

# Encoding - select closest embedding vectors, shape [b * s, ]
token_ids_flat = torch.argmin(distances, dim=1)

# Quantize - shape [b * s, d]
quantized_flat = self.decode(token_ids_flat)

# Straight through estimator
quantized_flat = z_flat + (quantized_flat - z_flat).detach()

# Reshape to original - [b, s, d] and [b, s]
quantized = quantized_flat.view(bsz, seq_len, z_embed_dim)
token_ids = token_ids_flat.view(bsz, seq_len)

return quantized, token_ids

def extra_repr(self) -> str:
return "num_embeddings={}, embedding_dim={}".format(
self.num_embeddings, self.embedding_dim
)

def decode(self, token_ids: Tensor) -> Tensor:
# Returns the embeddings of shape [b, s, d]
return F.embedding(token_ids, self.embedding)

0 comments on commit 3098447

Please sign in to comment.