Skip to content

Commit

Permalink
DPAS operand A and operand B conversion to LinearLayout (#1746)
Browse files Browse the repository at this point in the history
This PR adds DPAS -> LinearLayout conversion of operand A and B layouts.
Currently, Triton does not use LinearLayout conversion for DotOperand
layouts (A and B).
I have included operand A/B support in DPAStoLinearLayout function for
potential future use.
It tested with the unit tests added.
  • Loading branch information
hwnam831 authored Aug 7, 2024
1 parent e542e39 commit 75eee6f
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@

namespace mlir::triton::gpu {

// DPAS operand A: opIdx=0
// DPAS operand B: opIdx=1
// DPAS operand C (default): opIdx=2
// Operand A and B conversion are not used yet
std::optional<LinearLayout> DPAStoLinearLayout(ArrayRef<int64_t> shape,
Attribute layout);
Attribute layout,
unsigned opIdx = 2);

} // namespace mlir::triton::gpu

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,107 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,

// The layout example repeat_count=8, systolic_depth=8,
// execution_size=16 and operands_per_chan=2 for warp size 32.
// DPASInst layout of C operand:
// For A operand:
// systolic depth = 8
//<----------------------------------------------------->
// opsPerChan=2
//<--------->
// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 ^
// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 | repeat count <= 8
// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 |
// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 v
// In this case, the LinearLayout bases are:
// Register: {{0,1}, {4,0}}
// Lane: {{0,2}, {0,4}, {0,8}, {1,0}, {2,0}}
std::vector<std::vector<int32_t>> DPASRegBasesA(int opsPerChannel,
int repeatCount,
int threadsPerWarp,
int systolicDepth) {
int rowPerWarp = threadsPerWarp / systolicDepth;
int warpRepeats = repeatCount / rowPerWarp;
std::vector<std::vector<int32_t>> regBases;

for (int opc = 1; opc < opsPerChannel; opc *= 2) {
regBases.push_back({0, opc});
}

for (int warp = 1; warp < warpRepeats; warp *= 2) {
regBases.push_back({warp * rowPerWarp, 0});
}

return regBases;
}

std::vector<std::vector<int32_t>>
DPASLaneBasesA(int opsPerChannel, int threadsPerWarp, int systolicDepth) {
std::vector<std::vector<int32_t>> laneBases;

for (int tid = 1; tid < systolicDepth; tid *= 2) {
laneBases.push_back({0, opsPerChannel * tid});
}
for (int tid = systolicDepth; tid < threadsPerWarp; tid *= 2) {
laneBases.push_back({tid / systolicDepth, 0});
}

return laneBases;
}

// For B operand:
// execution size = 16
//<-------------------------------------------------->
// t0 t1 t2 t3 ~ t12 t13 t14 t15 ^ ^
//. . . . . . . . . | opsPerChan=2 |
// t0 t1 t2 t3 ~ t12 t13 t14 t15 v |
// t16 t17 t18 t19 ~ t28 t29 t30 t31 |
//. . . . . . . . . |
// t16 t17 t18 t19 ~ t28 t29 t30 t31 | systolic depth = 8
// t0 t1 t2 t3 ~ t12 t13 t14 t15 |
//. . . . . . . . . |
// t0 t1 t2 t3 ~ t12 t13 t14 t15 |
// t16 t17 t18 t19 ~ t28 t29 t30 t31 |
//. . . . . . . . . |
// t16 t17 t18 t19 ~ t28 t29 t30 t31 v
// In this case, the LinearLayout bases are:
// Register: {{1,0}, {4,0}, {8,0}}
// Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {2,0}}
std::vector<std::vector<int32_t>> DPASRegBasesB(int opsPerChannel,
int executionSize,
int threadsPerWarp,
int systolicDepth) {
int rowsPerWarp = threadsPerWarp / executionSize;
int warpRepeats = systolicDepth / rowsPerWarp;
std::vector<std::vector<int32_t>> regBases;

for (int opc = 1; opc < opsPerChannel; opc *= 2) {
regBases.push_back({opc, 0});
}
for (int rid = rowsPerWarp; rid < systolicDepth; rid *= 2) {
regBases.push_back({rid * opsPerChannel, 0});
}

return regBases;
}

std::vector<std::vector<int32_t>>
DPASLaneBasesB(int opsPerChannel, int threadsPerWarp, int executionSize) {
std::vector<std::vector<int32_t>> laneBases;

for (int tid = 1; tid < executionSize; tid *= 2) {
laneBases.push_back({0, tid});
}
int rowsPerWarp = threadsPerWarp / executionSize;
for (int row = 1; row < rowsPerWarp; row *= 2) {
laneBases.push_back({row * opsPerChannel, 0});
}

return laneBases;
}

