Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor wav2vec2 masker #830

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/fairseq2/models/wav2vec2/asr/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Wav2Vec2EncoderBuilder,
Wav2Vec2EncoderConfig,
)
from fairseq2.models.wav2vec2.masker import Wav2Vec2Masker
from fairseq2.models.wav2vec2.masker import StandardWav2Vec2Masker, Wav2Vec2Masker
from fairseq2.typing import DataType, Device

WAV2VEC2_ASR_FAMILY: Final = "wav2vec2_asr"
Expand Down Expand Up @@ -150,7 +150,7 @@ def build_masker(self) -> Wav2Vec2Masker | None:
if not self._config.use_masking:
return None

return Wav2Vec2Masker(
return StandardWav2Vec2Masker(
self._config.encoder_config.model_dim,
self._config.temporal_mask_span_len,
self._config.max_temporal_mask_prob,
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/models/wav2vec2/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Wav2Vec2FeatureExtractor,
)
from fairseq2.models.wav2vec2.frontend import Wav2Vec2Frontend
from fairseq2.models.wav2vec2.masker import Wav2Vec2Masker
from fairseq2.models.wav2vec2.masker import StandardWav2Vec2Masker, Wav2Vec2Masker
from fairseq2.models.wav2vec2.model import Wav2Vec2Model
from fairseq2.models.wav2vec2.position_encoder import (
Wav2Vec2PositionEncoder,
Expand Down Expand Up @@ -305,7 +305,7 @@ def build_model(self) -> Wav2Vec2Model:

def build_masker(self) -> Wav2Vec2Masker:
"""Build a feature masker."""
return Wav2Vec2Masker(
return StandardWav2Vec2Masker(
self._config.encoder_config.model_dim,
self._config.temporal_mask_span_len,
self._config.max_temporal_mask_prob,
Expand Down
70 changes: 48 additions & 22 deletions src/fairseq2/models/wav2vec2/masker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,51 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import final

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Module, Parameter
from typing_extensions import override

from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.utils.mask import compute_row_mask
from fairseq2.nn.utils.mask import RowMaskFactory, compute_row_mask
from fairseq2.typing import DataType, Device


class Wav2Vec2Masker(Module, ABC):
"""Masks extracted wav2vec 2.0 features."""

@abstractmethod
def forward(
self, seqs: Tensor, padding_mask: PaddingMask | None
) -> tuple[Tensor, Tensor]:
"""
:param seqs:
The sequences to mask. *Shape:* :math:`(N,S,M)`, where :math:`N` is
the batch size, :math:`S` is the sequence length, and :math:`M` is
the dimensionality of the model.
:param seq_lens:
An array where each element represents the length of the sequence at
the same index in ``seqs``. *Shape:* :math:`(N)`, where :math:`N` is
the batch size.

:returns:
- The input sequences with mask applied. *Shape:* Same as ``seqs``.
- The temporal mask that has been applied to ``seqs``. *Shape:*
:math:`(N,S)`, where :math:`N` is the batch size and :math`S` is
the sequence length.
"""


@final
class Wav2Vec2Masker(Module):
"""Masks extracted features as described in Section 3.1 of
class StandardWav2Vec2Masker(Wav2Vec2Masker):
"""Masks extracted wav2vec 2.0 features as described in Section 3.1 of
:cite:t:`https://doi.org/10.48550/arxiv.2006.11477`."""

mask_factory: RowMaskFactory
temporal_span_len: int
max_temporal_mask_prob: float
temporal_mask_embed: Parameter
Expand All @@ -39,6 +67,7 @@ def __init__(
max_spatial_mask_prob: float = 0.0,
min_num_spatial_mask_spans: int = 2,
*,
mask_factory: RowMaskFactory | None = None,
device: Device | None = None,
dtype: DataType | None = None,
) -> None:
Expand All @@ -56,9 +85,14 @@ def __init__(
:param max_spatial_mask_prob:
The maximum probability of masking a feature. Note that, due to mask
span overlap, the effective probability will be lower.
:param mask_factory:
The row mask factory. If ``None``, :func:`compute_row_mask` will be
used.
"""
super().__init__()

self.mask_factory = mask_factory or compute_row_mask

if max_temporal_mask_prob == 0.0:
raise ValueError("`max_temporal_mask_prob` must be greater than 0.")

Expand All @@ -80,29 +114,14 @@ def reset_parameters(self) -> None:
"""Reset the parameters and buffers of the module."""
nn.init.uniform_(self.temporal_mask_embed)

@override
def forward(
self, seqs: Tensor, padding_mask: PaddingMask | None
) -> tuple[Tensor, Tensor]:
"""
:param seqs:
The sequences to mask. *Shape:* :math:`(N,S,M)`, where :math:`N` is
the batch size, :math:`S` is the sequence length, and :math:`M` is
the dimensionality of the model.
:param seq_lens:
An array where each element represents the length of the sequence at
the same index in ``seqs``. *Shape:* :math:`(N)`, where :math:`N` is
the batch size.

:returns:
- The input sequences with mask applied. *Shape:* Same as ``seqs``.
- The temporal mask that has been applied to ``seqs``. *Shape:*
:math:`(N,S)`, where :math:`N` is the batch size and :math`S` is
the sequence length.
"""
batch_size, seq_len, model_dim = seqs.shape

# Temporal mask over time steps.
temporal_mask = compute_row_mask(
temporal_mask = self.mask_factory(
shape=(batch_size, seq_len),
span_len=self.temporal_span_len,
max_mask_prob=self.max_temporal_mask_prob,
Expand All @@ -118,7 +137,7 @@ def forward(
if self.max_spatial_mask_prob > 0.0:
# Spatial mask over features.
# (N, M)
spatial_mask = compute_row_mask(
spatial_mask = self.mask_factory(
shape=(batch_size, model_dim),
span_len=self.spatial_span_len,
max_mask_prob=self.max_spatial_mask_prob,
Expand All @@ -137,7 +156,7 @@ def forward(

def extra_repr(self) -> str:
""":meta private:"""
return (
s = (
f"temporal_span_len={self.temporal_span_len}, "
f"max_temporal_mask_prob={self.max_temporal_mask_prob}, "
f"min_num_temporal_mask_spans={self.min_num_temporal_mask_spans}, "
Expand All @@ -146,6 +165,13 @@ def extra_repr(self) -> str:
f"min_num_spatial_mask_spans={self.min_num_spatial_mask_spans}"
)

if self.mask_factory is not compute_row_mask:
mask_factory = getattr(self.mask_factory, "__name__", self.mask_factory)

s = f"{s}, mask_factory={mask_factory}"

return s


def extract_masked_elements(seqs: Tensor, temporal_mask: Tensor) -> Tensor:
"""Extract masked elements from ``seqs``.
Expand Down
4 changes: 1 addition & 3 deletions src/fairseq2/models/wav2vec2/vector_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import final

Check failure on line 11 in src/fairseq2/models/wav2vec2/vector_quantizer.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'typing.final' imported but unused

import torch
import torch.nn as nn
Expand Down Expand Up @@ -62,7 +62,6 @@
pass


@final
class GumbelVectorQuantizer(VectorQuantizer):
"""Quantizes incoming data using Gumbel-Softmax."""

Expand Down Expand Up @@ -144,7 +143,7 @@
self.num_updates.zero_()

@override
def forward(self, x: Tensor) -> "GumbelVectorQuantizerOutput":
def forward(self, x: Tensor) -> GumbelVectorQuantizerOutput:
current_temp = self._compute_current_temp()

bsz, tsz, fsz = x.shape
Expand Down Expand Up @@ -221,7 +220,6 @@
nn.init.zeros_(proj.bias)


@final
@dataclass
class GumbelVectorQuantizerOutput(VectorQuantizerOutput):
cb: Tensor
Expand Down
59 changes: 38 additions & 21 deletions src/fairseq2/nn/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

from typing import Protocol

import torch
from torch import Tensor

Expand All @@ -28,6 +30,37 @@ def to_float_mask(mask: Tensor, dtype: DataType | None = None) -> Tensor:
return torch.zeros_like(mask, dtype=dtype).masked_fill_(mask, -torch.inf)


class RowMaskFactory(Protocol):
def __call__(
self,
shape: tuple[int, int],
span_len: int,
max_mask_prob: float,
row_lens: Tensor | None = None,
min_num_spans: int = 0,
device: Device | None = None,
) -> Tensor | None:
"""Compute a random row mask of the specified shape.

:param shape:
The shape of the mask.
:param span_len:
The length of each mask span.
:param max_mask_prob:
The maximum probability of masking an element in a row.
:param row_lens:
The length of each row. *Shape:* :math:`(R)`, where :math:`R` is the
number of rows.
:param min_num_spans:
The minimum number of mask spans per row.
:param device:
The device on which to initialize the mask.

:returns:
The boolean row mask. *:Shape:* ``shape``.
"""


def compute_row_mask(
shape: tuple[int, int],
span_len: int,
Expand All @@ -36,27 +69,11 @@ def compute_row_mask(
min_num_spans: int = 0,
device: Device | None = None,
) -> Tensor | None:
"""Compute a random row mask of the specified shape.

:param shape:
The shape of the mask.
:param span_len:
The length of each mask span.
:param max_mask_prob:
The maximum probability of masking an element in a row. Note that, due
to mask span overlap, the effective probability will be lower. The
implementation also guarantees that there will be always at least one
unmasked element in each row.
:param row_lens:
The length of each row. *Shape:* :math:`(R)`, where :math:`R` is the
number of rows.
:param min_num_spans:
The minimum number of mask spans per row.
:param device:
The device on which to initialize the mask.

:returns:
The boolean row mask. *:Shape:* ``shape``.
"""Implements the :class:`RowMaskFactory` protocol.

Note that, due to mask span overlap, the effective mask probability will be
lower than ``max_mask_prob``. The implementation also guarantees that there
will be always at least one unmasked element in each row.
"""
num_rows, max_row_len = shape

Expand Down
Loading