Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Nov 8, 2024
1 parent c8bd056 commit d0b1ab0
Show file tree
Hide file tree
Showing 12 changed files with 1,057 additions and 1,062 deletions.
195 changes: 195 additions & 0 deletions tests/torchtune/modules/model_fusion/test_deep_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest

import torch
from tests.test_utils import assert_expected, fixed_init_model
from torch import nn
from torchtune.modules.model_fusion import DeepFusionModel, register_fusion_module
from torchtune.training.seed import set_seed


@pytest.fixture(autouse=True)
def random():
set_seed(1)


class DummyModel(nn.Module):
def __init__(self, dim, vocab_size):
super().__init__()
self.cache_enabled = False
self.tok_embeddings = nn.Embedding(vocab_size, dim)
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.output = nn.Linear(dim, vocab_size)
register_fusion_module(self.output)

def setup_caches(self, batch_size, dtype, *args, **kwargs):
self.cache_enabled = True

def caches_are_setup(self):
return self.cache_enabled

def reset_caches(self):
self.cache_enabled = False

def forward(
self,
tokens,
*,
mask=None,
encoder_input=None,
encoder_mask=None,
input_pos=None,
):
x = self.tok_embeddings(tokens)
if encoder_input is not None:
q = self.q(x)
k = self.k(encoder_input) if encoder_input is not None else self.k(x)
v = self.v(encoder_input) if encoder_input is not None else self.v(x)
x += nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=encoder_mask if encoder_mask is not None else mask
)
x = self.output(x)
return x


class TestDeepFusionModel:
"""
Class for testing our DeepFusionModel wrapper.
"""

@pytest.fixture
def vocab_size(self) -> int:
return 100

@pytest.fixture
def dim(self) -> int:
return 64

@pytest.fixture
def encoder(self, dim, vocab_size) -> nn.Module:
encoder = nn.Embedding(vocab_size, dim)
fixed_init_model(encoder)
return encoder

@pytest.fixture
def decoder(self, dim, vocab_size) -> nn.Module:
decoder = DummyModel(dim, vocab_size)
fixed_init_model(decoder, max_val=0.1)
return decoder

@pytest.fixture
def fused_model(self, encoder, decoder) -> DeepFusionModel:
model = DeepFusionModel(
encoder=encoder,
decoder=decoder,
)
return model

@pytest.fixture
def inputs(self, vocab_size):
batch_size = 2
seq_len = 10
tokens = torch.randint(0, vocab_size, (batch_size, seq_len))
encoder_input = {"input": torch.randint(0, vocab_size, (batch_size, seq_len))}
encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool()
input_pos = torch.Tensor([1]).int()
return tokens, encoder_input, encoder_mask, input_pos

@torch.no_grad()
def test_forward(self, fused_model, inputs, vocab_size):
"""
Test that the forward pass of the DeepFusionModel works as expected.
"""
tokens, encoder_input, encoder_mask, _ = inputs
batch_size, seq_len = tokens.shape
out = fused_model(
tokens, encoder_input=encoder_input, encoder_mask=encoder_mask
)

assert out.shape == (batch_size, seq_len, vocab_size)
assert_expected(out.mean(), torch.tensor(8.5584), atol=1e-3, rtol=1e-3)

@torch.no_grad()
def test_forward_no_encoding(self, fused_model, inputs, vocab_size):
"""
Test that the forward pass of the DeepFusionModel with no encoder input.
"""
tokens, *_ = inputs
batch_size, seq_len = tokens.shape
out = fused_model(tokens)

assert out.shape == (batch_size, seq_len, vocab_size)
assert_expected(out.mean(), torch.tensor(0.2271), atol=1e-3, rtol=1e-3)

@torch.no_grad()
def test_decoding_forward(self, fused_model, inputs, vocab_size):
"""
Test that the forward pass of the DeepFusionModel works during decoding.
"""
tokens, encoder_input, encoder_mask, input_pos = inputs
tokens = tokens[:, input_pos]
encoder_mask = encoder_mask[:, input_pos]
batch_size, seq_len = tokens.shape
out = fused_model(
tokens,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)

assert out.shape == (batch_size, seq_len, vocab_size)
assert_expected(out.mean(), torch.tensor(9.0072), atol=1e-3, rtol=1e-3)

def test_setup_cache(self, fused_model):
"""
Test that the cache methods works as expected.
"""
fused_model.setup_caches(2, torch.float32)
assert fused_model.caches_are_setup()
fused_model.reset_caches()
assert not fused_model.caches_are_setup()

def test_set_trainable_params(self, fused_model, encoder, decoder):
"""
Test that the trainable parameters are set correctly.
"""
# Test default case
trainable_params = {
n for n, p in fused_model.named_parameters() if p.requires_grad
}
assert trainable_params == {"decoder.output.weight", "decoder.output.bias"}

# Test encoder only
model = DeepFusionModel(
encoder=encoder,
decoder=decoder,
encoder_trainable=True,
fusion_trainable=False,
)
trainable_params = {n for n, p in model.named_parameters() if p.requires_grad}
assert trainable_params == {"encoder.weight"}

