Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FactoredTriphoneBlockV1 #58

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions i6_models/parts/factored_hybrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
"FactoredDiphoneBlockV1",
"FactoredDiphoneBlockV2Config",
"FactoredDiphoneBlockV2",
"FactoredTriphoneBlockV1Config",
"FactoredTriphoneBlockV1",
"BoundaryClassV1",
]

from .diphone import *
from .triphone import *
from .util import BoundaryClassV1
6 changes: 6 additions & 0 deletions i6_models/parts/factored_hybrid/diphone.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,13 @@ def forward_factored(
return logits_center, logits_left, contexts_embedded_left

def forward_joint(self, features: Tensor) -> Tensor:
"""See `forward_joint_diphone`."""
return self.forward_joint_diphone(features)

def forward_joint_diphone(self, features: Tensor) -> Tensor:
"""
Computes log p(c,l|h(x)), i.e. forwards the network for the full diphone joint.

:param features: Main encoder output. shape B, T, F. F=num_inputs
:return: log probabilities for p(c,l|x).
"""
Expand Down
79 changes: 79 additions & 0 deletions i6_models/parts/factored_hybrid/triphone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
__all__ = [
"FactoredTriphoneBlockV1Config",
"FactoredTriphoneBlockV1",
]

from dataclasses import dataclass
from typing import Tuple

import torch
from torch import nn, Tensor

from .diphone import FactoredDiphoneBlockV1, FactoredDiphoneBlockV2Config
from .util import get_mlp


@dataclass
class FactoredTriphoneBlockV1Config(FactoredDiphoneBlockV2Config):
"""
Attributes:
Same as the FactoredDiphoneBlockV2Config.
"""
michelwi marked this conversation as resolved.
Show resolved Hide resolved


class FactoredTriphoneBlockV1(FactoredDiphoneBlockV1):
"""
Triphone FH model output block.

Consumes the output h(x) of a main encoder model and computes factored logits/probabilities
for p(c|l,h(x)), p(l|h(x)) and p(r|c,l,h(x)).
michelwi marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, cfg: FactoredTriphoneBlockV1Config):
super().__init__(cfg)

self.center_state_embedding = nn.Embedding(self.num_center, cfg.center_state_embedding_dim)
self.right_context_encoder = get_mlp(
num_input=cfg.num_inputs + cfg.center_state_embedding_dim + cfg.left_context_embedding_dim,
num_output=self.num_contexts,
hidden_dim=cfg.context_mix_mlp_dim,
num_layers=cfg.context_mix_mlp_num_layers,
dropout=cfg.dropout,
activation=cfg.activation,
)

# update type definitions
def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
return super().forward(*args, **kwargs)

def forward_factored(
self,
features: Tensor, # B, T, F
contexts_left: Tensor, # B, T
contexts_center: Tensor, # B, T
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
:param features: Main encoder output. shape B, T, F. F=num_inputs
:param contexts_left: The left contexts used to compute p(c|l,x), shape B, T.
Valid values range from [0, num_contexts).
:param contexts_center: The center states used to compute p(r|c,l,x), shape B, T.
Given that the center state also contains the word-end class and HMM state ID, the valid values
range from [0, num_center_states), where num_center_states >= num_contexts.
:return: tuple of logits for p(c|l,x), p(l|x), p(r|c,l,x) and the embedded left context and center state values.
michelwi marked this conversation as resolved.
Show resolved Hide resolved
"""

logits_center, logits_left, contexts_left_embedded = super().forward_factored(features, contexts_left)

# This logic is very similar to FactoredDiphoneBlockV2.forward, but not the same.
# This class computes `p(r|c,l,h(x))` while FactoredDiphoneBlockV2 computes `p(r|c,h(x))`.
center_states_embedded = self.center_state_embedding(contexts_center) # B, T, E'
features_right = torch.cat((features, center_states_embedded, contexts_left_embedded), -1) # B, T, F+E'+E
logits_right = self.right_context_encoder(features_right) # B, T, C

return logits_center, logits_left, logits_right, contexts_left_embedded, center_states_embedded
michelwi marked this conversation as resolved.
Show resolved Hide resolved

def forward_joint(self, features: Tensor) -> Tensor:
raise NotImplementedError(
"It is computationally infeasible to forward the full triphone joint, "
"only the diphone joint can be computed via forward_joint_diphone."
)
54 changes: 54 additions & 0 deletions tests/test_fh.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
FactoredDiphoneBlockV1Config,
FactoredDiphoneBlockV2,
FactoredDiphoneBlockV2Config,
FactoredTriphoneBlockV1,
FactoredTriphoneBlockV1Config,
)
from i6_models.parts.factored_hybrid.util import get_center_dim

Expand Down Expand Up @@ -96,3 +98,55 @@ def test_v2_output_shape_and_norm():
assert output_right.shape == (b, t, n_ctx)
cdim = get_center_dim(n_ctx, states_per_ph, we_class)
assert output_center.shape == (b, t, cdim)


def test_tri_output_shape_and_norm():
n_ctx = 42
n_in = 32

for we_class, states_per_ph in product(
[BoundaryClassV1.none, BoundaryClassV1.word_end, BoundaryClassV1.boundary],
[1, 3],
):
tri_block = FactoredTriphoneBlockV1(
FactoredTriphoneBlockV1Config(
activation=nn.ReLU,
context_mix_mlp_dim=64,
context_mix_mlp_num_layers=2,
dropout=0.1,
left_context_embedding_dim=32,
center_state_embedding_dim=128,
num_contexts=n_ctx,
num_hmm_states_per_phone=states_per_ph,
num_inputs=n_in,
boundary_class=we_class,
)
)

for b, t in product([10, 50, 100], [10, 50, 100]):
cdim = get_center_dim(n_ctx, states_per_ph, we_class)
contexts_left = torch.randint(0, n_ctx, (b, t))
contexts_center = torch.randint(0, tri_block.num_center, (b, t))
encoder_output = torch.rand((b, t, n_in))
output_center, output_left, output_right, _, _ = tri_block(
features=encoder_output, contexts_left=contexts_left, contexts_center=contexts_center
)
assert output_left.shape == (b, t, n_ctx)
assert output_center.shape == (b, t, cdim)
assert output_right.shape == (b, t, n_ctx)

try:
tri_block.forward_joint(encoder_output)
except NotImplementedError:
pass
else:
assert False, "expected Error, did not get any"

encoder_output = torch.rand((b, t, n_in))
output = tri_block.forward_joint_diphone(features=encoder_output)
cdim = get_center_dim(n_ctx, states_per_ph, we_class)
assert output.shape == (b, t, cdim * n_ctx)
output_p = torch.exp(output)
ones_hopefully = torch.sum(output_p, dim=-1)
close_to_one = torch.abs(1 - ones_hopefully).flatten() < 1e-3
assert all(close_to_one)
Loading