Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FactoredTriphoneBlockV1 #58

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion i6_models/parts/factored_hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
__all__ = ["FactoredDiphoneBlockV1Config", "FactoredDiphoneBlockV1", "BoundaryClassV1"]
__all__ = [
"FactoredDiphoneBlockV1Config",
"FactoredDiphoneBlockV1",
"FactoredDiphoneBlockV2Config",
"FactoredDiphoneBlockV2",
"FactoredTriphoneBlockV1Config",
"FactoredTriphoneBlockV1",
"BoundaryClassV1",
]

from .diphone import *
from .triphone import *
from .util import BoundaryClassV1
71 changes: 71 additions & 0 deletions i6_models/parts/factored_hybrid/diphone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__all__ = [
"FactoredDiphoneBlockV1Config",
"FactoredDiphoneBlockV1",
"FactoredDiphoneBlockV2Config",
"FactoredDiphoneBlockV2",
]

from dataclasses import dataclass
Expand Down Expand Up @@ -108,6 +110,7 @@ def forward_factored(
"""
:param features: Main encoder output. shape B, T, F. F=num_inputs
:param contexts_left: The left contexts used to compute p(c|l,x), shape B, T.
Valid values range from [0, num_contexts).
:return: tuple of logits for p(c|l,x), p(l|x) and the embedded left context values.
"""

Expand Down Expand Up @@ -148,3 +151,71 @@ def forward_joint(self, features: Tensor) -> Tensor:
) # B, T, F'*C

return joint_log_probs


@dataclass
class FactoredDiphoneBlockV2Config(FactoredDiphoneBlockV1Config):
"""
Attributes:
Same attributes as parent class. In addition:

center_state_embedding_dim: embedding dimension of the center state
values. Good choice is in the order of num_center_states.
"""

center_state_embedding_dim: int

def __post_init__(self):
super().__post_init__()

assert self.center_state_embedding_dim > 0


class FactoredDiphoneBlockV2(FactoredDiphoneBlockV1):
"""
Like FactoredDiphoneBlockV1, but computes an additional diphone output on the right context `p(r|c,x)`.

This additional output is ignored when computing the joint, and only used in training.
"""

def __init__(self, cfg: FactoredDiphoneBlockV2Config):
super().__init__(cfg)

self.center_state_embedding = nn.Embedding(self.num_center, cfg.center_state_embedding_dim)
self.right_context_encoder = get_mlp(
num_input=cfg.num_inputs + cfg.center_state_embedding_dim,
num_output=self.num_contexts,
hidden_dim=cfg.context_mix_mlp_dim,
num_layers=cfg.context_mix_mlp_num_layers,
dropout=cfg.dropout,
activation=cfg.activation,
)

# update type definitions
def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
return super().forward(*args, **kwargs)

def forward_factored(
self,
features: Tensor, # B, T, F
contexts_left: Tensor, # B, T
contexts_center: Tensor, # B, T
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
:param features: Main encoder output. shape B, T, F. F=num_inputs
:param contexts_left: The left contexts used to compute p(c|l,x), shape B, T.
Valid values range from [0, num_contexts).
:param contexts_center: The center states used to compute p(r|c,x), shape B, T.
Given that the center state also contains the word-end class and HMM state ID, the valid values
range from [0, num_center_states), where num_center_states >= num_contexts.
:return: tuple of logits for p(c|l,x), p(l|x), p(r|c,x) and the embedded left context and center state values.
"""

logits_center, logits_left, contexts_left_embedded = super().forward_factored(features, contexts_left)

# in training we forward exactly one context per T, so: B, T, E
center_states_embedded = self.center_state_embedding(contexts_center)
features_right = torch.cat((features, center_states_embedded), -1) # B, T, F+E
logits_right = self.right_context_encoder(features_right)

return logits_center, logits_left, logits_right, contexts_left_embedded, center_states_embedded
69 changes: 69 additions & 0 deletions i6_models/parts/factored_hybrid/triphone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
__all__ = [
"FactoredTriphoneBlockV1Config",
"FactoredTriphoneBlockV1",
]

from dataclasses import dataclass
from typing import Tuple

import torch
from torch import nn, Tensor

from .diphone import FactoredDiphoneBlockV1, FactoredDiphoneBlockV2Config
from .util import get_mlp


@dataclass
class FactoredTriphoneBlockV1Config(FactoredDiphoneBlockV2Config):
"""
Attributes:
Same as the FactoredDiphoneBlockV2Config.
"""
michelwi marked this conversation as resolved.
Show resolved Hide resolved


class FactoredTriphoneBlockV1(FactoredDiphoneBlockV1):
"""
Triphone FH model output block.

Consumes the output h(x) of a main encoder model and computes factored logits/probabilities
for p(c|l,h(x)), p(l|h(x)) and p(r|c,l,h(x)).
michelwi marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, cfg: FactoredTriphoneBlockV1Config):
super().__init__(cfg)

self.center_state_embedding = nn.Embedding(self.num_center, cfg.center_state_embedding_dim)
self.right_context_encoder = get_mlp(
num_input=cfg.num_inputs + cfg.center_state_embedding_dim + cfg.left_context_embedding_dim,
num_output=self.num_contexts,
hidden_dim=cfg.context_mix_mlp_dim,
num_layers=cfg.context_mix_mlp_num_layers,
dropout=cfg.dropout,
activation=cfg.activation,
)

