From 08627128f193a2230a105420fe19650464e6a051 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Tue, 1 Jun 2021 11:00:25 +0200 Subject: [PATCH] limit "t" and correct prev non blank for search --- common/models/transducer/transducer_fullsum.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/common/models/transducer/transducer_fullsum.py b/common/models/transducer/transducer_fullsum.py index 93dc815b..204817ce 100644 --- a/common/models/transducer/transducer_fullsum.py +++ b/common/models/transducer/transducer_fullsum.py @@ -229,11 +229,15 @@ def make(self, encoder: LayerRef): blank_idx = self.ctx.blank_idx rec_decoder = { - "am0": {"class": "gather_nd", "from": _base(encoder), "position": "prev:t"}, # [B,D] + "t_": {"class": "eval", "from": ["prev:t", "enc_seq_len"], "eval": 'tf.minimum(source(0), source(1)-1)'}, + "am0": {"class": "gather_nd", "from": _base(encoder), "position": "t_"}, # [B,D] "am": {"class": "copy", "from": "am0" if search else "data:source"}, + "prev_output_wo_b": { + "class": "masked_computation", "unit": {"class": "copy", "initial_output": 0}, + "from": "prev:output_", "mask": "prev:output_emit", "initial_output": 0}, "prev_out_non_blank": { - "class": "reinterpret_data", "from": "prev:output_", "set_sparse_dim": target.get_num_classes()}, + "class": "reinterpret_data", "from": "prev_output_wo_b", "set_sparse_dim": target.get_num_classes()}, "slow_rnn": self.slow_rnn.make( prev_sparse_label_nb="prev_out_non_blank", @@ -252,7 +256,7 @@ def make(self, encoder: LayerRef): "output": { "class": 'choice', - 'target': target.key, # note: wrong! but this is ignored both in full-sum training and in search + 'target': target.key if train else None, 'beam_size': beam_size, 'from': "output_log_prob_wb", "input_type": "log_prob", "initial_output": 0,