Skip to content

Commit

Permalink
remove EMA training
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Dec 2, 2024
1 parent 01cbd6b commit 1a0841a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 186 deletions.
115 changes: 22 additions & 93 deletions tests/torchtune/modules/test_vq_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,24 @@ def embedding_dim(self):
return 5

@pytest.fixture
def codebook(self, num_embeddings, embedding_dim):
def vq(learnable):
return VectorQuantizedEmbeddings(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
decay=0.3,
learnable=learnable,
)
def embedding_weights(self):
# This is 4x5
return tensor(
[
[1.0, 0.0, -1.0, -1.0, 2.0],
[2.0, -2.0, 0.0, 0.0, 1.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[-1.0, -2.0, 0.0, 2.0, 0.0],
]
)

@pytest.fixture
def codebook(self, num_embeddings, embedding_dim, embedding_weights):
vq = VectorQuantizedEmbeddings(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
)
vq.embedding.data = embedding_weights
return vq

@pytest.fixture
Expand All @@ -59,22 +68,8 @@ def encoded(self):

return encoded

@pytest.fixture
def embedding_weights(self):
# This is 4x5
return tensor(
[
[1.0, 0.0, -1.0, -1.0, 2.0],
[2.0, -2.0, 0.0, 0.0, 1.0],
[2.0, 1.0, 0.0, 1.0, 1.0],
[-1.0, -2.0, 0.0, 2.0, 0.0],
]
)

def test_quantized_output(self, codebook, encoded, embedding_weights):
vq = codebook(learnable=False)
vq.embedding = embedding_weights
actual = vq(encoded)
def test_quantized_output(self, codebook, encoded):
actual = codebook(encoded)

expected_quantized = tensor(
[
Expand All @@ -97,77 +92,11 @@ def test_quantized_output(self, codebook, encoded, embedding_weights):
assert_expected(actual[0], expected_quantized)
assert_expected(actual[1], expected_token_ids)

def test_ema_update_embedding(self, codebook, encoded, embedding_weights):
vq = codebook(learnable=True)
vq.embedding = embedding_weights
encoded_flat = encoded.view(-1, encoded.shape[-1])
distances = torch.cdist(encoded_flat, vq.embedding, p=2.0) ** 2
codebook_indices = torch.argmin(distances, dim=1)
vq._ema_update_embedding(encoded_flat, codebook_indices)

actual_weight = vq.embedding
expected_weight = tensor(
[
[2.0000, -1.0000, 0.0000, 2.0000, 0.0000],
[2.0000, 1.0000, 0.0000, 1.0000, 1.0000],
[0.5647, 1.3760, -0.3936, 1.1213, -0.7635],
[1.0000, 0.0000, -1.0000, -1.0000, 1.0000],
]
)
assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4)

actual_code_avg = vq.code_avg
expected_code_avg = tensor(
[
[0.4176, 0.3790, -0.7551, -0.6548, 1.0419],
[1.3309, -0.3437, 0.2303, 1.1865, 0.1305],
[1.1859, 2.8897, -0.8265, 2.3547, -1.6033],
[-0.9834, -0.7490, -0.3521, 0.5825, 0.4301],
]
)
assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4)

actual_code_usage = vq.code_usage
expected_code_usage = tensor([0.7000, 0.7000, 2.1000, 0.7000])
assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4)

def test_codebook_restart(self, codebook, encoded, embedding_weights):
vq = codebook(learnable=True)
vq.embedding = embedding_weights
# Use only embedding vector at index = 1 and force restarts.
# Slightly modify encoded_flat to make sure vectors restart to something new
encoded_flat = encoded.view(-1, encoded.shape[-1])
encoded_noise = encoded_flat + torch.randn_like(encoded_flat)
codebook_indices_low_usage = torch.ones(encoded_flat.shape[0], dtype=torch.long)
vq._ema_update_embedding(encoded_noise, codebook_indices_low_usage)

# Check if embedding contains restarts
for i, emb in enumerate(vq.embedding):
# We used only emb vector with index = 1, so check it was not restarted
if i == 1:
assert_expected(
emb,
vq.code_avg[1] / vq.code_usage[1],
rtol=0,
atol=1e-4,
)
# Compare each embedding vector to each encoded vector.
# If at least one match, then restart happened.
else:
assert any(
[
torch.isclose(emb, enc, rtol=0, atol=1e-4).all()
for enc in encoded_noise
]
), "embedding restarted from encoder output incorrectly"

def test_lookup(self, codebook, embedding_weights):
vq = codebook(learnable=False)
vq.embedding = embedding_weights
def test_decode(self, codebook):
indices_flat = tensor([[0, 1]]) # (b, seq_len)
indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape)
actual_quantized_flat = vq.lookup(indices_flat)
actual_quantized = vq.lookup(indices_shaped)
actual_quantized_flat = codebook.decode(indices_flat)
actual_quantized = codebook.decode(indices_shaped)
expected_quantized_flat = tensor(
[[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]]
)
Expand Down
99 changes: 6 additions & 93 deletions torchtune/modules/vq_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,108 +17,26 @@ class VectorQuantizedEmbeddings(nn.Module):
and performs a nearest-neighbor lookup in the embedding space.
Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf)
to generate high-fidelity images, videos, and audio data.
The embedding weights are trained with exponential moving average updates as described
in original paper.
This module currently does not support pre-training of the embeddings via EMA.
Code was adapted from torchmultimodal's `Codebook module
<https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/modules/layers/codebook.py>`_.
Args:
num_embeddings (int): Number of vectors in the embedding space.
embedding_dim (int): Dimensionality of the embedding vectors.
decay (float, optional): Factor used in exponential moving average update of the embeddings.
Defaults to ``0.99``.
codebook_usage_threshold (float, optional): Threshold for the average number of times an embedding vector
is chosen below which it will be re-initialized. Defaults to ``1.0``.
learnable (bool): If True, register embedding weights, codebook usage, and codebook average to buffer
for EMA updates during training. If False, only register embedding weights as an nn.Parameter, for use
in a frozen module. Default is False.
epsilon (float, optional): Noise used in Laplace smoothing of codebook usage. Defaults to ``1e-7``.
"""

def __init__(
self,
num_embeddings: int,
embedding_dim: int,
decay: float = 0.99,
codebook_usage_threshold: float = 1.0,
learnable: bool = False,
epsilon: float = 1e-7,
) -> None:
super().__init__()
# Embedding weights and parameters for EMA update will be registered to buffer, as they
# will not be updated by the optimizer but are still model parameters.
# code_usage and code_avg correspond with N and m, respectively, from Oord et al.
randn_init_embedding = torch.randn(num_embeddings, embedding_dim)
self.register_buffer("embedding", randn_init_embedding.clone())
if learnable:
self.register_buffer("code_usage", torch.zeros(num_embeddings))
self.register_buffer("code_avg", randn_init_embedding.clone())

self.embedding_dim = embedding_dim
self.embedding = nn.Parameter(torch.empty(num_embeddings, embedding_dim))
self.num_embeddings = num_embeddings
self.learnable = learnable

self._decay = decay
# Used in Laplace smoothing of code usage
self._epsilon = epsilon

# Threshold for randomly reseting unused embedding vectors
self.codebook_usage_threshold = codebook_usage_threshold

def _tile(self, x: Tensor, n: int) -> Tensor:
# Repeat vectors in x if x has less than n vectors
num_vectors, num_channels = x.shape
if num_vectors < n:
num_repeats = (n + num_vectors - 1) // num_vectors
# Add a small amount of noise to repeated vectors
std = 0.01 / torch.sqrt(torch.tensor(num_channels))
x = x.repeat(num_repeats, 1)
x = x + torch.randn_like(x) * std
return x

def _get_random_vectors(self, x: Tensor, n: int) -> Tensor:
# Gets n random row vectors from 2D tensor x
x_tiled = self._tile(x, n)
idx = torch.randperm(x_tiled.shape[0])
x_rand = x_tiled[idx][:n]
return x_rand

def _ema_update_embedding(self, z: Tensor, codebook_indices: Tensor) -> None:
# Closed form solution of codebook loss, ||e - E(x)||^2, is simply the average
# of the encoder output. However, we can't compute this in minibatches, so we
# must use exponential moving average.

# Convert indices to one hot encoding
codebook_onehot = nn.functional.one_hot(
codebook_indices, num_classes=self.num_embeddings
).type(torch.float)
# Count how often each embedding vector was looked up
codebook_selection_count = torch.sum(codebook_onehot, 0)
# Update usage value for each embedding vector
self.code_usage.mul_(self._decay).add_(
codebook_selection_count, alpha=(1 - self._decay)
)
# Laplace smoothing of codebook usage - to prevent zero counts
n = torch.sum(self.code_usage)
self.code_usage.add_(self._epsilon).divide_(
n + self.num_embeddings * self._epsilon
).mul_(n)
# Get all encoded vectors attracted to each embedding vector
encoded_per_codebook = torch.matmul(codebook_onehot.t(), z)
# Update each embedding vector with new encoded vectors that are attracted to it,
# divided by its usage to yield the mean of encoded vectors that choose it
self.code_avg.mul_(self._decay).add_(
encoded_per_codebook, alpha=(1 - self._decay)
)
self.embedding = self.code_avg / self.code_usage.unsqueeze(1)
# Reset any embedding vectors that fall below threshold usage with random encoded vectors
z_rand = self._get_random_vectors(z, self.num_embeddings)
self.embedding = torch.where(
self.code_usage.unsqueeze(1) >= self.codebook_usage_threshold,
self.embedding,
z_rand,
)
self.embedding_dim = embedding_dim

def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
"""
Expand Down Expand Up @@ -147,12 +65,7 @@ def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
token_ids_flat = torch.argmin(distances, dim=1)

# Quantize - shape [b * s, d]
quantized_flat = self.lookup(token_ids_flat)

# Use exponential moving average to update the embedding instead of a codebook loss,
# as suggested by Oord et al. 2017 and Razavi et al. 2019.
if self.training and self.learnable:
self._ema_update_embedding(z_flat, token_ids_flat)
quantized_flat = self.decode(token_ids_flat)

# Straight through estimator
quantized_flat = z_flat + (quantized_flat - z_flat).detach()
Expand All @@ -168,6 +81,6 @@ def extra_repr(self) -> str:
self.num_embeddings, self.embedding_dim
)

def lookup(self, token_ids: Tensor) -> Tensor:
def decode(self, token_ids: Tensor) -> Tensor:
# Returns the embeddings of shape [b, s, d]
return F.embedding(token_ids, self.embedding)

0 comments on commit 1a0841a

Please sign in to comment.