Skip to content

Commit

Permalink
[XPU][OptEW] Allow use-def graphs of elementwise optimizable operations
Browse files Browse the repository at this point in the history
Allow operands being used by other optimizable operations to enable
elementwise operations graph optimizations.

Signed-off-by: victor-eds <[email protected]>
  • Loading branch information
victor-eds committed Nov 13, 2024
1 parent a971b85 commit 2a4ce18
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 12 deletions.
31 changes: 31 additions & 0 deletions test/TritonIntelGPU/optimize-elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 :
tt.return %0 : tensor<128xf16, #triton_gpu.slice<{dim = 1, parent = #mma}>>
}
}

// -----

// CHECK: #[[$ATTR_0:.+]] = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}>
// CHECK: #[[$ATTR_1:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>

#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>

module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK-LABEL: tt.func @test_multi_user(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, %[[VAL_1:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>, %[[VAL_2:.*]]: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>)
tt.func @test_multi_user(%arg0: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg1: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, %arg2: tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> {
// CHECK: %[[VAL_3:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
// CHECK: %[[VAL_4:.*]] = triton_gpu.convert_layout %[[VAL_1]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : tensor<16xf32, #[[$ATTR_0]]>
%0 = arith.addf %arg0, %arg1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
// CHECK: %[[VAL_6:.*]] = triton_gpu.convert_layout %[[VAL_5]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
// CHECK: %[[VAL_7:.*]] = triton_gpu.convert_layout %[[VAL_0]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
// CHECK: %[[VAL_8:.*]] = triton_gpu.convert_layout %[[VAL_2]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : tensor<16xf32, #[[$ATTR_0]]>
%1 = arith.addf %arg0, %arg2 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
// CHECK: %[[VAL_10:.*]] = triton_gpu.convert_layout %[[VAL_9]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
// CHECK: %[[VAL_11:.*]] = triton_gpu.convert_layout %[[VAL_6]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
// CHECK: %[[VAL_12:.*]] = triton_gpu.convert_layout %[[VAL_10]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>> -> tensor<16xf32, #[[$ATTR_0]]>
// CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_11]], %[[VAL_12]] : tensor<16xf32, #[[$ATTR_0]]>
%2 = arith.addf %0, %1 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
// CHECK: %[[VAL_14:.*]] = triton_gpu.convert_layout %[[VAL_13]] : tensor<16xf32, #[[$ATTR_0]]> -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
// CHECK: tt.return %[[VAL_14]] : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #[[$ATTR_1]]}>>
tt.return %2 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,47 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
linearLayout, numWorkGroupPos, rewriter);
}

/// Generic checks for the operation not looking at the tensor type.
bool isCandidateOp(Operation *op) {
// Rely on this for a simpler pass.
if (!op->hasTrait<OpTrait::SameOperandsAndResultType>() ||
op->getNumResults() != 1)
return false;

// Skip complex operations.
if (op->hasSuccessors() || op->getNumRegions() != 0)
return false;

return true;
}

bool optimizationDoesNotWorsenRegisterPressure(
Value value, RankedTensorType newType, SmallPtrSetImpl<Value> &visited) {
if (!visited.insert(value).second)
return true;
// All users must be operations we will optimize too or layout conversions we
// will introduce later.
return llvm::all_of(value.getUses(), [&visited, newType](OpOperand &operand) {
Operation *owner = operand.getOwner();

// We will be introducing just this operation later.
if (auto convertLayout = dyn_cast<ConvertLayoutOp>(owner))
return convertLayout.getResult().getType() == newType;

// Only allow candidates. Check only operation constraints. We do not have
// to check the type as we did already.
if (!owner->hasTrait<OpTrait::Elementwise>() || !isCandidateOp(owner))
return false;

// Check other operands fit the constraints.
return llvm::all_of(owner->getOperands(),
[&visited, newType](Value operand) {
return optimizationDoesNotWorsenRegisterPressure(
operand, newType, visited);
});
});
}

/// Get optimized unbroadcasted tensor type.
///
/// Get optimized ranked tensor type after unbroadcasting. As we only support 1D
Expand Down Expand Up @@ -110,13 +151,10 @@ struct ElementwiseOptPattern final

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final {
// Rely on this for a simpler pass.
if (!op->hasTrait<OpTrait::SameOperandsAndResultType>() ||
op->getNumResults() != 1)
return failure();
LLVM_DEBUG(llvm::dbgs() << "Checking operation:\n" << *op << "\n");

// Skip complex operations.
if (op->hasSuccessors() || op->getNumRegions() != 0)
// Rely on this for a simpler pass.
if (!isCandidateOp(op))
return failure();

// Layout optimizations only apply to tensors.
Expand All @@ -132,19 +170,30 @@ struct ElementwiseOptPattern final
return failure();
std::optional<LinearLayout> linearLayout =
toLinearLayout(type.getShape(), layout);
if (!linearLayout || !isValidLayoutForUnbroadcast(*linearLayout, rewriter))
return failure();

// Check the operands are not used by other operations. This will prevent
// register pressure increase:
if (!llvm::all_of(op->getOperands(),
[](Value val) { return val.hasOneUse(); }))
LLVM_DEBUG(llvm::dbgs() << "Checking linear layout:\n"
<< linearLayout << "\n");

if (!linearLayout || !isValidLayoutForUnbroadcast(*linearLayout, rewriter))
return failure();

// As we are dealing with 1D tensors, we can do a simple transform to obtain
// a more optimized operation.
Location loc = op->getLoc();
RankedTensorType newType = getOptimizedType(type, *linearLayout, rewriter);

LLVM_DEBUG(llvm::dbgs() << "Would convert to type:\n" << newType << "\n");

// Check the operands are not used by other operations. This will prevent
// register pressure increase:
if (SmallPtrSet<Value, 2> visited;
!llvm::all_of(op->getOperands(), [&visited, newType](Value operand) {
return optimizationDoesNotWorsenRegisterPressure(operand, newType,
visited);
}))
return failure();

// Obtain converted operands.
SmallVector<Value> newOperands(op->getNumOperands());
llvm::transform(op->getOperands(), std::begin(newOperands),
[&rewriter, loc, newType](Value operand) {
Expand All @@ -164,6 +213,8 @@ struct ElementwiseOptPattern final
Value newValue = newElementwiseOp->getResult(0);
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, type, newValue);

LLVM_DEBUG(llvm::dbgs() << "Conversion took place.\n");

return success();
}
};
Expand Down

0 comments on commit 2a4ce18

Please sign in to comment.