The monotonic RNN-T loss can be written as
where
The loss and gradients can be computed using the forward-backward-algorithm. For this, a forward variable
and a backward variable
are introduced. These have the property
and adhere to the recursive equations
and
(excluding edge cases).
For the gradients it is straightforward to prove that for any
And thus
which means for the overall gradient
For expressing the derivative directly with respect to the logits
Assume the following model posteriors
// t = 1
0.6, 0.3, 0.1, // s = 0
0.7, 0.1, 0.2, // s = 1
0.5, 0.1, 0.4, // s = 2
// t = 2
0.5, 0.4, 0.1, // s = 0
0.5, 0.1, 0.4, // s = 1
0.8, 0.1, 0.1, // s = 2
// t = 3
0.4, 0.3, 0.3, // s = 0
0.5, 0.1, 0.4, // s = 1
0.7, 0.2, 0.1, // s = 2
// t = 4
0.8, 0.1, 0.1, // s = 0
0.3, 0.1, 0.6, // s = 1
0.8, 0.1, 0.1 // s = 2
For
- . . 1 2
- . 1 . 2
- . 1 2 .
- 1 . . 2
- 1 . 2 .
- 1 2 . .
The 6 paths have probabilities of
- 0.6 * 0.5 * 0.3 * 0.6 = 0.0540
- 0.6 * 0.4 * 0.5 * 0.6 = 0.0720
- 0.6 * 0.4 * 0.4 * 0.8 = 0.0768
- 0.3 * 0.5 * 0.5 * 0.6 = 0.0450
- 0.3 * 0.5 * 0.4 * 0.8 = 0.0480
- 0.3 * 0.4 * 0.7 * 0.8 = 0.0672
wich sum to a total of 0.363, i.e. -1.0134 in log space
The alphas then are as follows in probability and log space:
- a(0, 0) = 1.0 -> 0.0
- a(1, 0) = 0.6 -> -0.51
- a(1, 1) = 0.3 -> -1.20
- a(2, 0) = 0.5 * a(1, 0) = 0.3 -> -1.20
- a(2, 1) = 0.5 * a(1, 1) + 0.4 * a(1, 0) = 0.39 -> -0.94
- a(2, 2) = 0.4 * a(1, 1) = 0.12 -> -2.12
- a(3, 1) = 0.5 * a(2, 1) + 0.3 * a(2, 0) = 0.285 -> -1.26
- a(3, 2) = 0.7 * a(2, 2) + 0.4 * a(2, 1) = 0.24 -> -1.43
- a(4, 2) = 0.8 * a(3, 2) + 0.6 * a(3, 1) = 0.363 -> -1.01
And the betas are as follows in probability and log space:
- b(5, 2) = 1.0 -> 0.0
- b(4, 2) = 0.8 -> -0.22
- b(4, 1) = 0.6 -> -0.51
- b(3, 2) = 0.7 * b(4, 2) = 0.56 -> -0.58
- b(3, 1) = 0.5 * b(4, 1) + 0.4 * b(4, 2) = 0.62 -> -0.48
- b(3, 0) = 0.3 * b(4, 1) = 0.18 -> -1.71
- b(2, 1) = 0.5 * b(3, 1) + 0.4 * b(3, 2) = 0.534 -> -0.63
- b(2, 0) = 0.5 * b(3, 0) + 0.4 * b(3, 1) = 0.338 -> -1.08
- b(1, 0) = 0.6 * b(2, 0) + 0.3 * b(2, 1) = 0.363 -> -1.01
As we can see
Now, the gradients with respect to all the logits can be computed as
// t = 1
0.04, -0.14, 0.1, // s = 0
0.0, 0.0, 0.0, // s = 1
0.0, 0.0, 0.0, // s = 2
// t = 2
0.13, -0.19, 0.06, // s = 0
-0.04, 0.04, -0.01, // s = 1
0.0, 0.0, 0.0, // s = 2
// t = 3
0.06, -0.1, 0.04, // s = 0
0.01, 0.07, -0.08, // s = 1
-0.06, 0.04, 0.02, // s = 2
// t = 4
0.0, 0.0, 0.0, // s = 0
0.14, 0.05, -0.19, // s = 1
-0.11, 0.05, 0.05 // s = 2