diff --git a/i6_models/parts/best_rq/__init__.py b/i6_models/parts/best_rq/__init__.py new file mode 100644 index 00000000..fb9a81db --- /dev/null +++ b/i6_models/parts/best_rq/__init__.py @@ -0,0 +1,2 @@ +from .mask import * +from .quantizer import * diff --git a/i6_models/parts/best_rq/mask.py b/i6_models/parts/best_rq/mask.py index 59106a8c..176ec2c3 100644 --- a/i6_models/parts/best_rq/mask.py +++ b/i6_models/parts/best_rq/mask.py @@ -1,25 +1,51 @@ +from typing import Optional, Tuple + import torch import torch.nn as nn -from typing import Optional import numpy as np +__all__ = ["RandomMask"] + class RandomMask(nn.Module): - def __init__(self, input_dim, mask_replace_val): + """ + randomly mask out consecutive frames time dimension, the masked frames can be either + replaced with zeros or with learnable embeddings. + simplified version from Fairseq compute_mask_indices function, + C.f. https://github.com/facebookresearch/fairseq/blob/ecbf110e1eb43861214b05fa001eff584954f65a/fairseq/data/data_utils.py#L399 + """ + def __init__( + self, + input_dim: int, + mask_replace_val: str, + mask_prob: float, + mask_length: int, + min_masks: int = 0, + ): + """ + :param input_dim: number of feature dimension of input + :param mask_replace_val: the way to replace masked frames, either with zeros or lernable embeddings + :param mask_prob: percentage of frames to be masked out + :param mask_length: the length of each mask span + :param min_masks: minimum number of masks + """ + super().__init__() + + assert mask_replace_val in ["lernable", "zero"], "not implemented yet" if mask_replace_val == "lernable": self.mask_emb = nn.Parameter(torch.FloatTensor(input_dim).uniform_()) - elif mask_replace_val == 0: + elif mask_replace_val == "zero": self.mask_emb = torch.zeros(input_dim) + self.mask_prob = mask_prob + self.mask_length = mask_length + self.min_masks = min_masks def forward( self, tensor: torch.tensor, padding_mask: Optional[torch.Tensor], - mask_prob: float, - mask_length: int, - min_masks: int = 0, - ): + ) -> Tuple[torch.Tensor, torch.Tensor]: ndim_batch, ndim_time, _ = tensor.size() mask = torch.zeros((ndim_batch, ndim_time), dtype=torch.bool) @@ -34,22 +60,24 @@ def forward( num_mask = int( # add a random number for probabilistic rounding - mask_prob * seq_len / float(mask_length) + self.mask_prob * seq_len / float(self.mask_length) + np.random.rand() ) - num_mask = max(min_masks, num_mask) + num_mask = max(self.min_masks, num_mask) - min_len = mask_length + min_len = self.mask_length if seq_len - min_len <= num_mask: min_len = seq_len - num_mask - 1 mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False) - mask_idc = np.asarray([mask_idc[j] + mask_length for j in range(len(mask_idc))]) - mask_idcs.append(mask_idc) + mask_idc = np.asarray( + [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(self.mask_length)] + ) + mask_idcs.append(mask_idc) for i, mask_idc in enumerate(mask_idcs): mask[i, mask_idc] = True - tensor[mask] = self.mask_emb + tensor[mask] = self.mask_emb.to(tensor.device) - return tensor + return tensor, torch.tensor(mask).to(tensor.device) diff --git a/i6_models/parts/best_rq/quantizer.py b/i6_models/parts/best_rq/quantizer.py index 0d039ba5..8633eb42 100644 --- a/i6_models/parts/best_rq/quantizer.py +++ b/i6_models/parts/best_rq/quantizer.py @@ -3,22 +3,35 @@ import torch.nn.functional as F from torch.linalg import vector_norm +__all__ = [ + "RandomProjectionQuantizer", +] + class RandomProjectionQuantizer(nn.Module): - def __init__(self, input_dim, cb_dim, cb_vocab): + """ + implement the fixed random projection quantizer from BestRQ + C.f. https://arxiv.org/pdf/2202.01855 for theoretic background + code adapted from https://github.com/speechbrain/speechbrain/blob/16b6420d4ff23210cfca2e888be8853264e0cb17/speechbrain/nnet/quantisers.py#L127 + """ + + def __init__(self, input_dim, codebook_dim, codebook_num_vars): + """ + :param input_dim: number of feature dimension of input + :param codebook_dim: number of dimension for vocab in the codebook + :param codebook_num_vars: vocab size of the codebook + """ super().__init__() self.input_dim = input_dim - self.cb_dim = cb_dim - self.cb_vocab = cb_vocab - # Section 3.1 "projection matrix A use Xavier initialization" - P_init = torch.empty((input_dim, cb_dim)) + # projection matrix use Xavier initialization + P_init = torch.empty((input_dim, codebook_dim)) self.register_buffer("P", nn.init.xavier_uniform_(P_init)) # normalize random matrix for codebook - self.register_buffer("CB", F.normalize(torch.randn(cb_vocab, cb_dim))) + self.register_buffer("CB", F.normalize(torch.randn(codebook_num_vars, codebook_dim))) - def forward(self, x): + def forward(self, x: torch.tensor) -> torch.tensor: x = F.normalize(x @ self.P) return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1)