def forward(
michelwi marked this conversation as resolved.
Show resolved Hide resolved
self,
features: Tensor, # B, T, F
contexts_left: Tensor, # B, T
contexts_center: Tensor, # B, T
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
:param features: Main encoder output. shape B, T, F. F=num_inputs
:param contexts_left: The left contexts used to compute p(c|l,x), shape B, T.
Valid values range from [0, num_contexts).
:param contexts_center: The center states used to compute p(r|c,l,x), shape B, T.
Given that the center state also contains the word-end class and HMM state ID, the valid values
range from [0, num_center_states), where num_center_states >= num_contexts.
:return: tuple of logits for p(c|l,x), p(l|x), p(r|c,l,x) and the embedded left context and center state values.
michelwi marked this conversation as resolved.
Show resolved Hide resolved
"""

logits_center, logits_left, contexts_left_embedded = super().forward(features, contexts_left)

# This logic is very similar to FactoredDiphoneBlockV2.forward, but not the same.
# This class computes `p(r|c,l,h(x))` while FactoredDiphoneBlockV2 computes `p(r|c,h(x))`.
center_states_embedded = self.center_state_embedding(contexts_center) # B, T, E'
features_right = torch.cat((features, center_states_embedded, contexts_left_embedded), -1) # B, T, F+E'+E
logits_right = self.right_context_encoder(features_right) # B, T, C

return logits_center, logits_left, logits_right, contexts_left_embedded, center_states_embedded
michelwi marked this conversation as resolved.
Show resolved Hide resolved
86 changes: 84 additions & 2 deletions tests/test_fh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
import torch
import torch.nn as nn

from i6_models.parts.factored_hybrid import BoundaryClassV1, FactoredDiphoneBlockV1, FactoredDiphoneBlockV1Config
from i6_models.parts.factored_hybrid import (
BoundaryClassV1,
FactoredDiphoneBlockV1,
FactoredDiphoneBlockV1Config,
FactoredDiphoneBlockV2,
FactoredDiphoneBlockV2Config,
FactoredTriphoneBlockV1,
FactoredTriphoneBlockV1Config,
)
from i6_models.parts.factored_hybrid.util import get_center_dim


Expand All @@ -16,7 +24,7 @@ def test_dim_calcs():
assert get_center_dim(n_ctx, 3, BoundaryClassV1.boundary) == 504


def test_output_shape_and_norm():
def test_v1_output_shape_and_norm():
n_ctx = 42
n_in = 32

Expand Down Expand Up @@ -54,3 +62,77 @@ def test_output_shape_and_norm():
ones_hopefully = torch.sum(output_p, dim=-1)
close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3
assert all(close_to_one)


def test_v2_output_shape_and_norm():
n_ctx = 42
n_in = 32

for we_class, states_per_ph in product(
[BoundaryClassV1.none, BoundaryClassV1.word_end, BoundaryClassV1.boundary],
[1, 3],
):
block = FactoredDiphoneBlockV2(
FactoredDiphoneBlockV2Config(
activation=nn.ReLU,
context_mix_mlp_dim=64,
context_mix_mlp_num_layers=2,
dropout=0.1,
left_context_embedding_dim=32,
center_state_embedding_dim=128,
num_contexts=n_ctx,
num_hmm_states_per_phone=states_per_ph,
num_inputs=n_in,
boundary_class=we_class,
)
)

for b, t in product([10, 50, 100], [10, 50, 100]):
contexts_left = torch.randint(0, n_ctx, (b, t))
contexts_center = torch.randint(0, block.num_center, (b, t))
encoder_output = torch.rand((b, t, n_in))
output_center, output_left, output_right, _, _ = block(
features=encoder_output, contexts_left=contexts_left, contexts_center=contexts_center
)
assert output_left.shape == (b, t, n_ctx)
assert output_right.shape == (b, t, n_ctx)
cdim = get_center_dim(n_ctx, states_per_ph, we_class)
assert output_center.shape == (b, t, cdim)


def test_tri_output_shape_and_norm():
n_ctx = 42
n_in = 32

for we_class, states_per_ph in product(
[BoundaryClassV1.none, BoundaryClassV1.word_end, BoundaryClassV1.boundary],
[1, 3],
):
tri_block = FactoredTriphoneBlockV1(
FactoredTriphoneBlockV1Config(
activation=nn.ReLU,
context_mix_mlp_dim=64,
context_mix_mlp_num_layers=2,
dropout=0.1,
left_context_embedding_dim=32,
center_state_embedding_dim=128,
num_contexts=n_ctx,
num_hmm_states_per_phone=states_per_ph,
num_inputs=n_in,
boundary_class=we_class,
)
)

for b, t in product([10, 50, 100], [10, 50, 100]):
cdim = get_center_dim(n_ctx, states_per_ph, we_class)
contexts_left = torch.randint(0, n_ctx, (b, t))
contexts_center = torch.randint(0, tri_block.num_center, (b, t))
encoder_output = torch.rand((b, t, n_in))
output_center, output_left, output_right, _, _ = tri_block(
features=encoder_output,
contexts_left=contexts_left,
contexts_center=contexts_center,
)
assert output_left.shape == (b, t, n_ctx)
assert output_center.shape == (b, t, cdim)
assert output_right.shape == (b, t, n_ctx)
Loading