-
Notifications
You must be signed in to change notification settings - Fork 471
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
1,057 additions
and
1,062 deletions.
There are no files selected for viewing
195 changes: 195 additions & 0 deletions
195
tests/torchtune/modules/model_fusion/test_deep_fusion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
|
||
import torch | ||
from tests.test_utils import assert_expected, fixed_init_model | ||
from torch import nn | ||
from torchtune.modules.model_fusion import DeepFusionModel, register_fusion_module | ||
from torchtune.training.seed import set_seed | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def random(): | ||
set_seed(1) | ||
|
||
|
||
class DummyModel(nn.Module): | ||
def __init__(self, dim, vocab_size): | ||
super().__init__() | ||
self.cache_enabled = False | ||
self.tok_embeddings = nn.Embedding(vocab_size, dim) | ||
self.q = nn.Linear(dim, dim) | ||
self.k = nn.Linear(dim, dim) | ||
self.v = nn.Linear(dim, dim) | ||
self.output = nn.Linear(dim, vocab_size) | ||
register_fusion_module(self.output) | ||
|
||
def setup_caches(self, batch_size, dtype, *args, **kwargs): | ||
self.cache_enabled = True | ||
|
||
def caches_are_setup(self): | ||
return self.cache_enabled | ||
|
||
def reset_caches(self): | ||
self.cache_enabled = False | ||
|
||
def forward( | ||
self, | ||
tokens, | ||
*, | ||
mask=None, | ||
encoder_input=None, | ||
encoder_mask=None, | ||
input_pos=None, | ||
): | ||
x = self.tok_embeddings(tokens) | ||
if encoder_input is not None: | ||
q = self.q(x) | ||
k = self.k(encoder_input) if encoder_input is not None else self.k(x) | ||
v = self.v(encoder_input) if encoder_input is not None else self.v(x) | ||
x += nn.functional.scaled_dot_product_attention( | ||
q, k, v, attn_mask=encoder_mask if encoder_mask is not None else mask | ||
) | ||
x = self.output(x) | ||
return x | ||
|
||
|
||
class TestDeepFusionModel: | ||
""" | ||
Class for testing our DeepFusionModel wrapper. | ||
""" | ||
|
||
@pytest.fixture | ||
def vocab_size(self) -> int: | ||
return 100 | ||
|
||
@pytest.fixture | ||
def dim(self) -> int: | ||
return 64 | ||
|
||
@pytest.fixture | ||
def encoder(self, dim, vocab_size) -> nn.Module: | ||
encoder = nn.Embedding(vocab_size, dim) | ||
fixed_init_model(encoder) | ||
return encoder | ||
|
||
@pytest.fixture | ||
def decoder(self, dim, vocab_size) -> nn.Module: | ||
decoder = DummyModel(dim, vocab_size) | ||
fixed_init_model(decoder, max_val=0.1) | ||
return decoder | ||
|
||
@pytest.fixture | ||
def fused_model(self, encoder, decoder) -> DeepFusionModel: | ||
model = DeepFusionModel( | ||
encoder=encoder, | ||
decoder=decoder, | ||
) | ||
return model | ||
|
||
@pytest.fixture | ||
def inputs(self, vocab_size): | ||
batch_size = 2 | ||
seq_len = 10 | ||
tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) | ||
encoder_input = {"input": torch.randint(0, vocab_size, (batch_size, seq_len))} | ||
encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool() | ||
input_pos = torch.Tensor([1]).int() | ||
return tokens, encoder_input, encoder_mask, input_pos | ||
|
||
@torch.no_grad() | ||
def test_forward(self, fused_model, inputs, vocab_size): | ||
""" | ||
Test that the forward pass of the DeepFusionModel works as expected. | ||
""" | ||
tokens, encoder_input, encoder_mask, _ = inputs | ||
batch_size, seq_len = tokens.shape | ||
out = fused_model( | ||
tokens, encoder_input=encoder_input, encoder_mask=encoder_mask | ||
) | ||
|
||
assert out.shape == (batch_size, seq_len, vocab_size) | ||
assert_expected(out.mean(), torch.tensor(8.5584), atol=1e-3, rtol=1e-3) | ||
|
||
@torch.no_grad() | ||
def test_forward_no_encoding(self, fused_model, inputs, vocab_size): | ||
""" | ||
Test that the forward pass of the DeepFusionModel with no encoder input. | ||
""" | ||
tokens, *_ = inputs | ||
batch_size, seq_len = tokens.shape | ||
out = fused_model(tokens) | ||
|
||
assert out.shape == (batch_size, seq_len, vocab_size) | ||
assert_expected(out.mean(), torch.tensor(0.2271), atol=1e-3, rtol=1e-3) | ||
|
||
@torch.no_grad() | ||
def test_decoding_forward(self, fused_model, inputs, vocab_size): | ||
""" | ||
Test that the forward pass of the DeepFusionModel works during decoding. | ||
""" | ||
tokens, encoder_input, encoder_mask, input_pos = inputs | ||
tokens = tokens[:, input_pos] | ||
encoder_mask = encoder_mask[:, input_pos] | ||
batch_size, seq_len = tokens.shape | ||
out = fused_model( | ||
tokens, | ||
encoder_input=encoder_input, | ||
encoder_mask=encoder_mask, | ||
input_pos=input_pos, | ||
) | ||
|
||
assert out.shape == (batch_size, seq_len, vocab_size) | ||
assert_expected(out.mean(), torch.tensor(9.0072), atol=1e-3, rtol=1e-3) | ||
|
||
def test_setup_cache(self, fused_model): | ||
""" | ||
Test that the cache methods works as expected. | ||
""" | ||
fused_model.setup_caches(2, torch.float32) | ||
assert fused_model.caches_are_setup() | ||
fused_model.reset_caches() | ||
assert not fused_model.caches_are_setup() | ||
|
||
def test_set_trainable_params(self, fused_model, encoder, decoder): | ||
""" | ||
Test that the trainable parameters are set correctly. | ||
""" | ||
# Test default case | ||
trainable_params = { | ||
n for n, p in fused_model.named_parameters() if p.requires_grad | ||
} | ||
assert trainable_params == {"decoder.output.weight", "decoder.output.bias"} | ||
|
||
# Test encoder only | ||
model = DeepFusionModel( | ||
encoder=encoder, | ||
decoder=decoder, | ||
encoder_trainable=True, | ||
fusion_trainable=False, | ||
) | ||
trainable_params = {n for n, p in model.named_parameters() if p.requires_grad} | ||
assert trainable_params == {"encoder.weight"} | ||
|
||
# Test decoder only, and confirm fusion layers are removed independently | ||
model = DeepFusionModel( | ||
encoder=encoder, | ||
decoder=decoder, | ||
decoder_trainable=True, | ||
fusion_trainable=False, | ||
) | ||
trainable_params = {n for n, p in model.named_parameters() if p.requires_grad} | ||
assert trainable_params == { | ||
"decoder.q.weight", | ||
"decoder.q.bias", | ||
"decoder.k.weight", | ||
"decoder.k.bias", | ||
"decoder.v.weight", | ||
"decoder.v.bias", | ||
"decoder.tok_embeddings.weight", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.