Skip to content

Commit

Permalink
Adding MM eval tests / attention bugfixes (#1989)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Nov 13, 2024
1 parent 51b31c8 commit 18d97f0
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 34 deletions.
3 changes: 3 additions & 0 deletions tests/cache_artifacts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ SMALL_MODEL_URLS=(
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-03082024.pt"
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-tune-llama3-05052024.pt"
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-reward-07122024.pt"
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-meta-vision-10172024.pt"
"https://ossci-datasets.s3.amazonaws.com/torchtune/small-ckpt-hf-vision-10172024.pt"

)
FULL_MODEL_URL=("s3://pytorch-multimodal/llama2-7b-torchtune.pt")
TOKENIZER_URLS=(
Expand Down
135 changes: 117 additions & 18 deletions tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,40 @@
import pytest

from tests.common import TUNE_PATH
from tests.recipes.utils import llama2_test_config, write_hf_ckpt_config
from tests.test_utils import CKPT_MODEL_PATHS
from tests.recipes.utils import (
llama2_test_config,
llama3_2_vision_test_config,
write_hf_ckpt_config,
write_hf_vision_ckpt_config,
)
from tests.test_utils import CKPT_MODEL_PATHS, gpu_test


class TestEleutherEval:
@pytest.fixture
def hide_correct_version_number(self, monkeypatch):
import importlib.metadata

import_orig = importlib.metadata.version

def mocked_import(name, *args, **kwargs):
if name == "lm-eval":
return "0.4.4" # Hardcode wrong version number
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(importlib.metadata, "version", mocked_import)

@pytest.fixture
def expected_vision_acc(self):
return {
"Science": 0.35,
"Biology": 0.25,
"Chemistry": 0.25,
"Geography": 0.5,
"Math": 0.0,
"Physics": 0.75,
}

@pytest.mark.parametrize(
"eval_name, expected_acc, bsz",
[
Expand Down Expand Up @@ -74,22 +103,9 @@ def test_torchtune_checkpoint_eval_results(
acc_result = float(search_results.group(1))
assert math.isclose(acc_result, expected_acc, abs_tol=0.05)

@pytest.fixture
def hide_correct_version_number(self, monkeypatch):
import importlib.metadata

import_orig = importlib.metadata.version

def mocked_import(name, *args, **kwargs):
if name == "lm-eval":
return "0.4.4" # Hardcode wrong version number
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(importlib.metadata, "version", mocked_import)

@pytest.mark.integration_test
@pytest.mark.usefixtures("hide_correct_version_number")
def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir):
def test_eval_recipe_errors_without_lm_eval(self, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
Expand Down Expand Up @@ -123,7 +139,7 @@ def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir):

@pytest.mark.integration_test
def test_eval_recipe_errors_with_quantization_hf_checkpointer(
self, capsys, monkeypatch, tmpdir
self, monkeypatch, tmpdir
):
ckpt = "llama2_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
Expand Down Expand Up @@ -162,7 +178,7 @@ def test_eval_recipe_errors_with_quantization_hf_checkpointer(
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir):
def test_eval_recipe_errors_with_qat_quantizer(self, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
Expand Down Expand Up @@ -194,3 +210,86 @@ def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir
match="QAT quantizers should only be used during quantization aware training",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_meta_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc):
ckpt = "llama3_2_vision_meta"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent

cmd = f"""
tune run eleuther_eval \
--config llama3_2_vision/11B_evaluation \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelMetaCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}] \
~checkpointer.checkpoint_files.filename_format \
~checkpointer.checkpoint_files.max_filename \
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3_VISION \
tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
tokenizer.prompt_template=null \
limit=4 \
dtype=bf16 \
device=cuda \
""".split()

model_config = llama3_2_vision_test_config()
cmd = cmd + model_config

monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

out = caplog.text

pattern = r"^\|\s*(?:-\s*)?([^\|]+?)\s*\|\s*(\d+)\s*\|.*?\|.*?\|acc\s*\|\s*↑\s*\|\s*([\d.]+)"

matches = re.findall(pattern, out, re.MULTILINE)
for task_name, _, accuracy in matches:
assert math.isclose(float(accuracy), expected_vision_acc[task_name])

@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_hf_eval_vision(self, caplog, monkeypatch, tmpdir, expected_vision_acc):
ckpt = "llama3_2_vision_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent

# Config file needed for model conversion.
write_hf_vision_ckpt_config(ckpt_dir)

cmd = f"""
tune run eleuther_eval \
--config llama3_2_vision/11B_evaluation \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
~checkpointer.checkpoint_files.filename_format \
~checkpointer.checkpoint_files.max_filename \
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3_VISION \
tokenizer.path=/tmp/test-artifacts/tokenizer_llama3.model \
tokenizer.prompt_template=null \
limit=4 \
dtype=bf16 \
device=cuda \
""".split()

model_config = llama3_2_vision_test_config()
cmd = cmd + model_config

monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

out = caplog.text

pattern = r"^\|\s*(?:-\s*)?([^\|]+?)\s*\|\s*(\d+)\s*\|.*?\|.*?\|acc\s*\|\s*↑\s*\|\s*([\d.]+)"

matches = re.findall(pattern, out, re.MULTILINE)
for task_name, _, accuracy in matches:
assert math.isclose(float(accuracy), expected_vision_acc[task_name])
73 changes: 73 additions & 0 deletions tests/recipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,58 @@ def llama3_test_config() -> List[str]:
]


def llama3_2_vision_test_config() -> List[str]:
return [
"model=tests.recipes.utils.dummy_vision_model",
"tokenizer._component_=torchtune.models.llama3_2_vision._transform.Llama3VisionTransform",
"tokenizer.patch_size=9",
"tokenizer.max_num_tiles=2",
"tokenizer.tile_size=18",
"tokenizer.max_seq_len=4096",
]


def dummy_vision_model():
from torchtune.models.llama3_2_vision._component_builders import (
llama3_2_vision_decoder,
llama3_2_vision_encoder,
)
from torchtune.modules.model_fusion import DeepFusionModel

vision_encoder = llama3_2_vision_encoder(
clip_embed_dim=128,
clip_num_layers=4,
num_heads=4,
tile_size=18,
patch_size=9,
max_num_tiles=2,
in_channels=3,
clip_hidden_states=[0, 1],
num_layers_projection=2,
decoder_embed_dim=128,
)
vision_decoder = llama3_2_vision_decoder(
vocab_size=128256,
num_layers=4,
fusion_interval=2,
num_special_tokens=2,
num_heads=8,
num_kv_heads=4,
embed_dim=128,
max_seq_len=4096,
encoder_max_seq_len=4096,
)

model = DeepFusionModel(
encoder=vision_encoder,
decoder=vision_decoder,
encoder_trainable=False,
decoder_trainable=False,
fusion_trainable=False,
)
return model


def lora_llama2_test_config(
lora_attn_modules,
apply_lora_to_mlp: bool = False,
Expand Down Expand Up @@ -199,6 +251,27 @@ def write_hf_ckpt_config(ckpt_dir: str):
json.dump(config, f)


def write_hf_vision_ckpt_config(ckpt_dir: str):
config = {
"text_config": {
"num_attention_heads": 8,
"num_key_value_heads": 4,
"hidden_size": 128,
"vocab_size": 128256,
"cross_attention_layers": [1, 4],
},
"vision_config": {
"hidden_size": 128,
"image_size": 18,
"max_num_tiles": 2,
"supported_aspect_ratios": [[1, 1], [1, 2], [2, 1]],
},
}
config_file = Path.joinpath(Path(ckpt_dir), "config.json")
with config_file.open("w") as f:
json.dump(config, f)


MODEL_TEST_CONFIGS = {
"llama2": llama2_test_config(),
"llama3": llama3_test_config(),
Expand Down
2 changes: 2 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
"llama2_reward_hf": "/tmp/test-artifacts/small-ckpt-hf-reward-07122024.pt",
"llama3_tune": "/tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt",
"llama2_7b": "/tmp/test-artifacts/llama2-7b-torchtune.pt",
"llama3_2_vision_hf": "/tmp/test-artifacts/small-ckpt-hf-vision-10172024.pt",
"llama3_2_vision_meta": "/tmp/test-artifacts/small-ckpt-meta-vision-10172024.pt",
}

TOKENIZER_PATHS = {
Expand Down
39 changes: 38 additions & 1 deletion tests/torchtune/modules/test_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,48 @@ def transformer_layer(
transformer_layer.eval()
return transformer_layer

@mps_ignored_test()
def test_forward_kv_cache(
self,
input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
transformer_layer: TransformerCrossAttentionLayer,
input_params: Tuple[int, int, int, int],
):

b, _, encoder_seq_len, _ = input_params
transformer_layer.setup_caches(
batch_size=b,
dtype=torch.float32,
encoder_max_seq_len=encoder_seq_len,
decoder_max_seq_len=None,
)
input_x, input_y, mask = input
with torch.no_grad():
# make an initial forward pass which should fill the encoder cache
first_output = transformer_layer(
input_x,
encoder_input=input_y,
encoder_mask=mask,
)
# the second pass should just retrieve from the kv-cache and produce
# identical outputs
output = transformer_layer(
input_x,
encoder_input=None,
encoder_mask=mask,
)

assert_expected(output.mean(), torch.tensor(1.7762), atol=1e-8, rtol=1e-3)
assert_expected(output.shape, input_x.shape)

assert_expected(first_output.shape, output.shape)
assert_expected(first_output.mean(), output.mean())

@mps_ignored_test()
def test_forward(
self,
input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
transformer_layer: TransformerSelfAttentionLayer,
transformer_layer: TransformerCrossAttentionLayer,
) -> None:
input_x, input_y, mask = input
with torch.no_grad():
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/gemma2/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def forward(
q = self.q_norm(q)

if y is None:
if self.kv_cache is None:
if self.kv_cache is None or not self.cache_enabled:
raise ValueError(
"Must provide y input or use kv_cache to enable streaming decoding"
)
Expand Down
26 changes: 13 additions & 13 deletions torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def forward(
and before the softmax. Either:
A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``,
or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
or ``[b x s x self.decoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means
token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask
is used by default.
Expand Down Expand Up @@ -249,7 +249,7 @@ def forward(
q = self.q_norm(q)

if y is None:
if self.kv_cache is None:
if self.kv_cache is None or not self.cache_enabled:
raise ValueError(
"Must provide y input or use kv_cache to enable streaming decoding"
)
Expand All @@ -273,21 +273,21 @@ def forward(
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Normalize k
if self.k_norm is not None:
k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)

# If needed, expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
# k,v shape: [b, n_h, s, h_d]
if self.num_heads != self.num_kv_heads:
expand_shape = (-1, -1, q_per_kv, -1, -1)
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)

# Normalize k
if self.k_norm is not None:
k = self.k_norm(k)
# If needed, expand the key and value tensors to have the same shape
# as the query tensor by copying values across the relevant dim
# k,v shape: [b, n_kv, s, h_d] -> [b, n_h, s, h_d]
if self.num_heads != self.num_kv_heads:
expand_shape = (b, self.num_kv_heads, q_per_kv, -1, self.head_dim)
k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)

output = self._attention_call(
q,
Expand Down
2 changes: 1 addition & 1 deletion torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def setup_caches(
isinstance(l, TransformerCrossAttentionLayer) for l in self.modules()
)
has_decoder_layers = any(
isinstance(l, TransformerSelfAttentionLayer) for l in self.layers
isinstance(l, TransformerSelfAttentionLayer) for l in self.modules()
)
if has_encoder_layers:
if encoder_max_seq_len is not None:
Expand Down

0 comments on commit 18d97f0

Please sign in to comment.