You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, the token sampling for MoD Infini-Former at inference time can result in different length sequences for each observation in the batch. The current workaround is to force the batch size to one and loop through the observations in the batch, which is highly inefficient.
There are two main options for handling this efficiently:
Pad the sampled sequences to the longest sequence length in such a way that the additional tokens contribute nothing to downstream calculations.
Wait for PyTorch to implement a ragged tensor type
I'm likely to pursue the first because there's no telling how long it'll be before the PyTorch devs add ragged tensors.
The text was updated successfully, but these errors were encountered:
Unfortunately, the fix you introduced assumes that calling .forward_() produces a valid result when called on the original input. What needs to happen during inference is for .forward() to use sample_mask_seg to pad the samples along the token dimension until they all have the same length. The part I haven't gotten around to is going through the math to determine a choice of padding token that doesn't affect downstream calculations.
For the moment, I'm going to revert the change, just to maintain functionality (slow as it is). I really appreciate your putting in time on this though!
Currently, the token sampling for MoD Infini-Former at inference time can result in different length sequences for each observation in the batch. The current workaround is to force the batch size to one and loop through the observations in the batch, which is highly inefficient.
There are two main options for handling this efficiently:
I'm likely to pursue the first because there's no telling how long it'll be before the PyTorch devs add ragged tensors.
The text was updated successfully, but these errors were encountered: