Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No more workarounds for string ops, enabled through guarded instantiation of unary/binary ops #795

Merged
merged 2 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 4 additions & 67 deletions src/ir/daphneir/DaphneDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -950,57 +950,18 @@ mlir::OpFoldResult mlir::daphne::EwConcatOp::fold(FoldAdaptor adaptor) {
return {};
}

// TODO This is duplicated from EwConcatOp. Actually, ConcatOp itself is only
// a temporary workaround, so it should be removed altogether later.
mlir::OpFoldResult mlir::daphne::ConcatOp::fold(FoldAdaptor adaptor) {
ArrayRef<Attribute> operands = adaptor.getOperands();

if (operands.size() != 2)
throw ErrorHandler::compilerError(
this->getLoc(), "CanonicalizerPass (mlir::daphne::ConcatOp::fold)",
"binary op takes two operands but " + std::to_string(operands.size()) + " were given");

if(!operands[0] || !operands[1])
return {};

if(llvm::isa<StringAttr>(operands[0]) && isa<StringAttr>(operands[1])) {
auto lhs = operands[0].cast<StringAttr>();
auto rhs = operands[1].cast<StringAttr>();

auto concated = lhs.getValue().str() + rhs.getValue().str();
return StringAttr::get(concated, getType());
}
return {};
}

mlir::OpFoldResult mlir::daphne::StringEqOp::fold(FoldAdaptor adaptor) {
ArrayRef<Attribute> operands = adaptor.getOperands();

if (operands.size() != 2)
throw ErrorHandler::compilerError(
this->getLoc(), "CanonicalizerPass (mlir::daphne::StringEqOp::fold)",
"binary op takes two operands but " + std::to_string(operands.size()) + " were given");

if (!operands[0] || !operands[1] || !llvm::isa<StringAttr>(operands[0]) ||
!isa<StringAttr>(operands[1])) {
return {};
}

auto lhs = operands[0].cast<StringAttr>();
auto rhs = operands[1].cast<StringAttr>();

return mlir::BoolAttr::get(getContext(), lhs.getValue() == rhs.getValue());
}

mlir::OpFoldResult mlir::daphne::EwEqOp::fold(FoldAdaptor adaptor) {
ArrayRef<Attribute> operands = adaptor.getOperands();
auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a == b; };
auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a == b; };
auto strOp = [](const llvm::StringRef &a, const llvm::StringRef &b) { return a == b; };
// TODO: fix bool return
if(auto res = constFoldBinaryOp<FloatAttr>(getLoc(), getType(), operands, floatOp))
return res;
if(auto res = constFoldBinaryOp<IntegerAttr>(getLoc(), getType(), operands, intOp))
return res;
if(auto res = constFoldBinaryOp<StringAttr, IntegerAttr>(getLoc(), IntegerType::get(getContext(), 64, IntegerType::SignednessSemantics::Signed), operands, strOp))
return res;
return {};
}

Expand Down Expand Up @@ -1278,30 +1239,6 @@ struct SimplifyDistributeRead : public mlir::OpRewritePattern<mlir::daphne::Dist
}
};

// The EwBinarySca kernel does not handle string types in any way. In order to
// support simple string equivalence checks this canonicalizer rewrites the
// EwEqOp to the StringEqOp if one of the operands is of daphne::StringType.
mlir::LogicalResult mlir::daphne::EwEqOp::canonicalize(
mlir::daphne::EwEqOp op, PatternRewriter &rewriter) {
mlir::Value lhs = op.getLhs();
mlir::Value rhs = op.getRhs();

const bool lhsIsStr = llvm::isa<mlir::daphne::StringType>(lhs.getType());
const bool rhsIsStr = llvm::isa<mlir::daphne::StringType>(rhs.getType());

if (!lhsIsStr && !rhsIsStr) return mlir::failure();

mlir::Type strTy = mlir::daphne::StringType::get(rewriter.getContext());
if (!lhsIsStr)
lhs = rewriter.create<mlir::daphne::CastOp>(op.getLoc(), strTy, lhs);
if (!rhsIsStr)
rhs = rewriter.create<mlir::daphne::CastOp>(op.getLoc(), strTy, rhs);

rewriter.replaceOpWithNewOp<mlir::daphne::StringEqOp>(
op, rewriter.getI1Type(), lhs, rhs);
return mlir::success();
}

/**
* @brief Replaces (1) `a + b` by `a concat b`, if `a` or `b` is a string,
* and (2) `a + X` by `X + a` (`a` scalar, `X` matrix/frame).
Expand Down Expand Up @@ -1331,7 +1268,7 @@ mlir::LogicalResult mlir::daphne::EwAddOp::canonicalize(
lhs = rewriter.create<mlir::daphne::CastOp>(op.getLoc(), strTy, lhs);
if(!rhsIsStr)
rhs = rewriter.create<mlir::daphne::CastOp>(op.getLoc(), strTy, rhs);
rewriter.replaceOpWithNewOp<mlir::daphne::ConcatOp>(op, strTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::daphne::EwConcatOp>(op, strTy, lhs, rhs);
return mlir::success();
}
else {
Expand Down
19 changes: 1 addition & 18 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -333,21 +333,6 @@ def Daphne_EwBitwiseAndOp : Daphne_EwBinaryOp<"ewBitwiseAnd", NumScalar, [Com

def Daphne_EwConcatOp : Daphne_EwBinaryOp<"ewConcat", StrScalar, [CastArgsToResType]>;

// TODO This is just a quick and dirty solution that should be properly
// integrated with EwConcatOp above.
def Daphne_ConcatOp : Daphne_Op<"concat", [DataTypeSca, ValueTypeStr]> {
let arguments = (ins StrScalar:$lhs, StrScalar:$rhs);
let results = (outs StrScalar:$res);

let hasFolder = 1;
}

def Daphne_StringEqOp : Daphne_Op<"stringEq", [ValueTypeStr]> {
let arguments = (ins StrScalar:$lhs, StrScalar:$rhs);
let results = (outs BoolScalar:$res);
let hasFolder = 1;
}

// ----------------------------------------------------------------------------
// Comparisons
// ----------------------------------------------------------------------------
Expand All @@ -359,9 +344,7 @@ class Daphne_EwCmpOp<string name, Type inputScalarType, list<Trait> traits = []>
//let results = (outs AnyTypeOf<[MatrixOf<[BoolScalar]>, BoolScalar, Unknown]>:$res);
}

def Daphne_EwEqOp : Daphne_EwCmpOp<"ewEq" , AnyScalar, [Commutative]> {
let hasCanonicalizeMethod = 1;
}
def Daphne_EwEqOp : Daphne_EwCmpOp<"ewEq" , AnyScalar, [Commutative]>;
def Daphne_EwNeqOp : Daphne_EwCmpOp<"ewNeq", AnyScalar, [Commutative, CUDASupport]>;
def Daphne_EwLtOp : Daphne_EwCmpOp<"ewLt" , AnyScalar>;
def Daphne_EwLeOp : Daphne_EwCmpOp<"ewLe" , AnyScalar>;
Expand Down
2 changes: 1 addition & 1 deletion src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f

if(func == "concat") {
checkNumArgsExact(loc, func, numArgs, 2);
return static_cast<mlir::Value>(builder.create<ConcatOp>(
return static_cast<mlir::Value>(builder.create<EwConcatOp>(
loc, StringType::get(builder.getContext()), args[0], args[1]
));
}
Expand Down
6 changes: 3 additions & 3 deletions src/parser/daphnedsl/DaphneDSLVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1338,11 +1338,11 @@ antlrcpp::Any DaphneDSLVisitor::visitAddExpr(DaphneDSLGrammarParser::AddExprCont

if(op == "+")
// Note that we use '+' for both addition (EwAddOp) and concatenation
// (ConcatOp). The choice is made based on the types of the operands
// (if one operand is a string, we choose ConcatOp). However, the types
// (EwConcatOp). The choice is made based on the types of the operands
// (if one operand is a string, we choose EwConcatOp). However, the types
// might not be known at this point in time. Thus, we always create an
// EwAddOp here. Note that EwAddOp has a canonicalize method rewriting
// it to ConcatOp if necessary.
// it to EwConcatOp if necessary.
return utils.retValWithInferedType(builder.create<mlir::daphne::EwAddOp>(loc, lhs, rhs));
if(op == "-")
return utils.retValWithInferedType(builder.create<mlir::daphne::EwSubOp>(loc, lhs, rhs));
Expand Down
1 change: 1 addition & 0 deletions src/runtime/local/datastructures/ValueTypeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ template<> const std::string ValueTypeUtils::cppNameFor<uint64_t>;
template<> const std::string ValueTypeUtils::cppNameFor<float>;
template<> const std::string ValueTypeUtils::cppNameFor<double>;
template<> const std::string ValueTypeUtils::cppNameFor<bool>;
template<> const std::string ValueTypeUtils::cppNameFor<char*>;

template<> const std::string ValueTypeUtils::irNameFor<int8_t>;
template<> const std::string ValueTypeUtils::irNameFor<int32_t>;
Expand Down
127 changes: 121 additions & 6 deletions src/runtime/local/kernels/BinaryOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

#pragma once

// ****************************************************************************
// Enum for binary op codes and their names
// ****************************************************************************

enum class BinaryOpCode {
// Arithmetic.
ADD, // addition
Expand All @@ -25,26 +29,137 @@ enum class BinaryOpCode {
POW, // to the power of
MOD, // modulus
LOG, // logarithm

// Comparisons.
EQ, // equal
NEQ, // not equal
LT, // less than
LE, // less equal
GT, // greater than
GE, // greater equal

// Min/max.
MIN,
MAX,

// Logical.
AND,
OR,

// Bitwise.
BITWISE_AND,
// Strings.
CONCAT
};

static std::string_view binary_op_codes[] = {"ADD", "SUB", "MUL", "DIV", "POW", "MOD", "LOG", "EQ", "NEQ", "LT", "LE",
"GT", "GE", "MIN", "MAX", "AND", "OR", "BITWISE_AND"};
/**
* @brief Array of the "names" of the `BinaryOpCode`s.
*
* Must contain the same elements as `BinaryOpCode` in the same order,
* such that we can obtain the name corresponding to a `BinaryOpCode` `opCode`
* by `binary_op_codes[static_cast<int>(opCode)]`.
*/
static std::string_view binary_op_codes[] = {
// Arithmetic.
"ADD", "SUB", "MUL", "DIV", "POW", "MOD", "LOG",
// Comparisons.
"EQ", "NEQ", "LT", "LE", "GT", "GE",
// Min/max.
"MIN", "MAX",
// Logical.
"AND", "OR",
// Bitwise.
"BITWISE_AND",
// Strings.
"CONCAT"
};

// ****************************************************************************
// Specification which binary ops should be supported on which value types
// ****************************************************************************

/**
* @brief Template constant specifying if the given binary operation
* should be supported on arguments of the given value types.
*
* @tparam VTRes The result value type.
* @tparam VTLhs The left-hand-side argument value type.
* @tparam VTRhs The right-hand-side argument value type.
* @tparam op The binary operation.
*/
template<BinaryOpCode op, typename VTRes, typename VTLhs, typename VTRhs>
static constexpr bool supportsBinaryOp = false;

// Macros for concisely specifying which binary operations should be
// supported on which value types.

// Generates code specifying that the binary operation `Op` should be supported on
// the value type `VT` (for the result and the two arguments, for simplicity).
#define SUPPORT(Op, VT) \
template<> constexpr bool supportsBinaryOp<BinaryOpCode::Op, VT, VT, VT> = true;

// Generates code specifying that all binary operations of a certain category should be
// supported on the given value type `VT` (for the result and the two arguments, for simplicity).
#define SUPPORT_ARITHMETIC(VT) \
/* Arithmetic. */ \
SUPPORT(ADD, VT) \
SUPPORT(SUB, VT) \
SUPPORT(MUL, VT) \
SUPPORT(DIV, VT) \
SUPPORT(POW, VT) \
SUPPORT(MOD, VT) \
SUPPORT(LOG, VT)
#define SUPPORT_EQUALITY(VT) \
/* Comparisons. */ \
SUPPORT(EQ , VT) \
SUPPORT(NEQ, VT)
#define SUPPORT_COMPARISONS(VT) \
/* Comparisons. */ \
SUPPORT(LT, VT) \
SUPPORT(LE, VT) \
SUPPORT(GT, VT) \
SUPPORT(GE, VT) \
/* Min/max. */ \
SUPPORT(MIN, VT) \
SUPPORT(MAX, VT)
#define SUPPORT_LOGICAL(VT) \
/* Logical. */ \
SUPPORT(AND, VT) \
SUPPORT(OR , VT)
#define SUPPORT_BITWISE(VT) \
/* Bitwise. */ \
SUPPORT(BITWISE_AND, VT)

// Generates code specifying that all binary operations typically supported on a certain
// category of value types should be supported on the given value type `VT`
// (for the result and the two arguments, for simplicity).
#define SUPPORT_NUMERIC_FP(VT) \
SUPPORT_ARITHMETIC(VT) \
SUPPORT_EQUALITY(VT) \
SUPPORT_COMPARISONS(VT) \
SUPPORT_LOGICAL(VT)
#define SUPPORT_NUMERIC_INT(VT) \
SUPPORT_ARITHMETIC(VT) \
SUPPORT_EQUALITY(VT) \
SUPPORT_COMPARISONS(VT) \
SUPPORT_LOGICAL(VT) \
SUPPORT_BITWISE(VT)

// Concise specification of which binary operations should be supported on
// which value types.
SUPPORT_NUMERIC_FP(double)
SUPPORT_NUMERIC_FP(float)
SUPPORT_NUMERIC_INT(int64_t)
SUPPORT_NUMERIC_INT(int32_t)
SUPPORT_NUMERIC_INT(int8_t)
SUPPORT_NUMERIC_INT(uint64_t)
SUPPORT_NUMERIC_INT(uint32_t)
SUPPORT_NUMERIC_INT(uint8_t)
template<> constexpr bool supportsBinaryOp<BinaryOpCode::CONCAT, const char *, const char *, const char *> = true;
template<> constexpr bool supportsBinaryOp<BinaryOpCode::EQ, int64_t, const char *, const char *> = true;

// Undefine helper macros.
#undef SUPPORT
#undef SUPPORT_ARITHMETIC
#undef SUPPORT_EQUALITY
#undef SUPPORT_COMPARISONS
#undef SUPPORT_LOGICAL
#undef SUPPORT_BITWISE
#undef SUPPORT_NUMERIC_FP
#undef SUPPORT_NUMERIC_INT
39 changes: 0 additions & 39 deletions src/runtime/local/kernels/Concat.h

This file was deleted.

Loading
Loading