// For C operand:
// execution size = 16
//<---------------------------------->
// t0 t1 t2 t3 ~ t12 t13 t14 t15 ^
Expand All @@ -348,15 +448,13 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
// In this case, the LinearLayout bases are:
// Register: {{2,0}, {4,0}}
// Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {1,0}}
// Currently, LinearLayout is not supported for DotOperandEncoding
// so only Operand C conversion is implemented.
std::vector<std::vector<int32_t>>
DPASRegBasesC(int repeatCount, int executionSize, int threadsPerWarp) {
int rowsPerWarp = threadsPerWarp / executionSize;

std::vector<std::vector<int32_t>> regBases;

for (int rid = rowsPerWarp; rid < repeatCount; rid = rid * 2) {
for (int rid = rowsPerWarp; rid < repeatCount; rid *= 2) {
regBases.push_back({rid, 0});
}

Expand All @@ -365,25 +463,24 @@ DPASRegBasesC(int repeatCount, int executionSize, int threadsPerWarp) {

std::vector<std::vector<int32_t>>
DPASLaneBasesC(int repeatCount, int executionSize, int threadsPerWarp) {

std::vector<std::vector<int32_t>> laneBases;

for (int tid = 1; tid < executionSize; tid = tid * 2) {
for (int tid = 1; tid < executionSize; tid *= 2) {
laneBases.push_back({0, tid});
}
int rowsPerWarp = threadsPerWarp / executionSize;
for (int row = 1; row < rowsPerWarp; row = row * 2) {
for (int row = 1; row < rowsPerWarp; row *= 2) {
laneBases.push_back({row, 0});
}

return laneBases;
}

std::optional<LinearLayout> DPAStoLinearLayout(ArrayRef<int64_t> shape,
Attribute layout) {

std::optional<LinearLayout>
DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout, unsigned opIdx) {
assert(opIdx < 3 && opIdx >= 0);
auto dpas = dyn_cast<DpasEncodingAttr>(layout);
assert(dpas && "Must be DPAS Operand C layout");
assert(dpas && "Must be DPAS layout");

int rank = shape.size();
assert(rank == dpas.getWarpsPerCTA().size());
Expand All @@ -397,33 +494,80 @@ std::optional<LinearLayout> DPAStoLinearLayout(ArrayRef<int64_t> shape,

const SmallVector<unsigned> warpsPerCTA = dpas.getWarpsPerCTA();
int threadsPerWarp = triton::gpu::getWarpSize(dpas);
unsigned opsPerChannel = dpas.getOpsPerChannel();
auto repCluster = dpas.getRepCluster();
SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, 2);
SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, opIdx);

auto tileLayout = LinearLayout::empty();
int systolicDepth = dpas.getSystolicDepth();
int repeatCount = dpas.getRepeatCount();
int executionSize = dpas.getExecutionSize();
unsigned KDim = 0;
unsigned nonKDim = 0;
if (opIdx == 0) { // Operand A
auto regBasesA = DPASRegBasesA(opsPerChannel, repeatCount, threadsPerWarp,
systolicDepth);
auto laneBasesA =
DPASLaneBasesA(opsPerChannel, threadsPerWarp, systolicDepth);
tileLayout = LinearLayout({{kRegister, regBasesA}, {kLane, laneBasesA}},
outDimNames);
// A only repeats by repCluster[0]
tileLayout *=
LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]);
nonKDim = 0;
KDim = 1;
} else if (opIdx == 1) { // Operand B
auto regBasesB = DPASRegBasesB(opsPerChannel, executionSize, threadsPerWarp,
systolicDepth);
auto laneBasesB =
DPASLaneBasesB(opsPerChannel, threadsPerWarp, executionSize);
tileLayout = LinearLayout({{kRegister, regBasesB}, {kLane, laneBasesB}},
outDimNames);
// B only repeats by repCluster[1]
tileLayout *=
LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]);
nonKDim = 1;
KDim = 0;
} else { // opIdx=2 -> Operand C
auto regBasesC = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp);
auto laneBasesC =
DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp);
tileLayout = LinearLayout({{kRegister, regBasesC}, {kLane, laneBasesC}},
outDimNames);
// The per-inst layout is repeated at each repCluster.
// Hence, multiply with the identity layouts starting from the
// least significant dimension.
tileLayout *=
LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]);
tileLayout *=
LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]);
nonKDim = 0;
KDim = 1;
}

auto regBases = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp);
auto laneBases = DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp);
tileLayout =
LinearLayout({{kRegister, regBases}, {kLane, laneBases}}, outDimNames);

// The per-inst layout is repeated at each repCluster.
// Hence, multiply with the identity layouts starting from the
// least significant dimension.
tileLayout *=
LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]);
// Operand A/B repeats through the K-dimension first then repeats
// through non-K dimension.
tileLayout *=
LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]);

// Then, it is repeated by DPASRepetitions to form per-Warp layout.
tileLayout *= LinearLayout::identity1D(numReps[1], kRegister, outDimNames[1]);
tileLayout *= LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]);

