Skip to content

Commit

Permalink
[DAPHNE-#772] Fix Constant folding crashes for values of different types
Browse files Browse the repository at this point in the history
- add type promotion to the smaller bit width type
- added testcases for binary op of different types

Closes #813
  • Loading branch information
ldirry authored and corepointer committed Sep 11, 2024
1 parent 5337d14 commit 1e7e3bc
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 41 deletions.
116 changes: 75 additions & 41 deletions src/ir/daphneir/DaphneDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ mlir::LogicalResult mlir::daphne::VectorizedPipelineOp::canonicalize(mlir::daphn
// For families of operations.

// Adapted from "mlir/Dialect/CommonFolders.h"
mlir::Attribute performCast(mlir::Attribute attr, mlir::Type targetType, mlir::Location loc);

template<
class ArgAttrElementT,
class ResAttrElementT = ArgAttrElementT,
Expand All @@ -595,6 +597,21 @@ mlir::Attribute constFoldBinaryOp(mlir::Location loc, mlir::Type resultType, llv
std::is_same<ResAttrElementT, mlir::IntegerAttr>::value ||
std::is_same<ResAttrElementT, mlir::FloatAttr>::value
) {
mlir::Type l = lhs.getType();
mlir::Type r = rhs.getType();
if ((l.dyn_cast<mlir::IntegerType>() || l.dyn_cast<mlir::FloatType>()) &&
(r.dyn_cast<mlir::IntegerType>() || r.dyn_cast<mlir::FloatType>())) {
auto lhsBitWidth = lhs.getType().getIntOrFloatBitWidth();
auto rhsBitWidth = rhs.getType().getIntOrFloatBitWidth();

if (lhsBitWidth < rhsBitWidth) {
mlir::Attribute promotedLhs = performCast(lhs, rhs.getType(), loc);
lhs = promotedLhs.cast<ArgAttrElementT>();
} else if (rhsBitWidth < lhsBitWidth) {
mlir::Attribute promotedRhs = performCast(rhs, lhs.getType(), loc);
rhs = promotedRhs.cast<ArgAttrElementT>();
}
}
return ResAttrElementT::get(resultType, calculate(lhs.getValue(), rhs.getValue()));
}
else if constexpr(std::is_same<ResAttrElementT, mlir::BoolAttr>::value) {
Expand Down Expand Up @@ -638,66 +655,83 @@ mlir::Attribute constFoldUnaryOp(mlir::Location loc, mlir::Type resultType, llvm
// ****************************************************************************
// Fold implementations
// ****************************************************************************
mlir::Attribute performCast(mlir::Attribute attr, mlir::Type targetType, mlir::Location loc) {
if (auto intAttr = attr.dyn_cast<mlir::IntegerAttr>()) {
auto apInt = intAttr.getValue();

mlir::OpFoldResult mlir::daphne::CastOp::fold(FoldAdaptor adaptor) {
ArrayRef<Attribute> operands = adaptor.getOperands();
if (isTrivialCast()) {
if (operands[0])
return {operands[0]};
else
return {getArg()};
}
if(auto in = operands[0].dyn_cast_or_null<IntegerAttr>()) {
auto apInt = in.getValue();
if(auto outTy = getType().dyn_cast<IntegerType>()) {
// TODO: throw exception if bits truncated?
if(outTy.isUnsignedInteger()) {
if (auto outTy = targetType.dyn_cast<mlir::IntegerType>()) {
// Extend or truncate the integer value based on the target type
if (outTy.isUnsignedInteger()) {
apInt = apInt.zextOrTrunc(outTy.getWidth());
}
else if(outTy.isSignedInteger()) {
apInt = (in.getType().isSignedInteger())
} else if (outTy.isSignedInteger()) {
apInt = (intAttr.getType().isSignedInteger())
? apInt.sextOrTrunc(outTy.getWidth())
: apInt.zextOrTrunc(outTy.getWidth());
}
return IntegerAttr::getChecked(getLoc(), outTy, apInt);
return mlir::IntegerAttr::getChecked(loc, outTy, apInt);
}
if(auto outTy = getType().dyn_cast<IndexType>()) {
return IntegerAttr::getChecked(getLoc(), outTy, apInt);

if (auto outTy = targetType.dyn_cast<mlir::IndexType>()) {
return mlir::IntegerAttr::getChecked(loc, outTy, apInt);
}
if(getType().isF64()) {
if(in.getType().isSignedInteger()) {
return FloatAttr::getChecked(getLoc(),
getType(),
llvm::APIntOps::RoundSignedAPIntToDouble(in.getValue()));

if (targetType.isF64()) {
if (intAttr.getType().isSignedInteger()) {
return mlir::FloatAttr::getChecked(loc, targetType,
llvm::APIntOps::RoundSignedAPIntToDouble(apInt));
}
if(in.getType().isUnsignedInteger() || in.getType().isIndex()) {
return FloatAttr::getChecked(getLoc(), getType(), llvm::APIntOps::RoundAPIntToDouble(in.getValue()));
if (intAttr.getType().isUnsignedInteger() || intAttr.getType().isIndex()) {
return mlir::FloatAttr::getChecked(loc, targetType,
llvm::APIntOps::RoundAPIntToDouble(apInt));
}
}
if(getType().isF32()) {
if(in.getType().isSignedInteger()) {
return FloatAttr::getChecked(getLoc(),
getType(),
llvm::APIntOps::RoundSignedAPIntToFloat(in.getValue()));

if (targetType.isF32()) {
if (intAttr.getType().isSignedInteger()) {
return mlir::FloatAttr::getChecked(loc, targetType,
llvm::APIntOps::RoundSignedAPIntToFloat(apInt));
}
if(in.getType().isUnsignedInteger()) {
return FloatAttr::get(getType(), llvm::APIntOps::RoundAPIntToFloat(in.getValue()));
if (intAttr.getType().isUnsignedInteger()) {
return mlir::FloatAttr::get(targetType,
llvm::APIntOps::RoundAPIntToFloat(apInt));
}
}
}
if(auto in = operands[0].dyn_cast_or_null<FloatAttr>()) {
auto val = in.getValueAsDouble();
if(getType().isF64()) {
return FloatAttr::getChecked(getLoc(), getType(), val);
else if (auto floatAttr = attr.dyn_cast<mlir::FloatAttr>()) {
auto val = floatAttr.getValueAsDouble();

if (targetType.isF64()) {
return mlir::FloatAttr::getChecked(loc, targetType, val);
}
if(getType().isF32()) {
return FloatAttr::getChecked(getLoc(), getType(), static_cast<float>(val));
if (targetType.isF32()) {
return mlir::FloatAttr::getChecked(loc, targetType, static_cast<float>(val));
}
if(getType().isIntOrIndex()) {
if (targetType.isIntOrIndex()) {
auto num = static_cast<int64_t>(val);
return IntegerAttr::getChecked(getLoc(), getType(), num);
return mlir::IntegerAttr::getChecked(loc, targetType, num);
}
}

// If casting is not possible, return the original attribute
return {};
}

mlir::OpFoldResult mlir::daphne::CastOp::fold(FoldAdaptor adaptor) {
ArrayRef<Attribute> operands = adaptor.getOperands();

if (isTrivialCast()) {
if (operands[0])
return {operands[0]};
else
return {getArg()};
}

if (operands[0]) {
if (auto castedAttr = performCast(operands[0], getType(), getLoc())) {
return castedAttr;
}
}

return {};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ TEST_CASE("additive_inverse_canonicalization", TAG_CODEGEN TAG_OPERATIONS) {
compareDaphneParsingSimplifiedToRef(dirPath + testName + ".txt", dirPath + testName + ".daphne");
}

TEST_CASE("binary_operator_casts_constant_folding", TAG_CODEGEN TAG_OPERATIONS) {
const std::string testName = "binary_op_casts_constant_folding";
compareDaphneParsingSimplifiedToRef(dirPath + testName + ".txt", dirPath + testName + ".daphne");
}
16 changes: 16 additions & 0 deletions test/api/cli/operations/binary_op_casts_constant_folding.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
print(1 + as.si8(1));
print(2 + as.si32(1));
print(3 + as.si64(1));
print(4 + as.ui8(1));
print(5 + as.ui32(1));
print(6 + as.ui64(1));

print(10.0 + as.f32(1));
print(11.0 + as.f64(1));

print(as.si32(7) + as.si8(1));
print(as.si64(8) + as.si32(1));
print(as.si8(9) + as.si64(1));

print(as.f64(12.0) + as.f32(1));
print(as.f32(13.0) + as.f64(1));
34 changes: 34 additions & 0 deletions test/api/cli/operations/binary_op_casts_constant_folding.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
IR after parsing and some simplifications:
module {
func.func @main() {
%0 = "daphne.constant"() {value = 1.400000e+01 : f64} : () -> f64
%1 = "daphne.constant"() {value = 1.300000e+01 : f64} : () -> f64
%2 = "daphne.constant"() {value = 1.200000e+01 : f64} : () -> f64
%3 = "daphne.constant"() {value = 10 : si64} : () -> si64
%4 = "daphne.constant"() {value = 9 : si64} : () -> si64
%5 = "daphne.constant"() {value = 8 : si32} : () -> si32
%6 = "daphne.constant"() {value = 1.100000e+01 : f64} : () -> f64
%7 = "daphne.constant"() {value = 7 : ui64} : () -> ui64
%8 = "daphne.constant"() {value = 6 : si64} : () -> si64
%9 = "daphne.constant"() {value = 5 : si64} : () -> si64
%10 = "daphne.constant"() {value = 4 : si64} : () -> si64
%11 = "daphne.constant"() {value = 3 : si64} : () -> si64
%12 = "daphne.constant"() {value = 2 : si64} : () -> si64
%13 = "daphne.constant"() {value = false} : () -> i1
%14 = "daphne.constant"() {value = true} : () -> i1
"daphne.print"(%12, %14, %13) : (si64, i1, i1) -> ()
"daphne.print"(%11, %14, %13) : (si64, i1, i1) -> ()
"daphne.print"(%10, %14, %13) : (si64, i1, i1) -> ()
"daphne.print"(%9, %14, %13) : (si64, i1, i1) -> ()
"daphne.print"(%8, %14, %13) : (si64, i1, i1) -> ()
"daphne.print"(%7, %14, %13) : (ui64, i1, i1) -> ()
"daphne.print"(%6, %14, %13) : (f64, i1, i1) -> ()
"daphne.print"(%2, %14, %13) : (f64, i1, i1) -> ()
"daphne.print"(%5, %14, %13) : (si32, i1, i1) -> ()
"daphne.print"(%4, %14, %13) : (si64, i1, i1) -> ()
"daphne.print"(%3, %14, %13) : (si64, i1, i1) -> ()
"daphne.print"(%1, %14, %13) : (f64, i1, i1) -> ()
"daphne.print"(%0, %14, %13) : (f64, i1, i1) -> ()
"daphne.return"() : () -> ()
}
}

0 comments on commit 1e7e3bc

Please sign in to comment.