Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Telescoping cache precision/throughput issues #1

Open
daviswer opened this issue Jul 10, 2024 · 2 comments
Open

Telescoping cache precision/throughput issues #1

daviswer opened this issue Jul 10, 2024 · 2 comments

Comments

@daviswer
Copy link

We have implemented a Triton kernel for matmul operations involving a telescoping cache in the telescoping-kernel branch. These kernels pass their respective correctness checks (also included), but deploying to our training pipeline is not straightforward because Triton does not support atomic-add in bf16 (see here).

We instead cast to fp16 before this op, but loss curves on a test llama3-1.8B model diverge when we do this:

loss_curve

Loss curves do not diverge when running the kernels in fp32. Unfortunately this sacrifices our speed gains. We're currently evaluating fp32 atomic-adds only, and will update here.

Running these matmuls in fp16 also breaks the vanilla pytorch code, so this is almost certainly a precision issue. If internal fp32 casting does not fix the diverging loss, can the kernel code be massaged to avoid these issues?

@daviswer daviswer changed the title Telescoping cache precision and convergence issues Telescoping cache precision/throughput issues Jul 11, 2024
@daviswer
Copy link
Author

Update: performing atomic-adds in fp32 produces the desired behavior, with minimal extra speed/memory overhead. Now it's a question of optimizing throughput: we're currently getting ~3850 tokens/sec/gpu for this particular training setup, compared to ~2550 for pure pytorch baseline and ~10600 for flash attn

loss_curve

@daviswer
Copy link
Author

Update 2: it turns out that the way we implemented the forward pass around the above kernels also made it amenable to standard attention with a custom mask (visualized below for seq len 512). So we're now running telescoping cache training - stably and relatively quickly - at 8B scale, using memory-efficient attention through PyTorch SDPA (as SDPA-flash attention still doesn't support custom masks, apparently). Further speedups will be possible if we can enable Flash Attention with custom masks in this context.

image (4)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant