From 1e7e3bceff1e71af165018c1facb69103a9f851a Mon Sep 17 00:00:00 2001 From: Lorenz Dirry Date: Mon, 19 Aug 2024 02:14:06 +0200 Subject: [PATCH] [DAPHNE-#772] Fix Constant folding crashes for values of different types - add type promotion to the smaller bit width type - added testcases for binary op of different types Closes #813 --- src/ir/daphneir/DaphneDialect.cpp | 116 +++++++++++------- .../CanonicalizationConstantFoldingOpTest.cpp | 4 + .../binary_op_casts_constant_folding.daphne | 16 +++ .../binary_op_casts_constant_folding.txt | 34 +++++ 4 files changed, 129 insertions(+), 41 deletions(-) create mode 100644 test/api/cli/operations/binary_op_casts_constant_folding.daphne create mode 100644 test/api/cli/operations/binary_op_casts_constant_folding.txt diff --git a/src/ir/daphneir/DaphneDialect.cpp b/src/ir/daphneir/DaphneDialect.cpp index 23f427fee..542e66810 100644 --- a/src/ir/daphneir/DaphneDialect.cpp +++ b/src/ir/daphneir/DaphneDialect.cpp @@ -569,6 +569,8 @@ mlir::LogicalResult mlir::daphne::VectorizedPipelineOp::canonicalize(mlir::daphn // For families of operations. // Adapted from "mlir/Dialect/CommonFolders.h" +mlir::Attribute performCast(mlir::Attribute attr, mlir::Type targetType, mlir::Location loc); + template< class ArgAttrElementT, class ResAttrElementT = ArgAttrElementT, @@ -595,6 +597,21 @@ mlir::Attribute constFoldBinaryOp(mlir::Location loc, mlir::Type resultType, llv std::is_same::value || std::is_same::value ) { + mlir::Type l = lhs.getType(); + mlir::Type r = rhs.getType(); + if ((l.dyn_cast() || l.dyn_cast()) && + (r.dyn_cast() || r.dyn_cast())) { + auto lhsBitWidth = lhs.getType().getIntOrFloatBitWidth(); + auto rhsBitWidth = rhs.getType().getIntOrFloatBitWidth(); + + if (lhsBitWidth < rhsBitWidth) { + mlir::Attribute promotedLhs = performCast(lhs, rhs.getType(), loc); + lhs = promotedLhs.cast(); + } else if (rhsBitWidth < lhsBitWidth) { + mlir::Attribute promotedRhs = performCast(rhs, lhs.getType(), loc); + rhs = promotedRhs.cast(); + } + } return ResAttrElementT::get(resultType, calculate(lhs.getValue(), rhs.getValue())); } else if constexpr(std::is_same::value) { @@ -638,66 +655,83 @@ mlir::Attribute constFoldUnaryOp(mlir::Location loc, mlir::Type resultType, llvm // **************************************************************************** // Fold implementations // **************************************************************************** +mlir::Attribute performCast(mlir::Attribute attr, mlir::Type targetType, mlir::Location loc) { + if (auto intAttr = attr.dyn_cast()) { + auto apInt = intAttr.getValue(); -mlir::OpFoldResult mlir::daphne::CastOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - if (isTrivialCast()) { - if (operands[0]) - return {operands[0]}; - else - return {getArg()}; - } - if(auto in = operands[0].dyn_cast_or_null()) { - auto apInt = in.getValue(); - if(auto outTy = getType().dyn_cast()) { - // TODO: throw exception if bits truncated? - if(outTy.isUnsignedInteger()) { + if (auto outTy = targetType.dyn_cast()) { + // Extend or truncate the integer value based on the target type + if (outTy.isUnsignedInteger()) { apInt = apInt.zextOrTrunc(outTy.getWidth()); - } - else if(outTy.isSignedInteger()) { - apInt = (in.getType().isSignedInteger()) + } else if (outTy.isSignedInteger()) { + apInt = (intAttr.getType().isSignedInteger()) ? apInt.sextOrTrunc(outTy.getWidth()) : apInt.zextOrTrunc(outTy.getWidth()); } - return IntegerAttr::getChecked(getLoc(), outTy, apInt); + return mlir::IntegerAttr::getChecked(loc, outTy, apInt); } - if(auto outTy = getType().dyn_cast()) { - return IntegerAttr::getChecked(getLoc(), outTy, apInt); + + if (auto outTy = targetType.dyn_cast()) { + return mlir::IntegerAttr::getChecked(loc, outTy, apInt); } - if(getType().isF64()) { - if(in.getType().isSignedInteger()) { - return FloatAttr::getChecked(getLoc(), - getType(), - llvm::APIntOps::RoundSignedAPIntToDouble(in.getValue())); + + if (targetType.isF64()) { + if (intAttr.getType().isSignedInteger()) { + return mlir::FloatAttr::getChecked(loc, targetType, + llvm::APIntOps::RoundSignedAPIntToDouble(apInt)); } - if(in.getType().isUnsignedInteger() || in.getType().isIndex()) { - return FloatAttr::getChecked(getLoc(), getType(), llvm::APIntOps::RoundAPIntToDouble(in.getValue())); + if (intAttr.getType().isUnsignedInteger() || intAttr.getType().isIndex()) { + return mlir::FloatAttr::getChecked(loc, targetType, + llvm::APIntOps::RoundAPIntToDouble(apInt)); } } - if(getType().isF32()) { - if(in.getType().isSignedInteger()) { - return FloatAttr::getChecked(getLoc(), - getType(), - llvm::APIntOps::RoundSignedAPIntToFloat(in.getValue())); + + if (targetType.isF32()) { + if (intAttr.getType().isSignedInteger()) { + return mlir::FloatAttr::getChecked(loc, targetType, + llvm::APIntOps::RoundSignedAPIntToFloat(apInt)); } - if(in.getType().isUnsignedInteger()) { - return FloatAttr::get(getType(), llvm::APIntOps::RoundAPIntToFloat(in.getValue())); + if (intAttr.getType().isUnsignedInteger()) { + return mlir::FloatAttr::get(targetType, + llvm::APIntOps::RoundAPIntToFloat(apInt)); } } } - if(auto in = operands[0].dyn_cast_or_null()) { - auto val = in.getValueAsDouble(); - if(getType().isF64()) { - return FloatAttr::getChecked(getLoc(), getType(), val); + else if (auto floatAttr = attr.dyn_cast()) { + auto val = floatAttr.getValueAsDouble(); + + if (targetType.isF64()) { + return mlir::FloatAttr::getChecked(loc, targetType, val); } - if(getType().isF32()) { - return FloatAttr::getChecked(getLoc(), getType(), static_cast(val)); + if (targetType.isF32()) { + return mlir::FloatAttr::getChecked(loc, targetType, static_cast(val)); } - if(getType().isIntOrIndex()) { + if (targetType.isIntOrIndex()) { auto num = static_cast(val); - return IntegerAttr::getChecked(getLoc(), getType(), num); + return mlir::IntegerAttr::getChecked(loc, targetType, num); } } + + // If casting is not possible, return the original attribute + return {}; +} + +mlir::OpFoldResult mlir::daphne::CastOp::fold(FoldAdaptor adaptor) { + ArrayRef operands = adaptor.getOperands(); + + if (isTrivialCast()) { + if (operands[0]) + return {operands[0]}; + else + return {getArg()}; + } + + if (operands[0]) { + if (auto castedAttr = performCast(operands[0], getType(), getLoc())) { + return castedAttr; + } + } + return {}; } diff --git a/test/api/cli/operations/CanonicalizationConstantFoldingOpTest.cpp b/test/api/cli/operations/CanonicalizationConstantFoldingOpTest.cpp index 1b5aeb68b..19125d733 100644 --- a/test/api/cli/operations/CanonicalizationConstantFoldingOpTest.cpp +++ b/test/api/cli/operations/CanonicalizationConstantFoldingOpTest.cpp @@ -40,3 +40,7 @@ TEST_CASE("additive_inverse_canonicalization", TAG_CODEGEN TAG_OPERATIONS) { compareDaphneParsingSimplifiedToRef(dirPath + testName + ".txt", dirPath + testName + ".daphne"); } +TEST_CASE("binary_operator_casts_constant_folding", TAG_CODEGEN TAG_OPERATIONS) { + const std::string testName = "binary_op_casts_constant_folding"; + compareDaphneParsingSimplifiedToRef(dirPath + testName + ".txt", dirPath + testName + ".daphne"); +} diff --git a/test/api/cli/operations/binary_op_casts_constant_folding.daphne b/test/api/cli/operations/binary_op_casts_constant_folding.daphne new file mode 100644 index 000000000..9280b9511 --- /dev/null +++ b/test/api/cli/operations/binary_op_casts_constant_folding.daphne @@ -0,0 +1,16 @@ +print(1 + as.si8(1)); +print(2 + as.si32(1)); +print(3 + as.si64(1)); +print(4 + as.ui8(1)); +print(5 + as.ui32(1)); +print(6 + as.ui64(1)); + +print(10.0 + as.f32(1)); +print(11.0 + as.f64(1)); + +print(as.si32(7) + as.si8(1)); +print(as.si64(8) + as.si32(1)); +print(as.si8(9) + as.si64(1)); + +print(as.f64(12.0) + as.f32(1)); +print(as.f32(13.0) + as.f64(1)); \ No newline at end of file diff --git a/test/api/cli/operations/binary_op_casts_constant_folding.txt b/test/api/cli/operations/binary_op_casts_constant_folding.txt new file mode 100644 index 000000000..85e1df74e --- /dev/null +++ b/test/api/cli/operations/binary_op_casts_constant_folding.txt @@ -0,0 +1,34 @@ +IR after parsing and some simplifications: +module { + func.func @main() { + %0 = "daphne.constant"() {value = 1.400000e+01 : f64} : () -> f64 + %1 = "daphne.constant"() {value = 1.300000e+01 : f64} : () -> f64 + %2 = "daphne.constant"() {value = 1.200000e+01 : f64} : () -> f64 + %3 = "daphne.constant"() {value = 10 : si64} : () -> si64 + %4 = "daphne.constant"() {value = 9 : si64} : () -> si64 + %5 = "daphne.constant"() {value = 8 : si32} : () -> si32 + %6 = "daphne.constant"() {value = 1.100000e+01 : f64} : () -> f64 + %7 = "daphne.constant"() {value = 7 : ui64} : () -> ui64 + %8 = "daphne.constant"() {value = 6 : si64} : () -> si64 + %9 = "daphne.constant"() {value = 5 : si64} : () -> si64 + %10 = "daphne.constant"() {value = 4 : si64} : () -> si64 + %11 = "daphne.constant"() {value = 3 : si64} : () -> si64 + %12 = "daphne.constant"() {value = 2 : si64} : () -> si64 + %13 = "daphne.constant"() {value = false} : () -> i1 + %14 = "daphne.constant"() {value = true} : () -> i1 + "daphne.print"(%12, %14, %13) : (si64, i1, i1) -> () + "daphne.print"(%11, %14, %13) : (si64, i1, i1) -> () + "daphne.print"(%10, %14, %13) : (si64, i1, i1) -> () + "daphne.print"(%9, %14, %13) : (si64, i1, i1) -> () + "daphne.print"(%8, %14, %13) : (si64, i1, i1) -> () + "daphne.print"(%7, %14, %13) : (ui64, i1, i1) -> () + "daphne.print"(%6, %14, %13) : (f64, i1, i1) -> () + "daphne.print"(%2, %14, %13) : (f64, i1, i1) -> () + "daphne.print"(%5, %14, %13) : (si32, i1, i1) -> () + "daphne.print"(%4, %14, %13) : (si64, i1, i1) -> () + "daphne.print"(%3, %14, %13) : (si64, i1, i1) -> () + "daphne.print"(%1, %14, %13) : (f64, i1, i1) -> () + "daphne.print"(%0, %14, %13) : (f64, i1, i1) -> () + "daphne.return"() : () -> () + } +}