Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add updated version of fast Baum-Welch algorithm #10

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions i6_native_ops/common/returnn_definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
#define Ndarray torch::Tensor
#define Ndarray_DEV_DATA(x) ((float*)(x).data_ptr())
#define Ndarray_DEV_DATA_int32(x) ((int32_t*)(x).data_ptr())
#define Ndarray_DEV_DATA_uint32(x) ((uint32_t*)(x).data_ptr())
#define Ndarray_DEV_DATA_int32_scalar(x) (x).scalar<int32>()()
#define Ndarray_HOST_DIMS(x) ((x).sizes())
#define Ndarray_DIMS(x) ((x).sizes())
#define Ndarray_DIMS Ndarray_HOST_DIMS
#define Ndarray_NDIM(x) (x).ndimension()
#define Ndarray_dtype_size(x) torch::elementSize((x).scalar_type())
typedef long long Ndarray_DIM_Type;
Expand Down Expand Up @@ -60,19 +60,12 @@ typedef long long Ndarray_DIM_Type;
static const char* _cudaGetErrorEnum(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";

case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";

case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";

case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";

case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";

case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";

case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";

case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
}

Expand Down Expand Up @@ -189,4 +182,4 @@ static inline void device_free(void* ptr) {
cudaFree(ptr);
}

#endif // _COMMON_RETURNN_DEFINITIONS_H
#endif // _COMMON_RETURNN_DEFINITIONS_H
42 changes: 24 additions & 18 deletions i6_native_ops/fbw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,57 @@

try:
# Package is installed, so ops are already compiled
__version__ = get_distribution('i6_native_ops').version
from . import fbw_core as core
except Exception as e:
__version__ = get_distribution("i6_native_ops").version
from .fbw_core import DebugOptions, fbw
except Exception:
# otherwise try to build locally
from torch.utils.cpp_extension import load

base_path = os.path.dirname(__file__)
core = load(
name="fbw_core",
sources=[
os.path.join(base_path, "fbw_torch.cpp"),
os.path.join(base_path, "fbw_op.cu"),
],
extra_include_paths=[
base_path,
os.path.join(base_path, "..", "common")
],
extra_include_paths=[base_path, os.path.join(base_path, "..", "common")],
)
DebugOptions = core.DebugOptions
fbw = core.fbw


class FastBaumWelchLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, am_scores, fsa, seq_lens):
def forward(ctx, am_scores, fsa, seq_lens, debug_options=None):
num_states, edge_tensor, weight_tensor, start_end_states = fsa

grad, loss = core.fbw(
am_scores, edge_tensor, weight_tensor,
start_end_states, seq_lens, int(num_states),
core.DebugOptions()
debug_options = debug_options or DebugOptions()

grad, loss = fbw(
am_scores,
edge_tensor,
weight_tensor,
start_end_states,
seq_lens,
int(num_states),
debug_options,
)
ctx.save_for_backward(grad)
return loss

@staticmethod
def backward(ctx, grad_loss):
# negative log prob -> prob
grad = ctx.saved_tensors[0].neg().exp()
return grad, None, None
return grad, None, None, None


def fbw_loss(
log_probs: torch.FloatTensor,
fsa: Tuple[int, torch.IntTensor, torch.FloatTensor, torch.IntTensor],
seq_lens: torch.IntTensor
seq_lens: torch.IntTensor,
) -> torch.FloatTensor:
"""
"""
Computes negative log likelihood of an emission model given an HMM finite state automaton.
The corresponding gradient with respect to the emission model is automatically backpropagated.
:param log_probs: log probabilities of emission model as a [B, T, F] tensor
Expand All @@ -63,6 +69,6 @@ def fbw_loss(
:param seq_lens: (B,) tensor consisting of the sequence lengths
:return: (B,) tensor of loss values
"""
neg_log_probs = log_probs.neg().transpose(0, 1).contiguous() # [T, B, F]
neg_log_probs = log_probs.neg().transpose(0, 1).contiguous() # [T, B, F]
loss = FastBaumWelchLoss.apply(neg_log_probs, fsa, seq_lens)
return loss
return loss
3 changes: 1 addition & 2 deletions i6_native_ops/fbw/fbw_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,7 @@ std::vector<torch::Tensor> fbw_cuda(torch::Tensor& am_scores, torch::Tensor& edg
// initialize buffers
float* d_state_buffer_prev = reinterpret_cast<float*>(device_malloc(n_states * sizeof(float)));
float* d_state_buffer_next = reinterpret_cast<float*>(device_malloc(n_states * sizeof(float)));
float* d_edge_buffer =
reinterpret_cast<float*>(device_malloc(n_edges * n_frames * sizeof(float)));
float* d_edge_buffer = reinterpret_cast<float*>(device_malloc(n_edges * n_frames * sizeof(float)));
if (!d_edge_buffer || !d_state_buffer_prev || !d_state_buffer_next) {
HANDLE_LAST_ERROR();
abort();
Expand Down
13 changes: 13 additions & 0 deletions i6_native_ops/fbw2/DebugOptions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef _DEBUG_OPTIONS_H
#define _DEBUG_OPTIONS_H

typedef struct {
bool dump_edges = false;
bool dump_alignment = false;
bool dump_output = false;
unsigned dump_every = 40u;
float pruning = 20.f;
unsigned explicit_merge = false;
} DebugOptionsV2;

#endif // _DEBUG_OPTIONS_H
84 changes: 84 additions & 0 deletions i6_native_ops/fbw2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
__all__ = ["DebugOptionsV2", "FastBaumWelch2Loss", "fbw2_loss"]

import os
import torch # needed to find pytorch specific libs
from pkg_resources import get_distribution
from typing import Tuple

try:
# Package is installed, so ops are already compiled
__version__ = get_distribution("i6_native_ops").version
from .fbw2_core import DebugOptionsV2, fbw2
except Exception:
# otherwise try to build locally
from torch.utils.cpp_extension import load

base_path = os.path.dirname(__file__)
core = load(
name="fbw2_core",
sources=[
os.path.join(base_path, "fbw2_torch.cpp"),
os.path.join(base_path, "fbw2_op.cu"),
],
extra_include_paths=[base_path, os.path.join(base_path, "..", "common")],
)
DebugOptionsV2 = core.DebugOptionsV2
fbw2 = core.fbw2


class FastBaumWelch2Loss(torch.autograd.Function):
@staticmethod
def forward(ctx, am_scores, fsa, seq_lens, debug_opts=None):
num_states, num_edges, edge_tensor, weight_tensor, start_end_states = fsa

if debug_opts is None:
debug_opts = DebugOptionsV2()
grad, loss = fbw2(
num_states,
num_edges,
seq_lens,
am_scores,
edge_tensor,
weight_tensor,
start_end_states,
debug_opts,
)
ctx.save_for_backward(grad)
return loss

@staticmethod
def backward(ctx, grad_loss):
# negative log prob -> prob
grad = ctx.saved_tensors[0].neg().exp()
return grad, None, None, None


def fbw2_loss(
log_probs: torch.FloatTensor,
fsa: Tuple[
torch.IntTensor,
torch.IntTensor,
torch.IntTensor,
torch.FloatTensor,
torch.IntTensor,
],
seq_lens: torch.IntTensor,
) -> torch.FloatTensor:
"""
Computes negative log likelihood of an emission model given an HMM finite state automaton.
The corresponding gradient with respect to the emission model is automatically backpropagated.
:param log_probs: log probabilities of emission model as a [B, T, F] tensor
:param fsa: weighted finite state automaton as a tuple consisting of:
* a (B) tensor with number of states per automaton
* a (B) tensor with number of edges per automaton
* a (4, E) tensor of integers specifying where each column consists of
origin state, target state, emission idx and the index of the sequence
* a (E,) tensor of floats holding the weight of each edge
* a (2, B) tensor of starting and ending states for each automaton in the batch where
the first row are starting states and the second the corresponding ending states
:param seq_lens: (B,) tensor consisting of the sequence lengths
:return: (B,) tensor of loss values
"""
neg_log_probs = log_probs.neg().transpose(0, 1).contiguous() # [T, B, F]
loss = FastBaumWelch2Loss.apply(neg_log_probs, fsa, seq_lens)
return loss
Loading