Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
Signed-off-by: Jason <[email protected]>
  • Loading branch information
blisc committed Jan 21, 2025
2 parents 0aeb4cb + 4f92fb0 commit 85cf416
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/collections/tts/modules/test_transformer_2501.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
import pytest
import torch

from nemo.collections.tts.modules.transformer_2501 import (
Transformer,
)
from nemo.collections.tts.modules.transformer_2501 import Transformer


def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)


@pytest.mark.unit
class TestTransformer:
@classmethod
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_forward_causal_self_attn_and_no_xattn(self):
assert torch.isclose(torch.mean(model.layers[0].pos_ff.o_net.weight), 0.)
assert torch.isclose(torch.std(model.layers[0].pos_ff.o_net.weight), 0.02/math.sqrt(2.))

mask_tensor = torch.zeros(1, self.max_length_causal_mask).bool()
mask_tensor = torch.ones(1, self.max_length_causal_mask).bool()
with torch.no_grad():
output_dict = model(x=self.input_tensor, x_mask=mask_tensor)

Expand Down

0 comments on commit 85cf416

Please sign in to comment.