From 80330ffc60e54766b87decb5dd4b7da45c48ce73 Mon Sep 17 00:00:00 2001 From: Garic152 Date: Fri, 19 Jul 2024 07:22:31 +0200 Subject: [PATCH 1/2] feat: added new SparsityOp --- src/ir/daphneir/DaphneDialect.cpp | 22 ++++++++++++++ src/ir/daphneir/DaphneOps.td | 7 +++++ src/parser/daphnedsl/DaphneDSLBuiltins.cpp | 6 ++++ src/runtime/local/kernels/Sparsity.h | 31 +++++++++++++++++++ src/runtime/local/kernels/kernels.json | 35 ++++++++++++++++++++++ 5 files changed, 101 insertions(+) create mode 100644 src/runtime/local/kernels/Sparsity.h diff --git a/src/ir/daphneir/DaphneDialect.cpp b/src/ir/daphneir/DaphneDialect.cpp index a2e75e32a..52cf2cb5e 100644 --- a/src/ir/daphneir/DaphneDialect.cpp +++ b/src/ir/daphneir/DaphneDialect.cpp @@ -1176,6 +1176,28 @@ mlir::LogicalResult mlir::daphne::NumCellsOp::canonicalize( return mlir::failure(); } +/** + * @brief Replaces SparsityOp by a constant, if the sparsity of the input is known + * (e.g., due to sparsity inference). + */ +mlir::LogicalResult mlir::daphne::SparsityOp::canonicalize( + mlir::daphne::SparsityOp op, PatternRewriter &rewriter +) { + double sparsity = -1.0; + + mlir::Type inTy = op.getArg().getType(); + if(auto t = inTy.dyn_cast()) + sparsity = t.getSparsity(); + + if(sparsity != -1) { + rewriter.replaceOpWithNewOp( + op, rewriter.getF64Type(), rewriter.getFloatAttr(rewriter.getF64Type(), sparsity) + ); + return mlir::success(); + } + return mlir::failure(); +} + /** * @brief Replaces a `DistributeOp` by a `DistributedReadOp`, if its input * value (a) is defined by a `ReadOp`, and (b) is not used elsewhere. diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index 471c1ba9a..d5b507a5e 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -161,6 +161,13 @@ def Daphne_SeqOp : Daphne_Op<"seq", [ // Matrix/frame dimensions // **************************************************************************** +def SparsityOp : Daphne_Op<"sparsity", [DataTypeSca]> { + let arguments = (ins MatrixOf<[AnyScalar]>:$arg); + let results = (outs FloatScalar:$res); + + let hasCanonicalizeMethod = 1; +} + class Daphne_NumOp traits = []> : Daphne_Op { diff --git a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp index d7bf6d483..9d7906e10 100644 --- a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp +++ b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp @@ -464,6 +464,12 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f return createNumOp(loc, func, args); if(func == "ncell") return createNumOp(loc, func, args); + + if(func == "sparsity") { + checkNumArgsExact(loc, func, numArgs, 1); + mlir::Value arg = args[0]; + return static_cast(builder.create(loc, builder.getF64Type(), arg)); + } // ******************************************************************** // Elementwise unary diff --git a/src/runtime/local/kernels/Sparsity.h b/src/runtime/local/kernels/Sparsity.h new file mode 100644 index 000000000..87034a988 --- /dev/null +++ b/src/runtime/local/kernels/Sparsity.h @@ -0,0 +1,31 @@ +/* + * Copyright 2021 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_RUNTIME_LOCAL_KERNELS_SPARSITY_H +#define SRC_RUNTIME_LOCAL_KERNELS_SPARSITY_H + +#include + +// **************************************************************************** +// Convenience function +// **************************************************************************** + +template +double sparsity(const DTArg * arg, DCTX(ctx)) { + return -1.0; +} + +#endif //SRC_RUNTIME_LOCAL_KERNELS_SPARSITY_H \ No newline at end of file diff --git a/src/runtime/local/kernels/kernels.json b/src/runtime/local/kernels/kernels.json index 8ac24bfda..a7b616e0a 100644 --- a/src/runtime/local/kernels/kernels.json +++ b/src/runtime/local/kernels/kernels.json @@ -2946,6 +2946,41 @@ "opCodes": ["MINUS", "SIGN", "SQRT", "EXP", "ABS", "FLOOR", "CEIL", "ROUND", "LN", "SIN", "COS", "TAN", "ASIN", "ACOS", "ATAN", "SINH", "COSH", "TANH", "ISNAN"] }, + { + "kernelTemplate": { + "header": "Sparsity.h", + "opName": "sparsity", + "returnType": "double", + "templateParams": [ + { + "name": "DTArg", + "isDataType": true + } + ], + "runtimeParams": [ + { + "type": "const DTArg *", + "name": "arg" + } + ] + }, + "instantiations": [ + [["DenseMatrix", "double"]], + [["DenseMatrix", "float"]], + [["DenseMatrix", "int64_t"]], + [["DenseMatrix", "int32_t"]], + [["DenseMatrix", "int8_t"]], + [["DenseMatrix", "uint64_t"]], + [["DenseMatrix", "uint32_t"]], + [["DenseMatrix", "uint8_t"]], + [["DenseMatrix", "bool"]], + [["DenseMatrix", "size_t"]], + [["CSRMatrix", "double"]], + [["CSRMatrix", "float"]], + [["CSRMatrix", "int64_t"]], + [["CSRMatrix", "uint8_t"]] + ] + }, { "kernelTemplate": { "header": "NumCols.h", From 014d09835b24766b214e619ba8535b2438d41f23 Mon Sep 17 00:00:00 2001 From: Patrick Damme Date: Mon, 5 Aug 2024 22:51:44 +0200 Subject: [PATCH 2/2] Minor polishing. --- doc/DaphneDSL/Builtins.md | 5 +++++ src/ir/daphneir/DaphneDialect.cpp | 2 +- src/ir/daphneir/DaphneOps.td | 14 +++++++------- src/parser/daphnedsl/DaphneDSLBuiltins.cpp | 1 - 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/doc/DaphneDSL/Builtins.md b/doc/DaphneDSL/Builtins.md index ba02e5421..438f84aa3 100644 --- a/doc/DaphneDSL/Builtins.md +++ b/doc/DaphneDSL/Builtins.md @@ -100,6 +100,11 @@ The following built-in functions allow to find out the shape/dimensions of matri Returns the number of cells in `arg`. This is the product of the number of rows and the number of columns. + +- **`sparsity`**`(arg:matrix)` + + Returns the DAPHNE compiler's *estimate* of the argument's sparsity. + Note that this value may deviate from the *actual* sparsity of the data at run-time. ## Elementwise unary diff --git a/src/ir/daphneir/DaphneDialect.cpp b/src/ir/daphneir/DaphneDialect.cpp index 52cf2cb5e..31b006dd2 100644 --- a/src/ir/daphneir/DaphneDialect.cpp +++ b/src/ir/daphneir/DaphneDialect.cpp @@ -1191,7 +1191,7 @@ mlir::LogicalResult mlir::daphne::SparsityOp::canonicalize( if(sparsity != -1) { rewriter.replaceOpWithNewOp( - op, rewriter.getF64Type(), rewriter.getFloatAttr(rewriter.getF64Type(), sparsity) + op, sparsity ); return mlir::success(); } diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index d5b507a5e..647dc7375 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -161,13 +161,6 @@ def Daphne_SeqOp : Daphne_Op<"seq", [ // Matrix/frame dimensions // **************************************************************************** -def SparsityOp : Daphne_Op<"sparsity", [DataTypeSca]> { - let arguments = (ins MatrixOf<[AnyScalar]>:$arg); - let results = (outs FloatScalar:$res); - - let hasCanonicalizeMethod = 1; -} - class Daphne_NumOp traits = []> : Daphne_Op { @@ -181,6 +174,13 @@ def Daphne_NumRowsOp : Daphne_NumOp<"numRows">; def Daphne_NumColsOp : Daphne_NumOp<"numCols">; def Daphne_NumCellsOp : Daphne_NumOp<"numCells">; +def SparsityOp : Daphne_Op<"sparsity", [DataTypeSca]> { + let arguments = (ins MatrixOf<[AnyScalar]>:$arg); + let results = (outs FloatScalar:$res); + + let hasCanonicalizeMethod = 1; +} + // **************************************************************************** // Matrix multiplication // **************************************************************************** diff --git a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp index 9d7906e10..194d63523 100644 --- a/src/parser/daphnedsl/DaphneDSLBuiltins.cpp +++ b/src/parser/daphnedsl/DaphneDSLBuiltins.cpp @@ -464,7 +464,6 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f return createNumOp(loc, func, args); if(func == "ncell") return createNumOp(loc, func, args); - if(func == "sparsity") { checkNumArgsExact(loc, func, numArgs, 1); mlir::Value arg = args[0];