Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add gemma2b variants #1835

Merged
merged 12 commits into from
Nov 8, 2024
Merged

Conversation

Optimox
Copy link
Contributor

@Optimox Optimox commented Oct 15, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

This is related to adding gemma2 support #1813

Changelog

What are the changes made in this PR?
*

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1835

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 53eed40 with merge base 57ab583 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 15, 2024
@Optimox Optimox mentioned this pull request Oct 15, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this! Just took a quick and very non-exhaustive first pass to leave a few comments, will get back to it with a full review later today.

torchtune/modules/attention.py Outdated Show resolved Hide resolved
torchtune/modules/attention.py Outdated Show resolved Hide resolved
recipes/configs/gemma2/27B_full.yaml Outdated Show resolved Hide resolved
torchtune/models/gemma2/_component_builders.py Outdated Show resolved Hide resolved
@joecummings joecummings mentioned this pull request Oct 15, 2024
34 tasks
logger = logging.getLogger(__name__)


class Gemma2Attention(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we support flex attention, which support soft capping, would it make sense to just force gemma2 users to use flex attention instead of implementing this module?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flex Attention is only supported on A100 or better, right? I don't think we can make the assumption that our users will have that.

Copy link
Contributor Author

@Optimox Optimox Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello everyone,

I just pushed a new commit which includes all changes discussed with @ebsmothers.
I also implemented a flex attention version but I could not make it work properly.

The default implementation (not using FlexAttention) seems to be working (I only launched the single lora pipeline, please see the attached logs log_gemma2-2b-single-lora_1729498141.txt).

I would appreciate some help on the FlexAttention implementation. Here is why I am struggling.

If I run the following code on my A6000 GPU with torch 2.5:

import torch

from torch.nn.attention.flex_attention import (
    create_block_mask,
    flex_attention)


WINDOW_SIZE=None #None
CAPPING=50.
SCALE=12.


def get_gemma2_flex_score_mask(sliding_window_size, softcapping, query_pre_attn_scalar):
    
    def sliding_window_causal_mask(b, h, q_idx, kv_idx):
        """Causal mask and sliding window as proposed here:
        https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
        """
        causal_mask = q_idx >= kv_idx
        if sliding_window_size is None:
            # if no sliding window return causal mask
            return causal_mask
        else:
            windowed_mask = q_idx - kv_idx <= sliding_window_size

            return causal_mask & windowed_mask
    
    def soft_capping_with_scaling(score, b, h, q_idx, kv_idx):
        if query_pre_attn_scalar is None:
            # usual scaling included in FlexAttention
            score = score / softcapping
            score = torch.tanh(score) #tanh_approx(score)
            return score * softcapping
        else:
            score = score / softcapping * query_pre_attn_scalar**-0.5
            score = torch.tanh(score) #tanh_approx(score)
            return score * softcapping
    
    return sliding_window_causal_mask, soft_capping_with_scaling

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)

B=4
H=8
S=117
D=256 #256

mask_mod, score_mod = get_gemma2_flex_score_mask(WINDOW_SIZE, CAPPING, SCALE)

query = torch.randn(
        B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
    )
key = torch.randn(
    B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
value = torch.randn(
    B, H, S, D, device="cuda", dtype=torch.float16, requires_grad=True
)
gradOut = torch.randn(B, H, S, D, device="cuda", dtype=torch.float16)


block_mask = create_block_mask(mask_mod=mask_mod,
                                   B=1,
                                   H=1,
                                   Q_LEN=S,
                                   KV_LEN=S,
                                   device=query.device)

out = flex_attention(
    query, key, value, score_mod=score_mod, block_mask=block_mask
)
print(out.shape)

The code runs fine if I don't compile the flex attention by commenting flex_attention = torch.compile(flex_attention, dynamic=False) but it raises this error otherwise:
BackendCompilerFailed: backend='inductor' raised: OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or num_stages may help.

So I disabled compilation and the code seems to be running but very very slowly (48s per iteration vs 1-2s on non flex implementation).

Maybe you could help me understand what is going on ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tagging @RdoubleA and @felipemello1 for their thoughts.

Just checking: which size Gemma-2 model are you testing with?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logs I shared are from gemma2 2B, the code snippet is independent of the gemma architecture it's just a toy example.
I am currently running the qlora single device pipeline with 9B (without flex attention), I'll share the logs tomorrow (I'll push the changes to recipe as there are typos on the output path etc).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the given kernel options, flex attention can be compiled and the code runs (9b lora single device training). However, the code is terribly slow (29 tokens per second) and the loss turns to nan after one batch:

