Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plan for packed dims #1645

Open
albertz opened this issue Nov 13, 2024 · 2 comments
Open

Plan for packed dims #1645

albertz opened this issue Nov 13, 2024 · 2 comments
Assignees

Comments

@albertz
Copy link
Member

albertz commented Nov 13, 2024

This here is open for discussion on what we want in RETURNN.

Packed tensors / packed sequences / ragged tensors / jagged tensors, however you want to call them, means that instead of having a padded tensor [B,T,D], where the T dim has different dynamic seq lens, you have a tensor [B_T,D], where B_T = sum(T.seq_lens), i.e. the padding frames are removed, and the B and T dims are packed together. (The shape of dyn_size_ext of T is [B] here. That is why you can pack B with T.)

In RF, we have pack_padded to pack and also pad_packed to unpack.

We have the somewhat implicit assumption in RETURNN (as do most other ML frameworks) that padding is the canonical way to represent a tensor with multiple dims where one dim is dynamic. Maybe it really is (it's definitely the most simple). However, the packed representation is also valid.

The main question here in this issue: Should we have more direct support for packed tensors?

I imagine sth like this:

  • The packed dim (B_T here) knows about what dims it consists of.
  • All operations on dims, like reduce, matmul, whatever, would directly support working on the dims that are contained in the packed dims, e.g. you can do y = reduce(x, axis=T) even though x has shape [B_T,D], and you would get y with shape [B,D]. Some ops can provide efficient code to directly handle this, otherwise the fallback would be to unpack when some of those dims are referred to. But it should always work.

Effectively, it means, it doesn't really make a difference for your code whether some dims are packed or not. Similarly as the order of the dims are also not relevant.

There are some things a bit unclear:

  • What should x.dims return? Should this give the low-level view on the packed dim? I think yes. But then, how to check whether some dim is maybe inside a packed dim? A separate x.unpacked_dims? In any case, this means that there need to be code changes, also in user code, to really support packed tensors just in the same way as padded tensors.

But also, getting back to the main question: Do we really want this? This adds a lot of complexity, specifically in the backend. Maybe it's also prone to errors? Also, it adds some overhead everywhere. Whenever some dim is not directly found in dims, we need to also check within the packed dims (or cache that somehow, even more complexity...).

Thus, I tend to think we should not do this. But I still keep thinking about whether there are maybe easy solutions here.

In any case, for certain kind of models, specifically for training with large batches, we need to have this, when we want to have a diverse set of sequence lengths in a batch and avoid the memory and compute overhead when operating on the padded seqs. Note, e.g. for a Transformer, most of the ops don't care about the T axis (e.g. all the FF matmuls), so they are just fine. Only the attention needs special care. But I think there are existing implementations (TorchTune, HuggingFace Transformers), and maybe it's not too difficult to come up with our own using FlexAttention for PyTorch.

So, this issue here is to collect some thoughts on this.

@albertz
Copy link
Member Author

albertz commented Nov 13, 2024

Note, for the TF layers backend, we had some partial support for this, but it was also quite problematic. It only was intended for the batch dim, i.e. a batch dim could actually consist of other dims packed into it, and the BatchInfo structure contained information about what dims were packed into it. See: #466, #467, #920, #975, look for "flatten".

@albertz
Copy link
Member Author

albertz commented Nov 13, 2024

Note, FlashAttention has flash_attn_varlen_qkvpacked_func (API), with args:

        qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
        cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
           of the sequences in the batch, used to index into qkv.
        max_seqlen: int. Maximum sequence length in the batch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants