diff --git a/src/ir/daphneir/DaphneDialect.cpp b/src/ir/daphneir/DaphneDialect.cpp index 5388842fc..23f427fee 100644 --- a/src/ir/daphneir/DaphneDialect.cpp +++ b/src/ir/daphneir/DaphneDialect.cpp @@ -950,57 +950,18 @@ mlir::OpFoldResult mlir::daphne::EwConcatOp::fold(FoldAdaptor adaptor) { return {}; } -// TODO This is duplicated from EwConcatOp. Actually, ConcatOp itself is only -// a temporary workaround, so it should be removed altogether later. -mlir::OpFoldResult mlir::daphne::ConcatOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - - if (operands.size() != 2) - throw ErrorHandler::compilerError( - this->getLoc(), "CanonicalizerPass (mlir::daphne::ConcatOp::fold)", - "binary op takes two operands but " + std::to_string(operands.size()) + " were given"); - - if(!operands[0] || !operands[1]) - return {}; - - if(llvm::isa(operands[0]) && isa(operands[1])) { - auto lhs = operands[0].cast(); - auto rhs = operands[1].cast(); - - auto concated = lhs.getValue().str() + rhs.getValue().str(); - return StringAttr::get(concated, getType()); - } - return {}; -} - -mlir::OpFoldResult mlir::daphne::StringEqOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - - if (operands.size() != 2) - throw ErrorHandler::compilerError( - this->getLoc(), "CanonicalizerPass (mlir::daphne::StringEqOp::fold)", - "binary op takes two operands but " + std::to_string(operands.size()) + " were given"); - - if (!operands[0] || !operands[1] || !llvm::isa(operands[0]) || - !isa(operands[1])) { - return {}; - } - - auto lhs = operands[0].cast(); - auto rhs = operands[1].cast(); - - return mlir::BoolAttr::get(getContext(), lhs.getValue() == rhs.getValue()); -} - mlir::OpFoldResult mlir::daphne::EwEqOp::fold(FoldAdaptor adaptor) { ArrayRef operands = adaptor.getOperands(); auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a == b; }; auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a == b; }; + auto strOp = [](const llvm::StringRef &a, const llvm::StringRef &b) { return a == b; }; // TODO: fix bool return if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) return res; if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, intOp)) return res; + if(auto res = constFoldBinaryOp(getLoc(), IntegerType::get(getContext(), 64, IntegerType::SignednessSemantics::Signed), operands, strOp)) + return res; return {}; } @@ -1278,30 +1239,6 @@ struct SimplifyDistributeRead : public mlir::OpRewritePattern(lhs.getType()); - const bool rhsIsStr = llvm::isa(rhs.getType()); - - if (!lhsIsStr && !rhsIsStr) return mlir::failure(); - - mlir::Type strTy = mlir::daphne::StringType::get(rewriter.getContext()); - if (!lhsIsStr) - lhs = rewriter.create(op.getLoc(), strTy, lhs); - if (!rhsIsStr) - rhs = rewriter.create(op.getLoc(), strTy, rhs); - - rewriter.replaceOpWithNewOp( - op, rewriter.getI1Type(), lhs, rhs); - return mlir::success(); -} - /** * @brief Replaces (1) `a + b` by `a concat b`, if `a` or `b` is a string, * and (2) `a + X` by `X + a` (`a` scalar, `X` matrix/frame). @@ -1331,7 +1268,7 @@ mlir::LogicalResult mlir::daphne::EwAddOp::canonicalize( lhs = rewriter.create(op.getLoc(), strTy, lhs); if(!rhsIsStr) rhs = rewriter.create(op.getLoc(), strTy, rhs); - rewriter.replaceOpWithNewOp(op, strTy, lhs, rhs); + rewriter.replaceOpWithNewOp(op, strTy, lhs, rhs); return mlir::success(); } else { diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index 63f1cc671..de6095bfb 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -338,21 +338,6 @@ def Daphne_EwBitwiseAndOp : Daphne_EwBinaryOp<"ewBitwiseAnd", NumScalar, [Com def Daphne_EwConcatOp : Daphne_EwBinaryOp<"ewConcat", StrScalar, [CastArgsToResType]>; -// TODO This is just a quick and dirty solution that should be properly -// integrated with EwConcatOp above. -def Daphne_ConcatOp : Daphne_Op<"concat", [DataTypeSca, ValueTypeStr]> { - let arguments = (ins StrScalar:$lhs, StrScalar:$rhs); - let results = (outs StrScalar:$res); - - let hasFolder = 1; -} - -def Daphne_StringEqOp : Daphne_Op<"stringEq", [ValueTypeStr]> { - let arguments = (ins StrScalar:$lhs, StrScalar:$rhs); - let results = (outs BoolScalar:$res); - let hasFolder = 1; -} - // ---------------------------------------------------------------------------- // Comparisons // ---------------------------------------------------------------------------- @@ -364,9 +349,7 @@ class Daphne_EwCmpOp traits = []> //let results = (outs AnyTypeOf<[MatrixOf<[BoolScalar]>, BoolScalar, Unknown]>:$res); } -def Daphne_EwEqOp : Daphne_EwCmpOp<"ewEq" , AnyScalar, [Commutative]> { - let hasCanonicalizeMethod = 1; -} +def Daphne_EwEqOp : Daphne_EwCmpOp<"ewEq" , AnyScalar, [Commutative]>; def Daphne_EwNeqOp : Daphne_EwCmpOp<"ewNeq", AnyScalar, [Commutative, CUDASupport]>; def Daphne_EwLtOp : Daphne_EwCmpOp<"ewLt" , AnyScalar>; def Daphne_EwLeOp : Daphne_EwCmpOp<"ewLe" , AnyScalar>; diff --git a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp index 7866c74f9..d6a80d105 100644 --- a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp +++ b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp @@ -566,7 +566,7 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f if(func == "concat") { checkNumArgsExact(loc, func, numArgs, 2); - return static_cast(builder.create( + return static_cast(builder.create( loc, StringType::get(builder.getContext()), args[0], args[1] )); } diff --git a/src/parser/daphnedsl/DaphneDSLVisitor.cpp b/src/parser/daphnedsl/DaphneDSLVisitor.cpp index 6b5687c79..697b0bde2 100644 --- a/src/parser/daphnedsl/DaphneDSLVisitor.cpp +++ b/src/parser/daphnedsl/DaphneDSLVisitor.cpp @@ -1338,11 +1338,11 @@ antlrcpp::Any DaphneDSLVisitor::visitAddExpr(DaphneDSLGrammarParser::AddExprCont if(op == "+") // Note that we use '+' for both addition (EwAddOp) and concatenation - // (ConcatOp). The choice is made based on the types of the operands - // (if one operand is a string, we choose ConcatOp). However, the types + // (EwConcatOp). The choice is made based on the types of the operands + // (if one operand is a string, we choose EwConcatOp). However, the types // might not be known at this point in time. Thus, we always create an // EwAddOp here. Note that EwAddOp has a canonicalize method rewriting - // it to ConcatOp if necessary. + // it to EwConcatOp if necessary. return utils.retValWithInferedType(builder.create(loc, lhs, rhs)); if(op == "-") return utils.retValWithInferedType(builder.create(loc, lhs, rhs)); diff --git a/src/runtime/local/datastructures/ValueTypeUtils.h b/src/runtime/local/datastructures/ValueTypeUtils.h index 390d87d36..73ba5a0a3 100644 --- a/src/runtime/local/datastructures/ValueTypeUtils.h +++ b/src/runtime/local/datastructures/ValueTypeUtils.h @@ -70,6 +70,7 @@ template<> const std::string ValueTypeUtils::cppNameFor; template<> const std::string ValueTypeUtils::cppNameFor; template<> const std::string ValueTypeUtils::cppNameFor; template<> const std::string ValueTypeUtils::cppNameFor; +template<> const std::string ValueTypeUtils::cppNameFor; template<> const std::string ValueTypeUtils::irNameFor; template<> const std::string ValueTypeUtils::irNameFor; diff --git a/src/runtime/local/kernels/BinaryOpCode.h b/src/runtime/local/kernels/BinaryOpCode.h index 39c022a69..cbc5832e2 100644 --- a/src/runtime/local/kernels/BinaryOpCode.h +++ b/src/runtime/local/kernels/BinaryOpCode.h @@ -44,6 +44,8 @@ enum class BinaryOpCode { OR, // Bitwise. BITWISE_AND, + // Strings. + CONCAT }; /** @@ -63,7 +65,9 @@ static std::string_view binary_op_codes[] = { // Logical. "AND", "OR", // Bitwise. - "BITWISE_AND" + "BITWISE_AND", + // Strings. + "CONCAT" }; // **************************************************************************** @@ -147,6 +151,8 @@ SUPPORT_NUMERIC_INT(int8_t) SUPPORT_NUMERIC_INT(uint64_t) SUPPORT_NUMERIC_INT(uint32_t) SUPPORT_NUMERIC_INT(uint8_t) +template<> constexpr bool supportsBinaryOp = true; +template<> constexpr bool supportsBinaryOp = true; // Undefine helper macros. #undef SUPPORT diff --git a/src/runtime/local/kernels/Concat.h b/src/runtime/local/kernels/Concat.h deleted file mode 100644 index b090bb4aa..000000000 --- a/src/runtime/local/kernels/Concat.h +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright 2021 The DAPHNE Consortium - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include -#include - -// **************************************************************************** -// Convenience function -// **************************************************************************** - -void concat(char *& res, const char * lhs, const char * rhs, DCTX(ctx)) { - const auto lenLhs = std::string_view(lhs).size(); - const auto lenRhs = std::string_view(rhs).size(); - const auto lenRes = lenLhs + lenRhs; - - if(res == nullptr) - res = new char[lenRes + 1]; - - std::memcpy(res , lhs, lenLhs); - std::memcpy(res + lenLhs, rhs, lenRhs); - res[lenRes] = '\0'; -} diff --git a/src/runtime/local/kernels/EwBinarySca.h b/src/runtime/local/kernels/EwBinarySca.h index 5b49647dd..99c0ab414 100644 --- a/src/runtime/local/kernels/EwBinarySca.h +++ b/src/runtime/local/kernels/EwBinarySca.h @@ -91,6 +91,8 @@ EwBinaryScaFuncPtr getEwBinaryScaFuncPtr(BinaryOpCode opCod // Logical. MAKE_CASE(BinaryOpCode::AND) MAKE_CASE(BinaryOpCode::OR) + // Strings. + MAKE_CASE(BinaryOpCode::CONCAT) #undef MAKE_CASE default: throw std::runtime_error( @@ -162,12 +164,35 @@ MAKE_EW_BINARY_SCA(BinaryOpCode::LT , lhs < rhs) MAKE_EW_BINARY_SCA(BinaryOpCode::LE , lhs <= rhs) MAKE_EW_BINARY_SCA(BinaryOpCode::GT , lhs > rhs) MAKE_EW_BINARY_SCA(BinaryOpCode::GE , lhs >= rhs) +template +struct EwBinarySca { + inline static TRes apply(const char * lhs, const char * rhs, DCTX(ctx)) { + return std::string_view(lhs) == std::string_view(rhs); + } +}; // Min/max. MAKE_EW_BINARY_SCA(BinaryOpCode::MIN, std::min(lhs, rhs)) MAKE_EW_BINARY_SCA(BinaryOpCode::MAX, std::max(lhs, rhs)) // Logical. MAKE_EW_BINARY_SCA(BinaryOpCode::AND, lhs && rhs) MAKE_EW_BINARY_SCA(BinaryOpCode::OR , lhs || rhs) +// Strings. +template<> +struct EwBinarySca { + inline static const char * apply(const char * lhs, const char * rhs, DCTX(ctx)) { + const auto lenLhs = std::string_view(lhs).size(); + const auto lenRhs = std::string_view(rhs).size(); + const auto lenRes = lenLhs + lenRhs; + + char* res = new char[lenRes + 1]; + + std::memcpy(res , lhs, lenLhs); + std::memcpy(res + lenLhs, rhs, lenRhs); + res[lenRes] = '\0'; + + return res; + } +}; #undef MAKE_EW_BINARY_SCA diff --git a/src/runtime/local/kernels/StringEq.h b/src/runtime/local/kernels/StringEq.h deleted file mode 100644 index 14df80743..000000000 --- a/src/runtime/local/kernels/StringEq.h +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright 2024 The DAPHNE Consortium - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include - -inline bool stringEq(const char *lhs, const char *rhs, DCTX(ctx)) { - return std::string_view(lhs) == std::string_view(rhs); -} diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 8465cbef7..10fa02e2f 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -865,52 +865,6 @@ } ] }, - { - "kernelTemplate": { - "header": "StringEq.h", - "opName": "stringEq", - "returnType": "bool", - "templateParams": [], - "runtimeParams": [ - { - "type": "const char *", - "name": "lhs" - }, - { - "type": "const char *", - "name": "rhs" - } - ] - }, - "instantiations": [ - [] - ] - }, - { - "kernelTemplate": { - "header": "Concat.h", - "opName": "concat", - "returnType": "void", - "templateParams": [], - "runtimeParams": [ - { - "type": "char *&", - "name": "res" - }, - { - "type": "const char *", - "name": "lhs" - }, - { - "type": "const char *", - "name": "rhs" - } - ] - }, - "instantiations": [ - [] - ] - }, { "kernelTemplate": { "header": "CreateDaphneContext.h", @@ -1430,9 +1384,11 @@ ["int64_t", "int64_t", "int64_t"], ["uint64_t", "uint64_t", "uint64_t"], ["uint32_t", "uint32_t", "uint32_t"], - ["size_t", "size_t", "size_t"] + ["size_t", "size_t", "size_t"], + ["const char *", "const char *", "const char *"], + ["int64_t", "const char *", "const char *"] ], - "opCodes": ["ADD", "SUB", "MUL", "DIV", "POW", "LOG", "MOD", "EQ", "NEQ", "LT", "LE", "GT", "GE", "MIN", "MAX", "AND", "OR", "BITWISE_AND"] + "opCodes": ["ADD", "SUB", "MUL", "DIV", "POW", "LOG", "MOD", "EQ", "NEQ", "LT", "LE", "GT", "GE", "MIN", "MAX", "AND", "OR", "BITWISE_AND", "CONCAT"] }, { "kernelTemplate": { diff --git a/test/codegen/stringeq.mlir b/test/codegen/stringeq.mlir deleted file mode 100644 index 3a5d18fe8..000000000 --- a/test/codegen/stringeq.mlir +++ /dev/null @@ -1,43 +0,0 @@ -// RUN: daphne-opt --canonicalize %s | FileCheck %s - -func.func @string_string() { - %0 = "daphne.constant"() {value = "debug"} : () -> !daphne.String - %1 = "daphne.constant"() {value = "debug"} : () -> !daphne.String - // CHECK-NOT: daphne.ewEq - %2 = "daphne.ewEq"(%0, %1) : (!daphne.String, !daphne.String) -> !daphne.String - %3 = "daphne.cast"(%2) : (!daphne.String) -> i1 - "daphne.print"(%0, %3, %3) : (!daphne.String, i1, i1) -> () - "daphne.return"() : () -> () -} - -func.func @string_int() { - // CHECK-NOT: daphne.eqEq - // CHECK: daphne.cast - // CHECK: daphne.stringEq - %0 = "daphne.constant"() {value = "debug"} : () -> !daphne.String - %1 = "daphne.constant"() {value = 5 : si64} : () -> si64 - %2 = "daphne.ewEq"(%0, %1) : (!daphne.String, si64) -> !daphne.String - %3 = "daphne.cast"(%2) : (!daphne.String) -> i1 - "daphne.print"(%0, %3, %3) : (!daphne.String, i1, i1) -> () - "daphne.return"() : () -> () -} - -func.func @int_int_do_not_canonicalize() { - %0 = "daphne.constant"() {value = 2 : si64} : () -> si64 - %1 = "daphne.constant"() {value = 5 : si64} : () -> si64 - %2 = "daphne.ewEq"(%0, %1) : (si64, si64) -> si64 - // CHECK-NOT: daphne.stringEq - %3 = "daphne.cast"(%2) : (si64) -> i1 - scf.if %3 { - %4 = "daphne.constant"() {value = "debug"} : () -> !daphne.String - %5 = "daphne.constant"() {value = true} : () -> i1 - %6 = "daphne.constant"() {value = false} : () -> i1 - "daphne.print"(%4, %5, %6) : (!daphne.String, i1, i1) -> () - } else { - %4 = "daphne.constant"() {value = "release"} : () -> !daphne.String - %5 = "daphne.constant"() {value = true} : () -> i1 - %6 = "daphne.constant"() {value = false} : () -> i1 - "daphne.print"(%4, %5, %6) : (!daphne.String, i1, i1) -> () - } - "daphne.return"() : () -> () -}