Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Rotary Positional Embeddings (both non-fused and fused kernel) #99

Open
wants to merge 7 commits into
base: main_perf
Choose a base branch
from

Conversation

alexkranias-amd
Copy link

@alexkranias-amd alexkranias-amd commented Nov 13, 2024

Motivation

Original Paper: RoFormer: Enhanced Transformer with Rotary Position Embedding

Rotary Positional Embeddings (RoPEs) are a common positional embedding type used in many transformer models today.

RoPEs work by applying a unique rotation transformation to the vectors that represent each token within our q and k tensors based on each token's respective position in the sequence $$m$$.

To compute attention, we must first compute $$\text{matmul(}Q \text{,} ~ K^T \text{)}$$. This effectively is taking the dot product between the vector embeddings of tokens in $$Q$$ and $$K^T$$. Given two tokens at positions $$i$$ and $$j$$, the closer $$i$$ and $$j$$ are to each other, then their vector embeddings will end up getting rotated roughly the same amount, and the dot product between these two token embedding vectors will be largely unchanged. However, the further away these tokens are from each other, the more the transformation applied to these two vector embeddings diverges, which causes the dot product to decay. As the dot product decays, so does the attention weighting applied between the two tokens, and likewise this effectively leads the model to learning that for a single token the tokens near it should be paid more attention to than the tokens much further away.

Dot Product Decay

A more detailed explanation

Fundamentally RoPEs work by dividing the embedding space of our q and k vectors (the $$\text{head}$$ _ $$\text{dim}$$) into many chunks of two. Each 2-dimensional chunk can be thought of as a vector subcomponent of q and k projected on a 2-dimensional plane that exists within the higher dimensional space of the q and k embedding. RoPE "rotates" the planar chunks of our q and k vectors uniquely based on the index of the token in the sequence. Each "chunk" is rotated some unique amount $$\theta_{m, d/2}$$ based on the index of the token in the sequence $$m$$, and the dimension $$d$$ of the subcomponents of q and k being rotated.

RoPE Implementation Details

@alexkranias-amd alexkranias-amd changed the title Added Rotary Embedding (non-fused kernel) Added Rotary Embedding (both non-fused and fused kernel) Nov 14, 2024
@alexkranias-amd alexkranias-amd changed the title Added Rotary Embedding (both non-fused and fused kernel) Added Support for Rotary Positional Embeddings (both non-fused and fused kernel) Nov 14, 2024
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("rotary_interleaved", [True])
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
@pytest.mark.parametrize("rotary_fraction", [0.5, 1.0])
Copy link
Collaborator

@micmelesse micmelesse Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have this match the original tests as much as possible.

@@ -1921,7 +1922,7 @@ def test_flash_attn_kvcache(
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
batch_size = 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Oddly enough batch_size = 4 causes the tests to segfault randomly. batch_size = 2 however passes just fine. Might be something to explore?

Rotary_interleaved = rotary_interleaved,
Rotary_conjugate = rotary_conjugate,
IS_SEQLEN_OFFSETS_TENSOR = isinstance(cache_seqlens, torch.Tensor),
IS_VARLEN = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is IS_VARLEN manually set to false

Copy link
Author

@alexkranias-amd alexkranias-amd Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because with rotary you have the option to use varlen but decode only use batched. We don't have a varlen parameter to pass in from decode tests

COS, # tensor of shape (seqlen (m), ro_dim // 2)
SIN, # tensor of shape (seqlen (m), ro_dim // 2)
SEQLEN_OFFSET, # we use this as an offset into COS and SIN to apply the correct rotation
SEQLEN_OFFSET_IS_TENSOR: tl.constexpr, # if seqlen_offset is a tensor it has shape (num_batch, )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do have to versions of SEQLEN_OFFSET? It seems an int and tensor which we do a load


ro_dim_half = rotary_dim // 2 # length of cos/sin

if SEQLEN_OFFSET_IS_TENSOR:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have 2 versions.

# Misc
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
TRANSPOSE: tl.constexpr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do the transpose inside the rotary function.. It probably makes sense for the caller to do that if it makes sense for them

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants