-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Support for Rotary Positional Embeddings (both non-fused and fused kernel) #99
base: main_perf
Are you sure you want to change the base?
Conversation
# @pytest.mark.parametrize("rotary_fraction", [0.0]) | ||
# @pytest.mark.parametrize("rotary_interleaved", [True]) | ||
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) | ||
@pytest.mark.parametrize("rotary_fraction", [0.5, 1.0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have this match the original tests as much as possible.
@@ -1921,7 +1922,7 @@ def test_flash_attn_kvcache( | |||
device = "cuda" | |||
# set seed | |||
torch.random.manual_seed(0) | |||
batch_size = 2 | |||
batch_size = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Oddly enough batch_size = 4
causes the tests to segfault randomly. batch_size = 2
however passes just fine. Might be something to explore?
Rotary_interleaved = rotary_interleaved, | ||
Rotary_conjugate = rotary_conjugate, | ||
IS_SEQLEN_OFFSETS_TENSOR = isinstance(cache_seqlens, torch.Tensor), | ||
IS_VARLEN = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is IS_VARLEN
manually set to false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because with rotary you have the option to use varlen but decode only use batched. We don't have a varlen parameter to pass in from decode tests
COS, # tensor of shape (seqlen (m), ro_dim // 2) | ||
SIN, # tensor of shape (seqlen (m), ro_dim // 2) | ||
SEQLEN_OFFSET, # we use this as an offset into COS and SIN to apply the correct rotation | ||
SEQLEN_OFFSET_IS_TENSOR: tl.constexpr, # if seqlen_offset is a tensor it has shape (num_batch, ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do have to versions of SEQLEN_OFFSET? It seems an int and tensor which we do a load
|
||
ro_dim_half = rotary_dim // 2 # length of cos/sin | ||
|
||
if SEQLEN_OFFSET_IS_TENSOR: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we have 2 versions.
# Misc | ||
INTERLEAVED: tl.constexpr, | ||
CONJUGATE: tl.constexpr, | ||
TRANSPOSE: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should do the transpose inside the rotary function.. It probably makes sense for the caller to do that if it makes sense for them
Motivation
Original Paper: RoFormer: Enhanced Transformer with Rotary Position Embedding
Rotary Positional Embeddings (RoPEs) are a common positional embedding type used in many transformer models today.
RoPEs work by applying a unique rotation transformation to the vectors that represent each token within our q and k tensors based on each token's respective position in the sequence$$m$$ .
To compute attention, we must first compute$$\text{matmul(}Q \text{,} ~ K^T \text{)}$$ . This effectively is taking the dot product between the vector embeddings of tokens in $$Q$$ and $$K^T$$ . Given two tokens at positions $$i$$ and $$j$$ , the closer $$i$$ and $$j$$ are to each other, then their vector embeddings will end up getting rotated roughly the same amount, and the dot product between these two token embedding vectors will be largely unchanged. However, the further away these tokens are from each other, the more the transformation applied to these two vector embeddings diverges, which causes the dot product to decay. As the dot product decays, so does the attention weighting applied between the two tokens, and likewise this effectively leads the model to learning that for a single token the tokens near it should be paid more attention to than the tokens much further away.
A more detailed explanation
Fundamentally RoPEs work by dividing the embedding space of our q and k vectors (the$$\text{head}$$ _ $$\text{dim}$$ ) into many chunks of two. Each 2-dimensional chunk can be thought of as a vector subcomponent of q and k projected on a 2-dimensional plane that exists within the higher dimensional space of the q and k embedding. RoPE "rotates" the planar chunks of our q and k vectors uniquely based on the index of the token in the sequence. Each "chunk" is rotated some unique amount $$\theta_{m, d/2}$$ based on the index of the token in the sequence $$m$$ , and the dimension $$d$$ of the subcomponents of q and k being rotated.