From 1ef30a907ef3ad7598cf8d0ca43d115be543a384 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Wed, 31 Jul 2024 04:21:46 +0000 Subject: [PATCH] annote triton constexpr #56 --- .../fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py | 2 +- .../src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py index ebf6f3d0..cade9c85 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py @@ -186,7 +186,7 @@ def _cross_entropy_backward( pass -MAX_FUSED_SIZE = 65536 # 2**16 +MAX_FUSED_SIZE: tl.constexpr = 65536 # 2**16 class Fast_CrossEntropyLoss(torch.autograd.Function): @staticmethod diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py index 3577b586..c97c8cfc 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py @@ -17,7 +17,7 @@ import torch from .utils import calculate_settings -ROPE_GROUP_SIZE = 4 +ROPE_GROUP_SIZE: tl.constexpr = 4 @triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],}) @triton.jit