Step 1 | loss:87.6104507446289 lr:2.0000000000000003e-06 tokens_per_second_per_gpu:21.504354449638843 
Step 2 | loss:nan lr:4.000000000000001e-06 tokens_per_second_per_gpu:29.156293709460176 

I don't understand what I am doing wrong, the only obvious optimisation I see is to create one block mask for every layer while I am currently recreating the same block mask for every layer (line 593 in gemma2/_attention.py). Nevertheless, I do not think that this is the current bottleneck.

Wouldn't it be better to go with the simpler implementation for now and switch to FlexAttention when it will work on more GPUs? or at least leave the choice of computation to the final user and default to the classical implementation ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I see what you're saying.. I repro'd this on my end too so it is not a function of any custom kernel configs you're using. Let me look into this a bit more but in the meantime it seems like we shouldn't enable the flex version until we figure this out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the code to keep the flex attention implementation but disable it for now, until we have found a solution.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey sorry just catching up:

So 2.5 should not require multiple of 128 for sequence length. It is unfortunately pretty common for consumer gpus to hit the SharedMemory issue. I have a pr: pytorch/pytorch#137959 to drop default block sizes but still need to debug the failing test.

For being slow, it is expected that the tanh instruction is very slow compared to the inline assembly variant: https://github.com/pytorch-labs/attention-gym/blob/36f8bd5ded5b3469f7892099590bb2405cc8f744/attn_gym/mods/softcapping.py#L92.

It is actually quite hard generically to know what what block sizes should be used since the amount of shared memory depends on the captured buffers. I am working on a better solution but that is going to take some time unfortunately

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you all figured out a solution to the out of resource: shared memory? Seems like any large hidden dim >=128 causes issues for me.

@felipemello1 felipemello1 self-assigned this Oct 23, 2024
@Optimox
Copy link
Contributor Author

Optimox commented Oct 24, 2024

I have pushed changes to the recipes for 9b and 27b (typos in folders' name).
I also ran the single lora recipe for gemma2 9b, everything ran ok (with flex attention disabled). nevertheless the loss seems better with the 2b model, maybe it's just because the larger model overfits more quickly.
logs_gemma2-9b-lora-single_1729686341.txt

# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token <HF_TOKEN>
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did some quick math, I guess this will take at least 216GB total memory (54GB params + 54GB gradients + 108GB optimizer states for AdamW) , which means to run on 4 devices we'd need people to be using A100s. I wonder whether we can use an 8-bit optimizer + optimizer in backward to get us down to a more reasonable peak VRAM here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does 8bit work with distributed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah duh.. there may be some issues with bitsandbytes optimizers on that front. I just tried out ao low-precision optimizers and it seems to work (though haven't resumed from intermediate checkpoint). Also there may be a compile dep there. Anyways if it's too much hassle we can consider it separately, don't wanna increase the scope of this already substantial PR more than necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should I do here? Change something or expect users to change parameters according to their hardware ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry missed this comment before now. I think it's fine to leave this as you have it and revisit these details in a later PR


checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/gemma-2b/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
checkpoint_dir: /tmp/gemma-2b/
checkpoint_dir: /tmp/gemma-2-2b/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

if query_pre_attn_scalar is not None:
self.scaling = query_pre_attn_scalar**-0.5
else:
self.scaling = self.head_dim**-0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to add self.cache_enabled=False here (then set it to True at the end of setup_cache), otherwise this will error out. But this is kind of a gotcha, it's not obvious that you need this. cc @SalmanMohammadi we should think about how to make this more obvious to someone adding their own attention layer

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I added a comment to indicate why it's in the init (maybe @Optimox forked before then?)

        # this flag indicates whether to update the kv-cache during forward
        # passes. when disabled, we can have the cache setup but still
        # perform normal forward passes
        self.cache_enabled = False

Could we be clearer here? I agree we could use with a comment in setup_caches explaining that you actually need to do this if you'd like to use the caches you've just setup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think I forked before this change, will make the change tomorrow thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

k = self.k_norm(k)

# Update key-value cache
if self.kv_cache is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.kv_cache is not None:
if self.kv_cache is not None and self.cache_enabled:

should complement the cache enabled stuff earlier to match the other attention module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Fine-tuning arguments
batch_size: 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we confident this'll fit on a single device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed batch size to 2 and accumulation to 8. What is the expected GPU? Is there a CI running everything? Otherwise I guess each user should be responsible to play with the batch to get something suitable for his GPU no ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally we try ship configs which we know will work on some common hardware configuration (see examples here https://github.com/pytorch/torchtune?tab=readme-ov-file#memory-and-training-speed), so users can maintain the expectation that they can get started without any painful OOMs. Then they are free to play with the configs. We should make sure this config works with e.g. 1xA1000 - let me know if you need a hand here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SalmanMohammadi I do not have easy access to a A100, would appreciate if someone could run the code for the 27B params model and let me know what batch size I should set.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a quick look when we're ready to land. We can also reasonably mirror the batch size from the config of another similarly sized model already in the codebase.

# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token <HF_TOKEN>
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/2B_full
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's just me but when I try to run these distributed recipes I am hitting AssertionError: FSDP requires named DeviceMesh dims for ND parallelism. It looks to me like we are actually entering _init_sharded_param with a DTensor (see here), which does not happen with our other recipes. Need to figure out why this would be happening

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I think I cracked the case. See here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big mistake, thank you for catching that!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logs after this fix look much better than previously for the 9b single lora pipeline!
log_gemma2-2b-single-lora_1729937021.txt

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers aren't the losses too low? could it be because of the (non) causal sliding window attention problem ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this comment before. I am gonna run some of your configs on my end now so will get back to you

"""
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)

mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be inside the for loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

path: /tmp/gemma-2-27b/tokenizer.model

# Dataset
dataset:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to potentially be a pain in the ass here. We have parallel PR (#1872) which is helping standardize our configs and better expose the features we have. This means we always have packed: False in dataset, and log_peak_memory_stats: True and compile: False below, for every one of our configs.

Would it be annoying to ask if we could update these in the same way while we're here, please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done I have updated all the configs to match the other PR!

@@ -27,6 +27,4 @@
"lora_gemma_7b",
"qlora_gemma_2b",
"qlora_gemma_7b",
"gemma_hf_to_tune",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch : )

flex_causal_sliding_window,
flex_tanh_soft_capping_with_scaling,
)
logger = logging.getLogger(__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Why is this style of logger getting proliferated? We should be calling get_logger from our utils.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a copy paste on my side, let me know if you want me to change that on this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just change to torchtune.utils.get_logger, but no strong preference here. Either way we should clean up other usages in a follow-up

@ebsmothers
Copy link
Contributor

Hi @Optimox sorry for the delay here. Given that the flex attention version is still not working properly, how do you feel about pulling it out of this PR? Then we can revisit in a follow-up. For context we are going to be cutting a release soon (targeting code freeze tomorrow) so don't want to block getting this in on something that we can address in a follow-up. Let me know if this makes sense to you.

@Optimox
Copy link
Contributor Author

Optimox commented Oct 29, 2024

@ebsmothers yes no problem! What is the best way of handling this? Adding a new commit deleting the flex attention part of this branch ? Or creating a new PR without the flex attention part?

@ebsmothers
Copy link
Contributor

@Optimox honestly whatever is easiest for you. I imagine just a commit deleting the flex code would be simplest, but feel free to do whatever makes sense to you!

@Optimox
Copy link
Contributor Author

Optimox commented Oct 30, 2024

@ebsmothers I have removed the flex attention implementation from the code, let me know if there are still other changes to make!

@Optimox Optimox changed the title (WIP)feat: add gemma2b variants feat: add gemma2b variants Oct 30, 2024
Comment on lines +14 to +18
Gemma 2 and Gemma original implementations share different normalization but with
the same name, so it is mandatory to differentiate their state dict in order to map
correctly the different weights.
They are essentially the same except for "model.layers.{}.post_attention_layernorm.weight" key.
See discussion here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1803410251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for documenting this

Comment on lines 302 to 310
sliding_mask = torch.triu(
all_ones, -1 * self.sliding_window_size + 1
) * torch.tril(all_ones, self.sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)

if self.softcapping is not None:
output = output / self.softcapping
output = torch.tanh(output)
output = output * self.softcapping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add code comments explaining sliding window and the softcapping? (Also one for the magic value in the torch.where line wouldn't hurt)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About this part of the code, I actually followed blindly the official pytoch implementation from Google here

I am not sure why they used this magic number instead of -torch.inf ...

About the sliding_mask I am now worried that something is wrong here because of the way I defined the causal mask...

s_x = 10
sliding_window_size = 5
mask = torch.tril(
                torch.ones(
                    size=(s_x, s_x),
                    dtype=torch.bool,
                )
            )
print(mask)

all_ones = torch.ones_like(mask)
sliding_mask = torch.triu(
                all_ones, -1 * sliding_window_size + 1
            ) * torch.tril(all_ones, sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
print(mask)

The final mask here does not seem to be causal anymore, and sliding future tokens are now accessible somehow...

Something like the following would seem better to me but is there a difference in the way masks are defined in gemma2 official code and torchtune?

s_x = 10
sliding_window_size = 5
mask = torch.tril(
                torch.ones(
                    size=(s_x, s_x),
                    dtype=torch.bool,
                )
            )

mask = torch.where(mask==0, -torch.inf, 1)
print(mask)

all_ones = torch.ones_like(mask)
sliding_mask = torch.triu(
                all_ones, -1 * sliding_window_size + 1
            ) * torch.tril(all_ones, sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -torch.inf)
print(mask)

This seems concerning... what do you think ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Optimox yeah good catch, that first mask definitely does not look right. I'm not that familiar with the official implementation, but looks like they are treating it as an additive mask here. In that case there should definitely not be anything that's not -torch.inf (or a very large negative number) above the diagonal.

This seems like a bug to me, maybe you can open an issue on the gemma_pytorch repo to confirm? Your second implementation looks correct to me.

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to add another kitchen to the cook here so late. I thought I'd share some of my findings from the last time I worked on masking + SDPA. I'd reccomend checking this PR out pytorch/pytorch#133882 and the linked issues.

TLDR; this is the correct approach to use the attention mask like a "bias" by adding a very large negative to the q/k tensors. However, using -inf as this negative number has been shown to produce NaN gradients for some rare corner cases (e.g. when an entire row is masked out). In Transformers, the approach is to use something like torch.finfo(dtype).min [1] (which is maybe where the original magic number is coming from?)

import torch
x = torch.Tensor([[[float("-inf"), float("-inf"), float("-inf")]]])
softmax = torch.nn.Softmax(dim=-1)
softmax(x)
# tensor([[[nan, nan, nan]]])
dtype = torch.bfloat16
min_value = torch.finfo(dtype).min
# -3.3895313892515355e+38 - on MPS, this will vary depending on the hardware you're using
x = torch.Tensor([[[min_value, min_value, min_value]]])
softmax = torch.nn.Softmax(dim=-1)
softmax(x)
# tensor([[[0.3340, 0.3340, 0.3340]]]) 

Aside, as of torch 2.5 this is handled internally slightly differently. -inf is used in the mask, softmax is performed, but then any rows in the original tensor which have entirely masked out rows are explicitly set to zero.

[1] Transformers follows this approach for their Gemma2 implementation. However, this apparently still causes issues for some dtypes so it's also been suggested to use torch.finfo(dtype).min / 2 - see huggingface/transformers#32390, or to just attend to all tokens in a row containing only padding tokens equally (huggingface/transformers@e22d913), but I'm not 100% about how this works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait so I am a bit confused here (this is also in reference to the issue opened by @Optimox on the gemma_pytorch repo). There are two separate questions here, right?

(1) Should values above the diagonal be unmasked?
(2) What is the right masked value for an additive mask?

I think both the gemma_pytorch discussion and @SalmanMohammadi's comments address (2), and that I am not so worried about. But I think (1) is more fundamental, and I can directly copy-paste the snippet from gemma_pytorch to get this:

import torch

s_x = 5
sliding_window_size = 3
mask = torch.tril(
                torch.ones(
                    size=(s_x, s_x),
                    dtype=torch.bool,
                )
            )

all_ones = torch.ones_like(mask)
sliding_mask = torch.triu(
    all_ones, -1 * sliding_window_size + 1
) * torch.tril(all_ones, sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
print(mask)
...

tensor([[ 1.0000e+00,  0.0000e+00,  0.0000e+00, -2.3820e+38, -2.3820e+38],
        [ 1.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00, -2.3820e+38],
        [ 1.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
        [-2.3820e+38,  1.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00],
        [-2.3820e+38, -2.3820e+38,  1.0000e+00,  1.0000e+00,  1.0000e+00]])

This demonstrates that there are values above the diagonal that are unmasked, no? (I will also reopen the issue on there just to confirm I am not missing something here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I understood is that there are two ways of defining masks (in bias mode):

  • 0 for reachable, -inf (or very large negative number) for unreachable -> this is the expected input mask for gemma_pytorch implementation
  • bolean mask (True for reachable, False otherwise) + torch.nn.functional.scaled_dot_product_attention which internally switches from boolean to the first implementation if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))

So here, since all previous models from torchtune used torch sdpa there was a mismatch between both implementation, that is why I added this conversion in the latest changes. This conversion should also work for block mask defined as boolean mask currently in torchtune code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I also caught up with @SalmanMohammadi offline about this. It seems like the initial definition of mask in the code snippet I shared in my last comment did not match what they do in gemma_pytorch, so that was a misunderstanding on my part. I think your approach makes sense, will look at the code more closely to confirm the BlockMask case (though I guess if we're not supporting packed or flex yet it doesn't matter?)

Comment on lines 287 to 288
q.mul_(self.scaling)
output = torch.matmul(q, k.transpose(2, 3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add shape comments here too

x: torch.Tensor,
y: Optional[torch.Tensor] = None,
*,
mask: Optional[_MaskType] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you run any of the configs with packed=True (i.e. when mask is a BlockMask)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't run a full training but I've checked that the code does not through any error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my latest comment: packed=True won't work at the moment!

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments and questions but overall this is looking great!

@joecummings
Copy link
Contributor

@Optimox Any chance you can get to these last comments today?

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 1.23967% with 239 lines in your changes missing coverage. Please review.

Project coverage is 67.44%. Comparing base (54673b7) to head (599c828).
Report is 60 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/models/gemma2/_attention.py 0.00% 100 Missing ⚠️
torchtune/models/gemma2/_component_builders.py 0.00% 71 Missing ⚠️
torchtune/models/gemma2/_convert_weights.py 0.00% 35 Missing ⚠️
torchtune/models/gemma2/_model_builders.py 0.00% 24 Missing ⚠️
torchtune/models/gemma2/__init__.py 0.00% 5 Missing ⚠️
torchtune/training/checkpointing/_checkpointer.py 33.33% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1835      +/-   ##
==========================================
+ Coverage   67.05%   67.44%   +0.38%     
==========================================
  Files         305      316      +11     
  Lines       15937    17143    +1206     
==========================================
+ Hits        10687    11562     +875     
- Misses       5250     5581     +331     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Optimox
Copy link
Contributor Author

Optimox commented Nov 4, 2024

I have pushed some minimal changes which take into account the fact that we are not using spda with gemma2 so the masks must be converted from boolean to True -> 0s and False -> -inf as discussed here.

There is one last identified concern on my side, @ebsmothers I think I answered somewhere that packed dataset was working but actually it is not.
The first reason is that when I run the code with torch 2.5, then flex attention is considered to be available, so block masks arrive in flex attention format which is a problem without a flex attention implementation for gemma2.

If I manually disable flexattention (I don't know how to disable it locally or automatically) then I had a broadcasting issue which has been solved in my latest commit.

So currently the code won't work for packed dataset for torch >= 2.5.


# Model Arguments
model:
_component_: torchtune.models.gemma2.gemma_27b
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_component_: torchtune.models.gemma2.gemma_27b
_component_: torchtune.models.gemma2.gemma2_27b


# Model Arguments
model:
_component_: torchtune.models.gemma2.qlora_gemma_27b
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_component_: torchtune.models.gemma2.qlora_gemma_27b
_component_: torchtune.models.gemma2.qlora_gemma2_27b

checkpoint_dir: /tmp/gemma-2-27b/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all usages of max_filename should look like this instead

Suggested change
max_filename: 00024
max_filename: "00024"

@ebsmothers
Copy link
Contributor

@Optimox thanks for your patience in the review process here. Is this comment about unusually low losses still a concern? I ran a few of your configs on my end and the loss does increase pretty dramatically (though I also don't have a baseline). For reference here are some loss curves:

Screenshot 2024-11-07 at 11 55 01 AM

Otherwise regarding the fact that packed is not yet supported, it seems like we should raise an error somewhere if that's the case. Maybe inside the attention if we receive a BlockMask there? (Though open to other thoughts you have on this)

@Optimox
Copy link
Contributor Author

Optimox commented Nov 8, 2024

@ebsmothers yes the comment was before the fix of the slinding window attention mask.

I have added a NotImplementedError for the BlockMasks and made the small fixes in the yaml files you pointed.

Let me know if you think some other changes should be made!

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this great new feature! And thanks for your patience and diligence throughout the review process. Very excited that we're now able to support Gemma 2 in torchtune.

@ebsmothers ebsmothers merged commit aa96cae into pytorch:main Nov 8, 2024
17 checks passed
@ebsmothers ebsmothers mentioned this pull request Nov 26, 2024
44 tasks
@ebsmothers ebsmothers mentioned this pull request Jan 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.