# Test decoder only, and confirm fusion layers are removed independently
model = DeepFusionModel(
encoder=encoder,
decoder=decoder,
decoder_trainable=True,
fusion_trainable=False,
)
trainable_params = {n for n, p in model.named_parameters() if p.requires_grad}
assert trainable_params == {
"decoder.q.weight",
"decoder.q.bias",
"decoder.k.weight",
"decoder.k.bias",
"decoder.v.weight",
"decoder.v.bias",
"decoder.tok_embeddings.weight",
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
import torch
from tests.test_utils import assert_expected, fixed_init_model
from torch import nn
from torchtune.modules.model_fusion import (
DeepFusionModel,
EarlyFusionModel,
register_fusion_module,
)
from torchtune.modules.model_fusion import EarlyFusionModel, register_fusion_module
from torchtune.training.seed import set_seed


Expand Down Expand Up @@ -65,167 +61,6 @@ def forward(
return x


class TestDeepFusionModel:
"""
Class for testing our DeepFusionModel wrapper.
"""

@pytest.fixture
def vocab_size(self) -> int:
return 100

@pytest.fixture
def dim(self) -> int:
return 64

@pytest.fixture
def encoder(self, dim, vocab_size) -> nn.Module:
encoder = nn.Embedding(vocab_size, dim)
fixed_init_model(encoder)
return encoder

@pytest.fixture
def decoder(self, dim, vocab_size) -> nn.Module:
decoder = DummyModel(dim, vocab_size)
fixed_init_model(decoder, max_val=0.1)
return decoder

@pytest.fixture
def fused_model(self, encoder, decoder) -> DeepFusionModel:
model = DeepFusionModel(
encoders={"encoder": encoder},
decoder=decoder,
)
return model

@pytest.fixture
def inputs(self, vocab_size):
batch_size = 2
seq_len = 10
tokens = torch.randint(0, vocab_size, (batch_size, seq_len))
encoder_input = {
"encoder": {"input": torch.randint(0, vocab_size, (batch_size, seq_len))}
}
encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool()
input_pos = torch.Tensor([1]).int()
return tokens, encoder_input, encoder_mask, input_pos

@torch.no_grad()
def test_forward(self, fused_model, inputs, vocab_size):
"""
Test that the forward pass of the DeepFusionModel works as expected.
"""
tokens, encoder_input, encoder_mask, _ = inputs
batch_size, seq_len = tokens.shape
out = fused_model(
tokens, encoder_input=encoder_input, encoder_mask=encoder_mask
)

assert out.shape == (batch_size, seq_len, vocab_size)
assert_expected(out.mean(), torch.tensor(8.5584), atol=1e-3, rtol=1e-3)

@torch.no_grad()
def test_forward_no_encoding(self, fused_model, inputs, vocab_size):
"""
Test that the forward pass of the DeepFusionModel with no encoder input.
"""
tokens, *_ = inputs
batch_size, seq_len = tokens.shape
out = fused_model(tokens)

assert out.shape == (batch_size, seq_len, vocab_size)
assert_expected(out.mean(), torch.tensor(0.2271), atol=1e-3, rtol=1e-3)

@torch.no_grad()
def test_decoding_forward(self, fused_model, inputs, vocab_size):
"""
Test that the forward pass of the DeepFusionModel works during decoding.
"""
tokens, encoder_input, encoder_mask, input_pos = inputs
tokens = tokens[:, input_pos]
encoder_mask = encoder_mask[:, input_pos]
batch_size, seq_len = tokens.shape
out = fused_model(
tokens,
encoder_input=encoder_input,
encoder_mask=encoder_mask,
input_pos=input_pos,
)

assert out.shape == (batch_size, seq_len, vocab_size)
assert_expected(out.mean(), torch.tensor(9.0072), atol=1e-3, rtol=1e-3)

def test_setup_cache(self, fused_model):
"""
Test that the cache methods works as expected.
"""
fused_model.setup_caches(2, torch.float32)
assert fused_model.caches_are_setup()
fused_model.reset_caches()
assert not fused_model.caches_are_setup()

def test_set_trainable_params(self, fused_model, encoder, decoder):
"""
Test that the trainable parameters are set correctly.
"""
# Test default case
trainable_params = {
n for n, p in fused_model.named_parameters() if p.requires_grad
}
assert trainable_params == {"decoder.output.weight", "decoder.output.bias"}

# Test encoder only
model = DeepFusionModel(
encoders={"encoder": encoder},
decoder=decoder,
encoders_trainable=True,
fusion_trainable=False,
)
trainable_params = {n for n, p in model.named_parameters() if p.requires_grad}
assert trainable_params == {"encoders.encoder.weight"}

# Test decoder only, and confirm fusion layers are removed independently
model = DeepFusionModel(
encoders={"encoder": encoder},
decoder=decoder,
decoder_trainable=True,
fusion_trainable=False,
)
trainable_params = {n for n, p in model.named_parameters() if p.requires_grad}
assert trainable_params == {
"decoder.q.weight",
"decoder.q.bias",
"decoder.k.weight",
"decoder.k.bias",
"decoder.v.weight",
"decoder.v.bias",
"decoder.tok_embeddings.weight",
}

def test_incorrect_number_of_encoders(self, decoder):
with pytest.raises(ValueError):
_ = DeepFusionModel(
encoders={"encoder": nn.Identity(), "encoder2": nn.Identity()},
decoder=decoder,
)

def test_mismatched_encoder_keys(self, decoder):
with pytest.raises(ValueError):
_ = DeepFusionModel(
encoders={"encoder": nn.Identity()},
decoder=decoder,
encoders_trainable={"encoder2": True},
)

def test_mismatched_encoder_input(self, fused_model, inputs):
tokens, _, _, _ = inputs
with pytest.raises(ValueError):
_ = fused_model(
tokens,
encoder_input={"encoder2": {"input": torch.tensor([1])}},
)


class TestEarlyFusionModel:
@pytest.fixture
def vocab_size(self) -> int:
Expand Down
Loading

0 comments on commit d0b1ab0

Please sign in to comment.