Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jingjing Xu committed Nov 18, 2024
1 parent 1972d99 commit 8b3ed95
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 21 deletions.
2 changes: 2 additions & 0 deletions i6_models/parts/best_rq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .mask import *
from .quantizer import *
56 changes: 42 additions & 14 deletions i6_models/parts/best_rq/mask.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,51 @@
from typing import Optional, Tuple

import torch
import torch.nn as nn
from typing import Optional
import numpy as np

__all__ = ["RandomMask"]


class RandomMask(nn.Module):
def __init__(self, input_dim, mask_replace_val):
"""
randomly mask out consecutive frames time dimension, the masked frames can be either
replaced with zeros or with learnable embeddings.
simplified version from Fairseq compute_mask_indices function,
C.f. https://github.com/facebookresearch/fairseq/blob/ecbf110e1eb43861214b05fa001eff584954f65a/fairseq/data/data_utils.py#L399
"""

def __init__(
self,
input_dim: int,
mask_replace_val: str,
mask_prob: float,
mask_length: int,
min_masks: int = 0,
):
"""
:param input_dim: number of feature dimension of input
:param mask_replace_val: the way to replace masked frames, either with zeros or lernable embeddings
:param mask_prob: percentage of frames to be masked out
:param mask_length: the length of each mask span
:param min_masks: minimum number of masks
"""
super().__init__()

assert mask_replace_val in ["lernable", "zero"], "not implemented yet"
if mask_replace_val == "lernable":
self.mask_emb = nn.Parameter(torch.FloatTensor(input_dim).uniform_())
elif mask_replace_val == 0:
elif mask_replace_val == "zero":
self.mask_emb = torch.zeros(input_dim)
self.mask_prob = mask_prob
self.mask_length = mask_length
self.min_masks = min_masks

def forward(
self,
tensor: torch.tensor,
padding_mask: Optional[torch.Tensor],
mask_prob: float,
mask_length: int,
min_masks: int = 0,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
ndim_batch, ndim_time, _ = tensor.size()

mask = torch.zeros((ndim_batch, ndim_time), dtype=torch.bool)
Expand All @@ -34,22 +60,24 @@ def forward(

num_mask = int(
# add a random number for probabilistic rounding
mask_prob * seq_len / float(mask_length)
self.mask_prob * seq_len / float(self.mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
num_mask = max(self.min_masks, num_mask)

min_len = mask_length
min_len = self.mask_length
if seq_len - min_len <= num_mask:
min_len = seq_len - num_mask - 1
mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False)

mask_idc = np.asarray([mask_idc[j] + mask_length for j in range(len(mask_idc))])
mask_idcs.append(mask_idc)
mask_idc = np.asarray(
[mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(self.mask_length)]
)
mask_idcs.append(mask_idc)

for i, mask_idc in enumerate(mask_idcs):
mask[i, mask_idc] = True

tensor[mask] = self.mask_emb
tensor[mask] = self.mask_emb.to(tensor.device)

return tensor
return tensor, torch.tensor(mask).to(tensor.device)
27 changes: 20 additions & 7 deletions i6_models/parts/best_rq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,35 @@
import torch.nn.functional as F
from torch.linalg import vector_norm

__all__ = [
"RandomProjectionQuantizer",
]


class RandomProjectionQuantizer(nn.Module):
def __init__(self, input_dim, cb_dim, cb_vocab):
"""
implement the fixed random projection quantizer from BestRQ
C.f. https://arxiv.org/pdf/2202.01855 for theoretic background
code adapted from https://github.com/speechbrain/speechbrain/blob/16b6420d4ff23210cfca2e888be8853264e0cb17/speechbrain/nnet/quantisers.py#L127
"""

def __init__(self, input_dim, codebook_dim, codebook_num_vars):
"""
:param input_dim: number of feature dimension of input
:param codebook_dim: number of dimension for vocab in the codebook
:param codebook_num_vars: vocab size of the codebook
"""
super().__init__()

self.input_dim = input_dim
self.cb_dim = cb_dim
self.cb_vocab = cb_vocab

# Section 3.1 "projection matrix A use Xavier initialization"
P_init = torch.empty((input_dim, cb_dim))
# projection matrix use Xavier initialization
P_init = torch.empty((input_dim, codebook_dim))
self.register_buffer("P", nn.init.xavier_uniform_(P_init))

# normalize random matrix for codebook
self.register_buffer("CB", F.normalize(torch.randn(cb_vocab, cb_dim)))
self.register_buffer("CB", F.normalize(torch.randn(codebook_num_vars, codebook_dim)))

def forward(self, x):
def forward(self, x: torch.tensor) -> torch.tensor:
x = F.normalize(x @ self.P)
return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1)

0 comments on commit 8b3ed95

Please sign in to comment.