You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We have implemented a Triton kernel for matmul operations involving a telescoping cache in the telescoping-kernel branch. These kernels pass their respective correctness checks (also included), but deploying to our training pipeline is not straightforward because Triton does not support atomic-add in bf16 (see here).
We instead cast to fp16 before this op, but loss curves on a test llama3-1.8B model diverge when we do this:
Loss curves do not diverge when running the kernels in fp32. Unfortunately this sacrifices our speed gains. We're currently evaluating fp32 atomic-adds only, and will update here.
Running these matmuls in fp16 also breaks the vanilla pytorch code, so this is almost certainly a precision issue. If internal fp32 casting does not fix the diverging loss, can the kernel code be massaged to avoid these issues?
The text was updated successfully, but these errors were encountered:
daviswer
changed the title
Telescoping cache precision and convergence issues
Telescoping cache precision/throughput issues
Jul 11, 2024
Update: performing atomic-adds in fp32 produces the desired behavior, with minimal extra speed/memory overhead. Now it's a question of optimizing throughput: we're currently getting ~3850 tokens/sec/gpu for this particular training setup, compared to ~2550 for pure pytorch baseline and ~10600 for flash attn
Update 2: it turns out that the way we implemented the forward pass around the above kernels also made it amenable to standard attention with a custom mask (visualized below for seq len 512). So we're now running telescoping cache training - stably and relatively quickly - at 8B scale, using memory-efficient attention through PyTorch SDPA (as SDPA-flash attention still doesn't support custom masks, apparently). Further speedups will be possible if we can enable Flash Attention with custom masks in this context.
We have implemented a Triton kernel for matmul operations involving a telescoping cache in the
telescoping-kernel
branch. These kernels pass their respective correctness checks (also included), but deploying to our training pipeline is not straightforward because Triton does not support atomic-add in bf16 (see here).We instead cast to fp16 before this op, but loss curves on a test llama3-1.8B model diverge when we do this:
Loss curves do not diverge when running the kernels in fp32. Unfortunately this sacrifices our speed gains. We're currently evaluating fp32 atomic-adds only, and will update here.
Running these matmuls in fp16 also breaks the vanilla pytorch code, so this is almost certainly a precision issue. If internal fp32 casting does not fix the diverging loss, can the kernel code be massaged to avoid these issues?
The text was updated successfully, but these errors were encountered: