Skip to content

Commit

Permalink
No more workarounds for string ops (equality, concat).
Browse files Browse the repository at this point in the history
- Motivation:
  - Originally, we intended to express string equality (like equality for any value type) through EwEqOp and string concatenation through EwConcatOp.
  - However, due to problems when instantiating elementwise binary kernels with the string value type, we introduced two new operations StringEqOp and ConcatOp as workarounds some while ago.
  - These workarounds are problematic in the long term because:
    - they are not consistent with the other elementwise unary/binary DaphneIR operations, and
    - in the future, we want to support more value types (through extensibility), which would require more workarounds of this kind
  - In fact, the underlying technical problem was fixed in a recent commit ("Guarded instantiation of unary/binary ops.").
- This commit removes the workarounds for string ops by enabling string equality through EwEqOp (as for any other value type) and string concatenation through EwConcatOp.
- Concrete changes:
  - DaphneDSL parser:
    - Creates EwEqOp and EwConcatOp instead of the workaround ops now.
  - DaphneIR:
    - Removed StringEqOp and ConcatOp, we use EwEqOp and EwConcatOp instead now.
  - DAPHNE compiler:
    - Removed constant folding of StringEqOp and ConcatOp, this functionality is now achieved through the constant folding of EwEqOp and EwConcatOp.
    - Removed the rewrite from EwEqOp to StringEqOp when the value type is string, we can use EwEqOp end-to-end now.
  - DAPHNE runtime:
    - Added a new binary op code CONCAT.
    - Specified that EQ and CONCAT should be supported on string value types.
    - Removed the stringEq and concat-kernels, their essential code was moved to the ewBinarySca-kernel.
    - Removed the instantiations of the stringEq and concat-kernels from kernels.json.
    - Added new instantiations of the ewBinary-kernels for string value type and the CONCAT op code to kernels.json.
  - test cases:
    - Removed the codegen test case "stringeq.mlir", since it became obsolete through the removal of StringEqOp.
  • Loading branch information
pdamme committed Sep 2, 2024
1 parent 333b752 commit a7226dc
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 245 deletions.
71 changes: 4 additions & 67 deletions src/ir/daphneir/DaphneDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -950,57 +950,18 @@ mlir::OpFoldResult mlir::daphne::EwConcatOp::fold(FoldAdaptor adaptor) {
return {};
}

// TODO This is duplicated from EwConcatOp. Actually, ConcatOp itself is only
// a temporary workaround, so it should be removed altogether later.
mlir::OpFoldResult mlir::daphne::ConcatOp::fold(FoldAdaptor adaptor) {
ArrayRef<Attribute> operands = adaptor.getOperands();

if (operands.size() != 2)
throw ErrorHandler::compilerError(
this->getLoc(), "CanonicalizerPass (mlir::daphne::ConcatOp::fold)",
"binary op takes two operands but " + std::to_string(operands.size()) + " were given");

if(!operands[0] || !operands[1])
return {};

if(llvm::isa<StringAttr>(operands[0]) && isa<StringAttr>(operands[1])) {
auto lhs = operands[0].cast<StringAttr>();
auto rhs = operands[1].cast<StringAttr>();

auto concated = lhs.getValue().str() + rhs.getValue().str();
return StringAttr::get(concated, getType());
}
return {};
}

mlir::OpFoldResult mlir::daphne::StringEqOp::fold(FoldAdaptor adaptor) {
ArrayRef<Attribute> operands = adaptor.getOperands();

if (operands.size() != 2)
throw ErrorHandler::compilerError(
this->getLoc(), "CanonicalizerPass (mlir::daphne::StringEqOp::fold)",
"binary op takes two operands but " + std::to_string(operands.size()) + " were given");

if (!operands[0] || !operands[1] || !llvm::isa<StringAttr>(operands[0]) ||
!isa<StringAttr>(operands[1])) {
return {};
}

auto lhs = operands[0].cast<StringAttr>();
auto rhs = operands[1].cast<StringAttr>();

return mlir::BoolAttr::get(getContext(), lhs.getValue() == rhs.getValue());
}

mlir::OpFoldResult mlir::daphne::EwEqOp::fold(FoldAdaptor adaptor) {
ArrayRef<Attribute> operands = adaptor.getOperands();
auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a == b; };
auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a == b; };
auto strOp = [](const llvm::StringRef &a, const llvm::StringRef &b) { return a == b; };
// TODO: fix bool return
if(auto res = constFoldBinaryOp<FloatAttr>(getLoc(), getType(), operands, floatOp))
return res;
if(auto res = constFoldBinaryOp<IntegerAttr>(getLoc(), getType(), operands, intOp))
return res;
if(auto res = constFoldBinaryOp<StringAttr, IntegerAttr>(getLoc(), IntegerType::get(getContext(), 64, IntegerType::SignednessSemantics::Signed), operands, strOp))
return res;
return {};
}

Expand Down Expand Up @@ -1278,30 +1239,6 @@ struct SimplifyDistributeRead : public mlir::OpRewritePattern<mlir::daphne::Dist
}
};

// The EwBinarySca kernel does not handle string types in any way. In order to
// support simple string equivalence checks this canonicalizer rewrites the
// EwEqOp to the StringEqOp if one of the operands is of daphne::StringType.
mlir::LogicalResult mlir::daphne::EwEqOp::canonicalize(
mlir::daphne::EwEqOp op, PatternRewriter &rewriter) {
mlir::Value lhs = op.getLhs();
mlir::Value rhs = op.getRhs();

const bool lhsIsStr = llvm::isa<mlir::daphne::StringType>(lhs.getType());
const bool rhsIsStr = llvm::isa<mlir::daphne::StringType>(rhs.getType());

if (!lhsIsStr && !rhsIsStr) return mlir::failure();

mlir::Type strTy = mlir::daphne::StringType::get(rewriter.getContext());
if (!lhsIsStr)
lhs = rewriter.create<mlir::daphne::CastOp>(op.getLoc(), strTy, lhs);
if (!rhsIsStr)
rhs = rewriter.create<mlir::daphne::CastOp>(op.getLoc(), strTy, rhs);

rewriter.replaceOpWithNewOp<mlir::daphne::StringEqOp>(
op, rewriter.getI1Type(), lhs, rhs);
return mlir::success();
}

/**
* @brief Replaces (1) `a + b` by `a concat b`, if `a` or `b` is a string,
* and (2) `a + X` by `X + a` (`a` scalar, `X` matrix/frame).
Expand Down Expand Up @@ -1331,7 +1268,7 @@ mlir::LogicalResult mlir::daphne::EwAddOp::canonicalize(
lhs = rewriter.create<mlir::daphne::CastOp>(op.getLoc(), strTy, lhs);
if(!rhsIsStr)
rhs = rewriter.create<mlir::daphne::CastOp>(op.getLoc(), strTy, rhs);
rewriter.replaceOpWithNewOp<mlir::daphne::ConcatOp>(op, strTy, lhs, rhs);
rewriter.replaceOpWithNewOp<mlir::daphne::EwConcatOp>(op, strTy, lhs, rhs);
return mlir::success();
}
else {
Expand Down
19 changes: 1 addition & 18 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -338,21 +338,6 @@ def Daphne_EwBitwiseAndOp : Daphne_EwBinaryOp<"ewBitwiseAnd", NumScalar, [Com

def Daphne_EwConcatOp : Daphne_EwBinaryOp<"ewConcat", StrScalar, [CastArgsToResType]>;

// TODO This is just a quick and dirty solution that should be properly
// integrated with EwConcatOp above.
def Daphne_ConcatOp : Daphne_Op<"concat", [DataTypeSca, ValueTypeStr]> {
let arguments = (ins StrScalar:$lhs, StrScalar:$rhs);
let results = (outs StrScalar:$res);

let hasFolder = 1;
}

def Daphne_StringEqOp : Daphne_Op<"stringEq", [ValueTypeStr]> {
let arguments = (ins StrScalar:$lhs, StrScalar:$rhs);
let results = (outs BoolScalar:$res);
let hasFolder = 1;
}

// ----------------------------------------------------------------------------
// Comparisons
// ----------------------------------------------------------------------------
Expand All @@ -364,9 +349,7 @@ class Daphne_EwCmpOp<string name, Type inputScalarType, list<Trait> traits = []>
//let results = (outs AnyTypeOf<[MatrixOf<[BoolScalar]>, BoolScalar, Unknown]>:$res);
}

def Daphne_EwEqOp : Daphne_EwCmpOp<"ewEq" , AnyScalar, [Commutative]> {
let hasCanonicalizeMethod = 1;
}
def Daphne_EwEqOp : Daphne_EwCmpOp<"ewEq" , AnyScalar, [Commutative]>;
def Daphne_EwNeqOp : Daphne_EwCmpOp<"ewNeq", AnyScalar, [Commutative, CUDASupport]>;
def Daphne_EwLtOp : Daphne_EwCmpOp<"ewLt" , AnyScalar>;
def Daphne_EwLeOp : Daphne_EwCmpOp<"ewLe" , AnyScalar>;
Expand Down
2 changes: 1 addition & 1 deletion src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f

if(func == "concat") {
checkNumArgsExact(loc, func, numArgs, 2);
return static_cast<mlir::Value>(builder.create<ConcatOp>(
return static_cast<mlir::Value>(builder.create<EwConcatOp>(
loc, StringType::get(builder.getContext()), args[0], args[1]
));
}
Expand Down
6 changes: 3 additions & 3 deletions src/parser/daphnedsl/DaphneDSLVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1338,11 +1338,11 @@ antlrcpp::Any DaphneDSLVisitor::visitAddExpr(DaphneDSLGrammarParser::AddExprCont

if(op == "+")
// Note that we use '+' for both addition (EwAddOp) and concatenation
// (ConcatOp). The choice is made based on the types of the operands
// (if one operand is a string, we choose ConcatOp). However, the types
// (EwConcatOp). The choice is made based on the types of the operands
// (if one operand is a string, we choose EwConcatOp). However, the types
// might not be known at this point in time. Thus, we always create an
// EwAddOp here. Note that EwAddOp has a canonicalize method rewriting
// it to ConcatOp if necessary.
// it to EwConcatOp if necessary.
return utils.retValWithInferedType(builder.create<mlir::daphne::EwAddOp>(loc, lhs, rhs));
if(op == "-")
return utils.retValWithInferedType(builder.create<mlir::daphne::EwSubOp>(loc, lhs, rhs));
Expand Down
1 change: 1 addition & 0 deletions src/runtime/local/datastructures/ValueTypeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ template<> const std::string ValueTypeUtils::cppNameFor<uint64_t>;
template<> const std::string ValueTypeUtils::cppNameFor<float>;
template<> const std::string ValueTypeUtils::cppNameFor<double>;
template<> const std::string ValueTypeUtils::cppNameFor<bool>;
template<> const std::string ValueTypeUtils::cppNameFor<char*>;

template<> const std::string ValueTypeUtils::irNameFor<int8_t>;
template<> const std::string ValueTypeUtils::irNameFor<int32_t>;
Expand Down
8 changes: 7 additions & 1 deletion src/runtime/local/kernels/BinaryOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ enum class BinaryOpCode {
OR,
// Bitwise.
BITWISE_AND,
// Strings.
CONCAT
};

/**
Expand All @@ -63,7 +65,9 @@ static std::string_view binary_op_codes[] = {
// Logical.
"AND", "OR",
// Bitwise.
"BITWISE_AND"
"BITWISE_AND",
// Strings.
"CONCAT"
};

// ****************************************************************************
Expand Down Expand Up @@ -147,6 +151,8 @@ SUPPORT_NUMERIC_INT(int8_t)
SUPPORT_NUMERIC_INT(uint64_t)
SUPPORT_NUMERIC_INT(uint32_t)
SUPPORT_NUMERIC_INT(uint8_t)
template<> constexpr bool supportsBinaryOp<BinaryOpCode::CONCAT, const char *, const char *, const char *> = true;
template<> constexpr bool supportsBinaryOp<BinaryOpCode::EQ, int64_t, const char *, const char *> = true;

// Undefine helper macros.
#undef SUPPORT
Expand Down
39 changes: 0 additions & 39 deletions src/runtime/local/kernels/Concat.h

This file was deleted.

25 changes: 25 additions & 0 deletions src/runtime/local/kernels/EwBinarySca.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ EwBinaryScaFuncPtr<VTRes, VTLhs, VTRhs> getEwBinaryScaFuncPtr(BinaryOpCode opCod
// Logical.
MAKE_CASE(BinaryOpCode::AND)
MAKE_CASE(BinaryOpCode::OR)
// Strings.
MAKE_CASE(BinaryOpCode::CONCAT)
#undef MAKE_CASE
default:
throw std::runtime_error(
Expand Down Expand Up @@ -162,12 +164,35 @@ MAKE_EW_BINARY_SCA(BinaryOpCode::LT , lhs < rhs)
MAKE_EW_BINARY_SCA(BinaryOpCode::LE , lhs <= rhs)
MAKE_EW_BINARY_SCA(BinaryOpCode::GT , lhs > rhs)
MAKE_EW_BINARY_SCA(BinaryOpCode::GE , lhs >= rhs)
template<typename TRes>
struct EwBinarySca<BinaryOpCode::EQ, TRes, const char *, const char *> {
inline static TRes apply(const char * lhs, const char * rhs, DCTX(ctx)) {
return std::string_view(lhs) == std::string_view(rhs);
}
};
// Min/max.
MAKE_EW_BINARY_SCA(BinaryOpCode::MIN, std::min(lhs, rhs))
MAKE_EW_BINARY_SCA(BinaryOpCode::MAX, std::max(lhs, rhs))
// Logical.
MAKE_EW_BINARY_SCA(BinaryOpCode::AND, lhs && rhs)
MAKE_EW_BINARY_SCA(BinaryOpCode::OR , lhs || rhs)
// Strings.
template<>
struct EwBinarySca<BinaryOpCode::CONCAT, const char *, const char *, const char *> {
inline static const char * apply(const char * lhs, const char * rhs, DCTX(ctx)) {
const auto lenLhs = std::string_view(lhs).size();
const auto lenRhs = std::string_view(rhs).size();
const auto lenRes = lenLhs + lenRhs;

char* res = new char[lenRes + 1];

std::memcpy(res , lhs, lenLhs);
std::memcpy(res + lenLhs, rhs, lenRhs);
res[lenRes] = '\0';

return res;
}
};

#undef MAKE_EW_BINARY_SCA

Expand Down
25 changes: 0 additions & 25 deletions src/runtime/local/kernels/StringEq.h

This file was deleted.

52 changes: 4 additions & 48 deletions src/runtime/local/kernels/kernels.json
Original file line number Diff line number Diff line change
Expand Up @@ -865,52 +865,6 @@
}
]
},
{
"kernelTemplate": {
"header": "StringEq.h",
"opName": "stringEq",
"returnType": "bool",
"templateParams": [],
"runtimeParams": [
{
"type": "const char *",
"name": "lhs"
},
{
"type": "const char *",
"name": "rhs"
}
]
},
"instantiations": [
[]
]
},
{
"kernelTemplate": {
"header": "Concat.h",
"opName": "concat",
"returnType": "void",
"templateParams": [],
"runtimeParams": [
{
"type": "char *&",
"name": "res"
},
{
"type": "const char *",
"name": "lhs"
},
{
"type": "const char *",
"name": "rhs"
}
]
},
"instantiations": [
[]
]
},
{
"kernelTemplate": {
"header": "CreateDaphneContext.h",
Expand Down Expand Up @@ -1430,9 +1384,11 @@
["int64_t", "int64_t", "int64_t"],
["uint64_t", "uint64_t", "uint64_t"],
["uint32_t", "uint32_t", "uint32_t"],
["size_t", "size_t", "size_t"]
["size_t", "size_t", "size_t"],
["const char *", "const char *", "const char *"],
["int64_t", "const char *", "const char *"]
],
"opCodes": ["ADD", "SUB", "MUL", "DIV", "POW", "LOG", "MOD", "EQ", "NEQ", "LT", "LE", "GT", "GE", "MIN", "MAX", "AND", "OR", "BITWISE_AND"]
"opCodes": ["ADD", "SUB", "MUL", "DIV", "POW", "LOG", "MOD", "EQ", "NEQ", "LT", "LE", "GT", "GE", "MIN", "MAX", "AND", "OR", "BITWISE_AND", "CONCAT"]
},
{
"kernelTemplate": {
Expand Down
Loading

0 comments on commit a7226dc

Please sign in to comment.