Skip to content

Commit

Permalink
[XPU][OptEW] Allow multiple warps in non-sliced dimension
Browse files Browse the repository at this point in the history
Allow multiple warps in non-sliced dimension as long as there are `n*sub_group_size`
contiguous elements per warp in the non-sliced dimension.

Signed-off-by: victor-eds <[email protected]>
  • Loading branch information
victor-eds committed Nov 12, 2024
1 parent 9e41b65 commit a971b85
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 8 deletions.
88 changes: 88 additions & 0 deletions test/TritonIntelGPU/optimize-elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,91 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return %0 : tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}
}

// -----

// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>
// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [16], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}>

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func @test_blocked_multi_warp(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> {
tt.func @test_blocked_multi_warp(%arg0: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> {
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<32xf32, #[[$ATTR_1]]>
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<32xf32, #[[$ATTR_1]]>
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<32xf32, #[[$ATTR_1]]>
%0 = arith.addf %arg0, %arg1 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<32xf32, #[[$ATTR_1]]> -> tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
// CHECK: tt.return %[[VAL_5]] : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
tt.return %0 : tensor<32xf32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}
}

// -----

// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}>
// CHECK: #[[$ATTR_1:.+]] = #triton_gpu.blocked<{sizePerThread = [32], threadsPerWarp = [16], warpsPerCTA = [4], order = [0]}>

#blocked = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [0, 1]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func @test_blocked_multi_warp_double_stride(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> {
tt.func @test_blocked_multi_warp_double_stride(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> {
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<128xf16, #[[$ATTR_1]]>
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>> -> tensor<128xf16, #[[$ATTR_1]]>
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_1]]>
%0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_1]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_0]]}>>
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
}
}

// -----

// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16], threadsPerWarp = [16], warpsPerCTA = [8], order = [0]}>
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>

#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func @test_mma_multi_warp_double_stride(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> {
tt.func @test_mma_multi_warp_double_stride(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_0]]>
%0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_0]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
}
}

// -----

// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [16], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}>
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>

#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func @test_mma_multi_warp_double_stride_repeat(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> {
tt.func @test_mma_multi_warp_double_stride_repeat(%arg0: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
// CHECK: %[[VAL_2:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<128xf16, #[[$ATTR_0]]>
// CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : tensor<128xf16, #[[$ATTR_0]]>
%0 = arith.addf %arg0, %arg1 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
// CHECK: %[[VAL_5:.*]] = triton_gpu.convert_layout %[[VAL_4]] : tensor<128xf16, #[[$ATTR_0]]> -> tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
// CHECK: tt.return %[[VAL_5]] : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,36 @@ namespace mlir::triton::gpu::intel {
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"

namespace {
bool isMultiWarpValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
int32_t numWorkGroupPos,
PatternRewriter &rewriter) {
StringAttr kLane = rewriter.getStringAttr("lane");
StringAttr kWarp = rewriter.getStringAttr("warp");
int32_t subGroupSize = linearLayout.getInDimSize(kLane);
ArrayRef<int32_t> numContiguousPerWarp = linearLayout.getBasis(kWarp, 0);
// Check the warp dimension hasn't been sliced away and we have n *
// sub_group_size contiguous elements per warp.
if (numContiguousPerWarp == ArrayRef<int32_t>{0} ||
numContiguousPerWarp[0] % subGroupSize != 0)
return false;
int32_t expectedValue = numContiguousPerWarp[0] * 2;
for (int32_t pos = 1; pos < numWorkGroupPos; ++pos) {
if (linearLayout.getBasis(kWarp, pos) != ArrayRef<int32_t>{expectedValue})
return false;
expectedValue *= 2;
}
return true;
}

/// Return whether the input linear layout can be unbroadcasted.
///
/// A layout is valid for being "unbroadcasted" along its lanes if:
/// - The 'lane' input dimension is zero: this means the lane dimension has been
/// sliced.
/// - The size of the input 'block' dimension is 1. This is true for XPU
/// backend.
/// - The size of the input 'warp' dimension is 1. This is a limitation to keep
/// things simple for now.
/// - The size of the input 'warp' dimension is 1 or there are n*sub_group_size
/// contiguous elements per warp.
///
/// Broadcasted layouts are layouts with sliced lane, warp or block (not
/// possible for XPU backend) dimensions, i.e., the same data is owned by
Expand All @@ -49,8 +70,11 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
// Only single block for now.
if (linearLayout.getInDimSize(kBlock) != 1)
return false;
// Only single warp for now.
return linearLayout.getInDimSize(kWarp) == 1;
// 'warp' dimension hasn't been sliced away and there are n*sub_group_size
// contiguous elements in each warp (or there is a single warp).
int32_t numWorkGroupPos = linearLayout.getInDimSizeLog2(kWarp);
return numWorkGroupPos == 0 || isMultiWarpValidLayoutForUnbroadcast(
linearLayout, numWorkGroupPos, rewriter);
}

/// Get optimized unbroadcasted tensor type.
Expand All @@ -61,18 +85,21 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
RankedTensorType getOptimizedType(RankedTensorType type,
const LinearLayout &linearLayout,
PatternRewriter &rewriter) {
StringAttr kWarp = rewriter.getStringAttr("warp");

auto encoding = cast<DistributedEncodingTrait>(type.getEncoding());
unsigned threadsPerWarp = product(encoding.getThreadsPerWarp());
[[maybe_unused]] unsigned warpsPerCTA = product(encoding.getWarpsPerCTA());
assert(warpsPerCTA == 1 && "Expecting single warp");
unsigned warpsPerCTA = product(encoding.getWarpsPerCTA());
[[maybe_unused]] unsigned ctaSplitNum = product(encoding.getCTASplitNum());
assert(ctaSplitNum == 1 && "Expecting single CTA");

RankedTensorType::Builder builder(type);
int32_t numWorkGroupPos = linearLayout.getInDimSizeLog2(kWarp);
unsigned sizePerThread =
numWorkGroupPos == 0 ? 1 : linearLayout.getBasis(kWarp, 0)[0];
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(rewriter.getContext(), 1);
auto newEncoding = rewriter.getAttr<BlockedEncodingAttr>(
/*sizePerThread=*/1, threadsPerWarp, /*warpsPerCTA=*/1, /*order=*/0,
ctaLayout);
sizePerThread, threadsPerWarp, warpsPerCTA, /*order=*/0, ctaLayout);
builder.setEncoding(newEncoding);
return builder;
}
Expand Down

0 comments on commit a971b85

Please sign in to comment.