You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This here is open for discussion on what we want in RETURNN.
Packed tensors / packed sequences / ragged tensors / jagged tensors, however you want to call them, means that instead of having a padded tensor [B,T,D], where the T dim has different dynamic seq lens, you have a tensor [B_T,D], where B_T = sum(T.seq_lens), i.e. the padding frames are removed, and the B and T dims are packed together. (The shape of dyn_size_ext of T is [B] here. That is why you can pack B with T.)
In RF, we have pack_padded to pack and also pad_packed to unpack.
We have the somewhat implicit assumption in RETURNN (as do most other ML frameworks) that padding is the canonical way to represent a tensor with multiple dims where one dim is dynamic. Maybe it really is (it's definitely the most simple). However, the packed representation is also valid.
The main question here in this issue: Should we have more direct support for packed tensors?
I imagine sth like this:
The packed dim (B_T here) knows about what dims it consists of.
All operations on dims, like reduce, matmul, whatever, would directly support working on the dims that are contained in the packed dims, e.g. you can do y = reduce(x, axis=T) even though x has shape [B_T,D], and you would get y with shape [B,D]. Some ops can provide efficient code to directly handle this, otherwise the fallback would be to unpack when some of those dims are referred to. But it should always work.
Effectively, it means, it doesn't really make a difference for your code whether some dims are packed or not. Similarly as the order of the dims are also not relevant.
There are some things a bit unclear:
What should x.dims return? Should this give the low-level view on the packed dim? I think yes. But then, how to check whether some dim is maybe inside a packed dim? A separate x.unpacked_dims? In any case, this means that there need to be code changes, also in user code, to really support packed tensors just in the same way as padded tensors.
But also, getting back to the main question: Do we really want this? This adds a lot of complexity, specifically in the backend. Maybe it's also prone to errors? Also, it adds some overhead everywhere. Whenever some dim is not directly found in dims, we need to also check within the packed dims (or cache that somehow, even more complexity...).
Thus, I tend to think we should not do this. But I still keep thinking about whether there are maybe easy solutions here.
In any case, for certain kind of models, specifically for training with large batches, we need to have this, when we want to have a diverse set of sequence lengths in a batch and avoid the memory and compute overhead when operating on the padded seqs. Note, e.g. for a Transformer, most of the ops don't care about the T axis (e.g. all the FF matmuls), so they are just fine. Only the attention needs special care. But I think there are existing implementations (TorchTune, HuggingFace Transformers), and maybe it's not too difficult to come up with our own using FlexAttention for PyTorch.
So, this issue here is to collect some thoughts on this.
The text was updated successfully, but these errors were encountered:
Note, for the TF layers backend, we had some partial support for this, but it was also quite problematic. It only was intended for the batch dim, i.e. a batch dim could actually consist of other dims packed into it, and the BatchInfo structure contained information about what dims were packed into it. See: #466, #467, #920, #975, look for "flatten".
Note, FlashAttention has flash_attn_varlen_qkvpacked_func (API), with args:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
This here is open for discussion on what we want in RETURNN.
Packed tensors / packed sequences / ragged tensors / jagged tensors, however you want to call them, means that instead of having a padded tensor [B,T,D], where the T dim has different dynamic seq lens, you have a tensor [B_T,D], where B_T = sum(T.seq_lens), i.e. the padding frames are removed, and the B and T dims are packed together. (The shape of dyn_size_ext of T is [B] here. That is why you can pack B with T.)
In RF, we have
pack_padded
to pack and alsopad_packed
to unpack.We have the somewhat implicit assumption in RETURNN (as do most other ML frameworks) that padding is the canonical way to represent a tensor with multiple dims where one dim is dynamic. Maybe it really is (it's definitely the most simple). However, the packed representation is also valid.
The main question here in this issue: Should we have more direct support for packed tensors?
I imagine sth like this:
reduce
,matmul
, whatever, would directly support working on the dims that are contained in the packed dims, e.g. you can doy = reduce(x, axis=T)
even though x has shape [B_T,D], and you would get y with shape [B,D]. Some ops can provide efficient code to directly handle this, otherwise the fallback would be to unpack when some of those dims are referred to. But it should always work.Effectively, it means, it doesn't really make a difference for your code whether some dims are packed or not. Similarly as the order of the dims are also not relevant.
There are some things a bit unclear:
x.dims
return? Should this give the low-level view on the packed dim? I think yes. But then, how to check whether some dim is maybe inside a packed dim? A separatex.unpacked_dims
? In any case, this means that there need to be code changes, also in user code, to really support packed tensors just in the same way as padded tensors.But also, getting back to the main question: Do we really want this? This adds a lot of complexity, specifically in the backend. Maybe it's also prone to errors? Also, it adds some overhead everywhere. Whenever some dim is not directly found in
dims
, we need to also check within the packed dims (or cache that somehow, even more complexity...).Thus, I tend to think we should not do this. But I still keep thinking about whether there are maybe easy solutions here.
In any case, for certain kind of models, specifically for training with large batches, we need to have this, when we want to have a diverse set of sequence lengths in a batch and avoid the memory and compute overhead when operating on the padded seqs. Note, e.g. for a Transformer, most of the ops don't care about the T axis (e.g. all the FF matmuls), so they are just fine. Only the attention needs special care. But I think there are existing implementations (TorchTune, HuggingFace Transformers), and maybe it's not too difficult to come up with our own using FlexAttention for PyTorch.
So, this issue here is to collect some thoughts on this.
The text was updated successfully, but these errors were encountered: