diff --git a/tests/torchtune/modules/test_vq_embeddings.py b/tests/torchtune/modules/test_vq_embeddings.py index 825bbd6cd9..b8c1e83286 100644 --- a/tests/torchtune/modules/test_vq_embeddings.py +++ b/tests/torchtune/modules/test_vq_embeddings.py @@ -27,15 +27,24 @@ def embedding_dim(self): return 5 @pytest.fixture - def codebook(self, num_embeddings, embedding_dim): - def vq(learnable): - return VectorQuantizedEmbeddings( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - decay=0.3, - learnable=learnable, - ) + def embedding_weights(self): + # This is 4x5 + return tensor( + [ + [1.0, 0.0, -1.0, -1.0, 2.0], + [2.0, -2.0, 0.0, 0.0, 1.0], + [2.0, 1.0, 0.0, 1.0, 1.0], + [-1.0, -2.0, 0.0, 2.0, 0.0], + ] + ) + @pytest.fixture + def codebook(self, num_embeddings, embedding_dim, embedding_weights): + vq = VectorQuantizedEmbeddings( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + ) + vq.embedding.data = embedding_weights return vq @pytest.fixture @@ -59,22 +68,8 @@ def encoded(self): return encoded - @pytest.fixture - def embedding_weights(self): - # This is 4x5 - return tensor( - [ - [1.0, 0.0, -1.0, -1.0, 2.0], - [2.0, -2.0, 0.0, 0.0, 1.0], - [2.0, 1.0, 0.0, 1.0, 1.0], - [-1.0, -2.0, 0.0, 2.0, 0.0], - ] - ) - - def test_quantized_output(self, codebook, encoded, embedding_weights): - vq = codebook(learnable=False) - vq.embedding = embedding_weights - actual = vq(encoded) + def test_quantized_output(self, codebook, encoded): + actual = codebook(encoded) expected_quantized = tensor( [ @@ -97,77 +92,11 @@ def test_quantized_output(self, codebook, encoded, embedding_weights): assert_expected(actual[0], expected_quantized) assert_expected(actual[1], expected_token_ids) - def test_ema_update_embedding(self, codebook, encoded, embedding_weights): - vq = codebook(learnable=True) - vq.embedding = embedding_weights - encoded_flat = encoded.view(-1, encoded.shape[-1]) - distances = torch.cdist(encoded_flat, vq.embedding, p=2.0) ** 2 - codebook_indices = torch.argmin(distances, dim=1) - vq._ema_update_embedding(encoded_flat, codebook_indices) - - actual_weight = vq.embedding - expected_weight = tensor( - [ - [2.0000, -1.0000, 0.0000, 2.0000, 0.0000], - [2.0000, 1.0000, 0.0000, 1.0000, 1.0000], - [0.5647, 1.3760, -0.3936, 1.1213, -0.7635], - [1.0000, 0.0000, -1.0000, -1.0000, 1.0000], - ] - ) - assert_expected(actual_weight, expected_weight, rtol=0.0, atol=1e-4) - - actual_code_avg = vq.code_avg - expected_code_avg = tensor( - [ - [0.4176, 0.3790, -0.7551, -0.6548, 1.0419], - [1.3309, -0.3437, 0.2303, 1.1865, 0.1305], - [1.1859, 2.8897, -0.8265, 2.3547, -1.6033], - [-0.9834, -0.7490, -0.3521, 0.5825, 0.4301], - ] - ) - assert_expected(actual_code_avg, expected_code_avg, rtol=0.0, atol=1e-4) - - actual_code_usage = vq.code_usage - expected_code_usage = tensor([0.7000, 0.7000, 2.1000, 0.7000]) - assert_expected(actual_code_usage, expected_code_usage, rtol=0.0, atol=1e-4) - - def test_codebook_restart(self, codebook, encoded, embedding_weights): - vq = codebook(learnable=True) - vq.embedding = embedding_weights - # Use only embedding vector at index = 1 and force restarts. - # Slightly modify encoded_flat to make sure vectors restart to something new - encoded_flat = encoded.view(-1, encoded.shape[-1]) - encoded_noise = encoded_flat + torch.randn_like(encoded_flat) - codebook_indices_low_usage = torch.ones(encoded_flat.shape[0], dtype=torch.long) - vq._ema_update_embedding(encoded_noise, codebook_indices_low_usage) - - # Check if embedding contains restarts - for i, emb in enumerate(vq.embedding): - # We used only emb vector with index = 1, so check it was not restarted - if i == 1: - assert_expected( - emb, - vq.code_avg[1] / vq.code_usage[1], - rtol=0, - atol=1e-4, - ) - # Compare each embedding vector to each encoded vector. - # If at least one match, then restart happened. - else: - assert any( - [ - torch.isclose(emb, enc, rtol=0, atol=1e-4).all() - for enc in encoded_noise - ] - ), "embedding restarted from encoder output incorrectly" - - def test_lookup(self, codebook, embedding_weights): - vq = codebook(learnable=False) - vq.embedding = embedding_weights + def test_decode(self, codebook): indices_flat = tensor([[0, 1]]) # (b, seq_len) indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape) - actual_quantized_flat = vq.lookup(indices_flat) - actual_quantized = vq.lookup(indices_shaped) + actual_quantized_flat = codebook.decode(indices_flat) + actual_quantized = codebook.decode(indices_shaped) expected_quantized_flat = tensor( [[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]] ) diff --git a/torchtune/modules/vq_embeddings.py b/torchtune/modules/vq_embeddings.py index 12a4918f73..14d6cef995 100644 --- a/torchtune/modules/vq_embeddings.py +++ b/torchtune/modules/vq_embeddings.py @@ -17,8 +17,8 @@ class VectorQuantizedEmbeddings(nn.Module): and performs a nearest-neighbor lookup in the embedding space. Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf) to generate high-fidelity images, videos, and audio data. - The embedding weights are trained with exponential moving average updates as described - in original paper. + + This module currently does not support pre-training of the embeddings via EMA. Code was adapted from torchmultimodal's `Codebook module `_. @@ -26,99 +26,17 @@ class VectorQuantizedEmbeddings(nn.Module): Args: num_embeddings (int): Number of vectors in the embedding space. embedding_dim (int): Dimensionality of the embedding vectors. - decay (float, optional): Factor used in exponential moving average update of the embeddings. - Defaults to ``0.99``. - codebook_usage_threshold (float, optional): Threshold for the average number of times an embedding vector - is chosen below which it will be re-initialized. Defaults to ``1.0``. - learnable (bool): If True, register embedding weights, codebook usage, and codebook average to buffer - for EMA updates during training. If False, only register embedding weights as an nn.Parameter, for use - in a frozen module. Default is False. - epsilon (float, optional): Noise used in Laplace smoothing of codebook usage. Defaults to ``1e-7``. """ def __init__( self, num_embeddings: int, embedding_dim: int, - decay: float = 0.99, - codebook_usage_threshold: float = 1.0, - learnable: bool = False, - epsilon: float = 1e-7, ) -> None: super().__init__() - # Embedding weights and parameters for EMA update will be registered to buffer, as they - # will not be updated by the optimizer but are still model parameters. - # code_usage and code_avg correspond with N and m, respectively, from Oord et al. - randn_init_embedding = torch.randn(num_embeddings, embedding_dim) - self.register_buffer("embedding", randn_init_embedding.clone()) - if learnable: - self.register_buffer("code_usage", torch.zeros(num_embeddings)) - self.register_buffer("code_avg", randn_init_embedding.clone()) - - self.embedding_dim = embedding_dim + self.embedding = nn.Parameter(torch.empty(num_embeddings, embedding_dim)) self.num_embeddings = num_embeddings - self.learnable = learnable - - self._decay = decay - # Used in Laplace smoothing of code usage - self._epsilon = epsilon - - # Threshold for randomly reseting unused embedding vectors - self.codebook_usage_threshold = codebook_usage_threshold - - def _tile(self, x: Tensor, n: int) -> Tensor: - # Repeat vectors in x if x has less than n vectors - num_vectors, num_channels = x.shape - if num_vectors < n: - num_repeats = (n + num_vectors - 1) // num_vectors - # Add a small amount of noise to repeated vectors - std = 0.01 / torch.sqrt(torch.tensor(num_channels)) - x = x.repeat(num_repeats, 1) - x = x + torch.randn_like(x) * std - return x - - def _get_random_vectors(self, x: Tensor, n: int) -> Tensor: - # Gets n random row vectors from 2D tensor x - x_tiled = self._tile(x, n) - idx = torch.randperm(x_tiled.shape[0]) - x_rand = x_tiled[idx][:n] - return x_rand - - def _ema_update_embedding(self, z: Tensor, codebook_indices: Tensor) -> None: - # Closed form solution of codebook loss, ||e - E(x)||^2, is simply the average - # of the encoder output. However, we can't compute this in minibatches, so we - # must use exponential moving average. - - # Convert indices to one hot encoding - codebook_onehot = nn.functional.one_hot( - codebook_indices, num_classes=self.num_embeddings - ).type(torch.float) - # Count how often each embedding vector was looked up - codebook_selection_count = torch.sum(codebook_onehot, 0) - # Update usage value for each embedding vector - self.code_usage.mul_(self._decay).add_( - codebook_selection_count, alpha=(1 - self._decay) - ) - # Laplace smoothing of codebook usage - to prevent zero counts - n = torch.sum(self.code_usage) - self.code_usage.add_(self._epsilon).divide_( - n + self.num_embeddings * self._epsilon - ).mul_(n) - # Get all encoded vectors attracted to each embedding vector - encoded_per_codebook = torch.matmul(codebook_onehot.t(), z) - # Update each embedding vector with new encoded vectors that are attracted to it, - # divided by its usage to yield the mean of encoded vectors that choose it - self.code_avg.mul_(self._decay).add_( - encoded_per_codebook, alpha=(1 - self._decay) - ) - self.embedding = self.code_avg / self.code_usage.unsqueeze(1) - # Reset any embedding vectors that fall below threshold usage with random encoded vectors - z_rand = self._get_random_vectors(z, self.num_embeddings) - self.embedding = torch.where( - self.code_usage.unsqueeze(1) >= self.codebook_usage_threshold, - self.embedding, - z_rand, - ) + self.embedding_dim = embedding_dim def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: """ @@ -147,12 +65,7 @@ def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]: token_ids_flat = torch.argmin(distances, dim=1) # Quantize - shape [b * s, d] - quantized_flat = self.lookup(token_ids_flat) - - # Use exponential moving average to update the embedding instead of a codebook loss, - # as suggested by Oord et al. 2017 and Razavi et al. 2019. - if self.training and self.learnable: - self._ema_update_embedding(z_flat, token_ids_flat) + quantized_flat = self.decode(token_ids_flat) # Straight through estimator quantized_flat = z_flat + (quantized_flat - z_flat).detach() @@ -168,6 +81,6 @@ def extra_repr(self) -> str: self.num_embeddings, self.embedding_dim ) - def lookup(self, token_ids: Tensor) -> Tensor: + def decode(self, token_ids: Tensor) -> Tensor: # Returns the embeddings of shape [b, s, d] return F.embedding(token_ids, self.embedding)