From 75eee6ffc41dcd8dce2e5d7c0f7c2fe380d3943d Mon Sep 17 00:00:00 2001 From: HyoungWook Nam Date: Wed, 7 Aug 2024 09:12:18 -0500 Subject: [PATCH] DPAS operand A and operand B conversion to LinearLayout (#1746) This PR adds DPAS -> LinearLayout conversion of operand A and B layouts. Currently, Triton does not use LinearLayout conversion for DotOperand layouts (A and B). I have included operand A/B support in DPAStoLinearLayout function for potential future use. It tested with the unit tests added. --- .../IR/LinearLayoutConversions.h | 7 +- .../IR/LinearLayoutConversions.cpp | 206 +++++++++++++++--- .../TritonGPU/DPAStoLinearLayoutTest.cpp | 73 +++++++ 3 files changed, 254 insertions(+), 32 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h b/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h index ba497942ae..8153393a1b 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h @@ -11,8 +11,13 @@ namespace mlir::triton::gpu { +// DPAS operand A: opIdx=0 +// DPAS operand B: opIdx=1 +// DPAS operand C (default): opIdx=2 +// Operand A and B conversion are not used yet std::optional DPAStoLinearLayout(ArrayRef shape, - Attribute layout); + Attribute layout, + unsigned opIdx = 2); } // namespace mlir::triton::gpu diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index 2758ed1cfe..763efe291a 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -336,7 +336,107 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, // The layout example repeat_count=8, systolic_depth=8, // execution_size=16 and operands_per_chan=2 for warp size 32. -// DPASInst layout of C operand: +// For A operand: +// systolic depth = 8 +//<-----------------------------------------------------> +// opsPerChan=2 +//<---------> +// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 ^ +// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 | +// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 | +// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 | repeat count <= 8 +// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 | +// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 | +// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 | +// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 v +// In this case, the LinearLayout bases are: +// Register: {{0,1}, {4,0}} +// Lane: {{0,2}, {0,4}, {0,8}, {1,0}, {2,0}} +std::vector> DPASRegBasesA(int opsPerChannel, + int repeatCount, + int threadsPerWarp, + int systolicDepth) { + int rowPerWarp = threadsPerWarp / systolicDepth; + int warpRepeats = repeatCount / rowPerWarp; + std::vector> regBases; + + for (int opc = 1; opc < opsPerChannel; opc *= 2) { + regBases.push_back({0, opc}); + } + + for (int warp = 1; warp < warpRepeats; warp *= 2) { + regBases.push_back({warp * rowPerWarp, 0}); + } + + return regBases; +} + +std::vector> +DPASLaneBasesA(int opsPerChannel, int threadsPerWarp, int systolicDepth) { + std::vector> laneBases; + + for (int tid = 1; tid < systolicDepth; tid *= 2) { + laneBases.push_back({0, opsPerChannel * tid}); + } + for (int tid = systolicDepth; tid < threadsPerWarp; tid *= 2) { + laneBases.push_back({tid / systolicDepth, 0}); + } + + return laneBases; +} + +// For B operand: +// execution size = 16 +//<--------------------------------------------------> +// t0 t1 t2 t3 ~ t12 t13 t14 t15 ^ ^ +//. . . . . . . . . | opsPerChan=2 | +// t0 t1 t2 t3 ~ t12 t13 t14 t15 v | +// t16 t17 t18 t19 ~ t28 t29 t30 t31 | +//. . . . . . . . . | +// t16 t17 t18 t19 ~ t28 t29 t30 t31 | systolic depth = 8 +// t0 t1 t2 t3 ~ t12 t13 t14 t15 | +//. . . . . . . . . | +// t0 t1 t2 t3 ~ t12 t13 t14 t15 | +// t16 t17 t18 t19 ~ t28 t29 t30 t31 | +//. . . . . . . . . | +// t16 t17 t18 t19 ~ t28 t29 t30 t31 v +// In this case, the LinearLayout bases are: +// Register: {{1,0}, {4,0}, {8,0}} +// Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {2,0}} +std::vector> DPASRegBasesB(int opsPerChannel, + int executionSize, + int threadsPerWarp, + int systolicDepth) { + int rowsPerWarp = threadsPerWarp / executionSize; + int warpRepeats = systolicDepth / rowsPerWarp; + std::vector> regBases; + + for (int opc = 1; opc < opsPerChannel; opc *= 2) { + regBases.push_back({opc, 0}); + } + for (int rid = rowsPerWarp; rid < systolicDepth; rid *= 2) { + regBases.push_back({rid * opsPerChannel, 0}); + } + + return regBases; +} + +std::vector> +DPASLaneBasesB(int opsPerChannel, int threadsPerWarp, int executionSize) { + std::vector> laneBases; + + for (int tid = 1; tid < executionSize; tid *= 2) { + laneBases.push_back({0, tid}); + } + int rowsPerWarp = threadsPerWarp / executionSize; + for (int row = 1; row < rowsPerWarp; row *= 2) { + laneBases.push_back({row * opsPerChannel, 0}); + } + + return laneBases; +} + +// For C operand: // execution size = 16 //<----------------------------------> // t0 t1 t2 t3 ~ t12 t13 t14 t15 ^ @@ -348,15 +448,13 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, // In this case, the LinearLayout bases are: // Register: {{2,0}, {4,0}} // Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {1,0}} -// Currently, LinearLayout is not supported for DotOperandEncoding -// so only Operand C conversion is implemented. std::vector> DPASRegBasesC(int repeatCount, int executionSize, int threadsPerWarp) { int rowsPerWarp = threadsPerWarp / executionSize; std::vector> regBases; - for (int rid = rowsPerWarp; rid < repeatCount; rid = rid * 2) { + for (int rid = rowsPerWarp; rid < repeatCount; rid *= 2) { regBases.push_back({rid, 0}); } @@ -365,25 +463,24 @@ DPASRegBasesC(int repeatCount, int executionSize, int threadsPerWarp) { std::vector> DPASLaneBasesC(int repeatCount, int executionSize, int threadsPerWarp) { - std::vector> laneBases; - for (int tid = 1; tid < executionSize; tid = tid * 2) { + for (int tid = 1; tid < executionSize; tid *= 2) { laneBases.push_back({0, tid}); } int rowsPerWarp = threadsPerWarp / executionSize; - for (int row = 1; row < rowsPerWarp; row = row * 2) { + for (int row = 1; row < rowsPerWarp; row *= 2) { laneBases.push_back({row, 0}); } return laneBases; } -std::optional DPAStoLinearLayout(ArrayRef shape, - Attribute layout) { - +std::optional +DPAStoLinearLayout(ArrayRef shape, Attribute layout, unsigned opIdx) { + assert(opIdx < 3 && opIdx >= 0); auto dpas = dyn_cast(layout); - assert(dpas && "Must be DPAS Operand C layout"); + assert(dpas && "Must be DPAS layout"); int rank = shape.size(); assert(rank == dpas.getWarpsPerCTA().size()); @@ -397,33 +494,80 @@ std::optional DPAStoLinearLayout(ArrayRef shape, const SmallVector warpsPerCTA = dpas.getWarpsPerCTA(); int threadsPerWarp = triton::gpu::getWarpSize(dpas); + unsigned opsPerChannel = dpas.getOpsPerChannel(); auto repCluster = dpas.getRepCluster(); - SmallVector numReps = dpas.getDPASRepetitions(shape, 2); + SmallVector numReps = dpas.getDPASRepetitions(shape, opIdx); auto tileLayout = LinearLayout::empty(); + int systolicDepth = dpas.getSystolicDepth(); int repeatCount = dpas.getRepeatCount(); int executionSize = dpas.getExecutionSize(); + unsigned KDim = 0; + unsigned nonKDim = 0; + if (opIdx == 0) { // Operand A + auto regBasesA = DPASRegBasesA(opsPerChannel, repeatCount, threadsPerWarp, + systolicDepth); + auto laneBasesA = + DPASLaneBasesA(opsPerChannel, threadsPerWarp, systolicDepth); + tileLayout = LinearLayout({{kRegister, regBasesA}, {kLane, laneBasesA}}, + outDimNames); + // A only repeats by repCluster[0] + tileLayout *= + LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]); + nonKDim = 0; + KDim = 1; + } else if (opIdx == 1) { // Operand B + auto regBasesB = DPASRegBasesB(opsPerChannel, executionSize, threadsPerWarp, + systolicDepth); + auto laneBasesB = + DPASLaneBasesB(opsPerChannel, threadsPerWarp, executionSize); + tileLayout = LinearLayout({{kRegister, regBasesB}, {kLane, laneBasesB}}, + outDimNames); + // B only repeats by repCluster[1] + tileLayout *= + LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]); + nonKDim = 1; + KDim = 0; + } else { // opIdx=2 -> Operand C + auto regBasesC = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp); + auto laneBasesC = + DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp); + tileLayout = LinearLayout({{kRegister, regBasesC}, {kLane, laneBasesC}}, + outDimNames); + // The per-inst layout is repeated at each repCluster. + // Hence, multiply with the identity layouts starting from the + // least significant dimension. + tileLayout *= + LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]); + tileLayout *= + LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]); + nonKDim = 0; + KDim = 1; + } - auto regBases = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp); - auto laneBases = DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp); - tileLayout = - LinearLayout({{kRegister, regBases}, {kLane, laneBases}}, outDimNames); - - // The per-inst layout is repeated at each repCluster. - // Hence, multiply with the identity layouts starting from the - // least significant dimension. - tileLayout *= - LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]); + // Operand A/B repeats through the K-dimension first then repeats + // through non-K dimension. tileLayout *= - LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]); - - // Then, it is repeated by DPASRepetitions to form per-Warp layout. - tileLayout *= LinearLayout::identity1D(numReps[1], kRegister, outDimNames[1]); - tileLayout *= LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]); - - // Finally, per-warp layout is repeated among the warps in the CTA. - LinearLayout warpLayout = - identityND(S("warp"), dpas.getWarpsPerCTA(), {0, 1}, outDimNames); + LinearLayout::identity1D(numReps[KDim], kRegister, outDimNames[KDim]); + tileLayout *= LinearLayout::identity1D(numReps[nonKDim], kRegister, + outDimNames[nonKDim]); + + // For Operand C, warps split the tensor identically. + // For Operand A and B, warps in the K-dimension share the same data. + // In these cases, the warp hops for K-dimensions are zero. + LinearLayout warpLayout = LinearLayout::empty(); + StringAttr kWarp = S("warp"); + if (opIdx == 0) { + warpLayout = + LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]); + warpLayout *= LinearLayout::zeros1D(warpsPerCTA[1], kWarp, outDimNames[1]); + } else if (opIdx == 1) { + warpLayout = LinearLayout::zeros1D(warpsPerCTA[0], kWarp, outDimNames[0]); + warpLayout *= + LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]); + } else { /* opIdx == 2 */ + warpLayout = identityND(kWarp, warpsPerCTA, {0, 1}, outDimNames); + } LinearLayout ctaLayout = tileLayout * warpLayout; return combineCtaCgaWithShape(ctaLayout, CTALayoutAttr::getDefault(ctx, rank), diff --git a/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp b/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp index 19c2c31ccd..caf374e3f9 100644 --- a/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp +++ b/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp @@ -39,6 +39,7 @@ class DPAStoLinearLayoutTest : public ::testing::Test { }; TEST_F(DPAStoLinearLayoutTest, DPAS_perInst) { + // Default: Operand C EXPECT_EQ(DPAStoLinearLayout({8, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32)), LinearLayout( { @@ -57,6 +58,28 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_perInst) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); + // Test Operand A (opIdx=0) + EXPECT_EQ( + DPAStoLinearLayout({8, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32), 0), + LinearLayout( + { + {S("register"), {{0, 1}, {4, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + // Test Operand B (opIdx=1) + EXPECT_EQ( + DPAStoLinearLayout({16, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32), 1), + LinearLayout( + { + {S("register"), {{1, 0}, {4, 0}, {8, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {2, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); } TEST_F(DPAStoLinearLayoutTest, DPAS_withRepCluster) { @@ -70,6 +93,28 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withRepCluster) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); + // Test Operand A (opIdx=0) + EXPECT_EQ( + DPAStoLinearLayout({32, 16}, dpas({1, 1}, 8, 8, 16, 2, {4, 2}, 32), 0), + LinearLayout( + { + {S("register"), {{0, 1}, {4, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + // Test Operand B (opIdx=1) + EXPECT_EQ( + DPAStoLinearLayout({16, 32}, dpas({1, 1}, 8, 8, 16, 2, {4, 2}, 32), 1), + LinearLayout( + { + {S("register"), {{1, 0}, {4, 0}, {8, 0}, {0, 16}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {2, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); EXPECT_EQ(DPAStoLinearLayout({32, 32}, dpas({1, 1}, 8, 8, 16, 1, {4, 2}, 16)), LinearLayout( { @@ -103,6 +148,34 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withWarp) { {S("dim0"), S("dim1")})); } +TEST_F(DPAStoLinearLayoutTest, DPAS_withWarpOperandA) { + EXPECT_EQ( + DPAStoLinearLayout({64, 64}, dpas({2, 2}, 8, 8, 16, 2, {4, 2}, 32), 0), + LinearLayout( + { + {S("register"), + {{0, 1}, {4, 0}, {8, 0}, {16, 0}, {0, 16}, {0, 32}}}, + {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}, + {S("warp"), {{32, 0}, {0, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + +TEST_F(DPAStoLinearLayoutTest, DPAS_withWarpOperandB) { + EXPECT_EQ( + DPAStoLinearLayout({64, 64}, dpas({2, 2}, 8, 8, 16, 2, {4, 2}, 32), 1), + LinearLayout( + { + {S("register"), + {{1, 0}, {4, 0}, {8, 0}, {0, 16}, {16, 0}, {32, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {2, 0}}}, + {S("warp"), {{0, 0}, {0, 32}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); +} + TEST_F(DPAStoLinearLayoutTest, DPAS_withDPASRepetitions) { EXPECT_EQ(DPAStoLinearLayout({64, 64}, dpas({2, 1}, 8, 8, 16, 2, {4, 2}, 32)), LinearLayout(