From 92a561a51625b058d22f8eaf64cbe0c011405654 Mon Sep 17 00:00:00 2001 From: blisc Date: Tue, 21 Jan 2025 18:48:17 +0000 Subject: [PATCH 1/2] Apply isort and black reformatting Signed-off-by: blisc --- tests/collections/tts/modules/test_transformer_2501.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/collections/tts/modules/test_transformer_2501.py b/tests/collections/tts/modules/test_transformer_2501.py index b3d4022150c7..feecb7ef51a8 100644 --- a/tests/collections/tts/modules/test_transformer_2501.py +++ b/tests/collections/tts/modules/test_transformer_2501.py @@ -18,9 +18,8 @@ import pytest import torch -from nemo.collections.tts.modules.transformer_2501 import ( - Transformer, -) +from nemo.collections.tts.modules.transformer_2501 import Transformer + def set_seed(seed): torch.manual_seed(seed) @@ -28,6 +27,7 @@ def set_seed(seed): np.random.seed(seed) random.seed(seed) + @pytest.mark.unit class TestTransformer: @classmethod From 4f92fb05a0e31915a252fbe6d4ad004df4a2c1d3 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Tue, 21 Jan 2025 11:10:07 -0800 Subject: [PATCH 2/2] bugfix for unit test: fixed speech masks as torch.ones instead of torch.zeros. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- tests/collections/tts/modules/test_transformer_2501.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collections/tts/modules/test_transformer_2501.py b/tests/collections/tts/modules/test_transformer_2501.py index feecb7ef51a8..da235b8e0582 100644 --- a/tests/collections/tts/modules/test_transformer_2501.py +++ b/tests/collections/tts/modules/test_transformer_2501.py @@ -73,7 +73,7 @@ def test_forward_causal_self_attn_and_no_xattn(self): max_length_causal_mask=self.max_length_causal_mask, ) - mask_tensor = torch.zeros(1, self.max_length_causal_mask).bool() + mask_tensor = torch.ones(1, self.max_length_causal_mask).bool() with torch.no_grad(): output_dict = model(x=self.input_tensor, x_mask=mask_tensor)