diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py index d299734d..0a1c8d2c 100644 --- a/i6_models/parts/conformer/mhsa_rel_pos.py +++ b/i6_models/parts/conformer/mhsa_rel_pos.py @@ -240,6 +240,10 @@ def _sinusoidal_pe(pos_seq: torch.Tensor, embed_dim: int): inv_freq = 1 / (10000 ** (torch.arange(0.0, embed_dim, 2.0, device=pos_seq.device) / embed_dim)) sinusoid_input = torch.outer(pos_seq, inv_freq) - pos_emb = torch.cat([sinusoid_input.sin(), sinusoid_input.cos()], dim=-1) # [num. positions, embed_dim] + + pos_emb = torch.zeros(pos_seq.shape[0], embed_dim) + + pos_emb[:, 0::2] = sinusoid_input.sin() + pos_emb[:, 1::2] = sinusoid_input.cos() return pos_emb diff --git a/requirements_dev.txt b/requirements_dev.txt index 60dd5826..b5584a47 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,2 +1,3 @@ onnx -onnxruntime \ No newline at end of file +onnxruntime +espnet diff --git a/tests/test_conformer_rel_pos.py b/tests/test_conformer_rel_pos.py index 88b2dba3..08d85d64 100644 --- a/tests/test_conformer_rel_pos.py +++ b/tests/test_conformer_rel_pos.py @@ -129,3 +129,77 @@ def get_output_shape( with_linear_pos=with_linear_pos, separate_pos_emb_per_head=separate_pos_emb_per_head, ) == [4, 15, 32] + + +def test_ConformerMHSARelPosV1_against_Espnet(): + from espnet2.asr_transducer.encoder.modules.attention import RelPositionMultiHeadedAttention + from espnet2.asr_transducer.encoder.modules.positional_encoding import RelPositionalEncoding + + num_heads = 4 + embed_size = 256 + dropout_rate = 0.1 + batch_dim_size = 4 + time_dim_size = 50 + + espnet_mhsa_module = RelPositionMultiHeadedAttention( + num_heads=num_heads, embed_size=embed_size, dropout_rate=dropout_rate + ) + espnet_mhsa_module.eval() + espnet_pos_enc_module = RelPositionalEncoding(embed_size, dropout_rate=dropout_rate) + espnet_pos_enc_module.eval() + + cfg = ConformerMHSARelPosV1Config( + input_dim=embed_size, + num_att_heads=num_heads, + with_bias=True, + att_weights_dropout=dropout_rate, + dropout=dropout_rate, + learnable_pos_emb=False, + with_linear_pos=True, + separate_pos_emb_per_head=True, + rel_pos_clip=None, + with_pos_bias=True, + pos_emb_dropout=dropout_rate, + dropout_broadcast_axes=None, + ) + own_mhsa_module = ConformerMHSARelPosV1(cfg) + own_mhsa_module.eval() + own_mhsa_module.linear_pos = espnet_mhsa_module.linear_pos + own_mhsa_module.pos_bias_u = espnet_mhsa_module.pos_bias_u + own_mhsa_module.pos_bias_v = espnet_mhsa_module.pos_bias_v + own_mhsa_module.out_proj = espnet_mhsa_module.linear_out + own_mhsa_module.qkv_proj.weight = nn.Parameter( + torch.cat( + [ + espnet_mhsa_module.linear_q.weight, + espnet_mhsa_module.linear_k.weight, + espnet_mhsa_module.linear_v.weight, + ], + dim=0, + ) + ) + own_mhsa_module.qkv_proj.bias = nn.Parameter( + torch.cat( + [espnet_mhsa_module.linear_q.bias, espnet_mhsa_module.linear_k.bias, espnet_mhsa_module.linear_v.bias], + dim=0, + ) + ) + + input_tensor = torch.rand((batch_dim_size, time_dim_size, embed_size)) + sequence_mask = torch.ones((batch_dim_size, time_dim_size)) + inv_sequence_mask = torch.logical_not(sequence_mask) + + input_tensor_layernorm = own_mhsa_module.layernorm(input_tensor) + + espnet_pos_enc = espnet_pos_enc_module(input_tensor_layernorm) + espnet_output_tensor = espnet_mhsa_module( + query=input_tensor_layernorm, + key=input_tensor_layernorm, + value=input_tensor_layernorm, + pos_enc=espnet_pos_enc, + mask=inv_sequence_mask, + ) + + own_output_tensor = own_mhsa_module(input_tensor, sequence_mask=sequence_mask) + + assert torch.allclose(espnet_output_tensor, own_output_tensor, rtol=1e-03)