// Finally, per-warp layout is repeated among the warps in the CTA.
LinearLayout warpLayout =
identityND(S("warp"), dpas.getWarpsPerCTA(), {0, 1}, outDimNames);
LinearLayout::identity1D(numReps[KDim], kRegister, outDimNames[KDim]);
tileLayout *= LinearLayout::identity1D(numReps[nonKDim], kRegister,
outDimNames[nonKDim]);

// For Operand C, warps split the tensor identically.
// For Operand A and B, warps in the K-dimension share the same data.
// In these cases, the warp hops for K-dimensions are zero.
LinearLayout warpLayout = LinearLayout::empty();
StringAttr kWarp = S("warp");
if (opIdx == 0) {
warpLayout =
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
warpLayout *= LinearLayout::zeros1D(warpsPerCTA[1], kWarp, outDimNames[1]);
} else if (opIdx == 1) {
warpLayout = LinearLayout::zeros1D(warpsPerCTA[0], kWarp, outDimNames[0]);
warpLayout *=
LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]);
} else { /* opIdx == 2 */
warpLayout = identityND(kWarp, warpsPerCTA, {0, 1}, outDimNames);
}
LinearLayout ctaLayout = tileLayout * warpLayout;

return combineCtaCgaWithShape(ctaLayout, CTALayoutAttr::getDefault(ctx, rank),
Expand Down
73 changes: 73 additions & 0 deletions unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class DPAStoLinearLayoutTest : public ::testing::Test {
};

TEST_F(DPAStoLinearLayoutTest, DPAS_perInst) {
// Default: Operand C
EXPECT_EQ(DPAStoLinearLayout({8, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32)),
LinearLayout(
{
Expand All @@ -57,6 +58,28 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_perInst) {
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
// Test Operand A (opIdx=0)
EXPECT_EQ(
DPAStoLinearLayout({8, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32), 0),
LinearLayout(
{
{S("register"), {{0, 1}, {4, 0}}},
{S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
{S("warp"), {}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
// Test Operand B (opIdx=1)
EXPECT_EQ(
DPAStoLinearLayout({16, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32), 1),
LinearLayout(
{
{S("register"), {{1, 0}, {4, 0}, {8, 0}}},
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {2, 0}}},
{S("warp"), {}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}

TEST_F(DPAStoLinearLayoutTest, DPAS_withRepCluster) {
Expand All @@ -70,6 +93,28 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withRepCluster) {
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
// Test Operand A (opIdx=0)
EXPECT_EQ(
DPAStoLinearLayout({32, 16}, dpas({1, 1}, 8, 8, 16, 2, {4, 2}, 32), 0),
LinearLayout(
{
{S("register"), {{0, 1}, {4, 0}, {8, 0}, {16, 0}}},
{S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
{S("warp"), {}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
// Test Operand B (opIdx=1)
EXPECT_EQ(
DPAStoLinearLayout({16, 32}, dpas({1, 1}, 8, 8, 16, 2, {4, 2}, 32), 1),
LinearLayout(
{
{S("register"), {{1, 0}, {4, 0}, {8, 0}, {0, 16}}},
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {2, 0}}},
{S("warp"), {}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
EXPECT_EQ(DPAStoLinearLayout({32, 32}, dpas({1, 1}, 8, 8, 16, 1, {4, 2}, 16)),
LinearLayout(
{
Expand Down Expand Up @@ -103,6 +148,34 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withWarp) {
{S("dim0"), S("dim1")}));
}

TEST_F(DPAStoLinearLayoutTest, DPAS_withWarpOperandA) {
EXPECT_EQ(
DPAStoLinearLayout({64, 64}, dpas({2, 2}, 8, 8, 16, 2, {4, 2}, 32), 0),
LinearLayout(
{
{S("register"),
{{0, 1}, {4, 0}, {8, 0}, {16, 0}, {0, 16}, {0, 32}}},
{S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
{S("warp"), {{32, 0}, {0, 0}}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}

TEST_F(DPAStoLinearLayoutTest, DPAS_withWarpOperandB) {
EXPECT_EQ(
DPAStoLinearLayout({64, 64}, dpas({2, 2}, 8, 8, 16, 2, {4, 2}, 32), 1),
LinearLayout(
{
{S("register"),
{{1, 0}, {4, 0}, {8, 0}, {0, 16}, {16, 0}, {32, 0}}},
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {2, 0}}},
{S("warp"), {{0, 0}, {0, 32}}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}

TEST_F(DPAStoLinearLayoutTest, DPAS_withDPASRepetitions) {
EXPECT_EQ(DPAStoLinearLayout({64, 64}, dpas({2, 1}, 8, 8, 16, 2, {4, 2}, 32)),
LinearLayout(
Expand Down

0 comments on commit 75eee6f

Please sign in to comment.