From 2a4ce1843bbf1f179ae7a2781c50b9199ec2530f Mon Sep 17 00:00:00 2001 From: victor-eds Date: Mon, 11 Nov 2024 15:06:20 +0000 Subject: [PATCH] [XPU][OptEW] Allow use-def graphs of elementwise optimizable operations Allow operands being used by other optimizable operations to enable elementwise operations graph optimizations. Signed-off-by: victor-eds --- test/TritonIntelGPU/optimize-elementwise.mlir | 31 ++++++++ .../OptimizeElementwiseParallelism.cpp | 75 ++++++++++++++++--- 2 files changed, 94 insertions(+), 12 deletions(-) diff --git a/test/TritonIntelGPU/optimize-elementwise.mlir b/test/TritonIntelGPU/optimize-elementwise.mlir index f01863e3c..069d37067 100644 --- a/test/TritonIntelGPU/optimize-elementwise.mlir +++ b/test/TritonIntelGPU/optimize-elementwise.mlir @@ -151,3 +151,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>> } } + +// ----- + +// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> +// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> + +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test_multi_user( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, %[[VAL_1:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, %[[VAL_2:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>) + tt.func @test_multi_user(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg2: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> { +// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_4:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : tensor<16xf32, #[[$ATTR_0]]> + %0 = arith.addf %arg0, %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> +// CHECK: %[[VAL_6:.*]] = triton_gpu.convert_layout %[[VAL_5]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> +// CHECK: %[[VAL_7:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_8:.*]] = triton_gpu.convert_layout %[[VAL_2]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : tensor<16xf32, #[[$ATTR_0]]> + %1 = arith.addf %arg0, %arg2 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> +// CHECK: %[[VAL_10:.*]] = triton_gpu.convert_layout %[[VAL_9]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> +// CHECK: %[[VAL_11:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_12:.*]] = triton_gpu.convert_layout %[[VAL_10]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_11]], %[[VAL_12]] : tensor<16xf32, #[[$ATTR_0]]> + %2 = arith.addf %0, %1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> +// CHECK: %[[VAL_14:.*]] = triton_gpu.convert_layout %[[VAL_13]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> +// CHECK: tt.return %[[VAL_14]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> + tt.return %2 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp index 5172a58f2..9f5e8cf57 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeElementwiseParallelism.cpp @@ -77,6 +77,47 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout, linearLayout, numWorkGroupPos, rewriter); } +/// Generic checks for the operation not looking at the tensor type. +bool isCandidateOp(Operation *op) { + // Rely on this for a simpler pass. + if (!op->hasTrait() || + op->getNumResults() != 1) + return false; + + // Skip complex operations. + if (op->hasSuccessors() || op->getNumRegions() != 0) + return false; + + return true; +} + +bool optimizationDoesNotWorsenRegisterPressure( + Value value, RankedTensorType newType, SmallPtrSetImpl &visited) { + if (!visited.insert(value).second) + return true; + // All users must be operations we will optimize too or layout conversions we + // will introduce later. + return llvm::all_of(value.getUses(), [&visited, newType](OpOperand &operand) { + Operation *owner = operand.getOwner(); + + // We will be introducing just this operation later. + if (auto convertLayout = dyn_cast(owner)) + return convertLayout.getResult().getType() == newType; + + // Only allow candidates. Check only operation constraints. We do not have + // to check the type as we did already. + if (!owner->hasTrait() || !isCandidateOp(owner)) + return false; + + // Check other operands fit the constraints. + return llvm::all_of(owner->getOperands(), + [&visited, newType](Value operand) { + return optimizationDoesNotWorsenRegisterPressure( + operand, newType, visited); + }); + }); +} + /// Get optimized unbroadcasted tensor type. /// /// Get optimized ranked tensor type after unbroadcasting. As we only support 1D @@ -110,13 +151,10 @@ struct ElementwiseOptPattern final LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { - // Rely on this for a simpler pass. - if (!op->hasTrait() || - op->getNumResults() != 1) - return failure(); + LLVM_DEBUG(llvm::dbgs() << "Checking operation:\n" << *op << "\n"); - // Skip complex operations. - if (op->hasSuccessors() || op->getNumRegions() != 0) + // Rely on this for a simpler pass. + if (!isCandidateOp(op)) return failure(); // Layout optimizations only apply to tensors. @@ -132,19 +170,30 @@ struct ElementwiseOptPattern final return failure(); std::optional linearLayout = toLinearLayout(type.getShape(), layout); - if (!linearLayout || !isValidLayoutForUnbroadcast(*linearLayout, rewriter)) - return failure(); - // Check the operands are not used by other operations. This will prevent - // register pressure increase: - if (!llvm::all_of(op->getOperands(), - [](Value val) { return val.hasOneUse(); })) + LLVM_DEBUG(llvm::dbgs() << "Checking linear layout:\n" + << linearLayout << "\n"); + + if (!linearLayout || !isValidLayoutForUnbroadcast(*linearLayout, rewriter)) return failure(); // As we are dealing with 1D tensors, we can do a simple transform to obtain // a more optimized operation. Location loc = op->getLoc(); RankedTensorType newType = getOptimizedType(type, *linearLayout, rewriter); + + LLVM_DEBUG(llvm::dbgs() << "Would convert to type:\n" << newType << "\n"); + + // Check the operands are not used by other operations. This will prevent + // register pressure increase: + if (SmallPtrSet visited; + !llvm::all_of(op->getOperands(), [&visited, newType](Value operand) { + return optimizationDoesNotWorsenRegisterPressure(operand, newType, + visited); + })) + return failure(); + + // Obtain converted operands. SmallVector newOperands(op->getNumOperands()); llvm::transform(op->getOperands(), std::begin(newOperands), [&rewriter, loc, newType](Value operand) { @@ -164,6 +213,8 @@ struct ElementwiseOptPattern final Value newValue = newElementwiseOp->getResult(0); rewriter.replaceOpWithNewOp(op, type, newValue); + LLVM_DEBUG(llvm::dbgs() << "Conversion took place.\n"); + return success(); } };