From d35969c98ecfdb1b02f986917c851b91508494c1 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Fri, 15 Nov 2024 13:15:10 +0000 Subject: [PATCH] Use ArrayRef in AxisInfo constructors Signed-off-by: Anatoly Myachev --- include/triton/Analysis/AxisInfo.h | 7 +++-- third_party/intel/lib/Analysis/AxisInfo.cpp | 34 ++++++++------------- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index aad4503b48..1bf9c8a690 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -27,11 +27,12 @@ class AxisInfo { public: AxisInfo() : AxisInfo({}, {}, {}) {} - AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy) + AxisInfo(ArrayRef contiguity, ArrayRef divisibility, + ArrayRef constancy) : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} - AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy, - std::optional constantValue) + AxisInfo(ArrayRef contiguity, ArrayRef divisibility, + ArrayRef constancy, std::optional constantValue) : contiguity(contiguity), divisibility(divisibility), constancy(constancy), constantValue(constantValue) { assert(divisibility.size() == contiguity.size()); diff --git a/third_party/intel/lib/Analysis/AxisInfo.cpp b/third_party/intel/lib/Analysis/AxisInfo.cpp index 463fb4522b..5215147a6a 100644 --- a/third_party/intel/lib/Analysis/AxisInfo.cpp +++ b/third_party/intel/lib/Analysis/AxisInfo.cpp @@ -123,8 +123,7 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); } } - return AxisInfo(std::move(contiguity), std::move(divisibility), - std::move(constancy), constantValue); + return AxisInfo(contiguity, divisibility, constancy, constantValue); } protected: @@ -544,8 +543,7 @@ class SplatOpAxisInfoVisitor final divisibility.push_back(opInfo.getDivisibility(0)); constancy.push_back(retTy.getShape()[d]); } - return AxisInfo(std::move(contiguity), std::move(divisibility), - std::move(constancy), + return AxisInfo(contiguity, divisibility, constancy, operands[0]->getValue().getConstantValue()); } }; @@ -576,8 +574,7 @@ class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl { maskInfo.has_value() ? maskInfo->getConstancy(d) : 0)); } - return AxisInfo(std::move(contiguity), std::move(divisibility), - std::move(constancy)); + return AxisInfo(contiguity, divisibility, constancy); } }; @@ -611,8 +608,7 @@ class ExpandDimsOpAxisInfoVisitor final contiguity.insert(contiguity.begin() + op.getAxis(), 1); divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility); constancy.insert(constancy.begin() + op.getAxis(), 1); - return AxisInfo(std::move(contiguity), std::move(divisibility), - std::move(constancy), + return AxisInfo(contiguity, divisibility, constancy, operands[0]->getValue().getConstantValue()); } }; @@ -641,8 +637,7 @@ class BroadcastOpAxisInfoVisitor final constancy.push_back(opShape[d] == 1 ? retShape[d] : opInfo.getConstancy(d)); } - return AxisInfo(std::move(contiguity), std::move(divisibility), - std::move(constancy), + return AxisInfo(contiguity, divisibility, constancy, operands[0]->getValue().getConstantValue()); } }; @@ -717,8 +712,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { contiguity.push_back(1); } - return AxisInfo(std::move(contiguity), std::move(divisibility), - std::move(constancy), constantValue); + return AxisInfo(contiguity, divisibility, constancy, constantValue); } private: @@ -846,8 +840,7 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { constantValue = lhsInfo.getConstantValue(); } - return AxisInfo(std::move(contiguity), std::move(divisibility), - std::move(constancy), constantValue); + return AxisInfo(contiguity, divisibility, constancy, constantValue); } }; @@ -1000,8 +993,7 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { contiguity.push_back( std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d))); } - return AxisInfo(std::move(contiguity), std::move(divisibility), - std::move(constancy), std::nullopt); + return AxisInfo(contiguity, divisibility, constancy, std::nullopt); } } }; @@ -1046,8 +1038,7 @@ class MakeTensorPtrOpAxisInfoVisitor final constancy.push_back(1); } - auto axisInfo = AxisInfo(std::move(contiguity), std::move(divisibility), - std::move(constancy)); + auto axisInfo = AxisInfo(contiguity, divisibility, constancy); LLVM_DEBUG({ std::string axisStr; @@ -1152,8 +1143,8 @@ LogicalResult AxisInfoAnalysis::visitOperation( auto vals = cast(attr).getValues(); newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end()); } - curr = AxisInfo(std::move(newContiguity), std::move(newDivisibility), - std::move(newConstancy), curr.getConstantValue()); + curr = AxisInfo(newContiguity, newDivisibility, newConstancy, + curr.getConstantValue()); // join all lattice elements for (auto *result : results) propagateIfChanged(result, result->join(curr)); @@ -1173,8 +1164,7 @@ void AxisInfoAnalysis::visitForOpInductionVar( AxisInfo::DimVectorT knownConstancy(1, 1); knownDivisibility[0] = gcd(lb.getDivisibility(0), step.getDivisibility(0)); auto inductionVar = - AxisInfo(std::move(knownContiguity), std::move(knownDivisibility), - std::move(knownConstancy)); + AxisInfo(knownContiguity, knownDivisibility, knownConstancy); (void)argLattices[0]->join(inductionVar); }