diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..ed5e6cc42 --- /dev/null +++ b/.clang-format @@ -0,0 +1,3 @@ +BasedOnStyle: LLVM +IndentWidth: 4 +ColumnLimit: 120 diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..ab68da3aa --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,12 @@ +# These commits are ignored by `git blame`. +# https://git-scm.com/docs/git-blame +# Run this command to configure git to use this file. +# `$ git config blame.ignoreRevsFile .git-blame-ignore-revs` + +# clang-format src/ +cb999f20b6f2934ad7c94b10d2b02f6acf74aab4 +b8d91db545fba0f2e85070dc438d2447528b619e + +# clang-format test/ +9a1b93e4fea27e91d14c18821be2c940adf63bf4 +68f7925044ba8777f6a7f41bf5704915de13f608 diff --git a/.github/workflows/clang-format-check.yml b/.github/workflows/clang-format-check.yml new file mode 100644 index 000000000..35ddc05b0 --- /dev/null +++ b/.github/workflows/clang-format-check.yml @@ -0,0 +1,20 @@ +name: check-clang-format + +on: + push: + branches: + - main + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: DoozyX/clang-format-lint-action@v0.18.1 + with: + source: 'src test' + exclude: './third_party ./external' + extensions: 'h,cpp' + clangFormatVersion: 18.1.3 diff --git a/CMakeLists.txt b/CMakeLists.txt index e25b7642e..fb57026ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -cmake_minimum_required(VERSION 3.20) +cmake_minimum_required(VERSION 3.25.2) # Build release version by default (override with -DCMAKE_BUILD_TYPE=Debug in your initial cmake invocation) # This needs to be set *before* the project() command @@ -32,6 +32,7 @@ set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) set(CMAKE_CXX_STANDARD 20 CACHE STRING "C++ standard to conform to") set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_OPTIMIZE_DEPENDENCIES 1) set(CMAKE_CXX_FLAGS_DEBUG="${CMAKE_CXX_FLAGS_DEBUG} -g") set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") @@ -112,12 +113,17 @@ if(USE_PAPI) endif() ########## +find_package(fmt) +add_definitions(-DSPDLOG_FMT_EXTERNAL) + option(USE_CUDA "Whether to activate compilation of CUDA features" OFF) include(CheckLanguage) check_language(CUDA) if(USE_CUDA AND CMAKE_CUDA_COMPILER) enable_language(CUDA) find_package(CUDAToolkit REQUIRED) + set(CMAKE_CUDA_STANDARD 20) + set(CMAKE_CUDA_STANDARD_REQUIRED ON) if(${CMAKE_COMPILER_IS_GNUCXX}) set(GCC_EXPECTED_VERSION 11.3.0) @@ -137,7 +143,7 @@ if(USE_CUDA AND CMAKE_CUDA_COMPILER) add_definitions(-DUSE_CUDA) message(STATUS "Note: disabled CUSPARSE_DEPRECATED in main CMakeLists.txt") add_definitions(-DDISABLE_CUSPARSE_DEPRECATED) - set(CMAKE_CUDA_STANDARD 17) + set(CMAKE_CUDA_STANDARD 20) set(CMAKE_CUDA_STANDARD_REQUIRED ON) message(STATUS "CUDA enabled (version ${CMAKE_CUDA_COMPILER_VERSION})") if(DEFINED ENV{CUDAHOSTCXX}) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 56ba82fe0..5e27ad91b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -79,9 +79,13 @@ That is, please try your best to make a good-quality contribution and we will he Please choose an expressive title and provide a short description of your changes. Feel free to mark your pull request "WIP: " or "Draft: " in the title. Note that you can add more commits to your pull request after you created it. -7. You **receive feedback** on your proposed contribution. + Ideally, the changes in the PR contain only the changes you made for that PR, + e.g, by rebasing your branch on top of the target branch. This makes it easier for others to + review your PR. +7. [Resolve any open conflicts](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/addressing-merge-conflicts/about-merge-conflicts) to the target branch of the PR. +8. You **receive feedback** on your proposed contribution. You may be asked to apply certain changes, or we might apply straightforward adjustments ourselves before the integration. -8. If it looks good (potentially after some help), **your contribution becomes a part of DAPHNE**. +9. If it looks good (potentially after some help), **your contribution becomes a part of DAPHNE**. ### Experienced DAPHNE Contributors (Collaborators) @@ -108,3 +112,18 @@ At the same time, this freedom comes with certain responsibilities, which are ro - actually merging a pull request in Balancing the handling of pull requests is important to *keep the development process scalable*. + + +### Code Style + +Before contributing, please make sure to run `clang-format` on your C++ (.h and +.cpp) files. The codebase is currently formatted with `clang-format` version +`18.1.3`. This is the default `clang-format` version when installing via `apt` +on Ubuntu 24.04, and can easily be installed via `python -mpip install clang-format==18.1.3` +on other systems. +We provide a `.clang-format` file at the root of the repository. Most text +editors and IDEs will have some kind of integration for detecting that file +and automatically applying `clang-format`. `git-clang-format` can be used to +format staged files. +For more information about `clang-format`, `git-clang-format` and text editor +integration, please see [ClangFormat](https://clang.llvm.org/docs/ClangFormat.html). diff --git a/build.sh b/build.sh index 7b0813dcf..5bdf6203f 100755 --- a/build.sh +++ b/build.sh @@ -60,6 +60,7 @@ function printHelp { echo " --fpgaopencl Compile with support for Intel PAC D5005 FPGA" echo " --mpi Compile with support for MPI" echo " --hdfs Compile with support for HDFS" + echo " --io_uring Compile with support for io_uring" echo " --no-papi Compile without support for PAPI" } @@ -451,6 +452,7 @@ BUILD_FPGAOPENCL="-DUSE_FPGAOPENCL=OFF" BUILD_DEBUG="-DCMAKE_BUILD_TYPE=Release" BUILD_MPI="-DUSE_MPI=OFF" BUILD_HDFS="-DUSE_HDFS=OFF" +BUILD_IO_URING="-DUSE_IO_URING=OFF" BUILD_PAPI="-DUSE_PAPI=ON" WITH_DEPS=1 WITH_SUBMODULE_UPDATE=1 @@ -504,6 +506,10 @@ while [[ $# -gt 0 ]]; do echo using HDFS export BUILD_HDFS="-DUSE_HDFS=ON" ;; + --io-uring) + echo using io_uring + export BUILD_IO_URING="-DUSE_IO_URING=ON" + ;; --no-papi) echo not using PAPI export BUILD_PAPI="-DUSE_PAPI=OFF" @@ -655,7 +661,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # #8.3 Antlr4 (parser) + # Antlr4 (parser) #------------------------------------------------------------------------------ antlrJarName="antlr-${antlrVersion}-complete.jar" @@ -708,7 +714,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # #8.4 catch2 (unit test framework) + # catch2 (unit test framework) #------------------------------------------------------------------------------ # Download catch2 release zip (if necessary), and unpack the single header file # (if necessary). @@ -734,7 +740,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # #8.5 OpenBLAS (basic linear algebra subprograms) + # OpenBLAS (basic linear algebra subprograms) #------------------------------------------------------------------------------ openBlasDirName="OpenBLAS-$openBlasVersion" @@ -744,7 +750,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then if ! is_dependency_downloaded "${dep_openBlas[@]}"; then daphne_msg "Get OpenBlas version ${openBlasVersion}" - wget "https://github.com/xianyi/OpenBLAS/releases/download/v${openBlasVersion}/${openBlasZipName}" \ + wget "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${openBlasVersion}/${openBlasZipName}" \ -qO "${cacheDir}/${openBlasZipName}" unzip -q "$cacheDir/$openBlasZipName" -d "$sourcePrefix" dependency_download_success "${dep_openBlas[@]}" @@ -761,7 +767,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # #8.6 nlohmann/json (library for JSON parsing) + # nlohmann/json (library for JSON parsing) #------------------------------------------------------------------------------ nlohmannjsonDirName=nlohmannjson @@ -779,7 +785,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # #8.7 abseil (compiled separately to apply a patch) + # abseil (compiled separately to apply a patch) #------------------------------------------------------------------------------ abslPath=$sourcePrefix/abseil-cpp @@ -808,7 +814,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # #8.8 MPI (Default is MPI library is OpenMPI but cut can be any) + # MPI (Default is MPI library is OpenMPI but cut can be any) #------------------------------------------------------------------------------ MPIZipName=openmpi-$openMPIVersion.tar.gz @@ -834,7 +840,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # #8.9 gRPC + # gRPC #------------------------------------------------------------------------------ grpcDirName="grpc" @@ -875,7 +881,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # #8.10 Arrow / Parquet + # Arrow / Parquet #------------------------------------------------------------------------------ arrowDirName="apache-arrow-$arrowVersion" @@ -913,22 +919,44 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # 8.11 spdlog + # fmt + #------------------------------------------------------------------------------ + + fmtDirName="fmt-$fmtVersion" + fmtArtifactFileName=$fmtDirName.zip + if ! is_dependency_downloaded "fmt_v${fmtVersion}"; then + rm -rf "${sourcePrefix:?}/${fmtDirName}" + wget "https://github.com/fmtlib/fmt/releases/download/${fmtVersion}/$fmtArtifactFileName" -qO "$cacheDir/$fmtArtifactFileName" + unzip -q "$cacheDir/$fmtArtifactFileName" -d "$sourcePrefix" + dependency_download_success "fmt_v${fmtVersion}" + fi + if ! is_dependency_installed "fmt_v${fmtVersion}"; then + cmake -G Ninja -S "${sourcePrefix}/${fmtDirName}" -B "${buildPrefix}/${fmtDirName}" \ + -DCMAKE_INSTALL_PREFIX="${installPrefix}" -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DFMT_MASTER_PROJECT=OFF + cmake --build "${buildPrefix}/${fmtDirName}" --target install/strip + dependency_install_success "fmt_v${fmtVersion}" + else + daphne_msg "No need to build fmt again." + fi + + #------------------------------------------------------------------------------ + # spdlog #------------------------------------------------------------------------------ spdlogDirName="spdlog-$spdlogVersion" spdlogArtifactFileName=$spdlogDirName.tar.gz if ! is_dependency_downloaded "spdlog_v${spdlogVersion}"; then rm -rf "${sourcePrefix:?}/${spdlogDirName}" - wget "https://github.com/gabime/spdlog/archive/refs/tags/v$spdlogVersion.tar.gz" -qO \ + # changed URL scheme due to temporarily use tip of main branch (2024-10-03) +# wget "https://github.com/gabime/spdlog/archive/refs/tags/v$spdlogVersion.tar.gz" -qO \ + wget https://github.com/gabime/spdlog/archive/$spdlogVersion.tar.gz -qO \ "$cacheDir/$spdlogArtifactFileName" tar xzf "$cacheDir/$spdlogArtifactFileName" --directory="$sourcePrefix" dependency_download_success "spdlog_v${spdlogVersion}" fi - if ! is_dependency_installed "spdlog_v${spdlogVersion}"; then cmake -G Ninja -S "${sourcePrefix}/${spdlogDirName}" -B "${buildPrefix}/${spdlogDirName}" \ - -DCMAKE_INSTALL_PREFIX="${installPrefix}" -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DSPDLOG_FMT_EXTERNAL=ON -DCMAKE_INSTALL_PREFIX="${installPrefix}" -DCMAKE_POSITION_INDEPENDENT_CODE=ON cmake --build "${buildPrefix}/${spdlogDirName}" --target install/strip dependency_install_success "spdlog_v${spdlogVersion}" else @@ -936,7 +964,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # 8.12 Eigen + # Eigen #------------------------------------------------------------------------------ eigenDirName="eigen-${eigenVersion}" @@ -957,7 +985,38 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # #8.13 Build MLIR + # HAWQ (libhdfs3) + #------------------------------------------------------------------------------ + + hawqDirName="hawq-rel-v$hawqVersion" + hawqDlTarName="v${hawqVersion}.tar.gz" + hawqTarName="${hawqDirName}.tar.gz" + hawqInstDirName=$installPrefix + + if [ $BUILD_HDFS == "-DUSE_HDFS=ON" ]; then + if ! is_dependency_downloaded "hawq_v${hawqVersion}"; then + daphne_msg "Get HAWQ (libhdfs3) version ${hawqVersion}" + wget "https://github.com/apache/hawq/archive/refs/tags/rel/${hawqDlTarName}" \ + -qO "${cacheDir}/${hawqTarName}" + tar -xf "$cacheDir/$hawqTarName" -C "$sourcePrefix" + daphne_msg "Applying 0005-libhdfs3-remove-gtest-dep.patch" + patch -Np1 -i "${patchDir}/0005-libhdfs3-remove-gtest-dep.patch" -d "$sourcePrefix/$hawqDirName" + daphne_msg "Applying 0006-libhdfs3-add-cstdint-include.patch" + patch -Np1 -i "${patchDir}/0006-libhdfs3-add-cstdint-include.patch" -d "$sourcePrefix/$hawqDirName" + dependency_download_success "hawq_v${hawqVersion}" + fi + if ! is_dependency_installed "hawq_v${hawqVersion}"; then + cmake -G Ninja -S "$sourcePrefix/$hawqDirName/depends/libhdfs3" -B "${buildPrefix}/${hawqDirName}" \ + -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX="$installPrefix" + cmake --build "${buildPrefix}/${hawqDirName}" --target install/strip + dependency_install_success "hawq_v${hawqVersion}" + else + daphne_msg "No need to build HAWQ (libhdfs3) again." + fi + fi + + #------------------------------------------------------------------------------ + # Build MLIR #------------------------------------------------------------------------------ # We rarely need to build MLIR/LLVM, only during the first build of the # prototype and after upgrades of the LLVM sub-module. To avoid unnecessary @@ -1016,34 +1075,7 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then fi #------------------------------------------------------------------------------ - # #8.14 HAWQ (libhdfs3) - #------------------------------------------------------------------------------ - - hawqDirName="hawq-rel-v$hawqVersion" - hawqTarName="v${hawqVersion}.tar.gz" - hawqInstDirName=$installPrefix - if ! is_dependency_downloaded "hawq_v${hawqVersion}"; then - daphne_msg "Get HAWQ (libhdfs3) version ${hawqVersion}" - wget "https://github.com/apache/hawq/archive/refs/tags/rel/${hawqTarName}" \ - -qO "${cacheDir}/${hawqTarName}" - tar -xf "$cacheDir/$hawqTarName" -C "$sourcePrefix" - daphne_msg "Applying 0005-libhdfs3-remove-gtest-dep.patch" - patch -Np1 -i "${patchDir}/0005-libhdfs3-remove-gtest-dep.patch" -d "$sourcePrefix/$hawqDirName" - daphne_msg "Applying 0006-libhdfs3-add-cstdint-include.patch" - patch -Np1 -i "${patchDir}/0006-libhdfs3-add-cstdint-include.patch" -d "$sourcePrefix/$hawqDirName" - dependency_download_success "hawq_v${hawqVersion}" - fi - if ! is_dependency_installed "hawq_v${hawqVersion}"; then - cmake -G Ninja -S "$sourcePrefix/$hawqDirName/depends/libhdfs3" -B "${buildPrefix}/${hawqDirName}" \ - -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX="$installPrefix" - cmake --build "${buildPrefix}/${hawqDirName}" --target install/strip - dependency_install_success "hawq_v${hawqVersion}" - else - daphne_msg "No need to build HAWQ (libhdfs3) again." - fi - - #------------------------------------------------------------------------------ - # 8.15 Liburing + # Liburing #------------------------------------------------------------------------------ liburingDirName="liburing-$liburingVersion" @@ -1053,28 +1085,30 @@ if [ $BUILD_PAPI == "-DUSE_PAPI=ON" ]; then liburing_cc=$([ "$CC" = "" ] && echo "gcc" || echo "$CC") liburing_cxx=$([ "$CXX" = "" ] && echo "g++" || echo "$CXX") - if ! is_dependency_downloaded "liburing_v${liburingVersion}"; then - daphne_msg "Get liburing version ${liburingVersion}" - wget "https://github.com/axboe/liburing/archive/refs/tags/${liburingTarName}" \ - -qO "${cacheDir}/${liburingTarName}" - mkdir "$sourcePrefix/$liburingDirName" - tar -xf "$cacheDir/$liburingTarName" -C "$sourcePrefix/$liburingDirName" --strip-components=1 - dependency_download_success "liburing_v${liburingVersion}" - fi - if ! is_dependency_installed "liburing_v${liburingVersion}"; then - cd "$sourcePrefix/$liburingDirName" - ./configure --cc="$liburing_cc" --cxx="$liburing_cxx" --prefix="$liburingInstDirName" - make -j"$(nproc)" - cp ./src/liburing.a "$installPrefix/lib/" - cp -r ./src/include/* "$installPrefix/include" - cd - > /dev/null - dependency_install_success "liburing_v${liburingVersion}" - else - daphne_msg "No need to build liburing again." + if [ $BUILD_IO_URING == "-DUSE_IO_URING=ON" ]; then + if ! is_dependency_downloaded "liburing_v${liburingVersion}"; then + daphne_msg "Get liburing version ${liburingVersion}" + wget "https://github.com/axboe/liburing/archive/refs/tags/${liburingTarName}" \ + -qO "${cacheDir}/${liburingTarName}" + mkdir "$sourcePrefix/$liburingDirName" + tar -xf "$cacheDir/$liburingTarName" -C "$sourcePrefix/$liburingDirName" --strip-components=1 + dependency_download_success "liburing_v${liburingVersion}" + fi + if ! is_dependency_installed "liburing_v${liburingVersion}"; then + cd "$sourcePrefix/$liburingDirName" + ./configure --cc="$liburing_cc" --cxx="$liburing_cxx" --prefix="$liburingInstDirName" + make -j"$(nproc)" + cp ./src/liburing.a "$installPrefix/lib/" + cp -r ./src/include/* "$installPrefix/include" + cd - > /dev/null + dependency_install_success "liburing_v${liburingVersion}" + else + daphne_msg "No need to build liburing again." + fi fi #------------------------------------------------------------------------------ - # 8.16 Fetch bitstreams + # Fetch bitstreams #------------------------------------------------------------------------------ if [[ $BUILD_FPGAOPENCL = *"ON"* ]]; then diff --git a/containers/build-containers.sh b/containers/build-containers.sh index 10e65020e..42d842c38 100755 --- a/containers/build-containers.sh +++ b/containers/build-containers.sh @@ -85,6 +85,7 @@ DAPHNE_TARGET=daphne-deps BASE_IMAGE=ubuntu:${ubuntuVersion} DAPHNE_TAG=$TIMESTAMP_DATE_${ARCH} IMAGE_REPO=daphneeu/$DAPHNE_TARGET +DAPHNE_BUILD_FLAGS="--hdfs --mpi" #bulid deps stage build_daphne -deps @@ -106,7 +107,6 @@ BASE_IMAGE=ubuntu:${ubuntuVersion} DAPHNE_TAG=${TIMESTAMP_DATE}_${ARCH}_BASE_ubuntu${ubuntuVersion} IMAGE_REPO=daphneeu/$DAPHNE_TARGET build_daphne -dev - $USE_SUDO docker tag $IMAGE_REPO:$DAPHNE_TAG daphneeu/daphne-dev:latest_${ARCH}_BASE #------------------------------------------------------------------------------ @@ -118,19 +118,8 @@ BASE_IMAGE=nvidia/cuda:$CUDA_TAG DAPHNE_TAG=${TIMESTAMP_DATE}_${ARCH}_CUDA_${CUDA_TAG} IMAGE_REPO=daphneeu/$DAPHNE_TARGET build_daphne -dev - $USE_SUDO docker tag $IMAGE_REPO:$DAPHNE_TAG daphneeu/daphne-dev:latest_${ARCH}_CUDA -#----------------------------------------------------------------------------- -# Images for DAPHNE development (OneAPI) -#------------------------------------------------------------------------------ -#DAPHNE_TARGET=daphne-dev -#ONEAPI_TAG=2023.1.0-devel-ubuntu${ubuntuVersion} -#BASE_IMAGE=intel/oneapi:$ONEAPI_TAG -#DAPHNE_TAG=${TIMESTAMP_DATE}_${ONEAPI_TAG} -#IMAGE_REPO=daphneeu/$DAPHNE_TARGET -#build_daphne -dev - #------------------------------------------------------------------------------ # Images for running DAPHNE #------------------------------------------------------------------------------ @@ -139,7 +128,7 @@ BASE_IMAGE=daphneeu/daphne-deps FINAL_BASE_IMAGE=ubuntu:${ubuntuVersion} DAPHNE_TAG=${TIMESTAMP_DATE}_${ARCH}_BASE_ubuntu${ubuntuVersion} IMAGE_REPO=daphneeu/$DAPHNE_TARGET -DAPHNE_BUILD_FLAGS="--mpi" +DAPHNE_BUILD_FLAGS="--hdfs --mpi" build_daphne $USE_SUDO docker tag $IMAGE_REPO:$DAPHNE_TAG daphneeu/daphne:latest_${ARCH}_BASE @@ -152,8 +141,19 @@ DAPHNE_TAG=${TIMESTAMP_DATE}_${ARCH}_CUDA_${CUDA_TAG} IMAGE_REPO=daphneeu/$DAPHNE_TARGET BASE_IMAGE=daphneeu/daphne-dev FINAL_BASE_IMAGE=nvidia/cuda:$CUDA_TAG -DAPHNE_BUILD_FLAGS="--mpi --cuda" +DAPHNE_BUILD_FLAGS="--hdfs --mpi --cuda" build_daphne $USE_SUDO docker tag $IMAGE_REPO:$DAPHNE_TAG daphneeu/daphne:latest_${ARCH}_CUDA +#----------------------------------------------------------------------------- +# Images for conversion to singularity for DAPHNE compilation +#------------------------------------------------------------------------------ +DAPHNE_TARGET=daphne-dev-hpc +CUDA_TAG=${cudaVersion}-cudnn-devel-ubuntu${ubuntuVersion} +BASE_IMAGE=nvidia/cuda:$CUDA_TAG +DAPHNE_TAG=${TIMESTAMP_DATE}_${ARCH}_CUDA_${CUDA_TAG} +IMAGE_REPO=daphneeu/$DAPHNE_TARGET +build_daphne -dev-hpc +$USE_SUDO docker tag $IMAGE_REPO:$DAPHNE_TAG daphneeu/daphne-dev:latest_${ARCH}_HPC + set +e diff --git a/containers/daphne-deps.Dockerfile b/containers/daphne-deps.Dockerfile index b5b93de55..9a9131068 100644 --- a/containers/daphne-deps.Dockerfile +++ b/containers/daphne-deps.Dockerfile @@ -62,9 +62,10 @@ FROM build-cmake AS build ARG DAPHNE_DIR=/daphne ARG DAPHNE_REPO=https://github.com/daphne-eu/daphne.git ARG DAPHNE_BRANCH=main +ARG DAPHNE_BUILD_FLAGS="--mpi --hdfs" RUN git clone --depth=1 --single-branch --branch=$DAPHNE_BRANCH $DAPHNE_REPO $DAPHNE_DIR WORKDIR $DAPHNE_DIR -RUN ./build.sh --no-fancy --no-submodule-update --installPrefix /usr/local +RUN PATH=/usr/local/bin:$PATH LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH ./build.sh --no-fancy --no-submodule-update --installPrefix /usr/local $DAPHNE_BUILD_FLAGS RUN find /usr/local -exec file {} \; | grep -e "not stripped" | cut -d ":" -f 1 | xargs strip --strip-unneeded RUN rm -rf $DAPHNE_DIR RUN ldconfig diff --git a/containers/daphne-dev-hpc.Dockerfile b/containers/daphne-dev-hpc.Dockerfile new file mode 100644 index 000000000..5b4516d8e --- /dev/null +++ b/containers/daphne-dev-hpc.Dockerfile @@ -0,0 +1,51 @@ +# syntax=docker/dockerfile:1 + +# Copyright 2023 The DAPHNE Consortium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This Dockerfile provides a basic DAPHNE compilation environment with all +# third party dependencies precompiled (use ''./build.sh --no-deps --installPrefix /usr/local'' to compile DAPHNE) + +ARG BASE_IMAGE=ubuntu:20.04 +#ARG FINAL_BASE_IMAGE=ubuntu:20.04 +ARG CMAKE_VERSION=3.29.3 +ARG TIMESTAMP=0 +ARG TZ=Etc/UTC + +FROM ${BASE_IMAGE} AS daphne-dev-hpc +ARG DEBIAN_FRONTEND="noninteractive" +ARG TZ +RUN apt-get -qq -y update && apt-get -y upgrade && apt-get -y --no-install-recommends install \ + ca-certificates file git openssh-client unzip wget tar \ + libomp-dev libpfm4-dev libssl-dev libxml2-dev uuid-dev zlib1g-dev libgsasl-dev libkrb5-dev \ + build-essential clang gfortran lld llvm llvm-18-tools ninja-build openjdk-11-jdk-headless pkg-config python3-numpy python3-pandas \ + vim nano rsync sudo iputils-ping virtualenv openssh-server iproute2 git htop gdb lldb lld gpg-agent net-tools \ + software-properties-common ca-certificates file unzip wget tar zstd \ + ccache python3-pip python3-networkx python3-dev graphviz-dev clang-format \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +COPY --from=daphneeu/daphne-deps /usr/local/bin/ /usr/local/bin/ +COPY --from=daphneeu/daphne-deps /usr/local/include/ /usr/local/include/ +COPY --from=daphneeu/daphne-deps /usr/local/lib/ /usr/local/lib/ +COPY --from=daphneeu/daphne-deps /usr/local/share/ /usr/local/share/ +RUN ldconfig +# this is a temporary workaround to make the lit code (from the llvm-*-tools package) available to some pre-Ubuntu24 \ +# test cases when run locally in the dev container +RUN ln -s /usr/lib/llvm-18 /usr/lib/llvm-10 +RUN ln -fs /usr/share/zoneinfo/$TZ /etc/localtime +#COPY entrypoint-interactive.sh / +#RUN mkdir -p /var/run/sshd +#EXPOSE 22 +#ENTRYPOINT [ "/entrypoint-interactive.sh"] diff --git a/containers/daphne-dev.Dockerfile b/containers/daphne-dev.Dockerfile index d56fb8002..c50a7c375 100644 --- a/containers/daphne-dev.Dockerfile +++ b/containers/daphne-dev.Dockerfile @@ -33,8 +33,9 @@ RUN apt-get -qq -y update && apt-get -y upgrade && apt-get -y --no-install-recom build-essential clang gfortran lld llvm llvm-18-tools ninja-build openjdk-11-jdk-headless pkg-config python3-numpy python3-pandas \ vim nano rsync sudo iputils-ping virtualenv openssh-server iproute2 git htop gdb lldb lld gpg-agent net-tools \ software-properties-common ca-certificates file unzip wget tar zstd \ - ccache python3-pip python3-networkx python3-dev graphviz-dev \ + ccache python3-pip python3-networkx python3-dev graphviz-dev clang-format \ && apt-get clean && rm -rf /var/lib/apt/lists/* + COPY --from=daphneeu/daphne-deps /usr/local/bin/ /usr/local/bin/ COPY --from=daphneeu/daphne-deps /usr/local/include/ /usr/local/include/ COPY --from=daphneeu/daphne-deps /usr/local/lib/ /usr/local/lib/ diff --git a/containers/publish.sh b/containers/publish.sh index 0de241a25..cd6d32e06 100755 --- a/containers/publish.sh +++ b/containers/publish.sh @@ -44,8 +44,8 @@ fi $USE_SUDO docker push -a daphneeu/github-action # cuda dev image -$USE_SUDO docker tag daphneeu/daphne-dev:${TIMESTAMP_DATE}_${ARCH}_CUDA_${cudaVersion}-cudnn8-devel-ubuntu${ubuntuVersion} daphneeu/daphne-dev:${VERSION}_${ARCH}_CUDA_${cudaVersion}-cudnn8-devel-ubuntu${ubuntuVersion} -$USE_SUDO docker push daphneeu/daphne-dev:${VERSION}_${ARCH}_CUDA_${cudaVersion}-cudnn8-devel-ubuntu${ubuntuVersion} +$USE_SUDO docker tag daphneeu/daphne-dev:${TIMESTAMP_DATE}_${ARCH}_CUDA_${cudaVersion}-cudnn-devel-ubuntu${ubuntuVersion} daphneeu/daphne-dev:${VERSION}_${ARCH}_CUDA_${cudaVersion}-cudnn-devel-ubuntu${ubuntuVersion} +$USE_SUDO docker push daphneeu/daphne-dev:${VERSION}_${ARCH}_CUDA_${cudaVersion}-cudnn-devel-ubuntu${ubuntuVersion} $USE_SUDO docker push daphneeu/daphne-dev:latest_${ARCH}_CUDA # base dev image @@ -54,8 +54,8 @@ $USE_SUDO docker push daphneeu/daphne-dev:${VERSION}_${ARCH}_BASE_ubuntu${ubuntu $USE_SUDO docker push daphneeu/daphne-dev:latest_${ARCH}_BASE # cuda run image -$USE_SUDO docker tag daphneeu/daphne:${TIMESTAMP_DATE}_${ARCH}_CUDA_${cudaVersion}-cudnn8-runtime-ubuntu${ubuntuVersion} daphneeu/daphne:${VERSION}_${ARCH}_CUDA_${cudaVersion}-cudnn8-runtime-ubuntu${ubuntuVersion} -$USE_SUDO docker push daphneeu/daphne:${VERSION}_${ARCH}_CUDA_${cudaVersion}-cudnn8-runtime-ubuntu${ubuntuVersion} +$USE_SUDO docker tag daphneeu/daphne:${TIMESTAMP_DATE}_${ARCH}_CUDA_${cudaVersion}-cudnn-runtime-ubuntu${ubuntuVersion} daphneeu/daphne:${VERSION}_${ARCH}_CUDA_${cudaVersion}-cudnn-runtime-ubuntu${ubuntuVersion} +$USE_SUDO docker push daphneeu/daphne:${VERSION}_${ARCH}_CUDA_${cudaVersion}-cudnn-runtime-ubuntu${ubuntuVersion} $USE_SUDO docker push daphneeu/daphne:latest_${ARCH}_CUDA # base run image diff --git a/doc/GettingStarted.md b/doc/GettingStarted.md index 07e69abbd..dd9eebd41 100644 --- a/doc/GettingStarted.md +++ b/doc/GettingStarted.md @@ -233,6 +233,7 @@ launching DAPHNE via Docker (see below) should work the same way as in a native | java (e.g. openjdk) | 11 (1.7 should be fine) | | | jq | | json commandline processor used in docker image generation scripts. | | libpfm4-dev | 4.10 | This dependency is needed for profiling support [DAPHNE-#479] | +| gRPC | 1.38.0 | | | libssl-dev | 1.1.1 | Dependency introduced while optimizing grpc build (which used to build ssl unnecessarily) | | lld | 10.0.0 | | | llvm-10-tools | 10, 15, 18 | `apt` provides up to `llvm-10-tools` for Ubuntu 20.04 whereas 22.04 / 24.04 require a newer version such as `llvm-15-tools`. | diff --git a/doc/HDFS-Usage.md b/doc/HDFS-Usage.md new file mode 100644 index 000000000..429e40bfe --- /dev/null +++ b/doc/HDFS-Usage.md @@ -0,0 +1,233 @@ + + +# HDFS Usage + +About employing HDFS as a distributed file system. + +This document shows how a DAPHNE user can execute DAPHNE scripts using HDFS as a file system, +which is optimized for performance on big data through distributed computing. +This document assumes that DAPHNE was build with the `--hdfs` options, if this is not the case please rebuild DAPHNE with the `--hdfs` option +`./build.sh --hdfs` + +The DAPHNE build script uses [HAWQ (libhdfs3)](https://github.com/apache/hawq/archive/refs/tags/rel/v3.0.0.0.tar.gz). + +## Configuring DAPHNE for HDFS + +In order for DAPHNE to utilize the HDFS file system certain command line arguments need to be passed +(or included in the config file). + +- `--enable-hdfs`: A flag to enable hdfs. +- `--hdfs-address=`: The IP and port HDFS listens to. +- `--hdfs-username=`: The username used to connect to HDFS. + +## Reading from HDFS + +In order to read a file from the HDFS some pre processing must be done. Assuming the +file is named `FILE_NAME`, a user needs to: + +1. Upload the file into HDFS. DAPHNE expects the file to be located inside a directory with some specific naming conventions. + The path can by any path under HDFS, however the file must be named with the following convention: + +``` +/path/to/hdfs/file/FILE_NAME.FILE_TYPE/FILE_NAME.FILE_TYPE_segment_1 +``` + +`FILE_TYPE` is either `.csv` or `.dbdf` (DAPHNE binary data format) followed by `.hdfs`, e.g. `myfile.csv.hdfs`. + +The suffix `_segment_1` is necessary, since we support multiple writers at once (see more below), the writers need to write into different files (different segments). +In this case where the user pre-uploads the file, it needs to be in the same format, but just one segment. + +Each segment must also have it's own .meta file within the HDFS. This is a JSON +file containg information about the size of the segment as well as the type. +For example `myfile.csv.hdfs_segment_1.meta`: + +```json +{ + "numCols": 10, + "numRows": 10, + "valueType": "f64" +} +``` + +2. We also need to create a .meta file containing information about the file, within the local file system (from where DAPHNE is invoked). + Similar to any other file which will be read by DAPHNE, we need to create a .meta file, which is in JSON format, containing information about where the + file is, information about the rows/cols etc. The file should be named: `FILE_NAME.FILE_TYPE.meta`, e.g. + `myfile.csv.hdfs.meta`. The meta file should contain all the regular information any DAPHNE meta file contains, but in addition it also contains information about whether this is an HDFS file and where it is located within HDFS: + +```json +{ + "hdfs": { + "HDFSFilename": "/path/to/hdfs/file/FILE_NAME.FILE_TYPE", + "isHDFS": true + }, + "numCols": 10, + "numRows": 10, + "valueType": "f64" +} +``` + +### Example: + +Let's say we have a dataset called `training_data.csv` which we want to upload to HDFS and use it with DAPHNE. + +1. Upload file under path `datasets` and create the segment .meta file. HDFS should look like this: + +```bash +$ hdfs dfs -ls / +/datasets/training_data.csv.hdfs/training_data.csv.hdfs_segment_1 +/datasets/training_data.csv.hdfs/training_data.csv.hdfs_segment_1.meta + +$ hdfs dfs -cat /datasets/training_data.csv.hdfs/training_data.csv.hdfs_segment_1.meta +{"numCols":10,"numRows":10,"valueType":"f64"} +``` + +2. Create the local file .meta file: + +```bash +$ cat ./training_data.csv.hdfs.meta +{"hdfs":{"HDFSFilename":"/datasets/training_data.csv.hdfs","isHDFS":true},"numCols":10,"numRows":10,"valueType":"f64"} +``` + +3. DAPHNE script: + +``` +X = readMatrix("training_data.csv.hdfs"); +print(X); +``` + +4. Run DAPHNE + +``` +./bin/daphne --enable-hdfs --hdfs-ip= --hdfs-username=ubuntu code.daph +``` + +## Writing to HDFS + +In order to write to HDFS we just need to use the `writeMatrix` function like we would for any other file type and specify the hdfs suffix. For example: + +1. Code + +``` +X = rand(10, 10, 0.0, 1.0, 1.0, 1); +writeMatrix(X, "randomSet.csv.hdfs"); +``` + +2. Call daphne + +```bash +./bin/daphne --enable-hdfs --hdfs-ip= --hdfs-username=ubuntu code.daph +``` + +This will create the following files inside HDFS: + +```bash +$ hdfs dfs -ls / +/randomSet.csv.hdfs/randomSet.csv.hdfs_segment_1 +/randomSet.csv.hdfs/randomSet.csv.hdfs_segment_1.meta + +$ hdfs dfs -cat /randomSet.csv.hdfs/randomSet.csv.hdfs_segment_1.meta +{"numCols":10,"numRows":10,"valueType":"f64"} +``` + +And also the .meta file within the local file system named `randomSet.csv.hdfs.meta`: + +```json +{ + "hdfs": { + "HDFSFilename": "/randomSet.csv.hdfs", + "isHDFS": true + }, + "numCols": 10, + "numRows": 10, + "valueType": "f64" +} +``` + +### Limitations: + +For now writing to a specific directory, through DAPHNE, within HDFS is not supported. DAPHNE will always try to write under the root HDFS directory `/..hdfs`. + +## Distributed Runtime + +Both read and write operations are supported by the distributed runtime. + +### Read + +Exactly the same preprocessing must be done, creating one file inside the HDFS with the +appropriate naming conventions. Users can then run DAPHNE using the +[distributed runtime](DistributedRuntime.md) and depending on the generated pipeline, DAPHNE's distributed workers will read their +corresponding part of the data speeding up IO significantly. For example: + +1. DAPHNE script: + +``` +X = readMatrix("training_data.csv.hdfs"); +print(X+X); +``` + +2. Run DAPHNE + +```bash +$ export DISTRIBUTED_WORKERS=worker-1::worker-2: +$ ./bin/daphne --distributed --dist_backend=sync-gRPC --enable-hdfs --hdfs-ip= --hdfs-username=ubuntu code.daph +``` + +### Write + +Similar to read, nothing really changes, users just need to call DAPHNE using the distributed runtime flags. Notice that since we have multiple workers/writers, more than +one segements are generated inside HDFS: + +1. Code + +``` +X = rand(10, 10, 0.0, 1.0, 1.0, 1); +writeMatrix(X, "randomSet.csv.hdfs"); +``` + +2. Call daphne + +```bash +$ export DISTRIBUTED_WORKERS=worker-1::worker-2: +$ ./bin/daphne --distributed --dist_backend=sync-gRPC --enable-hdfs --hdfs-ip= --hdfs-username=ubuntu code.daph +``` + +Assuming 2 distributed workers: + +```bash +$ hdfs dfs -ls / +/randomSet.csv.hdfs/randomSet.csv.hdfs_segment_1 # First part of the matrix +/randomSet.csv.hdfs/randomSet.csv.hdfs_segment_1.meta +/randomSet.csv.hdfs/randomSet.csv.hdfs_segment_2 # Second part of the matrix +/randomSet.csv.hdfs/randomSet.csv.hdfs_segment_2.meta + +$ hdfs dfs -cat /randomSet.csv.hdfs/randomSet.csv.hdfs_segment_1.meta +{"numCols":10,"numRows":5,"valueType":"f64"} +$ hdfs dfs -cat /randomSet.csv.hdfs/randomSet.csv.hdfs_segment_2.meta +{"numCols":10,"numRows":5,"valueType":"f64"} +``` + +And also the .meta file within the local file system named `randomSet.csv.hdfs.meta`. + +### Notes + +It does not matter how many segments are generated or exist. DAPHNE is designed to read +the segments according to the current state (distributed or not and how many distributed +workers are being used). + +For example if we use 4 distributed workers to write a matrix, +DAPHNE will generate 4 different segments. DAPHNE can later read the same matrix either in +local execution (no distributed runtime) or using a different number of workers, not depending on the amount of segments generated earlier. diff --git a/doc/development/ImplementBuiltinKernel.md b/doc/development/ImplementBuiltinKernel.md index 4ab3f487d..0dc093aa3 100644 --- a/doc/development/ImplementBuiltinKernel.md +++ b/doc/development/ImplementBuiltinKernel.md @@ -193,3 +193,27 @@ It is recommended to exceptions such as `throw std::runtime_error` in a kernel in case the code runs into an unresolvable issue. We catch these exceptions in our surrounding code to the kernel and provide, whenever possible, additional information about the source of the error in the DaphneDSL script. + + +### Experimental Kernels + +As an alternative to implementing a new kernel that is directly integrated into +DAPHNE, one can also work on kernel implementations using the [kernel catalog](doc/Extensions.md). +These should reside in [experimental/op/](src/runtime/local/kernels/experimental/op/) where `op` is +the mnemonic of the DaphneIR operation that the kernel is implementing. + +Experimental kernels are not directly integrated into DAPHNE and are neither +compiled nor executed by default. They can be used to test new ideas and +provide an easier way of prototyping kernel implementations. One can easily +test multiple different implementations of the same DAPHNE kernel using a +single DaphneDSL script which calls all the kernel implementations. + +There are less restrictions put on experimental kernels than on built-in +kernels, e.g., they are not tested as part of the CI pipeline. You are also +free to introduce new dependencies that are handled by the accompanying +`Makefile` or build script. Testing and dependency management will have to be +resolved before the experimental kernel is integrated into DAPHNE as a built-in +kernel. + +Check out [Extensions.md](doc/Extensions.md) for more information on how to +implement experimental kernels. diff --git a/software-package-versions.txt b/software-package-versions.txt index 30e00f49f..cda7b2841 100644 --- a/software-package-versions.txt +++ b/software-package-versions.txt @@ -19,7 +19,7 @@ abslVersion=20230802.1 antlrVersion=4.9.2 arrowVersion=13.0.0 catch2Version=2.13.8 -cmakeVersion=3.30.3 +cmakeVersion=3.30.5 cudaVersion=12.6.1 eigenVersion=3.4.0 grpcVersion=1.38.0 @@ -28,7 +28,10 @@ nlohmannjsonVersion=3.10.5 openBlasVersion=0.3.23 openMPIVersion=4.1.5 papiVersion=7.0.1 -spdlogVersion=1.11.0 +# temporarily (2024-10-03) use tip of default branch for spdlog due to compilation issue of latest +# release 1.14.1 in combination with external fmt 11.0.2 +spdlogVersion=e593f6695c6065e6b345fe2862f04a519ed484e0 ubuntuVersion=24.04 hawqVersion=3.0.0.0 liburingVersion=2.7 +fmtVersion=11.0.2 diff --git a/src/api/cli/DaphneUserConfig.h b/src/api/cli/DaphneUserConfig.h index 071e15812..d33a18e30 100644 --- a/src/api/cli/DaphneUserConfig.h +++ b/src/api/cli/DaphneUserConfig.h @@ -14,29 +14,28 @@ * limitations under the License. */ - #pragma once #include #include -#include #include -#include +#include #include +#include class DaphneLogger; -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include /* * Container to pass around user configuration */ struct DaphneUserConfig { - // Remember to update UserConfig.json accordingly! + // Remember to update UserConfig.json accordingly! bool use_cuda = false; bool use_vectorized_exec = false; bool use_distributed = false; @@ -44,11 +43,11 @@ struct DaphneUserConfig { bool use_ipa_const_propa = true; bool use_phy_op_selection = true; bool use_mlir_codegen = false; - int matmul_vec_size_bits = 0; + int matmul_vec_size_bits = 0; bool matmul_tile = false; int matmul_unroll_factor = 1; - int matmul_unroll_jam_factor=4; - int matmul_num_vec_registers=16; + int matmul_unroll_jam_factor = 4; + int matmul_num_vec_registers = 16; bool matmul_use_fixed_tile_sizes = false; std::vector matmul_fixed_tile_sizes = {4, 4}; bool matmul_invert_loops = false; @@ -81,9 +80,12 @@ struct DaphneUserConfig { SelfSchedulingScheme taskPartitioningScheme = STATIC; QueueTypeOption queueSetupScheme = CENTRALIZED; - VictimSelectionLogic victimSelection = SEQPRI; - ALLOCATION_TYPE distributedBackEndSetup= ALLOCATION_TYPE::DIST_MPI; // default value - size_t max_distributed_serialization_chunk_size = std::numeric_limits::max() - 1024; // 2GB (-1KB to make up for gRPC headers etc.) - which is the maximum size allowed by gRPC / MPI. TODO: Investigate what might be the optimal. + VictimSelectionLogic victimSelection = SEQPRI; + ALLOCATION_TYPE distributedBackEndSetup = ALLOCATION_TYPE::DIST_MPI; // default value + size_t max_distributed_serialization_chunk_size = + std::numeric_limits::max() - 1024; // 2GB (-1KB to make up for gRPC headers etc.) - which is the + // maximum size allowed by gRPC / MPI. TODO: Investigate what + // might be the optimal. int numberOfThreads = -1; int minimumTaskSize = 1; @@ -92,14 +94,16 @@ struct DaphneUserConfig { std::string hdfs_Address = ""; std::string hdfs_username = ""; - // minimum considered log level (e.g., no logging below ERROR (essentially suppressing WARN, INFO, DEBUG and TRACE) + // minimum considered log level (e.g., no logging below ERROR (essentially + // suppressing WARN, INFO, DEBUG and TRACE) spdlog::level::level_enum log_level_limit = spdlog::level::err; std::vector loggers; - DaphneLogger* log_ptr{}; + DaphneLogger *log_ptr{}; float sparsity_threshold = 0.25; #ifdef USE_CUDA - // User config holds once context atm for convenience until we have proper system infrastructure + // User config holds once context atm for convenience until we have proper + // system infrastructure // CUDA device IDs (future work, as we create only one context atm) std::vector cuda_devices; @@ -110,28 +114,27 @@ struct DaphneUserConfig { #ifdef USE_FPGAOPENCL std::vector fpga_devices; #endif - - + std::string libdir = "{exedir}/../lib"; std::map> daphnedsl_import_paths; + // TODO Maybe the DaphneLib result should better reside in the + // DaphneContext, but having it here is simpler for now. + DaphneLibResult *result_struct = nullptr; - // TODO Maybe the DaphneLib result should better reside in the DaphneContext, - // but having it here is simpler for now. - DaphneLibResult* result_struct = nullptr; - KernelCatalog kernelCatalog; /** - * @brief Replaces the prefix `"{exedir}/"` in the field `libdir` by the path - * of the directory in which the currently running executable resides. + * @brief Replaces the prefix `"{exedir}/"` in the field `libdir` by the + * path of the directory in which the currently running executable resides. * - * Note that the current executable is not necessarily `daphne`. It could also - * be a distributed worker (e.g., `DistributedWorker`) or Python (`python3`). + * Note that the current executable is not necessarily `daphne`. It could + * also be a distributed worker (e.g., `DistributedWorker`) or Python + * (`python3`). */ void resolveLibDir() { const std::string exedirPlaceholder = "{exedir}/"; - if(libdir.substr(0, exedirPlaceholder.size()) == exedirPlaceholder) { + if (libdir.substr(0, exedirPlaceholder.size()) == exedirPlaceholder) { // This next line adds to our Linux platform lock-in. std::filesystem::path daphneExeDir(std::filesystem::canonical("/proc/self/exe").parent_path()); libdir = daphneExeDir / libdir.substr(exedirPlaceholder.size()); diff --git a/src/api/cli/StatusCode.h b/src/api/cli/StatusCode.h index f705f2e71..2be5be2b8 100644 --- a/src/api/cli/StatusCode.h +++ b/src/api/cli/StatusCode.h @@ -19,7 +19,7 @@ /** * @brief Possible status codes returned by the command line interface. - * + * * Note that this is deliberately not an `enum class`, because we frequently * need to use it as an integer. */ @@ -30,4 +30,4 @@ enum StatusCode { EXECUTION_ERROR, }; -#endif //SRC_API_CLI_STATUSCODE_H \ No newline at end of file +#endif // SRC_API_CLI_STATUSCODE_H \ No newline at end of file diff --git a/src/api/cli/daphne.cpp b/src/api/cli/daphne.cpp index 347ff81bc..bd83fbcce 100644 --- a/src/api/cli/daphne.cpp +++ b/src/api/cli/daphne.cpp @@ -16,6 +16,4 @@ #include -int main(int argc, const char** argv) { - return mainInternal(argc, argv, nullptr); -} +int main(int argc, const char **argv) { return mainInternal(argc, argv, nullptr); } diff --git a/src/api/daphnelib/DaphneLibResult.h b/src/api/daphnelib/DaphneLibResult.h index 4d346c8cb..2a7a99158 100644 --- a/src/api/daphnelib/DaphneLibResult.h +++ b/src/api/daphnelib/DaphneLibResult.h @@ -19,17 +19,16 @@ #include #include - struct DaphneLibResult { // For matrices. - void* address; + void *address; int64_t rows; int64_t cols; int64_t vtc; // For frames. - int64_t* vtcs; - char** labels; - void** columns; + int64_t *vtcs; + char **labels; + void **columns; // To pass error messages to Python code. std::string error_message; }; \ No newline at end of file diff --git a/src/api/daphnelib/daphnelib.cpp b/src/api/daphnelib/daphnelib.cpp index 1e6f99735..16c129e4c 100644 --- a/src/api/daphnelib/daphnelib.cpp +++ b/src/api/daphnelib/daphnelib.cpp @@ -25,15 +25,14 @@ DaphneLibResult daphneLibRes; /** * @brief Returns the result of a DaphneLib invocation. */ -extern "C" DaphneLibResult getResult() { - return daphneLibRes; -} +extern "C" DaphneLibResult getResult() { return daphneLibRes; } /** - * @brief Invokes DAPHNE with the specified DaphneDSL script and path to lib dir. + * @brief Invokes DAPHNE with the specified DaphneDSL script and path to lib + * dir. */ -extern "C" int daphne(const char* libDirPath, const char* scriptPath) { - const char * argv[] = {"daphne", "--libdir", libDirPath, scriptPath}; +extern "C" int daphne(const char *libDirPath, const char *scriptPath) { + const char *argv[] = {"daphne", "--libdir", libDirPath, scriptPath}; int argc = 4; return mainInternal(argc, argv, &daphneLibRes); diff --git a/src/api/internal/daphne_internal.cpp b/src/api/internal/daphne_internal.cpp index d406f954d..5b533601b 100644 --- a/src/api/internal/daphne_internal.cpp +++ b/src/api/internal/daphne_internal.cpp @@ -15,19 +15,20 @@ */ #include "runtime/local/datastructures/IAllocationDescriptor.h" -#include + #ifdef USE_MPI - #include "runtime/distributed/worker/MPIWorker.h" +#include "runtime/distributed/worker/MPIWorker.h" #endif -#include -#include + +#include "compiler/execution/DaphneIrExecutor.h" #include +#include #include -#include -#include "compiler/execution/DaphneIrExecutor.h" -#include +#include #include #include +#include +#include #include #include #include @@ -38,7 +39,7 @@ #include "llvm/Support/CommandLine.h" #ifdef USE_CUDA - #include +#include #endif #include @@ -47,9 +48,8 @@ #include #include -#include #include -#include +#include #include #include @@ -60,49 +60,50 @@ using namespace std; using namespace mlir; using namespace llvm::cl; -void parseScriptArgs(const llvm::cl::list& scriptArgsCli, unordered_map& scriptArgsFinal) { - for(const std::string& pair : scriptArgsCli) { +void parseScriptArgs(const llvm::cl::list &scriptArgsCli, unordered_map &scriptArgsFinal) { + for (const std::string &pair : scriptArgsCli) { size_t pos = pair.find('='); - if(pos == string::npos) - throw std::runtime_error("script arguments must be specified as name=value, but found '" + pair + "'"); + if (pos == string::npos) + throw std::runtime_error("script arguments must be specified as " + "name=value, but found '" + + pair + "'"); const string argName = pair.substr(0, pos); const string argValue = pair.substr(pos + 1, pair.size()); - if(scriptArgsFinal.count(argName)) + if (scriptArgsFinal.count(argName)) throw runtime_error("script argument: '" + argName + "' was provided more than once"); scriptArgsFinal.emplace(argName, argValue); } } -void printVersion(llvm::raw_ostream& os) { +void printVersion(llvm::raw_ostream &os) { // TODO Include some of the important build flags into the version string. - os - << "DAPHNE Version 0.3\n" - << "An Open and Extensible System Infrastructure for Integrated Data Analysis Pipelines\n" - << "https://github.com/daphne-eu/daphne\n"; + os << "DAPHNE Version 0.3\n" + << "An Open and Extensible System Infrastructure for Integrated Data " + "Analysis Pipelines\n" + << "https://github.com/daphne-eu/daphne\n"; } -namespace -{ - volatile std::sig_atomic_t gSignalStatus; - jmp_buf return_from_handler; -} +namespace { +volatile std::sig_atomic_t gSignalStatus; +jmp_buf return_from_handler; +} // namespace void handleSignals(int signal) { constexpr int callstackMaxSize = 25; - void* callstack[callstackMaxSize]; + void *callstack[callstackMaxSize]; auto callstacksReturned = backtrace(callstack, callstackMaxSize); backtrace_symbols_fd(callstack, callstacksReturned, STDOUT_FILENO); gSignalStatus = signal; longjmp(return_from_handler, gSignalStatus); } -void logErrorDaphneLibAware(DaphneLibResult * daphneLibRes, std::string msg) { - if(daphneLibRes != nullptr) // For DaphneLib (Python API), error message is handled later in script.py. +void logErrorDaphneLibAware(DaphneLibResult *daphneLibRes, std::string msg) { + if (daphneLibRes != nullptr) // For DaphneLib (Python API), error message is handled later in script.py. daphneLibRes->error_message = msg; else spdlog::error(msg); } -int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int *id, DaphneUserConfig& user_config){ +int startDAPHNE(int argc, const char **argv, DaphneLibResult *daphneLibRes, int *id, DaphneUserConfig &user_config) { using clock = std::chrono::high_resolution_clock; clock::time_point tpBeg = clock::now(); @@ -113,330 +114,240 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int // ************************************************************************ // Parse command line arguments // ************************************************************************ - + // ------------------------------------------------------------------------ // Define options // ------------------------------------------------------------------------ // All the variables concerned with the LLVM command line parser (those of // type OptionCategory, opt, ...) must be declared static here, because - // this function may run multiple times in the context of DaphneLib (DAPHNE's - // Python API). Without static, the second invocation of this function - // crashes because the options set in the first invocation are still present - // in some global state. This must be due to the way the LLVM command line - // library handles its internal state. - + // this function may run multiple times in the context of DaphneLib + // (DAPHNE's Python API). Without static, the second invocation of this + // function crashes because the options set in the first invocation are + // still present in some global state. This must be due to the way the LLVM + // command line library handles its internal state. + // Option categories ------------------------------------------------------ - + // TODO We will probably subdivide the options into multiple groups later. static OptionCategory daphneOptions("DAPHNE Options"); static OptionCategory schedulingOptions("Advanced Scheduling Knobs"); static OptionCategory distributedBackEndSetupOptions("Distributed Backend Knobs"); static OptionCategory HDFSOptions("HDFS Knobs"); - // Options ---------------------------------------------------------------- // Distributed backend Knobs - static opt distributedBackEndSetup("dist_backend", cat(distributedBackEndSetupOptions), - desc("Choose the options for the distribution backend:"), - values( - clEnumValN(ALLOCATION_TYPE::DIST_MPI, "MPI", "Use message passing interface for internode data exchange (default)"), - clEnumValN(ALLOCATION_TYPE::DIST_GRPC_SYNC, "sync-gRPC", "Use remote procedure call (synchronous gRPC with threading) for internode data exchange"), - clEnumValN(ALLOCATION_TYPE::DIST_GRPC_ASYNC, "async-gRPC", "Use remote procedure call (asynchronous gRPC) for internode data exchange") - ), - init(ALLOCATION_TYPE::DIST_MPI) - ); - static opt maxDistrChunkSize("max-distr-chunk-size", cat(distributedBackEndSetupOptions), - desc( - "Define the maximum chunk size per message for the distributed runtime (in bytes)" - "(default is close to maximum allowed ~2GB)" - ), - init(std::numeric_limits::max() - 1024) - ); + static opt distributedBackEndSetup( + "dist_backend", cat(distributedBackEndSetupOptions), desc("Choose the options for the distribution backend:"), + values(clEnumValN(ALLOCATION_TYPE::DIST_MPI, "MPI", + "Use message passing interface for internode data " + "exchange (default)"), + clEnumValN(ALLOCATION_TYPE::DIST_GRPC_SYNC, "sync-gRPC", + "Use remote procedure call (synchronous gRPC with " + "threading) for internode data exchange"), + clEnumValN(ALLOCATION_TYPE::DIST_GRPC_ASYNC, "async-gRPC", + "Use remote procedure call (asynchronous gRPC) for " + "internode data exchange")), + init(ALLOCATION_TYPE::DIST_MPI)); + static opt maxDistrChunkSize("max-distr-chunk-size", cat(distributedBackEndSetupOptions), + desc("Define the maximum chunk size per message for the distributed " + "runtime (in bytes)" + "(default is close to maximum allowed ~2GB)"), + init(std::numeric_limits::max() - 1024)); // HDFS knobs - static opt use_hdfs( - "enable-hdfs", cat(HDFSOptions), - desc("Enable HDFS filesystem") - ); - static opt hdfs_Address( - "hdfs-ip", cat(HDFSOptions), - desc("IP of the HDFS filesystem (including port)."), - init("") - ); - static opt hdfs_username( - "hdfs-username", cat(HDFSOptions), - desc("Username of the HDFS filesystem."), - init("") - ); - - + static opt use_hdfs("enable-hdfs", cat(HDFSOptions), desc("Enable HDFS filesystem")); + static opt hdfs_Address("hdfs-ip", cat(HDFSOptions), desc("IP of the HDFS filesystem (including port)."), + init("")); + static opt hdfs_username("hdfs-username", cat(HDFSOptions), desc("Username of the HDFS filesystem."), + init("")); + // Scheduling options - static opt taskPartitioningScheme("partitioning", - cat(schedulingOptions), desc("Choose task partitioning scheme:"), - values( - clEnumVal(STATIC , "Static (default)"), - clEnumVal(SS, "Self-scheduling"), - clEnumVal(GSS, "Guided self-scheduling"), - clEnumVal(TSS, "Trapezoid self-scheduling"), - clEnumVal(FAC2, "Factoring self-scheduling"), - clEnumVal(TFSS, "Trapezoid Factoring self-scheduling"), - clEnumVal(FISS, "Fixed-increase self-scheduling"), - clEnumVal(VISS, "Variable-increase self-scheduling"), - clEnumVal(PLS, "Performance loop-based self-scheduling"), - clEnumVal(MSTATIC, "Modified version of Static, i.e., instead of n/p, it uses n/(4*p) where n is number of tasks and p is number of threads"), - clEnumVal(MFSC, "Modified version of fixed size chunk self-scheduling, i.e., MFSC does not require profiling information as FSC"), - clEnumVal(PSS, "Probabilistic self-scheduling"), - clEnumVal(AUTO, "Automatic partitioning") - ), - init(STATIC) - ); - static opt queueSetupScheme("queue_layout", - cat(schedulingOptions), desc("Choose queue setup scheme:"), - values( - clEnumVal(CENTRALIZED, "One queue (default)"), - clEnumVal(PERGROUP, "One queue per CPU group"), - clEnumVal(PERCPU, "One queue per CPU core") - ), - init(CENTRALIZED) - ); - static opt victimSelection("victim_selection", - cat(schedulingOptions), desc("Choose work stealing victim selection logic:"), - values( - clEnumVal(SEQ, "Steal from next adjacent worker (default)"), - clEnumVal(SEQPRI, "Steal from next adjacent worker, prioritize same NUMA domain"), - clEnumVal(RANDOM, "Steal from random worker"), - clEnumVal(RANDOMPRI, "Steal from random worker, prioritize same NUMA domain") - ), - init(SEQ) - ); - - static opt numberOfThreads( - "num-threads", cat(schedulingOptions), - desc( - "Define the number of the CPU threads used by the vectorized execution engine " - "(default is equal to the number of physical cores on the target node that executes the code)" - ) - ); - static opt minimumTaskSize( - "grain-size", cat(schedulingOptions), - desc( - "Define the minimum grain size of a task (default is 1)" - ), - init(1) - ); - static opt useVectorizedPipelines( - "vec", cat(schedulingOptions), - desc("Enable vectorized execution engine") - ); - static opt useDistributedRuntime( - "distributed", cat(daphneOptions), - desc("Enable distributed runtime") - ); - static opt prePartitionRows( - "pre-partition", cat(schedulingOptions), - desc("Partition rows into the number of queues before applying scheduling technique") - ); - static opt pinWorkers( - "pin-workers", cat(schedulingOptions), - desc("Pin workers to CPU cores") - ); - static opt hyperthreadingEnabled( - "hyperthreading", cat(schedulingOptions), - desc("Utilize multiple logical CPUs located on the same physical CPU") - ); - static opt debugMultiThreading( - "debug-mt", cat(schedulingOptions), - desc("Prints debug information about the Multithreading Wrapper") - ); - + static opt taskPartitioningScheme( + "partitioning", cat(schedulingOptions), desc("Choose task partitioning scheme:"), + values(clEnumVal(STATIC, "Static (default)"), clEnumVal(SS, "Self-scheduling"), + clEnumVal(GSS, "Guided self-scheduling"), clEnumVal(TSS, "Trapezoid self-scheduling"), + clEnumVal(FAC2, "Factoring self-scheduling"), clEnumVal(TFSS, "Trapezoid Factoring self-scheduling"), + clEnumVal(FISS, "Fixed-increase self-scheduling"), clEnumVal(VISS, "Variable-increase self-scheduling"), + clEnumVal(PLS, "Performance loop-based self-scheduling"), + clEnumVal(MSTATIC, "Modified version of Static, i.e., instead " + "of n/p, it uses n/(4*p) where n is number " + "of tasks and p is number of threads"), + clEnumVal(MFSC, "Modified version of fixed size chunk self-scheduling, " + "i.e., MFSC does not require profiling information as FSC"), + clEnumVal(PSS, "Probabilistic self-scheduling"), clEnumVal(AUTO, "Automatic partitioning")), + init(STATIC)); + static opt queueSetupScheme( + "queue_layout", cat(schedulingOptions), desc("Choose queue setup scheme:"), + values(clEnumVal(CENTRALIZED, "One queue (default)"), clEnumVal(PERGROUP, "One queue per CPU group"), + clEnumVal(PERCPU, "One queue per CPU core")), + init(CENTRALIZED)); + static opt victimSelection( + "victim_selection", cat(schedulingOptions), desc("Choose work stealing victim selection logic:"), + values(clEnumVal(SEQ, "Steal from next adjacent worker (default)"), + clEnumVal(SEQPRI, "Steal from next adjacent worker, prioritize same NUMA domain"), + clEnumVal(RANDOM, "Steal from random worker"), + clEnumVal(RANDOMPRI, "Steal from random worker, prioritize same NUMA domain")), + init(SEQ)); + + static opt numberOfThreads("num-threads", cat(schedulingOptions), + desc("Define the number of the CPU threads used by the vectorized " + "execution engine " + "(default is equal to the number of physical cores on the target " + "node that executes the code)")); + static opt minimumTaskSize("grain-size", cat(schedulingOptions), + desc("Define the minimum grain size of a task (default is 1)"), init(1)); + static opt useVectorizedPipelines("vec", cat(schedulingOptions), desc("Enable vectorized execution engine")); + static opt useDistributedRuntime("distributed", cat(daphneOptions), desc("Enable distributed runtime")); + static opt prePartitionRows("pre-partition", cat(schedulingOptions), + desc("Partition rows into the number of queues before applying " + "scheduling technique")); + static opt pinWorkers("pin-workers", cat(schedulingOptions), desc("Pin workers to CPU cores")); + static opt hyperthreadingEnabled("hyperthreading", cat(schedulingOptions), + desc("Utilize multiple logical CPUs located on the same physical CPU")); + static opt debugMultiThreading("debug-mt", cat(schedulingOptions), + desc("Prints debug information about the Multithreading Wrapper")); + // Other options - static opt noObjRefMgnt( - "no-obj-ref-mgnt", cat(daphneOptions), - desc( - "Switch off garbage collection by not managing data " - "objects' reference counters" - ) - ); - static opt noIPAConstPropa( - "no-ipa-const-propa", cat(daphneOptions), - desc("Switch off inter-procedural constant propagation") - ); - static opt noPhyOpSelection( - "no-phy-op-selection", cat(daphneOptions), - desc("Switch off physical operator selection, use default kernels for all operations") - ); - static opt selectMatrixRepr( - "select-matrix-repr", cat(daphneOptions), - desc( - "Automatically choose physical matrix representations " - "(e.g., dense/sparse)" - ) - ); + static opt noObjRefMgnt("no-obj-ref-mgnt", cat(daphneOptions), + desc("Switch off garbage collection by not managing data " + "objects' reference counters")); + static opt noIPAConstPropa("no-ipa-const-propa", cat(daphneOptions), + desc("Switch off inter-procedural constant propagation")); + static opt noPhyOpSelection("no-phy-op-selection", cat(daphneOptions), + desc("Switch off physical operator selection, use default kernels for " + "all operations")); + static opt selectMatrixRepr("select-matrix-repr", cat(daphneOptions), + desc("Automatically choose physical matrix representations " + "(e.g., dense/sparse)")); static alias selectMatrixReprAlias( // to still support the longer old form - "select-matrix-representations", aliasopt(selectMatrixRepr), - desc("Alias for --select-matrix-repr") - ); - static opt cuda( - "cuda", cat(daphneOptions), - desc("Use CUDA") - ); - static opt fpgaopencl( - "fpgaopencl", cat(daphneOptions), - desc("Use FPGAOPENCL") - ); - static opt libDir( - "libdir", cat(daphneOptions), - desc( - "The directory containing the kernel catalog files " - "(typically, but not necessarily, along with the kernel shared libraries)" - ) - ); - - static opt mlirCodegen( - "mlir-codegen", cat(daphneOptions), - desc("Enables lowering of certain DaphneIR operations on DenseMatrix to low-level MLIR operations.") - ); - static opt matmul_vec_size_bits( - "matmul-vec-size-bits", cat(daphneOptions), - desc("Set the vector size to be used in the lowering of the MatMul operation if possible. Value of 0 is interpreted as off switch."), - init(0) - ); - static opt matmul_tile( - "matmul-tile", cat(daphneOptions), - desc("Enables loop tiling in the lowering of the MatMul operation.") - ); - static opt matmul_unroll_factor( - "matmul-unroll-factor", cat(daphneOptions), - desc("Factor by which to unroll the finally resulting inner most loop in the lowered MatMul if tiling is used."), - init(1) - ); - static opt matmul_unroll_jam_factor( - "matmul-unroll-jam-factor", cat(daphneOptions), - desc("Factor by which to unroll jam the two inner most loop in the lowered MatMul if tiling is used."), - init(4) - ); - static opt matmul_num_vec_registers( - "matmul-num-vec-registers", cat(daphneOptions), - desc("Number of vector registers. Used during automatic tiling in lowering of MatMulOp"), - init(16) - ); + "select-matrix-representations", aliasopt(selectMatrixRepr), desc("Alias for --select-matrix-repr")); + static opt cuda("cuda", cat(daphneOptions), desc("Use CUDA")); + static opt fpgaopencl("fpgaopencl", cat(daphneOptions), desc("Use FPGAOPENCL")); + static opt libDir("libdir", cat(daphneOptions), + desc("The directory containing the kernel catalog files " + "(typically, but not necessarily, along with the kernel shared " + "libraries)")); + + static opt mlirCodegen("mlir-codegen", cat(daphneOptions), + desc("Enables lowering of certain DaphneIR operations on DenseMatrix " + "to low-level MLIR operations.")); + static opt matmul_vec_size_bits("matmul-vec-size-bits", cat(daphneOptions), + desc("Set the vector size to be used in the lowering of the MatMul " + "operation if possible. Value of 0 is interpreted as off switch."), + init(0)); + static opt matmul_tile("matmul-tile", cat(daphneOptions), + desc("Enables loop tiling in the lowering of the MatMul operation.")); + static opt matmul_unroll_factor("matmul-unroll-factor", cat(daphneOptions), + desc("Factor by which to unroll the finally resulting inner most loop " + "in the lowered MatMul if tiling is used."), + init(1)); + static opt matmul_unroll_jam_factor("matmul-unroll-jam-factor", cat(daphneOptions), + desc("Factor by which to unroll jam the two inner most loop in the " + "lowered MatMul if tiling is used."), + init(4)); + static opt matmul_num_vec_registers("matmul-num-vec-registers", cat(daphneOptions), + desc("Number of vector registers. Used during automatic tiling in " + "lowering of MatMulOp"), + init(16)); static llvm::cl::list matmul_fixed_tile_sizes( "matmul-fixed-tile-sizes", cat(daphneOptions), - desc("Set fixed tile sizes to be used for the lowering of MatMul if tiling is used. This also enables tiling."), - CommaSeparated - ); - static opt matmul_invert_loops( - "matmul-invert-loops", cat(daphneOptions), - desc("Enable inverting of the inner two loops in the matrix multiplication as a fallback option, if tiling is not possible or deactivated.") - ); - - - static opt performHybridCodegen( - "mlir-hybrid-codegen", cat(daphneOptions), - desc("Enables prototypical hybrid code generation combining pre-compiled kernels and MLIR code generation.") - ); - static opt kernelExt( - "kernel-ext", cat(daphneOptions), - desc("Additional kernel extension to register (path to a kernel catalog JSON file).") - ); + desc("Set fixed tile sizes to be used for the lowering of MatMul if " + "tiling is used. This also enables tiling."), + CommaSeparated); + static opt matmul_invert_loops("matmul-invert-loops", cat(daphneOptions), + desc("Enable inverting of the inner two loops in the matrix " + "multiplication as a fallback option, if tiling is not possible " + "or deactivated.")); + + static opt performHybridCodegen("mlir-hybrid-codegen", cat(daphneOptions), + desc("Enables prototypical hybrid code generation combining " + "pre-compiled kernels and MLIR code generation.")); + static opt kernelExt("kernel-ext", cat(daphneOptions), + desc("Additional kernel extension to register " + "(path to a kernel catalog JSON file).")); enum ExplainArgs { - kernels, - llvm, - parsing, - parsing_simplified, - property_inference, - select_matrix_repr, - sql, - phy_op_selection, - type_adaptation, - vectorized, - obj_ref_mgnt, - mlir_codegen + kernels, + llvm, + parsing, + parsing_simplified, + property_inference, + select_matrix_repr, + sql, + phy_op_selection, + type_adaptation, + vectorized, + obj_ref_mgnt, + mlir_codegen }; static llvm::cl::list explainArgList( "explain", cat(daphneOptions), llvm::cl::desc("Show DaphneIR after certain compiler passes (separate " "multiple values by comma, the order is irrelevant)"), - llvm::cl::values( - clEnumVal(parsing, "Show DaphneIR after parsing"), - clEnumVal(parsing_simplified, "Show DaphneIR after parsing and some simplifications"), - clEnumVal(sql, "Show DaphneIR after SQL parsing"), - clEnumVal(property_inference, "Show DaphneIR after property inference"), - clEnumVal(select_matrix_repr, "Show DaphneIR after selecting physical matrix representations"), - clEnumVal(phy_op_selection, "Show DaphneIR after selecting physical operators"), - clEnumVal(type_adaptation, "Show DaphneIR after adapting types to available kernels"), - clEnumVal(vectorized, "Show DaphneIR after vectorization"), - clEnumVal(obj_ref_mgnt, "Show DaphneIR after managing object references"), - clEnumVal(kernels, "Show DaphneIR after kernel lowering"), - clEnumVal(llvm, "Show DaphneIR after llvm lowering"), - clEnumVal(mlir_codegen, "Show DaphneIR after MLIR codegen")), + llvm::cl::values(clEnumVal(parsing, "Show DaphneIR after parsing"), + clEnumVal(parsing_simplified, "Show DaphneIR after parsing and some simplifications"), + clEnumVal(sql, "Show DaphneIR after SQL parsing"), + clEnumVal(property_inference, "Show DaphneIR after property inference"), + clEnumVal(select_matrix_repr, "Show DaphneIR after selecting " + "physical matrix representations"), + clEnumVal(phy_op_selection, "Show DaphneIR after selecting physical operators"), + clEnumVal(type_adaptation, "Show DaphneIR after adapting types to available kernels"), + clEnumVal(vectorized, "Show DaphneIR after vectorization"), + clEnumVal(obj_ref_mgnt, "Show DaphneIR after managing object references"), + clEnumVal(kernels, "Show DaphneIR after kernel lowering"), + clEnumVal(llvm, "Show DaphneIR after llvm lowering"), + clEnumVal(mlir_codegen, "Show DaphneIR after MLIR codegen")), CommaSeparated); - static llvm::cl::list scriptArgs1( - "args", cat(daphneOptions), - desc( - "Alternative way of specifying arguments to the DaphneDSL " - "script; must be a comma-separated list of name-value-pairs, " - "e.g., `--args x=1,y=2.2`" - ), - CommaSeparated - ); + static llvm::cl::list scriptArgs1("args", cat(daphneOptions), + desc("Alternative way of specifying arguments to the DaphneDSL " + "script; must be a comma-separated list of name-value-pairs, " + "e.g., `--args x=1,y=2.2`"), + CommaSeparated); const std::string configFileInitValue = "-"; - static opt configFile( - "config", cat(daphneOptions), - desc("A JSON file that contains the DAPHNE configuration"), - value_desc("filename"), - llvm::cl::init(configFileInitValue) - ); - - static opt enableStatistics( - "statistics", cat(daphneOptions), - desc("Enables runtime statistics output.")); - - static opt enableProfiling ( - "enable-profiling", cat(daphneOptions), - desc("Enable profiling support") - ); - static opt timing ( - "timing", cat(daphneOptions), - desc("Enable timing of high-level steps (start-up, parsing, compilation, execution) and print the times to stderr in JSON format") - ); + static opt configFile("config", cat(daphneOptions), + desc("A JSON file that contains the DAPHNE configuration"), value_desc("filename"), + llvm::cl::init(configFileInitValue)); + + static opt enableStatistics("statistics", cat(daphneOptions), desc("Enables runtime statistics output.")); + + static opt enableProfiling("enable-profiling", cat(daphneOptions), desc("Enable profiling support")); + static opt timing("timing", cat(daphneOptions), + desc("Enable timing of high-level steps (start-up, " + "parsing, compilation, execution) and print " + "the times to stderr in JSON format")); // Positional arguments --------------------------------------------------- - + static opt inputFile(Positional, desc("script"), Required); static llvm::cl::list scriptArgs2(ConsumeAfter, desc("[arguments]")); // ------------------------------------------------------------------------ // Parse arguments // ------------------------------------------------------------------------ - + static std::vector visibleCategories; visibleCategories.push_back(&daphneOptions); visibleCategories.push_back(&schedulingOptions); visibleCategories.push_back(&distributedBackEndSetupOptions); visibleCategories.push_back(&HDFSOptions); - + HideUnrelatedOptions(visibleCategories); - extrahelp( - "\nEXAMPLES:\n\n" - " daphne example.daphne\n" - " daphne --vec example.daphne x=1 y=2.2 z=\"foo\"\n" - " daphne --vec --args x=1,y=2.2,z=\"foo\" example.daphne\n" - " daphne --vec --args x=1,y=2.2 example.daphne z=\"foo\"\n" - ); + extrahelp("\nEXAMPLES:\n\n" + " daphne example.daphne\n" + " daphne --vec example.daphne x=1 y=2.2 z=\"foo\"\n" + " daphne --vec --args x=1,y=2.2,z=\"foo\" example.daphne\n" + " daphne --vec --args x=1,y=2.2 example.daphne z=\"foo\"\n"); SetVersionPrinter(&printVersion); - ParseCommandLineOptions( - argc, argv, - "The DAPHNE Prototype.\n\nThis program compiles and executes a DaphneDSL script.\n" - ); + ParseCommandLineOptions(argc, argv, + "The DAPHNE Prototype.\n\nThis program compiles " + "and executes a DaphneDSL script.\n"); // ************************************************************************ // Process parsed arguments @@ -446,18 +357,17 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int if (configFile != configFileInitValue && ConfigParser::fileExists(configFile)) { ConfigParser::readUserConfig(configFile, user_config); } - } - catch(std::exception & e) { + } catch (std::exception &e) { logErrorDaphneLibAware(daphneLibRes, "Parser error while reading user config:\n" + std::string(e.what())); return StatusCode::PARSER_ERROR; } // initialize logging facility - if(not logger) + if (not logger) logger = std::make_unique(user_config); user_config.use_vectorized_exec = useVectorizedPipelines; - user_config.use_distributed = useDistributedRuntime; + user_config.use_distributed = useDistributedRuntime; user_config.use_obj_ref_mgnt = !noObjRefMgnt; user_config.use_ipa_const_propa = !noIPAConstPropa; user_config.use_phy_op_selection = !noPhyOpSelection; @@ -476,33 +386,34 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int } user_config.use_mlir_hybrid_codegen = performHybridCodegen; - if(!libDir.getValue().empty()) + if (!libDir.getValue().empty()) user_config.libdir = libDir.getValue(); user_config.resolveLibDir(); user_config.taskPartitioningScheme = taskPartitioningScheme; user_config.queueSetupScheme = queueSetupScheme; - user_config.victimSelection = victimSelection; + user_config.victimSelection = victimSelection; // only overwrite with non-defaults - if(numberOfThreads != 0) { + if (numberOfThreads != 0) { spdlog::trace("Overwriting config file supplied numberOfThreads={} with command line argument --num-threads={}", - user_config.numberOfThreads, numberOfThreads); + user_config.numberOfThreads, static_cast(numberOfThreads)); user_config.numberOfThreads = numberOfThreads; } - user_config.minimumTaskSize = minimumTaskSize; + user_config.minimumTaskSize = minimumTaskSize; user_config.pinWorkers = pinWorkers; user_config.hyperthreadingEnabled = hyperthreadingEnabled; user_config.debugMultiThreading = debugMultiThreading; user_config.prePartitionRows = prePartitionRows; user_config.distributedBackEndSetup = distributedBackEndSetup; - if(user_config.use_distributed) - { - if(user_config.distributedBackEndSetup!=ALLOCATION_TYPE::DIST_MPI && user_config.distributedBackEndSetup!=ALLOCATION_TYPE::DIST_GRPC_SYNC && user_config.distributedBackEndSetup!=ALLOCATION_TYPE::DIST_GRPC_ASYNC) + if (user_config.use_distributed) { + if (user_config.distributedBackEndSetup != ALLOCATION_TYPE::DIST_MPI && + user_config.distributedBackEndSetup != ALLOCATION_TYPE::DIST_GRPC_SYNC && + user_config.distributedBackEndSetup != ALLOCATION_TYPE::DIST_GRPC_ASYNC) spdlog::warn("No backend has been selected. Wiil use the default 'MPI'"); } - user_config.max_distributed_serialization_chunk_size = maxDistrChunkSize; + user_config.max_distributed_serialization_chunk_size = maxDistrChunkSize; // only overwrite with non-defaults if (use_hdfs) { @@ -514,96 +425,99 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int if (hdfs_username != "") { user_config.hdfs_username = hdfs_username; } - if (user_config.use_hdfs && (user_config.hdfs_Address == "" || user_config.hdfs_username == "")){ - spdlog::warn("HDFS is enabled, but the HDFS IP address or username were not provided."); + if (user_config.use_hdfs && (user_config.hdfs_Address == "" || user_config.hdfs_username == "")) { + spdlog::warn("HDFS is enabled, but the HDFS IP address or username " + "were not provided."); } #ifndef USE_HDFS - if (user_config.use_hdfs){ - throw std::runtime_error("you are trying to use HDFS, but Daphne was not build with --hdfs option\n"); + if (user_config.use_hdfs) { + throw std::runtime_error("you are trying to use HDFS, but Daphne was " + "not build with --hdfs option\n"); } #endif for (auto explain : explainArgList) { switch (explain) { - case kernels: - user_config.explain_kernels = true; - break; - case llvm: - user_config.explain_llvm = true; - break; - case parsing: - user_config.explain_parsing = true; - break; - case parsing_simplified: - user_config.explain_parsing_simplified = true; - break; - case property_inference: - user_config.explain_property_inference = true; - break; - case select_matrix_repr: - user_config.explain_select_matrix_repr = true; - break; - case sql: - user_config.explain_sql = true; - break; - case phy_op_selection: - user_config.explain_phy_op_selection = true; - break; - case type_adaptation: - user_config.explain_type_adaptation = true; - break; - case vectorized: - user_config.explain_vectorized = true; - break; - case obj_ref_mgnt: - user_config.explain_obj_ref_mgnt = true; - break; - case mlir_codegen: - user_config.explain_mlir_codegen = true; - break; + case kernels: + user_config.explain_kernels = true; + break; + case llvm: + user_config.explain_llvm = true; + break; + case parsing: + user_config.explain_parsing = true; + break; + case parsing_simplified: + user_config.explain_parsing_simplified = true; + break; + case property_inference: + user_config.explain_property_inference = true; + break; + case select_matrix_repr: + user_config.explain_select_matrix_repr = true; + break; + case sql: + user_config.explain_sql = true; + break; + case phy_op_selection: + user_config.explain_phy_op_selection = true; + break; + case type_adaptation: + user_config.explain_type_adaptation = true; + break; + case vectorized: + user_config.explain_vectorized = true; + break; + case obj_ref_mgnt: + user_config.explain_obj_ref_mgnt = true; + break; + case mlir_codegen: + user_config.explain_mlir_codegen = true; + break; } } user_config.statistics = enableStatistics; - if(user_config.use_distributed && distributedBackEndSetup==ALLOCATION_TYPE::DIST_MPI) - { + if (user_config.use_distributed && distributedBackEndSetup == ALLOCATION_TYPE::DIST_MPI) { #ifndef USE_MPI - throw std::runtime_error("you are trying to use the MPI backend. But, Daphne was not build with --mpi option\n"); + throw std::runtime_error("you are trying to use the MPI backend. But, " + "Daphne was not build with --mpi option\n"); #else - MPI_Init(NULL,NULL); + MPI_Init(NULL, NULL); MPI_Comm_rank(MPI_COMM_WORLD, id); - int size=0; + int size = 0; MPI_Comm_size(MPI_COMM_WORLD, &size); - if(size<=1) - { - throw std::runtime_error("you need to rerun with at least 2 MPI ranks (1 Master + 1 Worker)\n"); + if (size <= 1) { + throw std::runtime_error("you need to rerun with at least 2 MPI " + "ranks (1 Master + 1 Worker)\n"); } - if(*id!=COORDINATOR) - { - return *id; + if (*id != COORDINATOR) { + return *id; } -#endif +#endif } - if(cuda) { + if (cuda) { int device_count = 0; #ifdef USE_CUDA CHECK_CUDART(cudaGetDeviceCount(&device_count)); #endif - if(device_count < 1) - spdlog::warn("CUDA ops requested by user option but no suitable device found"); + if (device_count < 1) + spdlog::warn("CUDA ops requested by user option but no suitable " + "device found"); else { user_config.use_cuda = true; } } - if(fpgaopencl) { + if (fpgaopencl) { user_config.use_fpgaopencl = true; } - if(enableProfiling) { + if (enableProfiling) { #ifndef USE_PAPI - throw std::runtime_error("you are trying to use profiling, but daphne was built with --no-papi\n"); + throw std::runtime_error("you are trying to use profiling, but daphne " + "was built with --no-papi\n"); #else user_config.enable_profiling = true; #endif @@ -617,8 +531,7 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int try { parseScriptArgs(scriptArgs2, scriptArgsFinal); parseScriptArgs(scriptArgs1, scriptArgsFinal); - } - catch(exception& e) { + } catch (exception &e) { logErrorDaphneLibAware(daphneLibRes, "Parser error: " + std::string(e.what())); return StatusCode::PARSER_ERROR; } @@ -629,20 +542,20 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int // Creates an MLIR context and loads the required MLIR dialects. DaphneIrExecutor executor(selectMatrixRepr, user_config); - mlir::MLIRContext * mctx = executor.getContext(); + mlir::MLIRContext *mctx = executor.getContext(); // ************************************************************************ // Populate kernel extension catalog // ************************************************************************ - KernelCatalog & kc = executor.getUserConfig().kernelCatalog; + KernelCatalog &kc = executor.getUserConfig().kernelCatalog; // kc.dump(); KernelCatalogParser kcp(mctx); kcp.parseKernelCatalog(user_config.libdir + "/catalog.json", kc); - if(user_config.use_cuda) + if (user_config.use_cuda) kcp.parseKernelCatalog(user_config.libdir + "/CUDAcatalog.json", kc); // kc.dump(); - if(!kernelExt.empty()) + if (!kernelExt.empty()) kcp.parseKernelCatalog(kernelExt, kc); // ************************************************************************ @@ -657,7 +570,7 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int OpBuilder builder(mctx); auto loc = mlir::FileLineColLoc::get(builder.getStringAttr(inputFile), 0, 0); auto moduleOp = ModuleOp::create(loc); - auto * body = moduleOp.getBody(); + auto *body = moduleOp.getBody(); builder.setInsertionPoint(body, body->begin()); // Parse the input file and generate the corresponding DaphneIR operations @@ -665,8 +578,7 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int DaphneDSLParser parser(scriptArgsFinal, user_config); try { parser.parseFile(builder, inputFile); - } - catch(std::exception & e) { + } catch (std::exception &e) { logErrorDaphneLibAware(daphneLibRes, "While parsing: " + std::string(e.what())); return StatusCode::PARSER_ERROR; } @@ -674,16 +586,14 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int clock::time_point tpBegComp = clock::now(); // Further, process the module, including optimization and lowering passes. - try{ + try { if (!executor.runPasses(moduleOp)) { return StatusCode::PASS_ERROR; } } catch (std::exception &e) { - logErrorDaphneLibAware( - daphneLibRes, - "Lowering pipeline error.{}\nPassManager failed module lowering, " - "responsible IR written to module_fail.log.\n" + std::string(e.what()) - ); + logErrorDaphneLibAware(daphneLibRes, "Lowering pipeline error.{}\nPassManager failed module lowering, " + "responsible IR written to module_fail.log.\n" + + std::string(e.what())); return StatusCode::PASS_ERROR; } catch (...) { logErrorDaphneLibAware(daphneLibRes, "Lowering pipeline error: Unknown exception"); @@ -693,91 +603,86 @@ int startDAPHNE(int argc, const char** argv, DaphneLibResult* daphneLibRes, int // JIT-compile the module and execute it. // module->dump(); // print the LLVM IR representation clock::time_point tpBegExec; - try{ + try { auto engine = executor.createExecutionEngine(moduleOp); tpBegExec = clock::now(); - // set jump address for catching exceptions in kernel libraries via signal handling - if(setjmp(return_from_handler) == 0) { + // set jump address for catching exceptions in kernel libraries via + // signal handling + if (setjmp(return_from_handler) == 0) { auto error = engine->invoke("main"); if (error) { llvm::errs() << "JIT-Engine invocation failed: " << error; return StatusCode::EXECUTION_ERROR; } - } - else { - logErrorDaphneLibAware( - daphneLibRes, - "Got an abort signal from the execution engine. Most likely an " - "exception in a shared library. Check logs!\n" - "Execution error: Returning from signal " + std::to_string(gSignalStatus) - ); + } else { + logErrorDaphneLibAware(daphneLibRes, "Got an abort signal from the execution engine. Most likely an " + "exception in a shared library. Check logs!\n" + "Execution error: Returning from signal " + + std::to_string(gSignalStatus)); return StatusCode::EXECUTION_ERROR; } - } - catch (std::runtime_error& re) { + } catch (std::runtime_error &re) { logErrorDaphneLibAware(daphneLibRes, "Execution error: " + std::string(re.what())); return StatusCode::EXECUTION_ERROR; - } - catch(std::exception & e){ + } catch (std::exception &e) { logErrorDaphneLibAware(daphneLibRes, "Execution error " + std::string(e.what())); return StatusCode::EXECUTION_ERROR; } clock::time_point tpEnd = clock::now(); - if(timing) { + if (timing) { // Calculate durations of the individual high-level steps of DAPHNE. - double durStrt = chrono::duration_cast>(tpBegPars - tpBeg ).count(); - double durPars = chrono::duration_cast>(tpBegComp - tpBegPars).count(); - double durComp = chrono::duration_cast>(tpBegExec - tpBegComp).count(); - double durExec = chrono::duration_cast>(tpEnd - tpBegExec).count(); - double durTotal = chrono::duration_cast>(tpEnd - tpBeg ).count(); + double durStrt = chrono::duration_cast>(tpBegPars - tpBeg).count(); + double durPars = chrono::duration_cast>(tpBegComp - tpBegPars).count(); + double durComp = chrono::duration_cast>(tpBegExec - tpBegComp).count(); + double durExec = chrono::duration_cast>(tpEnd - tpBegExec).count(); + double durTotal = chrono::duration_cast>(tpEnd - tpBeg).count(); // ToDo: use logger // Output durations in JSON. std::cerr << "{"; - std::cerr << "\"startup_seconds\": " << durStrt << ", "; - std::cerr << "\"parsing_seconds\": " << durPars << ", "; - std::cerr << "\"compilation_seconds\": " << durComp << ", "; - std::cerr << "\"execution_seconds\": " << durExec << ", "; - std::cerr << "\"total_seconds\": " << durTotal; + std::cerr << "\"startup_seconds\": " << durStrt << ", "; + std::cerr << "\"parsing_seconds\": " << durPars << ", "; + std::cerr << "\"compilation_seconds\": " << durComp << ", "; + std::cerr << "\"execution_seconds\": " << durExec << ", "; + std::cerr << "\"total_seconds\": " << durTotal; std::cerr << "}" << std::endl; } if (user_config.statistics) Statistics::instance().dumpStatistics(KernelDispatchMapping::instance()); - // explicitly destroying the moduleOp here due to valgrind complaining about a memory leak otherwise. + // explicitly destroying the moduleOp here due to valgrind complaining about + // a memory leak otherwise. moduleOp->destroy(); return StatusCode::SUCCESS; } - -int mainInternal(int argc, const char** argv, DaphneLibResult* daphneLibRes){ - int id=-1; // this -1 would not change if the user did not select mpi backend during execution +int mainInternal(int argc, const char **argv, DaphneLibResult *daphneLibRes) { + int id = -1; // this -1 would not change if the user did not select mpi + // backend during execution // Initialize user configuration. DaphneUserConfig user_config{}; - int res=startDAPHNE(argc, argv, daphneLibRes, &id, user_config); + int res = startDAPHNE(argc, argv, daphneLibRes, &id, user_config); -#ifdef USE_MPI - if(id==COORDINATOR) - { - int size=0; +#ifdef USE_MPI + if (id == COORDINATOR) { + int size = 0; MPI_Comm_size(MPI_COMM_WORLD, &size); - unsigned char terminateMessage=0x00; - for(int i=1;i-1){ + unsigned char terminateMessage = 0x00; + for (int i = 1; i < size; i++) { + MPI_Send(&terminateMessage, 1, MPI_UNSIGNED_CHAR, i, DETACH, MPI_COMM_WORLD); + } + MPI_Finalize(); + } else if (id > -1) { MPIWorker worker(user_config); worker.joinComputingTeam(); - res=StatusCode::SUCCESS; + res = StatusCode::SUCCESS; MPI_Finalize(); } #endif - + return res; } diff --git a/src/api/internal/daphne_internal.h b/src/api/internal/daphne_internal.h index 737b0383f..c237ff35d 100644 --- a/src/api/internal/daphne_internal.h +++ b/src/api/internal/daphne_internal.h @@ -18,4 +18,4 @@ #include -int mainInternal(int argc, const char** argv, DaphneLibResult* daphneLibRes); \ No newline at end of file +int mainInternal(int argc, const char **argv, DaphneLibResult *daphneLibRes); \ No newline at end of file diff --git a/src/compiler/catalog/KernelCatalog.h b/src/compiler/catalog/KernelCatalog.h index 258cb8d78..f15a031bc 100644 --- a/src/compiler/catalog/KernelCatalog.h +++ b/src/compiler/catalog/KernelCatalog.h @@ -60,15 +60,9 @@ struct KernelInfo { */ const std::string libPath; - KernelInfo( - const std::string kernelFuncName, - const std::vector resTypes, - const std::vector argTypes, - const std::string backend, - const std::string libPath - ) : - kernelFuncName(kernelFuncName), resTypes(resTypes), argTypes(argTypes), backend(backend), libPath(libPath) - { + KernelInfo(const std::string kernelFuncName, const std::vector resTypes, + const std::vector argTypes, const std::string backend, const std::string libPath) + : kernelFuncName(kernelFuncName), resTypes(resTypes), argTypes(argTypes), backend(backend), libPath(libPath) { // } }; @@ -78,44 +72,46 @@ struct KernelInfo { */ class KernelCatalog { /** - * @brief The central data structure mapping DaphneIR operations to registered kernels. - * - * The DaphneIR operation is represented by its mnemonic. The kernels are represented - * by their kernel information. + * @brief The central data structure mapping DaphneIR operations to + * registered kernels. + * + * The DaphneIR operation is represented by its mnemonic. The kernels are + * represented by their kernel information. */ std::unordered_map> kernelInfosByOp; /** * @brief Prints the given kernel information. - * + * * @param opMnemonic The mnemonic of the corresponding DaphneIR operation. * @param kernelInfos The kernel information to print. * @param os The stream to print to. Defaults to `std::cerr`. */ - void dumpKernelInfos(const std::string & opMnemonic, const std::vector & kernelInfos, std::ostream & os = std::cerr) const { + void dumpKernelInfos(const std::string &opMnemonic, const std::vector &kernelInfos, + std::ostream &os = std::cerr) const { os << "- operation `" << opMnemonic << "` (" << kernelInfos.size() << " kernels)" << std::endl; - for(KernelInfo ki : kernelInfos) { + for (KernelInfo ki : kernelInfos) { os << " - kernel `" << ki.kernelFuncName << "`: ("; - for(size_t i = 0; i < ki.argTypes.size(); i++) { + for (size_t i = 0; i < ki.argTypes.size(); i++) { os << ki.argTypes[i]; - if(i < ki.argTypes.size() - 1) + if (i < ki.argTypes.size() - 1) os << ", "; } os << ") -> ("; - for(size_t i = 0; i < ki.resTypes.size(); i++) { + for (size_t i = 0; i < ki.resTypes.size(); i++) { os << ki.resTypes[i]; - if(i < ki.resTypes.size() - 1) + if (i < ki.resTypes.size() - 1) os << ", "; } - os << ") for backend `" << ki.backend << "` (in `" << ki.libPath << "`)" << std::endl; + os << ") for backend `" << ki.backend << "` (in `" << ki.libPath << "`)" << std::endl; } } -public: + public: /** - * @brief Registers the given kernel information as a kernel for the DaphneIR - * operation with the given mnemonic. - * + * @brief Registers the given kernel information as a kernel for the + * DaphneIR operation with the given mnemonic. + * * @param opMnemonic The DaphneIR operation's mnemonic. * @param kernelInfo The information on the kernel. */ @@ -124,15 +120,16 @@ class KernelCatalog { } /** - * @brief Retrieves information on all kernels registered for the given DaphneIR operation. - * + * @brief Retrieves information on all kernels registered for the given + * DaphneIR operation. + * * @param opMnemonic The mnemonic of the DaphneIR operation. - * @return A vector of kernel information, or an empty vector if no kernels are registered - * for the given operation. + * @return A vector of kernel information, or an empty vector if no kernels + * are registered for the given operation. */ - const std::vector getKernelInfos(const std::string & opMnemonic) const { + const std::vector getKernelInfos(const std::string &opMnemonic) const { auto it = kernelInfosByOp.find(opMnemonic); - if(it != kernelInfosByOp.end()) + if (it != kernelInfosByOp.end()) return it->second; else return {}; @@ -145,44 +142,43 @@ class KernelCatalog { * @param kernelFuncName The name of the kernel function to look for. * @return The mnemonic of the operation. */ - std::string getOpMnemonic(const std::string & kernelFuncName) { - for(auto it : kernelInfosByOp) { + std::string getOpMnemonic(const std::string &kernelFuncName) { + for (auto it : kernelInfosByOp) { std::string opMnemonic = it.first; - const std::vector & kis = it.second; - for(auto it2 : kis) - if(it2.kernelFuncName == kernelFuncName) + const std::vector &kis = it.second; + for (auto it2 : kis) + if (it2.kernelFuncName == kernelFuncName) return opMnemonic; } - throw std::runtime_error( - "no kernel with name `" + kernelFuncName + "` registered in the kernel catalog" - ); + throw std::runtime_error("no kernel with name `" + kernelFuncName + "` registered in the kernel catalog"); } /** * @brief Prints high-level statistics on the kernel catalog. - * + * * @param os The stream to print to. Defaults to `std::cerr`. */ - void stats(std::ostream & os = std::cerr) const { + void stats(std::ostream &os = std::cerr) const { const size_t numOps = kernelInfosByOp.size(); size_t numKernels = 0; - for(auto it = kernelInfosByOp.begin(); it != kernelInfosByOp.end(); it++) + for (auto it = kernelInfosByOp.begin(); it != kernelInfosByOp.end(); it++) numKernels += it->second.size(); os << "KernelCatalog (" << numOps << " ops, " << numKernels << " kernels)" << std::endl; } /** * @brief Prints this kernel catalog. - * - * @param opMnemonic If an empty string, print registered kernels for all DaphneIR - * operations; otherwise, consider only the specified DaphneIR operation. + * + * @param opMnemonic If an empty string, print registered kernels for all + * DaphneIR operations; otherwise, consider only the specified DaphneIR + * operation. * @param os The stream to print to. Defaults to `std::cerr`. */ - void dump(std::string opMnemonic = "", std::ostream & os = std::cerr) const { + void dump(std::string opMnemonic = "", std::ostream &os = std::cerr) const { stats(os); - if(opMnemonic.empty()) + if (opMnemonic.empty()) // Print info on all ops. - for(auto it = kernelInfosByOp.begin(); it != kernelInfosByOp.end(); it++) + for (auto it = kernelInfosByOp.begin(); it != kernelInfosByOp.end(); it++) dumpKernelInfos(it->first, it->second, os); else // Print info on specified op only. @@ -190,17 +186,18 @@ class KernelCatalog { } /** - * @brief Returns all distinct kernel libraries in the form of a mapping from - * the library path to the constant `false`. + * @brief Returns all distinct kernel libraries in the form of a mapping + * from the library path to the constant `false`. * - * @return A mapping from each distict kernel library path to the constant `false`. + * @return A mapping from each distict kernel library path to the constant + * `false`. */ std::unordered_map getLibPaths() const { std::unordered_map res; - for(auto it : kernelInfosByOp) { - const std::vector & kis = it.second; - for(auto it2 : kis) + for (auto it : kernelInfosByOp) { + const std::vector &kis = it.second; + for (auto it2 : kis) res[it2.libPath] = false; } diff --git a/src/compiler/execution/DaphneIrExecutor.cpp b/src/compiler/execution/DaphneIrExecutor.cpp index b19f76ba7..67fdac21d 100644 --- a/src/compiler/execution/DaphneIrExecutor.cpp +++ b/src/compiler/execution/DaphneIrExecutor.cpp @@ -18,14 +18,13 @@ #include #include -#include #include +#include #include #include #include -#include "llvm/Support/TargetSelect.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" @@ -47,13 +46,13 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Transforms/Passes.h" +#include "llvm/Support/TargetSelect.h" -DaphneIrExecutor::DaphneIrExecutor(bool selectMatrixRepresentations, - DaphneUserConfig cfg) - : userConfig_(std::move(cfg)), - selectMatrixRepresentations_(selectMatrixRepresentations) { +DaphneIrExecutor::DaphneIrExecutor(bool selectMatrixRepresentations, DaphneUserConfig cfg) + : userConfig_(std::move(cfg)), selectMatrixRepresentations_(selectMatrixRepresentations) { // register loggers - if (userConfig_.log_ptr != nullptr) userConfig_.log_ptr->registerLoggers(); + if (userConfig_.log_ptr != nullptr) + userConfig_.log_ptr->registerLoggers(); context_.getOrLoadDialect(); context_.getOrLoadDialect(); @@ -78,49 +77,34 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { // return false; //} - if (!module) return false; + if (!module) + return false; // This flag is really useful to figure out why the lowering failed llvm::DebugFlag = userConfig_.debug_llvm; - { - mlir::PassManager pm(&context_); - // TODO Enable the verifier for all passes where it is possible. - // Originally, it was only turned off for the - // SpecializeGenericFunctionsPass. - pm.enableVerifier(false); - if (userConfig_.explain_parsing) - pm.addPass(mlir::daphne::createPrintIRPass("IR after parsing:")); + mlir::PassManager pm(&context_); + // TODO Enable the verifier for all passes where it is possible. + // Originally, it was only turned off for the + // SpecializeGenericFunctionsPass. + pm.enableVerifier(false); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - if (userConfig_.explain_parsing_simplified) - pm.addPass(mlir::daphne::createPrintIRPass( - "IR after parsing and some simplifications:")); - - pm.addPass(mlir::daphne::createRewriteSqlOpPass()); // calls SQL Parser - if (userConfig_.explain_sql) - pm.addPass( - mlir::daphne::createPrintIRPass("IR after SQL parsing:")); - - pm.addPass( - mlir::daphne::createSpecializeGenericFunctionsPass(userConfig_)); - if (userConfig_.explain_property_inference) - pm.addPass(mlir::daphne::createPrintIRPass("IR after inference:")); - - try { - if (failed(pm.run(module))) { - module->dump(); - module->emitError("module pass error"); - return false; - } - } catch(...) { - ErrorHandler::dumpModuleToDisk(module); - throw; - } - } + if (userConfig_.explain_parsing) + pm.addPass(mlir::daphne::createPrintIRPass("IR after parsing:")); + + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + if (userConfig_.explain_parsing_simplified) + pm.addPass(mlir::daphne::createPrintIRPass("IR after parsing and some simplifications:")); + + pm.addPass(mlir::daphne::createRewriteSqlOpPass()); // calls SQL Parser + if (userConfig_.explain_sql) + pm.addPass(mlir::daphne::createPrintIRPass("IR after SQL parsing:")); + + pm.addPass(mlir::daphne::createSpecializeGenericFunctionsPass(userConfig_)); + if (userConfig_.explain_property_inference) + pm.addPass(mlir::daphne::createPrintIRPass("IR after inference:")); - mlir::PassManager pm(&context_); // Note that property inference and canonicalization have already been done // in the SpecializeGenericFunctionsPass, so actually, it's not necessary // here anymore. @@ -137,22 +121,18 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { } if (userConfig_.explain_select_matrix_repr) - pm.addPass(mlir::daphne::createPrintIRPass( - "IR after selecting matrix representations:")); + pm.addPass(mlir::daphne::createPrintIRPass("IR after selecting matrix representations:")); if (userConfig_.use_phy_op_selection) { pm.addPass(mlir::daphne::createPhyOperatorSelectionPass()); pm.addPass(mlir::createCSEPass()); } if (userConfig_.explain_phy_op_selection) - pm.addPass(mlir::daphne::createPrintIRPass( - "IR after selecting physical operators:")); + pm.addPass(mlir::daphne::createPrintIRPass("IR after selecting physical operators:")); - pm.addNestedPass( - mlir::daphne::createAdaptTypesToKernelsPass()); + pm.addNestedPass(mlir::daphne::createAdaptTypesToKernelsPass()); if (userConfig_.explain_type_adaptation) - pm.addPass( - mlir::daphne::createPrintIRPass("IR after type adaptation:")); + pm.addPass(mlir::daphne::createPrintIRPass("IR after type adaptation:")); // For now, in order to use the distributed runtime we also require the // vectorized engine to be enabled to create pipelines. Therefore, *if* @@ -160,8 +140,7 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { if (userConfig_.use_vectorized_exec || userConfig_.use_distributed) { // TODO: add inference here if we have rewrites that could apply to // vectorized pipelines due to smaller sizes - pm.addNestedPass( - mlir::daphne::createVectorizeComputationsPass()); + pm.addNestedPass(mlir::daphne::createVectorizeComputationsPass()); pm.addPass(mlir::createCanonicalizerPass()); } if (userConfig_.explain_vectorized) @@ -170,25 +149,22 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { if (userConfig_.use_distributed) pm.addPass(mlir::daphne::createDistributePipelinesPass()); - if (userConfig_.use_mlir_codegen || userConfig_.use_mlir_hybrid_codegen) buildCodegenPipeline(pm); + if (userConfig_.use_mlir_codegen || userConfig_.use_mlir_hybrid_codegen) + buildCodegenPipeline(pm); if (userConfig_.enable_profiling) - pm.addNestedPass( - mlir::daphne::createProfilingPass()); + pm.addNestedPass(mlir::daphne::createProfilingPass()); - pm.addNestedPass( - mlir::daphne::createInsertDaphneContextPass(userConfig_)); + pm.addNestedPass(mlir::daphne::createInsertDaphneContextPass(userConfig_)); #ifdef USE_CUDA if (userConfig_.use_cuda) - pm.addNestedPass( - mlir::daphne::createMarkCUDAOpsPass(userConfig_)); + pm.addNestedPass(mlir::daphne::createMarkCUDAOpsPass(userConfig_)); #endif #ifdef USE_FPGAOPENCL if (userConfig_.use_fpgaopencl) - pm.addNestedPass( - mlir::daphne::createMarkFPGAOPENCLOpsPass(userConfig_)); + pm.addNestedPass(mlir::daphne::createMarkFPGAOPENCLOpsPass(userConfig_)); #endif // Tidy up the IR before managing object reference counters with IncRefOp @@ -200,21 +176,16 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { pm.addPass(mlir::createCSEPass()); if (userConfig_.use_obj_ref_mgnt) - pm.addNestedPass( - mlir::daphne::createManageObjRefsPass()); + pm.addNestedPass(mlir::daphne::createManageObjRefsPass()); if (userConfig_.explain_obj_ref_mgnt) - pm.addPass(mlir::daphne::createPrintIRPass( - "IR after managing object references:")); + pm.addPass(mlir::daphne::createPrintIRPass("IR after managing object references:")); - pm.addNestedPass( - mlir::daphne::createRewriteToCallKernelOpPass(userConfig_, usedLibPaths)); + pm.addNestedPass(mlir::daphne::createRewriteToCallKernelOpPass(userConfig_, usedLibPaths)); if (userConfig_.explain_kernels) - pm.addPass( - mlir::daphne::createPrintIRPass("IR after kernel lowering:")); + pm.addPass(mlir::daphne::createPrintIRPass("IR after kernel lowering:")); pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addNestedPass( - mlir::LLVM::createRequestCWrappersPass()); + pm.addNestedPass(mlir::LLVM::createRequestCWrappersPass()); pm.addPass(mlir::daphne::createLowerToLLVMPass(userConfig_)); pm.addPass(mlir::createReconcileUnrealizedCastsPass()); if (userConfig_.explain_llvm) @@ -222,7 +193,7 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { // Initialize the use of each distinct kernels library to false. usedLibPaths = userConfig_.kernelCatalog.getLibPaths(); - + try { if (failed(pm.run(module))) { module->dump(); @@ -237,9 +208,9 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) { return true; } -std::unique_ptr DaphneIrExecutor::createExecutionEngine( - mlir::ModuleOp module) { - if (!module) return nullptr; +std::unique_ptr DaphneIrExecutor::createExecutionEngine(mlir::ModuleOp module) { + if (!module) + return nullptr; // An optimization pipeline to use within the execution engine. unsigned optLevel = 0; unsigned sizeLevel = 0; @@ -248,19 +219,17 @@ std::unique_ptr DaphneIrExecutor::createExecutionEngine( // Determine the actually used kernels libraries. std::vector sharedLibRefs; - for(auto it = usedLibPaths.begin(); it != usedLibPaths.end(); it++) - if(it->second) { + for (auto it = usedLibPaths.begin(); it != usedLibPaths.end(); it++) + if (it->second) { std::string usedLibPath = it->first; sharedLibRefPaths.push_back(usedLibPath); sharedLibRefs.emplace_back(sharedLibRefPaths.back()); - // Check if the used kernels library really exists at the expected path - // and throw an understandable error, otherwise. - if(!std::filesystem::exists(usedLibPath)) - throw std::runtime_error( - "the shared library `" + usedLibPath + - "` is needed for some kernel, but the file does not exist" - ); + // Check if the used kernels library really exists at the expected + // path and throw an understandable error, otherwise. + if (!std::filesystem::exists(usedLibPath)) + throw std::runtime_error("the shared library `" + usedLibPath + + "` is needed for some kernel, but the file does not exist"); } registerLLVMDialectTranslation(context_); @@ -276,8 +245,7 @@ std::unique_ptr DaphneIrExecutor::createExecutionEngine( auto maybeEngine = mlir::ExecutionEngine::create(module, options); if (!maybeEngine) { - llvm::errs() << "Failed to create JIT-Execution engine: " - << maybeEngine.takeError(); + llvm::errs() << "Failed to create JIT-Execution engine: " << maybeEngine.takeError(); return nullptr; } return std::move(maybeEngine.get()); @@ -285,8 +253,7 @@ std::unique_ptr DaphneIrExecutor::createExecutionEngine( void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { if (userConfig_.explain_mlir_codegen) - pm.addPass( - mlir::daphne::createPrintIRPass("IR before codegen pipeline")); + pm.addPass(mlir::daphne::createPrintIRPass("IR before codegen pipeline")); pm.addPass(mlir::daphne::createDaphneOptPass()); pm.addPass(mlir::daphne::createEwOpLoweringPass()); @@ -298,28 +265,23 @@ void DaphneIrExecutor::buildCodegenPipeline(mlir::PassManager &pm) { if (!userConfig_.use_mlir_hybrid_codegen) { pm.addPass(mlir::daphne::createMatMulOpLoweringPass( - userConfig_.matmul_tile, userConfig_.matmul_vec_size_bits, - userConfig_.matmul_fixed_tile_sizes, - userConfig_.matmul_use_fixed_tile_sizes, - userConfig_.matmul_unroll_factor, userConfig_.matmul_unroll_jam_factor, - userConfig_.matmul_num_vec_registers, - userConfig_.matmul_invert_loops)); + userConfig_.matmul_tile, userConfig_.matmul_vec_size_bits, userConfig_.matmul_fixed_tile_sizes, + userConfig_.matmul_use_fixed_tile_sizes, userConfig_.matmul_unroll_factor, + userConfig_.matmul_unroll_jam_factor, userConfig_.matmul_num_vec_registers, + userConfig_.matmul_invert_loops)); if (userConfig_.explain_mlir_codegen) - pm.addPass( - mlir::daphne::createPrintIRPass("IR directly after lowering MatMulOp.")); + pm.addPass(mlir::daphne::createPrintIRPass("IR directly after lowering MatMulOp.")); } pm.addPass(mlir::createConvertMathToLLVMPass()); pm.addPass(mlir::daphne::createModOpLoweringPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - pm.addNestedPass( - mlir::createAffineScalarReplacementPass()); + pm.addNestedPass(mlir::createAffineScalarReplacementPass()); pm.addPass(mlir::createLowerAffinePass()); mlir::LowerVectorToLLVMOptions lowerVectorToLLVMOptions; pm.addPass(mlir::createConvertVectorToLLVMPass(lowerVectorToLLVMOptions)); - + if (userConfig_.explain_mlir_codegen) - pm.addPass( - mlir::daphne::createPrintIRPass("IR after codegen pipeline")); + pm.addPass(mlir::daphne::createPrintIRPass("IR after codegen pipeline")); } diff --git a/src/compiler/execution/DaphneIrExecutor.h b/src/compiler/execution/DaphneIrExecutor.h index b13a23a73..809dafb45 100644 --- a/src/compiler/execution/DaphneIrExecutor.h +++ b/src/compiler/execution/DaphneIrExecutor.h @@ -16,33 +16,27 @@ #pragma once -#include "mlir/IR/BuiltinOps.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" -#include +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/PassManager.h" +#include #include -class DaphneIrExecutor -{ -public: +class DaphneIrExecutor { + public: DaphneIrExecutor(bool selectMatrixRepresentations, DaphneUserConfig cfg); bool runPasses(mlir::ModuleOp module); std::unique_ptr createExecutionEngine(mlir::ModuleOp module); - mlir::MLIRContext *getContext() - { return &context_; } + mlir::MLIRContext *getContext() { return &context_; } - DaphneUserConfig & getUserConfig() { - return userConfig_; - } + DaphneUserConfig &getUserConfig() { return userConfig_; } - const DaphneUserConfig & getUserConfig() const { - return userConfig_; - } + const DaphneUserConfig &getUserConfig() const { return userConfig_; } -private: + private: mlir::MLIRContext context_; DaphneUserConfig userConfig_; bool selectMatrixRepresentations_; @@ -50,17 +44,16 @@ class DaphneIrExecutor std::vector sharedLibRefPaths; /** - * @brief A map indicating which of the distinct kernels libraries known to the - * kernel catalog are actually used in the MLIR module. + * @brief A map indicating which of the distinct kernels libraries known to + * the kernel catalog are actually used in the MLIR module. * - * This map gets pre-populated with `false` for each distinct library. The values - * are set to `true` when a call to a pre-compiled kernel from that library is - * created by this pass. This approach is thread-safe, since the structure of the - * map does not change anymore. Thus, it can be used by multiple concurrent - * instances of this pass. + * This map gets pre-populated with `false` for each distinct library. The + * values are set to `true` when a call to a pre-compiled kernel from that + * library is created by the RewriteToCallKernelOpPass pass. This approach + * is thread-safe, since the structure of the map does not change anymore. + * Thus, it can be used by multiple concurrent instances of this pass. */ std::unordered_map usedLibPaths; void buildCodegenPipeline(mlir::PassManager &); }; - diff --git a/src/compiler/explanation/PrintIRPass.cpp b/src/compiler/explanation/PrintIRPass.cpp index 3adf1bf5b..befc773fa 100644 --- a/src/compiler/explanation/PrintIRPass.cpp +++ b/src/compiler/explanation/PrintIRPass.cpp @@ -31,7 +31,7 @@ using namespace mlir; class PrintIRPass : public PassWrapper> { std::string message; - public: + public: PrintIRPass(const std::string message) : message(message) {} void runOnOperation() final; diff --git a/src/compiler/inference/AdaptTypesToKernelsPass.cpp b/src/compiler/inference/AdaptTypesToKernelsPass.cpp index 0e3ec7d4e..414c27827 100644 --- a/src/compiler/inference/AdaptTypesToKernelsPass.cpp +++ b/src/compiler/inference/AdaptTypesToKernelsPass.cpp @@ -26,121 +26,110 @@ using namespace mlir; /** - * @brief Adapts an operation's input/output types such that it can be lowered to an available pre-compiled kernel. - * - * While type inference propagates types through the IR, it is not guaranteed that a pre-compiled kernel - * for each infered type combination is available. Thus, the task of this pass is to adapt input and - * output types by casts, where necessary, to ensure that an existing pre-compiled kernel can be used. - * - * At the moment, this pass is implemented in a very simple way. It supports only two concrete actions: - * - It harmonizes the value types of all inputs with those of the single output, for certain operations. - * This is because so far we mainly pre-compile our kernels for homogeneous combinations of input/output - * types. - * - It harmonizes the value types of all inputs (independently of the output type), for certain operations. - * This is because some kernels need to output a different type than their inputs (e.g., comparisons on - * non-numeric value types). - * In general, the affected operations are marked by traits. - * - * In the future, this pass should take the kernel registry and/or extension catalog into account to find - * out for which type combinations there are available kernels. + * @brief Adapts an operation's input/output types such that it can be lowered + * to an available pre-compiled kernel. + * + * While type inference propagates types through the IR, it is not guaranteed + * that a pre-compiled kernel for each infered type combination is available. + * Thus, the task of this pass is to adapt input and output types by casts, + * where necessary, to ensure that an existing pre-compiled kernel can be used. + * + * At the moment, this pass is implemented in a very simple way. It supports + * only two concrete actions: + * - It harmonizes the value types of all inputs with those of the single + * output, for certain operations. This is because so far we mainly pre-compile + * our kernels for homogeneous combinations of input/output types. + * - It harmonizes the value types of all inputs (independently of the output + * type), for certain operations. This is because some kernels need to output a + * different type than their inputs (e.g., comparisons on non-numeric value + * types). In general, the affected operations are marked by traits. + * + * In the future, this pass should take the kernel registry and/or extension + * catalog into account to find out for which type combinations there are + * available kernels. */ -struct AdaptTypesToKernelsPass : public PassWrapper> -{ +struct AdaptTypesToKernelsPass : public PassWrapper> { void runOnOperation() final; StringRef getArgument() const final { return "adapt-types-to-kernels"; } - StringRef getDescription() const final { - return "TODO"; - } + StringRef getDescription() const final { return "TODO"; } }; -void AdaptTypesToKernelsPass::runOnOperation() -{ +void AdaptTypesToKernelsPass::runOnOperation() { func::FuncOp f = getOperation(); OpBuilder builder(f.getContext()); - f.getBody().front().walk([&](Operation* op) { + f.getBody().front().walk([&](Operation *op) { const size_t numOperands = op->getNumOperands(); - // Depending on the related trait, determine which inputs to cast to which value type. + // Depending on the related trait, determine which inputs to cast to + // which value type. std::vector operandIdxs; // the inputs to cast - Type targetVTy; // the value type to cast to + Type targetVTy; // the value type to cast to - if(op->hasTrait()) { + if (op->hasTrait()) { // The only related trait that does not consider the result type. // TODO Support frame ops. - // Skip frame ops, since we cannot easily cast the column types of frames anyway. - if(llvm::any_of(op->getOperands(), [](Value operand){ - return llvm::isa(operand.getType()); - })) + // Skip frame ops, since we cannot easily cast the column types of + // frames anyway. + if (llvm::any_of(op->getOperands(), + [](Value operand) { return llvm::isa(operand.getType()); })) return; // Cast all inputs to the most general input value type. - for(size_t i = 0; i < numOperands; i++) + for (size_t i = 0; i < numOperands; i++) operandIdxs.push_back(i); std::vector argVTys; - for(size_t i = 0; i < numOperands; i++) + for (size_t i = 0; i < numOperands; i++) argVTys.push_back(CompilerUtils::getValueType(op->getOperand(i).getType())); targetVTy = mostGeneralVt(argVTys); - } - else { + } else { // All remaining related traits consider the result type. // Skip operations without results. - if(!op->getNumResults()) + if (!op->getNumResults()) return; Type resTy = op->getResult(0).getType(); // TODO Support frame ops. - // Skip frame ops, since we cannot easily cast the column types of frames anyway. - if( - llvm::isa(resTy) || - llvm::any_of(op->getOperands(), [](Value operand){ + // Skip frame ops, since we cannot easily cast the column types of + // frames anyway. + if (llvm::isa(resTy) || llvm::any_of(op->getOperands(), [](Value operand) { return llvm::isa(operand.getType()); - }) - ) + })) return; Type resVTy = CompilerUtils::getValueType(resTy); - if(op->hasTrait()) { + if (op->hasTrait()) { // Cast all inputs to the result value type. - for(size_t i = 0; i < numOperands; i++) + for (size_t i = 0; i < numOperands; i++) operandIdxs.push_back(i); targetVTy = resVTy; - } - else if(op->hasTrait()) { + } else if (op->hasTrait()) { // Cast inputs 0 and 1 to the result value type. operandIdxs = {0, 1}; targetVTy = resVTy; } - // TODO Instead of such a non-reusable op-specific trait, we should rather check for the concrete op here. - else if(op->hasTrait()) { + // TODO Instead of such a non-reusable op-specific trait, we should + // rather check for the concrete op here. + else if (op->hasTrait()) { // Cast inputs 2 and 3 to the result value type. operandIdxs = {2, 3}; targetVTy = resVTy; } } - if(!operandIdxs.empty()) { + if (!operandIdxs.empty()) { // Insert casts where necessary. builder.setInsertionPoint(op); - for(size_t i : operandIdxs) { + for (size_t i : operandIdxs) { Value argVal = op->getOperand(i); Type argTy = argVal.getType(); - if(CompilerUtils::getValueType(argTy) != targetVTy) { - op->setOperand( - i, - builder.create( - argVal.getLoc(), - CompilerUtils::setValueType(argTy, targetVTy), - argVal - ) - ); + if (CompilerUtils::getValueType(argTy) != targetVTy) { + op->setOperand(i, builder.create( + argVal.getLoc(), CompilerUtils::setValueType(argTy, targetVTy), argVal)); } } } }); } -std::unique_ptr daphne::createAdaptTypesToKernelsPass() -{ - return std::make_unique(); -} +std::unique_ptr daphne::createAdaptTypesToKernelsPass() { return std::make_unique(); } diff --git a/src/compiler/inference/InferencePass.cpp b/src/compiler/inference/InferencePass.cpp index 192890783..2fe63de74 100644 --- a/src/compiler/inference/InferencePass.cpp +++ b/src/compiler/inference/InferencePass.cpp @@ -14,121 +14,105 @@ * limitations under the License. */ -#include #include #include +#include #include #include #include #include -#include #include -#include +#include #include +#include using namespace mlir; -daphne::InferenceConfig::InferenceConfig(bool partialInferenceAllowed, - bool typeInference, - bool shapeInference, - bool frameLabelInference, - bool sparsityInference) +daphne::InferenceConfig::InferenceConfig(bool partialInferenceAllowed, bool typeInference, bool shapeInference, + bool frameLabelInference, bool sparsityInference) : partialInferenceAllowed(partialInferenceAllowed), typeInference(typeInference), shapeInference(shapeInference), frameLabelInference(frameLabelInference), sparsityInference(sparsityInference) {} namespace { - void castOperandIf(OpBuilder & builder, Operation * op, size_t operandIdx, Type type) { - Value operand = op->getOperand(operandIdx); - if(operand.getType() != type) { - builder.setInsertionPoint(op); - op->setOperand( - operandIdx, - // TODO Is this the right loc? - builder.create(op->getLoc(), type, operand) - ); - } +void castOperandIf(OpBuilder &builder, Operation *op, size_t operandIdx, Type type) { + Value operand = op->getOperand(operandIdx); + if (operand.getType() != type) { + builder.setInsertionPoint(op); + op->setOperand(operandIdx, + // TODO Is this the right loc? + builder.create(op->getLoc(), type, operand)); } +} - /** - * @brief Returns a type retaining all common properties of the two - * given types, and setting all mismatching properties to unknown. - * - * If the two given types are of different data types, then `nullptr` - * is returned. - */ - Type getTypeWithCommonInfo(Type t1, Type t2) { - MLIRContext* ctx = t1.getContext(); - Type u = daphne::UnknownType::get(ctx); - auto mat1 = t1.dyn_cast(); - auto mat2 = t2.dyn_cast(); - auto frm1 = t1.dyn_cast(); - auto frm2 = t2.dyn_cast(); - - if(mat1 && mat2) { // both types are matrices - const Type vt1 = mat1.getElementType(); - const Type vt2 = mat2.getElementType(); - const ssize_t nr1 = mat1.getNumRows(); - const ssize_t nr2 = mat2.getNumRows(); - const ssize_t nc1 = mat1.getNumCols(); - const ssize_t nc2 = mat2.getNumCols(); - const ssize_t sp1 = mat1.getSparsity(); - const ssize_t sp2 = mat2.getSparsity(); - const daphne::MatrixRepresentation repr1 = mat1.getRepresentation(); - const daphne::MatrixRepresentation repr2 = mat2.getRepresentation(); - return daphne::MatrixType::get( - ctx, - (vt1 == vt2) ? vt1 : u, - (nr1 == nr2) ? nr1 : -1, - (nc1 == nc2) ? nc1 : -1, - // TODO Maybe do approximate comparison of floating-point values. - (sp1 == sp2) ? sp1 : -1, - (repr1 == repr2) ? repr1 : daphne::MatrixRepresentation::Default - ); - } - else if(frm1 && frm2) { // both types are frames - const std::vector cts1 = frm1.getColumnTypes(); - const std::vector cts2 = frm2.getColumnTypes(); - std::vector cts3; - if(cts1.size() == cts2.size()) - for(size_t i = 0; i < cts1.size(); i++) - cts3.push_back((cts1[i] == cts2[i]) ? cts1[i] : u); - else - // TODO How to represent a frame with unknown column - // types? See #421. - cts3.push_back(u); - const ssize_t nr1 = frm1.getNumRows(); - const ssize_t nr2 = frm2.getNumRows(); - const ssize_t nc1 = frm1.getNumCols(); - const ssize_t nc2 = frm2.getNumCols(); - std::vector* lbls1 = frm1.getLabels(); - std::vector* lbls2 = frm2.getLabels(); - return daphne::FrameType::get( - ctx, - cts3, - (nr1 == nr2) ? nr1 : -1, - (nc1 == nc2) ? nc1 : -1, - // TODO Take #485 into account. - (lbls1 == lbls2) ? lbls1 : nullptr - ); - } - else if(mat1 || mat2 || frm1 || frm2) // t1 and t2 are of different data types (matrix, frame, scalar) - return nullptr; - else // both types are unknown or scalars - return (t1 == t2) ? t1 : u; - } +/** + * @brief Returns a type retaining all common properties of the two + * given types, and setting all mismatching properties to unknown. + * + * If the two given types are of different data types, then `nullptr` + * is returned. + */ +Type getTypeWithCommonInfo(Type t1, Type t2) { + MLIRContext *ctx = t1.getContext(); + Type u = daphne::UnknownType::get(ctx); + auto mat1 = t1.dyn_cast(); + auto mat2 = t2.dyn_cast(); + auto frm1 = t1.dyn_cast(); + auto frm2 = t2.dyn_cast(); + + if (mat1 && mat2) { // both types are matrices + const Type vt1 = mat1.getElementType(); + const Type vt2 = mat2.getElementType(); + const ssize_t nr1 = mat1.getNumRows(); + const ssize_t nr2 = mat2.getNumRows(); + const ssize_t nc1 = mat1.getNumCols(); + const ssize_t nc2 = mat2.getNumCols(); + const ssize_t sp1 = mat1.getSparsity(); + const ssize_t sp2 = mat2.getSparsity(); + const daphne::MatrixRepresentation repr1 = mat1.getRepresentation(); + const daphne::MatrixRepresentation repr2 = mat2.getRepresentation(); + return daphne::MatrixType::get(ctx, (vt1 == vt2) ? vt1 : u, (nr1 == nr2) ? nr1 : -1, (nc1 == nc2) ? nc1 : -1, + // TODO Maybe do approximate comparison of floating-point values. + (sp1 == sp2) ? sp1 : -1, + (repr1 == repr2) ? repr1 : daphne::MatrixRepresentation::Default); + } else if (frm1 && frm2) { // both types are frames + const std::vector cts1 = frm1.getColumnTypes(); + const std::vector cts2 = frm2.getColumnTypes(); + std::vector cts3; + if (cts1.size() == cts2.size()) + for (size_t i = 0; i < cts1.size(); i++) + cts3.push_back((cts1[i] == cts2[i]) ? cts1[i] : u); + else + // TODO How to represent a frame with unknown column + // types? See #421. + cts3.push_back(u); + const ssize_t nr1 = frm1.getNumRows(); + const ssize_t nr2 = frm2.getNumRows(); + const ssize_t nc1 = frm1.getNumCols(); + const ssize_t nc2 = frm2.getNumCols(); + std::vector *lbls1 = frm1.getLabels(); + std::vector *lbls2 = frm2.getLabels(); + return daphne::FrameType::get(ctx, cts3, (nr1 == nr2) ? nr1 : -1, (nc1 == nc2) ? nc1 : -1, + // TODO Take #485 into account. + (lbls1 == lbls2) ? lbls1 : nullptr); + } else if (mat1 || mat2 || frm1 || frm2) // t1 and t2 are of different data + // types (matrix, frame, scalar) + return nullptr; + else // both types are unknown or scalars + return (t1 == t2) ? t1 : u; } +} // namespace /** * @brief A compiler pass infering various properties of the data objects. - * + * * Rooted at a function, the pass walks the operations, and for each operation * it encounters, it infers all currently considered properties of the * operation's results based on the properties of the operation's arguments. * This approach can easily handle dependencies between different properties to * be infered without explicitly modeling them. - * + * * Note that the actual inference logic is outsourced to MLIR operation * interfaces. */ @@ -136,18 +120,19 @@ class InferencePass : public PassWrapper walkSetUnknown = [&](Operation * op) { - // For all other operations, we reset the types of all results to unknown. - for(size_t i = 0; i < op->getNumResults(); i++) { + std::function walkSetUnknown = [&](Operation *op) { + // For all other operations, we reset the types of all results to + // unknown. + for (size_t i = 0; i < op->getNumResults(); i++) { Type t = op->getResult(i).getType(); - if(auto mt = t.dyn_cast()) + if (auto mt = t.dyn_cast()) t = mt.withSameElementType(); - else if(auto ft = t.dyn_cast()) + else if (auto ft = t.dyn_cast()) t = ft.withSameColumnTypes(); op->getResult(i).setType(t); } @@ -157,38 +142,31 @@ class InferencePass : public PassWrapper walkOp = [&](Operation * op) { + std::function walkOp = [&](Operation *op) { const bool isScfOp = op->getDialect() == op->getContext()->getOrLoadDialect(); // ---------------------------------------------------------------- // Handle all non-control-flow (non-SCF) operations // ---------------------------------------------------------------- - if(llvm::isa(op)) { - Type typeWithCommonInfo = getTypeWithCommonInfo( - op->getOperand(1).getType(), - op->getOperand(2).getType() - ); - if(!typeWithCommonInfo) { - throw ErrorHandler::compilerError( - op, "InferencePass.cpp:" + std::to_string(__LINE__), - " a variable must not be assigned values of " - "different data types (matrix, frame, scalar) " - "in then/else branches (arith.select)"); + if (llvm::isa(op)) { + Type typeWithCommonInfo = getTypeWithCommonInfo(op->getOperand(1).getType(), op->getOperand(2).getType()); + if (!typeWithCommonInfo) { + throw ErrorHandler::compilerError(op, "InferencePass.cpp:" + std::to_string(__LINE__), + " a variable must not be assigned values of " + "different data types (matrix, frame, scalar) " + "in then/else branches (arith.select)"); } OpBuilder builder(op->getContext()); castOperandIf(builder, op, 1, typeWithCommonInfo); castOperandIf(builder, op, 2, typeWithCommonInfo); op->getResult(0).setType(typeWithCommonInfo); - } - else if(!isScfOp) { + } else if (!isScfOp) { if (cfg.typeInference && returnsUnknownType(op)) { // Try to infer the types of all results of this operation. try { daphne::setInferedTypes(op, cfg.partialInferenceAllowed); - } - catch (std::runtime_error& re) { - throw ErrorHandler::rethrowError( - "InferencePass.cpp:" + std::to_string(__LINE__), re.what()); + } catch (std::runtime_error &re) { + throw ErrorHandler::rethrowError("InferencePass.cpp:" + std::to_string(__LINE__), re.what()); } } if (cfg.shapeInference && returnsUnknownShape(op)) { @@ -196,18 +174,16 @@ class InferencePass : public PassWrapper> shapes = daphne::tryInferShape(op); const size_t numRes = op->getNumResults(); if (shapes.size() != numRes) { - throw ErrorHandler::compilerError( - op, "InferencePass.cpp:" + std::to_string(__LINE__), - "shape inference for op " + - op->getName().getStringRef().str() + " returned " + - std::to_string(shapes.size()) + - " shapes, but the op has " + - std::to_string(numRes) + " results"); + throw ErrorHandler::compilerError(op, "InferencePass.cpp:" + std::to_string(__LINE__), + "shape inference for op " + op->getName().getStringRef().str() + + " returned " + std::to_string(shapes.size()) + + " shapes, but the op has " + std::to_string(numRes) + + " results"); } // Set the infered shapes on all results of this operation. - for(size_t i = 0 ; i < numRes ; i++) { - if(llvm::isa(op->getResultTypes()[i]) || - llvm::isa(op->getResultTypes()[i])) { + for (size_t i = 0; i < numRes; i++) { + if (llvm::isa(op->getResultTypes()[i]) || + llvm::isa(op->getResultTypes()[i])) { const ssize_t numRows = shapes[i].first; const ssize_t numCols = shapes[i].second; Value rv = op->getResult(i); @@ -218,11 +194,11 @@ class InferencePass : public PassWrappergetName().getStringRef().str() + - " operand " + std::to_string(i) + ", since it " - "is neither a matrix nor a frame" - ); + "shape inference cannot set the shape of op " + + op->getName().getStringRef().str() + " operand " + + std::to_string(i) + + ", since it " + "is neither a matrix nor a frame"); } } } @@ -230,18 +206,19 @@ class InferencePass : public PassWrapper sparsities = daphne::tryInferSparsity(op); const size_t numRes = op->getNumResults(); - if(sparsities.size() != numRes) + if (sparsities.size() != numRes) throw ErrorHandler::compilerError(op, "InferencePass", - "sparsity inference for op " + - op->getName().getStringRef().str() + " returned " + - std::to_string(sparsities.size()) + " shapes, but the " - "op has " + std::to_string(numRes) + " results" - ); + "sparsity inference for op " + + op->getName().getStringRef().str() + " returned " + + std::to_string(sparsities.size()) + + " shapes, but the " + "op has " + + std::to_string(numRes) + " results"); // Set the inferred sparsities on all results of this operation. - for(size_t i = 0 ; i < numRes ; i++) { + for (size_t i = 0; i < numRes; i++) { const double sparsity = sparsities[i]; - if(llvm::isa(op->getResultTypes()[i]) || - llvm::isa(op->getResultTypes()[i])) { + if (llvm::isa(op->getResultTypes()[i]) || + llvm::isa(op->getResultTypes()[i])) { Value rv = op->getResult(i); const Type rt = rv.getType(); auto mt = rt.dyn_cast(); @@ -253,16 +230,17 @@ class InferencePass : public PassWrappergetName().getStringRef().str() + - " operand " + std::to_string(i) + ", since it " - "is not a matrix" - ); + "sparsity inference cannot set the shape of " + "op " + + op->getName().getStringRef().str() + " operand " + + std::to_string(i) + + ", since it " + "is not a matrix"); } } } if (cfg.frameLabelInference && returnsFrameWithUnknownLabels(op)) { - if(auto inferFrameLabelsOp = llvm::dyn_cast(op)) + if (auto inferFrameLabelsOp = llvm::dyn_cast(op)) inferFrameLabelsOp.inferFrameLabels(); // Else: Not a problem, since currently we use the frame labels // only to aid type inference, and for this purpose, we don't @@ -274,35 +252,34 @@ class InferencePass : public PassWrapper(op)) { - Block & beforeBlock = whileOp.getBefore().front(); - Block & afterBlock = whileOp.getAfter().front(); + // the then-branch and the value yielded in the else-branch must have + // the same type in MLIR. At the same time, we encode interesting data + // properties (such as those inferred by this pass) as MLIR type + // parameters. As a consequence, e.g., a matrix with two rows and a + // matrix with three rows are technically different MLIR types. Thus, + // e.g., an IfOp cannot simply yield matrices of different shapes from + // the then- and else-branches. To solve this general problem, and to + // allow control-flow operations to change all properties of a data + // object, we generally set mismatching properties to unknown. The + // details depend on the specific SCF operation. + else if (auto whileOp = llvm::dyn_cast(op)) { + Block &beforeBlock = whileOp.getBefore().front(); + Block &afterBlock = whileOp.getAfter().front(); OpBuilder builder(whileOp.getContext()); // Infer the types/properties inside the loop body. If some property - // of some argument is changed inside the loop body, this property is - // set to unknown for both the argument and the yielded value. If that - // is the case, we need to do the inference anew, with the new set of - // arguments' properties. - // This loop searches a fix-point and always terminates, since we only - // set properties to unknown and in the extreme case, after a finite - // number of iterations all of the arguments' properties will have - // become unknown. - while(true) { + // of some argument is changed inside the loop body, this property + // is set to unknown for both the argument and the yielded value. If + // that is the case, we need to do the inference anew, with the new + // set of arguments' properties. This loop searches a fix-point and + // always terminates, since we only set properties to unknown and in + // the extreme case, after a finite number of iterations all of the + // arguments' properties will have become unknown. + while (true) { bool repeat = false; // Transfer the WhileOp's operand types to the block arguments // of the before-block to fulfill constraints on the WhileOp. - for(size_t i = 0; i < whileOp.getNumOperands(); i++) { + for (size_t i = 0; i < whileOp.getNumOperands(); i++) { Type t = whileOp->getOperand(i).getType(); beforeBlock.getArgument(i).setType(t); } @@ -312,17 +289,17 @@ class InferencePass : public PassWrapper(walkOp); // Get the ConditionOp. - Operation * condOp = beforeBlock.getTerminator(); + Operation *condOp = beforeBlock.getTerminator(); - if(!llvm::isa(condOp)) + if (!llvm::isa(condOp)) throw ErrorHandler::compilerError(op, "InferencePass", "WhileOp terminator is not a ConditionOp"); - // Transfer the ConditionOp's operand types to the block arguments - // of the after-block and the results of the WhileOp to fulfill - // constraints on the WhileOp. - // Note that the first operand of the ConditionOp is skipped, since it - // is the condition value itself. - for(size_t i = 1; i < condOp->getNumOperands(); i++) { + // Transfer the ConditionOp's operand types to the block + // arguments of the after-block and the results of the WhileOp + // to fulfill constraints on the WhileOp. Note that the first + // operand of the ConditionOp is skipped, since it is the + // condition value itself. + for (size_t i = 1; i < condOp->getNumOperands(); i++) { Type t = condOp->getOperand(i).getType(); afterBlock.getArgument(i - 1).setType(t); whileOp.getResult(i - 1).setType(t); @@ -333,72 +310,71 @@ class InferencePass : public PassWrapper(walkOp); // Get the YieldOp. - Operation * yieldOp = afterBlock.getTerminator(); + Operation *yieldOp = afterBlock.getTerminator(); - if(whileOp->getNumOperands() != yieldOp->getNumOperands()) - throw ErrorHandler::compilerError( - op, "InferencePass", - "WhileOp and YieldOp must have the same number of " - "operands"); + if (whileOp->getNumOperands() != yieldOp->getNumOperands()) + throw ErrorHandler::compilerError(op, "InferencePass", + "WhileOp and YieldOp must have the same number of " + "operands"); // Check if the inferred MLIR types match the result MLIR types. - // If any interesting properties were changed inside the loop body, - // we set them to unknown to make the type comparison pass. - for(size_t i = 0; i < whileOp.getNumOperands(); i++) { + // If any interesting properties were changed inside the loop + // body, we set them to unknown to make the type comparison + // pass. + for (size_t i = 0; i < whileOp.getNumOperands(); i++) { Type yieldedTy = yieldOp->getOperand(i).getType(); Type operandTy = op->getOperand(i).getType(); - if(yieldedTy != operandTy) { - // Get a type with the conflicting properties set to unknown. + if (yieldedTy != operandTy) { + // Get a type with the conflicting properties set to + // unknown. Type typeWithCommonInfo = getTypeWithCommonInfo(yieldedTy, operandTy); - if(!typeWithCommonInfo) { - throw ErrorHandler::compilerError( - op, "InferencePass", - "the data type (matrix, frame, scalar) of a " - "variable " - "must not be changed within the body of a " - "while-loop"); + if (!typeWithCommonInfo) { + throw ErrorHandler::compilerError(op, "InferencePass", + "the data type (matrix, frame, scalar) of a " + "variable " + "must not be changed within the body of a " + "while-loop"); } // Use casts to remove those properties accordingly. castOperandIf(builder, yieldOp, i, typeWithCommonInfo); castOperandIf(builder, whileOp, i, typeWithCommonInfo); - // Since the WhileOp's argument types/properties have changed, - // we must repeat the inference for the loop body. + // Since the WhileOp's argument types/properties have + // changed, we must repeat the inference for the loop + // body. repeat = true; } } - if(repeat) { - // Before we can repeat the inference, we reset all information - // inferred so far to unknown (in the loop body). + if (repeat) { + // Before we can repeat the inference, we reset all + // information inferred so far to unknown (in the loop + // body). beforeBlock.walk(walkSetUnknown); afterBlock.walk(walkSetUnknown); - } - else + } else // If all types matched, we are done. break; } // Tell the walker to skip the descendants of the WhileOp, we // have already triggered a walk on them explicitly. return WalkResult::skip(); - } - else if(auto forOp = llvm::dyn_cast(op)) { - Block & block = forOp.getRegion().front(); + } else if (auto forOp = llvm::dyn_cast(op)) { + Block &block = forOp.getRegion().front(); const size_t numIndVars = forOp.getNumInductionVars(); OpBuilder builder(forOp.getContext()); // Infer the types/properties inside the loop body. If some property - // of some argument is changed inside the loop body, this property is - // set to unknown for both the argument and the yielded value. If that - // is the case, we need to do the inference anew, with the new set of - // arguments' properties. - // This loop searches a fix-point and always terminates, since we only - // set properties to unknown and in the extreme case, after a finite - // number of iterations all of the arguments' properties will have - // become unknown. - while(true) { + // of some argument is changed inside the loop body, this property + // is set to unknown for both the argument and the yielded value. If + // that is the case, we need to do the inference anew, with the new + // set of arguments' properties. This loop searches a fix-point and + // always terminates, since we only set properties to unknown and in + // the extreme case, after a finite number of iterations all of the + // arguments' properties will have become unknown. + while (true) { bool repeat = false; // Transfer the ForOp's operand types to the block arguments // and results to fulfill constraints on the ForOp. - for(size_t i = 0; i < forOp.getNumIterOperands(); i++) { + for (size_t i = 0; i < forOp.getNumIterOperands(); i++) { Type t = forOp.getIterOpOperands()[i].get().getType(); block.getArgument(i + numIndVars).setType(t); forOp.getResult(i).setType(t); @@ -409,35 +385,38 @@ class InferencePass : public PassWrapper(walkOp); // Get the YieldOp. - Operation * yieldOp = block.getTerminator(); + Operation *yieldOp = block.getTerminator(); // Check if the inferred MLIR types match the result MLIR types. - // If any interesting properties were changed inside the loop body, - // we set them to unknown to make the type comparison pass. - for(size_t i = 0; i < forOp.getNumIterOperands(); i++) { + // If any interesting properties were changed inside the loop + // body, we set them to unknown to make the type comparison + // pass. + for (size_t i = 0; i < forOp.getNumIterOperands(); i++) { Type yieldedTy = yieldOp->getOperand(i).getType(); Type resultTy = op->getResult(i).getType(); - if(yieldedTy != resultTy) { - // Get a type with the conflicting properties set to unknown. + if (yieldedTy != resultTy) { + // Get a type with the conflicting properties set to + // unknown. Type typeWithCommonInfo = getTypeWithCommonInfo(yieldedTy, resultTy); - if(!typeWithCommonInfo) - throw ErrorHandler::compilerError( - op, "InferencePass.cpp:" + std::to_string(__LINE__), - "the data type (matrix, frame, scalar) of a " - "variable " - "must not be changed within the body of a " - "for-loop."); + if (!typeWithCommonInfo) + throw ErrorHandler::compilerError(op, "InferencePass.cpp:" + std::to_string(__LINE__), + "the data type (matrix, frame, scalar) of a " + "variable " + "must not be changed within the body of a " + "for-loop."); // Use casts to remove those properties accordingly. castOperandIf(builder, yieldOp, i, typeWithCommonInfo); castOperandIf(builder, forOp, forOp.getNumControlOperands() + i, typeWithCommonInfo); - // Since the WhileOp's argument types/properties have changed, - // we must repeat the inference for the loop body. + // Since the WhileOp's argument types/properties have + // changed, we must repeat the inference for the loop + // body. repeat = true; } } - if(repeat) - // Before we can repeat the inference, we reset all information - // inferred so far to unknown (in the loop body). + if (repeat) + // Before we can repeat the inference, we reset all + // information inferred so far to unknown (in the loop + // body). block.walk(walkSetUnknown); else // If all types matched, we are done. @@ -446,12 +425,11 @@ class InferencePass : public PassWrapper(op)) { + } else if (auto ifOp = llvm::dyn_cast(op)) { // Walk the then/else blocks first. We need the inference on // them before we can do anything about the IfOp itself. ifOp.thenBlock()->walk(walkOp); - if(ifOp.elseBlock()) { + if (ifOp.elseBlock()) { ifOp.elseBlock()->walk(walkOp); // For all pairs of corresponding values yielded in the @@ -461,17 +439,14 @@ class InferencePass : public PassWrappergetOperand(i).getType(), - elseYield->getOperand(i).getType() - ); - if(!typeWithCommonInfo) - throw ErrorHandler::compilerError( - op, "InferencePass" + std::to_string(__LINE__), - "a variable must not be assigned values of " - "different data types (matrix, frame, scalar) " - "in then/else branches"); + for (size_t i = 0; i < ifOp.getNumResults(); i++) { + Type typeWithCommonInfo = + getTypeWithCommonInfo(thenYield->getOperand(i).getType(), elseYield->getOperand(i).getType()); + if (!typeWithCommonInfo) + throw ErrorHandler::compilerError(op, "InferencePass" + std::to_string(__LINE__), + "a variable must not be assigned values of " + "different data types (matrix, frame, scalar) " + "in then/else branches"); castOperandIf(builder, thenYield, i, typeWithCommonInfo); castOperandIf(builder, elseYield, i, typeWithCommonInfo); ifOp.getResult(i).setType(typeWithCommonInfo); @@ -486,7 +461,7 @@ class InferencePass : public PassWrapper(walkOp); } catch (std::runtime_error &re) { - throw ErrorHandler::rethrowError( - "InferencePass.cpp:" + std::to_string(__LINE__), re.what()); + throw ErrorHandler::rethrowError("InferencePass.cpp:" + std::to_string(__LINE__), re.what()); } // infer function return types - f.setType(FunctionType::get(&getContext(), - f.getFunctionType().getInputs(), - f.getBody().back().getTerminator()->getOperandTypes())); + f.setType(FunctionType::get(&getContext(), f.getFunctionType().getInputs(), + f.getBody().back().getTerminator()->getOperandTypes())); } static bool returnsUnknownType(Operation *op) { return llvm::any_of(op->getResultTypes(), [](Type resType) { - if(llvm::isa(resType)) + if (llvm::isa(resType)) return true; - if(auto mt = resType.dyn_cast()) + if (auto mt = resType.dyn_cast()) return llvm::isa(mt.getElementType()); - if(auto ft = resType.dyn_cast()) - for(Type ct : ft.getColumnTypes()) - if(llvm::isa(ct)) + if (auto ft = resType.dyn_cast()) + for (Type ct : ft.getColumnTypes()) + if (llvm::isa(ct)) return true; return false; }); } - static bool returnsFrameWithUnknownLabels(Operation * op) { + static bool returnsFrameWithUnknownLabels(Operation *op) { return llvm::any_of(op->getResultTypes(), [](Type resultType) { auto ft = resultType.dyn_cast(); return ft && !ft.getLabels(); }); } - static bool returnsUnknownShape(Operation * op) { + static bool returnsUnknownShape(Operation *op) { return llvm::any_of(op->getResultTypes(), [](Type rt) { - if(auto mt = rt.dyn_cast()) + if (auto mt = rt.dyn_cast()) return mt.getNumRows() == -1 || mt.getNumCols() == -1; - if(auto ft = rt.dyn_cast()) + if (auto ft = rt.dyn_cast()) return ft.getNumRows() == -1 || ft.getNumCols() == -1; return false; }); } - static bool returnsUnknownSparsity(Operation * op) { + static bool returnsUnknownSparsity(Operation *op) { return llvm::any_of(op->getResultTypes(), [](Type rt) { - if(auto mt = rt.dyn_cast()) + if (auto mt = rt.dyn_cast()) return mt.getSparsity() == -1.0; return false; }); diff --git a/src/compiler/inference/SelectMatrixRepresentationsPass.cpp b/src/compiler/inference/SelectMatrixRepresentationsPass.cpp index b05c996d9..f49fa05f1 100644 --- a/src/compiler/inference/SelectMatrixRepresentationsPass.cpp +++ b/src/compiler/inference/SelectMatrixRepresentationsPass.cpp @@ -22,26 +22,27 @@ #include #include -#include #include +#include using namespace mlir; -class SelectMatrixRepresentationsPass : public PassWrapper> { - const DaphneUserConfig& cfg; +class SelectMatrixRepresentationsPass + : public PassWrapper> { + const DaphneUserConfig &cfg; - std::function walkOp = [&](Operation * op) { - if(returnsKnownProperties(op)) { + std::function walkOp = [&](Operation *op) { + if (returnsKnownProperties(op)) { const bool isScfOp = op->getDialect() == op->getContext()->getOrLoadDialect(); // ---------------------------------------------------------------- // Handle all non-SCF operations // ---------------------------------------------------------------- - if(!isScfOp) { + if (!isScfOp) { // Set the matrix representation for all result types - for(auto res : op->getResults()) { - if(auto matTy = res.getType().dyn_cast()) { + for (auto res : op->getResults()) { + if (auto matTy = res.getType().dyn_cast()) { const double sparsity = matTy.getSparsity(); - if(sparsity < cfg.sparsity_threshold) { + if (sparsity < cfg.sparsity_threshold) { res.setType(matTy.withRepresentation(daphne::MatrixRepresentation::Sparse)); } } @@ -52,12 +53,12 @@ class SelectMatrixRepresentationsPass : public PassWrapper(op)) { + else if (auto whileOp = llvm::dyn_cast(op)) { Block &beforeBlock = whileOp.getBefore().front(); Block &afterBlock = whileOp.getAfter().front(); // Transfer the WhileOp's operand types to the block arguments // and results to fulfill constraints on the WhileOp. - for(size_t i = 0 ; i < whileOp.getNumOperands() ; i++) { + for (size_t i = 0; i < whileOp.getNumOperands(); i++) { Type t = whileOp->getOperand(i).getType(); beforeBlock.getArgument(i).setType(t); afterBlock.getArgument(i).setType(t); @@ -68,31 +69,30 @@ class SelectMatrixRepresentationsPass : public PassWrapper(walkOp); afterBlock.walk(walkOp); - // Check if the inferred matrix representations match the required result representations. - // This is not the case if, for instance, the representation of some - // variable written in the loop changes. The WhileOp would also - // check this later during verification, but here, we want to - // throw a readable error message. + // Check if the inferred matrix representations match the + // required result representations. This is not the case if, for + // instance, the representation of some variable written in the + // loop changes. The WhileOp would also check this later during + // verification, but here, we want to throw a readable error + // message. Operation *yieldOp = afterBlock.getTerminator(); - for(size_t i = 0 ; i < whileOp.getNumOperands() ; i++) { + for (size_t i = 0; i < whileOp.getNumOperands(); i++) { Type yieldedTy = yieldOp->getOperand(i).getType(); Type resultTy = op->getResult(i).getType(); if (yieldedTy != resultTy) - throw ErrorHandler::compilerError( - whileOp, "SelectMatrixRepresentationsPass", - "the representation of a matrix must not be " - "changed within the body of a while-loop."); + throw ErrorHandler::compilerError(whileOp, "SelectMatrixRepresentationsPass", + "the representation of a matrix must not be " + "changed within the body of a while-loop."); } // Tell the walker to skip the descendants of the WhileOp, we // have already triggered a walk on them explicitly. return WalkResult::skip(); - } - else if(auto forOp = llvm::dyn_cast(op)) { + } else if (auto forOp = llvm::dyn_cast(op)) { Block &block = forOp.getRegion().front(); const size_t numIndVars = forOp.getNumInductionVars(); // Transfer the ForOp's operand types to the block arguments // and results to fulfill constraints on the ForOp. - for(size_t i = 0 ; i < forOp.getNumIterOperands() ; i++) { + for (size_t i = 0; i < forOp.getNumIterOperands(); i++) { Type t = forOp.getIterOperands()[i].getType(); block.getArgument(i + numIndVars).setType(t); forOp.getResult(i).setType(t); @@ -100,45 +100,43 @@ class SelectMatrixRepresentationsPass : public PassWrapper(walkOp); - // Check if the infered matrix representations match the required result representations. - // This is not the case if, for instance, the representation of some - // variable written in the loop changes. The ForOp would also - // check this later during verification, but here, we want to - // throw a readable error message. + // Check if the infered matrix representations match the + // required result representations. This is not the case if, for + // instance, the representation of some variable written in the + // loop changes. The ForOp would also check this later during + // verification, but here, we want to throw a readable error + // message. Operation *yieldOp = block.getTerminator(); - for(size_t i = 0 ; i < forOp.getNumIterOperands() ; i++) { + for (size_t i = 0; i < forOp.getNumIterOperands(); i++) { Type yieldedTy = yieldOp->getOperand(i).getType(); Type resultTy = op->getResult(i).getType(); if (yieldedTy != resultTy) - throw ErrorHandler::compilerError( - forOp, "SelectMatrixRepresentationsPass", - "the representation of a matrix must not be " - "changed within the body of a for-loop"); + throw ErrorHandler::compilerError(forOp, "SelectMatrixRepresentationsPass", + "the representation of a matrix must not be " + "changed within the body of a for-loop"); } // Tell the walker to skip the descendants of the ForOp, we // have already triggered a walk on them explicitly. return WalkResult::skip(); - } - else if(auto ifOp = llvm::dyn_cast(op)) { + } else if (auto ifOp = llvm::dyn_cast(op)) { // Walk the then/else blocks first. We need the inference on // them before we can do anything about the IfOp itself. ifOp.thenBlock()->walk(walkOp); ifOp.elseBlock()->walk(walkOp); - // Check if the yielded matrix representations are the same in both - // branches. The IfOp would also check this later during + // Check if the yielded matrix representations are the same in + // both branches. The IfOp would also check this later during // verification, but here, we want to throw a readable error // message. // Additionally, we set the result types of the IfOp here. scf::YieldOp thenYield = ifOp.thenYield(); scf::YieldOp elseYield = ifOp.elseYield(); - for(size_t i = 0 ; i < ifOp.getNumResults() ; i++) { + for (size_t i = 0; i < ifOp.getNumResults(); i++) { Type thenTy = thenYield->getOperand(i).getType(); Type elseTy = elseYield->getOperand(i).getType(); if (thenTy != elseTy) - throw ErrorHandler::compilerError( - ifOp, "SelectMatrixRepresentationsPass", - "a matrix must not be assigned two values of " - "different representations in then/else branches"); + throw ErrorHandler::compilerError(ifOp, "SelectMatrixRepresentationsPass", + "a matrix must not be assigned two values of " + "different representations in then/else branches"); ifOp.getResult(i).setType(thenTy); } // Tell the walker to skip the descendants of the IfOp, we @@ -150,17 +148,16 @@ class SelectMatrixRepresentationsPass : public PassWrapper(walkOp); // infer function return types // TODO: cast for UDFs? - f.setType(FunctionType::get(&getContext(), - f.getFunctionType().getInputs(), - f.getBody().back().getTerminator()->getOperandTypes())); + f.setType(FunctionType::get(&getContext(), f.getFunctionType().getInputs(), + f.getBody().back().getTerminator()->getOperandTypes())); } StringRef getArgument() const final { return "select-matrix-representations"; } @@ -168,13 +165,13 @@ class SelectMatrixRepresentationsPass : public PassWrappergetResultTypes(), [](Type rt) { - if(auto mt = rt.dyn_cast()) + if (auto mt = rt.dyn_cast()) return mt.getSparsity() != -1.0; return false; }); } }; -std::unique_ptr daphne::createSelectMatrixRepresentationsPass(const DaphneUserConfig& cfg) { +std::unique_ptr daphne::createSelectMatrixRepresentationsPass(const DaphneUserConfig &cfg) { return std::make_unique(cfg); } diff --git a/src/compiler/inference/TypeInferenceUtils.cpp b/src/compiler/inference/TypeInferenceUtils.cpp index e9a3a5365..aba2d3611 100644 --- a/src/compiler/inference/TypeInferenceUtils.cpp +++ b/src/compiler/inference/TypeInferenceUtils.cpp @@ -18,67 +18,73 @@ int generality(mlir::Type t) { using namespace mlir; - + // TODO It is debatable if unsigned int shall be more general than signed // int of the same bit width. - + // The greater the number, the more general the type. - if(llvm::isa(t)) return 11; - if(llvm::isa(t)) return 10; - if(t.isF64()) return 9; - if(t.isF32()) return 8; - if(t.isUnsignedInteger(64)) return 7; - if(t. isSignedInteger(64)) return 6; - if(t.isIndex()) return 5; - if(t.isUnsignedInteger(32)) return 4; - if(t. isSignedInteger(32)) return 3; - if(t.isUnsignedInteger(8)) return 2; - if(t. isSignedInteger(8)) return 1; - if(t. isInteger(1)) return 0; - + if (llvm::isa(t)) + return 11; + if (llvm::isa(t)) + return 10; + if (t.isF64()) + return 9; + if (t.isF32()) + return 8; + if (t.isUnsignedInteger(64)) + return 7; + if (t.isSignedInteger(64)) + return 6; + if (t.isIndex()) + return 5; + if (t.isUnsignedInteger(32)) + return 4; + if (t.isSignedInteger(32)) + return 3; + if (t.isUnsignedInteger(8)) + return 2; + if (t.isSignedInteger(8)) + return 1; + if (t.isInteger(1)) + return 0; + std::string str; llvm::raw_string_ostream msg(str); msg << "no generality code available for value type: " << t; throw std::runtime_error(msg.str()); } -mlir::Type mostGeneralVt(const std::vector & vt) { - if(vt.empty()) - throw std::runtime_error( - "mostGeneralVt() invoked with empty list of value types" - ); - +mlir::Type mostGeneralVt(const std::vector &vt) { + if (vt.empty()) + throw std::runtime_error("mostGeneralVt() invoked with empty list of value types"); + mlir::Type res = vt[0]; - for(size_t i = 1; i < vt.size(); i++) - if(generality(vt[i]) > generality(res)) + for (size_t i = 1; i < vt.size(); i++) + if (generality(vt[i]) > generality(res)) res = vt[i]; - + return res; } -mlir::Type mostGeneralVt(const std::vector> & vts, size_t num) { - if(vts.empty()) - throw std::runtime_error( - "mostGeneralVt() invoked with empty list of lists of value types" - ); - - if(num == 0) +mlir::Type mostGeneralVt(const std::vector> &vts, size_t num) { + if (vts.empty()) + throw std::runtime_error("mostGeneralVt() invoked with empty list of lists of value types"); + + if (num == 0) num = vts.size(); - + mlir::Type res = mostGeneralVt(vts[0]); - for(size_t i = 1; i < std::min(vts.size(), num); i++) { + for (size_t i = 1; i < std::min(vts.size(), num); i++) { mlir::Type cur = mostGeneralVt(vts[i]); - if(generality(cur) > generality(res)) + if (generality(cur) > generality(res)) res = cur; } - + return res; } -std::vector inferValueTypeFromArgs( - const std::vector & argDtc, - std::vector> & argVts -) { +std::vector inferValueTypeFromArgs(const std::vector &argDtc, + std::vector> &argVts) { // TODO Simplify: resDtc is already known. If it's not Frame, this // can be done simpler and we don't need the getMostGeneralVt later. @@ -86,31 +92,29 @@ std::vector inferValueTypeFromArgs( // arguments to match the number of column types of frame arguments. size_t commonNumFrameCols = 1; bool hasFrame = false; - for(size_t i = 0; i < argVts.size(); i++) - if(argDtc[i] == DataTypeCode::FRAME) { - if(hasFrame && argVts[i].size() != commonNumFrameCols) - throw std::runtime_error( - "type inference trait ValueTypeFromArgs requires that " - "all input frames have the same number of columns" - ); + for (size_t i = 0; i < argVts.size(); i++) + if (argDtc[i] == DataTypeCode::FRAME) { + if (hasFrame && argVts[i].size() != commonNumFrameCols) + throw std::runtime_error("type inference trait ValueTypeFromArgs requires that " + "all input frames have the same number of columns"); hasFrame = true; commonNumFrameCols = argVts[i].size(); } // If required: Expand the value type of matrix and scalar arguments to // match the common number of column types of frame arguments. - if(hasFrame) - for(size_t i = 0; i < argVts.size(); i++) - if(argDtc[i] != DataTypeCode::FRAME) + if (hasFrame) + for (size_t i = 0; i < argVts.size(); i++) + if (argDtc[i] != DataTypeCode::FRAME) argVts[i] = std::vector(commonNumFrameCols, argVts[i][0]); // Determine the most general argument value type. This is done for each // column separately, if frames are involved. std::vector resVts = argVts[0]; - for(size_t i = 1; i < argVts.size(); i++) - for(size_t k = 0; k < commonNumFrameCols; k++) - if(generality(argVts[i][k]) > generality(resVts[k])) + for (size_t i = 1; i < argVts.size(); i++) + for (size_t k = 0; k < commonNumFrameCols; k++) + if (generality(argVts[i][k]) > generality(resVts[k])) resVts[k] = argVts[i][k]; - + return resVts; } diff --git a/src/compiler/inference/TypeInferenceUtils.h b/src/compiler/inference/TypeInferenceUtils.h index e03f64373..5de58c8e2 100644 --- a/src/compiler/inference/TypeInferenceUtils.h +++ b/src/compiler/inference/TypeInferenceUtils.h @@ -18,20 +18,20 @@ #include -#include #include +#include #include #include /** * @brief Returns an integer code representing how general a value type is. - * + * * This code can be used to determine which of two value types is more general. * The larger the code, the more general the value type. - * + * * @param t - * @return + * @return */ int generality(mlir::Type t); @@ -51,54 +51,48 @@ enum class DataTypeCode : uint8_t { /** * @brief Returns the most general value type in a list of value types. - * + * * @param vt A list of value types. - * @return + * @return */ -mlir::Type mostGeneralVt(const std::vector & vt); +mlir::Type mostGeneralVt(const std::vector &vt); /** * @brief Returns the most general value type in a list of lists of value types. - * + * * @param vts A list of lists of value types. * @param num Optionally, only consider the first `num` lists of value types. - * @return + * @return */ -mlir::Type mostGeneralVt( - const std::vector> & vts, - size_t num = 0 -); +mlir::Type mostGeneralVt(const std::vector> &vts, size_t num = 0); /** * @brief Infers the value type assuming the type inference trait * `ValueTypeFromArgs`. - * + * * @param argDtc Information on the argument data types. * @param argVts Information on the argument value types. * @return The infered value type. */ -std::vector inferValueTypeFromArgs( - const std::vector & argDtc, - std::vector> & argVts -); +std::vector inferValueTypeFromArgs(const std::vector &argDtc, + std::vector> &argVts); /** * @brief Infers the type of the result of the given operation based on its * type inference traits. - * + * * @tparam O The type of the operation. For the inference in the compiler we * use `mlir::Operation`, but for the unit tests we use a mock class. * @param op * @return The infered type of the single result of the operation. */ -template -mlir::Type inferTypeByTraits(O * op) { +template mlir::Type inferTypeByTraits(O *op) { using namespace mlir; using namespace mlir::OpTrait; - - MLIRContext * ctx = op->getContext(); + + MLIRContext *ctx = op->getContext(); Type u = daphne::UnknownType::get(ctx); - + Type resTy = u; // -------------------------------------------------------------------- @@ -107,20 +101,17 @@ mlir::Type inferTypeByTraits(O * op) { std::vector argDtc; std::vector> argVts; - for(Type t : op->getOperandTypes()) { - if(llvm::isa(t)) { + for (Type t : op->getOperandTypes()) { + if (llvm::isa(t)) { argDtc.push_back(DataTypeCode::UNKNOWN); argVts.push_back({u}); - } - else if(auto ft = t.dyn_cast()) { + } else if (auto ft = t.dyn_cast()) { argDtc.push_back(DataTypeCode::FRAME); argVts.push_back(ft.getColumnTypes()); - } - else if(auto mt = t.dyn_cast()) { + } else if (auto mt = t.dyn_cast()) { argDtc.push_back(DataTypeCode::MATRIX); argVts.push_back({mt.getElementType()}); - } - else { // TODO Check if this is really a supported scalar type! + } else { // TODO Check if this is really a supported scalar type! argDtc.push_back(DataTypeCode::SCALAR); argVts.push_back({t}); } @@ -132,19 +123,18 @@ mlir::Type inferTypeByTraits(O * op) { DataTypeCode resDtc = DataTypeCode::UNKNOWN; - if(op->template hasTrait() || op->template hasTrait()) + if (op->template hasTrait() || op->template hasTrait()) resDtc = argDtc[0]; - else if(op->template hasTrait()) { + else if (op->template hasTrait()) { resDtc = argDtc[0]; - for(size_t i = 1; i < argDtc.size(); i++) - if(argDtc[i] > resDtc) // generality comparison + for (size_t i = 1; i < argDtc.size(); i++) + if (argDtc[i] > resDtc) // generality comparison resDtc = argDtc[i]; - } - else if(op->template hasTrait()) + } else if (op->template hasTrait()) resDtc = DataTypeCode::SCALAR; - else if(op->template hasTrait()) + else if (op->template hasTrait()) resDtc = DataTypeCode::MATRIX; - else if(op->template hasTrait()) + else if (op->template hasTrait()) resDtc = DataTypeCode::FRAME; // -------------------------------------------------------------------- @@ -154,158 +144,140 @@ mlir::Type inferTypeByTraits(O * op) { // TODO What about the #cols, if the data type is a frame (see #421)? std::vector resVts = {u}; - if(op->template hasTrait()) { + if (op->template hasTrait()) { // Initially take the most general value type of the arguments, // resVts has one element for scalars and matrices, or one element // per column for frames. resVts = inferValueTypeFromArgs(argDtc, argVts); // Replace string by si64. Otherwise, we would represent the result // of the comparison of two strings as a string. - for(size_t i = 0; i < resVts.size(); i++) - if(llvm::isa(resVts[i])) + for (size_t i = 0; i < resVts.size(); i++) + if (llvm::isa(resVts[i])) resVts[i] = IntegerType::get(ctx, 64, IntegerType::SignednessSemantics::Signed); - } - else if(op->template hasTrait()) + } else if (op->template hasTrait()) resVts = argVts[0]; - else if(op->template hasTrait()) { - if(resDtc == DataTypeCode::FRAME && argDtc[0] == DataTypeCode::MATRIX) { + else if (op->template hasTrait()) { + if (resDtc == DataTypeCode::FRAME && argDtc[0] == DataTypeCode::MATRIX) { // We need to make sure that the value type of the input matrix is // repeated in the column value types of the output frame to match // the number of columns of the input matrix. - const ssize_t numCols = op->getOperand(0) - .getType() - .template dyn_cast() - .getNumCols(); - if(numCols == -1) + const ssize_t numCols = op->getOperand(0).getType().template dyn_cast().getNumCols(); + if (numCols == -1) // The input's number of columns is unknown. resVts = {u}; // TODO How to properly represent such cases (see #421)? else // The input's number of columns is known. resVts = std::vector(numCols, argVts[0][0]); - } - else + } else // Even if the first arg is a frame, its column types get collapsed // to the most general type later on. resVts = argVts[0]; } - // TODO Reduce the code duplication. Merge the traits ValueTypeFromFirstArg and - // ValueTypeFromThirdArg into one parametric trait ValueTypeFromArg, see #487. - else if(op->template hasTrait()) { - if(resDtc == DataTypeCode::FRAME && argDtc[2] == DataTypeCode::MATRIX) { + // TODO Reduce the code duplication. Merge the traits ValueTypeFromFirstArg + // and ValueTypeFromThirdArg into one parametric trait ValueTypeFromArg, + // see #487. + else if (op->template hasTrait()) { + if (resDtc == DataTypeCode::FRAME && argDtc[2] == DataTypeCode::MATRIX) { // We need to make sure that the value type of the input matrix is // repeated in the column value types of the output frame to match // the number of columns of the input matrix. - const ssize_t numCols = op->getOperand(2) - .getType() - .template dyn_cast() - .getNumCols(); - if(numCols == -1) + const ssize_t numCols = op->getOperand(2).getType().template dyn_cast().getNumCols(); + if (numCols == -1) // The input's number of columns is unknown. resVts = {u}; // TODO How to properly represent such cases (see #421)? else // The input's number of columns is known. resVts = std::vector(numCols, argVts[2][0]); - } - else + } else // Even if the third arg is a frame, its column types get collapsed // to the most general type later on. resVts = argVts[2]; - } - else if(op->template hasTrait()) + } else if (op->template hasTrait()) resVts = inferValueTypeFromArgs(argDtc, argVts); - else if(op->template hasTrait()) { + else if (op->template hasTrait()) { // Get the most general value types... resVts = inferValueTypeFromArgs(argDtc, argVts); // ...and replace them by the most general floating-point type where // necessary. - for(size_t i = 0; i < resVts.size(); i++) - if(!llvm::isa(resVts[i]) && !llvm::isa(resVts[i])) + for (size_t i = 0; i < resVts.size(); i++) + if (!llvm::isa(resVts[i]) && !llvm::isa(resVts[i])) resVts[i] = FloatType::getF64(ctx); - } - else if(op->template hasTrait()) { + } else if (op->template hasTrait()) { // Get the most general value types... resVts = inferValueTypeFromArgs(argDtc, argVts); // ...and replace them by the most general integer type where // necessary. - for(size_t i = 0; i < resVts.size(); i++) - if(!llvm::isa(resVts[i]) && !llvm::isa(resVts[i])) - resVts[i] = IntegerType::get( - ctx, 64, IntegerType::SignednessSemantics::Unsigned - ); - } - else if(op->template hasTrait()) { + for (size_t i = 0; i < resVts.size(); i++) + if (!llvm::isa(resVts[i]) && !llvm::isa(resVts[i])) + resVts[i] = IntegerType::get(ctx, 64, IntegerType::SignednessSemantics::Unsigned); + } else if (op->template hasTrait()) { const size_t numArgsConsider = 2; - if(argVts.size() < numArgsConsider) - throw std::runtime_error( - "type inference trait ValueTypesConcat requires at least " - "two arguments" - ); + if (argVts.size() < numArgsConsider) + throw std::runtime_error("type inference trait ValueTypesConcat requires at least " + "two arguments"); - switch(resDtc) { - case DataTypeCode::FRAME: - resVts = {}; - for(size_t i = 0; i < numArgsConsider; i++) { - bool abort = false; - switch(argDtc[i]) { - case DataTypeCode::FRAME: - // Append this input frame's column types to the - // result's column types. - for(size_t k = 0; k < argVts[i].size(); k++) - resVts.push_back(argVts[i][k]); - break; - case DataTypeCode::MATRIX: { - const ssize_t numCols = op->getOperand(i) - .getType() - .template dyn_cast() - .getNumCols(); - if(numCols == -1) { - // The number of columns of this input matrix - // is unknown, so it is unclear how often its - // value type needs to be appended to the - // result column types. - resVts = {u}; // TODO How to best represent this case (see #421)? - abort = true; - } - else - // The number of columns of this input matrix - // is known, so we append its value type to the - // result column types that number of times. - for(ssize_t k = 0; k < numCols; k++) - resVts.push_back(argVts[i][0]); - break; - } - case DataTypeCode::SCALAR: - // Append the value type of this input scalar to - // the result column types. + switch (resDtc) { + case DataTypeCode::FRAME: + resVts = {}; + for (size_t i = 0; i < numArgsConsider; i++) { + bool abort = false; + switch (argDtc[i]) { + case DataTypeCode::FRAME: + // Append this input frame's column types to the + // result's column types. + for (size_t k = 0; k < argVts[i].size(); k++) + resVts.push_back(argVts[i][k]); + break; + case DataTypeCode::MATRIX: { + const ssize_t numCols = + op->getOperand(i).getType().template dyn_cast().getNumCols(); + if (numCols == -1) { + // The number of columns of this input matrix + // is unknown, so it is unclear how often its + // value type needs to be appended to the + // result column types. + resVts = {u}; // TODO How to best represent this case + // (see #421)? + abort = true; + } else + // The number of columns of this input matrix + // is known, so we append its value type to the + // result column types that number of times. + for (ssize_t k = 0; k < numCols; k++) resVts.push_back(argVts[i][0]); - break; - case DataTypeCode::UNKNOWN: - // It is unclear how this input contributes to - // the result's column types. - resVts = {u}; // TODO How to best represent this case (see #421)? - abort = true; - break; - } - if(abort) - break; + break; + } + case DataTypeCode::SCALAR: + // Append the value type of this input scalar to + // the result column types. + resVts.push_back(argVts[i][0]); + break; + case DataTypeCode::UNKNOWN: + // It is unclear how this input contributes to + // the result's column types. + resVts = {u}; // TODO How to best represent this case (see #421)? + abort = true; + break; } - break; - case DataTypeCode::MATRIX: // fall-through intended - case DataTypeCode::SCALAR: - resVts = {mostGeneralVt(argVts, numArgsConsider)}; - break; - case DataTypeCode::UNKNOWN: - // nothing to do - break; + if (abort) + break; + } + break; + case DataTypeCode::MATRIX: // fall-through intended + case DataTypeCode::SCALAR: + resVts = {mostGeneralVt(argVts, numArgsConsider)}; + break; + case DataTypeCode::UNKNOWN: + // nothing to do + break; } - } - else if(op->template hasTrait()) + } else if (op->template hasTrait()) resVts = {IntegerType::get(ctx, 64, IntegerType::SignednessSemantics::Signed)}; - else if(op->template hasTrait()) + else if (op->template hasTrait()) resVts = {IndexType::get(ctx)}; - else if(op->template hasTrait()) + else if (op->template hasTrait()) resVts = {daphne::StringType::get(ctx)}; // -------------------------------------------------------------------- @@ -314,20 +286,20 @@ mlir::Type inferTypeByTraits(O * op) { // It is important to recreate matrix and frame types (not reuse those from // the inputs) to get rid of any additional properties (shape, etc.). - switch(resDtc) { - case DataTypeCode::UNKNOWN: - resTy = u; - break; - case DataTypeCode::SCALAR: - resTy = mostGeneralVt(resVts); - break; - case DataTypeCode::MATRIX: - resTy = daphne::MatrixType::get(ctx, mostGeneralVt(resVts)); - break; - case DataTypeCode::FRAME: { - resTy = daphne::FrameType::get(ctx, resVts); - break; - } + switch (resDtc) { + case DataTypeCode::UNKNOWN: + resTy = u; + break; + case DataTypeCode::SCALAR: + resTy = mostGeneralVt(resVts); + break; + case DataTypeCode::MATRIX: + resTy = daphne::MatrixType::get(ctx, mostGeneralVt(resVts)); + break; + case DataTypeCode::FRAME: { + resTy = daphne::FrameType::get(ctx, resVts); + break; + } } // Note that all our type inference traits assume that the operation has diff --git a/src/compiler/lowering/AggAllOpLowering.cpp b/src/compiler/lowering/AggAllOpLowering.cpp index e6398803f..031f814d0 100644 --- a/src/compiler/lowering/AggAllOpLowering.cpp +++ b/src/compiler/lowering/AggAllOpLowering.cpp @@ -58,128 +58,120 @@ using namespace mlir; class SumAllOpLowering : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - SumAllOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) - : mlir::OpConversionPattern(typeConverter, ctx) { - this->setDebugName("SumAllOpLowering"); - } - // Float and Integer value type matrices have to be handled separately, since - // arith operations are different. - LogicalResult - matchAndRewrite(daphne::AllAggSumOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - mlir::daphne::MatrixType matrixType = - adaptor.getArg().getType().dyn_cast(); - auto loc = op->getLoc(); - auto nR = matrixType.getNumRows(); - auto nC = matrixType.getNumCols(); - - auto matrixElementType = matrixType.getElementType(); - auto memRefType = mlir::MemRefType::get({nR, nC}, matrixElementType); - auto memRef = rewriter.create( - op->getLoc(), memRefType, adaptor.getArg()); - - if (matrixElementType.isIntOrIndex()) { - IntegerType signless_type = - rewriter.getIntegerType(matrixElementType.getIntOrFloatBitWidth()); - Value sum = rewriter.create( - loc, signless_type, rewriter.getIntegerAttr(signless_type, 0)); - - SmallVector loopIvs; - SmallVector forOps; - auto outerLoop = - rewriter.create(loc, 0, nR, 1, ValueRange{sum}); - for (Operation &nested : *outerLoop.getBody()) { - rewriter.eraseOp(&nested); - } - loopIvs.push_back(outerLoop.getInductionVar()); - // outer loop body - rewriter.setInsertionPointToStart(outerLoop.getBody()); - Value sum_iter = rewriter.create( - loc, signless_type, rewriter.getIntegerAttr(signless_type, 0)); - // inner loop - auto innerLoop = - rewriter.create(loc, 0, nC, 1, ValueRange{sum_iter}); - for (Operation &nested : *innerLoop.getBody()) { - rewriter.eraseOp(&nested); - } - loopIvs.push_back(innerLoop.getInductionVar()); - // inner loop body - rewriter.setInsertionPointToStart(innerLoop.getBody()); - // load value from memref - Value elementLoad = rewriter.create(loc, memRef, loopIvs); - auto castedElement = this->typeConverter->materializeSourceConversion( - rewriter, loc, signless_type, ValueRange{elementLoad}); - // sum loop iter arg and memref value - mlir::Value inner_sum = rewriter.create( - loc, innerLoop.getRegionIterArgs()[0], castedElement); - // yield inner loop result - rewriter.setInsertionPointToEnd(innerLoop.getBody()); - rewriter.create(loc, inner_sum); - // yield outer loop result - rewriter.setInsertionPointToEnd(outerLoop.getBody()); - mlir::Value outer_sum = rewriter.create( - loc, outerLoop.getRegionIterArgs()[0], innerLoop.getResult(0)); - rewriter.create(loc, outer_sum); - - rewriter.setInsertionPointAfter(outerLoop); - rewriter.create(loc, adaptor.getArg()); - // replace sumAll op with result of loops - auto castedRes = this->typeConverter->materializeTargetConversion( - rewriter, loc, matrixElementType, - ValueRange{outerLoop->getResult(0)}); - rewriter.replaceOp(op, ValueRange{castedRes}); - - return success(); - } else { - Value sum = rewriter.create( - loc, matrixElementType, rewriter.getFloatAttr(matrixElementType, 0)); - - SmallVector loopIvs; - SmallVector forOps; - auto outerLoop = - rewriter.create(loc, 0, nR, 1, ValueRange{sum}); - for (Operation &nested : *outerLoop.getBody()) { - rewriter.eraseOp(&nested); - } - loopIvs.push_back(outerLoop.getInductionVar()); - // outer loop body - rewriter.setInsertionPointToStart(outerLoop.getBody()); - Value sum_iter = rewriter.create( - loc, matrixElementType, rewriter.getFloatAttr(matrixElementType, 0)); - // inner loop - auto innerLoop = - rewriter.create(loc, 0, nC, 1, ValueRange{sum_iter}); - for (Operation &nested : *innerLoop.getBody()) { - rewriter.eraseOp(&nested); - } - loopIvs.push_back(innerLoop.getInductionVar()); - // inner loop body - rewriter.setInsertionPointToStart(innerLoop.getBody()); - // load value from memref - auto elementLoad = rewriter.create(loc, memRef, loopIvs); - // sum loop iter arg and memref value - mlir::Value inner_sum = rewriter.create( - loc, innerLoop.getRegionIterArgs()[0], elementLoad); - // yield inner loop result - rewriter.setInsertionPointToEnd(innerLoop.getBody()); - rewriter.create(loc, inner_sum); - // yield outer loop result - rewriter.setInsertionPointToEnd(outerLoop.getBody()); - mlir::Value outer_sum = rewriter.create( - loc, outerLoop.getRegionIterArgs()[0], innerLoop.getResult(0)); - rewriter.create(loc, outer_sum); - - rewriter.setInsertionPointAfter(outerLoop); - rewriter.create(loc, adaptor.getArg()); - // replace sumAll op with result of loops - rewriter.replaceOp(op, outerLoop.getResult(0)); - - return success(); + public: + using OpConversionPattern::OpConversionPattern; + + SumAllOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : mlir::OpConversionPattern(typeConverter, ctx) { + this->setDebugName("SumAllOpLowering"); + } + // Float and Integer value type matrices have to be handled separately, + // since arith operations are different. + LogicalResult matchAndRewrite(daphne::AllAggSumOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + mlir::daphne::MatrixType matrixType = adaptor.getArg().getType().dyn_cast(); + auto loc = op->getLoc(); + auto nR = matrixType.getNumRows(); + auto nC = matrixType.getNumCols(); + + auto matrixElementType = matrixType.getElementType(); + auto memRefType = mlir::MemRefType::get({nR, nC}, matrixElementType); + auto memRef = + rewriter.create(op->getLoc(), memRefType, adaptor.getArg()); + + if (matrixElementType.isIntOrIndex()) { + IntegerType signless_type = rewriter.getIntegerType(matrixElementType.getIntOrFloatBitWidth()); + Value sum = + rewriter.create(loc, signless_type, rewriter.getIntegerAttr(signless_type, 0)); + + SmallVector loopIvs; + SmallVector forOps; + auto outerLoop = rewriter.create(loc, 0, nR, 1, ValueRange{sum}); + for (Operation &nested : *outerLoop.getBody()) { + rewriter.eraseOp(&nested); + } + loopIvs.push_back(outerLoop.getInductionVar()); + // outer loop body + rewriter.setInsertionPointToStart(outerLoop.getBody()); + Value sum_iter = + rewriter.create(loc, signless_type, rewriter.getIntegerAttr(signless_type, 0)); + // inner loop + auto innerLoop = rewriter.create(loc, 0, nC, 1, ValueRange{sum_iter}); + for (Operation &nested : *innerLoop.getBody()) { + rewriter.eraseOp(&nested); + } + loopIvs.push_back(innerLoop.getInductionVar()); + // inner loop body + rewriter.setInsertionPointToStart(innerLoop.getBody()); + // load value from memref + Value elementLoad = rewriter.create(loc, memRef, loopIvs); + auto castedElement = + this->typeConverter->materializeSourceConversion(rewriter, loc, signless_type, ValueRange{elementLoad}); + // sum loop iter arg and memref value + mlir::Value inner_sum = + rewriter.create(loc, innerLoop.getRegionIterArgs()[0], castedElement); + // yield inner loop result + rewriter.setInsertionPointToEnd(innerLoop.getBody()); + rewriter.create(loc, inner_sum); + // yield outer loop result + rewriter.setInsertionPointToEnd(outerLoop.getBody()); + mlir::Value outer_sum = + rewriter.create(loc, outerLoop.getRegionIterArgs()[0], innerLoop.getResult(0)); + rewriter.create(loc, outer_sum); + + rewriter.setInsertionPointAfter(outerLoop); + rewriter.create(loc, adaptor.getArg()); + // replace sumAll op with result of loops + auto castedRes = this->typeConverter->materializeTargetConversion(rewriter, loc, matrixElementType, + ValueRange{outerLoop->getResult(0)}); + rewriter.replaceOp(op, ValueRange{castedRes}); + + return success(); + } else { + Value sum = rewriter.create(loc, matrixElementType, + rewriter.getFloatAttr(matrixElementType, 0)); + + SmallVector loopIvs; + SmallVector forOps; + auto outerLoop = rewriter.create(loc, 0, nR, 1, ValueRange{sum}); + for (Operation &nested : *outerLoop.getBody()) { + rewriter.eraseOp(&nested); + } + loopIvs.push_back(outerLoop.getInductionVar()); + // outer loop body + rewriter.setInsertionPointToStart(outerLoop.getBody()); + Value sum_iter = rewriter.create(loc, matrixElementType, + rewriter.getFloatAttr(matrixElementType, 0)); + // inner loop + auto innerLoop = rewriter.create(loc, 0, nC, 1, ValueRange{sum_iter}); + for (Operation &nested : *innerLoop.getBody()) { + rewriter.eraseOp(&nested); + } + loopIvs.push_back(innerLoop.getInductionVar()); + // inner loop body + rewriter.setInsertionPointToStart(innerLoop.getBody()); + // load value from memref + auto elementLoad = rewriter.create(loc, memRef, loopIvs); + // sum loop iter arg and memref value + mlir::Value inner_sum = + rewriter.create(loc, innerLoop.getRegionIterArgs()[0], elementLoad); + // yield inner loop result + rewriter.setInsertionPointToEnd(innerLoop.getBody()); + rewriter.create(loc, inner_sum); + // yield outer loop result + rewriter.setInsertionPointToEnd(outerLoop.getBody()); + mlir::Value outer_sum = + rewriter.create(loc, outerLoop.getRegionIterArgs()[0], innerLoop.getResult(0)); + rewriter.create(loc, outer_sum); + + rewriter.setInsertionPointAfter(outerLoop); + rewriter.create(loc, adaptor.getArg()); + // replace sumAll op with result of loops + rewriter.replaceOp(op, outerLoop.getResult(0)); + + return success(); + } } - } }; namespace { @@ -191,61 +183,58 @@ namespace { * This rewrite may enable loop fusion of the produced affine loops by * running the loop fusion pass. */ -struct AggAllLoweringPass - : public mlir::PassWrapper> { - explicit AggAllLoweringPass() {} - - StringRef getArgument() const final { return "lower-agg"; } - StringRef getDescription() const final { - return "Lowers AggAll operators to a set of affine loops and performs " - "the aggregation on a MemRef which is created from the input " - "DenseMatrix."; - } - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() final; +struct AggAllLoweringPass : public mlir::PassWrapper> { + explicit AggAllLoweringPass() {} + + StringRef getArgument() const final { return "lower-agg"; } + StringRef getDescription() const final { + return "Lowers AggAll operators to a set of affine loops and performs " + "the aggregation on a MemRef which is created from the input " + "DenseMatrix."; + } + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; }; } // end anonymous namespace void AggAllLoweringPass::runOnOperation() { - mlir::ConversionTarget target(getContext()); - mlir::RewritePatternSet patterns(&getContext()); - LowerToLLVMOptions llvmOptions(&getContext()); - LLVMTypeConverter typeConverter(&getContext(), llvmOptions); - - typeConverter.addConversion(convertInteger); - typeConverter.addConversion(convertFloat); - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addArgumentMaterialization(materializeCastFromIllegal); - typeConverter.addSourceMaterialization(materializeCastToIllegal); - typeConverter.addTargetMaterialization(materializeCastFromIllegal); - - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - - target.addIllegalOp(); - - patterns.insert(typeConverter, &getContext()); - auto module = getOperation(); - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - signalPassFailure(); - } + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions llvmOptions(&getContext()); + LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + target.addIllegalOp(); + + patterns.insert(typeConverter, &getContext()); + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); + } } std::unique_ptr mlir::daphne::createAggAllOpLoweringPass() { - return std::make_unique(); + return std::make_unique(); } diff --git a/src/compiler/lowering/CMakeLists.txt b/src/compiler/lowering/CMakeLists.txt index 02bbf26ca..af7f0eb88 100644 --- a/src/compiler/lowering/CMakeLists.txt +++ b/src/compiler/lowering/CMakeLists.txt @@ -26,7 +26,6 @@ add_mlir_dialect_library(MLIRDaphneTransforms RewriteToCallKernelOpPass.cpp SpecializeGenericFunctionsPass.cpp VectorizeComputationsPass.cpp - WhileLoopInvariantCodeMotionPass.cpp DaphneOptPass.cpp EwOpsLowering.cpp ModOpLowering.cpp diff --git a/src/compiler/lowering/DaphneOptPass.cpp b/src/compiler/lowering/DaphneOptPass.cpp index 8795962e2..4b4ea2493 100644 --- a/src/compiler/lowering/DaphneOptPass.cpp +++ b/src/compiler/lowering/DaphneOptPass.cpp @@ -2,7 +2,6 @@ #include "compiler/utils/LoweringUtils.h" #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" -#include "llvm/Support/Debug.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -14,33 +13,31 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Debug.h" #define DEBUG_TYPE "dm-opt" using namespace mlir; class IntegerModOpt : public mlir::OpConversionPattern { - public: + public: using OpConversionPattern::OpConversionPattern; [[nodiscard]] static bool optimization_viable(mlir::daphne::EwModOp op) { - if (!op.getRhs().getType().isUnsignedInteger()) return false; + if (!op.getRhs().getType().isUnsignedInteger()) + return false; - std::pair isConstant = - CompilerUtils::isConstant(op.getRhs()); - // Apply (lhs % rhs) to (lhs & (rhs - 1)) optimization when rhs is a power of two + std::pair isConstant = CompilerUtils::isConstant(op.getRhs()); + // Apply (lhs % rhs) to (lhs & (rhs - 1)) optimization when rhs is a + // power of two return isConstant.first && (isConstant.second & (isConstant.second - 1)) == 0; } - mlir::LogicalResult matchAndRewrite( - mlir::daphne::EwModOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Value cst_one = rewriter.create( - op.getLoc(), static_cast(1)); - mlir::Value sub = rewriter.create( - op.getLoc(), adaptor.getRhs(), cst_one); - mlir::Value andOp = rewriter.create( - op.getLoc(), adaptor.getLhs(), sub); + mlir::LogicalResult matchAndRewrite(mlir::daphne::EwModOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Value cst_one = rewriter.create(op.getLoc(), static_cast(1)); + mlir::Value sub = rewriter.create(op.getLoc(), adaptor.getRhs(), cst_one); + mlir::Value andOp = rewriter.create(op.getLoc(), adaptor.getLhs(), sub); rewriter.replaceOp(op, andOp); return success(); } @@ -52,14 +49,11 @@ namespace { * the DaphneDialect to a different set of operations also from the * DaphneDialect. */ -struct DenseMatrixOptPass - : public mlir::PassWrapper> { +struct DenseMatrixOptPass : public mlir::PassWrapper> { explicit DenseMatrixOptPass() {} void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } void runOnOperation() final; @@ -70,7 +64,7 @@ struct DenseMatrixOptPass "also from the DaphneDialect."; } }; -} // end anonymous namespace +} // end anonymous namespace void DenseMatrixOptPass::runOnOperation() { mlir::ConversionTarget target(getContext()); @@ -85,9 +79,7 @@ void DenseMatrixOptPass::runOnOperation() { target.addLegalDialect(); target.addDynamicallyLegalOp( - [&](mlir::daphne::EwModOp op) { - return !IntegerModOpt::optimization_viable(op); - }); + [&](mlir::daphne::EwModOp op) { return !IntegerModOpt::optimization_viable(op); }); patterns.insert(typeConverter, &getContext()); @@ -97,6 +89,4 @@ void DenseMatrixOptPass::runOnOperation() { } } -std::unique_ptr mlir::daphne::createDaphneOptPass() { - return std::make_unique(); -} +std::unique_ptr mlir::daphne::createDaphneOptPass() { return std::make_unique(); } diff --git a/src/compiler/lowering/DistributeComputationsPass.cpp b/src/compiler/lowering/DistributeComputationsPass.cpp index 088688d31..8f9428587 100644 --- a/src/compiler/lowering/DistributeComputationsPass.cpp +++ b/src/compiler/lowering/DistributeComputationsPass.cpp @@ -18,25 +18,21 @@ #include "ir/daphneir/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" +#include #include #include using namespace mlir; -namespace -{ -struct Distribute : public OpInterfaceConversionPattern -{ +namespace { +struct Distribute : public OpInterfaceConversionPattern { using OpInterfaceConversionPattern::OpInterfaceConversionPattern; - LogicalResult - matchAndRewrite(daphne::Distributable op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override - { + LogicalResult matchAndRewrite(daphne::Distributable op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { std::vector distributedInputs; for (auto zipIt : llvm::zip(operands, op.getOperandDistrPrimitives())) { Value operand = std::get<0>(zipIt); @@ -56,7 +52,7 @@ struct Distribute : public OpInterfaceConversionPattern else { // The operands need to be distributed/broadcasted first. Type t = daphne::HandleType::get(getContext(), operand.getType()); - if(isBroadcast) + if (isBroadcast) distributedInputs.push_back(rewriter.create(op->getLoc(), t, operand)); else distributedInputs.push_back(rewriter.create(op->getLoc(), t, operand)); @@ -69,24 +65,19 @@ struct Distribute : public OpInterfaceConversionPattern } }; -struct DistributeComputationsPass - : public PassWrapper> -{ +struct DistributeComputationsPass : public PassWrapper> { void runOnOperation() final; StringRef getArgument() const final { return "distribute-computation"; } StringRef getDescription() const final { return "TODO"; } }; -} +} // namespace -bool onlyMatrixOperands(Operation * op) { - return llvm::all_of(op->getOperandTypes(), [](Type t) { - return llvm::isa(t); - }); +bool onlyMatrixOperands(Operation *op) { + return llvm::all_of(op->getOperandTypes(), [](Type t) { return llvm::isa(t); }); } -void DistributeComputationsPass::runOnOperation() -{ +void DistributeComputationsPass::runOnOperation() { auto module = getOperation(); RewritePatternSet patterns(&getContext()); @@ -95,17 +86,16 @@ void DistributeComputationsPass::runOnOperation() ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); - target.addDynamicallyLegalDialect([](Operation *op) - { + target.addDynamicallyLegalDialect([](Operation *op) { // An operation is legal (does not need to be replaced), if ... return - // ... it is not distributable - !llvm::isa(op) || - // ... it is inside some distributed computation already - op->getParentOfType() || - // ... not all of its operands are matrices - // TODO Support distributing frames and scalars. - !onlyMatrixOperands(op); + // ... it is not distributable + !llvm::isa(op) || + // ... it is inside some distributed computation already + op->getParentOfType() || + // ... not all of its operands are matrices + // TODO Support distributing frames and scalars. + !onlyMatrixOperands(op); }); patterns.add(&getContext()); @@ -114,7 +104,6 @@ void DistributeComputationsPass::runOnOperation() signalPassFailure(); } -std::unique_ptr daphne::createDistributeComputationsPass() -{ +std::unique_ptr daphne::createDistributeComputationsPass() { return std::make_unique(); } diff --git a/src/compiler/lowering/DistributePipelinesPass.cpp b/src/compiler/lowering/DistributePipelinesPass.cpp index 03489f63e..ede7ed178 100644 --- a/src/compiler/lowering/DistributePipelinesPass.cpp +++ b/src/compiler/lowering/DistributePipelinesPass.cpp @@ -18,31 +18,28 @@ #include "ir/daphneir/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/Transforms/DialectConversion.h" +#include using namespace mlir; /** * @brief Replaces vectorized pipelines by distributed pipelines. */ -struct DistributePipelines : public OpConversionPattern -{ +struct DistributePipelines : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(daphne::VectorizedPipelineOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { + LogicalResult matchAndRewrite(daphne::VectorizedPipelineOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { MLIRContext newContext; OpBuilder tempBuilder(&newContext); std::string funcName = "dist"; auto &bodyBlock = op.getBody().front(); - auto funcType = tempBuilder.getFunctionType( - bodyBlock.getArgumentTypes(), bodyBlock.getTerminator()->getOperandTypes()); + auto funcType = + tempBuilder.getFunctionType(bodyBlock.getArgumentTypes(), bodyBlock.getTerminator()->getOperandTypes()); auto funcOp = tempBuilder.create(op.getLoc(), funcName, funcType); IRMapping mapper; @@ -71,36 +68,31 @@ struct DistributePipelines : public OpConversionPattern(op.getLoc(), stream.str()); - rewriter.replaceOpWithNewOp( - op.getOperation(), - op.getOutputs().getTypes(), irStr, newInputs, - op.getOutRows(), op.getOutCols(), rewriter.getArrayAttr(newSplits), op.getCombines() - ); - + rewriter.replaceOpWithNewOp(op.getOperation(), op.getOutputs().getTypes(), irStr, + newInputs, op.getOutRows(), op.getOutCols(), + rewriter.getArrayAttr(newSplits), op.getCombines()); + return success(); } }; -struct DistributePipelinesPass - : public PassWrapper> -{ +struct DistributePipelinesPass : public PassWrapper> { void runOnOperation() final; StringRef getArgument() const final { return "distribute-pipelines"; } StringRef getDescription() const final { return "TODO"; } }; -void DistributePipelinesPass::runOnOperation() -{ +void DistributePipelinesPass::runOnOperation() { auto module = getOperation(); RewritePatternSet patterns(&getContext()); @@ -110,11 +102,11 @@ void DistributePipelinesPass::runOnOperation() // TODO do we need all these? target.addLegalDialect(); target.addLegalOp(); - target.addDynamicallyLegalOp([](daphne::VectorizedPipelineOp op) - { - // TODO Carefully decide if this pipeline shall be distributed, e.g., - // based on physical input size. For now, all pipelines are distributed - // (false means this pipeline is illegal and must be rewritten). + target.addDynamicallyLegalOp([](daphne::VectorizedPipelineOp op) { + // TODO Carefully decide if this pipeline shall be distributed, + // e.g., based on physical input size. For now, all pipelines are + // distributed (false means this pipeline is illegal and must be + // rewritten). return false; }); @@ -124,7 +116,4 @@ void DistributePipelinesPass::runOnOperation() signalPassFailure(); } -std::unique_ptr daphne::createDistributePipelinesPass() -{ - return std::make_unique(); -} +std::unique_ptr daphne::createDistributePipelinesPass() { return std::make_unique(); } diff --git a/src/compiler/lowering/EwOpsLowering.cpp b/src/compiler/lowering/EwOpsLowering.cpp index 056e46f44..d04960c83 100644 --- a/src/compiler/lowering/EwOpsLowering.cpp +++ b/src/compiler/lowering/EwOpsLowering.cpp @@ -44,27 +44,23 @@ using namespace mlir; -template -struct UnaryOpLowering : public mlir::OpConversionPattern { +template struct UnaryOpLowering : public mlir::OpConversionPattern { using OpAdaptor = typename mlir::OpConversionPattern::OpAdaptor; - public: + public: UnaryOpLowering(mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) : mlir::OpConversionPattern(typeConverter, ctx) { this->setDebugName("EwDaphneOpsLowering"); } - mlir::LogicalResult matchAndRewrite( - UnaryOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { + mlir::LogicalResult matchAndRewrite(UnaryOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { mlir::Type type = op.getType(); if (llvm::isa(type)) { - rewriter.replaceOpWithNewOp(op.getOperation(), - adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op.getOperation(), adaptor.getOperands()); } else if (llvm::isa(type)) { - rewriter.replaceOpWithNewOp(op.getOperation(), - adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op.getOperation(), adaptor.getOperands()); } else { return mlir::failure(); } @@ -76,50 +72,42 @@ template class BinaryOpLowering final : public mlir::OpConversionPattern { using OpAdaptor = typename mlir::OpConversionPattern::OpAdaptor; - public: + public: BinaryOpLowering(mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) : mlir::OpConversionPattern(typeConverter, ctx) { this->setDebugName("EwDaphneOpLowering"); } - mlir::LogicalResult convertEwScalar( - BinaryOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const { + mlir::LogicalResult convertEwScalar(BinaryOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { auto lhs = adaptor.getLhs(); auto rhs = adaptor.getRhs(); auto loc = op.getLoc(); - if (lhs.getType().template isa() && - rhs.getType().template isa()) { - rewriter.replaceOpWithNewOp(op.getOperation(), - adaptor.getOperands()); + if (lhs.getType().template isa() && rhs.getType().template isa()) { + rewriter.replaceOpWithNewOp(op.getOperation(), adaptor.getOperands()); return mlir::success(); } Value castedLhs = this->typeConverter->materializeTargetConversion( - rewriter, loc, - rewriter.getIntegerType( - adaptor.getRhs().getType().getIntOrFloatBitWidth()), + rewriter, loc, rewriter.getIntegerType(adaptor.getRhs().getType().getIntOrFloatBitWidth()), ValueRange{adaptor.getLhs()}); Value castedRhs = this->typeConverter->materializeTargetConversion( - rewriter, loc, - rewriter.getIntegerType( - adaptor.getRhs().getType().getIntOrFloatBitWidth()), + rewriter, loc, rewriter.getIntegerType(adaptor.getRhs().getType().getIntOrFloatBitWidth()), ValueRange{adaptor.getRhs()}); Value binaryOp = rewriter.create(loc, castedLhs, castedRhs); - Value res = this->typeConverter->materializeSourceConversion( - rewriter, loc, lhs.getType(), ValueRange{binaryOp}); + Value res = + this->typeConverter->materializeSourceConversion(rewriter, loc, lhs.getType(), ValueRange{binaryOp}); rewriter.replaceOp(op, res); return mlir::success(); } - mlir::LogicalResult matchAndRewrite( - BinaryOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { + mlir::LogicalResult matchAndRewrite(BinaryOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { auto lhs = adaptor.getLhs(); auto rhs = adaptor.getRhs(); @@ -130,98 +118,71 @@ class BinaryOpLowering final : public mlir::OpConversionPattern { // for now assume matrix is LHS and RHS is non matrix mlir::daphne::MatrixType lhsMatrixType = - adaptor.getLhs() - .getType() - .template dyn_cast(); + adaptor.getLhs().getType().template dyn_cast(); auto matrixElementType = lhsMatrixType.getElementType(); auto lhsRows = lhsMatrixType.getNumRows(); auto lhsCols = lhsMatrixType.getNumCols(); - auto lhsMemRefType = - mlir::MemRefType::get({lhsRows, lhsCols}, matrixElementType); + auto lhsMemRefType = mlir::MemRefType::get({lhsRows, lhsCols}, matrixElementType); mlir::Type elementType{}; mlir::Value memRefLhs = - rewriter.create( - op->getLoc(), lhsMemRefType, adaptor.getLhs()); + rewriter.create(op->getLoc(), lhsMemRefType, adaptor.getLhs()); mlir::Value memRefRhs{}; - bool isMatrixMatrix = - rhs.getType().template isa(); + bool isMatrixMatrix = rhs.getType().template isa(); if (isMatrixMatrix) { - memRefRhs = - rewriter.create( - op->getLoc(), lhsMemRefType, adaptor.getRhs()); + memRefRhs = rewriter.create(op->getLoc(), lhsMemRefType, + adaptor.getRhs()); elementType = lhsMemRefType.getElementType(); } else { elementType = rhs.getType(); } - mlir::Value outputMemRef = - insertMemRefAlloc(lhsMemRefType, op->getLoc(), rewriter); + mlir::Value outputMemRef = insertMemRefAlloc(lhsMemRefType, op->getLoc(), rewriter); SmallVector lowerBounds(/*Rank=*/2, /*Value=*/0); SmallVector steps(/*Rank=*/2, /*Value=*/1); buildAffineLoopNest( - rewriter, op.getLoc(), lowerBounds, - {lhsMatrixType.getNumRows(), lhsMatrixType.getNumCols()}, steps, + rewriter, op.getLoc(), lowerBounds, {lhsMatrixType.getNumRows(), lhsMatrixType.getNumCols()}, steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - mlir::Value loadLhs = - nestedBuilder.create(loc, memRefLhs, ivs); + mlir::Value loadLhs = nestedBuilder.create(loc, memRefLhs, ivs); mlir::Value binaryOp{}; - if (adaptor.getRhs() - .getType() - .template isa()) { - binaryOp = nestedBuilder.create(loc, loadLhs, - adaptor.getRhs()); + if (adaptor.getRhs().getType().template isa()) { + binaryOp = nestedBuilder.create(loc, loadLhs, adaptor.getRhs()); - nestedBuilder.create(loc, binaryOp, - outputMemRef, ivs); + nestedBuilder.create(loc, binaryOp, outputMemRef, ivs); return; } mlir::Value rhs{}; if (isMatrixMatrix) - rhs = - nestedBuilder.create(loc, memRefRhs, ivs); + rhs = nestedBuilder.create(loc, memRefRhs, ivs); else rhs = adaptor.getRhs(); // is integer - if (elementType.isInteger( - elementType.getIntOrFloatBitWidth())) { - Value castedLhs = - this->typeConverter->materializeTargetConversion( - nestedBuilder, loc, - nestedBuilder.getIntegerType( - lhsMemRefType.getElementTypeBitWidth()), - ValueRange{loadLhs}); - - Value castedRhs = - this->typeConverter->materializeTargetConversion( - nestedBuilder, loc, - nestedBuilder.getIntegerType( - lhsMemRefType.getElementTypeBitWidth()), - ValueRange{rhs}); - - binaryOp = - nestedBuilder.create(loc, castedLhs, castedRhs); - Value castedRes = - this->typeConverter->materializeSourceConversion( - nestedBuilder, loc, elementType, - ValueRange{binaryOp}); - nestedBuilder.create(loc, castedRes, - outputMemRef, ivs); + if (elementType.isInteger(elementType.getIntOrFloatBitWidth())) { + Value castedLhs = this->typeConverter->materializeTargetConversion( + nestedBuilder, loc, nestedBuilder.getIntegerType(lhsMemRefType.getElementTypeBitWidth()), + ValueRange{loadLhs}); + + Value castedRhs = this->typeConverter->materializeTargetConversion( + nestedBuilder, loc, nestedBuilder.getIntegerType(lhsMemRefType.getElementTypeBitWidth()), + ValueRange{rhs}); + + binaryOp = nestedBuilder.create(loc, castedLhs, castedRhs); + Value castedRes = this->typeConverter->materializeSourceConversion(nestedBuilder, loc, elementType, + ValueRange{binaryOp}); + nestedBuilder.create(loc, castedRes, outputMemRef, ivs); } else { // is float binaryOp = nestedBuilder.create(loc, loadLhs, rhs); - nestedBuilder.create(loc, binaryOp, - outputMemRef, ivs); + nestedBuilder.create(loc, binaryOp, outputMemRef, ivs); } }); - mlir::Value output = convertMemRefToDenseMatrix( - op->getLoc(), rewriter, outputMemRef, op.getType()); + mlir::Value output = convertMemRefToDenseMatrix(op->getLoc(), rewriter, outputMemRef, op.getType()); rewriter.replaceOp(op, output); return mlir::success(); @@ -247,14 +208,11 @@ namespace { * This rewrite may enable loop fusion of the produced affine loops by * running the loop fusion pass. */ -struct EwOpLoweringPass - : public mlir::PassWrapper> { +struct EwOpLoweringPass : public mlir::PassWrapper> { explicit EwOpLoweringPass() {} void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); } void runOnOperation() final; @@ -265,10 +223,9 @@ struct EwOpLoweringPass "structures and arithmetic operations."; } }; -} // end anonymous namespace +} // end anonymous namespace -void populateLowerEwOpConversionPatterns(mlir::LLVMTypeConverter &typeConverter, - mlir::RewritePatternSet &patterns) { +void populateLowerEwOpConversionPatterns(mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns) { // clang-format off patterns.insert< AddOpLowering, @@ -294,29 +251,20 @@ void EwOpLoweringPass::runOnOperation() { typeConverter.addSourceMaterialization(materializeCastToIllegal); typeConverter.addTargetMaterialization(materializeCastFromIllegal); - target.addLegalDialect(); + target.addLegalDialect(); target.addDynamicallyLegalOp( - [](Operation *op) { - return llvm::isa(op->getOperandTypes()[0]); - }); + [](Operation *op) { return llvm::isa(op->getOperandTypes()[0]); }); - target.addDynamicallyLegalOp([](Operation *op) { + target.addDynamicallyLegalOp([](Operation *op) { if (llvm::isa(op->getOperandTypes()[0]) && llvm::isa(op->getOperandTypes()[1])) { - mlir::daphne::MatrixType lhs = - op->getOperandTypes()[0] - .template dyn_cast(); - mlir::daphne::MatrixType rhs = - op->getOperandTypes()[1] - .template dyn_cast(); - if (lhs.getNumRows() != rhs.getNumRows() || - lhs.getNumCols() != rhs.getNumCols() || + mlir::daphne::MatrixType lhs = op->getOperandTypes()[0].template dyn_cast(); + mlir::daphne::MatrixType rhs = op->getOperandTypes()[1].template dyn_cast(); + if (lhs.getNumRows() != rhs.getNumRows() || lhs.getNumCols() != rhs.getNumCols() || lhs.getNumRows() == -1 || lhs.getNumCols() == -1) return true; @@ -324,8 +272,7 @@ void EwOpLoweringPass::runOnOperation() { } if (llvm::isa(op->getOperandTypes()[0])) { - mlir::daphne::MatrixType lhsMatrixType = - op->getOperandTypes()[0].dyn_cast(); + mlir::daphne::MatrixType lhsMatrixType = op->getOperandTypes()[0].dyn_cast(); return lhsMatrixType.getNumRows() == -1 || lhsMatrixType.getNumCols() == -1; } @@ -339,6 +286,4 @@ void EwOpLoweringPass::runOnOperation() { signalPassFailure(); } -std::unique_ptr mlir::daphne::createEwOpLoweringPass() { - return std::make_unique(); -} +std::unique_ptr mlir::daphne::createEwOpLoweringPass() { return std::make_unique(); } diff --git a/src/compiler/lowering/InsertDaphneContextPass.cpp b/src/compiler/lowering/InsertDaphneContextPass.cpp index c801cda8f..feb3a8bf0 100644 --- a/src/compiler/lowering/InsertDaphneContextPass.cpp +++ b/src/compiler/lowering/InsertDaphneContextPass.cpp @@ -33,54 +33,52 @@ using namespace mlir; // extensions in several directions, e.g.: // - inserting the context into blocks (e.g. parfor loop bodies) // - passing the context as an argument to a function -struct InsertDaphneContextPass : public PassWrapper> -{ - const DaphneUserConfig& user_config; - explicit InsertDaphneContextPass(const DaphneUserConfig& cfg) : user_config(cfg) {} +struct InsertDaphneContextPass : public PassWrapper> { + const DaphneUserConfig &user_config; + explicit InsertDaphneContextPass(const DaphneUserConfig &cfg) : user_config(cfg) {} void runOnOperation() final; }; -void InsertDaphneContextPass::runOnOperation() -{ +void InsertDaphneContextPass::runOnOperation() { func::FuncOp f = getOperation(); - Block & b = f.getBody().front(); - + Block &b = f.getBody().front(); + OpBuilder builder(&b, b.begin()); Location loc = f.getLoc(); // Insert a CreateDaphneContextOp as the first operation in the block. - builder.create(loc, daphne::DaphneContextType::get(&getContext()), - builder.create(loc, reinterpret_cast(&user_config)), - builder.create(loc, reinterpret_cast(&KernelDispatchMapping::instance())), - builder.create(loc, reinterpret_cast(&Statistics::instance())), - builder.create(loc, reinterpret_cast(&StringRefCounter::instance()))); + builder.create( + loc, daphne::DaphneContextType::get(&getContext()), + builder.create(loc, reinterpret_cast(&user_config)), + builder.create(loc, reinterpret_cast(&KernelDispatchMapping::instance())), + builder.create(loc, reinterpret_cast(&Statistics::instance())), + builder.create(loc, reinterpret_cast(&StringRefCounter::instance()))); #ifdef USE_CUDA - if(user_config.use_cuda) { + if (user_config.use_cuda) { builder.create(loc); } #endif - if (user_config.use_distributed){ + if (user_config.use_distributed) { builder.create(loc); } #ifdef USE_HDFS - if(user_config.use_hdfs) { + if (user_config.use_hdfs) { builder.create(loc); } #endif #ifdef USE_FPGAOPENCL - if(user_config.use_fpgaopencl) { + if (user_config.use_fpgaopencl) { builder.create(loc); } #endif - + // Insert a DestroyDaphneContextOp as the last operation in the block, but // before the block's terminator. builder.setInsertionPoint(b.getTerminator()); builder.create(loc); } -std::unique_ptr daphne::createInsertDaphneContextPass(const DaphneUserConfig& cfg) -{ +std::unique_ptr daphne::createInsertDaphneContextPass(const DaphneUserConfig &cfg) { return std::make_unique(cfg); } diff --git a/src/compiler/lowering/LowerToLLVMPass.cpp b/src/compiler/lowering/LowerToLLVMPass.cpp index 31d981ab3..0f14ab0cf 100644 --- a/src/compiler/lowering/LowerToLLVMPass.cpp +++ b/src/compiler/lowering/LowerToLLVMPass.cpp @@ -14,9 +14,9 @@ * limitations under the License. */ +#include "compiler/utils/CompilerUtils.h" #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" -#include "compiler/utils/CompilerUtils.h" #include #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" @@ -24,13 +24,13 @@ #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" @@ -38,38 +38,36 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Transforms/DialectConversion.h" +#include #include #include #include -#include using namespace mlir; // Remark on the creation of mlir::LLVM::AllocaOp // ============================================== -// This pass creates an mlir::LLVM::AllocaOp in several places and for various purposes, -// e.g., to store the result pointer of a kernel call, for variadic operands/results, etc. -// AllocaOp should not be inside a loop, as its repeated execution at run-time can lead -// to a stack overflow (depending on the number of iterations, the number of AllocaOps -// inside the loop, and the stack size). The reason is that the memory allocated by AllocaOp -// is freed only at the end of the scope (i.e., function). -// To avoid such problems, we don't create AllocaOps at the original insertion point of -// the rewriter, but at the beginning of function surrounding the currently considered op. -// To this end, we use the rewriter's ability to switch between different insertion points. -// Note that the memory allocated by an AllocaOp can be reused by multiple repeated -// kernel calls. +// This pass creates an mlir::LLVM::AllocaOp in several places and for various +// purposes, e.g., to store the result pointer of a kernel call, for variadic +// operands/results, etc. AllocaOp should not be inside a loop, as its repeated +// execution at run-time can lead to a stack overflow (depending on the number +// of iterations, the number of AllocaOps inside the loop, and the stack size). +// The reason is that the memory allocated by AllocaOp is freed only at the end +// of the scope (i.e., function). To avoid such problems, we don't create +// AllocaOps at the original insertion point of the rewriter, but at the +// beginning of function surrounding the currently considered op. To this end, +// we use the rewriter's ability to switch between different insertion points. +// Note that the memory allocated by an AllocaOp can be reused by multiple +// repeated kernel calls. // Optional attribute of CallKernelOp, which indicates that all results shall // be combined into a single variadic result. const std::string ATTR_HASVARIADICRESULTS = "hasVariadicResults"; -struct ReturnOpLowering : public OpRewritePattern -{ +struct ReturnOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(daphne::ReturnOp op, - PatternRewriter &rewriter) const final - { + LogicalResult matchAndRewrite(daphne::ReturnOp op, PatternRewriter &rewriter) const final { rewriter.replaceOpWithNewOp(op, op.getOperands()); return success(); } @@ -78,9 +76,8 @@ struct ReturnOpLowering : public OpRewritePattern struct CastOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(daphne::CastOp op, - PatternRewriter &rewriter) const final { - if(op.isTrivialCast() || op.isRemovePropertyCast()) { + LogicalResult matchAndRewrite(daphne::CastOp op, PatternRewriter &rewriter) const final { + if (op.isTrivialCast() || op.isRemovePropertyCast()) { rewriter.replaceOp(op, op.getOperand()); return success(); } @@ -90,17 +87,14 @@ struct CastOpLowering : public OpRewritePattern { /// ConstantOp lowering for types not handled before (str) -class ConstantOpLowering : public OpConversionPattern -{ -public: +class ConstantOpLowering : public OpConversionPattern { + public: using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(daphne::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { + LogicalResult matchAndRewrite(daphne::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - if(auto strAttr = op.getValue().dyn_cast()) { + if (auto strAttr = op.getValue().dyn_cast()) { StringRef sr = strAttr.getValue(); #if 1 // MLIR does not have direct support for strings. Thus, if this is @@ -109,76 +103,60 @@ class ConstantOpLowering : public OpConversionPattern // characters of the string constant to that array one by one. The // SSA value of the constant is replaced by a pointer to i8 // pointing to the allocated buffer. - Type i8PtrType = LLVM::LLVMPointerType::get( - IntegerType::get(rewriter.getContext(), 8) - ); + Type i8PtrType = LLVM::LLVMPointerType::get(IntegerType::get(rewriter.getContext(), 8)); const size_t numChars = sr.size() + 1; // +1 for trailing '\0' const std::string str = sr.str(); - const char * chars = str.c_str(); + const char *chars = str.c_str(); - // We could assume that the daphne::ConstantOp `op` is *not* inside a loop, - // because constants are typically moved to the top of a function during - // canonicalization. Consequently, we would not need to change the insertion - // point. However, being defensive, we still do it. + // We could assume that the daphne::ConstantOp `op` is *not* inside + // a loop, because constants are typically moved to the top of a + // function during canonicalization. Consequently, we would not need + // to change the insertion point. However, being defensive, we still + // do it. - // Set the insertion point to the beginning of the function surrounding this ConstantOp - // (see comment on AllocaOp above). + // Set the insertion point to the beginning of the function + // surrounding this ConstantOp (see comment on AllocaOp above). OpBuilder::InsertPoint ipHere = rewriter.saveInsertionPoint(); - Block & fb = op.getOperation()->getParentOfType().getBody().front(); + Block &fb = op.getOperation()->getParentOfType().getBody().front(); rewriter.setInsertionPointToStart(&fb); auto allocaOp = rewriter.replaceOpWithNewOp( - op.getOperation(), - i8PtrType, - rewriter.create(loc, rewriter.getI64IntegerAttr(numChars)), - 1 - ); + op.getOperation(), i8PtrType, + rewriter.create(loc, rewriter.getI64IntegerAttr(numChars)), 1); // Go back to the original insertion point. rewriter.restoreInsertionPoint(ipHere); - for(size_t i = 0; i < numChars; i++) { - std::vector indices = { - rewriter.create(loc, rewriter.getI64IntegerAttr(i)) - }; + for (size_t i = 0; i < numChars; i++) { + std::vector indices = {rewriter.create(loc, rewriter.getI64IntegerAttr(i))}; rewriter.create( - loc, - rewriter.create( - loc, rewriter.getI8IntegerAttr(chars[i]) - ), - rewriter.create( - op->getLoc(), i8PtrType, allocaOp, indices - ) - ); + loc, rewriter.create(loc, rewriter.getI8IntegerAttr(chars[i])), + rewriter.create(op->getLoc(), i8PtrType, allocaOp, indices)); } #else // Alternatively, we could create a global string, which would // yield a poiner to i8, too. However, we would need to choose a // unique name. - rewriter.replaceOp( - op.getOperation(), - LLVM::createGlobalString( - loc, rewriter, "someName", sr, - LLVM::Linkage::Private // TODO Does that make sense? - ) - ); + rewriter.replaceOp(op.getOperation(), + LLVM::createGlobalString(loc, rewriter, "someName", sr, + LLVM::Linkage::Private // TODO Does that make sense? + )); #endif - } - else { - // Constants of all other types are lowered to an mlir::arith::ConstantOp. - // Note that this is a different op than mlir::daphne::ConstantOp! + } else { + // Constants of all other types are lowered to an + // mlir::arith::ConstantOp. Note that this is a different op than + // mlir::daphne::ConstantOp! #if 1 rewriter.replaceOpWithNewOp(op.getOperation(), op.getValue()); #else - // NOTE: this fixes printing due to an error in the LLVMDialect, but is the wrong behaviour. + // NOTE: this fixes printing due to an error in the LLVMDialect, but + // is the wrong behaviour. // Use this for debugging only if (auto iTy = op.getType().dyn_cast()) { auto ty = IntegerType::get(getContext(), iTy.getWidth()); - rewriter.replaceOpWithNewOp(op.getOperation(), - ty, - IntegerAttr::get(ty, op.getValue().cast().getValue())); - } - else { + rewriter.replaceOpWithNewOp( + op.getOperation(), ty, IntegerAttr::get(ty, op.getValue().cast().getValue())); + } else { rewriter.replaceOpWithNewOp(op.getOperation(), op.getValue()); } #endif @@ -188,79 +166,62 @@ class ConstantOpLowering : public OpConversionPattern } }; -class CallKernelOpLowering : public OpConversionPattern -{ - - static std::vector getLLVMInputOutputTypes(Location &loc, - MLIRContext *context, - TypeConverter *typeConverter, - TypeRange resultTypes, - TypeRange operandTypes, - bool hasVarRes, - Type indexType) - { +class CallKernelOpLowering : public OpConversionPattern { + + static std::vector getLLVMInputOutputTypes(Location &loc, MLIRContext *context, TypeConverter *typeConverter, + TypeRange resultTypes, TypeRange operandTypes, bool hasVarRes, + Type indexType) { llvm::SmallVector args; - + // -------------------------------------------------------------------- // Results // -------------------------------------------------------------------- - + const size_t numRes = resultTypes.size(); - if(hasVarRes) { // combine all results into one variadic result + if (hasVarRes) { // combine all results into one variadic result // TODO Support individual result types, at least if they are all // mapped to the superclass Structure (see #397). // Check if all results have the same type. Type t0 = resultTypes[0]; Type mt0 = t0.dyn_cast().withSameElementTypeAndRepr(); - for(size_t i = 1; i < numRes; i++) - if (mt0 != resultTypes[i] - .dyn_cast() - .withSameElementTypeAndRepr()) { - throw ErrorHandler::compilerError( - loc, "LowerToLLVMPass", - "all results of a CallKernelOp must have the same " - "type to combine them into a single variadic result"); + for (size_t i = 1; i < numRes; i++) + if (mt0 != resultTypes[i].dyn_cast().withSameElementTypeAndRepr()) { + throw ErrorHandler::compilerError(loc, "LowerToLLVMPass", + "all results of a CallKernelOp must have the same " + "type to combine them into a single variadic result"); } // Wrap the common result type into a pointer, since we need an // array of that type. - args.push_back(LLVM::LLVMPointerType::get( - typeConverter->isLegal(t0) - ? t0 - : typeConverter->convertType(t0) - )); - } - else // typical case + args.push_back( + LLVM::LLVMPointerType::get(typeConverter->isLegal(t0) ? t0 : typeConverter->convertType(t0))); + } else // typical case for (auto type : resultTypes) { if (typeConverter->isLegal(type)) { args.push_back(type); - } - else if (failed(typeConverter->convertType(type, args))) + } else if (failed(typeConverter->convertType(type, args))) emitError(loc) << "Couldn't convert result type `" << type << "`\n"; } - + // -------------------------------------------------------------------- // Operands // -------------------------------------------------------------------- - - if(hasVarRes) + + if (hasVarRes) // Create a parameter for passing the number of results in the // single variadic result. - args.push_back(typeConverter->isLegal(indexType) - ? indexType - : typeConverter->convertType(indexType)); - + args.push_back(typeConverter->isLegal(indexType) ? indexType : typeConverter->convertType(indexType)); + for (auto type : operandTypes) { if (typeConverter->isLegal(type)) { args.push_back(type); - } - else if (failed(typeConverter->convertType(type, args))) + } else if (failed(typeConverter->convertType(type, args))) emitError(loc) << "Couldn't convert operand type `" << type << "`\n"; } // -------------------------------------------------------------------- // Create final LLVM types // -------------------------------------------------------------------- - + std::vector argsLLVM; for (size_t i = 0; i < args.size(); i++) { Type type = args[i]; @@ -271,18 +232,15 @@ class CallKernelOpLowering : public OpConversionPattern if (!hasVarRes && i < numRes) { type = LLVM::LLVMPointerType::get(type); } - + argsLLVM.push_back(type); } - + return argsLLVM; } - static FlatSymbolRefAttr - getOrInsertFunctionAttr(OpBuilder &rewriter, ModuleOp module, - llvm::StringRef funcName, - LLVM::LLVMFunctionType llvmFnType) - { + static FlatSymbolRefAttr getOrInsertFunctionAttr(OpBuilder &rewriter, ModuleOp module, llvm::StringRef funcName, + LLVM::LLVMFunctionType llvmFnType) { auto *context = module.getContext(); if (module.lookupSymbol(funcName)) return SymbolRefAttr::get(context, funcName); @@ -293,86 +251,61 @@ class CallKernelOpLowering : public OpConversionPattern return SymbolRefAttr::get(context, funcName); } - static LLVM::LLVMFunctionType - getKernelFuncSignature(MLIRContext *context, std::vector argsLLVM) - { - return LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), argsLLVM, - false); + static LLVM::LLVMFunctionType getKernelFuncSignature(MLIRContext *context, std::vector argsLLVM) { + return LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), argsLLVM, false); } -public: + public: using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(daphne::CallKernelOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { + LogicalResult matchAndRewrite(daphne::CallKernelOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { // Whether all results of the operation shall be combined into one // vardiadic result. If this is false (typical case), we pass a // separate nullptr for each result to the kernel. If it is true, we // create an array with the number of results, fill it with nullptrs, // and pass that to the kernel (variadic results). const bool hasVarRes = op->hasAttr(ATTR_HASVARIADICRESULTS) - ? op->getAttr(ATTR_HASVARIADICRESULTS).dyn_cast().getValue() - : false; - + ? op->getAttr(ATTR_HASVARIADICRESULTS).dyn_cast().getValue() + : false; + auto module = op->getParentOfType(); auto loc = op.getLoc(); - auto inputOutputTypes = getLLVMInputOutputTypes( - loc, rewriter.getContext(), typeConverter, op.getResultTypes(), - ValueRange(adaptor.getOperands()).getTypes(), hasVarRes, - rewriter.getIndexType()); + auto inputOutputTypes = + getLLVMInputOutputTypes(loc, rewriter.getContext(), typeConverter, op.getResultTypes(), + ValueRange(adaptor.getOperands()).getTypes(), hasVarRes, rewriter.getIndexType()); // create function protoype and get `FlatSymbolRefAttr` to it - auto kernelRef = getOrInsertFunctionAttr( - rewriter, module, op.getCalleeAttr().getValue(), - getKernelFuncSignature(rewriter.getContext(), inputOutputTypes)); + auto kernelRef = getOrInsertFunctionAttr(rewriter, module, op.getCalleeAttr().getValue(), + getKernelFuncSignature(rewriter.getContext(), inputOutputTypes)); - auto kernelOperands = allocOutputReferences( - loc, rewriter, adaptor.getOperands(), inputOutputTypes, - op->getNumResults(), hasVarRes, op); + auto kernelOperands = allocOutputReferences(loc, rewriter, adaptor.getOperands(), inputOutputTypes, + op->getNumResults(), hasVarRes, op); // call function // The kernel call has an empty list of return types, because our // kernel(-wrapper)s generally return via parameters. TypeRange ts; - rewriter.create( - loc, kernelRef, - ts, - kernelOperands); - rewriter.replaceOp(op, dereferenceOutputs(loc, rewriter, module, - op->getNumResults(), - hasVarRes, kernelOperands)); + rewriter.create(loc, kernelRef, ts, kernelOperands); + rewriter.replaceOp(op, + dereferenceOutputs(loc, rewriter, module, op->getNumResults(), hasVarRes, kernelOperands)); return success(); } -private: - - static std::vector - dereferenceOutputs(Location &loc, PatternRewriter &rewriter, ModuleOp &module, - size_t numResults, bool hasVarRes, std::vector kernelOperands) - { + private: + static std::vector dereferenceOutputs(Location &loc, PatternRewriter &rewriter, ModuleOp &module, + size_t numResults, bool hasVarRes, std::vector kernelOperands) { // transformed results std::vector results; - - if(hasVarRes) { // combine all results into one variadic result - for(size_t i = 0; i < numResults; i++) { - std::vector indices = { - rewriter.create(loc, rewriter.getI64IntegerAttr(i)) - }; + + if (hasVarRes) { // combine all results into one variadic result + for (size_t i = 0; i < numResults; i++) { + std::vector indices = {rewriter.create(loc, rewriter.getI64IntegerAttr(i))}; results.push_back(rewriter.create( - loc, - rewriter.create( - loc, - kernelOperands[0].getType(), - kernelOperands[0], - indices - ) - )); + loc, rewriter.create(loc, kernelOperands[0].getType(), kernelOperands[0], indices))); } - } - else // typical case + } else // typical case for (size_t i = 0; i < numResults; i++) { // dereference output auto value = kernelOperands[i]; @@ -381,23 +314,20 @@ class CallKernelOpLowering : public OpConversionPattern results.push_back(resultVal); } - + return results; } - std::vector - allocOutputReferences(Location &loc, PatternRewriter &rewriter, - ValueRange operands, - std::vector inputOutputTypes, size_t numRes, bool hasVarRes, - daphne::CallKernelOp op) const - { + std::vector allocOutputReferences(Location &loc, PatternRewriter &rewriter, ValueRange operands, + std::vector inputOutputTypes, size_t numRes, bool hasVarRes, + daphne::CallKernelOp op) const { std::vector kernelOperands; - - // Obtain an insertion point at the beginning of the function surrounding this CallKernelOp - // (see comment on AllocaOp above). + + // Obtain an insertion point at the beginning of the function + // surrounding this CallKernelOp (see comment on AllocaOp above). OpBuilder::InsertPoint ipHere = rewriter.saveInsertionPoint(); - Block & fb = op.getOperation()->getParentOfType().getBody().front(); + Block &fb = op.getOperation()->getParentOfType().getBody().front(); rewriter.setInsertionPointToStart(&fb); OpBuilder::InsertPoint ipFuncStart = rewriter.saveInsertionPoint(); rewriter.restoreInsertionPoint(ipHere); @@ -405,18 +335,17 @@ class CallKernelOpLowering : public OpConversionPattern // -------------------------------------------------------------------- // Results // -------------------------------------------------------------------- - - if(hasVarRes) { // combine all results into one variadic result + + if (hasVarRes) { // combine all results into one variadic result // Allocate an array of numRes elements. - // Set the insertion point to the beginning of the function (see comment on AllocaOp above). + // Set the insertion point to the beginning of the function (see + // comment on AllocaOp above). ipHere = rewriter.saveInsertionPoint(); rewriter.restoreInsertionPoint(ipFuncStart); auto allocaOp = rewriter.create( - loc, - inputOutputTypes[0], - rewriter.create(loc, rewriter.getI64IntegerAttr(numRes)).getResult() - ); + loc, inputOutputTypes[0], + rewriter.create(loc, rewriter.getI64IntegerAttr(numRes)).getResult()); ipFuncStart = rewriter.saveInsertionPoint(); // Go back to the original insertion point. @@ -430,24 +359,19 @@ class CallKernelOpLowering : public OpConversionPattern // (i.e. when it represents a scalar), initialization is not // required. Type elType = inputOutputTypes[0].dyn_cast().getElementType(); - if(llvm::isa(elType)) { - for(size_t i = 0; i < numRes; i++) { + if (llvm::isa(elType)) { + for (size_t i = 0; i < numRes; i++) { std::vector indices = { - rewriter.create(loc, rewriter.getI64IntegerAttr(i)) - }; + rewriter.create(loc, rewriter.getI64IntegerAttr(i))}; rewriter.create( - loc, - rewriter.create(loc, elType), - rewriter.create( - loc, inputOutputTypes[0], allocaOp, indices - ) - ); + loc, rewriter.create(loc, elType), + rewriter.create(loc, inputOutputTypes[0], allocaOp, indices)); } } - } - else { // typical case + } else { // typical case // Constant of 1 for AllocaOp of output. - // Set the insertion point to the beginning of the function (see comment on AllocaOp above). + // Set the insertion point to the beginning of the function (see + // comment on AllocaOp above). ipHere = rewriter.saveInsertionPoint(); rewriter.restoreInsertionPoint(ipFuncStart); Value cst1 = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); @@ -455,10 +379,11 @@ class CallKernelOpLowering : public OpConversionPattern // Go back to the original insertion point. rewriter.restoreInsertionPoint(ipHere); - + for (size_t i = 0; i < numRes; i++) { // Allocate space for a single element. - // Set the insertion point to the beginning of the function (see comment on AllocaOp above). + // Set the insertion point to the beginning of the function (see + // comment on AllocaOp above). ipHere = rewriter.saveInsertionPoint(); rewriter.restoreInsertionPoint(ipFuncStart); auto allocaOp = rewriter.create(loc, inputOutputTypes[i], cst1); @@ -468,33 +393,30 @@ class CallKernelOpLowering : public OpConversionPattern // Go back to the original insertion point. rewriter.restoreInsertionPoint(ipHere); - // If the type of this result parameter is a pointer (i.e. when it - // represents a matrix or frame), then initialize the allocated - // element with a null pointer (required by the kernels). Otherwise - // (i.e. when it represents a scalar), initialization is not - // required. + // If the type of this result parameter is a pointer (i.e. when + // it represents a matrix or frame), then initialize the + // allocated element with a null pointer (required by the + // kernels). Otherwise (i.e. when it represents a scalar), + // initialization is not required. Type elType = inputOutputTypes[i].dyn_cast().getElementType(); - if(llvm::isa(elType)) { - rewriter.create( - loc, - rewriter.create(loc, elType), - allocaOp - ); + if (llvm::isa(elType)) { + rewriter.create(loc, rewriter.create(loc, elType), allocaOp); } } } - + // -------------------------------------------------------------------- // Operands // -------------------------------------------------------------------- - - if(hasVarRes) - // Insert the number of results in the variadic result as a constant. + + if (hasVarRes) + // Insert the number of results in the variadic result as a + // constant. kernelOperands.push_back(rewriter.create(loc, rewriter.getIndexAttr(numRes))); - - for(auto op : operands) + + for (auto op : operands) kernelOperands.push_back(op); - + return kernelOperands; } }; @@ -503,28 +425,22 @@ class CallKernelOpLowering : public OpConversionPattern * @brief Rewrites `daphne::CreateVariadicPackOp` to `LLVM::AllocaOp` to create * an array for the required number of occurrences of a variadic operand. */ -class CreateVariadicPackOpLowering : public OpConversionPattern -{ -public: +class CreateVariadicPackOpLowering : public OpConversionPattern { + public: using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(daphne::CreateVariadicPackOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - // Set the insertion point to the beginning of the function surrounding this CreateVariadicPackOp - // (see comment on AllocaOp above). - Block & fb = op.getOperation()->getParentOfType().getBody().front(); + LogicalResult matchAndRewrite(daphne::CreateVariadicPackOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Set the insertion point to the beginning of the function surrounding + // this CreateVariadicPackOp (see comment on AllocaOp above). + Block &fb = op.getOperation()->getParentOfType().getBody().front(); rewriter.setInsertionPointToStart(&fb); Type contType = op.getRes().getType().dyn_cast().getContainedType(); Type convType = typeConverter->convertType(contType); rewriter.replaceOpWithNewOp( - op.getOperation(), - LLVM::LLVMPointerType::get(convType), - rewriter.create(op->getLoc(), op.getNumElementsAttr()), - 1 - ); + op.getOperation(), LLVM::LLVMPointerType::get(convType), + rewriter.create(op->getLoc(), op.getNumElementsAttr()), 1); return success(); } }; @@ -534,65 +450,49 @@ class CreateVariadicPackOpLowering : public OpConversionPattern -{ -public: +class StoreVariadicPackOpLowering : public OpConversionPattern { + public: using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(daphne::StoreVariadicPackOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { + LogicalResult matchAndRewrite(daphne::StoreVariadicPackOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { mlir::Location loc = op->getLoc(); mlir::Value pack = adaptor.getOperands()[0]; mlir::Value item = adaptor.getOperands()[1]; auto elementType = pack.getType().cast().getElementType(); - std::vector indices = { - rewriter.create(loc, op.getPosAttr()) - }; - auto addr = rewriter.create( - loc, pack.getType(), pack, indices - ); + std::vector indices = {rewriter.create(loc, op.getPosAttr())}; + auto addr = rewriter.create(loc, pack.getType(), pack, indices); Type itemType = item.getType(); if (itemType != elementType) { if (llvm::isa(elementType)) { - if(itemType.isSignedInteger()) + if (itemType.isSignedInteger()) item = rewriter.create(loc, rewriter.getI64Type(), item); - else if(itemType.isUnsignedInteger() || itemType.isSignlessInteger()) + else if (itemType.isUnsignedInteger() || itemType.isSignlessInteger()) item = rewriter.create(loc, rewriter.getI64Type(), item); - else if(llvm::isa(itemType)) { + else if (llvm::isa(itemType)) { item = rewriter.create(loc, rewriter.getF64Type(), item); item = rewriter.create(loc, rewriter.getI64Type(), item); } else { - throw ErrorHandler::compilerError( - loc, "LowerToLLVMPass", - "itemType is an unsupported type"); + throw ErrorHandler::compilerError(loc, "LowerToLLVMPass", "itemType is an unsupported type"); } item = rewriter.create(loc, elementType, item); - } - else { + } else { throw ErrorHandler::compilerError(loc, "LowerToLLVMPass", - "casting to a non-pointer type in " - "StoreVariadicPackOpLowering is not implemented yet" - ); + "casting to a non-pointer type in " + "StoreVariadicPackOpLowering is not implemented yet"); } } - rewriter.replaceOpWithNewOp( - op.getOperation(), item, addr - ); + rewriter.replaceOpWithNewOp(op.getOperation(), item, addr); return success(); } }; -class MapOpLowering : public OpConversionPattern -{ -public: +class MapOpLowering : public OpConversionPattern { + public: using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(daphne::MapOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { + LogicalResult matchAndRewrite(daphne::MapOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto module = op->getParentOfType(); @@ -605,42 +505,34 @@ class MapOpLowering : public OpConversionPattern // Input Matrix callee << "__" << CompilerUtils::mlirTypeToCppTypeName(op.getArg().getType(), false); - // Pointer to UDF + // Pointer to UDF callee << "__void"; - - // get pointer to UDF + // get pointer to UDF LLVM::LLVMFuncOp udfFuncOp = module.lookupSymbol(op.getFunc()); auto udfFnPtr = rewriter.create(loc, udfFuncOp); std::vector kernelOperands{op.getArg(), udfFnPtr}; - auto kernel = rewriter.create( - loc, - callee.str(), - kernelOperands, - op->getResultTypes() - ); + auto kernel = rewriter.create(loc, callee.str(), kernelOperands, op->getResultTypes()); rewriter.replaceOp(op, kernel.getResults()); return success(); } }; -class VectorizedPipelineOpLowering : public OpConversionPattern -{ - const DaphneUserConfig& cfg; +class VectorizedPipelineOpLowering : public OpConversionPattern { + const DaphneUserConfig &cfg; -public: - explicit VectorizedPipelineOpLowering(TypeConverter &typeConverter, MLIRContext *context, const DaphneUserConfig &cfg) - : OpConversionPattern(typeConverter, context), cfg(cfg) {} + public: + explicit VectorizedPipelineOpLowering(TypeConverter &typeConverter, MLIRContext *context, + const DaphneUserConfig &cfg) + : OpConversionPattern(typeConverter, context), cfg(cfg) {} using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(daphne::VectorizedPipelineOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { + LogicalResult matchAndRewrite(daphne::VectorizedPipelineOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { if (op.getCtx() == nullptr) { op->emitOpError() << "`DaphneContext` not known"; return failure(); @@ -658,7 +550,7 @@ class VectorizedPipelineOpLowering : public OpConversionPatterngetParentOfType(); - Block * moduleBody = moduleOp.getBody(); + Block *moduleBody = moduleOp.getBody(); rewriter.setInsertionPointToStart(moduleBody); static auto ix = 0; @@ -666,7 +558,8 @@ class VectorizedPipelineOpLowering : public OpConversionPattern(loc, funcName, funcType); fOp.getBody().takeBody(op.getBody()); @@ -675,35 +568,32 @@ class VectorizedPipelineOpLowering : public OpConversionPattern()) { callKernelOp.setOperand(callKernelOp.getNumOperands() - 1, daphneContext); } - // Extract inputs from array containing them and remove the block arguments matching the old inputs of the - // `VectorizedPipelineOp` + // Extract inputs from array containing them and remove the block + // arguments matching the old inputs of the `VectorizedPipelineOp` rewriter.setInsertionPointToStart(&funcBlock); - for(auto i = 0u; i < numDataOperands; ++i) { - auto addr = rewriter.create(loc, - ptrPtrI1Ty, - inputsArg, - ArrayRef({ - rewriter.create(loc, rewriter.getI64IntegerAttr(i))})); + for (auto i = 0u; i < numDataOperands; ++i) { + auto addr = rewriter.create( + loc, ptrPtrI1Ty, inputsArg, + ArrayRef({rewriter.create(loc, rewriter.getI64IntegerAttr(i))})); Value val = rewriter.create(loc, addr); auto expTy = typeConverter->convertType(op.getInputs().getType()[i]); if (expTy != val.getType()) { // casting for scalars val = rewriter.create(loc, rewriter.getI64Type(), val); - if(llvm::isa(expTy)) + if (llvm::isa(expTy)) val = rewriter.create(loc, expTy, val); - else if(llvm::isa(expTy)) { + else if (llvm::isa(expTy)) { val = rewriter.create(loc, rewriter.getF64Type(), val); val = rewriter.create(loc, expTy, val); } else { - throw ErrorHandler::compilerError( - loc, "LowerToLLVMPass", - "expTy is an unsupported type"); + throw ErrorHandler::compilerError(loc, "LowerToLLVMPass", "expTy is an unsupported type"); } } funcBlock.getArgument(0).replaceAllUsesWith(val); @@ -715,11 +605,14 @@ class VectorizedPipelineOpLowering : public OpConversionPatterngetNumOperands(); ++i) { auto retVal = oldReturn->getOperand(i); - // TODO: check how the GEPOp works exactly, and if this can be written better - auto addr1 = rewriter.create(op->getLoc(), pppI1Ty, returnRef, ArrayRef( - {rewriter.create(loc, rewriter.getI64IntegerAttr(i))})); + // TODO: check how the GEPOp works exactly, and if this can be + // written better + auto addr1 = rewriter.create( + op->getLoc(), pppI1Ty, returnRef, + ArrayRef({rewriter.create(loc, rewriter.getI64IntegerAttr(i))})); auto addr2 = rewriter.create(op->getLoc(), addr1); - Value retValConverted = typeConverter->materializeTargetConversion(rewriter, oldReturn->getLoc(), typeConverter->convertType(retVal.getType()), {retVal}); + Value retValConverted = typeConverter->materializeTargetConversion( + rewriter, oldReturn->getLoc(), typeConverter->convertType(retVal.getType()), {retVal}); rewriter.create(loc, retValConverted, addr2); } // Replace the old ReturnOp with operands by a new ReturnOp without @@ -731,19 +624,20 @@ class VectorizedPipelineOpLowering : public OpConversionPatterngetParentOfType(); - Block * moduleBody = moduleOp.getBody(); + Block *moduleBody = moduleOp.getBody(); rewriter.setInsertionPointToStart(moduleBody); static auto ix = 0; std::string funcName = "_vect_cuda" + std::to_string(++ix); auto funcType = LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(rewriter.getContext()), - {/*outputs...*/pppI1Ty, /*inputs...*/ ptrPtrI1Ty, /*daphneContext...*/ptrI1Ty}); + {/*outputs...*/ pppI1Ty, /*inputs...*/ ptrPtrI1Ty, + /*daphneContext...*/ ptrI1Ty}); fOp2 = rewriter.create(loc, funcName, funcType); fOp2.getBody().takeBody(op.getCuda()); @@ -753,42 +647,50 @@ class VectorizedPipelineOpLowering : public OpConversionPattern()) { + // TODO: we should not create a new daphneContext, instead pass + // the one created in the main function + for (auto callKernelOp : funcBlock.getOps()) { callKernelOp.setOperand(callKernelOp.getNumOperands() - 1, daphneContext); } - // Extract inputs from array containing them and remove the block arguments matching the old inputs of the + // Extract inputs from array containing them and remove the + // block arguments matching the old inputs of the // `VectorizedPipelineOp` rewriter.setInsertionPointToStart(&funcBlock); for (auto i = 0u; i < numDataOperands; ++i) { - auto addr = rewriter.create(loc, ptrPtrI1Ty, inputsArg, ArrayRef({ - rewriter.create(loc, rewriter.getI64IntegerAttr(i))})); + auto addr = rewriter.create( + loc, ptrPtrI1Ty, inputsArg, + ArrayRef({rewriter.create(loc, rewriter.getI64IntegerAttr(i))})); Value val = rewriter.create(loc, addr); auto expTy = typeConverter->convertType(op.getInputs().getType()[i]); if (expTy != val.getType()) { - val = rewriter.create(loc, rewriter.getIntegerType(expTy.getIntOrFloatBitWidth(),false), val); + val = rewriter.create( + loc, rewriter.getIntegerType(expTy.getIntOrFloatBitWidth(), false), val); val = rewriter.create(loc, expTy, val); } funcBlock.getArgument(0).replaceAllUsesWith(val); funcBlock.eraseArgument(0); } - // Update function block to write return value by reference instead + // Update function block to write return value by reference + // instead auto oldReturn = funcBlock.getTerminator(); rewriter.setInsertionPoint(oldReturn); for (auto i = 0u; i < oldReturn->getNumOperands(); ++i) { auto retVal = oldReturn->getOperand(i); - // TODO: check how the GEPOp works exactly, and if this can be written better - auto addr1 = rewriter.create(op->getLoc(), pppI1Ty, returnRef, ArrayRef( - {rewriter.create(loc, rewriter.getI64IntegerAttr(i))})); + // TODO: check how the GEPOp works exactly, and if this can + // be written better + auto addr1 = rewriter.create( + op->getLoc(), pppI1Ty, returnRef, + ArrayRef({rewriter.create(loc, rewriter.getI64IntegerAttr(i))})); auto addr2 = rewriter.create(op->getLoc(), addr1); - Value retValConverted = typeConverter->materializeTargetConversion(rewriter, oldReturn->getLoc(), typeConverter->convertType(retVal.getType()), {retVal}); + Value retValConverted = typeConverter->materializeTargetConversion( + rewriter, oldReturn->getLoc(), typeConverter->convertType(retVal.getType()), {retVal}); rewriter.create(loc, retValConverted, addr2); } - // Replace the old ReturnOp with operands by a new ReturnOp without - // operands. + // Replace the old ReturnOp with operands by a new ReturnOp + // without operands. rewriter.replaceOpWithNewOp(oldReturn); } @@ -803,95 +705,84 @@ class VectorizedPipelineOpLowering : public OpConversionPatterngetResultTypes(); const size_t numRes = op->getNumResults(); - if(numRes > 0) { + if (numRes > 0) { // TODO Support individual types for all outputs (see #397). // Check if all results have the same type. Type mt0 = resultTypes[0].dyn_cast().withSameElementTypeAndRepr(); for (size_t i = 1; i < numRes; i++) { - if (mt0 != resultTypes[i] - .dyn_cast() - .withSameElementTypeAndRepr()) { - throw ErrorHandler::compilerError( - op, "LowerToLLVMPass", - "encountered a vectorized pipelines with different " - "result types, but at the moment we require all " - "results to have the same type"); + if (mt0 != resultTypes[i].dyn_cast().withSameElementTypeAndRepr()) { + throw ErrorHandler::compilerError(op, "LowerToLLVMPass", + "encountered a vectorized pipelines with different " + "result types, but at the moment we require all " + "results to have the same type"); } } - // Append the name of the common type of all results to the kernel name. + // Append the name of the common type of all results to the kernel + // name. callee << "__" << CompilerUtils::mlirTypeToCppTypeName(resultTypes[0], false) << "_variadic__size_t"; } mlir::Type operandType; std::vector newOperands; - if(numRes > 0) { + if (numRes > 0) { auto m32type = rewriter.getF32Type(); auto m64type = rewriter.getF64Type(); auto msi64type = rewriter.getIntegerType(64, true); auto res_elem_type = op->getResult(0).getType().dyn_cast().getElementType(); - if(res_elem_type == m64type) + if (res_elem_type == m64type) operandType = daphne::MatrixType::get(getContext(), m64type); - else if(res_elem_type == m32type) + else if (res_elem_type == m32type) operandType = daphne::MatrixType::get(getContext(), m32type); - else if(res_elem_type == msi64type) + else if (res_elem_type == msi64type) operandType = daphne::MatrixType::get(getContext(), msi64type); else { std::string str; llvm::raw_string_ostream output(str); op->getResult(0).getType().print(output); - throw ErrorHandler::compilerError( - op, "LowerToLLVMPass", - "Unsupported result type for vectorizedPipeline op: " + - str); + throw ErrorHandler::compilerError(op, "LowerToLLVMPass", + "Unsupported result type for vectorizedPipeline op: " + str); } - } - else { - throw ErrorHandler::compilerError( - op, "LowerToLLVMPass", - "vectorizedPipelineOp without outputs not supported at the " - "moment!"); + } else { + throw ErrorHandler::compilerError(op, "LowerToLLVMPass", + "vectorizedPipelineOp without outputs not supported at the " + "moment!"); } // Handle variadic operands isScalar and inputs (both share numInputs). auto attrNumInputs = rewriter.getI64IntegerAttr(numDataOperands); // For isScalar. callee << "__bool"; - auto vpScalar = rewriter.create(loc, - daphne::VariadicPackType::get(rewriter.getContext(), rewriter.getI1Type()), - attrNumInputs); + auto vpScalar = rewriter.create( + loc, daphne::VariadicPackType::get(rewriter.getContext(), rewriter.getI1Type()), attrNumInputs); // For inputs and numInputs. callee << "__" << CompilerUtils::mlirTypeToCppTypeName(operandType, false, true); callee << "_variadic__size_t"; - auto vpInputs = rewriter.create(loc, - daphne::VariadicPackType::get(rewriter.getContext(), operandType), - attrNumInputs); + auto vpInputs = rewriter.create( + loc, daphne::VariadicPackType::get(rewriter.getContext(), operandType), attrNumInputs); // Populate the variadic packs for isScalar and inputs. - for(size_t k = 0; k < numDataOperands; k++) { + for (size_t k = 0; k < numDataOperands; k++) { auto attrK = rewriter.getI64IntegerAttr(k); rewriter.create( + loc, vpScalar, + rewriter.create( loc, - vpScalar, - rewriter.create( - loc, - // We assume this input to be a scalar if its type - // has not been converted to a pointer type. - !llvm::isa(adaptor.getOperands()[k].getType()) - ), - attrK - ); - rewriter.create( - loc, vpInputs, adaptor.getOperands()[k], attrK - ); + // We assume this input to be a scalar if its type + // has not been converted to a pointer type. + !llvm::isa(adaptor.getOperands()[k].getType())), + attrK); + rewriter.create(loc, vpInputs, adaptor.getOperands()[k], attrK); } newOperands.push_back(vpScalar); newOperands.push_back(vpInputs); - newOperands.push_back(rewriter.create(loc, rewriter.getIndexType(), rewriter.getIndexAttr(numDataOperands))); + newOperands.push_back( + rewriter.create(loc, rewriter.getIndexType(), rewriter.getIndexAttr(numDataOperands))); - // Obtain an insertion point at the beginning of the function surrounding this VectorizedPipelineOp - // (see comment on AllocaOp above). + // Obtain an insertion point at the beginning of the function + // surrounding this VectorizedPipelineOp (see comment on AllocaOp + // above). OpBuilder::InsertPoint ipHere = rewriter.saveInsertionPoint(); - Block & fb = op.getOperation()->getParentOfType().getBody().front(); + Block &fb = op.getOperation()->getParentOfType().getBody().front(); rewriter.setInsertionPointToStart(&fb); OpBuilder::InsertPoint ipFuncStart = rewriter.saveInsertionPoint(); rewriter.restoreInsertionPoint(ipHere); @@ -900,16 +791,17 @@ class VectorizedPipelineOpLowering : public OpConversionPattern splitConsts; - for(auto split : op.getSplits()) { + for (auto split : op.getSplits()) { splitConsts.push_back(rewriter.create(loc, split)); } newOperands.push_back(convertToArray(loc, rewriter, rewriter.getI64Type(), splitConsts, ipFuncStart)); @@ -917,57 +809,53 @@ class VectorizedPipelineOpLowering : public OpConversionPattern combineConsts; - for(auto combine : op.getCombines()) { + for (auto combine : op.getCombines()) { combineConsts.push_back(rewriter.create(loc, combine)); } newOperands.push_back(convertToArray(loc, rewriter, rewriter.getI64Type(), combineConsts, ipFuncStart)); - // TODO: pass function pointer with special placeholder instead of `void` + // TODO: pass function pointer with special placeholder instead of + // `void` callee << "__size_t"; - newOperands.push_back(rewriter.create(loc, rewriter.getIndexType(), rewriter.getIndexAttr(func_ptrs.size()))); + newOperands.push_back( + rewriter.create(loc, rewriter.getIndexType(), rewriter.getIndexAttr(func_ptrs.size()))); callee << "__void_variadic"; newOperands.push_back(convertToArray(loc, rewriter, ptrPtrI1Ty, func_ptrs, ipFuncStart)); -// newOperands.push_back(fnPtr); + // newOperands.push_back(fnPtr); // Add ctx -// newOperands.push_back(operands.back()); + // newOperands.push_back(operands.back()); if (op.getCtx() == nullptr) { op->emitOpError() << "`DaphneContext` not known"; return failure(); - } - else + } else newOperands.push_back(op.getCtx()); // Create a CallKernelOp for the kernel function to call and return // success(). - auto kernel = rewriter.create( - loc, - callee.str(), - newOperands, - resultTypes - ); + auto kernel = rewriter.create(loc, callee.str(), newOperands, resultTypes); kernel->setAttr(ATTR_HASVARIADICRESULTS, rewriter.getBoolAttr(true)); rewriter.replaceOp(op, kernel.getResults()); return success(); } -private: - static Value convertToArray(Location loc, ConversionPatternRewriter &rewriter, Type valueTy, ValueRange values, OpBuilder::InsertPoint & ipFuncStart) - { - // Set the insertion point to the beginning of the function surrounding this VectorizedPipelineOp - // (see comment on AllocaOp above). + + private: + static Value convertToArray(Location loc, ConversionPatternRewriter &rewriter, Type valueTy, ValueRange values, + OpBuilder::InsertPoint &ipFuncStart) { + // Set the insertion point to the beginning of the function surrounding + // this VectorizedPipelineOp (see comment on AllocaOp above). OpBuilder::InsertPoint ipHere = rewriter.saveInsertionPoint(); rewriter.restoreInsertionPoint(ipFuncStart); auto valuePtrTy = LLVM::LLVMPointerType::get(valueTy); - auto array = rewriter.create(loc, - valuePtrTy, - Value(rewriter.create(loc, rewriter.getI64IntegerAttr(values.size())))); + auto array = rewriter.create( + loc, valuePtrTy, Value(rewriter.create(loc, rewriter.getI64IntegerAttr(values.size())))); ipFuncStart = rewriter.saveInsertionPoint(); // Go back to the original insertion point. rewriter.restoreInsertionPoint(ipHere); - for(auto i = 0u; i < values.size(); ++i) { + for (auto i = 0u; i < values.size(); ++i) { Value cstI = rewriter.create(loc, rewriter.getI64IntegerAttr(i)); auto addr = rewriter.create(loc, valuePtrTy, array, ArrayRef({cstI})); auto val = values[i]; @@ -980,38 +868,30 @@ class VectorizedPipelineOpLowering : public OpConversionPattern -{ -public: +class GenericCallOpLowering : public OpConversionPattern { + public: using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(daphne::GenericCallOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { + LogicalResult matchAndRewrite(daphne::GenericCallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getCallee(), op->getResultTypes(), adaptor.getOperands()); return success(); } }; -namespace -{ - struct DaphneLowerToLLVMPass - : public PassWrapper> - { - explicit DaphneLowerToLLVMPass(const DaphneUserConfig& cfg) : cfg(cfg) { } - const DaphneUserConfig& cfg; +namespace { +struct DaphneLowerToLLVMPass : public PassWrapper> { + explicit DaphneLowerToLLVMPass(const DaphneUserConfig &cfg) : cfg(cfg) {} + const DaphneUserConfig &cfg; - void getDependentDialects(DialectRegistry & registry) const override - { - registry.insert(); - } - void runOnOperation() final; - }; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() final; +}; } // end anonymous namespace -void DaphneLowerToLLVMPass::runOnOperation() -{ +void DaphneLowerToLLVMPass::runOnOperation() { auto module = getOperation(); RewritePatternSet patterns(&getContext()); @@ -1019,57 +899,27 @@ void DaphneLowerToLLVMPass::runOnOperation() LowerToLLVMOptions llvmOptions(&getContext()); // llvmOptions.useBarePtrCallConv = true; LLVMTypeConverter typeConverter(&getContext(), llvmOptions); - typeConverter.addConversion([&](daphne::MatrixType t) - { - return LLVM::LLVMPointerType::get( - IntegerType::get(t.getContext(), 1)); - }); - typeConverter.addConversion([&](daphne::FrameType t) - { - return LLVM::LLVMPointerType::get( - IntegerType::get(t.getContext(), 1)); - }); - typeConverter.addConversion([&](daphne::ListType t) - { - return LLVM::LLVMPointerType::get( - IntegerType::get(t.getContext(), 1)); - }); - typeConverter.addConversion([&](daphne::StringType t) - { - return LLVM::LLVMPointerType::get( - IntegerType::get(t.getContext(), 8)); - }); - typeConverter.addConversion([&](daphne::VariadicPackType t) - { - return LLVM::LLVMPointerType::get( - typeConverter.convertType(t.getContainedType()) - ); - }); - typeConverter.addConversion([&](daphne::DaphneContextType t) - { - return LLVM::LLVMPointerType::get( - IntegerType::get(t.getContext(), 1)); - }); - typeConverter.addConversion([&](daphne::HandleType t) - { - return LLVM::LLVMPointerType::get( - IntegerType::get(t.getContext(), 1)); - }); - typeConverter.addConversion([&](daphne::FileType t) - { - return LLVM::LLVMPointerType::get( - IntegerType::get(t.getContext(), 1)); - }); - typeConverter.addConversion([&](daphne::DescriptorType t) - { - return LLVM::LLVMPointerType::get( - IntegerType::get(t.getContext(), 1)); - }); - typeConverter.addConversion([&](daphne::TargetType t) - { - return LLVM::LLVMPointerType::get( - IntegerType::get(t.getContext(), 1)); + typeConverter.addConversion( + [&](daphne::MatrixType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); + typeConverter.addConversion( + [&](daphne::FrameType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); + typeConverter.addConversion( + [&](daphne::ListType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); + typeConverter.addConversion( + [&](daphne::StringType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 8)); }); + typeConverter.addConversion([&](daphne::VariadicPackType t) { + return LLVM::LLVMPointerType::get(typeConverter.convertType(t.getContainedType())); }); + typeConverter.addConversion( + [&](daphne::DaphneContextType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); + typeConverter.addConversion( + [&](daphne::HandleType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); + typeConverter.addConversion( + [&](daphne::FileType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); + typeConverter.addConversion( + [&](daphne::DescriptorType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); + typeConverter.addConversion( + [&](daphne::TargetType t) { return LLVM::LLVMPointerType::get(IntegerType::get(t.getContext(), 1)); }); LLVMConversionTarget target(getContext()); @@ -1087,17 +937,11 @@ void DaphneLowerToLLVMPass::runOnOperation() // for trivial casts no lowering to kernels -> higher benefit patterns.insert(&getContext(), 2); - patterns.insert( - typeConverter, &getContext()); + patterns.insert(typeConverter, &getContext()); patterns.insert(typeConverter, &getContext(), cfg); - patterns.insert< - ConstantOpLowering, - ReturnOpLowering, - StoreVariadicPackOpLowering, - GenericCallOpLowering, - MapOpLowering - >(&getContext()); + patterns.insert(&getContext()); // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. @@ -1105,7 +949,6 @@ void DaphneLowerToLLVMPass::runOnOperation() signalPassFailure(); } -std::unique_ptr daphne::createLowerToLLVMPass(const DaphneUserConfig& cfg) -{ +std::unique_ptr daphne::createLowerToLLVMPass(const DaphneUserConfig &cfg) { return std::make_unique(cfg); } diff --git a/src/compiler/lowering/ManageObjRefsPass.cpp b/src/compiler/lowering/ManageObjRefsPass.cpp index f0e89d6d7..fd9f8746e 100644 --- a/src/compiler/lowering/ManageObjRefsPass.cpp +++ b/src/compiler/lowering/ManageObjRefsPass.cpp @@ -15,10 +15,10 @@ */ #include -#include #include #include #include +#include #include #include @@ -46,8 +46,7 @@ using namespace mlir; * object that is still needed in a surrounding scope, i.e., to prevent * double frees. */ -struct ManageObjRefsPass : public PassWrapper> -{ +struct ManageObjRefsPass : public PassWrapper> { explicit ManageObjRefsPass() {} void runOnOperation() final; @@ -56,11 +55,10 @@ struct ManageObjRefsPass : public PassWrapper(v.getLoc(), - v.getDefiningOp()->getOperand(0)); + builder.create(v.getLoc(), v.getDefiningOp()->getOperand(0)); } /** @@ -76,49 +74,48 @@ void processValue(OpBuilder builder, Value v) { // We only need to manage the reference counters of DAPHNE data objects // like matrices and frames (not of scalars). - Operation* defOp = v.getDefiningOp(); + Operation *defOp = v.getDefiningOp(); if (defOp && llvm::isa(defOp)) processMemRefInterop(builder, v); // Increase the reference counter of string literals, such that they don't // get gargabe collected. - if(defOp && llvm::isa(defOp) && llvm::isa(v.getType())) { - // The given value is a string literal. We want to increase its reference - // counter right after its definition, such that it is never removed. - // But if the defining op is the block of a FuncOp, make sure not to insert the - // IncRefOp before the CreateDaphneContextOp, otherwise we will run - // into problems during/after lowering to kernel calls. - Block * pb = v.getParentBlock(); - if(auto fo = dyn_cast(pb->getParentOp())) { + if (defOp && llvm::isa(defOp) && llvm::isa(v.getType())) { + // The given value is a string literal. We want to increase its + // reference counter right after its definition, such that it is never + // removed. But if the defining op is the block of a FuncOp, make sure + // not to insert the IncRefOp before the CreateDaphneContextOp, + // otherwise we will run into problems during/after lowering to kernel + // calls. + Block *pb = v.getParentBlock(); + if (auto fo = dyn_cast(pb->getParentOp())) { Value dctx = CompilerUtils::getDaphneContext(fo); builder.setInsertionPointAfterValue(dctx); - } - else + } else builder.setInsertionPointAfter(defOp); builder.create(v.getLoc(), v); } - // Increase the reference counter of the result of the arith.select op, if it is - // a string scalar. - // This is necessary because for arith.select, we have no clue which of - // its two arguments (2nd or 3rd one) it will return. Unless we do something - // about it, the reference counter of the result will be too low by 1. - // Thus, we increase the result's reference counter here. - if(defOp && llvm::isa(defOp) && llvm::isa(v.getType())) { + // Increase the reference counter of the result of the arith.select op, if + // it is a string scalar. This is necessary because for arith.select, we + // have no clue which of its two arguments (2nd or 3rd one) it will return. + // Unless we do something about it, the reference counter of the result will + // be too low by 1. Thus, we increase the result's reference counter here. + if (defOp && llvm::isa(defOp) && llvm::isa(v.getType())) { builder.setInsertionPointAfter(defOp); builder.create(v.getLoc(), v); } - if (!llvm::isa(v.getType())) + if (!llvm::isa(v.getType())) return; - Operation* decRefAfterOp = nullptr; + Operation *decRefAfterOp = nullptr; if (v.use_empty()) { // If the given SSA value has no uses, we want to decrease its // reference counter directly after its definition (nullptr for block // args). Note that ideally, there should be no unused SSA values. - if (defOp) decRefAfterOp = defOp; + if (defOp) + decRefAfterOp = defOp; // else: decRefAfterOp stays nullptr } else { // If the given SSA value has uses, we need to find the last of them. @@ -135,11 +132,11 @@ void processValue(OpBuilder builder, Value v) { // At this point, decRefAfterOp is nullptr, or the last user of v, or the // defining op of v. - if(decRefAfterOp) { + if (decRefAfterOp) { // The given value is used and/or an OpResult. // Don't insert a DecRefOp if the last user is a terminator. - if(decRefAfterOp->hasTrait()) + if (decRefAfterOp->hasTrait()) // The value is handed out of its block (e.g., return, yield, ...). // So a new reference to it is created. Thus, the reference counter // must remain unchanged. Moreover, it is impossible to insert any @@ -150,23 +147,21 @@ void processValue(OpBuilder builder, Value v) { // Don't insert a DecRefOp if there is already one. Currently, this can // happen only on the distributed worker, since the IR it gets already // contains - if(llvm::isa(decRefAfterOp)) + if (llvm::isa(decRefAfterOp)) return; builder.setInsertionPointAfter(decRefAfterOp); - } - else { + } else { // The given value is an unused block arg. Decrease its reference // counter at the beginning of the block. // But if this is the block of a FuncOp, make sure not to insert the // DecRefOp before the CreateDaphneContextOp, otherwise we will run // into problems during/after lowering to kernel calls. - Block * pb = v.getParentBlock(); - if(auto fo = dyn_cast(pb->getParentOp())) { + Block *pb = v.getParentBlock(); + if (auto fo = dyn_cast(pb->getParentOp())) { Value dctx = CompilerUtils::getDaphneContext(fo); builder.setInsertionPointAfterValue(dctx); - } - else + } else builder.setInsertionPointToStart(pb); } @@ -183,15 +178,14 @@ void processValue(OpBuilder builder, Value v) { * @param v * @param b */ -void incRefIfObj(Value v, OpBuilder & b) { +void incRefIfObj(Value v, OpBuilder &b) { Type t = v.getType(); - if(llvm::isa(t)) + if (llvm::isa(t)) b.create(v.getLoc(), v); - else if(llvm::isa(t)) - throw ErrorHandler::compilerError( - v.getDefiningOp(), "ManageObjRefsPass", - "ManageObjRefsPass encountered a value of unknown type, so it " - "cannot know if it is a data object."); + else if (llvm::isa(t)) + throw ErrorHandler::compilerError(v.getDefiningOp(), "ManageObjRefsPass", + "ManageObjRefsPass encountered a value of unknown type, so it " + "cannot know if it is a data object."); } /** @@ -202,9 +196,9 @@ void incRefIfObj(Value v, OpBuilder & b) { * @param op * @param b */ -void incRefArgs(Operation& op, OpBuilder & b) { +void incRefArgs(Operation &op, OpBuilder &b) { b.setInsertionPoint(&op); - for(Value arg : op.getOperands()) + for (Value arg : op.getOperands()) incRefIfObj(arg, b); } @@ -215,46 +209,46 @@ void incRefArgs(Operation& op, OpBuilder & b) { * @param builder * @param b */ -void processBlock(OpBuilder builder, Block * b) { +void processBlock(OpBuilder builder, Block *b) { // Make sure that the reference counters of block arguments are decreased. - for(BlockArgument& arg : b->getArguments()) + for (BlockArgument &arg : b->getArguments()) processValue(builder, arg); // Make sure the reference counters of op results are decreased, and // Increase the reference counters of operands where necessary. - for(Operation& op : b->getOperations()) { + for (Operation &op : b->getOperations()) { // 1) Increase the reference counters of operands, if necessary. // TODO We could use traits to identify those cases. // Casts that will not call a kernel. - if(auto co = dyn_cast(op)) { - if(co.isTrivialCast() || co.isRemovePropertyCast()) + if (auto co = dyn_cast(op)) { + if (co.isTrivialCast() || co.isRemovePropertyCast()) incRefArgs(op, builder); } // Loops and function calls. - else if(llvm::isa(op)) + else if (llvm::isa(op)) incRefArgs(op, builder); // YieldOp of IfOp. - else if(llvm::isa(op) && llvm::isa(op.getParentOp())) { + else if (llvm::isa(op) && llvm::isa(op.getParentOp())) { // Increase the reference counters of data objects that already // existed before the IfOp, because yielding them creates a new // SSA value referring to them. builder.setInsertionPoint(&op); - for(Value arg : op.getOperands()) - if(arg.getParentBlock() != op.getBlock()) + for (Value arg : op.getOperands()) + if (arg.getParentBlock() != op.getBlock()) incRefIfObj(arg, builder); } // Terminators. - else if(op.hasTrait()) { + else if (op.hasTrait()) { // By default, we do not decrease the reference counter of a // terminator's argument. If the same value is used multiple times // as an argument, we need to increase its reference counter. builder.setInsertionPoint(&op); - for(size_t i = 1; i < op.getNumOperands(); i++) { + for (size_t i = 1; i < op.getNumOperands(); i++) { Value arg = op.getOperand(i); - for(size_t k = 0; k < i; k++) - if(arg == op.getOperand(k)) + for (size_t k = 0; k < i; k++) + if (arg == op.getOperand(k)) incRefIfObj(arg, builder); } } @@ -263,27 +257,21 @@ void processBlock(OpBuilder builder, Block * b) { // of vectorized pipelines, because internally, a pipeline processes // views into its inputs. These are individual data objects. - // 2) Make sure the reference counters of op results are decreased. - for(Value v : op.getResults()) + for (Value v : op.getResults()) processValue(builder, v); - // 3) Recurse into the op, if it has regions. - for(Region& r : op.getRegions()) - for(Block& b2 : r.getBlocks()) + for (Region &r : op.getRegions()) + for (Block &b2 : r.getBlocks()) processBlock(builder, &b2); } } -void ManageObjRefsPass::runOnOperation() -{ +void ManageObjRefsPass::runOnOperation() { func::FuncOp f = getOperation(); OpBuilder builder(f.getContext()); processBlock(builder, &(f.getBody().front())); } -std::unique_ptr daphne::createManageObjRefsPass() -{ - return std::make_unique(); -} +std::unique_ptr daphne::createManageObjRefsPass() { return std::make_unique(); } diff --git a/src/compiler/lowering/MapOpLowering.cpp b/src/compiler/lowering/MapOpLowering.cpp index 27fff5dcc..fa6ac9016 100644 --- a/src/compiler/lowering/MapOpLowering.cpp +++ b/src/compiler/lowering/MapOpLowering.cpp @@ -32,33 +32,27 @@ using namespace mlir; -class InlineMapOpLowering - : public mlir::OpConversionPattern { - public: +class InlineMapOpLowering : public mlir::OpConversionPattern { + public: using OpConversionPattern::OpConversionPattern; - mlir::LogicalResult matchAndRewrite( - mlir::daphne::MapOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { + mlir::LogicalResult matchAndRewrite(mlir::daphne::MapOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - mlir::daphne::MatrixType lhsMatrixType = - op->getOperandTypes().front().dyn_cast(); + mlir::daphne::MatrixType lhsMatrixType = op->getOperandTypes().front().dyn_cast(); auto matrixElementType = lhsMatrixType.getElementType(); - auto lhsMemRefType = mlir::MemRefType::get( - {lhsMatrixType.getNumRows(), lhsMatrixType.getNumCols()}, matrixElementType); + auto lhsMemRefType = + mlir::MemRefType::get({lhsMatrixType.getNumRows(), lhsMatrixType.getNumCols()}, matrixElementType); mlir::Value lhs = - rewriter.create( - loc, lhsMemRefType, adaptor.getArg()); + rewriter.create(loc, lhsMemRefType, adaptor.getArg()); mlir::ModuleOp module = op->getParentOfType(); - func::FuncOp udfFuncOp = - module.lookupSymbol(op.getFunc()); + func::FuncOp udfFuncOp = module.lookupSymbol(op.getFunc()); SmallVector loopIvs; - auto outerLoop = - rewriter.create(loc, 0, lhsMatrixType.getNumRows(), 1); + auto outerLoop = rewriter.create(loc, 0, lhsMatrixType.getNumRows(), 1); for (Operation &nested : *outerLoop.getBody()) { rewriter.eraseOp(&nested); } @@ -66,8 +60,7 @@ class InlineMapOpLowering // outer loop body rewriter.setInsertionPointToStart(outerLoop.getBody()); - auto innerLoop = - rewriter.create(loc, 0, lhsMatrixType.getNumCols(), 1); + auto innerLoop = rewriter.create(loc, 0, lhsMatrixType.getNumCols(), 1); for (Operation &nested : *innerLoop.getBody()) { rewriter.eraseOp(&nested); } @@ -77,15 +70,12 @@ class InlineMapOpLowering // inner loop body mlir::Value lhsValue = rewriter.create(loc, lhs, loopIvs); - mlir::Value res = - rewriter.create(loc, udfFuncOp, ValueRange{lhsValue}) - ->getResult(0); + mlir::Value res = rewriter.create(loc, udfFuncOp, ValueRange{lhsValue})->getResult(0); rewriter.create(loc, res, lhs, loopIvs); rewriter.create(loc); rewriter.setInsertionPointAfter(outerLoop); - mlir::Value output = convertMemRefToDenseMatrix(op->getLoc(), rewriter, - lhs, op.getType()); + mlir::Value output = convertMemRefToDenseMatrix(op->getLoc(), rewriter, lhs, op.getType()); rewriter.replaceOp(op, output); return mlir::success(); } @@ -100,14 +90,11 @@ namespace { * This rewrite enables subsequent inlining pass to completely replace * the daphne::MapOp by inlining the produced CallOps from this pass. */ -struct MapOpLoweringPass - : public mlir::PassWrapper> { +struct MapOpLoweringPass : public mlir::PassWrapper> { explicit MapOpLoweringPass() {} void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); } void runOnOperation() final; @@ -120,7 +107,7 @@ struct MapOpLoweringPass "UDF."; } }; -} // end anonymous namespace +} // end anonymous namespace void MapOpLoweringPass::runOnOperation() { mlir::ConversionTarget target(getContext()); @@ -128,8 +115,7 @@ void MapOpLoweringPass::runOnOperation() { mlir::LowerToLLVMOptions llvmOptions(&getContext()); mlir::LLVMTypeConverter typeConverter(&getContext(), llvmOptions); - target.addLegalDialect(); target.addIllegalOp(); @@ -141,6 +127,4 @@ void MapOpLoweringPass::runOnOperation() { } } -std::unique_ptr mlir::daphne::createMapOpLoweringPass() { - return std::make_unique(); -} +std::unique_ptr mlir::daphne::createMapOpLoweringPass() { return std::make_unique(); } diff --git a/src/compiler/lowering/MarkCUDAOpsPass.cpp b/src/compiler/lowering/MarkCUDAOpsPass.cpp index 96ea66f65..de89f2f16 100644 --- a/src/compiler/lowering/MarkCUDAOpsPass.cpp +++ b/src/compiler/lowering/MarkCUDAOpsPass.cpp @@ -24,76 +24,76 @@ using namespace mlir; struct MarkCUDAOpsPass : public PassWrapper> { - + /** * @brief User configuration influencing the rewrite pass */ - const DaphneUserConfig& cfg; + const DaphneUserConfig &cfg; size_t available_gpu_mem{}; size_t total_gpu_mem{}; size_t mem_budget; std::shared_ptr logger; - explicit MarkCUDAOpsPass(const DaphneUserConfig& cfg) : cfg(cfg) { + explicit MarkCUDAOpsPass(const DaphneUserConfig &cfg) : cfg(cfg) { // ToDo: use context and per device mem info cudaMemGetInfo(&available_gpu_mem, &total_gpu_mem); mem_budget = std::floor(0.9 * static_cast(total_gpu_mem)); logger = spdlog::get("compiler::cuda"); } - + void runOnOperation() final; - - void addCUDAOpsToVectorizedPipeline(OpBuilder& builder, daphne::VectorizedPipelineOp& pipelineOp) const { - - auto& pipeline = pipelineOp.getBody().front().getOperations(); + + void addCUDAOpsToVectorizedPipeline(OpBuilder &builder, daphne::VectorizedPipelineOp &pipelineOp) const { + + auto &pipeline = pipelineOp.getBody().front().getOperations(); bool build_cuda_pipeline; - - // add CUDA ops if at least one (cuda_fuse_any) or all (!cuda_fuse_any) ops would be supported - if(cfg.cuda_fuse_any) { - bool pipeline_has_supported_cuda_ops = llvm::any_of(pipeline, [&](Operation& o) { - return llvm::isa(o) || checkUseCUDA(&o); - }); + + // add CUDA ops if at least one (cuda_fuse_any) or all (!cuda_fuse_any) + // ops would be supported + if (cfg.cuda_fuse_any) { + bool pipeline_has_supported_cuda_ops = llvm::any_of( + pipeline, [&](Operation &o) { return llvm::isa(o) || checkUseCUDA(&o); }); build_cuda_pipeline = pipeline_has_supported_cuda_ops; - } - else { - bool pipeline_has_unsupported_cuda_ops = llvm::any_of(pipeline, [&](Operation& o) { + } else { + bool pipeline_has_unsupported_cuda_ops = llvm::any_of(pipeline, [&](Operation &o) { if (!llvm::isa(o)) { bool out = checkUseCUDA(&o); logger->trace("checking pipeline op for cuda: {}: {}", o.getName().getStringRef().str(), out); return !out; - } - else return false; + } else + return false; }); build_cuda_pipeline = !pipeline_has_unsupported_cuda_ops; } - - // clone body region into cuda region if there's a cuda supported op in body - if(build_cuda_pipeline) { + + // clone body region into cuda region if there's a cuda supported op in + // body + if (build_cuda_pipeline) { PatternRewriter::InsertionGuard insertGuard(builder); IRMapping mapper; pipelineOp.getBody().cloneInto(&pipelineOp.getCuda(), mapper); - for (auto &op: pipelineOp.getCuda().front().getOperations()) { + for (auto &op : pipelineOp.getCuda().front().getOperations()) { bool isMat = CompilerUtils::isMatrixComputation(&op); if (op.hasTrait() && isMat) op.setAttr("cuda_device", builder.getI32IntegerAttr(0)); } } } - - bool fitsInMemory(mlir::Operation* op) const { + + bool fitsInMemory(mlir::Operation *op) const { auto opSize = 0ul; - for(auto operand : op->getOperands()) { + for (auto operand : op->getOperands()) { auto type = operand.getType(); - if(auto t = type.dyn_cast()) { + if (auto t = type.dyn_cast()) { auto rows = t.getNumRows(); auto cols = t.getNumCols(); - if(rows < 0 || cols < 0) { + if (rows < 0 || cols < 0) { logger->warn("Ignoring unknown dimension in max mem check of {}" - "dims are: {}x{}\nsetting unknowns to 1 for this test", op->getName().getStringRef().str(), - rows, cols); - if(rows < 0) + "dims are: {}x{}\nsetting unknowns to 1 for this test", + op->getName().getStringRef().str(), rows, cols); + if (rows < 0) rows = 1; - if(cols < 0) + if (cols < 0) cols = 1; } opSize += rows * cols * t.getElementType().getIntOrFloatBitWidth() / 8; @@ -101,33 +101,34 @@ struct MarkCUDAOpsPass : public PassWrappertrace("op in size: {} kb", opSize / 1024); - for(auto result : op->getResults()) { + for (auto result : op->getResults()) { auto type = result.getType(); - if(auto t = type.dyn_cast()) { + if (auto t = type.dyn_cast()) { opSize += t.getNumRows() * t.getNumCols() * t.getElementType().getIntOrFloatBitWidth() / 8; } } - logger->debug("op out size: {} kb\ntotal op size: {} mb", (opSize-inSize) / 1024, - opSize / 1048576); + logger->debug("op out size: {} kb\ntotal op size: {} mb", (opSize - inSize) / 1024, opSize / 1048576); - if(opSize < mem_budget) + if (opSize < mem_budget) return true; else return false; } - + // ToDo: requirements should be set per operator in tablegen - bool hasReqMinDims(mlir::Operation* op) const { - auto checkDims = [this,op](const mlir::Type& type) -> bool { - if(auto t = type.dyn_cast()) { + bool hasReqMinDims(mlir::Operation *op) const { + auto checkDims = [this, op](const mlir::Type &type) -> bool { + if (auto t = type.dyn_cast()) { auto rows = t.getNumRows(); auto cols = t.getNumCols(); - if(rows < 0 || cols < 0) { - logger->warn("Ignoring unknown dimension in min input size check of {} dims are: {}x{}\nsetting " - "unknowns to 256 for this test", op->getName().getStringRef().str(), rows, cols); - if(rows < 0) + if (rows < 0 || cols < 0) { + logger->warn("Ignoring unknown dimension in min input size " + "check of {} dims are: {}x{}\nsetting " + "unknowns to 256 for this test", + op->getName().getStringRef().str(), rows, cols); + if (rows < 0) rows = 256; - if(cols < 0) + if (cols < 0) cols = 256; } return (rows > 255 || cols > 255); @@ -136,27 +137,27 @@ struct MarkCUDAOpsPass : public PassWrappergetOperandTypes()) { - if((ret = checkDims(type))) + for (auto type : op->getOperandTypes()) { + if ((ret = checkDims(type))) break; } - if(!ret) { - for (auto type: op->getResultTypes()) { - if((ret = checkDims(type))) + if (!ret) { + for (auto type : op->getResultTypes()) { + if ((ret = checkDims(type))) break; } } return ret; } - - bool checkUseCUDA(Operation* op) const { + + bool checkUseCUDA(Operation *op) const { logger->trace("checkUseCUDA: {}", op->getName().getStringRef().str()); bool use_cuda = op->hasTrait(); logger->trace("{} CUDA supported={}", op->getName().getStringRef().str(), use_cuda); use_cuda = use_cuda && CompilerUtils::isMatrixComputation(op); logger->trace("{} isMatrixComputation={}", op->getName().getStringRef().str(), use_cuda); - if(!cfg.force_cuda) { + if (!cfg.force_cuda) { use_cuda = use_cuda && hasReqMinDims(op); logger->trace("{} hasMinInputDims={}", op->getName().getStringRef().str(), use_cuda); use_cuda = use_cuda && fitsInMemory(op); @@ -167,21 +168,19 @@ struct MarkCUDAOpsPass : public PassWrapperwalk([&](Operation* op) { + getOperation()->walk([&](Operation *op) { logger->debug("MarkCUDAOpsPass: {} parent: {}", op->getName().getStringRef().str(), - op->getParentOp()->getName().getStringRef().str()); + op->getParentOp()->getName().getStringRef().str()); OpBuilder builder(op); // handle vectorizedPipelineOps - if (auto constOp = llvm::dyn_cast(op)) - { + if (auto constOp = llvm::dyn_cast(op)) { WalkResult::advance(); return; - } - else if (auto pipelineOp = llvm::dyn_cast(op)) + } else if (auto pipelineOp = llvm::dyn_cast(op)) addCUDAOpsToVectorizedPipeline(builder, pipelineOp); else { - if((!llvm::isa(op->getParentOp()) && checkUseCUDA(op)) || - llvm::isa(op)) { + if ((!llvm::isa(op->getParentOp()) && checkUseCUDA(op)) || + llvm::isa(op)) { op->setAttr("cuda_device", builder.getI32IntegerAttr(0)); } } @@ -189,7 +188,7 @@ void MarkCUDAOpsPass::runOnOperation() { }); } -std::unique_ptr daphne::createMarkCUDAOpsPass(const DaphneUserConfig& cfg) { +std::unique_ptr daphne::createMarkCUDAOpsPass(const DaphneUserConfig &cfg) { return std::make_unique(cfg); } diff --git a/src/compiler/lowering/MarkFPGAOPENCLOpsPass.cpp b/src/compiler/lowering/MarkFPGAOPENCLOpsPass.cpp index 7d281322a..ef5666d06 100644 --- a/src/compiler/lowering/MarkFPGAOPENCLOpsPass.cpp +++ b/src/compiler/lowering/MarkFPGAOPENCLOpsPass.cpp @@ -27,31 +27,31 @@ struct MarkFPGAOPENCLOpsPass : public PassWrappergetName().getStringRef().str() << std::endl; + bool checkUseFPGAOPENCL(Operation *op) const { + // std::cout << "checkUseFPGAOPENCL: " << + // op->getName().getStringRef().str() << std::endl; return op->hasTrait(); } }; void MarkFPGAOPENCLOpsPass::runOnOperation() { func::FuncOp f = getOperation(); - f->walk([&](Operation* op) { + f->walk([&](Operation *op) { OpBuilder builder(op); - if(checkUseFPGAOPENCL(op)) { + if (checkUseFPGAOPENCL(op)) { op->setAttr("fpgaopencl_device", builder.getI32IntegerAttr(0)); } WalkResult::advance(); }); } -std::unique_ptr daphne::createMarkFPGAOPENCLOpsPass(const DaphneUserConfig& cfg) { +std::unique_ptr daphne::createMarkFPGAOPENCLOpsPass(const DaphneUserConfig &cfg) { return std::make_unique(cfg); } -#endif +#endif diff --git a/src/compiler/lowering/MatMulOpLowering.cpp b/src/compiler/lowering/MatMulOpLowering.cpp index 80b6709b7..c0aef46db 100644 --- a/src/compiler/lowering/MatMulOpLowering.cpp +++ b/src/compiler/lowering/MatMulOpLowering.cpp @@ -23,7 +23,6 @@ #include #include "compiler/utils/LoweringUtils.h" -#include #include "hwloc.h" #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" @@ -66,6 +65,7 @@ #include "spdlog/spdlog.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include namespace mlir { #define GEN_PASS_DECL_MATMULOPLOWERINGPASS @@ -79,537 +79,531 @@ static constexpr int ROW = 0; static constexpr int COL = 1; struct LowerMatMulOpOptions { - LowerMatMulOpOptions() {} - int vec_size_bits{0}; - int num_vec_registers{0}; - bool vectorize{false}; - bool tile{false}; - bool invert_loops{false}; - bool useFixedTileSizes{false}; - llvm::SmallVector cache_sizes; - llvm::SmallVector tile_sizes; - int unroll_factor{0}; - int unroll_jam_factor{0}; - - LowerMatMulOpOptions &setTileSizes(std::vector sizes) { - tile_sizes.clear(); - for (auto s : sizes) { - tile_sizes.push_back(s); + LowerMatMulOpOptions() {} + int vec_size_bits{0}; + int num_vec_registers{0}; + bool vectorize{false}; + bool tile{false}; + bool invert_loops{false}; + bool useFixedTileSizes{false}; + llvm::SmallVector cache_sizes; + llvm::SmallVector tile_sizes; + int unroll_factor{0}; + int unroll_jam_factor{0}; + + LowerMatMulOpOptions &setTileSizes(std::vector sizes) { + tile_sizes.clear(); + for (auto s : sizes) { + tile_sizes.push_back(s); + } + return *this; + } + LowerMatMulOpOptions &setUnrollFactor(int f) { + unroll_factor = f; + return *this; + } + LowerMatMulOpOptions &setUnrollJamFactor(int f) { + unroll_jam_factor = f; + return *this; + } + LowerMatMulOpOptions &setCacheSizes(llvm::SmallVector caches) { + cache_sizes.clear(); + for (auto c : caches) { + cache_sizes.push_back(c); + } + return *this; + } + LowerMatMulOpOptions &enableVectorization(bool b = true) { + vectorize = b; + return *this; + } + LowerMatMulOpOptions &setVectorSizeBits(int s) { + vec_size_bits = s; + return *this; + } + LowerMatMulOpOptions &setNumberOfVectorRegisters(int s) { + num_vec_registers = s; + return *this; } - return *this; - } - LowerMatMulOpOptions &setUnrollFactor(int f) { - unroll_factor = f; - return *this; - } - LowerMatMulOpOptions &setUnrollJamFactor(int f) { - unroll_jam_factor = f; - return *this; - } - LowerMatMulOpOptions &setCacheSizes(llvm::SmallVector caches) { - cache_sizes.clear(); - for (auto c : caches) { - cache_sizes.push_back(c); + LowerMatMulOpOptions &enableTiling(bool b = true) { + tile = b; + return *this; } - return *this; - } - LowerMatMulOpOptions &enableVectorization(bool b = true) { - vectorize = b; - return *this; - } - LowerMatMulOpOptions &setVectorSizeBits(int s) { - vec_size_bits = s; - return *this; - } - LowerMatMulOpOptions &setNumberOfVectorRegisters(int s) { - num_vec_registers = s; - return *this; - } - LowerMatMulOpOptions &enableTiling(bool b = true) { - tile = b; - return *this; - } - LowerMatMulOpOptions &enableLoopInversion(bool b = true) { - invert_loops = b; - return *this; - } - int getVecSize(int bitwidth) const { - if (vec_size_bits > 0) { - return std::max(1, vec_size_bits / bitwidth); - } else { - return 1; + LowerMatMulOpOptions &enableLoopInversion(bool b = true) { + invert_loops = b; + return *this; } - } - int getRegisterSize() const { - if (num_vec_registers != 0 && vec_size_bits != 0) { - return std::max(1, num_vec_registers * vec_size_bits); + int getVecSize(int bitwidth) const { + if (vec_size_bits > 0) { + return std::max(1, vec_size_bits / bitwidth); + } else { + return 1; + } + } + int getRegisterSize() const { + if (num_vec_registers != 0 && vec_size_bits != 0) { + return std::max(1, num_vec_registers * vec_size_bits); + } + return 1; } - return 1; - } }; bool is_valid_options(LowerMatMulOpOptions const options) { - for (auto s : options.tile_sizes) - if (s <= 1) { - spdlog::warn("Tile sizes must be an integer larger than 1."); - return false; + for (auto s : options.tile_sizes) + if (s <= 1) { + spdlog::warn("Tile sizes must be an integer larger than 1."); + return false; + } + if (options.unroll_factor < 0) { + spdlog::warn("Unroll factor must be an integer >= 0."); + return false; } - if (options.unroll_factor < 0) { - spdlog::warn("Unroll factor must be an integer >= 0."); - return false; - } - if (options.unroll_jam_factor < 0) { - spdlog::warn("Unroll jam factor must be an integer >= 0."); - return false; - } - if (options.vec_size_bits < 0) { - spdlog::warn("Vector size bits must be an integer >= 0."); - return false; - } - return true; -} - -class MatMulLowering : public OpConversionPattern { - const LowerMatMulOpOptions options; - -public: - using OpConversionPattern::OpConversionPattern; - explicit MatMulLowering(mlir::TypeConverter &typeConverter, - MLIRContext *context, - LowerMatMulOpOptions const &options) - : OpConversionPattern(typeConverter, context, - PatternBenefit(1)), - options(options) { - this->setDebugName("MatMulLowering"); - } - - bool is_vectorizable(ArrayRef const rhsShape, - Type const matrixElementType) const { - if (rhsShape[COL] % - options.getVecSize(matrixElementType.getIntOrFloatBitWidth()) != - 0) { - return false; + if (options.unroll_jam_factor < 0) { + spdlog::warn("Unroll jam factor must be an integer >= 0."); + return false; } - if (!matrixElementType.isa()) { - return false; + if (options.vec_size_bits < 0) { + spdlog::warn("Vector size bits must be an integer >= 0."); + return false; } return true; - } - - bool is_tileable(ArrayRef const rhsShape) const { return true; } - - llvm::SmallVector - affineMatMul(mlir::Value &lhs, mlir::Value &rhs, mlir::Value &output, - ConversionPatternRewriter &rewriter, mlir::Location loc, - ArrayRef lhsShape, ArrayRef rhsShape, - mlir::MLIRContext *ctx, SmallVector &loops, - Type elementType) const { - // row loop - auto rowLoop = rewriter.create(loc, 0, lhsShape[ROW], 1); - // row loop body - rewriter.setInsertionPointToStart(rowLoop.getBody()); - // col loop - auto colLoop = rewriter.create(loc, 0, rhsShape[COL], 1); - // col loop body - rewriter.setInsertionPointToStart(colLoop.getBody()); - // fma loop - auto fmaLoop = rewriter.create(loc, 0, rhsShape[ROW], 1); - // inner loop body - rewriter.setInsertionPointToStart(fmaLoop.getBody()); - - auto a = rewriter.create( - loc, lhs, - ValueRange{rowLoop.getInductionVar(), fmaLoop.getInductionVar()}); - auto b = rewriter.create( - loc, rhs, - ValueRange{fmaLoop.getInductionVar(), colLoop.getInductionVar()}); - auto c = rewriter.create( - loc, output, - ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); - if (elementType.isIntOrIndex()) { - // Arith operates on MLIR signless integers, while Daphne uses (un)signed - // integers. - Value castedA = this->typeConverter->materializeTargetConversion( - rewriter, loc, - rewriter.getIntegerType(elementType.getIntOrFloatBitWidth()), - ValueRange{a}); - Value castedB = this->typeConverter->materializeTargetConversion( - rewriter, loc, - rewriter.getIntegerType(elementType.getIntOrFloatBitWidth()), - ValueRange{b}); - Value castedC = this->typeConverter->materializeTargetConversion( - rewriter, loc, - rewriter.getIntegerType(elementType.getIntOrFloatBitWidth()), - ValueRange{c}); - Value added = rewriter.create(loc, castedA, castedB); - Value res = rewriter.create(loc, added, castedC); - Value castedRes = this->typeConverter->materializeSourceConversion( - rewriter, loc, elementType, ValueRange{res}); - rewriter.create( - loc, castedRes, output, - ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); - } else { - Value res = rewriter.create(loc, a, b, c); - rewriter.create( - loc, res, output, - ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); - } +} - // AffineYieldOp at end of loop blocks - rewriter.setInsertionPointAfter(fmaLoop); - rewriter.setInsertionPointAfter(colLoop); - rewriter.setInsertionPointAfter(rowLoop); - - loops.push_back(rowLoop); - loops.push_back(colLoop); - loops.push_back(fmaLoop); - return loops; - } - - llvm::SmallVector vectorizedAffineMatMul( - mlir::Value &lhs, mlir::Value &rhs, mlir::Value &output, - ConversionPatternRewriter &rewriter, mlir::Location loc, - ArrayRef lhsShape, ArrayRef rhsShape, - mlir::MLIRContext *ctx, llvm::SmallVector &loops, - Type elementType, int64_t vec_size) const { - auto vec_Type = mlir::VectorType::get({vec_size}, elementType); - - // row loop - auto rowLoop = rewriter.create(loc, 0, lhsShape[ROW], 1); - // row loop body - rewriter.setInsertionPointToStart(rowLoop.getBody()); - // col loop - auto colLoop = - rewriter.create(loc, 0, rhsShape[COL], vec_size); - // col loop body - rewriter.setInsertionPointToStart(colLoop.getBody()); - // fma loop - auto fmaLoop = rewriter.create(loc, 0, rhsShape[ROW], 1); - // inner loop body - rewriter.setInsertionPointToStart(fmaLoop.getBody()); - - auto a_single = rewriter.create( - loc, lhs, - ValueRange{rowLoop.getInductionVar(), fmaLoop.getInductionVar()}); - auto a = rewriter.create(loc, a_single, vec_Type); - auto b = rewriter.create( - loc, vec_Type, rhs, - ValueRange{fmaLoop.getInductionVar(), colLoop.getInductionVar()}); - auto c = rewriter.create( - loc, vec_Type, output, - ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); - - // TODO: Integer doesn't actually work yet, so is disabled in - // is_vectorizable. - if (elementType.isIntOrIndex()) { - Value added = rewriter.create(loc, a, b); - Value res = rewriter.create(loc, added, c); - rewriter.create( - loc, res, output, - ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); - } else { - Value res = rewriter.create(loc, a, b, c); - rewriter.create( - loc, res, output, - ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); +class MatMulLowering : public OpConversionPattern { + const LowerMatMulOpOptions options; + + public: + using OpConversionPattern::OpConversionPattern; + explicit MatMulLowering(mlir::TypeConverter &typeConverter, MLIRContext *context, + LowerMatMulOpOptions const &options) + : OpConversionPattern(typeConverter, context, PatternBenefit(1)), options(options) { + this->setDebugName("MatMulLowering"); } - // AffineYieldOp at end of loop blocks - rewriter.setInsertionPointAfter(fmaLoop); - rewriter.setInsertionPointAfter(colLoop); - rewriter.setInsertionPointAfter(rowLoop); - - loops.push_back(rowLoop); - loops.push_back(colLoop); - loops.push_back(fmaLoop); - return loops; - } - - LogicalResult - matchAndRewrite(daphne::MatMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - mlir::daphne::MatrixType lhsMatrixType = - adaptor.getLhs().getType().dyn_cast(); - mlir::daphne::MatrixType rhsMatrixType = - adaptor.getRhs().getType().dyn_cast(); - - auto lhsRows = lhsMatrixType.getNumRows(); - auto lhsCols = lhsMatrixType.getNumCols(); - - auto rhsRows = rhsMatrixType.getNumRows(); - auto rhsCols = rhsMatrixType.getNumCols(); - - auto matrixElementType = lhsMatrixType.getElementType(); - - // TODO(phil): if shape is unknown, e.g., row/col = -1 we currently - // can't create a MemRefType - auto lhsMemRefType = - mlir::MemRefType::get({lhsRows, lhsCols}, matrixElementType); - auto rhsMemRefType = - mlir::MemRefType::get({rhsRows, rhsCols}, matrixElementType); - - mlir::MemRefType outputMemRefType = - mlir::MemRefType::get({lhsRows, rhsCols}, matrixElementType); - - // daphne::Matrix -> memref - mlir::Value lhs = rewriter.create( - op->getLoc(), lhsMemRefType, adaptor.getLhs()); - mlir::Value rhs = rewriter.create( - op->getLoc(), rhsMemRefType, adaptor.getRhs()); - - // Alloc output memref - mlir::Value outputMemRef = - insertMemRefAlloc(outputMemRefType, loc, rewriter); - - // Fill the output MemRef - if (matrixElementType.isIntOrIndex()) { - auto signless_type = - rewriter.getIntegerType(matrixElementType.getIntOrFloatBitWidth()); - auto fillValue = rewriter.create( - loc, signless_type, rewriter.getIntegerAttr(signless_type, 0)); - auto castedFillValue = this->typeConverter->materializeTargetConversion( - rewriter, loc, matrixElementType, mlir::ValueRange{fillValue}); - affineFillMemRefInt(castedFillValue, rewriter, loc, - outputMemRefType.getShape(), op->getContext(), - outputMemRef); - } else { - affineFillMemRef(0.0, rewriter, loc, outputMemRefType.getShape(), - op->getContext(), outputMemRef, matrixElementType); + bool is_vectorizable(ArrayRef const rhsShape, Type const matrixElementType) const { + if (rhsShape[COL] % options.getVecSize(matrixElementType.getIntOrFloatBitWidth()) != 0) { + return false; + } + if (!matrixElementType.isa()) { + return false; + } + return true; } - // Do the actual MatMul with hand built codegen - SmallVector loops; - if (options.vectorize && - is_vectorizable(rhsMemRefType.getShape(), matrixElementType)) { - vectorizedAffineMatMul( - lhs, rhs, outputMemRef, rewriter, loc, lhsMemRefType.getShape(), - rhsMemRefType.getShape(), op->getContext(), loops, matrixElementType, - options.getVecSize(matrixElementType.getIntOrFloatBitWidth())); - } else { - affineMatMul(lhs, rhs, outputMemRef, rewriter, loc, - lhsMemRefType.getShape(), rhsMemRefType.getShape(), - op->getContext(), loops, matrixElementType); - } - if (options.tile && is_tileable(rhsMemRefType.getShape())) { - auto tile_sizes = extendTileSizes(lhsRows); - if (!options.useFixedTileSizes) { - tile_sizes = getTileSizesFromCache(matrixElementType, - loops[1].getStep(), lhsRows); - } - tile_loops(loc, loops, tile_sizes); - } else if (options.invert_loops){ - permuteLoops(loops, {0, 2, 1}); - } - mlir::Value DM = - convertMemRefToDenseMatrix(loc, rewriter, outputMemRef, op.getType()); - - rewriter.replaceOp(op, DM); - return success(); - } - - // tile_loops requires 5 tile sizes. If fewer tile sizes are specified, we can - // extend with the size of the loop, since loops with only one iteration are - // later removed. - SmallVector extendTileSizes(int64_t max_loop_length) const { - SmallVector tile_sizes = options.tile_sizes; - while (tile_sizes.size() < 5) { - tile_sizes.push_back(max_loop_length); - } - return tile_sizes; - } - - // Choose tile sizes so that reuse is happening across the cache levels. This - // is just a proof of concept and not a very sophisticated strategy. Assuming - // cache sizes are in Bytes not KB or other units. Assume square matmul of - // length loop_length. The target below is laid out assuming there are a - // number of vector registers available. If not all cache sizes "move down" a - // slot if set. If there are also no cache sizes available, set MR and NR to - // 2, since otherwise the tiling breaks. Target: MR * NR ~ Register size * 3 - // / 4 - // KC * NR ~ L1, - // MC * KC ~ L2, - // NC * MC ~ L3 - // & NR divides NC & MR divides MC - SmallVector getTileSizesFromCache(Type const matrixElementType, - int64_t vec_size, - int64_t loop_length) const { - SmallVector tile_sizes; - int bitwidth = matrixElementType.getIntOrFloatBitWidth(); - int register_size = options.getRegisterSize(); - int no_register = 0; - if (register_size == 1) { - if (options.cache_sizes.size() > 0) { - tile_sizes.push_back( - std::max(2, (int)(std::sqrt(register_size / bitwidth)))); - tile_sizes.push_back(tile_sizes.back()); - no_register++; - } else { - tile_sizes.push_back(2); - tile_sizes.push_back(2); - } - } else { - tile_sizes.push_back( - std::max(2, (int)(std::sqrt(register_size / bitwidth * 3 / 4)))); - tile_sizes.push_back(tile_sizes.back()); + + bool is_tileable(ArrayRef const rhsShape) const { return true; } + + llvm::SmallVector affineMatMul(mlir::Value &lhs, mlir::Value &rhs, mlir::Value &output, + ConversionPatternRewriter &rewriter, mlir::Location loc, + ArrayRef lhsShape, ArrayRef rhsShape, + mlir::MLIRContext *ctx, SmallVector &loops, + Type elementType) const { + // row loop + auto rowLoop = rewriter.create(loc, 0, lhsShape[ROW], 1); + // row loop body + rewriter.setInsertionPointToStart(rowLoop.getBody()); + // col loop + auto colLoop = rewriter.create(loc, 0, rhsShape[COL], 1); + // col loop body + rewriter.setInsertionPointToStart(colLoop.getBody()); + // fma loop + auto fmaLoop = rewriter.create(loc, 0, rhsShape[ROW], 1); + // inner loop body + rewriter.setInsertionPointToStart(fmaLoop.getBody()); + + auto a = + rewriter.create(loc, lhs, ValueRange{rowLoop.getInductionVar(), fmaLoop.getInductionVar()}); + auto b = + rewriter.create(loc, rhs, ValueRange{fmaLoop.getInductionVar(), colLoop.getInductionVar()}); + auto c = rewriter.create(loc, output, + ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); + if (elementType.isIntOrIndex()) { + // Arith operates on MLIR signless integers, while Daphne uses + // (un)signed integers. + Value castedA = this->typeConverter->materializeTargetConversion( + rewriter, loc, rewriter.getIntegerType(elementType.getIntOrFloatBitWidth()), ValueRange{a}); + Value castedB = this->typeConverter->materializeTargetConversion( + rewriter, loc, rewriter.getIntegerType(elementType.getIntOrFloatBitWidth()), ValueRange{b}); + Value castedC = this->typeConverter->materializeTargetConversion( + rewriter, loc, rewriter.getIntegerType(elementType.getIntOrFloatBitWidth()), ValueRange{c}); + Value added = rewriter.create(loc, castedA, castedB); + Value res = rewriter.create(loc, added, castedC); + Value castedRes = + this->typeConverter->materializeSourceConversion(rewriter, loc, elementType, ValueRange{res}); + rewriter.create(loc, castedRes, output, + ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); + } else { + Value res = rewriter.create(loc, a, b, c); + rewriter.create(loc, res, output, + ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); + } + + // AffineYieldOp at end of loop blocks + rewriter.setInsertionPointAfter(fmaLoop); + rewriter.setInsertionPointAfter(colLoop); + rewriter.setInsertionPointAfter(rowLoop); + + loops.push_back(rowLoop); + loops.push_back(colLoop); + loops.push_back(fmaLoop); + return loops; } - if (options.cache_sizes.size() > 0) { - int idx = 0; - for (auto cache_size = options.cache_sizes.begin() + no_register; - cache_size != options.cache_sizes.end(); cache_size++) { - unsigned candidate = - std::max(1, (int)(*cache_size / tile_sizes.back() / bitwidth)); - if (idx == 3) - candidate = candidate - (candidate % tile_sizes[0]); - if (idx == 4) - candidate = candidate - (candidate % tile_sizes[1]); - tile_sizes.push_back(candidate); - idx++; - } + + llvm::SmallVector vectorizedAffineMatMul(mlir::Value &lhs, mlir::Value &rhs, mlir::Value &output, + ConversionPatternRewriter &rewriter, mlir::Location loc, + ArrayRef lhsShape, ArrayRef rhsShape, + mlir::MLIRContext *ctx, + llvm::SmallVector &loops, Type elementType, + int64_t vec_size) const { + auto vec_Type = mlir::VectorType::get({vec_size}, elementType); + + // row loop + auto rowLoop = rewriter.create(loc, 0, lhsShape[ROW], 1); + // row loop body + rewriter.setInsertionPointToStart(rowLoop.getBody()); + // col loop + auto colLoop = rewriter.create(loc, 0, rhsShape[COL], vec_size); + // col loop body + rewriter.setInsertionPointToStart(colLoop.getBody()); + // fma loop + auto fmaLoop = rewriter.create(loc, 0, rhsShape[ROW], 1); + // inner loop body + rewriter.setInsertionPointToStart(fmaLoop.getBody()); + + auto a_single = + rewriter.create(loc, lhs, ValueRange{rowLoop.getInductionVar(), fmaLoop.getInductionVar()}); + auto a = rewriter.create(loc, a_single, vec_Type); + auto b = rewriter.create(loc, vec_Type, rhs, + ValueRange{fmaLoop.getInductionVar(), colLoop.getInductionVar()}); + auto c = rewriter.create(loc, vec_Type, output, + ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); + + // TODO: Integer doesn't actually work yet, so is disabled in + // is_vectorizable. + if (elementType.isIntOrIndex()) { + Value added = rewriter.create(loc, a, b); + Value res = rewriter.create(loc, added, c); + rewriter.create(loc, res, output, + ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); + } else { + Value res = rewriter.create(loc, a, b, c); + rewriter.create(loc, res, output, + ValueRange{rowLoop.getInductionVar(), colLoop.getInductionVar()}); + } + + // AffineYieldOp at end of loop blocks + rewriter.setInsertionPointAfter(fmaLoop); + rewriter.setInsertionPointAfter(colLoop); + rewriter.setInsertionPointAfter(rowLoop); + + loops.push_back(rowLoop); + loops.push_back(colLoop); + loops.push_back(fmaLoop); + return loops; } - while (tile_sizes.size() < 5) { - tile_sizes.push_back(loop_length); + + LogicalResult matchAndRewrite(daphne::MatMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + mlir::daphne::MatrixType lhsMatrixType = adaptor.getLhs().getType().dyn_cast(); + mlir::daphne::MatrixType rhsMatrixType = adaptor.getRhs().getType().dyn_cast(); + + auto lhsRows = lhsMatrixType.getNumRows(); + auto lhsCols = lhsMatrixType.getNumCols(); + + auto rhsRows = rhsMatrixType.getNumRows(); + auto rhsCols = rhsMatrixType.getNumCols(); + + auto matrixElementType = lhsMatrixType.getElementType(); + + // TODO(phil): if shape is unknown, e.g., row/col = -1 we currently + // can't create a MemRefType + auto lhsMemRefType = mlir::MemRefType::get({lhsRows, lhsCols}, matrixElementType); + auto rhsMemRefType = mlir::MemRefType::get({rhsRows, rhsCols}, matrixElementType); + + mlir::MemRefType outputMemRefType = mlir::MemRefType::get({lhsRows, rhsCols}, matrixElementType); + + // daphne::Matrix -> memref + mlir::Value lhs = + rewriter.create(op->getLoc(), lhsMemRefType, adaptor.getLhs()); + mlir::Value rhs = + rewriter.create(op->getLoc(), rhsMemRefType, adaptor.getRhs()); + + // Alloc output memref + mlir::Value outputMemRef = insertMemRefAlloc(outputMemRefType, loc, rewriter); + + // Fill the output MemRef + if (matrixElementType.isIntOrIndex()) { + auto signless_type = rewriter.getIntegerType(matrixElementType.getIntOrFloatBitWidth()); + auto fillValue = + rewriter.create(loc, signless_type, rewriter.getIntegerAttr(signless_type, 0)); + auto castedFillValue = this->typeConverter->materializeTargetConversion(rewriter, loc, matrixElementType, + mlir::ValueRange{fillValue}); + affineFillMemRef(castedFillValue, rewriter, loc, outputMemRefType.getShape(), op->getContext(), + outputMemRef); + } else { + mlir::Value fillValue = rewriter.create( + loc, matrixElementType, rewriter.getFloatAttr(matrixElementType, 0.0)); + affineFillMemRef(fillValue, rewriter, loc, outputMemRefType.getShape(), op->getContext(), outputMemRef); + } + // Do the actual MatMul with hand built codegen + SmallVector loops; + if (options.vectorize && is_vectorizable(rhsMemRefType.getShape(), matrixElementType)) { + vectorizedAffineMatMul(lhs, rhs, outputMemRef, rewriter, loc, lhsMemRefType.getShape(), + rhsMemRefType.getShape(), op->getContext(), loops, matrixElementType, + options.getVecSize(matrixElementType.getIntOrFloatBitWidth())); + } else { + affineMatMul(lhs, rhs, outputMemRef, rewriter, loc, lhsMemRefType.getShape(), rhsMemRefType.getShape(), + op->getContext(), loops, matrixElementType); + } + if (options.tile && is_tileable(rhsMemRefType.getShape())) { + auto tile_sizes = extendTileSizes(lhsRows); + if (!options.useFixedTileSizes) { + tile_sizes = getTileSizesFromCache(matrixElementType, loops[1].getStep(), lhsRows); + } + tile_loops(loc, loops, tile_sizes); + } else if (options.invert_loops) { + permuteLoops(loops, {0, 2, 1}); + } + mlir::Value DM = convertMemRefToDenseMatrix(loc, rewriter, outputMemRef, op.getType()); + + rewriter.replaceOp(op, DM); + return success(); } - // If vector size is longer than 1, we need to keep that in mind for the NR - // loop - if (vec_size > 1) - tile_sizes[1] = std::max(1, (int)(tile_sizes[1] / vec_size)); - return tile_sizes; - } - - // Tile the affine loop nest generated from MatMulOp with the specified tile - // sizes. Includes validations to follow the movement and creation of the tile - // loops. - void tile_loops(mlir::Location loc, - SmallVector loops, - SmallVector tile_sizes) const { - unsigned NC = tile_sizes[4]; - unsigned MC = tile_sizes[3]; - unsigned KC = tile_sizes[2]; - unsigned NR = tile_sizes[1]; - unsigned MR = tile_sizes[0]; - unsigned KU = options.unroll_factor; - [[maybe_unused]] auto vec_size = loops[1].getStep(); - llvm::SmallVector loopNest; - getPerfectlyNestedLoops(loopNest, loops.front()); - // tile i with MC, j with NC, k with KC - llvm::SmallVector tiledNest; - if (failed(tilePerfectlyNested(loopNest, {MC, NC, KC}, &tiledNest))) { - spdlog::warn("Could not tile the loop nest in MatMulLowering"); - }; - - #define GEN_ERR_MSG(name, size, expected) \ - std::string(name) + " should have step size " + std::string(expected) + " but is " + std::to_string(size) - - if (tiledNest[0].getStep() != MC) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("tiledNest 0", tiledNest[0].getStep(), "MC (" + std::to_string(MC) + ")")); - if (tiledNest[1].getStep() != NC * vec_size) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("tiledNest 1", tiledNest[1].getStep(), "NC * vec_size (" + std::to_string(NC * vec_size) + ")")); - if (tiledNest[2].getStep() != KC) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("tiledNest 2", tiledNest[2].getStep(), "KC (" + std::to_string(KC) + ")")); - if (tiledNest[3].getStep() != 1) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("tiledNest 3", tiledNest[3].getStep(), "1")); - if (tiledNest[4].getStep() != vec_size) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("tiledNest 4", tiledNest[4].getStep(), "vec_size (" + std::to_string(vec_size) + ")")); - if (tiledNest[5].getStep() != 1) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("tiledNest 5", tiledNest[5].getStep(), "1")); - - // Further tile the i mod MC loop with MR - if (failed(tilePerfectlyNested(tiledNest[3], {MR}))) { - spdlog::warn("Could not tile the second i loop in MatMulLowering"); - }; - - // Further tile the j mod NC loop with NR - if (tiledNest[4].getStep() != vec_size) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("tiledNest 4", tiledNest[4].getStep(), "vec_size (" + std::to_string(vec_size) + ")")); - if (failed(tilePerfectlyNested(tiledNest[4], {NR}))) { - spdlog::warn("Could not tile the second j loop in MatMulLowering"); - }; - - llvm::SmallVector twiceTiledNest; - getPerfectlyNestedLoops(twiceTiledNest, tiledNest[0]); - // i loops - if (twiceTiledNest[0].getStep() != MC) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("twiceTiledNest 0", twiceTiledNest[0].getStep(), "MC (" + std::to_string(MC) + ")")); - if (twiceTiledNest[3].getStep() != MR) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("twiceTiledNest 3", twiceTiledNest[3].getStep(), "MR (" + std::to_string(MR) + ")")); - if (twiceTiledNest[4].getStep() != 1) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("twiceTiledNest 4", twiceTiledNest[4].getStep(), "1")); - - // j loops - if (twiceTiledNest[1].getStep() != NC * vec_size) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("twiceTiledNest 1", twiceTiledNest[1].getStep(), "NC * vec_size (" + std::to_string(NC * vec_size) + ")")); - if (twiceTiledNest[5].getStep() != NR * vec_size) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("twiceTiledNest 5", twiceTiledNest[5].getStep(), "NR * vec_size (" + std::to_string(NR * vec_size) + ")")); - if (twiceTiledNest[6].getStep() != vec_size) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("twiceTiledNest 6", twiceTiledNest[6].getStep(), "vec_size (" + std::to_string(vec_size) + ")")); - - // k loops - if (twiceTiledNest[2].getStep() != KC) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("twiceTiledNest 2", twiceTiledNest[2].getStep(), "KC (" + std::to_string(KC) + ")")); - if (twiceTiledNest[7].getStep() != 1) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("twiceTiledNest 7", twiceTiledNest[7].getStep(), "1")); - - // permute loops to final order (i / MC, j / NC, k / KC, i / MR, i mod MR, j - // / NR, j mod NR, k mod KC) -> - // (j / NC, k / KC, i / MC, j / NR, i / MR, k - // mod KC, j mod NR, i mod MR) - unsigned root_idx = permuteLoops(twiceTiledNest, {2, 0, 1, 4, 7, 3, 6, 5}); - - // Unroll and jam - llvm::SmallVector blisTiledLoops; - getPerfectlyNestedLoops(blisTiledLoops, twiceTiledNest[root_idx]); - // i loops - if (blisTiledLoops[2].getStep() != MC) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("blisTiled 2", blisTiledLoops[2].getStep(), "MC (" + std::to_string(MC) + ")")); - if (blisTiledLoops[4].getStep() != MR) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("blisTiled 4", blisTiledLoops[4].getStep(), "MR (" + std::to_string(MR) + ")")); - if (blisTiledLoops[7].getStep() != 1) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("blisTiled 7", blisTiledLoops[7].getStep(), "1")); - - // j loops - if (blisTiledLoops[0].getStep() != NC * vec_size) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("blisTiled 0", blisTiledLoops[0].getStep(), "NC * vec_size (" + std::to_string(NC * vec_size) + ")")); - if (blisTiledLoops[3].getStep() != NR * vec_size) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("blisTiled 3", blisTiledLoops[3].getStep(), "NR * vec_size (" + std::to_string(NR * vec_size) + ")")); - if (blisTiledLoops[6].getStep() != vec_size) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("blisTiled 6", blisTiledLoops[6].getStep(), "vec_size (" + std::to_string(vec_size) + ")")); - - // k loops - if (blisTiledLoops[1].getStep() != KC) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("blisTiled 1", blisTiledLoops[1].getStep(), "KC (" + std::to_string(KC) + ")")); - if (blisTiledLoops[5].getStep() != 1) - throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", GEN_ERR_MSG("blisTiled 5", blisTiledLoops[5].getStep(), "1")); - - #undef GEN_ERR_MSG - - // Unroll jam causes Segfault, if called in a way where the loop is not - // cleanly divided. - if (options.unroll_jam_factor > 0 && - blisTiledLoops[5].getUpperBound().getMap().getNumResults() == 1 && - succeeded(loopUnrollJamUpToFactor(blisTiledLoops[5], - options.unroll_jam_factor))) { - if (blisTiledLoops[6].getUpperBound().getMap().getNumResults() != 1 || - failed(loopUnrollJamUpToFactor(blisTiledLoops[6], - options.unroll_jam_factor))) { - spdlog::warn( - "Could not unroll the (j mod NC) mod NR loop in MatMulLowering"); - } - } else { - spdlog::warn( - "Could not unroll the (i mod MC) mod MR loop in MatMulLowering"); + + // tile_loops requires 5 tile sizes. If fewer tile sizes are specified, we + // can extend with the size of the loop, since loops with only one iteration + // are later removed. + SmallVector extendTileSizes(int64_t max_loop_length) const { + SmallVector tile_sizes = options.tile_sizes; + while (tile_sizes.size() < 5) { + tile_sizes.push_back(max_loop_length); + } + return tile_sizes; } - llvm::SmallVector lastNest; - getPerfectlyNestedLoops(lastNest, blisTiledLoops.front()); - int64_t i = 0; - while (succeeded(promoteIfSingleIteration(lastNest[i])) && i < 4) { - i++; + // Choose tile sizes so that reuse is happening across the cache levels. + // This is just a proof of concept and not a very sophisticated strategy. + // Assuming cache sizes are in Bytes not KB or other units. Assume square + // matmul of length loop_length. The target below is laid out assuming there + // are a number of vector registers available. If not all cache sizes "move + // down" a slot if set. If there are also no cache sizes available, set MR + // and NR to 2, since otherwise the tiling breaks. Target: MR * NR ~ + // Register size * 3 / 4 + // KC * NR ~ L1, + // MC * KC ~ L2, + // NC * MC ~ L3 + // & NR divides NC & MR divides MC + SmallVector getTileSizesFromCache(Type const matrixElementType, int64_t vec_size, + int64_t loop_length) const { + SmallVector tile_sizes; + int bitwidth = matrixElementType.getIntOrFloatBitWidth(); + int register_size = options.getRegisterSize(); + int no_register = 0; + if (register_size == 1) { + if (options.cache_sizes.size() > 0) { + tile_sizes.push_back(std::max(2, (int)(std::sqrt(register_size / bitwidth)))); + tile_sizes.push_back(tile_sizes.back()); + no_register++; + } else { + tile_sizes.push_back(2); + tile_sizes.push_back(2); + } + } else { + tile_sizes.push_back(std::max(2, (int)(std::sqrt(register_size / bitwidth * 3 / 4)))); + tile_sizes.push_back(tile_sizes.back()); + } + if (options.cache_sizes.size() > 0) { + int idx = 0; + for (auto cache_size = options.cache_sizes.begin() + no_register; cache_size != options.cache_sizes.end(); + cache_size++) { + unsigned candidate = std::max(1, (int)(*cache_size / tile_sizes.back() / bitwidth)); + if (idx == 3) + candidate = candidate - (candidate % tile_sizes[0]); + if (idx == 4) + candidate = candidate - (candidate % tile_sizes[1]); + tile_sizes.push_back(candidate); + idx++; + } + } + while (tile_sizes.size() < 5) { + tile_sizes.push_back(loop_length); + } + // If vector size is longer than 1, we need to keep that in mind for the + // NR loop + if (vec_size > 1) + tile_sizes[1] = std::max(1, (int)(tile_sizes[1] / vec_size)); + return tile_sizes; } - if (KU > 0 && failed(loopUnrollUpToFactor(lastNest.back(), KU))) { - spdlog::warn("Could not unroll the K loop in MatMulLowering"); + // Tile the affine loop nest generated from MatMulOp with the specified tile + // sizes. Includes validations to follow the movement and creation of the + // tile loops. + void tile_loops(mlir::Location loc, SmallVector loops, SmallVector tile_sizes) const { + unsigned NC = tile_sizes[4]; + unsigned MC = tile_sizes[3]; + unsigned KC = tile_sizes[2]; + unsigned NR = tile_sizes[1]; + unsigned MR = tile_sizes[0]; + unsigned KU = options.unroll_factor; + [[maybe_unused]] auto vec_size = loops[1].getStep(); + llvm::SmallVector loopNest; + getPerfectlyNestedLoops(loopNest, loops.front()); + // tile i with MC, j with NC, k with KC + llvm::SmallVector tiledNest; + if (failed(tilePerfectlyNested(loopNest, {MC, NC, KC}, &tiledNest))) { + spdlog::warn("Could not tile the loop nest in MatMulLowering"); + }; + +#define GEN_ERR_MSG(name, size, expected) \ + std::string(name) + " should have step size " + std::string(expected) + " but is " + std::to_string(size) + + if (tiledNest[0].getStep() != MC) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("tiledNest 0", tiledNest[0].getStep(), "MC (" + std::to_string(MC) + ")")); + if (tiledNest[1].getStep() != NC * vec_size) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("tiledNest 1", tiledNest[1].getStep(), + "NC * vec_size (" + std::to_string(NC * vec_size) + ")")); + if (tiledNest[2].getStep() != KC) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("tiledNest 2", tiledNest[2].getStep(), "KC (" + std::to_string(KC) + ")")); + if (tiledNest[3].getStep() != 1) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("tiledNest 3", tiledNest[3].getStep(), "1")); + if (tiledNest[4].getStep() != vec_size) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("tiledNest 4", tiledNest[4].getStep(), "vec_size (" + std::to_string(vec_size) + ")")); + if (tiledNest[5].getStep() != 1) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("tiledNest 5", tiledNest[5].getStep(), "1")); + + // Further tile the i mod MC loop with MR + if (failed(tilePerfectlyNested(tiledNest[3], {MR}))) { + spdlog::warn("Could not tile the second i loop in MatMulLowering"); + }; + + // Further tile the j mod NC loop with NR + if (tiledNest[4].getStep() != vec_size) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("tiledNest 4", tiledNest[4].getStep(), "vec_size (" + std::to_string(vec_size) + ")")); + if (failed(tilePerfectlyNested(tiledNest[4], {NR}))) { + spdlog::warn("Could not tile the second j loop in MatMulLowering"); + }; + + llvm::SmallVector twiceTiledNest; + getPerfectlyNestedLoops(twiceTiledNest, tiledNest[0]); + // i loops + if (twiceTiledNest[0].getStep() != MC) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("twiceTiledNest 0", twiceTiledNest[0].getStep(), "MC (" + std::to_string(MC) + ")")); + if (twiceTiledNest[3].getStep() != MR) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("twiceTiledNest 3", twiceTiledNest[3].getStep(), "MR (" + std::to_string(MR) + ")")); + if (twiceTiledNest[4].getStep() != 1) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("twiceTiledNest 4", twiceTiledNest[4].getStep(), "1")); + + // j loops + if (twiceTiledNest[1].getStep() != NC * vec_size) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("twiceTiledNest 1", twiceTiledNest[1].getStep(), + "NC * vec_size (" + std::to_string(NC * vec_size) + ")")); + if (twiceTiledNest[5].getStep() != NR * vec_size) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("twiceTiledNest 5", twiceTiledNest[5].getStep(), + "NR * vec_size (" + std::to_string(NR * vec_size) + ")")); + if (twiceTiledNest[6].getStep() != vec_size) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("twiceTiledNest 6", twiceTiledNest[6].getStep(), + "vec_size (" + std::to_string(vec_size) + ")")); + + // k loops + if (twiceTiledNest[2].getStep() != KC) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("twiceTiledNest 2", twiceTiledNest[2].getStep(), "KC (" + std::to_string(KC) + ")")); + if (twiceTiledNest[7].getStep() != 1) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("twiceTiledNest 7", twiceTiledNest[7].getStep(), "1")); + + // permute loops to final order (i / MC, j / NC, k / KC, i / MR, i mod + // MR, j / NR, j mod NR, k mod KC) -> + // (j / NC, k / KC, i / MC, j / NR, i / MR, + // k mod KC, j mod NR, i mod MR) + unsigned root_idx = permuteLoops(twiceTiledNest, {2, 0, 1, 4, 7, 3, 6, 5}); + + // Unroll and jam + llvm::SmallVector blisTiledLoops; + getPerfectlyNestedLoops(blisTiledLoops, twiceTiledNest[root_idx]); + // i loops + if (blisTiledLoops[2].getStep() != MC) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("blisTiled 2", blisTiledLoops[2].getStep(), "MC (" + std::to_string(MC) + ")")); + if (blisTiledLoops[4].getStep() != MR) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("blisTiled 4", blisTiledLoops[4].getStep(), "MR (" + std::to_string(MR) + ")")); + if (blisTiledLoops[7].getStep() != 1) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("blisTiled 7", blisTiledLoops[7].getStep(), "1")); + + // j loops + if (blisTiledLoops[0].getStep() != NC * vec_size) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("blisTiled 0", blisTiledLoops[0].getStep(), + "NC * vec_size (" + std::to_string(NC * vec_size) + ")")); + if (blisTiledLoops[3].getStep() != NR * vec_size) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("blisTiled 3", blisTiledLoops[3].getStep(), + "NR * vec_size (" + std::to_string(NR * vec_size) + ")")); + if (blisTiledLoops[6].getStep() != vec_size) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("blisTiled 6", blisTiledLoops[6].getStep(), "vec_size (" + std::to_string(vec_size) + ")")); + + // k loops + if (blisTiledLoops[1].getStep() != KC) + throw ErrorHandler::compilerError( + loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("blisTiled 1", blisTiledLoops[1].getStep(), "KC (" + std::to_string(KC) + ")")); + if (blisTiledLoops[5].getStep() != 1) + throw ErrorHandler::compilerError(loc, "MatMulOpLowering (tile_loops)", + GEN_ERR_MSG("blisTiled 5", blisTiledLoops[5].getStep(), "1")); + +#undef GEN_ERR_MSG + + // Unroll jam causes Segfault, if called in a way where the loop is not + // cleanly divided. + if (options.unroll_jam_factor > 0 && blisTiledLoops[5].getUpperBound().getMap().getNumResults() == 1 && + succeeded(loopUnrollJamUpToFactor(blisTiledLoops[5], options.unroll_jam_factor))) { + if (blisTiledLoops[6].getUpperBound().getMap().getNumResults() != 1 || + failed(loopUnrollJamUpToFactor(blisTiledLoops[6], options.unroll_jam_factor))) { + spdlog::warn("Could not unroll the (j mod NC) mod NR loop in " + "MatMulLowering"); + } + } else { + spdlog::warn("Could not unroll the (i mod MC) mod MR loop in " + "MatMulLowering"); + } + + llvm::SmallVector lastNest; + getPerfectlyNestedLoops(lastNest, blisTiledLoops.front()); + int64_t i = 0; + while (succeeded(promoteIfSingleIteration(lastNest[i])) && i < 4) { + i++; + } + + if (KU > 0 && failed(loopUnrollUpToFactor(lastNest.back(), KU))) { + spdlog::warn("Could not unroll the K loop in MatMulLowering"); + } } - } }; namespace { @@ -626,129 +620,118 @@ namespace { * * A more detailed description can be found in 'daphneir/Passes.td'. */ -struct MatMulLoweringPass - : public impl::MatMulOpLoweringPassBase { - MatMulLoweringPass() = default; - -public: - explicit MatMulLoweringPass(bool matmul_tile, int matmul_vec_size_bits, - std::vector matmul_fixed_tile_sizes, - bool matmul_use_fixed_tile_sizes, - int matmul_unroll_factor, - int matmul_unroll_jam_factor, - int matmul_num_vec_registers, - bool matmul_invert_loops) - : impl::MatMulOpLoweringPassBase() { - this->matmul_tile = matmul_tile; - this->matmul_vec_size_bits = matmul_vec_size_bits; - this->matmul_fixed_tile_sizes = matmul_fixed_tile_sizes; - this->matmul_use_fixed_tile_sizes = matmul_use_fixed_tile_sizes; - this->matmul_unroll_factor = matmul_unroll_factor; - this->matmul_unroll_jam_factor = matmul_unroll_jam_factor; - this->matmul_num_vec_registers = matmul_num_vec_registers; - this->matmul_invert_loops = matmul_invert_loops; - } - - void runOnOperation() override; - -private: - // Get the L1, L2 and L3 cache sizes to adapt tile sizes. - // So far assumes process is executed on a single processing unit. - // See example: - // https://www.open-mpi.org/projects/hwloc/doc/v2.2.0/a00324.php#cli_examples - SmallVector get_cache_sizes() const { - hwloc_topology_t topology; - hwloc_obj_t obj; - SmallVector sizes; - - // Allocate and initialize topology object - hwloc_topology_init(&topology); - // Perform topology detection - hwloc_topology_load(topology); - - for (obj = hwloc_get_obj_by_type(topology, HWLOC_OBJ_PU, 0); obj; - obj = obj->parent) - if (hwloc_obj_type_is_cache(obj->type)) { - sizes.push_back(obj->attr->cache.size); - } - return sizes; - } +struct MatMulLoweringPass : public impl::MatMulOpLoweringPassBase { + MatMulLoweringPass() = default; + + public: + explicit MatMulLoweringPass(bool matmul_tile, int matmul_vec_size_bits, + std::vector matmul_fixed_tile_sizes, bool matmul_use_fixed_tile_sizes, + int matmul_unroll_factor, int matmul_unroll_jam_factor, int matmul_num_vec_registers, + bool matmul_invert_loops) + : impl::MatMulOpLoweringPassBase() { + this->matmul_tile = matmul_tile; + this->matmul_vec_size_bits = matmul_vec_size_bits; + this->matmul_fixed_tile_sizes = matmul_fixed_tile_sizes; + this->matmul_use_fixed_tile_sizes = matmul_use_fixed_tile_sizes; + this->matmul_unroll_factor = matmul_unroll_factor; + this->matmul_unroll_jam_factor = matmul_unroll_jam_factor; + this->matmul_num_vec_registers = matmul_num_vec_registers; + this->matmul_invert_loops = matmul_invert_loops; + } + + void runOnOperation() override; + + private: + // Get the L1, L2 and L3 cache sizes to adapt tile sizes. + // So far assumes process is executed on a single processing unit. + // See example: + // https://www.open-mpi.org/projects/hwloc/doc/v2.2.0/a00324.php#cli_examples + SmallVector get_cache_sizes() const { + hwloc_topology_t topology; + hwloc_obj_t obj; + SmallVector sizes; + + // Allocate and initialize topology object + hwloc_topology_init(&topology); + // Perform topology detection + hwloc_topology_load(topology); + + for (obj = hwloc_get_obj_by_type(topology, HWLOC_OBJ_PU, 0); obj; obj = obj->parent) + if (hwloc_obj_type_is_cache(obj->type)) { + sizes.push_back(obj->attr->cache.size); + } + return sizes; + } }; } // end anonymous namespace void MatMulLoweringPass::runOnOperation() { - auto module = getOperation(); - mlir::ConversionTarget target(getContext()); - mlir::RewritePatternSet patterns(&getContext()); - LowerToLLVMOptions llvmOptions(&getContext()); - LLVMTypeConverter typeConverter(&getContext(), llvmOptions); - - typeConverter.addConversion(convertInteger); - typeConverter.addConversion(convertFloat); - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addArgumentMaterialization(materializeCastFromIllegal); - typeConverter.addSourceMaterialization(materializeCastToIllegal); - typeConverter.addTargetMaterialization(materializeCastFromIllegal); - - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - LowerMatMulOpOptions options; - if (matmul_tile) { - options.enableTiling(); - if (matmul_use_fixed_tile_sizes) { - options.useFixedTileSizes = true; - options.setTileSizes(matmul_fixed_tile_sizes); - } else { - options.setCacheSizes(get_cache_sizes()); + auto module = getOperation(); + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions llvmOptions(&getContext()); + LLVMTypeConverter typeConverter(&getContext(), llvmOptions); + + typeConverter.addConversion(convertInteger); + typeConverter.addConversion(convertFloat); + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addArgumentMaterialization(materializeCastFromIllegal); + typeConverter.addSourceMaterialization(materializeCastToIllegal); + typeConverter.addTargetMaterialization(materializeCastFromIllegal); + + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + LowerMatMulOpOptions options; + if (matmul_tile) { + options.enableTiling(); + if (matmul_use_fixed_tile_sizes) { + options.useFixedTileSizes = true; + options.setTileSizes(matmul_fixed_tile_sizes); + } else { + options.setCacheSizes(get_cache_sizes()); + } + options.setUnrollFactor(matmul_unroll_factor); + options.setUnrollJamFactor(matmul_unroll_jam_factor); + } + if (matmul_vec_size_bits > 0) { + options.enableVectorization(); + options.setVectorSizeBits(matmul_vec_size_bits); + } + options.enableLoopInversion(matmul_invert_loops); + options.setNumberOfVectorRegisters(matmul_num_vec_registers); + target.addDynamicallyLegalOp( + [options](Operation *op) { return !is_valid_options(options); }); + + patterns.insert(typeConverter, &getContext(), options); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + signalPassFailure(); } - options.setUnrollFactor(matmul_unroll_factor); - options.setUnrollJamFactor(matmul_unroll_jam_factor); - } - if (matmul_vec_size_bits > 0) { - options.enableVectorization(); - options.setVectorSizeBits(matmul_vec_size_bits); - } - options.enableLoopInversion(matmul_invert_loops); - options.setNumberOfVectorRegisters(matmul_num_vec_registers); - target.addDynamicallyLegalOp( - [options](Operation *op) { return !is_valid_options(options); }); - - patterns.insert(typeConverter, &getContext(), options); - - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { - signalPassFailure(); - } } -std::unique_ptr> -mlir::daphne::createMatMulOpLoweringPass( - bool matmul_tile, int matmul_vec_size_bits, - std::vector matmul_fixed_tile_sizes, - bool matmul_use_fixed_tile_sizes, int matmul_unroll_factor, - int matmul_unroll_jam_factor, int matmul_num_vec_registers, - bool matmul_invert_loops) { - return std::make_unique( - matmul_tile, matmul_vec_size_bits, matmul_fixed_tile_sizes, - matmul_use_fixed_tile_sizes, matmul_unroll_factor, - matmul_unroll_jam_factor, matmul_num_vec_registers, - matmul_invert_loops); +std::unique_ptr> mlir::daphne::createMatMulOpLoweringPass( + bool matmul_tile, int matmul_vec_size_bits, std::vector matmul_fixed_tile_sizes, + bool matmul_use_fixed_tile_sizes, int matmul_unroll_factor, int matmul_unroll_jam_factor, + int matmul_num_vec_registers, bool matmul_invert_loops) { + return std::make_unique( + matmul_tile, matmul_vec_size_bits, matmul_fixed_tile_sizes, matmul_use_fixed_tile_sizes, matmul_unroll_factor, + matmul_unroll_jam_factor, matmul_num_vec_registers, matmul_invert_loops); } // This is used by daphne-opt and automatically inserts the options provided on // the command line into the pass. -std::unique_ptr> -mlir::daphne::createMatMulOpLoweringPass() { - return std::make_unique(); +std::unique_ptr> mlir::daphne::createMatMulOpLoweringPass() { + return std::make_unique(); } diff --git a/src/compiler/lowering/ModOpLowering.cpp b/src/compiler/lowering/ModOpLowering.cpp index fb1fd8f11..8d047726e 100644 --- a/src/compiler/lowering/ModOpLowering.cpp +++ b/src/compiler/lowering/ModOpLowering.cpp @@ -32,131 +32,97 @@ using namespace mlir; -class EwModOpLowering - : public mlir::OpConversionPattern { - public: +class EwModOpLowering : public mlir::OpConversionPattern { + public: using OpConversionPattern::OpConversionPattern; [[nodiscard]] bool optimization_viable(mlir::Value divisor) const { - std::pair isConstant = - CompilerUtils::isConstant(divisor); + std::pair isConstant = CompilerUtils::isConstant(divisor); return isConstant.first && (isConstant.second & (isConstant.second - 1)) == 0; } - void optimizeEwModOp(mlir::Value memRef, mlir::Value divisor, - ArrayRef shape, - ConversionPatternRewriter &rewriter, - Location loc) const { + void optimizeEwModOp(mlir::Value memRef, mlir::Value divisor, ArrayRef shape, + ConversionPatternRewriter &rewriter, Location loc) const { // divisor - 1 - mlir::Value cst_one = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1)); + mlir::Value cst_one = + rewriter.create(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(1)); - auto casted_divisor = typeConverter->materializeTargetConversion( - rewriter, loc, rewriter.getI64Type(), ValueRange{divisor}); + auto casted_divisor = + typeConverter->materializeTargetConversion(rewriter, loc, rewriter.getI64Type(), ValueRange{divisor}); - mlir::Value rhs = - rewriter.create(loc, casted_divisor, cst_one); + mlir::Value rhs = rewriter.create(loc, casted_divisor, cst_one); SmallVector lowerBounds(/*Rank=*/2, /*Value=*/0); SmallVector steps(/*Rank=*/2, /*Value=*/1); buildAffineLoopNest( - rewriter, loc, lowerBounds, shape, steps, - [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - mlir::Value load = - nestedBuilder.create(loc, memRef, ivs); + rewriter, loc, lowerBounds, shape, steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { + mlir::Value load = nestedBuilder.create(loc, memRef, ivs); mlir::Value res{}; - Value castedLhs = - this->typeConverter->materializeTargetConversion( - nestedBuilder, loc, - nestedBuilder.getIntegerType( - divisor.getType().getIntOrFloatBitWidth()), - ValueRange{load}); + Value castedLhs = this->typeConverter->materializeTargetConversion( + nestedBuilder, loc, nestedBuilder.getIntegerType(divisor.getType().getIntOrFloatBitWidth()), + ValueRange{load}); res = nestedBuilder.create(loc, castedLhs, rhs); - Value castedRes = - this->typeConverter->materializeSourceConversion( - nestedBuilder, loc, divisor.getType(), ValueRange{res}); + Value castedRes = this->typeConverter->materializeSourceConversion(nestedBuilder, loc, + divisor.getType(), ValueRange{res}); - nestedBuilder.create(loc, castedRes, memRef, - ivs); + nestedBuilder.create(loc, castedRes, memRef, ivs); }); } - void lowerEwModOp(mlir::Value memRef, mlir::Value divisor, - ArrayRef shape, + void lowerEwModOp(mlir::Value memRef, mlir::Value divisor, ArrayRef shape, ConversionPatternRewriter &rewriter, Location loc) const { SmallVector lowerBounds(/*Rank=*/2, /*Value=*/0); SmallVector steps(/*Rank=*/2, /*Value=*/1); buildAffineLoopNest( - rewriter, loc, lowerBounds, shape, steps, - [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - mlir::Value load = - nestedBuilder.create(loc, memRef, ivs); + rewriter, loc, lowerBounds, shape, steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { + mlir::Value load = nestedBuilder.create(loc, memRef, ivs); mlir::Value res{}; // this is enough since divisor will be casted to float if // matrix is float if (llvm::isa(divisor.getType())) { - res = - nestedBuilder.create(loc, load, divisor); + res = nestedBuilder.create(loc, load, divisor); nestedBuilder.create(loc, res, memRef, ivs); return; } - Value castedLhs = - this->typeConverter->materializeTargetConversion( - nestedBuilder, loc, - nestedBuilder.getIntegerType( - divisor.getType().getIntOrFloatBitWidth()), - ValueRange{load}); - - Value castedRhs = - this->typeConverter->materializeTargetConversion( - nestedBuilder, loc, - nestedBuilder.getIntegerType( - divisor.getType().getIntOrFloatBitWidth()), - ValueRange{divisor}); - - res = nestedBuilder.create(loc, castedLhs, - castedRhs); - Value castedRes = - this->typeConverter->materializeSourceConversion( - nestedBuilder, loc, divisor.getType(), ValueRange{res}); - - nestedBuilder.create(loc, castedRes, memRef, - ivs); + Value castedLhs = this->typeConverter->materializeTargetConversion( + nestedBuilder, loc, nestedBuilder.getIntegerType(divisor.getType().getIntOrFloatBitWidth()), + ValueRange{load}); + + Value castedRhs = this->typeConverter->materializeTargetConversion( + nestedBuilder, loc, nestedBuilder.getIntegerType(divisor.getType().getIntOrFloatBitWidth()), + ValueRange{divisor}); + + res = nestedBuilder.create(loc, castedLhs, castedRhs); + Value castedRes = this->typeConverter->materializeSourceConversion(nestedBuilder, loc, + divisor.getType(), ValueRange{res}); + + nestedBuilder.create(loc, castedRes, memRef, ivs); }); } - mlir::LogicalResult matchAndRewrite( - mlir::daphne::EwModOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - mlir::daphne::MatrixType lhsTensor = - adaptor.getLhs().getType().dyn_cast(); + mlir::LogicalResult matchAndRewrite(mlir::daphne::EwModOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::daphne::MatrixType lhsTensor = adaptor.getLhs().getType().dyn_cast(); auto lhsRows = lhsTensor.getNumRows(); auto lhsCols = lhsTensor.getNumCols(); - auto lhsMemRefType = mlir::MemRefType::get({lhsRows, lhsCols}, - lhsTensor.getElementType()); + auto lhsMemRefType = mlir::MemRefType::get({lhsRows, lhsCols}, lhsTensor.getElementType()); // daphne::Matrix -> memref mlir::Value lhs = - rewriter.create( - op->getLoc(), lhsMemRefType, adaptor.getLhs()); + rewriter.create(op->getLoc(), lhsMemRefType, adaptor.getLhs()); mlir::Value rhs = adaptor.getRhs(); if (optimization_viable(rhs)) - optimizeEwModOp(lhs, rhs, - {lhsTensor.getNumRows(), lhsTensor.getNumCols()}, - rewriter, op->getLoc()); + optimizeEwModOp(lhs, rhs, {lhsTensor.getNumRows(), lhsTensor.getNumCols()}, rewriter, op->getLoc()); else - lowerEwModOp(lhs, rhs, - {lhsTensor.getNumRows(), lhsTensor.getNumCols()}, - rewriter, op->getLoc()); + lowerEwModOp(lhs, rhs, {lhsTensor.getNumRows(), lhsTensor.getNumCols()}, rewriter, op->getLoc()); - mlir::Value output = convertMemRefToDenseMatrix(op->getLoc(), rewriter, - lhs, op.getType()); + mlir::Value output = convertMemRefToDenseMatrix(op->getLoc(), rewriter, lhs, op.getType()); rewriter.replaceOp(op, output); return success(); } @@ -171,15 +137,12 @@ namespace { * If possible, we additionally perform the integer modulo optimization by * replacing the modulo with an bitwise AND and a subtraction. */ -struct ModOpLoweringPass - : public mlir::PassWrapper> { +struct ModOpLoweringPass : public mlir::PassWrapper> { explicit ModOpLoweringPass() {} void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnOperation() final; @@ -190,7 +153,7 @@ struct ModOpLoweringPass "and performing the mod op on values loaded from a MemRef."; } }; -} // end anonymous namespace +} // end anonymous namespace void ModOpLoweringPass::runOnOperation() { mlir::ConversionTarget target(getContext()); @@ -221,6 +184,4 @@ void ModOpLoweringPass::runOnOperation() { } } -std::unique_ptr mlir::daphne::createModOpLoweringPass() { - return std::make_unique(); -} +std::unique_ptr mlir::daphne::createModOpLoweringPass() { return std::make_unique(); } diff --git a/src/compiler/lowering/PhyOperatorSelectionPass.cpp b/src/compiler/lowering/PhyOperatorSelectionPass.cpp index 0ac0de062..798784df9 100644 --- a/src/compiler/lowering/PhyOperatorSelectionPass.cpp +++ b/src/compiler/lowering/PhyOperatorSelectionPass.cpp @@ -14,9 +14,9 @@ * limitations under the License. */ -#include #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" +#include #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -25,10 +25,10 @@ #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" #include @@ -38,25 +38,23 @@ using namespace mlir; class MatMulOpLowering : public OpConversionPattern { -public: + public: using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(daphne::MatMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(daphne::MatMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { Value lhs = op.getLhs(); Value rhs = op.getRhs(); - if(auto to = lhs.getDefiningOp()) { + if (auto to = lhs.getDefiningOp()) { bool rhsTransposed = CompilerUtils::constantOrThrow( - op.getTransb(), "MatMulOp.getTransb() is expected to be a constant" - ); - if(to.getArg() == rhs && !rhsTransposed) { + op.getTransb(), "MatMulOp.getTransb() is expected to be a constant"); + if (to.getArg() == rhs && !rhsTransposed) { // `t(M) @ M` -> `syrk(M)` rewriter.replaceOpWithNewOp(op, op.getResult().getType(), rhs); return success(); } auto rhsMatTy = rhs.getType().dyn_cast(); - if((!rhsTransposed && rhsMatTy.getNumCols() == 1) || (rhsTransposed && rhsMatTy.getNumRows() == 1)) { + if ((!rhsTransposed && rhsMatTy.getNumCols() == 1) || (rhsTransposed && rhsMatTy.getNumRows() == 1)) { // `t(M) @ v` -> `gemv(M, v)` rewriter.replaceOpWithNewOp(op, op.getResult().getType(), to.getArg(), rhs); return success(); @@ -67,11 +65,10 @@ class MatMulOpLowering : public OpConversionPattern { }; namespace { - struct PhyOperatorSelectionPass - : public PassWrapper> { - explicit PhyOperatorSelectionPass() { } - void runOnOperation() final; - }; +struct PhyOperatorSelectionPass : public PassWrapper> { + explicit PhyOperatorSelectionPass() {} + void runOnOperation() final; +}; } // end anonymous namespace void PhyOperatorSelectionPass::runOnOperation() { @@ -93,26 +90,22 @@ void PhyOperatorSelectionPass::runOnOperation() { // (see MatMulOp::canonicalize()), once we do it there again, we need // to account for it here (and above in MatMulOpLowering). auto to = op.getLhs().getDefiningOp(); - bool rhsTransposed = CompilerUtils::constantOrThrow( - op.getTransb(), "MatMulOp.getTransb() is expected to be a constant" - ); + bool rhsTransposed = + CompilerUtils::constantOrThrow(op.getTransb(), "MatMulOp.getTransb() is expected to be a constant"); auto rhsMatTy = op.getRhs().getType().dyn_cast(); - return !(to && ( - // `t(M) @ M` -> `syrk(M)` - (to.getArg() == op.getRhs() && !rhsTransposed) || - // `t(M) @ v` -> `gemv(M, v)` - (!rhsTransposed && rhsMatTy.getNumCols() == 1) || - (rhsTransposed && rhsMatTy.getNumRows() == 1) - )); + return !(to && + ( + // `t(M) @ M` -> `syrk(M)` + (to.getArg() == op.getRhs() && !rhsTransposed) || + // `t(M) @ v` -> `gemv(M, v)` + (!rhsTransposed && rhsMatTy.getNumCols() == 1) || (rhsTransposed && rhsMatTy.getNumRows() == 1))); }); RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); - if(failed(applyPartialConversion(module, target, std::move(patterns)))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } -std::unique_ptr daphne::createPhyOperatorSelectionPass() { - return std::make_unique(); -} \ No newline at end of file +std::unique_ptr daphne::createPhyOperatorSelectionPass() { return std::make_unique(); } \ No newline at end of file diff --git a/src/compiler/lowering/ProfilingPass.cpp b/src/compiler/lowering/ProfilingPass.cpp index beda1edc0..d91a28c27 100644 --- a/src/compiler/lowering/ProfilingPass.cpp +++ b/src/compiler/lowering/ProfilingPass.cpp @@ -26,16 +26,14 @@ using namespace mlir; /** * @brief Inserts profiling tracepoints */ -struct ProfilingPass: public PassWrapper> -{ +struct ProfilingPass : public PassWrapper> { explicit ProfilingPass() {} void runOnOperation() final; }; -void ProfilingPass::runOnOperation() -{ +void ProfilingPass::runOnOperation() { func::FuncOp f = getOperation(); - Block & b = f.getBody().front(); + Block &b = f.getBody().front(); OpBuilder builder(&b, b.begin()); Location loc = builder.getUnknownLoc(); @@ -45,7 +43,4 @@ void ProfilingPass::runOnOperation() builder.create(loc); } -std::unique_ptr daphne::createProfilingPass() -{ - return std::make_unique(); -} +std::unique_ptr daphne::createProfilingPass() { return std::make_unique(); } diff --git a/src/compiler/lowering/RewriteSqlOpPass.cpp b/src/compiler/lowering/RewriteSqlOpPass.cpp index 832d3a82c..02509516e 100644 --- a/src/compiler/lowering/RewriteSqlOpPass.cpp +++ b/src/compiler/lowering/RewriteSqlOpPass.cpp @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" +#include #include #include "mlir/Dialect/Arith/IR/Arith.h" @@ -23,77 +23,67 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" -#include -#include #include +#include #include #include #include #include +#include using namespace mlir; -namespace -{ - - std::unordered_map tables; - struct SqlReplacement : public RewritePattern{ - - SqlReplacement(MLIRContext * context, PatternBenefit benefit = 1) - : RewritePattern(Pattern::MatchAnyOpTypeTag(), benefit, context) - {} - - LogicalResult matchAndRewrite( - Operation *op, - PatternRewriter &rewriter - ) const override - { - if(auto rOp = llvm::dyn_cast(op)){ - std::stringstream view_stream; - view_stream << rOp.getView().str(); - mlir::Value arg = rOp.getArg(); - - tables[view_stream.str()] = arg; - rewriter.eraseOp(op); - return success(); - }else if(auto sqlop = llvm::dyn_cast(op)){ - std::stringstream sql_query; - sql_query << sqlop.getSql().str(); - - SQLParser parser; - parser.setView(tables); - parser.setSqlOp(sqlop); - std::string sourceName; - llvm::raw_string_ostream ss(sourceName); - ss << sqlop->getLoc(); - mlir::Value result_op; - try { - result_op = parser.parseStreamFrame(rewriter, sql_query, sourceName); - } catch (std::runtime_error &re) { - throw ErrorHandler::rethrowError("RewriteSqlOpPass", - re.what()); - } - rewriter.replaceOp(op, result_op); - // TODO Why is this necessary when we have already replaced the op? - rewriter.replaceAllUsesWith(op->getResult(0), result_op); - return success(); +namespace { + +std::unordered_map tables; +struct SqlReplacement : public RewritePattern { + + SqlReplacement(MLIRContext *context, PatternBenefit benefit = 1) + : RewritePattern(Pattern::MatchAnyOpTypeTag(), benefit, context) {} + + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + if (auto rOp = llvm::dyn_cast(op)) { + std::stringstream view_stream; + view_stream << rOp.getView().str(); + mlir::Value arg = rOp.getArg(); + + tables[view_stream.str()] = arg; + rewriter.eraseOp(op); + return success(); + } else if (auto sqlop = llvm::dyn_cast(op)) { + std::stringstream sql_query; + sql_query << sqlop.getSql().str(); + + SQLParser parser; + parser.setView(tables); + parser.setSqlOp(sqlop); + std::string sourceName; + llvm::raw_string_ostream ss(sourceName); + ss << sqlop->getLoc(); + mlir::Value result_op; + try { + result_op = parser.parseStreamFrame(rewriter, sql_query, sourceName); + } catch (std::runtime_error &re) { + throw ErrorHandler::rethrowError("RewriteSqlOpPass", re.what()); } - return failure(); + rewriter.replaceOp(op, result_op); + // TODO Why is this necessary when we have already replaced the op? + rewriter.replaceAllUsesWith(op->getResult(0), result_op); + return success(); } - }; + return failure(); + } +}; - struct RewriteSqlOpPass - : public PassWrapper > - { - void runOnOperation() final; +struct RewriteSqlOpPass : public PassWrapper> { + void runOnOperation() final; StringRef getArgument() const final { return "rewrite-sqlop"; } StringRef getDescription() const final { return "TODO"; } - }; -} +}; +} // namespace -void RewriteSqlOpPass::runOnOperation() -{ +void RewriteSqlOpPass::runOnOperation() { auto module = getOperation(); RewritePatternSet patterns(&getContext()); @@ -108,7 +98,4 @@ void RewriteSqlOpPass::runOnOperation() signalPassFailure(); } -std::unique_ptr daphne::createRewriteSqlOpPass() -{ - return std::make_unique(); -} +std::unique_ptr daphne::createRewriteSqlOpPass() { return std::make_unique(); } diff --git a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp index 8f41d8f06..f99c9e472 100644 --- a/src/compiler/lowering/RewriteToCallKernelOpPass.cpp +++ b/src/compiler/lowering/RewriteToCallKernelOpPass.cpp @@ -15,11 +15,11 @@ */ #include "compiler/utils/CompilerUtils.h" -#include -#include -#include #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" +#include +#include +#include #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -29,683 +29,574 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/IR/IRMapping.h" -#include -#include #include +#include #include #include #include #include #include +#include #include using namespace mlir; -namespace -{ - class KernelReplacement : public RewritePattern - { - // TODO This method is only required since MLIR does not seem to - // provide a means to get this information. - static size_t getNumODSOperands(Operation * op) { - if(llvm::isa(op)) - return 4; - if(llvm::isa(op)) - return 4; - if(llvm::isa(op)) - return 3; - if(llvm::isa(op)) - return 2; - if(llvm::isa(op)) - return 1; - - throw ErrorHandler::compilerError( - op, "RewriteToCallKernelOpPass", - "lowering to kernel call not yet supported for this variadic " - "operation: " + - op->getName().getStringRef().str()); +namespace { +class KernelReplacement : public RewritePattern { + // TODO This method is only required since MLIR does not seem to + // provide a means to get this information. + static size_t getNumODSOperands(Operation *op) { + if (llvm::isa(op)) + return 4; + if (llvm::isa(op)) + return 4; + if (llvm::isa(op)) + return 3; + if (llvm::isa(op)) + return 2; + if (llvm::isa(op)) + return 1; + + throw ErrorHandler::compilerError(op, "RewriteToCallKernelOpPass", + "lowering to kernel call not yet supported for this variadic " + "operation: " + + op->getName().getStringRef().str()); + } + + // TODO This method is only required since MLIR does not seem to + // provide a means to get this information. But, for instance, the + // isVariadic boolean array is automatically generated *within* the + // getODSOperandIndexAndLength method. + static std::tuple getODSOperandInfo(Operation *op, unsigned index) { + // TODO Simplify those by a macro. + if (auto concreteOp = llvm::dyn_cast(op)) { + auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); + static bool isVariadic[] = {true, true}; + return std::make_tuple(idxAndLen.first, idxAndLen.second, isVariadic[index]); + } + if (auto concreteOp = llvm::dyn_cast(op)) { + auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); + static bool isVariadic[] = {true}; + return std::make_tuple(idxAndLen.first, idxAndLen.second, isVariadic[index]); + } + if (auto concreteOp = llvm::dyn_cast(op)) { + auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); + static bool isVariadic[] = {false, true}; + return std::make_tuple(idxAndLen.first, idxAndLen.second, isVariadic[index]); } + if (auto concreteOp = llvm::dyn_cast(op)) { + auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); + static bool isVariadic[] = {true}; + return std::make_tuple(idxAndLen.first, idxAndLen.second, isVariadic[index]); + } + if (auto concreteOp = llvm::dyn_cast(op)) { + auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); + static bool isVariadic[] = {false, true, true}; + return std::make_tuple(idxAndLen.first, idxAndLen.second, isVariadic[index]); + } + if (auto concreteOp = llvm::dyn_cast(op)) { + auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); + static bool isVariadic[] = {false, false, true, true}; + return std::make_tuple(idxAndLen.first, idxAndLen.second, isVariadic[index]); + } + if (auto concreteOp = llvm::dyn_cast(op)) { + auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); + static bool isVariadic[] = {false, true, true, false}; + return std::make_tuple(idxAndLen.first, idxAndLen.second, isVariadic[index]); + } + throw ErrorHandler::compilerError(op, "RewriteToCallKernelOpPass", + "lowering to kernel call not yet supported for this variadic " + "operation: " + + op->getName().getStringRef().str()); + } + + /** + * @brief The value of type `DaphneContext` to insert as the last + * argument to all kernel calls. + */ + Value dctx; + + const DaphneUserConfig &userConfig; + std::unordered_map &usedLibPaths; + + mlir::Type adaptType(mlir::Type t, bool generalizeToStructure) const { + MLIRContext *mctx = t.getContext(); + if (generalizeToStructure && t.isa()) + return mlir::daphne::StructureType::get(mctx); + if (auto mt = t.dyn_cast()) + return mt.withSameElementTypeAndRepr(); + if (t.isa()) + return mlir::daphne::FrameType::get(mctx, {mlir::daphne::UnknownType::get(mctx)}); + if (auto lt = t.dyn_cast()) + return mlir::daphne::ListType::get(mctx, adaptType(lt.getElementType(), generalizeToStructure)); + if (auto mrt = t.dyn_cast()) + // Remove any dimension information ({0, 0}), but retain the element + // type. + return mlir::MemRefType::get({0, 0}, mrt.getElementType()); + return t; + } + + public: + /** + * Creates a new KernelReplacement rewrite pattern. + * + * @param mctx The MLIR context. + * @param dctx The DaphneContext to pass to the kernels. + * @param userConfig The user config. + * @param benefit + */ + KernelReplacement(MLIRContext *mctx, Value dctx, const DaphneUserConfig &userConfig, + std::unordered_map &usedLibPaths, PatternBenefit benefit = 1) + : RewritePattern(Pattern::MatchAnyOpTypeTag(), benefit, mctx), dctx(dctx), userConfig(userConfig), + usedLibPaths(usedLibPaths) {} + + /** + * @brief Rewrites the given operation to a `CallKernelOp`. + * + * This involves looking up a matching kernel from the kernel catalog based + * on the mnemonic, argument/result types, and backend (e.g., hardware + * accelerator) of the given operation. Variadic operands are also taken + * into account. + * + * @param op The operation to rewrite. + * @param rewriter The rewriter. + * @result Always returns `mlir::success()` unless an exception is thrown. + */ + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + // The argument/result types of the given operation. + Operation::operand_type_range opArgTys = op->getOperandTypes(); + Operation::result_type_range opResTys = op->getResultTypes(); + + // The argument/result types to use for kernel look-up. + std::vector lookupArgTys; + std::vector lookupResTys; + // Differences between op argument types and look-up argument types: + // - The look-up argument types summarize n occurrences of a variadic + // operand into + // one variadic pack and one number of occurrences. + // - The look-up argument types omit most of the properties of the op + // argument types, + // because those would complicate the search for matching kernels. + // Differences between op result types and look-up result types: + // - The look-up result types omit most of the properties of the op + // result types, + // because those would complicate the search for matching kernels. + + // The operands to use for the CallKernelOp to be created. These may + // differ from the operands of the given operation, if it has a variadic + // operand. + std::vector kernelArgs; + + // ***************************************************************************** + // Prepare the kernel look-up and the creation of the CallKernelOp + // ***************************************************************************** + // Determine the argument/result types for the kernel look-up as well as + // the arguments of the CallKernelOp to be created. Variadic operands + // are taken into account. + + // Find out if argument types shall the generalized from matrix/frame to + // the supertype structure. + // TODO Don't enumerate all ops, decide based on a trait. + const bool generalizeInputTypes = + llvm::isa(op) || llvm::isa(op) || + llvm::isa(op) || llvm::isa(op) || llvm::isa(op) || + llvm::isa(op) || llvm::isa(op); + + // Append converted op result types to the look-up result types. + for (size_t i = 0; i < opResTys.size(); i++) + lookupResTys.push_back(adaptType(opResTys[i], false)); + + // Append converted op argument types to the look-up argument types. + // Variadic operands, which can have an arbitrary number of occurrences, + // are treated specially. + if ( + // TODO Unfortunately, one needs to know the exact N for + // AtLeastNOperands... There seems to be no simple way to + // detect if an operation has variadic ODS operands with any N. + op->hasTrait() || op->hasTrait::Impl>() || + op->hasTrait::Impl>()) { + // For operations with variadic ODS operands, we replace all + // occurrences of a variadic ODS operand by a single operand of + // type VariadicPack as well as an operand for the number of + // occurrences. All occurrences of the variadic ODS operand are + // stored in the VariadicPack. + // Note that a variadic ODS operand may have zero occurrences. + // In that case, there is no operand corresponding to the + // variadic ODS operand. + const size_t numODSOperands = getNumODSOperands(op); + for (size_t i = 0; i < numODSOperands; i++) { + auto odsOpInfo = getODSOperandInfo(op, i); + const unsigned idx = std::get<0>(odsOpInfo); + const unsigned len = std::get<1>(odsOpInfo); + const bool isVariadic = std::get<2>(odsOpInfo); + + // Determine the MLIR type of the current ODS operand. + Type odsOperandTy; + if (len > 0) { + // If the current ODS operand has occurrences, then + // we use the type of the first operand belonging to + // the current ODS operand. + odsOperandTy = opArgTys[idx]; + } else { // len == 0 + // If the current ODS operand does not have any occurrences + // (e.g., a variadic ODS operand with zero concrete operands + // provided), then we cannot derive the type of the + // current ODS operand from any given operand. Instead, + // we use a default type depending on which ODS operand of + // which operation it is. + // Note that we cannot simply omit the type, since the + // underlying kernel expects an "empty list" (represented + // in the DAPHNE compiler by an empty VariadicPack). + if (llvm::dyn_cast(op) && i == 2) + // A GroupOp may have zero aggregation column names. + odsOperandTy = daphne::StringType::get(rewriter.getContext()); + else + throw std::runtime_error("RewriteToCallKernelOpPass encountered a variadic " + "ODS operand with zero occurrences, " + "but does not know how to handle it: ODS operand " + + std::to_string(i) + " of operation " + + op->getName().getStringRef().str()); + } - // TODO This method is only required since MLIR does not seem to - // provide a means to get this information. But, for instance, the - // isVariadic boolean array is automatically generated *within* the - // getODSOperandIndexAndLength method. - static std::tuple getODSOperandInfo(Operation * op, unsigned index) { - // TODO Simplify those by a macro. - if(auto concreteOp = llvm::dyn_cast(op)) { - auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); - static bool isVariadic[] = {true, true}; - return std::make_tuple( - idxAndLen.first, - idxAndLen.second, - isVariadic[index] - ); - } - if(auto concreteOp = llvm::dyn_cast(op)) { - auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); - static bool isVariadic[] = {true}; - return std::make_tuple( - idxAndLen.first, - idxAndLen.second, - isVariadic[index] - ); + lookupArgTys.push_back(adaptType(odsOperandTy, generalizeInputTypes)); + + if (isVariadic) { + // Variadic operand. + lookupArgTys.push_back(rewriter.getIndexType()); + auto cvpOp = rewriter.create( + loc, daphne::VariadicPackType::get(rewriter.getContext(), odsOperandTy), + rewriter.getI64IntegerAttr(len)); + for (int64_t k = 0; k < len; k++) + rewriter.create(loc, cvpOp, op->getOperand(idx + k), + rewriter.getI64IntegerAttr(k)); + kernelArgs.push_back(cvpOp); + kernelArgs.push_back( + rewriter.create(loc, rewriter.getIndexType(), rewriter.getIndexAttr(len))); + } else + // Non-variadic operand. + kernelArgs.push_back(op->getOperand(idx)); } - if(auto concreteOp = llvm::dyn_cast(op)) { - auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); - static bool isVariadic[] = {false, true}; - return std::make_tuple( - idxAndLen.first, - idxAndLen.second, - isVariadic[index] - ); + } else + // For operations without variadic operands, we simply append + // the type of each operand to the vector of types to use for + // kernel look-up, and pass all operands to the CallKernelOp as-is. + for (size_t i = 0; i < opArgTys.size(); i++) { + lookupArgTys.push_back(adaptType(opArgTys[i], generalizeInputTypes)); + kernelArgs.push_back(op->getOperand(i)); } - if(auto concreteOp = llvm::dyn_cast(op)) { - auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); - static bool isVariadic[] = {true}; - return std::make_tuple( - idxAndLen.first, - idxAndLen.second, - isVariadic[index] - ); - } - if(auto concreteOp = llvm::dyn_cast(op)) { - auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); - static bool isVariadic[] = {false, true, true}; - return std::make_tuple( - idxAndLen.first, - idxAndLen.second, - isVariadic[index] - ); - } - if(auto concreteOp = llvm::dyn_cast(op)) { - auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); - static bool isVariadic[] = {false, false, true, true}; - return std::make_tuple( - idxAndLen.first, - idxAndLen.second, - isVariadic[index] - ); - } - if(auto concreteOp = llvm::dyn_cast(op)) { - auto idxAndLen = concreteOp.getODSOperandIndexAndLength(index); - static bool isVariadic[] = {false, true, true, false}; - return std::make_tuple( - idxAndLen.first, - idxAndLen.second, - isVariadic[index] - ); - } - throw ErrorHandler::compilerError( - op, "RewriteToCallKernelOpPass", - "lowering to kernel call not yet supported for this variadic " - "operation: " + - op->getName().getStringRef().str()); - } - /** - * @brief The value of type `DaphneContext` to insert as the last - * argument to all kernel calls. - */ - Value dctx; - - const DaphneUserConfig & userConfig; - std::unordered_map & usedLibPaths; - - mlir::Type adaptType(mlir::Type t, bool generalizeToStructure) const { - MLIRContext * mctx = t.getContext(); - if(generalizeToStructure && t.isa()) - return mlir::daphne::StructureType::get(mctx); - if(auto mt = t.dyn_cast()) - return mt.withSameElementTypeAndRepr(); - if(t.isa()) - return mlir::daphne::FrameType::get(mctx, {mlir::daphne::UnknownType::get(mctx)}); - if(auto lt = t.dyn_cast()) - return mlir::daphne::ListType::get(mctx, adaptType(lt.getElementType(), generalizeToStructure)); - if(auto mrt = t.dyn_cast()) - // Remove any dimension information ({0, 0}), but retain the element type. - return mlir::MemRefType::get({0, 0}, mrt.getElementType()); - return t; + if (auto groupOp = llvm::dyn_cast(op)) { + // GroupOp carries the aggregation functions to apply as an + // attribute. Since attributes do not automatically become + // inputs to the kernel call, we need to add them explicitly + // here. + + ArrayAttr aggFuncs = groupOp.getAggFuncs(); + const size_t numAggFuncs = aggFuncs.size(); + const Type t = rewriter.getIntegerType(32, false); + auto cvpOp = rewriter.create( + loc, daphne::VariadicPackType::get(rewriter.getContext(), t), rewriter.getI64IntegerAttr(numAggFuncs)); + size_t k = 0; + for (Attribute aggFunc : aggFuncs.getValue()) + rewriter.create( + loc, cvpOp, + rewriter.create( + loc, t, + rewriter.getIntegerAttr( + t, static_cast(aggFunc.dyn_cast().getValue()))), + rewriter.getI64IntegerAttr(k++)); + kernelArgs.push_back(cvpOp); + kernelArgs.push_back( + rewriter.create(loc, rewriter.getIndexType(), rewriter.getIndexAttr(numAggFuncs))); } - public: - /** - * Creates a new KernelReplacement rewrite pattern. - * - * @param mctx The MLIR context. - * @param dctx The DaphneContext to pass to the kernels. - * @param userConfig The user config. - * @param benefit - */ - KernelReplacement( - MLIRContext * mctx, - Value dctx, - const DaphneUserConfig & userConfig, - std::unordered_map & usedLibPaths, - PatternBenefit benefit = 1 - ) - : RewritePattern(Pattern::MatchAnyOpTypeTag(), benefit, mctx), - dctx(dctx), userConfig(userConfig), usedLibPaths(usedLibPaths) - { + if (auto thetaJoinOp = llvm::dyn_cast(op)) { + // ThetaJoinOp carries multiple CompareOperation as an + // attribute. Since attributes do not automatically become + // inputs to the kernel call, we need to add them explicitly + // here. + + // get array of CompareOperations + ArrayAttr compareOperations = thetaJoinOp.getCmp(); + const size_t numCompareOperations = compareOperations.size(); + const Type t = rewriter.getIntegerType(32, false); + // create Variadic Pack + auto cvpOp = rewriter.create( + loc, daphne::VariadicPackType::get(rewriter.getContext(), t), + rewriter.getI64IntegerAttr(numCompareOperations)); + // fill variadic pack + size_t k = 0; + for (Attribute compareOperation : compareOperations.getValue()) + rewriter.create( + loc, cvpOp, + rewriter.create( + loc, t, + rewriter.getIntegerAttr( + t, static_cast( + compareOperation.dyn_cast().getValue()))), + rewriter.getI64IntegerAttr(k++)); + // add created variadic pack and size of this pack as + // new operands / parameters of the ThetaJoin-Kernel call + kernelArgs.push_back(cvpOp); + kernelArgs.push_back(rewriter.create(loc, rewriter.getIndexType(), + rewriter.getIndexAttr(numCompareOperations))); } - /** - * @brief Rewrites the given operation to a `CallKernelOp`. - * - * This involves looking up a matching kernel from the kernel catalog based on the - * mnemonic, argument/result types, and backend (e.g., hardware accelerator) of the - * given operation. Variadic operands are also taken into account. - * - * @param op The operation to rewrite. - * @param rewriter The rewriter. - * @result Always returns `mlir::success()` unless an exception is thrown. - */ - LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - - // The argument/result types of the given operation. - Operation::operand_type_range opArgTys = op->getOperandTypes(); - Operation::result_type_range opResTys = op->getResultTypes(); - - // The argument/result types to use for kernel look-up. - std::vector lookupArgTys; - std::vector lookupResTys; - // Differences between op argument types and look-up argument types: - // - The look-up argument types summarize n occurrences of a variadic operand into - // one variadic pack and one number of occurrences. - // - The look-up argument types omit most of the properties of the op argument types, - // because those would complicate the search for matching kernels. - // Differences between op result types and look-up result types: - // - The look-up result types omit most of the properties of the op result types, - // because those would complicate the search for matching kernels. - - // The operands to use for the CallKernelOp to be created. These may differ from - // the operands of the given operation, if it has a variadic operand. - std::vector kernelArgs; - - // ***************************************************************************** - // Prepare the kernel look-up and the creation of the CallKernelOp - // ***************************************************************************** - // Determine the argument/result types for the kernel look-up as well as - // the arguments of the CallKernelOp to be created. Variadic operands are taken - // into account. - - // Find out if argument types shall the generalized from matrix/frame to the - // supertype structure. - // TODO Don't enumerate all ops, decide based on a trait. - const bool generalizeInputTypes = - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op); - - // Append converted op result types to the look-up result types. - for(size_t i = 0; i < opResTys.size(); i++) - lookupResTys.push_back(adaptType(opResTys[i], false)); - - // Append converted op argument types to the look-up argument types. - // Variadic operands, which can have an arbitrary number of occurrences, are - // treated specially. - if( - // TODO Unfortunately, one needs to know the exact N for - // AtLeastNOperands... There seems to be no simple way to - // detect if an operation has variadic ODS operands with any N. - op->hasTrait() || - op->hasTrait::Impl>() || - op->hasTrait::Impl>() - ) { - // For operations with variadic ODS operands, we replace all - // occurrences of a variadic ODS operand by a single operand of - // type VariadicPack as well as an operand for the number of - // occurrences. All occurrences of the variadic ODS operand are - // stored in the VariadicPack. - // Note that a variadic ODS operand may have zero occurrences. - // In that case, there is no operand corresponding to the - // variadic ODS operand. - const size_t numODSOperands = getNumODSOperands(op); - for(size_t i = 0; i < numODSOperands; i++) { - auto odsOpInfo = getODSOperandInfo(op, i); - const unsigned idx = std::get<0>(odsOpInfo); - const unsigned len = std::get<1>(odsOpInfo); - const bool isVariadic = std::get<2>(odsOpInfo); - - // Determine the MLIR type of the current ODS operand. - Type odsOperandTy; - if(len > 0) { - // If the current ODS operand has occurrences, then - // we use the type of the first operand belonging to - // the current ODS operand. - odsOperandTy = opArgTys[idx]; - } - else { // len == 0 - // If the current ODS operand does not have any occurrences - // (e.g., a variadic ODS operand with zero concrete operands - // provided), then we cannot derive the type of the - // current ODS operand from any given operand. Instead, - // we use a default type depending on which ODS operand of - // which operation it is. - // Note that we cannot simply omit the type, since the - // underlying kernel expects an "empty list" (represented - // in the DAPHNE compiler by an empty VariadicPack). - if(llvm::dyn_cast(op) && i == 2) - // A GroupOp may have zero aggregation column names. - odsOperandTy = daphne::StringType::get(rewriter.getContext()); - else - throw std::runtime_error( - "RewriteToCallKernelOpPass encountered a variadic ODS operand with zero occurrences, " - "but does not know how to handle it: ODS operand " + std::to_string(i) + - " of operation " + op->getName().getStringRef().str() - ); - } - - lookupArgTys.push_back(adaptType(odsOperandTy, generalizeInputTypes)); - - if(isVariadic) { - // Variadic operand. - lookupArgTys.push_back(rewriter.getIndexType()); - auto cvpOp = rewriter.create( - loc, - daphne::VariadicPackType::get( - rewriter.getContext(), - odsOperandTy - ), - rewriter.getI64IntegerAttr(len) - ); - for(int64_t k = 0; k < len; k++) - rewriter.create( - loc, - cvpOp, - op->getOperand(idx + k), - rewriter.getI64IntegerAttr(k) - ); - kernelArgs.push_back(cvpOp); - kernelArgs.push_back(rewriter.create( - loc, rewriter.getIndexType(), rewriter.getIndexAttr(len) - )); - } - else - // Non-variadic operand. - kernelArgs.push_back(op->getOperand(idx)); - } - } - else - // For operations without variadic operands, we simply append - // the type of each operand to the vector of types to use for - // kernel look-up, and pass all operands to the CallKernelOp as-is. - for(size_t i = 0; i < opArgTys.size(); i++) { - lookupArgTys.push_back(adaptType(opArgTys[i], generalizeInputTypes)); - kernelArgs.push_back(op->getOperand(i)); - } + if (auto distCompOp = llvm::dyn_cast(op)) { + MLIRContext newContext; // TODO Reuse the existing context. + OpBuilder tempBuilder(&newContext); + std::string funcName = "dist"; - if(auto groupOp = llvm::dyn_cast(op)) { - // GroupOp carries the aggregation functions to apply as an - // attribute. Since attributes do not automatically become - // inputs to the kernel call, we need to add them explicitly - // here. - - ArrayAttr aggFuncs = groupOp.getAggFuncs(); - const size_t numAggFuncs = aggFuncs.size(); - const Type t = rewriter.getIntegerType(32, false); - auto cvpOp = rewriter.create( - loc, - daphne::VariadicPackType::get(rewriter.getContext(), t), - rewriter.getI64IntegerAttr(numAggFuncs) - ); - size_t k = 0; - for(Attribute aggFunc : aggFuncs.getValue()) - rewriter.create( - loc, - cvpOp, - rewriter.create( - loc, - t, - rewriter.getIntegerAttr( - t, - static_cast( - aggFunc.dyn_cast().getValue() - ) - ) - ), - rewriter.getI64IntegerAttr(k++) - ); - kernelArgs.push_back(cvpOp); - kernelArgs.push_back(rewriter.create( - loc, rewriter.getIndexType(), rewriter.getIndexAttr(numAggFuncs)) - ); - } - - if(auto thetaJoinOp = llvm::dyn_cast(op)) { - // ThetaJoinOp carries multiple CompareOperation as an - // attribute. Since attributes do not automatically become - // inputs to the kernel call, we need to add them explicitly - // here. - - // get array of CompareOperations - ArrayAttr compareOperations = thetaJoinOp.getCmp(); - const size_t numCompareOperations = compareOperations.size(); - const Type t = rewriter.getIntegerType(32, false); - // create Variadic Pack - auto cvpOp = rewriter.create( - loc, - daphne::VariadicPackType::get(rewriter.getContext(), t), - rewriter.getI64IntegerAttr(numCompareOperations) - ); - // fill variadic pack - size_t k = 0; - for(Attribute compareOperation : compareOperations.getValue()) - rewriter.create( - loc, - cvpOp, - rewriter.create( - loc, - t, - rewriter.getIntegerAttr( - t, - static_cast( - compareOperation.dyn_cast().getValue() - ) - ) - ), - rewriter.getI64IntegerAttr(k++) - ); - // add created variadic pack and size of this pack as - // new operands / parameters of the ThetaJoin-Kernel call - kernelArgs.push_back(cvpOp); - kernelArgs.push_back(rewriter.create( - loc, rewriter.getIndexType(), rewriter.getIndexAttr(numCompareOperations)) - ); - } + auto &bodyBlock = distCompOp.getBody().front(); + auto funcType = + tempBuilder.getFunctionType(bodyBlock.getArgumentTypes(), bodyBlock.getTerminator()->getOperandTypes()); + auto funcOp = tempBuilder.create(loc, funcName, funcType); - if(auto distCompOp = llvm::dyn_cast(op)) { - MLIRContext newContext; // TODO Reuse the existing context. - OpBuilder tempBuilder(&newContext); - std::string funcName = "dist"; - - auto &bodyBlock = distCompOp.getBody().front(); - auto funcType = tempBuilder.getFunctionType( - bodyBlock.getArgumentTypes(), bodyBlock.getTerminator()->getOperandTypes()); - auto funcOp = tempBuilder.create(loc, funcName, funcType); - - IRMapping mapper; - distCompOp.getBody().cloneInto(&funcOp.getRegion(), mapper); - - // write recompile region as string constant - std::string s; - llvm::raw_string_ostream stream(s); - funcOp.print(stream); - - auto strTy = daphne::StringType::get(rewriter.getContext()); - Value - rewriteStr = rewriter.create(loc, strTy, rewriter.getStringAttr(stream.str())); - lookupArgTys.push_back(mlir::daphne::StringType::get(&newContext)); - kernelArgs.push_back(rewriteStr); - } + IRMapping mapper; + distCompOp.getBody().cloneInto(&funcOp.getRegion(), mapper); + + // write recompile region as string constant + std::string s; + llvm::raw_string_ostream stream(s); + funcOp.print(stream); + + auto strTy = daphne::StringType::get(rewriter.getContext()); + Value rewriteStr = rewriter.create(loc, strTy, rewriter.getStringAttr(stream.str())); + lookupArgTys.push_back(mlir::daphne::StringType::get(&newContext)); + kernelArgs.push_back(rewriteStr); + } - // ***************************************************************************** - // Look up a matching kernel from the kernel catalog. - // ***************************************************************************** - - const KernelCatalog & kc = userConfig.kernelCatalog; - const std::string opMnemonic = op->getName().stripDialect().data(); - std::vector kernelInfos = kc.getKernelInfos(opMnemonic); - - std::string libPath; - std::string kernelFuncName; - // TODO Don't hardcode the attribute name, put it in a central place. - if(op->hasAttr("kernel_hint")) { - // The operation has a kernel hint. Lower to the hinted kernel if possible. - - // TODO Check if the attribute has the right type. - kernelFuncName = op->getAttrOfType("kernel_hint").getValue().str(); - bool found = false; - for(size_t i = 0; i < kernelInfos.size() && !found; i++) { - auto ki = kernelInfos[i]; - if(ki.kernelFuncName == kernelFuncName) { - libPath = ki.libPath; - found = true; - } + // ***************************************************************************** + // Look up a matching kernel from the kernel catalog. + // ***************************************************************************** + + const KernelCatalog &kc = userConfig.kernelCatalog; + const std::string opMnemonic = op->getName().stripDialect().data(); + std::vector kernelInfos = kc.getKernelInfos(opMnemonic); + + std::string libPath; + std::string kernelFuncName; + // TODO Don't hardcode the attribute name, put it in a central place. + if (op->hasAttr("kernel_hint")) { + // The operation has a kernel hint. Lower to the hinted kernel if + // possible. + + // TODO Check if the attribute has the right type. + kernelFuncName = op->getAttrOfType("kernel_hint").getValue().str(); + bool found = false; + for (size_t i = 0; i < kernelInfos.size() && !found; i++) { + auto ki = kernelInfos[i]; + if (ki.kernelFuncName == kernelFuncName) { + libPath = ki.libPath; + found = true; } - if(!found) - throw ErrorHandler::compilerError( - loc, - "RewriteToCallKernelOpPass", - "no kernel found for operation `" + opMnemonic + - "` with hinted name `" + kernelFuncName + "`" - ); } - else { - // The operation does not have a kernel hint. Search for a kernel - // for this operation and the given result/argument types and backend. - - if(kernelInfos.empty()) - throw ErrorHandler::compilerError( - loc, - "RewriteToCallKernelOpPass", - "no kernels registered for operation `" + opMnemonic + "`" - ); - - std::string backend; - if(op->hasAttr("cuda_device")) - backend = "CUDA"; - else if(op->hasAttr("fpgaopencl_device")) - backend = "FPGAOPENCL"; - else - backend = "CPP"; - - const size_t numArgs = lookupArgTys.size(); - const size_t numRess = lookupResTys.size(); - int chosenKernelIdx = -1; - for(size_t i = 0; i < kernelInfos.size() && chosenKernelIdx == -1; i++) { - auto ki = kernelInfos[i]; - if(ki.backend != backend) - continue; - if(numArgs != ki.argTypes.size()) - continue; - if(numRess != ki.resTypes.size()) - continue; - - bool mismatch = false; - for(size_t i = 0; i < numArgs && !mismatch; i++) - if(lookupArgTys[i] != ki.argTypes[i]) - mismatch = true; - for(size_t i = 0; i < numRess && !mismatch; i++) - if(lookupResTys[i] != ki.resTypes[i]) - mismatch = true; - if(!mismatch) - chosenKernelIdx = i; + if (!found) + throw ErrorHandler::compilerError(loc, "RewriteToCallKernelOpPass", + "no kernel found for operation `" + opMnemonic + + "` with hinted name `" + kernelFuncName + "`"); + } else { + // The operation does not have a kernel hint. Search for a kernel + // for this operation and the given result/argument types and + // backend. + + if (kernelInfos.empty()) + throw ErrorHandler::compilerError(loc, "RewriteToCallKernelOpPass", + "no kernels registered for operation `" + opMnemonic + "`"); + + std::string backend; + if (op->hasAttr("cuda_device")) + backend = "CUDA"; + else if (op->hasAttr("fpgaopencl_device")) + backend = "FPGAOPENCL"; + else + backend = "CPP"; + + const size_t numArgs = lookupArgTys.size(); + const size_t numRess = lookupResTys.size(); + int chosenKernelIdx = -1; + for (size_t i = 0; i < kernelInfos.size() && chosenKernelIdx == -1; i++) { + auto ki = kernelInfos[i]; + if (ki.backend != backend) + continue; + if (numArgs != ki.argTypes.size()) + continue; + if (numRess != ki.resTypes.size()) + continue; + + bool mismatch = false; + for (size_t i = 0; i < numArgs && !mismatch; i++) + if (lookupArgTys[i] != ki.argTypes[i]) + mismatch = true; + for (size_t i = 0; i < numRess && !mismatch; i++) + if (lookupResTys[i] != ki.resTypes[i]) + mismatch = true; + if (!mismatch) + chosenKernelIdx = i; + } + if (chosenKernelIdx == -1) { + std::stringstream s; + s << "no kernel for operation `" << opMnemonic << "` available for the required input types `("; + for (size_t i = 0; i < numArgs; i++) { + s << lookupArgTys[i]; + if (i < numArgs - 1) + s << ", "; } - if(chosenKernelIdx == -1) { - std::stringstream s; - s << "no kernel for operation `" << opMnemonic - << "` available for the required input types `("; - for(size_t i = 0; i < numArgs; i++) { - s << lookupArgTys[i]; - if(i < numArgs - 1) - s << ", "; - } - s << + ")` and output types `("; - for(size_t i = 0; i < numRess; i++) { - s << lookupResTys[i]; - if(i < numRess - 1) - s << ", "; - } - s << ")` for backend `" << backend << "`, registered kernels for this op:" << std::endl; - kc.dump(opMnemonic, s); - throw ErrorHandler::compilerError(loc, "RewriteToCallKernelOpPass", s.str()); + s << +")` and output types `("; + for (size_t i = 0; i < numRess; i++) { + s << lookupResTys[i]; + if (i < numRess - 1) + s << ", "; } - KernelInfo chosenKI = kernelInfos[chosenKernelIdx]; - libPath = chosenKI.libPath; - kernelFuncName = chosenKI.kernelFuncName; + s << ")` for backend `" << backend << "`, registered kernels for this op:" << std::endl; + kc.dump(opMnemonic, s); + throw ErrorHandler::compilerError(loc, "RewriteToCallKernelOpPass", s.str()); } - - // ***************************************************************************** - // Add kernel id and DAPHNE context as arguments - // ***************************************************************************** - - auto kId = rewriter.create( - loc, rewriter.getI32IntegerAttr( - KernelDispatchMapping::instance().registerKernel( - kernelFuncName, op))); - - // NOTE: kId has to be added before CreateDaphneContextOp because - // there is an assumption that the CTX is the last argument - // (LowerToLLVMPass.cpp::623,702). This means the kId is expected to - // be the second to last argument. - kernelArgs.push_back(kId); - - // Inject the current DaphneContext as the last input parameter to - // all kernel calls, unless it's a CreateDaphneContextOp. - if(!llvm::isa(op)) - kernelArgs.push_back(dctx); - - // ***************************************************************************** - // Create the CallKernelOp - // ***************************************************************************** - - // Mark the shared library the chosen kernel comes from as used. This means we - // will link this library into the JIT-compiled program later. - usedLibPaths.at(libPath) = true; - - // Create a CallKernelOp for the kernel function to call and return success(). - auto kernel = rewriter.create( - loc, - kernelFuncName, - kernelArgs, - opResTys - ); - rewriter.replaceOp(op, kernel.getResults()); - return success(); - } - }; - - class DistributedPipelineKernelReplacement : public OpConversionPattern { - Value dctx; - const DaphneUserConfig & userConfig; - std::unordered_map & usedLibPaths; - - public: - using OpConversionPattern::OpConversionPattern; - DistributedPipelineKernelReplacement( - MLIRContext * mctx, - Value dctx, - const DaphneUserConfig & userConfig, - std::unordered_map & usedLibPaths, - PatternBenefit benefit = 2 - ) - : OpConversionPattern(mctx, benefit), - dctx(dctx), userConfig(userConfig), usedLibPaths(usedLibPaths) - { + KernelInfo chosenKI = kernelInfos[chosenKernelIdx]; + libPath = chosenKI.libPath; + kernelFuncName = chosenKI.kernelFuncName; } - LogicalResult matchAndRewrite(daphne::DistributedPipelineOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override - { - size_t numOutputs = op.getOutputs().size(); - size_t numInputs = op.getInputs().size(); - - - std::stringstream callee; - callee << "_distributedPipeline"; // kernel name - callee << "__DenseMatrix_double_variadic" // outputs - << "__size_t" // numOutputs - << "__Structure_variadic" // inputs - << "__size_t" // numInputs - << "__int64_t" // outRows - << "__int64_t" // outCols - << "__int64_t" // splits - << "__int64_t" // combines - << "__char"; // irCode - - MLIRContext* mctx = rewriter.getContext(); - - Location loc = op.getLoc(); - Type vptObj = daphne::VariadicPackType::get(mctx, daphne::MatrixType::get(mctx, rewriter.getF64Type())); - Type vptSize = daphne::VariadicPackType::get(mctx, rewriter.getIntegerType(64, false)); - Type vptInt64 = daphne::VariadicPackType::get(mctx, rewriter.getIntegerType(64, true)); - - // Variadic pack for inputs. - auto cvpInputs = rewriter.create(loc, vptObj, rewriter.getI64IntegerAttr(numInputs)); - for(size_t i = 0; i < numInputs; i++) - rewriter.create( - loc, cvpInputs, op.getInputs()[i], rewriter.getI64IntegerAttr(i) - ); - // Constants for #inputs. - auto coNumInputs = rewriter.create(loc, numInputs); - [[maybe_unused]] auto coNumOutputs = rewriter.create(loc, numOutputs); - // Variadic pack for out_rows. - auto cvpOutRows = rewriter.create(loc, vptSize, rewriter.getI64IntegerAttr(numOutputs)); - for(size_t i = 0; i < numOutputs; i++) - rewriter.create( - loc, cvpOutRows, op.getOutRows()[i], rewriter.getI64IntegerAttr(i) - ); - // Variadic pack for out_cols. - auto cvpOutCols = rewriter.create(loc, vptSize, rewriter.getI64IntegerAttr(numOutputs)); - for(size_t i = 0; i < numOutputs; i++) - rewriter.create( - loc, cvpOutCols, op.getOutCols()[i], rewriter.getI64IntegerAttr(i) - ); - // Variadic pack for splits. - auto cvpSplits = rewriter.create(loc, vptInt64, rewriter.getI64IntegerAttr(numInputs)); - for(size_t i = 0; i < numInputs; i++) - rewriter.create( - loc, - cvpSplits, - rewriter.create( - loc, static_cast(op.getSplits()[i].dyn_cast().getValue()) - ), - rewriter.getI64IntegerAttr(i) - ); - // Variadic pack for combines. - auto cvpCombines = rewriter.create(loc, vptInt64, rewriter.getI64IntegerAttr(numOutputs)); - for(size_t i = 0; i < numOutputs; i++) - rewriter.create( - loc, - cvpCombines, - rewriter.create( - loc, static_cast(op.getCombines()[i].dyn_cast().getValue()) - ), - rewriter.getI64IntegerAttr(i) - ); - - // Create CallKernelOp. - std::vector newOperands = { - cvpInputs, coNumInputs, cvpOutRows, cvpOutCols, cvpSplits, cvpCombines, op.getIr(), dctx - }; - auto cko = rewriter.replaceOpWithNewOp( - op.getOperation(), - callee.str(), - newOperands, - op.getOutputs().getTypes() - ); - // TODO Use ATTR_HASVARIADICRESULTS from LowerToLLVMPass.cpp. - cko->setAttr("hasVariadicResults", rewriter.getBoolAttr(true)); - - return success(); - } - }; - - struct RewriteToCallKernelOpPass - : public PassWrapper> - { - const DaphneUserConfig& userConfig; - std::unordered_map & usedLibPaths; - - explicit RewriteToCallKernelOpPass( - const DaphneUserConfig& cfg, std::unordered_map & usedLibPaths - ) : userConfig(cfg), usedLibPaths(usedLibPaths) {} - - void runOnOperation() final; - }; -} - -void RewriteToCallKernelOpPass::runOnOperation() -{ + // ***************************************************************************** + // Add kernel id and DAPHNE context as arguments + // ***************************************************************************** + + auto kId = rewriter.create( + loc, rewriter.getI32IntegerAttr(KernelDispatchMapping::instance().registerKernel(kernelFuncName, op))); + + // NOTE: kId has to be added before CreateDaphneContextOp because + // there is an assumption that the CTX is the last argument + // (LowerToLLVMPass.cpp::623,702). This means the kId is expected to + // be the second to last argument. + kernelArgs.push_back(kId); + + // Inject the current DaphneContext as the last input parameter to + // all kernel calls, unless it's a CreateDaphneContextOp. + if (!llvm::isa(op)) + kernelArgs.push_back(dctx); + + // ***************************************************************************** + // Create the CallKernelOp + // ***************************************************************************** + + // Mark the shared library the chosen kernel comes from as used. This + // means we will link this library into the JIT-compiled program later. + usedLibPaths.at(libPath) = true; + + // Create a CallKernelOp for the kernel function to call and return + // success(). + auto kernel = rewriter.create(loc, kernelFuncName, kernelArgs, opResTys); + rewriter.replaceOp(op, kernel.getResults()); + return success(); + } +}; + +class DistributedPipelineKernelReplacement : public OpConversionPattern { + Value dctx; + const DaphneUserConfig &userConfig; + std::unordered_map &usedLibPaths; + + public: + using OpConversionPattern::OpConversionPattern; + DistributedPipelineKernelReplacement(MLIRContext *mctx, Value dctx, const DaphneUserConfig &userConfig, + std::unordered_map &usedLibPaths, + PatternBenefit benefit = 2) + : OpConversionPattern(mctx, benefit), dctx(dctx), userConfig(userConfig), usedLibPaths(usedLibPaths) {} + + LogicalResult matchAndRewrite(daphne::DistributedPipelineOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + size_t numOutputs = op.getOutputs().size(); + size_t numInputs = op.getInputs().size(); + + std::stringstream callee; + callee << "_distributedPipeline"; // kernel name + callee << "__DenseMatrix_double_variadic" // outputs + << "__size_t" // numOutputs + << "__Structure_variadic" // inputs + << "__size_t" // numInputs + << "__int64_t" // outRows + << "__int64_t" // outCols + << "__int64_t" // splits + << "__int64_t" // combines + << "__char"; // irCode + + MLIRContext *mctx = rewriter.getContext(); + + Location loc = op.getLoc(); + Type vptObj = daphne::VariadicPackType::get(mctx, daphne::MatrixType::get(mctx, rewriter.getF64Type())); + Type vptSize = daphne::VariadicPackType::get(mctx, rewriter.getIntegerType(64, false)); + Type vptInt64 = daphne::VariadicPackType::get(mctx, rewriter.getIntegerType(64, true)); + + // Variadic pack for inputs. + auto cvpInputs = + rewriter.create(loc, vptObj, rewriter.getI64IntegerAttr(numInputs)); + for (size_t i = 0; i < numInputs; i++) + rewriter.create(loc, cvpInputs, op.getInputs()[i], + rewriter.getI64IntegerAttr(i)); + // Constants for #inputs. + auto coNumInputs = rewriter.create(loc, numInputs); + [[maybe_unused]] auto coNumOutputs = rewriter.create(loc, numOutputs); + // Variadic pack for out_rows. + auto cvpOutRows = + rewriter.create(loc, vptSize, rewriter.getI64IntegerAttr(numOutputs)); + for (size_t i = 0; i < numOutputs; i++) + rewriter.create(loc, cvpOutRows, op.getOutRows()[i], + rewriter.getI64IntegerAttr(i)); + // Variadic pack for out_cols. + auto cvpOutCols = + rewriter.create(loc, vptSize, rewriter.getI64IntegerAttr(numOutputs)); + for (size_t i = 0; i < numOutputs; i++) + rewriter.create(loc, cvpOutCols, op.getOutCols()[i], + rewriter.getI64IntegerAttr(i)); + // Variadic pack for splits. + auto cvpSplits = + rewriter.create(loc, vptInt64, rewriter.getI64IntegerAttr(numInputs)); + for (size_t i = 0; i < numInputs; i++) + rewriter.create( + loc, cvpSplits, + rewriter.create( + loc, static_cast(op.getSplits()[i].dyn_cast().getValue())), + rewriter.getI64IntegerAttr(i)); + // Variadic pack for combines. + auto cvpCombines = + rewriter.create(loc, vptInt64, rewriter.getI64IntegerAttr(numOutputs)); + for (size_t i = 0; i < numOutputs; i++) + rewriter.create( + loc, cvpCombines, + rewriter.create( + loc, static_cast(op.getCombines()[i].dyn_cast().getValue())), + rewriter.getI64IntegerAttr(i)); + + // Create CallKernelOp. + std::vector newOperands = {cvpInputs, coNumInputs, cvpOutRows, cvpOutCols, + cvpSplits, cvpCombines, op.getIr(), dctx}; + auto cko = rewriter.replaceOpWithNewOp(op.getOperation(), callee.str(), newOperands, + op.getOutputs().getTypes()); + // TODO Use ATTR_HASVARIADICRESULTS from LowerToLLVMPass.cpp. + cko->setAttr("hasVariadicResults", rewriter.getBoolAttr(true)); + + return success(); + } +}; + +struct RewriteToCallKernelOpPass : public PassWrapper> { + const DaphneUserConfig &userConfig; + std::unordered_map &usedLibPaths; + + explicit RewriteToCallKernelOpPass(const DaphneUserConfig &cfg, std::unordered_map &usedLibPaths) + : userConfig(cfg), usedLibPaths(usedLibPaths) {} + + void runOnOperation() final; +}; +} // namespace + +void RewriteToCallKernelOpPass::runOnOperation() { func::FuncOp func = getOperation(); RewritePatternSet patterns(&getContext()); @@ -713,47 +604,29 @@ void RewriteToCallKernelOpPass::runOnOperation() // Specification of (il)legal dialects/operations. All DaphneIR operations // but those explicitly marked as legal will be replaced by CallKernelOp. ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalOp(); target.addIllegalDialect(); - target.addLegalOp< - daphne::ConstantOp, - daphne::ReturnOp, - daphne::CallKernelOp, - daphne::CreateVariadicPackOp, - daphne::StoreVariadicPackOp, - daphne::VectorizedPipelineOp, - scf::ForOp, - memref::LoadOp, - daphne::GenericCallOp, - daphne::MapOp - >(); - target.addDynamicallyLegalOp([](daphne::CastOp op) { - return op.isTrivialCast() || op.isRemovePropertyCast(); - }); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [](daphne::CastOp op) { return op.isTrivialCast() || op.isRemovePropertyCast(); }); // Determine the DaphneContext valid in the MLIR function being rewritten. mlir::Value dctx = CompilerUtils::getDaphneContext(func); - func->walk([&](daphne::VectorizedPipelineOp vpo) - { - vpo.getCtxMutable().assign(dctx); - }); + func->walk([&](daphne::VectorizedPipelineOp vpo) { vpo.getCtxMutable().assign(dctx); }); // Apply conversion to CallKernelOps. - patterns.insert< - KernelReplacement, - DistributedPipelineKernelReplacement - >(&getContext(), dctx, userConfig, usedLibPaths); + patterns.insert(&getContext(), dctx, userConfig, + usedLibPaths); if (failed(applyPartialConversion(func, target, std::move(patterns)))) signalPassFailure(); - } -std::unique_ptr daphne::createRewriteToCallKernelOpPass(const DaphneUserConfig& cfg, std::unordered_map & usedLibPaths) -{ +std::unique_ptr daphne::createRewriteToCallKernelOpPass(const DaphneUserConfig &cfg, + std::unordered_map &usedLibPaths) { return std::make_unique(cfg, usedLibPaths); } diff --git a/src/compiler/lowering/SpecializeGenericFunctionsPass.cpp b/src/compiler/lowering/SpecializeGenericFunctionsPass.cpp index 3541121c2..e6b98bc8a 100644 --- a/src/compiler/lowering/SpecializeGenericFunctionsPass.cpp +++ b/src/compiler/lowering/SpecializeGenericFunctionsPass.cpp @@ -14,10 +14,10 @@ * limitations under the License. */ -#include -#include #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" +#include +#include #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -26,435 +26,448 @@ #include "mlir/Transforms/Passes.h" #include +#include #include #include -#include #include using namespace mlir; namespace { - /** - * @brief Checks if the function is untyped, i.e., if at least one of the inputs is - * of unknown type. - * - * @param op The `FuncOp` to check - * @return true if `FuncOp` is untyped, false otherwise - */ - bool isUntypedFunction(func::FuncOp op) { - return llvm::any_of( - op.getFunctionType().getInputs(), - [&](Type ty) { - auto matTy = ty.dyn_cast(); - return - llvm::isa(ty) || - (matTy && (llvm::isa(matTy.getElementType()))); - } - ); - } +/** + * @brief Checks if the function is untyped, i.e., if at least one of the inputs + * is of unknown type. + * + * @param op The `FuncOp` to check + * @return true if `FuncOp` is untyped, false otherwise + */ +bool isUntypedFunction(func::FuncOp op) { + return llvm::any_of(op.getFunctionType().getInputs(), [&](Type ty) { + auto matTy = ty.dyn_cast(); + return llvm::isa(ty) || + (matTy && (llvm::isa(matTy.getElementType()))); + }); +} - /** - * @brief Checks if the function is a template, by checking the types of input arguments. - * - * We consider a function a template iff: - * (1) it is an untyped function (i.e., at least one of the inputs is of unknown type - * or a matrix of unknown value type), or - * (2) at least one of the inputs is a matrix with unknown properties - * - * @param op The `FuncOp` to check - * @return true if `FuncOp` is a template, false otherwise - */ - bool isFunctionTemplate(func::FuncOp op) { - return llvm::any_of( - op.getFunctionType().getInputs(), - [&](Type ty) { - auto matTy = ty.dyn_cast(); - return - llvm::isa(ty) || - (matTy && ( - llvm::isa(matTy.getElementType()) || - (matTy.getNumRows() == -1 && matTy.getNumCols() == -1 && matTy.getSparsity() == -1) - )); - } - ); +/** + * @brief Checks if the function is a template, by checking the types of input + * arguments. + * + * We consider a function a template iff: + * (1) it is an untyped function (i.e., at least one of the inputs is of unknown + * type or a matrix of unknown value type), or (2) at least one of the inputs is + * a matrix with unknown properties + * + * @param op The `FuncOp` to check + * @return true if `FuncOp` is a template, false otherwise + */ +bool isFunctionTemplate(func::FuncOp op) { + return llvm::any_of(op.getFunctionType().getInputs(), [&](Type ty) { + auto matTy = ty.dyn_cast(); + return llvm::isa(ty) || + (matTy && (llvm::isa(matTy.getElementType()) || + (matTy.getNumRows() == -1 && matTy.getNumCols() == -1 && matTy.getSparsity() == -1))); + }); +} + +std::string uniqueSpecializedFuncName(const std::string &functionName) { + static unsigned functionUniqueId = 0; + return functionName + "-" + std::to_string(++functionUniqueId); +} + +/** + * @brief Check if a function with the given input/output types can be called + * with the input types given. + * @param functionType The type of the function + * @param callTypes The types used in the call + * @return true if the types match for a call, false otherwise + */ +bool callTypesMatchFunctionTypes(FunctionType functionType, TypeRange callTypes) { + for (auto zipIt : llvm::zip(functionType.getInputs(), callTypes)) { + auto funcTy = std::get<0>(zipIt); + auto callTy = std::get<1>(zipIt); + // Note that we explicitly take all properties (e.g., shape) into + // account. + if (funcTy != callTy) + return false; } + return true; +} - std::string uniqueSpecializedFuncName(const std::string &functionName) { - static unsigned functionUniqueId = 0; - return functionName + "-" + std::to_string(++functionUniqueId); +/** + * @brief Get argument types for the specialized version of a template function. + * @param functionType The types of the template function. + * @param callTypes The types used in the call to the specialized version. + * @param funcName The name of the function to call + * @param callLoc The location of the call + * @return The argument types to use for the specialized version + */ +std::vector getSpecializedFuncArgTypes(FunctionType functionType, TypeRange callTypes, + const std::string &funcName, mlir::Location callLoc) { + auto unknownTy = daphne::UnknownType::get(functionType.getContext()); + std::vector specializedTypes; + for (auto it : llvm::enumerate(llvm::zip(functionType.getInputs(), callTypes))) { + auto index = it.index(); + auto funcInTy = std::get<0>(it.value()); + auto specializedTy = std::get<1>(it.value()); + if (funcInTy != specializedTy) { + auto funcMatTy = funcInTy.dyn_cast(); + auto specializedMatTy = specializedTy.dyn_cast(); + bool isMatchingUnknownMatrix = funcMatTy && specializedMatTy && funcMatTy.getElementType() == unknownTy; + bool isMatchingUnknownPropertiesMatrix = + funcMatTy && specializedMatTy && funcMatTy.getElementType() == specializedMatTy.getElementType() && + funcMatTy.getNumRows() == -1 && funcMatTy.getNumCols() == -1 && funcMatTy.getSparsity() == -1; + if (!isMatchingUnknownMatrix && !isMatchingUnknownPropertiesMatrix && funcInTy != unknownTy) { + std::string s; + llvm::raw_string_ostream stream(s); + // TODO The function name funcName has a cryptic suffix from + // overloading/specialization, which is not suitable for users + // for see. + // TODO This error message can shiw up even for typed functions + // which are no "templates", which is confusing for a user. + // TODO The index seems to be off by 1 (too large)... (or not, + // simply 0-based counting). + stream << "call to function template `" << funcName << "` with invalid types for argument " << index + << ": expected `" << funcInTy << "`, got `" << specializedTy << "`"; + throw ErrorHandler::compilerError(callLoc, "SpecializeGenericFunctionsPass", stream.str()); + } + } + // Note that specializedTy may explicitly contain property information + // (e.g., shape). + specializedTypes.push_back(specializedTy); } + return specializedTypes; +} - /** - * @brief Check if a function with the given input/output types can be called with the input types given. - * @param functionType The type of the function - * @param callTypes The types used in the call - * @return true if the types match for a call, false otherwise - */ - bool callTypesMatchFunctionTypes(FunctionType functionType, TypeRange callTypes) { - for(auto zipIt : llvm::zip(functionType.getInputs(), callTypes)) { - auto funcTy = std::get<0>(zipIt); - auto callTy = std::get<1>(zipIt); - // Note that we explicitly take all properties (e.g., shape) into account. - if(funcTy != callTy) - return false; +/** + * @brief Set the result types to the types of the function results. + * @param results The results for which to fix the types + * @param functionType The function type + * @return true if changes where made, else false + */ +bool fixResultTypes(ResultRange results, FunctionType functionType) { + bool madeChanges = false; + for (auto it : llvm::zip(results, functionType.getResults())) { + auto result = std::get<0>(it); + auto functionResultTy = std::get<1>(it); + if (result.getType() != functionResultTy) { + madeChanges = true; + result.setType(functionResultTy); } - return true; } + return madeChanges; +} + +/** + * @brief Run partial type and label inference on the given `FuncOp`. + * @param function The `FuncOp` + * @return The inferred `FuncOp` (same as input), or `nullptr` if an error + * happened + */ +func::FuncOp inferTypesInFunction(func::FuncOp function) { + // Run inference + mlir::PassManager pm(function->getContext(), "func.func"); + pm.enableVerifier(false); + // TODO There is a cyclic dependency between (shape) inference and + // constant folding (included in canonicalization), at the moment we + // run only three iterations of both passes (see #173). + pm.addPass(daphne::createInferencePass({true, true, true, true, true})); + pm.addPass(createCanonicalizerPass()); + pm.addPass(daphne::createInferencePass({true, true, true, true, true})); + pm.addPass(createCanonicalizerPass()); + pm.addPass(daphne::createInferencePass({true, true, true, true, true})); + pm.addPass(createCanonicalizerPass()); + pm.addPass(daphne::createInferencePass({true, true, true, true, true})); + pm.addPass(createCanonicalizerPass()); + if (failed(pm.run(function))) { + throw ErrorHandler::compilerError(function.getOperation(), "SpecializeGenericFunctionsPass", + "could not infer types for a call of function template: " + + function.getName().str()); + } + return function; +} +class SpecializeGenericFunctionsPass : public PassWrapper> { + std::unordered_map functions; + std::multimap specializedVersions; + std::set visited; + std::set called; + std::set templateFunctions; + + const DaphneUserConfig &userConfig; + std::shared_ptr logger; + + public: + explicit SpecializeGenericFunctionsPass(const DaphneUserConfig &cfg) : userConfig(cfg) { + logger = spdlog::get("compiler"); + } + + private: /** - * @brief Get argument types for the specialized version of a template function. - * @param functionType The types of the template function. - * @param callTypes The types used in the call to the specialized version. - * @param funcName The name of the function to call - * @param callLoc The location of the call - * @return The argument types to use for the specialized version + * @brief Create a specialized version of the template function. + * @param templateFunction The template function. + * @param specializedTypes The specialized function arguments + * @param operands The operands of the call operation + * @return The specialized function */ - std::vector getSpecializedFuncArgTypes(FunctionType functionType, TypeRange callTypes, const std::string & funcName, mlir::Location callLoc) { - auto unknownTy = daphne::UnknownType::get(functionType.getContext()); - std::vector specializedTypes; - for(auto it : llvm::enumerate(llvm::zip(functionType.getInputs(), callTypes))) { - auto index = it.index(); - auto funcInTy = std::get<0>(it.value()); - auto specializedTy = std::get<1>(it.value()); - if(funcInTy != specializedTy) { - auto funcMatTy = funcInTy.dyn_cast(); - auto specializedMatTy = specializedTy.dyn_cast(); - bool isMatchingUnknownMatrix = - funcMatTy && specializedMatTy && funcMatTy.getElementType() == unknownTy; - bool isMatchingUnknownPropertiesMatrix = - funcMatTy && specializedMatTy && funcMatTy.getElementType() == specializedMatTy.getElementType() && - funcMatTy.getNumRows() == -1 && funcMatTy.getNumCols() == -1 && funcMatTy.getSparsity() == -1; - if(!isMatchingUnknownMatrix && !isMatchingUnknownPropertiesMatrix && funcInTy != unknownTy) { - std::string s; - llvm::raw_string_ostream stream(s); - // TODO The function name funcName has a cryptic suffix from overloading/specialization, which is not suitable for users for see. - // TODO This error message can shiw up even for typed functions which are no "templates", which is confusing for a user. - // TODO The index seems to be off by 1 (too large)... (or not, simply 0-based counting). - stream << "call to function template `" << funcName << "` with invalid types for argument " << index - << ": expected `" << funcInTy << "`, got `" << specializedTy << "`"; - throw ErrorHandler::compilerError(callLoc, "SpecializeGenericFunctionsPass", stream.str()); + func::FuncOp createSpecializedFunction(func::FuncOp templateFunction, TypeRange specializedTypes, + ValueRange operands) { + OpBuilder builder(templateFunction); + auto specializedFunc = templateFunction.clone(); + builder.insert(specializedFunc); + + auto uniqueFuncName = uniqueSpecializedFuncName(templateFunction.getSymName().str()); + specializedFunc.setName(uniqueFuncName); + functions.insert({uniqueFuncName, specializedFunc}); + + // change argument types + specializedFunc.setType( + builder.getFunctionType(specializedTypes, specializedFunc.getFunctionType().getResults())); + for (auto it : llvm::zip(specializedFunc.getArguments(), specializedTypes)) { + std::get<0>(it).setType(std::get<1>(it)); + } + + bool insertedConst = false; + // Don't propagate constants into untyped functions, since that still + // causes problems for some reason. + if (userConfig.use_ipa_const_propa && !isUntypedFunction(templateFunction)) { + // Insert compile-time constant scalar call operands into the + // function. + Block &specializedFuncBodyBlock = specializedFunc.getBody().front(); + builder.setInsertionPointToStart(&specializedFuncBodyBlock); + for (auto it : llvm::enumerate(operands)) { + auto i = it.index(); + Value v = it.value(); + if (Operation *co = CompilerUtils::constantOfAnyType(v)) { + // Clone the constant operation into the function body. + Operation *coNew = co->clone(); + builder.insert(coNew); + // Replace all uses of the corresponding block argument by + // the newly inserted constant. + specializedFuncBodyBlock.getArgument(i).replaceAllUsesWith(coNew->getResult(0)); + // TODO We could even remove the corresponding function + // argument. + insertedConst = true; } } - // Note that specializedTy may explicitly contain property information (e.g., shape). - specializedTypes.push_back(specializedTy); } - return specializedTypes; + // Remember the newly specialized function for reuse only if we did not + // insert any constant call operands. + // TODO We could reuse it for other calls with the same constant (it's + // just more book-keeping effort). + if (!insertedConst) + specializedVersions.insert({templateFunction.getSymName().str(), specializedFunc}); + + return inferTypesInFunction(specializedFunc); } /** - * @brief Set the result types to the types of the function results. - * @param results The results for which to fix the types - * @param functionType The function type - * @return true if changes where made, else false + * @brief Try to reuse an existing specialization for the given template + * function + * @param operandTypes Operand types of the call operation + * @param operands Operands of the call operation or an empty list if the + * operands are not available + * @param templateFunction The template function called by the call + * operation + * @return either an existing and matching `FuncOp`, `nullptr` otherwise */ - bool fixResultTypes(ResultRange results, FunctionType functionType) { - bool madeChanges = false; - for(auto it : llvm::zip(results, functionType.getResults())) { - auto result = std::get<0>(it); - auto functionResultTy = std::get<1>(it); - if(result.getType() != functionResultTy) { - madeChanges = true; - result.setType(functionResultTy); + func::FuncOp tryReuseExistingSpecialization(TypeRange operandTypes, ValueRange operands, + func::FuncOp templateFunction) { + if (userConfig.use_ipa_const_propa) { + // If any call operand is a compile-time constant scalar, we don't + // reuse an existing specialization, but create a new one while + // propagating the constant to the function body. + // TODO We could reuse a former specialization that uses the same + // constant. + for (Value v : operands) + if (CompilerUtils::constantOfAnyType(v)) + return nullptr; + } + + // Try to find a reusable function specialization based on types and + // data properties. + auto eqIt = specializedVersions.equal_range(templateFunction.getSymName().str()); + for (auto it = eqIt.first; it != eqIt.second; ++it) { + auto specializedFunc = it->second; + + if (callTypesMatchFunctionTypes(specializedFunc.getFunctionType(), operandTypes)) { + // reuse existing specialized function + return specializedFunc; } } - return madeChanges; + + return nullptr; } /** - * @brief Run partial type and label inference on the given `FuncOp`. - * @param function The `FuncOp` - * @return The inferred `FuncOp` (same as input), or `nullptr` if an error happened + * @brief Try to reuse an existing specializtion if one exists, else creates + * a new specialization + * @param operandTypes Operand types of the call operation + * @param operands Operands of the call operation or an empty list if the + * operands are not available + * @param calledFunction The function called by the call operation + * @param callLoc The location of the call for which a function + * specialization shall be created or reused + * @return A `FuncOp`for the specialization */ - func::FuncOp inferTypesInFunction(func::FuncOp function) { - // Run inference - mlir::PassManager pm(function->getContext(), "func.func"); - pm.enableVerifier(false); - // TODO There is a cyclic dependency between (shape) inference and - // constant folding (included in canonicalization), at the moment we - // run only three iterations of both passes (see #173). - pm.addPass(daphne::createInferencePass({true, true, true, true, true})); - pm.addPass(createCanonicalizerPass()); - pm.addPass(daphne::createInferencePass({true, true, true, true, true})); - pm.addPass(createCanonicalizerPass()); - pm.addPass(daphne::createInferencePass({true, true, true, true, true})); - pm.addPass(createCanonicalizerPass()); - pm.addPass(daphne::createInferencePass({true, true, true, true, true})); - pm.addPass(createCanonicalizerPass()); - if(failed(pm.run(function))) { - throw ErrorHandler::compilerError( - function.getOperation(), "SpecializeGenericFunctionsPass", - "could not infer types for a call of function template: " + - function.getName().str()); + func::FuncOp createOrReuseSpecialization(TypeRange operandTypes, ValueRange operands, func::FuncOp calledFunction, + mlir::Location callLoc) { + // check for existing specialization that matches + func::FuncOp specializedFunc = tryReuseExistingSpecialization(operandTypes, operands, calledFunction); + if (!specializedFunc) { + // Create specialized function + auto specializedTypes = getSpecializedFuncArgTypes(calledFunction.getFunctionType(), operandTypes, + calledFunction.getSymName().str(), callLoc); + specializedFunc = createSpecializedFunction(calledFunction, specializedTypes, operands); + } + if (logger->should_log(spdlog::level::debug)) { + std::string s; + llvm::raw_string_ostream stream(s); + calledFunction->getLoc().print(stream); + logger->debug("calledFunction\n\tname: {}\n\tlocation: {}", calledFunction.getSymName().str(), s); } - return function; + templateFunctions.insert(calledFunction); + return specializedFunc; } - class SpecializeGenericFunctionsPass - : public PassWrapper> { - std::unordered_map functions; - std::multimap specializedVersions; - std::set visited; - std::set called; - std::set templateFunctions; - - const DaphneUserConfig& userConfig; - std::shared_ptr logger; - - public: - explicit SpecializeGenericFunctionsPass(const DaphneUserConfig& cfg) : userConfig(cfg) { - logger = spdlog::get("compiler"); + /** + * @brief Recursively specializes all functions within a `FuncOp` based on + * calls to the functions + * @param function The `FuncOp` to scan for function specializations + */ + void specializeCallsInFunction(func::FuncOp function) { + if (visited.count(function)) { + return; } - - private: - /** - * @brief Create a specialized version of the template function. - * @param templateFunction The template function. - * @param specializedTypes The specialized function arguments - * @param operands The operands of the call operation - * @return The specialized function - */ - func::FuncOp createSpecializedFunction(func::FuncOp templateFunction, TypeRange specializedTypes, ValueRange operands) { - OpBuilder builder(templateFunction); - auto specializedFunc = templateFunction.clone(); - builder.insert(specializedFunc); - - auto uniqueFuncName = uniqueSpecializedFuncName(templateFunction.getSymName().str()); - specializedFunc.setName(uniqueFuncName); - functions.insert({uniqueFuncName, specializedFunc}); - - // change argument types - specializedFunc - .setType(builder.getFunctionType(specializedTypes, specializedFunc.getFunctionType().getResults())); - for(auto it : llvm::zip(specializedFunc.getArguments(), specializedTypes)) { - std::get<0>(it).setType(std::get<1>(it)); - } - - bool insertedConst = false; - // Don't propagate constants into untyped functions, since that still causes problems for some reason. - if(userConfig.use_ipa_const_propa && !isUntypedFunction(templateFunction)) { - // Insert compile-time constant scalar call operands into the function. - Block & specializedFuncBodyBlock = specializedFunc.getBody().front(); - builder.setInsertionPointToStart(&specializedFuncBodyBlock); - for(auto it : llvm::enumerate(operands)) { - auto i = it.index(); - Value v = it.value(); - if(Operation * co = CompilerUtils::constantOfAnyType(v)) { - // Clone the constant operation into the function body. - Operation * coNew = co->clone(); - builder.insert(coNew); - // Replace all uses of the corresponding block argument by the newly inserted constant. - specializedFuncBodyBlock.getArgument(i).replaceAllUsesWith(coNew->getResult(0)); - // TODO We could even remove the corresponding function argument. - insertedConst = true; - } + visited.insert(function); + // Specialize all functions called directly + function.walk([&](daphne::GenericCallOp callOp) { + auto calledFunction = functions[callOp.getCallee().str()]; + bool hasConstantInput = llvm::any_of( + callOp.getOperands(), [&](Value v) { return CompilerUtils::constantOfAnyType(v) != nullptr; }); + if (isFunctionTemplate(calledFunction) || hasConstantInput) { + func::FuncOp specializedFunc = createOrReuseSpecialization( + callOp.getOperandTypes(), callOp.getOperands(), calledFunction, callOp.getLoc()); + callOp.setCalleeAttr(specializedFunc.getSymNameAttr()); + if (fixResultTypes(callOp->getResults(), specializedFunc.getFunctionType())) { + inferTypesInFunction(function); } + specializeCallsInFunction(specializedFunc); + called.insert(specializedFunc); + } else { + specializeCallsInFunction(calledFunction); + called.insert(calledFunction); } - // Remember the newly specialized function for reuse only if we did not insert any constant - // call operands. - // TODO We could reuse it for other calls with the same constant (it's just more book-keeping effort). - if(!insertedConst) - specializedVersions.insert({templateFunction.getSymName().str(), specializedFunc}); - - return inferTypesInFunction(specializedFunc); - } - - /** - * @brief Try to reuse an existing specialization for the given template function - * @param operandTypes Operand types of the call operation - * @param operands Operands of the call operation or an empty list if the operands are not available - * @param templateFunction The template function called by the call operation - * @return either an existing and matching `FuncOp`, `nullptr` otherwise - */ - func::FuncOp tryReuseExistingSpecialization(TypeRange operandTypes, ValueRange operands, func::FuncOp templateFunction) { - if(userConfig.use_ipa_const_propa) { - // If any call operand is a compile-time constant scalar, we don't reuse an existing specialization, - // but create a new one while propagating the constant to the function body. - // TODO We could reuse a former specialization that uses the same constant. - for(Value v : operands) - if(CompilerUtils::constantOfAnyType(v)) - return nullptr; - } - - // Try to find a reusable function specialization based on types and data properties. - auto eqIt = specializedVersions.equal_range(templateFunction.getSymName().str()); - for(auto it = eqIt.first ; it != eqIt.second ; ++it) { - auto specializedFunc = it->second; - - if(callTypesMatchFunctionTypes(specializedFunc.getFunctionType(), operandTypes)) { - // reuse existing specialized function - return specializedFunc; + }); + + // Specialize all functions called by MapOp + function.walk([&](daphne::MapOp mapOp) { + auto calledFunction = functions[mapOp.getFunc().str()]; + if (isFunctionTemplate(calledFunction)) { + // Get the element type of the matrix the function should be + // mapped on + mlir::Type opTy = mapOp.getArg().getType(); + auto inpMatrixTy = opTy.dyn_cast(); + func::FuncOp specializedFunc = + createOrReuseSpecialization(inpMatrixTy.getElementType(), {}, calledFunction, mapOp.getLoc()); + mapOp.setFuncAttr(specializedFunc.getSymNameAttr()); + + // We only allow functions that return exactly one result for + // mapOp + if (specializedFunc.getFunctionType().getNumResults() != 1) { + throw ErrorHandler::compilerError( + mapOp.getOperation(), "SpecializeGenericFunctionsPass", + "map expects a function with exactly one return " + "value. The provided function returns" + + std::to_string(specializedFunc.getFunctionType().getNumResults()) + "values instead."); } - } - return nullptr; - } - - /** - * @brief Try to reuse an existing specializtion if one exists, else creates a new - * specialization - * @param operandTypes Operand types of the call operation - * @param operands Operands of the call operation or an empty list if the operands are not available - * @param calledFunction The function called by the call operation - * @param callLoc The location of the call for which a function specialization shall be created or reused - * @return A `FuncOp`for the specialization - */ - func::FuncOp createOrReuseSpecialization(TypeRange operandTypes, ValueRange operands, func::FuncOp calledFunction, mlir::Location callLoc) { - // check for existing specialization that matches - func::FuncOp specializedFunc = tryReuseExistingSpecialization(operandTypes, operands, calledFunction); - if(!specializedFunc) { - // Create specialized function - auto specializedTypes = - getSpecializedFuncArgTypes(calledFunction.getFunctionType(), operandTypes, calledFunction.getSymName().str(), callLoc); - specializedFunc = createSpecializedFunction(calledFunction, specializedTypes, operands); - } - if(logger->should_log(spdlog::level::debug)) { - std::string s; - llvm::raw_string_ostream stream(s); - calledFunction->getLoc().print(stream); - logger->debug("calledFunction\n\tname: {}\n\tlocation: {}", calledFunction.getSymName().str(), s); - } - templateFunctions.insert(calledFunction); - return specializedFunc; - } + // Get current mapOp result matrix type and fix it if needed. + // If we fixed something we rerun inference of the whole + // function + daphne::MatrixType resMatrixTy = mapOp.getType().dyn_cast(); + mlir::Type funcResTy = specializedFunc.getFunctionType().getResult(0); + + // The matrix that results from the mapOp has the same dimension + // as the input matrix and the element-type returned by the + // specialized function + if (resMatrixTy.getNumCols() != inpMatrixTy.getNumCols() || + resMatrixTy.getNumRows() != inpMatrixTy.getNumRows() || resMatrixTy.getElementType() != funcResTy) { + mapOp.getResult().setType(inpMatrixTy.withElementType(funcResTy)); + inferTypesInFunction(function); + } - /** - * @brief Recursively specializes all functions within a `FuncOp` based on calls to the functions - * @param function The `FuncOp` to scan for function specializations - */ - void specializeCallsInFunction(func::FuncOp function) { - if(visited.count(function)) { - return; + specializeCallsInFunction(specializedFunc); + called.insert(specializedFunc); + } else { + specializeCallsInFunction(calledFunction); + called.insert(calledFunction); } - visited.insert(function); - // Specialize all functions called directly - function.walk([&](daphne::GenericCallOp callOp) { - auto calledFunction = functions[callOp.getCallee().str()]; - bool hasConstantInput = llvm::any_of( - callOp.getOperands(), - [&](Value v) { - return CompilerUtils::constantOfAnyType(v) != nullptr; - } - ); - if(isFunctionTemplate(calledFunction) || hasConstantInput) { - func::FuncOp specializedFunc = createOrReuseSpecialization(callOp.getOperandTypes(), callOp.getOperands(), calledFunction, callOp.getLoc()); - callOp.setCalleeAttr(specializedFunc.getSymNameAttr()); - if(fixResultTypes(callOp->getResults(), specializedFunc.getFunctionType())) { - inferTypesInFunction(function); - } - specializeCallsInFunction(specializedFunc); - called.insert(specializedFunc); - } - else { - specializeCallsInFunction(calledFunction); - called.insert(calledFunction); - } - }); - - // Specialize all functions called by MapOp - function.walk([&](daphne::MapOp mapOp) { - auto calledFunction = functions[mapOp.getFunc().str()]; - if(isFunctionTemplate(calledFunction)) { - // Get the element type of the matrix the function should be mapped on - mlir::Type opTy = mapOp.getArg().getType(); - auto inpMatrixTy = opTy.dyn_cast(); - func::FuncOp specializedFunc = createOrReuseSpecialization(inpMatrixTy.getElementType(), {}, calledFunction, mapOp.getLoc()); - mapOp.setFuncAttr(specializedFunc.getSymNameAttr()); - - // We only allow functions that return exactly one result for mapOp - if (specializedFunc.getFunctionType().getNumResults() != 1) { - throw ErrorHandler::compilerError( - mapOp.getOperation(), - "SpecializeGenericFunctionsPass", - "map expects a function with exactly one return " - "value. The provided function returns" + - std::to_string(specializedFunc.getFunctionType() - .getNumResults()) + - "values instead."); - } - - // Get current mapOp result matrix type and fix it if needed. - // If we fixed something we rerun inference of the whole function - daphne::MatrixType resMatrixTy = mapOp.getType().dyn_cast(); - mlir::Type funcResTy = specializedFunc.getFunctionType().getResult(0); - - // The matrix that results from the mapOp has the same dimension as the input - // matrix and the element-type returned by the specialized function - if(resMatrixTy.getNumCols() != inpMatrixTy.getNumCols() || - resMatrixTy.getNumRows() != inpMatrixTy.getNumRows() || - resMatrixTy.getElementType() != funcResTy) { - mapOp.getResult().setType(inpMatrixTy.withElementType(funcResTy)); - inferTypesInFunction(function); - } - - specializeCallsInFunction(specializedFunc); - called.insert(specializedFunc); - } - else { - specializeCallsInFunction(calledFunction); - called.insert(calledFunction); - } - }); - } + }); + } - public: - void runOnOperation() final; + public: + void runOnOperation() final; StringRef getArgument() const final { return "specialize-generic-funcs"; } StringRef getDescription() const final { return "TODO"; } - }; -} +}; +} // namespace /** - * @brief Generate and call specialized functions from template definitions and remove templates. + * @brief Generate and call specialized functions from template definitions and + * remove templates. * * We start entry functions (like `main` or `dist`) and then proceed as follows: * - * 1. Infer types (types up to the first `GenericCallOp` will be inferred for sure) - * 2. If the function called by `GenericCallOp` is untyped (input types are unknown), we clone it and set the input types - * to the types used in the call. For this specialized function we then do the same steps starting at 1. - * 3. With the (possibly cloned) specialized function we now know the outputs. Starting here we infer up to the next - * `GenericCallOp` and go back to step 2. + * 1. Infer types (types up to the first `GenericCallOp` will be inferred for + * sure) + * 2. If the function called by `GenericCallOp` is untyped (input types are + * unknown), we clone it and set the input types to the types used in the call. + * For this specialized function we then do the same steps starting at 1. + * 3. With the (possibly cloned) specialized function we now know the outputs. + * Starting here we infer up to the next `GenericCallOp` and go back to step 2. * 4. When all `GenericCallOp`s are specialized we are finished * - * Finally we delete all the template functions such that the MLIR code can be verified for correct input and output types. + * Finally we delete all the template functions such that the MLIR code can be + * verified for correct input and output types. */ void SpecializeGenericFunctionsPass::runOnOperation() { auto module = getOperation(); - module.walk([&](func::FuncOp funcOp) { - functions.insert({funcOp.getSymName().str(), funcOp}); - }); + module.walk([&](func::FuncOp funcOp) { functions.insert({funcOp.getSymName().str(), funcOp}); }); - // `entryFunctions` will hold entry functions like `main`, but also `dist` (for distributed computation) - // we could also directly specify the names `main`, `dist` etc. (if we add more `entry` functions), or just set - // an attribute flag for those functions. + // `entryFunctions` will hold entry functions like `main`, but also `dist` + // (for distributed computation) we could also directly specify the names + // `main`, `dist` etc. (if we add more `entry` functions), or just set an + // attribute flag for those functions. std::vector entryFunctions; - for(const auto &entry : functions) { + for (const auto &entry : functions) { entryFunctions.push_back(entry.second); } - for(const auto &function : entryFunctions) { - if(isFunctionTemplate(function) || visited.count(function) || templateFunctions.count(function)) + for (const auto &function : entryFunctions) { + if (isFunctionTemplate(function) || visited.count(function) || templateFunctions.count(function)) continue; try { inferTypesInFunction(function); - } catch (std::runtime_error& e) { + } catch (std::runtime_error &e) { throw ErrorHandler::rethrowError("SpecializeGenericFunctionsPass", e.what()); } specializeCallsInFunction(function); } // Delete non-called functions. - for(auto f : functions) { + for (auto f : functions) { // Never remove the main or dist function. - if(f.first == "main" or f.first == "dist") + if (f.first == "main" or f.first == "dist") continue; // Remove a function that was present before creating specializations, // if it is never called. - if(!called.count(f.second) || templateFunctions.count(f.second)) + if (!called.count(f.second) || templateFunctions.count(f.second)) f.second.erase(); } } -std::unique_ptr daphne::createSpecializeGenericFunctionsPass(const DaphneUserConfig& cfg) { +std::unique_ptr daphne::createSpecializeGenericFunctionsPass(const DaphneUserConfig &cfg) { return std::make_unique(cfg); } diff --git a/src/compiler/lowering/VectorizeComputationsPass.cpp b/src/compiler/lowering/VectorizeComputationsPass.cpp index a891f3543..985c6442e 100644 --- a/src/compiler/lowering/VectorizeComputationsPass.cpp +++ b/src/compiler/lowering/VectorizeComputationsPass.cpp @@ -14,190 +14,196 @@ * limitations under the License. */ - #include "compiler/utils/CompilerUtils.h" -#include #include "ir/daphneir/Daphne.h" #include "ir/daphneir/Passes.h" +#include #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" +#include #include #include -#include using namespace mlir; -namespace -{ - /** - * @brief Recursive function checking if the given value is transitively dependant on the operation `op`. - * @param value The value to check - * @param op The operation to check - * @return true if there is a dependency, false otherwise - */ - bool valueDependsOnResultOf(Value value, Operation *op) { - if (auto defOp = value.getDefiningOp()) { - if (defOp == op) - return true; +namespace { +/** + * @brief Recursive function checking if the given value is transitively + * dependant on the operation `op`. + * @param value The value to check + * @param op The operation to check + * @return true if there is a dependency, false otherwise + */ +bool valueDependsOnResultOf(Value value, Operation *op) { + if (auto defOp = value.getDefiningOp()) { + if (defOp == op) + return true; #if 1 - // TODO This crashes if defOp and op are not in the same block. - // At the same time, it does not seem to be strictly required. -// if (defOp->isBeforeInBlock(op)) - // Nevertheless, this modified line seems to be a good soft-filter; - // without that, the vectorization pass may take very long on - // programs with 100s of operations. - if (defOp->getBlock() == op->getBlock() && defOp->isBeforeInBlock(op)) - // can't have results of `op` as inputs, as it is defined before - return false; + // TODO This crashes if defOp and op are not in the same block. + // At the same time, it does not seem to be strictly required. + // if (defOp->isBeforeInBlock(op)) + // Nevertheless, this modified line seems to be a good soft-filter; + // without that, the vectorization pass may take very long on + // programs with 100s of operations. + if (defOp->getBlock() == op->getBlock() && defOp->isBeforeInBlock(op)) + // can't have results of `op` as inputs, as it is defined before + return false; #endif - for (auto operand : defOp->getOperands()) { - if (valueDependsOnResultOf(operand, op)) - return true; - } + for (auto operand : defOp->getOperands()) { + if (valueDependsOnResultOf(operand, op)) + return true; } - return false; } + return false; +} - /** - * @brief Check if the vectorizable operation can directly be fused into the pipeline, without requiring any other - * operation to be fused first. - * @param opBefore The vectorizable operation to check - * @param pipeline The pipeline - * @return true if it can be directly fused, false otherwise - */ - bool isDirectlyFusible(daphne::Vectorizable opBefore, const std::vector& pipeline) { - for (auto pipeOp : pipeline) { - for (auto operand : pipeOp->getOperands()) { - if (std::find(pipeline.begin(), pipeline.end(), operand.getDefiningOp()) != pipeline.end()) { - // transitive dependencies inside the pipeline are of course fine. - continue; - } - if (operand.getDefiningOp() != opBefore && valueDependsOnResultOf(operand, opBefore)) { - return false; - } +/** + * @brief Check if the vectorizable operation can directly be fused into the + * pipeline, without requiring any other operation to be fused first. + * @param opBefore The vectorizable operation to check + * @param pipeline The pipeline + * @return true if it can be directly fused, false otherwise + */ +bool isDirectlyFusible(daphne::Vectorizable opBefore, const std::vector &pipeline) { + for (auto pipeOp : pipeline) { + for (auto operand : pipeOp->getOperands()) { + if (std::find(pipeline.begin(), pipeline.end(), operand.getDefiningOp()) != pipeline.end()) { + // transitive dependencies inside the pipeline are of course + // fine. + continue; + } + if (operand.getDefiningOp() != opBefore && valueDependsOnResultOf(operand, opBefore)) { + return false; } } - return true; } + return true; +} - /** - * @brief Greedily fuses the operation into the pipeline if possible. - * @param operationToPipelineIx A map of operations to their index in the pipelines collection - * @param pipelines The collection of pipelines - * @param currentPipelineIx The index of the current pipeline into which we want to possibly fuse the operation - * @param operationToCheck The operation we possibly want to fuse into the current pipeline - */ - void greedyPipelineFusion(std::map &operationToPipelineIx, - std::vector> &pipelines, - size_t currentPipelineIx, daphne::Vectorizable operationToCheck) { - auto ¤tPipeline = pipelines[currentPipelineIx]; - auto existingPipelineIt = operationToPipelineIx.find(operationToCheck); - if(existingPipelineIt != operationToPipelineIx.end()) { - // existing pipeline is sure to be after the current pipeline (due to reverse iteration order) - auto existingPipelineIx = existingPipelineIt->second; - auto &existingPipeline = pipelines[existingPipelineIx]; - for (auto op : currentPipeline) { - if (!isDirectlyFusible(op, existingPipeline)) { - continue; - } - } - // append existing to current - currentPipeline.insert(currentPipeline.end(), existingPipeline.begin(), existingPipeline.end()); - for (auto vectorizable : existingPipeline) { - operationToPipelineIx[vectorizable] = currentPipelineIx; +/** + * @brief Greedily fuses the operation into the pipeline if possible. + * @param operationToPipelineIx A map of operations to their index in the + * pipelines collection + * @param pipelines The collection of pipelines + * @param currentPipelineIx The index of the current pipeline into which we want + * to possibly fuse the operation + * @param operationToCheck The operation we possibly want to fuse into the + * current pipeline + */ +void greedyPipelineFusion(std::map &operationToPipelineIx, + std::vector> &pipelines, size_t currentPipelineIx, + daphne::Vectorizable operationToCheck) { + auto ¤tPipeline = pipelines[currentPipelineIx]; + auto existingPipelineIt = operationToPipelineIx.find(operationToCheck); + if (existingPipelineIt != operationToPipelineIx.end()) { + // existing pipeline is sure to be after the current pipeline (due to + // reverse iteration order) + auto existingPipelineIx = existingPipelineIt->second; + auto &existingPipeline = pipelines[existingPipelineIx]; + for (auto op : currentPipeline) { + if (!isDirectlyFusible(op, existingPipeline)) { + continue; } - // just make it empty, it will be skipped later. Ixs changes and reshuffling is therefore not necessary. - existingPipeline.clear(); } - else if(isDirectlyFusible(operationToCheck, currentPipeline)) { - currentPipeline.push_back(operationToCheck); - operationToPipelineIx[operationToCheck] = currentPipelineIx; + // append existing to current + currentPipeline.insert(currentPipeline.end(), existingPipeline.begin(), existingPipeline.end()); + for (auto vectorizable : existingPipeline) { + operationToPipelineIx[vectorizable] = currentPipelineIx; } + // just make it empty, it will be skipped later. Ixs changes and + // reshuffling is therefore not necessary. + existingPipeline.clear(); + } else if (isDirectlyFusible(operationToCheck, currentPipeline)) { + currentPipeline.push_back(operationToCheck); + operationToPipelineIx[operationToCheck] = currentPipelineIx; } +} - /** - * @brief Moves operation which are between the operations, which should be fused into a single pipeline, before - * or after the position where the pipeline will be placed. - * @param pipelinePosition The position where the pipeline will be - * @param pipeline The pipeline for which this function should be executed - */ - void movePipelineInterleavedOperations(Block::iterator pipelinePosition, const std::vector &pipeline) { - // first operation in pipeline vector is last in IR, and the last is the first - auto startPos = pipeline.back()->getIterator(); - auto endPos = pipeline.front()->getIterator(); - auto currSkip = pipeline.rbegin(); - std::vector moveBeforeOps; - std::vector moveAfterOps; - for(auto it = startPos; it != endPos; ++it) { - if (it == (*currSkip)->getIterator()) { - ++currSkip; - continue; - } +/** + * @brief Moves operation which are between the operations, which should be + * fused into a single pipeline, before or after the position where the pipeline + * will be placed. + * @param pipelinePosition The position where the pipeline will be + * @param pipeline The pipeline for which this function should be executed + */ +void movePipelineInterleavedOperations(Block::iterator pipelinePosition, + const std::vector &pipeline) { + // first operation in pipeline vector is last in IR, and the last is the + // first + auto startPos = pipeline.back()->getIterator(); + auto endPos = pipeline.front()->getIterator(); + auto currSkip = pipeline.rbegin(); + std::vector moveBeforeOps; + std::vector moveAfterOps; + for (auto it = startPos; it != endPos; ++it) { + if (it == (*currSkip)->getIterator()) { + ++currSkip; + continue; + } - bool dependsOnPipeline = false; - auto pipelineOpsBeforeIt = currSkip; - while (--pipelineOpsBeforeIt != pipeline.rbegin()) { - for (auto operand : it->getOperands()) { - if(valueDependsOnResultOf(operand, *pipelineOpsBeforeIt)) { - dependsOnPipeline = true; - break; - } - } - if (dependsOnPipeline) { - break; - } - } - // check first pipeline op + bool dependsOnPipeline = false; + auto pipelineOpsBeforeIt = currSkip; + while (--pipelineOpsBeforeIt != pipeline.rbegin()) { for (auto operand : it->getOperands()) { - if(valueDependsOnResultOf(operand, *pipelineOpsBeforeIt)) { + if (valueDependsOnResultOf(operand, *pipelineOpsBeforeIt)) { dependsOnPipeline = true; break; } } if (dependsOnPipeline) { - moveAfterOps.push_back(&(*it)); - } - else { - moveBeforeOps.push_back(&(*it)); + break; } } - - for(auto moveBeforeOp: moveBeforeOps) { - moveBeforeOp->moveBefore(pipelinePosition->getBlock(), pipelinePosition); + // check first pipeline op + for (auto operand : it->getOperands()) { + if (valueDependsOnResultOf(operand, *pipelineOpsBeforeIt)) { + dependsOnPipeline = true; + break; + } } - for(auto moveAfterOp: moveAfterOps) { - moveAfterOp->moveAfter(pipelinePosition->getBlock(), pipelinePosition); - pipelinePosition = moveAfterOp->getIterator(); + if (dependsOnPipeline) { + moveAfterOps.push_back(&(*it)); + } else { + moveBeforeOps.push_back(&(*it)); } } - struct VectorizeComputationsPass : public PassWrapper> { - void runOnOperation() final; - }; + for (auto moveBeforeOp : moveBeforeOps) { + moveBeforeOp->moveBefore(pipelinePosition->getBlock(), pipelinePosition); + } + for (auto moveAfterOp : moveAfterOps) { + moveAfterOp->moveAfter(pipelinePosition->getBlock(), pipelinePosition); + pipelinePosition = moveAfterOp->getIterator(); + } } -void VectorizeComputationsPass::runOnOperation() -{ +struct VectorizeComputationsPass : public PassWrapper> { + void runOnOperation() final; +}; +} // namespace + +void VectorizeComputationsPass::runOnOperation() { auto func = getOperation(); - // TODO: fuse pipelines that have the matching inputs, even if no output of the one pipeline is used by the other. - // This requires multi-returns in way more cases, which is not implemented yet. + // TODO: fuse pipelines that have the matching inputs, even if no output of + // the one pipeline is used by the other. + // This requires multi-returns in way more cases, which is not implemented + // yet. // Find vectorizable operations and their inputs of vectorizable operations std::vector vectOps; - func->walk([&](daphne::Vectorizable op) - { - if(CompilerUtils::isMatrixComputation(op)) - vectOps.emplace_back(op); + func->walk([&](daphne::Vectorizable op) { + if (CompilerUtils::isMatrixComputation(op)) + vectOps.emplace_back(op); }); std::vector vectorizables(vectOps.begin(), vectOps.end()); std::multimap possibleMerges; - for(auto v : vectorizables) { - for(auto e : llvm::zip(v->getOperands(), v.getVectorSplits())) { + for (auto v : vectorizables) { + for (auto e : llvm::zip(v->getOperands(), v.getVectorSplits())) { auto operand = std::get<0>(e); auto defOp = operand.getDefiningOp(); - if(defOp && v->getBlock() == defOp->getBlock() && CompilerUtils::isMatrixComputation(defOp)) { + if (defOp && v->getBlock() == defOp->getBlock() && CompilerUtils::isMatrixComputation(defOp)) { // defOp is not a candidate for fusion with v, if the // result/operand along which we would fuse is used within a // nested block (e.g., control structure) between defOp and v. @@ -207,56 +213,53 @@ void VectorizeComputationsPass::runOnOperation() // when it would be safe (also taking NoSideEffect into // account). bool qualified = true; - for(OpOperand & use : operand.getUses()) { - Operation * user = use.getOwner(); - if(user->getBlock() != v->getBlock()) { + for (OpOperand &use : operand.getUses()) { + Operation *user = use.getOwner(); + if (user->getBlock() != v->getBlock()) { // user must be in a child block of the block in which // v resides, because we have already checked that v // and defOp are in the same block. - while(user->getBlock() != v->getBlock()) + while (user->getBlock() != v->getBlock()) user = user->getParentOp(); - if(user->isBeforeInBlock(v)) { + if (user->isBeforeInBlock(v)) { qualified = false; break; } } } - if(qualified){ + if (qualified) { auto split = std::get<1>(e); // find the corresponding `OpResult` to figure out combine auto opResult = *llvm::find(defOp->getResults(), operand); auto combine = defOp.getVectorCombines()[opResult.getResultNumber()]; - if(split == daphne::VectorSplit::ROWS) { - if(combine == daphne::VectorCombine::ROWS) + if (split == daphne::VectorSplit::ROWS) { + if (combine == daphne::VectorCombine::ROWS) possibleMerges.insert({v, defOp}); - } - else if (split == daphne::VectorSplit::NONE) { + } else if (split == daphne::VectorSplit::NONE) { // can't be merged - } - else { - throw ErrorHandler::compilerError( - v, "VectorizeComputationsPass", - "VectorSplit case `" + stringifyEnum(split).str() + - "` not handled"); + } else { + throw ErrorHandler::compilerError(v, "VectorizeComputationsPass", + "VectorSplit case `" + stringifyEnum(split).str() + + "` not handled"); } } } } } - // Collect vectorizable operations that can be computed together in pipelines + // Collect vectorizable operations that can be computed together in + // pipelines std::map operationToPipelineIx; std::vector> pipelines; - for(auto vIt = vectorizables.rbegin(); vIt != vectorizables.rend(); ++vIt) { + for (auto vIt = vectorizables.rbegin(); vIt != vectorizables.rend(); ++vIt) { auto v = *vIt; size_t pipelineIx; auto pipelineIt = operationToPipelineIx.find(v); - if(pipelineIt != operationToPipelineIx.end()) { + if (pipelineIt != operationToPipelineIx.end()) { pipelineIx = pipelineIt->second; - } - else { + } else { pipelineIx = pipelines.size(); std::vector pipeline; pipeline.push_back(v); @@ -265,17 +268,18 @@ void VectorizeComputationsPass::runOnOperation() // iterate all operands that could be combined into the pipeline auto itRange = possibleMerges.equal_range(v); - for(auto it = itRange.first; it != itRange.second; ++it) { + for (auto it = itRange.first; it != itRange.second; ++it) { auto operandVectorizable = it->second; - // TODO: this fuses greedily, the first pipeline we can fuse this operation into, we do. improve + // TODO: this fuses greedily, the first pipeline we can fuse this + // operation into, we do. improve greedyPipelineFusion(operationToPipelineIx, pipelines, pipelineIx, operandVectorizable); } } OpBuilder builder(func); // Create the `VectorizedPipelineOp`s - for(auto pipeline : pipelines) { - if(pipeline.empty()) { + for (auto pipeline : pipelines) { + if (pipeline.empty()) { continue; } auto valueIsPartOfPipeline = [&](Value operand) { @@ -291,110 +295,106 @@ void VectorizeComputationsPass::runOnOperation() // first op in pipeline is last in IR builder.setInsertionPoint(pipeline.front()); - // move all operations, between the operations that will be part of the pipeline, before or after the - // completed pipeline + // move all operations, between the operations that will be part of the + // pipeline, before or after the completed pipeline movePipelineInterleavedOperations(builder.getInsertionPoint(), pipeline); - for(auto vIt = pipeline.rbegin(); vIt != pipeline.rend(); ++vIt) { + for (auto vIt = pipeline.rbegin(); vIt != pipeline.rend(); ++vIt) { auto v = *vIt; auto vSplits = v.getVectorSplits(); auto vCombines = v.getVectorCombines(); - // TODO: although we do create enum attributes, it might make sense/make it easier to + // TODO: although we do create enum attributes, it might make + // sense/make it easier to // just directly use an I64ArrayAttribute - for(auto i = 0u; i < v->getNumOperands(); ++i) { + for (auto i = 0u; i < v->getNumOperands(); ++i) { auto operand = v->getOperand(i); - if(!valueIsPartOfPipeline(operand)) { + if (!valueIsPartOfPipeline(operand)) { vSplitAttrs.push_back(daphne::VectorSplitAttr::get(&getContext(), vSplits[i])); operands.push_back(operand); } } - for(auto vCombine : vCombines) { + for (auto vCombine : vCombines) { vCombineAttrs.push_back(daphne::VectorCombineAttr::get(&getContext(), vCombine)); } locations.push_back(v->getLoc()); - for(auto result: v->getResults()) { + for (auto result : v->getResults()) { results.push_back(result); } - for(auto outSize: v.createOpsOutputSizes(builder)) { + for (auto outSize : v.createOpsOutputSizes(builder)) { outRows.push_back(outSize.first); outCols.push_back(outSize.second); } } std::vector locs; locs.reserve(pipeline.size()); - for(auto op: pipeline) { + for (auto op : pipeline) { locs.push_back(op->getLoc()); } auto loc = builder.getFusedLoc(locs); - auto pipelineOp = builder.create(loc, - ValueRange(results).getTypes(), - operands, - outRows, - outCols, - builder.getArrayAttr(vSplitAttrs), - builder.getArrayAttr(vCombineAttrs), - nullptr); + auto pipelineOp = builder.create( + loc, ValueRange(results).getTypes(), operands, outRows, outCols, builder.getArrayAttr(vSplitAttrs), + builder.getArrayAttr(vCombineAttrs), nullptr); Block *bodyBlock = builder.createBlock(&pipelineOp.getBody()); - for(size_t i = 0u; i < operands.size(); ++i) { + for (size_t i = 0u; i < operands.size(); ++i) { auto argTy = operands[i].getType(); switch (vSplitAttrs[i].cast().getValue()) { - case daphne::VectorSplit::ROWS: { - auto matTy = argTy.cast(); - // only remove row information - argTy = matTy.withShape(-1, matTy.getNumCols()); - break; - } - case daphne::VectorSplit::NONE: - // keep any size information - break; + case daphne::VectorSplit::ROWS: { + auto matTy = argTy.cast(); + // only remove row information + argTy = matTy.withShape(-1, matTy.getNumCols()); + break; + } + case daphne::VectorSplit::NONE: + // keep any size information + break; } bodyBlock->addArgument(argTy, builder.getUnknownLoc()); } auto argsIx = 0u; auto resultsIx = 0u; - for(auto vIt = pipeline.rbegin(); vIt != pipeline.rend(); ++vIt) { + for (auto vIt = pipeline.rbegin(); vIt != pipeline.rend(); ++vIt) { auto v = *vIt; auto numOperands = v->getNumOperands(); auto numResults = v->getNumResults(); v->moveBefore(bodyBlock, bodyBlock->end()); - for(auto i = 0u; i < numOperands; ++i) { - if(!valueIsPartOfPipeline(v->getOperand(i))) { + for (auto i = 0u; i < numOperands; ++i) { + if (!valueIsPartOfPipeline(v->getOperand(i))) { v->setOperand(i, bodyBlock->getArgument(argsIx++)); } } auto pipelineReplaceResults = pipelineOp->getResults().drop_front(resultsIx).take_front(numResults); resultsIx += numResults; - for(auto z: llvm::zip(v->getResults(), pipelineReplaceResults)) { + for (auto z : llvm::zip(v->getResults(), pipelineReplaceResults)) { auto old = std::get<0>(z); auto replacement = std::get<1>(z); // TODO: switch to type based size inference instead // FIXME: if output is dynamic sized, we can't do this // replace `NumRowOp` and `NumColOp`s for output size inference - for(auto& use: old.getUses()) { - auto* op = use.getOwner(); - if(auto nrowOp = llvm::dyn_cast(op)) { + for (auto &use : old.getUses()) { + auto *op = use.getOwner(); + if (auto nrowOp = llvm::dyn_cast(op)) { nrowOp.replaceAllUsesWith(pipelineOp.getOutRows()[replacement.getResultNumber()]); nrowOp.erase(); } - if(auto ncolOp = llvm::dyn_cast(op)) { + if (auto ncolOp = llvm::dyn_cast(op)) { ncolOp.replaceAllUsesWith(pipelineOp.getOutCols()[replacement.getResultNumber()]); ncolOp.erase(); } } // Replace only if not used by pipeline op - old.replaceUsesWithIf(replacement, [&](OpOperand& opOperand) { + old.replaceUsesWithIf(replacement, [&](OpOperand &opOperand) { return llvm::count(pipeline, opOperand.getOwner()) == 0; }); } } - bodyBlock->walk([](Operation* op) { - for(auto resVal: op->getResults()) { - if(auto ty = resVal.getType().dyn_cast()) { + bodyBlock->walk([](Operation *op) { + for (auto resVal : op->getResults()) { + if (auto ty = resVal.getType().dyn_cast()) { resVal.setType(ty.withShape(-1, -1)); } } diff --git a/src/compiler/lowering/WhileLoopInvariantCodeMotionPass.cpp b/src/compiler/lowering/WhileLoopInvariantCodeMotionPass.cpp deleted file mode 100644 index 8e933155e..000000000 --- a/src/compiler/lowering/WhileLoopInvariantCodeMotionPass.cpp +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2021 The DAPHNE Consortium - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ir/daphneir/Passes.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SCF/IR/SCF.h" - -#include - -using namespace mlir; - -/** - * @brief This is a very limited variant of loop invariant code motion (LICM), - * tailored just to WhileOp. - * - * We need this because MLIR does not seem to support LICM for while loops. - * Nevertheless, we should clarify this (see #175). - * - * This pass is strongly inspired by MLIR's LoopInvariantCodeMotion.cpp, but - * significantly simplified. - */ -struct WhileLoopInvariantCodeMotionPass -: public PassWrapper > { - void runOnOperation() final; - - StringRef getArgument() const final { return "while-loop-invariant-code-motion"; } - StringRef getDescription() const final { return "TODO"; } -}; - -void WhileLoopInvariantCodeMotionPass::runOnOperation() { - getOperation()->walk([&](scf::WhileOp whileOp) { - Region & loopBody = whileOp.getAfter(); - - SmallPtrSet willBeMovedSet; - SmallVector opsToMove; - - auto isDefinedOutsideOfBody = [&](Value value) { - auto definingOp = value.getDefiningOp(); - return (definingOp && !!willBeMovedSet.count(definingOp)) || - !loopBody.isAncestor(value.getParentRegion()); - }; - - for(auto & block : loopBody) - for(auto & op : block.without_terminator()) { - auto memInterface = dyn_cast(op); - if( - llvm::all_of(op.getOperands(), isDefinedOutsideOfBody) && - op.hasTrait() && // such that we don't need to recurse - memInterface && memInterface.hasNoEffect() - ) { - opsToMove.push_back(&op); - willBeMovedSet.insert(&op); - } - } - - for(auto op : opsToMove) - op->moveBefore(whileOp); - }); -} - -std::unique_ptr daphne::createWhileLoopInvariantCodeMotionPass() { - return std::make_unique(); -} diff --git a/src/compiler/utils/CompilerUtils.cpp b/src/compiler/utils/CompilerUtils.cpp index 85c55a814..da0931ab5 100644 --- a/src/compiler/utils/CompilerUtils.cpp +++ b/src/compiler/utils/CompilerUtils.cpp @@ -24,15 +24,16 @@ // Specializations of isConstantHelper for string types // ************************************************************************************************** -template<> -std::pair CompilerUtils::isConstantHelper(mlir::Value v, const std::function& func) { - if(auto co = v.getDefiningOp()) { - if(auto attr = co.getValue().dyn_cast()) { +template <> +std::pair CompilerUtils::isConstantHelper( + mlir::Value v, const std::function &func) { + if (auto co = v.getDefiningOp()) { + if (auto attr = co.getValue().dyn_cast()) { return std::make_pair(true, func(attr)); } } - if(auto co = v.getDefiningOp()) { - if(auto attr = co.getValue().dyn_cast()) { + if (auto co = v.getDefiningOp()) { + if (auto attr = co.getValue().dyn_cast()) { return std::make_pair(true, func(attr)); } } @@ -43,166 +44,120 @@ std::pair CompilerUtils::isConstantHelper -std::pair CompilerUtils::isConstant(mlir::Value v) { - return isConstantHelper( - v, [](mlir::StringAttr attr){return attr.getValue().str();} - ); +template <> std::pair CompilerUtils::isConstant(mlir::Value v) { + return isConstantHelper(v, + [](mlir::StringAttr attr) { return attr.getValue().str(); }); } -template<> -std::pair CompilerUtils::isConstant(mlir::Value v) { +template <> std::pair CompilerUtils::isConstant(mlir::Value v) { return isConstantHelper( - v, [](mlir::IntegerAttr attr){return attr.getValue().getLimitedValue();} - ); + v, [](mlir::IntegerAttr attr) { return attr.getValue().getLimitedValue(); }); } -template<> -std::pair CompilerUtils::isConstant(mlir::Value v) { +template <> std::pair CompilerUtils::isConstant(mlir::Value v) { return isConstantHelper( - v, [](mlir::IntegerAttr attr){return attr.getValue().getLimitedValue();} - ); + v, [](mlir::IntegerAttr attr) { return attr.getValue().getLimitedValue(); }); } -template<> -std::pair CompilerUtils::isConstant(mlir::Value v) { +template <> std::pair CompilerUtils::isConstant(mlir::Value v) { return isConstantHelper( - v, [](mlir::IntegerAttr attr){return attr.getValue().getLimitedValue();} - ); + v, [](mlir::IntegerAttr attr) { return attr.getValue().getLimitedValue(); }); } -template<> -std::pair CompilerUtils::isConstant(mlir::Value v) { +template <> std::pair CompilerUtils::isConstant(mlir::Value v) { return isConstantHelper( - v, [](mlir::IntegerAttr attr){return attr.getValue().getLimitedValue();} - ); + v, [](mlir::IntegerAttr attr) { return attr.getValue().getLimitedValue(); }); } -template<> -std::pair CompilerUtils::isConstant(mlir::Value v) { +template <> std::pair CompilerUtils::isConstant(mlir::Value v) { return isConstantHelper( - v, [](mlir::IntegerAttr attr){return attr.getValue().getLimitedValue();} - ); + v, [](mlir::IntegerAttr attr) { return attr.getValue().getLimitedValue(); }); } -template<> -std::pair CompilerUtils::isConstant(mlir::Value v) { +template <> std::pair CompilerUtils::isConstant(mlir::Value v) { return isConstantHelper( - v, [](mlir::IntegerAttr attr){return attr.getValue().getLimitedValue();} - ); + v, [](mlir::IntegerAttr attr) { return attr.getValue().getLimitedValue(); }); } -template<> -std::pair CompilerUtils::isConstant(mlir::Value v) { +template <> std::pair CompilerUtils::isConstant(mlir::Value v) { return isConstantHelper( - v, [](mlir::FloatAttr attr){return attr.getValue().convertToFloat();} - ); + v, [](mlir::FloatAttr attr) { return attr.getValue().convertToFloat(); }); } -template<> -std::pair CompilerUtils::isConstant(mlir::Value v) { +template <> std::pair CompilerUtils::isConstant(mlir::Value v) { return isConstantHelper( - v, [](mlir::FloatAttr attr){return attr.getValue().convertToDouble();} - ); + v, [](mlir::FloatAttr attr) { return attr.getValue().convertToDouble(); }); } -template<> -std::pair CompilerUtils::isConstant(mlir::Value v) { - return isConstantHelper( - v, [](mlir::BoolAttr attr){return attr.getValue();} - ); +template <> std::pair CompilerUtils::isConstant(mlir::Value v) { + return isConstantHelper(v, [](mlir::BoolAttr attr) { return attr.getValue(); }); } // ************************************************************************************************** // Specializations of constantOrThrow for various types // ************************************************************************************************** -template<> -std::string CompilerUtils::constantOrThrow(mlir::Value v, const std::string & errorMsg) { +template <> std::string CompilerUtils::constantOrThrow(mlir::Value v, const std::string &errorMsg) { return constantOrThrowHelper( - v, [](mlir::StringAttr attr){return attr.getValue().str();}, errorMsg, "string" - ); + v, [](mlir::StringAttr attr) { return attr.getValue().str(); }, errorMsg, "string"); } -template<> -int64_t CompilerUtils::constantOrThrow(mlir::Value v, const std::string & errorMsg) { +template <> int64_t CompilerUtils::constantOrThrow(mlir::Value v, const std::string &errorMsg) { return constantOrThrowHelper( - v, [](mlir::IntegerAttr attr){return attr.getValue().getLimitedValue();}, errorMsg, "integer" - ); + v, [](mlir::IntegerAttr attr) { return attr.getValue().getLimitedValue(); }, errorMsg, "integer"); } -template<> -uint64_t CompilerUtils::constantOrThrow(mlir::Value v, const std::string & errorMsg) { +template <> uint64_t CompilerUtils::constantOrThrow(mlir::Value v, const std::string &errorMsg) { return constantOrThrowHelper( - v, [](mlir::IntegerAttr attr){return attr.getValue().getLimitedValue();}, errorMsg, "integer" - ); + v, [](mlir::IntegerAttr attr) { return attr.getValue().getLimitedValue(); }, errorMsg, "integer"); } -template<> -float CompilerUtils::constantOrThrow(mlir::Value v, const std::string & errorMsg) { +template <> float CompilerUtils::constantOrThrow(mlir::Value v, const std::string &errorMsg) { return constantOrThrowHelper( - v, [](mlir::FloatAttr attr){return attr.getValue().convertToFloat();}, errorMsg, "float" - ); + v, [](mlir::FloatAttr attr) { return attr.getValue().convertToFloat(); }, errorMsg, "float"); } -template<> -double CompilerUtils::constantOrThrow(mlir::Value v, const std::string & errorMsg) { +template <> double CompilerUtils::constantOrThrow(mlir::Value v, const std::string &errorMsg) { return constantOrThrowHelper( - v, [](mlir::FloatAttr attr){return attr.getValue().convertToDouble();}, errorMsg, "double" - ); + v, [](mlir::FloatAttr attr) { return attr.getValue().convertToDouble(); }, errorMsg, "double"); } -template<> -bool CompilerUtils::constantOrThrow(mlir::Value v, const std::string & errorMsg) { +template <> bool CompilerUtils::constantOrThrow(mlir::Value v, const std::string &errorMsg) { return constantOrThrowHelper( - v, [](mlir::BoolAttr attr){return attr.getValue();}, errorMsg, "bool" - ); + v, [](mlir::BoolAttr attr) { return attr.getValue(); }, errorMsg, "bool"); } // ************************************************************************************************** // Specializations of constantOrDefault for various types // ************************************************************************************************** -template<> -std::string CompilerUtils::constantOrDefault(mlir::Value v, std::string d) { +template <> std::string CompilerUtils::constantOrDefault(mlir::Value v, std::string d) { return constantOrDefaultHelper( - v, std::move(d), [](mlir::StringAttr attr){return attr.getValue().str();} - ); + v, std::move(d), [](mlir::StringAttr attr) { return attr.getValue().str(); }); } -template<> -int64_t CompilerUtils::constantOrDefault(mlir::Value v, int64_t d) { +template <> int64_t CompilerUtils::constantOrDefault(mlir::Value v, int64_t d) { return constantOrDefaultHelper( - v, d, [](mlir::IntegerAttr attr){return attr.getValue().getLimitedValue();} - ); + v, d, [](mlir::IntegerAttr attr) { return attr.getValue().getLimitedValue(); }); } -template<> -uint64_t CompilerUtils::constantOrDefault(mlir::Value v, uint64_t d) { +template <> uint64_t CompilerUtils::constantOrDefault(mlir::Value v, uint64_t d) { return constantOrDefaultHelper( - v, d, [](mlir::IntegerAttr attr){return attr.getValue().getLimitedValue();} - ); + v, d, [](mlir::IntegerAttr attr) { return attr.getValue().getLimitedValue(); }); } -template<> -float CompilerUtils::constantOrDefault(mlir::Value v, float d) { +template <> float CompilerUtils::constantOrDefault(mlir::Value v, float d) { return constantOrDefaultHelper( - v, d, [](mlir::FloatAttr attr){return attr.getValue().convertToFloat();} - ); + v, d, [](mlir::FloatAttr attr) { return attr.getValue().convertToFloat(); }); } -template<> -double CompilerUtils::constantOrDefault(mlir::Value v, double d) { +template <> double CompilerUtils::constantOrDefault(mlir::Value v, double d) { return constantOrDefaultHelper( - v, d, [](mlir::FloatAttr attr){return attr.getValue().convertToDouble();} - ); + v, d, [](mlir::FloatAttr attr) { return attr.getValue().convertToDouble(); }); } -template<> -bool CompilerUtils::constantOrDefault(mlir::Value v, bool d) { - return constantOrDefaultHelper( - v, d, [](mlir::BoolAttr attr){return attr.getValue();} - ); +template <> bool CompilerUtils::constantOrDefault(mlir::Value v, bool d) { + return constantOrDefaultHelper(v, d, [](mlir::BoolAttr attr) { return attr.getValue(); }); } // ************************************************************************************************** @@ -214,8 +169,6 @@ bool CompilerUtils::constantOrDefault(mlir::Value v, bool d) { } bool CompilerUtils::isMatrixComputation(mlir::Operation *v) { - return - llvm::any_of(v->getOperandTypes(), [&](mlir::Type ty){ return llvm::isa(ty); }) - || - llvm::any_of(v->getResultTypes(), [&](mlir::Type ty){ return llvm::isa(ty); }); + return llvm::any_of(v->getOperandTypes(), [&](mlir::Type ty) { return llvm::isa(ty); }) || + llvm::any_of(v->getResultTypes(), [&](mlir::Type ty) { return llvm::isa(ty); }); } diff --git a/src/compiler/utils/CompilerUtils.h b/src/compiler/utils/CompilerUtils.h index fe418419e..a8dd3dfd2 100644 --- a/src/compiler/utils/CompilerUtils.h +++ b/src/compiler/utils/CompilerUtils.h @@ -16,6 +16,7 @@ #pragma once +// clang-format off #include #include #include "util/ErrorHandler.h" @@ -25,183 +26,187 @@ #include #include +// clang-format on struct CompilerUtils { -private: - - template - static std::pair isConstantHelper(mlir::Value v, const std::function& func) { - if(auto co = v.getDefiningOp()) - if(auto attr = co.getValue().dyn_cast()) + private: + template + static std::pair isConstantHelper(mlir::Value v, const std::function &func) { + if (auto co = v.getDefiningOp()) + if (auto attr = co.getValue().dyn_cast()) return std::make_pair(true, func(attr)); - if(auto co = v.getDefiningOp()) - if(auto attr = co.getValue().dyn_cast()) + if (auto co = v.getDefiningOp()) + if (auto attr = co.getValue().dyn_cast()) return std::make_pair(true, func(attr)); return std::make_pair(false, ValT(0)); } - template - static ValT constantOrThrowHelper(mlir::Value v, std::function func, const std::string & errorMsg, const std::string & valTypeName) { + template + static ValT constantOrThrowHelper(mlir::Value v, std::function func, + const std::string &errorMsg, const std::string &valTypeName) { auto p = isConstantHelper(v, func); - if(p.first) + if (p.first) return p.second; else - throw ErrorHandler::compilerError(v.getLoc(), "constantOrThrow", - errorMsg.empty() ? - ("the given value must be a constant of " + valTypeName + " type") - : errorMsg - ); + throw ErrorHandler::compilerError( + v.getLoc(), "constantOrThrow", + errorMsg.empty() ? ("the given value must be a constant of " + valTypeName + " type") : errorMsg); } - template + template static ValT constantOrDefaultHelper(mlir::Value v, ValT d, std::function func) { auto p = isConstantHelper(v, func); - if(p.first) + if (p.first) return p.second; else return d; } - -public: + public: /** - * @brief If the given `Value` is defined by some constant operation, return that constant - * operation; otherwise, return `nullptr`. - * + * @brief If the given `Value` is defined by some constant operation, return + * that constant operation; otherwise, return `nullptr`. + * * @param v The `Value`. * @return The defining constant operation or `nullptr`. */ - static mlir::Operation * constantOfAnyType(mlir::Value v) { - if(auto co = v.getDefiningOp()) + static mlir::Operation *constantOfAnyType(mlir::Value v) { + if (auto co = v.getDefiningOp()) return co; - if(auto co = v.getDefiningOp()) + if (auto co = v.getDefiningOp()) return co; return nullptr; } /** - * @brief Returns if the given `Value` is a constant, and if so, also the constant itself. - * + * @brief Returns if the given `Value` is a constant, and if so, also the + * constant itself. + * * @tparam T The C++ type of the constant to extract. * @param v The `Value`. - * @return If the given value is a constant: a pair of the value `true` and the constant value as type `T`; - * otherwise, a pair of the value `false` and an unspecified value of type `T`. + * @return If the given value is a constant: a pair of the value `true` and + * the constant value as type `T`; otherwise, a pair of the value `false` + * and an unspecified value of type `T`. */ - template - static std::pair isConstant(mlir::Value v); - + template static std::pair isConstant(mlir::Value v); + /** - * @brief Returns a constant extracted from the given `Value`, or throws an exception if this is not possible. - * + * @brief Returns a constant extracted from the given `Value`, or throws an + * exception if this is not possible. + * * @tparam T The C++ type of the constant to extract. * @param v The `Value`. - * @param errorMsg The message of the exception to throw. In case of an empty string (default), the exception - * will have a generic error message. - * @return The extracted constant as a value of type `T`, if the given value is a constant. + * @param errorMsg The message of the exception to throw. In case of an + * empty string (default), the exception will have a generic error message. + * @return The extracted constant as a value of type `T`, if the given value + * is a constant. */ - template - static T constantOrThrow(mlir::Value v, const std::string & errorMsg = ""); + template static T constantOrThrow(mlir::Value v, const std::string &errorMsg = ""); /** - * @brief Returns a constant extracted from the given `Value`, or a default value if this is not possible. - * + * @brief Returns a constant extracted from the given `Value`, or a default + * value if this is not possible. + * * @tparam T The C++ type of the constant to extract. * @param v The `Value`. * @param d The default value. - * @return The extracted constant as a value of type `T`, if the given value is a constant, or the given - * default value, otherwise. + * @return The extracted constant as a value of type `T`, if the given value + * is a constant, or the given default value, otherwise. */ - template - static T constantOrDefault(mlir::Value v, T d); + template static T constantOrDefault(mlir::Value v, T d); [[maybe_unused]] static FileMetaData getFileMetaData(mlir::Value filename); /** - * @brief Produces a string containing the C++ type name of the corresponding MLIR type. Mainly used to - * generate function names for generated kernel libraries. This function is defined recursively to also print - * the value types of templated containers (e.g., DenseMatrix). A pragma is added to silence clang-tidy which - * might complain about recursion. + * @brief Produces a string containing the C++ type name of the + * corresponding MLIR type. Mainly used to generate function names for + * generated kernel libraries. This function is defined recursively to also + * print the value types of templated containers (e.g., DenseMatrix). + * A pragma is added to silence clang-tidy which might complain about + * recursion. * * @param t MLIR type name - * @param angleBrackets If `true` (default), angle brackets are used for C++ template types (e.g., `DenseMatrix`); - * Otherwise, underscores are used (e.g., `DenseMatrix_float`). - * @param generalizeToStructure If `true`, `Structure` is used instead of derived types like `DenseMatrix` etc. + * @param angleBrackets If `true` (default), angle brackets are used for C++ + * template types (e.g., `DenseMatrix`); Otherwise, underscores are + * used (e.g., `DenseMatrix_float`). + * @param generalizeToStructure If `true`, `Structure` is used instead of + * derived types like `DenseMatrix` etc. * @return A string representation of the C++ type names */ - // TODO The parameter generalizeToStructure seems to be used only by some remaining kernel name generation - // in LowerToLLVMPass. Once those call-sites have been refactored to use the kernel catalog, this feature + // TODO The parameter generalizeToStructure seems to be used only by some + // remaining kernel name generation in LowerToLLVMPass. Once those + // call-sites have been refactored to use the kernel catalog, this feature // can be removed here. - static std::string mlirTypeToCppTypeName(mlir::Type t, bool angleBrackets = true, bool generalizeToStructure = false) { // NOLINT(misc-no-recursion) - if(t.isF64()) + static std::string mlirTypeToCppTypeName(mlir::Type t, bool angleBrackets = true, + bool generalizeToStructure = false) { // NOLINT(misc-no-recursion) + if (t.isF64()) return "double"; - else if(t.isF32()) + else if (t.isF32()) return "float"; - else if(t.isSignedInteger(8)) + else if (t.isSignedInteger(8)) return "int8_t"; - else if(t.isSignedInteger(32)) + else if (t.isSignedInteger(32)) return "int32_t"; - else if(t.isSignedInteger(64)) + else if (t.isSignedInteger(64)) return "int64_t"; - else if(t.isUnsignedInteger(8)) + else if (t.isUnsignedInteger(8)) return "uint8_t"; - else if(t.isUnsignedInteger(32)) + else if (t.isUnsignedInteger(32)) return "uint32_t"; - else if(t.isUnsignedInteger(64)) + else if (t.isUnsignedInteger(64)) return "uint64_t"; - else if(t.isSignlessInteger(1)) + else if (t.isSignlessInteger(1)) return "bool"; - else if(t.isIndex()) + else if (t.isIndex()) return "size_t"; - else if(t.isa()) + else if (t.isa()) return "Structure"; - else if(auto matTy = t.dyn_cast()) { - if(generalizeToStructure) + else if (auto matTy = t.dyn_cast()) { + if (generalizeToStructure) return "Structure"; else { switch (matTy.getRepresentation()) { - case mlir::daphne::MatrixRepresentation::Dense: { - const std::string vtName = mlirTypeToCppTypeName(matTy.getElementType(), angleBrackets, false); - return angleBrackets ? ("DenseMatrix<" + vtName + ">") : ("DenseMatrix_" + vtName); - } - case mlir::daphne::MatrixRepresentation::Sparse: { - const std::string vtName = mlirTypeToCppTypeName(matTy.getElementType(), angleBrackets, false); - return angleBrackets ? ("CSRMatrix<" + vtName + ">") : ("CSRMatrix_" + vtName); - } + case mlir::daphne::MatrixRepresentation::Dense: { + const std::string vtName = mlirTypeToCppTypeName(matTy.getElementType(), angleBrackets, false); + return angleBrackets ? ("DenseMatrix<" + vtName + ">") : ("DenseMatrix_" + vtName); + } + case mlir::daphne::MatrixRepresentation::Sparse: { + const std::string vtName = mlirTypeToCppTypeName(matTy.getElementType(), angleBrackets, false); + return angleBrackets ? ("CSRMatrix<" + vtName + ">") : ("CSRMatrix_" + vtName); + } } } - } - else if(llvm::isa(t)) - if(generalizeToStructure) + } else if (llvm::isa(t)) + if (generalizeToStructure) return "Structure"; else return "Frame"; - else if(auto lstTy = t.dyn_cast()) { - if(generalizeToStructure) + else if (auto lstTy = t.dyn_cast()) { + if (generalizeToStructure) return "Structure"; else { const std::string dtName = mlirTypeToCppTypeName(lstTy.getElementType(), angleBrackets, false); return angleBrackets ? ("List<" + dtName + ">") : ("List_" + dtName); } - } - else if(llvm::isa(t)) + } else if (llvm::isa(t)) // This becomes "const char *" (which makes perfect sense for // strings) when inserted into the typical "const DT *" template of // kernel input parameters. return "char"; - else if(llvm::isa(t)) + else if (llvm::isa(t)) return "DaphneContext"; - else if(auto handleTy = t.dyn_cast()) { - const std::string tName = mlirTypeToCppTypeName(handleTy.getDataType(), angleBrackets, generalizeToStructure); + else if (auto handleTy = t.dyn_cast()) { + const std::string tName = + mlirTypeToCppTypeName(handleTy.getDataType(), angleBrackets, generalizeToStructure); return angleBrackets ? ("Handle<" + tName + ">") : ("Handle_" + tName); - } - else if(llvm::isa(t)) + } else if (llvm::isa(t)) return "File"; - else if(llvm::isa(t)) + else if (llvm::isa(t)) return "Descriptor"; - else if(llvm::isa(t)) + else if (llvm::isa(t)) return "Target"; - else if(auto memRefType = t.dyn_cast()) { + else if (auto memRefType = t.dyn_cast()) { const std::string vtName = mlirTypeToCppTypeName(memRefType.getElementType(), angleBrackets, false); return angleBrackets ? ("StridedMemRefType<" + vtName + ",2>") : ("StridedMemRefType_" + vtName + "_2"); } @@ -209,60 +214,55 @@ struct CompilerUtils { std::string typeName; llvm::raw_string_ostream rsos(typeName); t.print(rsos); - throw std::runtime_error( - "no C++ type name known for the given MLIR type: " + typeName - ); + throw std::runtime_error("no C++ type name known for the given MLIR type: " + typeName); } static bool isMatrixComputation(mlir::Operation *v); /** * @brief Returns the DAPHNE context used in the given function. - * + * * Throws if there is not exactly one DAPHNE context. - * + * * @param func - * @return + * @return */ - [[maybe_unused]] mlir::Value static getDaphneContext(mlir::func::FuncOp & func) { + [[maybe_unused]] mlir::Value static getDaphneContext(mlir::func::FuncOp &func) { mlir::Value dctx = nullptr; auto ops = func.getBody().front().getOps(); - for(auto op : ops) { - if(!dctx) + for (auto op : ops) { + if (!dctx) dctx = op.getResult(); else throw ErrorHandler::compilerError(op.getLoc(), "getDaphneContext", - "function body block contains more than one CreateDaphneContextOp" - ); + "function body block contains more than one " + "CreateDaphneContextOp"); } - if(!dctx) + if (!dctx) throw ErrorHandler::compilerError(func.getLoc(), "getDaphneContext", - "function body block contains no CreateDaphneContextOp" - ); + "function body block contains no CreateDaphneContextOp"); return dctx; } - + [[maybe_unused]] static bool isObjType(mlir::Type t) { return llvm::isa(t); } - - [[maybe_unused]] static bool hasObjType(mlir::Value v) { - return isObjType(v.getType()); - } + + [[maybe_unused]] static bool hasObjType(mlir::Value v) { return isObjType(v.getType()); } /** * @brief Returns the value type of the given scalar/matrix/frame type. - * + * * For matrices and frames, the value type is extracted. For scalars, * the type itself is the value type. - * + * * @param t the given scalar/matrix/frame type * @return the value type of the given type */ static mlir::Type getValueType(mlir::Type t) { - if(auto mt = t.dyn_cast()) + if (auto mt = t.dyn_cast()) return mt.getElementType(); - if(auto ft = t.dyn_cast()) + if (auto ft = t.dyn_cast()) throw std::runtime_error("getValueType() doesn't support frames yet"); // TODO else // TODO Check if this is really a scalar. return t; @@ -271,18 +271,18 @@ struct CompilerUtils { /** * @brief Sets the value type of the given scalar/matrix/frame type to the * given value type and returns this derived type. - * + * * For matrices and frames, the value type is set to the given value type. * For scalars, the given value type itself is returned. - * + * * @param t the scalar/matrix/frame type whose value type shall be set * @param vt the value type to use * @return the derived scalar/matrix/frame type */ static mlir::Type setValueType(mlir::Type t, mlir::Type vt) { - if(auto mt = t.dyn_cast()) + if (auto mt = t.dyn_cast()) return mt.withElementType(vt); - if(auto ft = t.dyn_cast()) + if (auto ft = t.dyn_cast()) throw std::runtime_error("setValueType() doesn't support frames yet"); // TODO else // TODO Check if this is really a scalar. return vt; @@ -291,13 +291,13 @@ struct CompilerUtils { /** * @brief Checks if the two given types are the same, whereby * DaphneIR's unknown type acts as a wildcard. - * + * * The two types are considered equal, iff they are exactly the same * type, or one of the following "excuses" holds: * - at least one of the types is unknown * - both types are matrices and at least one of them has an unknown * value type - * + * * @param t1 The first type * @param t2 The second type * @result `true` if the two types are considered equal, `false` otherwise @@ -311,16 +311,13 @@ struct CompilerUtils { // The two types are exactly the same... t1 == t2 // ...or one of the following "excuses" holds: - || ( + || + ( // at least one of the types is unknown llvm::isa(t1) || llvm::isa(t2) || // both types are matrices and at least one of them // has an unknown value type - (matT1 && matT2 && ( - llvm::isa(matT1.getElementType()) || - llvm::isa(matT2.getElementType()) - )) - ) - ); + (matT1 && matT2 && + (llvm::isa(matT1.getElementType()) || llvm::isa(matT2.getElementType()))))); } }; diff --git a/src/compiler/utils/LoweringUtils.cpp b/src/compiler/utils/LoweringUtils.cpp index b9f3107a2..230ff0281 100644 --- a/src/compiler/utils/LoweringUtils.cpp +++ b/src/compiler/utils/LoweringUtils.cpp @@ -28,8 +28,7 @@ #include "mlir/Transforms/Passes.h" /// Insert an allocation for the given MemRefType. -mlir::Value insertMemRefAlloc(mlir::MemRefType type, mlir::Location loc, - mlir::PatternRewriter &rewriter) { +mlir::Value insertMemRefAlloc(mlir::MemRefType type, mlir::Location loc, mlir::PatternRewriter &rewriter) { auto alloc = rewriter.create(loc, type); // Make sure to allocate at the beginning of the block. @@ -39,49 +38,13 @@ mlir::Value insertMemRefAlloc(mlir::MemRefType type, mlir::Location loc, return alloc; } -void insertMemRefDealloc(mlir::Value memref, mlir::Location loc, - mlir::PatternRewriter &rewriter) { +void insertMemRefDealloc(mlir::Value memref, mlir::Location loc, mlir::PatternRewriter &rewriter) { auto dealloc = rewriter.create(loc, memref); dealloc->moveBefore(&memref.getParentBlock()->back()); } -// TODO(phil) try to provide function templates to remove duplication -void affineFillMemRefInt(int value, mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, mlir::ArrayRef shape, - mlir::MLIRContext *ctx, mlir::Value memRef, - mlir::Type elemType) { - constexpr int ROW = 0; - constexpr int COL = 1; - mlir::Value fillValue = rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(value)); - - llvm::SmallVector loopIvs; - - auto outerLoop = rewriter.create(loc, 0, shape[ROW], 1); - for (mlir::Operation &nested : *outerLoop.getBody()) { - rewriter.eraseOp(&nested); - } - loopIvs.push_back(outerLoop.getInductionVar()); - - // outer loop body - rewriter.setInsertionPointToStart(outerLoop.getBody()); - auto innerLoop = rewriter.create(loc, 0, shape[COL], 1); - for (mlir::Operation &nested : *innerLoop.getBody()) { - rewriter.eraseOp(&nested); - } - loopIvs.push_back(innerLoop.getInductionVar()); - rewriter.create(loc); - rewriter.setInsertionPointToStart(innerLoop.getBody()); - rewriter.create(loc, fillValue, memRef, loopIvs); - - rewriter.create(loc); - rewriter.setInsertionPointAfter(outerLoop); -} - -// Specify the fill Value directly -void affineFillMemRefInt(mlir::Value value, mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, mlir::ArrayRef shape, - mlir::MLIRContext *ctx, mlir::Value memRef) { +void affineFillMemRef(mlir::Value value, mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, + mlir::ArrayRef shape, mlir::MLIRContext *ctx, mlir::Value memRef) { constexpr int ROW = 0; constexpr int COL = 1; llvm::SmallVector loopIvs; @@ -107,47 +70,11 @@ void affineFillMemRefInt(mlir::Value value, mlir::ConversionPatternRewriter &rew rewriter.setInsertionPointAfter(outerLoop); } -void affineFillMemRef(double value, mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, mlir::ArrayRef shape, - mlir::MLIRContext *ctx, mlir::Value memRef, - mlir::Type elemType) { - constexpr int ROW = 0; - constexpr int COL = 1; - mlir::Value fillValue = rewriter.create( - loc, elemType, rewriter.getFloatAttr(elemType, value)); - - llvm::SmallVector loopIvs; - - auto outerLoop = rewriter.create(loc, 0, shape[ROW], 1); - for (mlir::Operation &nested : *outerLoop.getBody()) { - rewriter.eraseOp(&nested); - } - loopIvs.push_back(outerLoop.getInductionVar()); - - // outer loop body - rewriter.setInsertionPointToStart(outerLoop.getBody()); - auto innerLoop = rewriter.create(loc, 0, shape[COL], 1); - for (mlir::Operation &nested : *innerLoop.getBody()) { - rewriter.eraseOp(&nested); - } - loopIvs.push_back(innerLoop.getInductionVar()); - rewriter.create(loc); - rewriter.setInsertionPointToStart(innerLoop.getBody()); - rewriter.create(loc, fillValue, memRef, loopIvs); - - rewriter.create(loc); - rewriter.setInsertionPointAfter(outerLoop); -} - -mlir::Value convertMemRefToDenseMatrix( - mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, - mlir::Value memRef, mlir::Type type) { - auto extractStridedMetadataOp = - rewriter.create(loc, memRef); +mlir::Value convertMemRefToDenseMatrix(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, + mlir::Value memRef, mlir::Type type) { + auto extractStridedMetadataOp = rewriter.create(loc, memRef); // aligned ptr (memref.data) - mlir::Value alignedPtr = - rewriter.create(loc, - memRef); + mlir::Value alignedPtr = rewriter.create(loc, memRef); // offset mlir::Value offset = extractStridedMetadataOp.getOffset(); // strides @@ -155,51 +82,38 @@ mlir::Value convertMemRefToDenseMatrix( // sizes mlir::ResultRange sizes = extractStridedMetadataOp.getSizes(); - return rewriter.create( - loc, type, alignedPtr, offset, sizes[0], sizes[1], strides[0], - strides[1]); + return rewriter.create(loc, type, alignedPtr, offset, sizes[0], sizes[1], + strides[0], strides[1]); } mlir::Type convertFloat(mlir::FloatType floatType) { - return mlir::IntegerType::get(floatType.getContext(), - floatType.getIntOrFloatBitWidth()); + return mlir::IntegerType::get(floatType.getContext(), floatType.getIntOrFloatBitWidth()); } mlir::Type convertInteger(mlir::IntegerType intType) { - return mlir::IntegerType::get(intType.getContext(), - intType.getIntOrFloatBitWidth()); + return mlir::IntegerType::get(intType.getContext(), intType.getIntOrFloatBitWidth()); } -llvm::Optional materializeCastFromIllegal(mlir::OpBuilder &builder, - mlir::Type type, - mlir::ValueRange inputs, - mlir::Location loc) { +llvm::Optional materializeCastFromIllegal(mlir::OpBuilder &builder, mlir::Type type, + mlir::ValueRange inputs, mlir::Location loc) { mlir::Type fromType = getElementTypeOrSelf(inputs[0].getType()); mlir::Type toType = getElementTypeOrSelf(type); - if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) || - !toType.isSignlessInteger()) + if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) || !toType.isSignlessInteger()) return std::nullopt; // Use unrealized conversion casts to do signful->signless conversions. - return builder - .create(loc, type, inputs[0]) - ->getResult(0); + return builder.create(loc, type, inputs[0])->getResult(0); } -llvm::Optional materializeCastToIllegal(mlir::OpBuilder &builder, - mlir::Type type, - mlir::ValueRange inputs, +llvm::Optional materializeCastToIllegal(mlir::OpBuilder &builder, mlir::Type type, mlir::ValueRange inputs, mlir::Location loc) { mlir::Type fromType = getElementTypeOrSelf(inputs[0].getType()); mlir::Type toType = getElementTypeOrSelf(type); - if (!fromType.isSignlessInteger() || - (!toType.isSignedInteger() && !toType.isUnsignedInteger())) + if (!fromType.isSignlessInteger() || (!toType.isSignedInteger() && !toType.isUnsignedInteger())) return std::nullopt; // Use unrealized conversion casts to do signless->signful conversions. - return builder - .create(loc, type, inputs[0]) - ->getResult(0); + return builder.create(loc, type, inputs[0])->getResult(0); } mlir::Operation *findLastUseOfSSAValue(mlir::Value &v) { diff --git a/src/compiler/utils/LoweringUtils.h b/src/compiler/utils/LoweringUtils.h index e5c49123a..6a3d11ed3 100644 --- a/src/compiler/utils/LoweringUtils.h +++ b/src/compiler/utils/LoweringUtils.h @@ -33,39 +33,20 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" -mlir::Value insertMemRefAlloc(mlir::MemRefType type, mlir::Location loc, - mlir::PatternRewriter &rewriter); +mlir::Value insertMemRefAlloc(mlir::MemRefType type, mlir::Location loc, mlir::PatternRewriter &rewriter); -void insertMemRefDealloc(mlir::Value memref, mlir::Location loc, - mlir::PatternRewriter &rewriter); +void insertMemRefDealloc(mlir::Value memref, mlir::Location loc, mlir::PatternRewriter &rewriter); -void affineFillMemRefInt(int value, mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, mlir::ArrayRef shape, - mlir::MLIRContext *ctx, mlir::Value memRef, - mlir::Type elemType); +void affineFillMemRef(mlir::Value value, mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, + mlir::ArrayRef shape, mlir::MLIRContext *ctx, mlir::Value memRef); -void affineFillMemRefInt(mlir::Value value, - mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, mlir::ArrayRef shape, - mlir::MLIRContext *ctx, mlir::Value memRef); +mlir::Value convertMemRefToDenseMatrix(mlir::Location, mlir::ConversionPatternRewriter &, mlir::Value memRef, + mlir::Type); -void affineFillMemRef(double value, mlir::ConversionPatternRewriter &rewriter, - mlir::Location loc, mlir::ArrayRef shape, - mlir::MLIRContext *ctx, mlir::Value memRef, - mlir::Type elemType); +llvm::Optional materializeCastFromIllegal(mlir::OpBuilder &builder, mlir::Type type, + mlir::ValueRange inputs, mlir::Location loc); -mlir::Value convertMemRefToDenseMatrix(mlir::Location, - mlir::ConversionPatternRewriter &, - mlir::Value memRef, mlir::Type); - -llvm::Optional materializeCastFromIllegal(mlir::OpBuilder &builder, - mlir::Type type, - mlir::ValueRange inputs, - mlir::Location loc); - -llvm::Optional materializeCastToIllegal(mlir::OpBuilder &builder, - mlir::Type type, - mlir::ValueRange inputs, +llvm::Optional materializeCastToIllegal(mlir::OpBuilder &builder, mlir::Type type, mlir::ValueRange inputs, mlir::Location loc); mlir::Type convertFloat(mlir::FloatType floatType); diff --git a/src/compiler/utils/TypePrinting.cpp b/src/compiler/utils/TypePrinting.cpp index bb2e98582..dc58cb6b3 100644 --- a/src/compiler/utils/TypePrinting.cpp +++ b/src/compiler/utils/TypePrinting.cpp @@ -22,7 +22,7 @@ #include #include -std::ostream & operator<<(std::ostream & os, mlir::Type t) { +std::ostream &operator<<(std::ostream &os, mlir::Type t) { std::string s; llvm::raw_string_ostream rsos(s); t.print(rsos); diff --git a/src/compiler/utils/TypePrinting.h b/src/compiler/utils/TypePrinting.h index cfde50395..50d5d7bdd 100644 --- a/src/compiler/utils/TypePrinting.h +++ b/src/compiler/utils/TypePrinting.h @@ -20,4 +20,4 @@ #include -std::ostream & operator<<(std::ostream & os, mlir::Type t); \ No newline at end of file +std::ostream &operator<<(std::ostream &os, mlir::Type t); \ No newline at end of file diff --git a/src/ir/daphneir/CMakeLists.txt b/src/ir/daphneir/CMakeLists.txt index 87a0fb6ac..3016dc0c5 100644 --- a/src/ir/daphneir/CMakeLists.txt +++ b/src/ir/daphneir/CMakeLists.txt @@ -36,6 +36,8 @@ add_mlir_doc(Passes -gen-pass-doc DaphnePasses Dialects/) add_mlir_dialect_library(MLIRDaphne DaphneDialect.cpp + Fold.cpp + Canonicalize.cpp DaphneDistributableOpInterface.cpp DaphneInferFrameLabelsOpInterface.cpp DaphneInferShapeOpInterface.cpp diff --git a/src/ir/daphneir/Canonicalize.cpp b/src/ir/daphneir/Canonicalize.cpp new file mode 100644 index 000000000..b296cc08b --- /dev/null +++ b/src/ir/daphneir/Canonicalize.cpp @@ -0,0 +1,516 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/daphneir/Daphne.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LogicalResult.h" +#include + +mlir::LogicalResult mlir::daphne::VectorizedPipelineOp::canonicalize(mlir::daphne::VectorizedPipelineOp op, + mlir::PatternRewriter &rewriter) { + // // Find duplicate inputs + std::vector vSplitsAttrs; + for (auto &split : op.getSplits()) + vSplitsAttrs.push_back(split); + auto currentSize = op.getInputs().size(); + + DenseMap inputMap; + + for (size_t i = 0; i < currentSize; i++) { + const auto &input = op.getInputs()[i]; + const auto &split = vSplitsAttrs[i].cast().getValue(); + + if (inputMap.count(input) == 0) { + inputMap[input] = i; + } else { + size_t j = inputMap[input]; + if (op.getSplits()[j].cast().getValue() == split) { + op.getBody().getArgument(i).replaceAllUsesWith(op.getBody().getArgument(j)); + op.getBody().eraseArgument(i); + op.getInputsMutable().erase(i); + vSplitsAttrs.erase(vSplitsAttrs.begin() + i); + currentSize--; + i--; + } + } + } + + std::vector resultsToReplace; + std::vector outRows; + std::vector outCols; + std::vector vCombineAttrs; + + llvm::BitVector eraseIxs; + eraseIxs.resize(op.getNumResults()); + for (auto result : op.getResults()) { + auto resultIx = result.getResultNumber(); + if (result.use_empty()) { + // remove + eraseIxs.set(resultIx); + } else { + resultsToReplace.push_back(result); + outRows.push_back(op.getOutRows()[resultIx]); + outCols.push_back(op.getOutCols()[resultIx]); + vCombineAttrs.push_back(op.getCombines()[resultIx]); + } + } + op.getBody().front().getTerminator()->eraseOperands(eraseIxs); + if (!op.getCuda().getBlocks().empty()) + op.getCuda().front().getTerminator()->eraseOperands(eraseIxs); + + if (resultsToReplace.size() == op->getNumResults() && op.getSplits().size() == vSplitsAttrs.size()) { + return failure(); + } + auto pipelineOp = rewriter.create( + op.getLoc(), ValueRange(resultsToReplace).getTypes(), op.getInputs(), outRows, outCols, + rewriter.getArrayAttr(vSplitsAttrs), rewriter.getArrayAttr(vCombineAttrs), op.getCtx()); + pipelineOp.getBody().takeBody(op.getBody()); + if (!op.getCuda().getBlocks().empty()) + pipelineOp.getCuda().takeBody(op.getCuda()); + for (auto e : llvm::enumerate(resultsToReplace)) { + auto resultToReplace = e.value(); + auto i = e.index(); + resultToReplace.replaceAllUsesWith(pipelineOp.getResult(i)); + } + op.erase(); + return success(); +} + +/** + * @brief Transposition-aware matrix multiplication + * Identifies if an input to a MatMulOp is the result of a TransposeOp; Rewrites + * the Operation, passing transposition info as a flag, instead of transposing + * the matrix before multiplication + */ +mlir::LogicalResult mlir::daphne::MatMulOp::canonicalize(mlir::daphne::MatMulOp op, PatternRewriter &rewriter) { + mlir::Value lhs = op.getLhs(); + mlir::Value rhs = op.getRhs(); + mlir::Value transa = op.getTransa(); + mlir::Value transb = op.getTransb(); + + // TODO If transa or transb are not constant, we cannot continue on the + // respective side; we cannot just assume false then. + bool ta = CompilerUtils::constantOrDefault(transa, false); + bool tb = CompilerUtils::constantOrDefault(transb, false); + + // TODO Turn on the transposition-awareness for the left-hand-side argument + // again (see #447). mlir::daphne::TransposeOp lhsTransposeOp = + // lhs.getDefiningOp(); + mlir::daphne::TransposeOp rhsTransposeOp = rhs.getDefiningOp(); + + // if (!lhsTransposeOp && !rhsTransposeOp){ + if (!rhsTransposeOp) { + return mlir::failure(); + } + + // ToDo: This check prevents merging transpose into matrix multiplication + // because that is not yet supported by our + // sparse kernels. + // ToDo: bring user config here for sparsity threshold or properly use + // MatrixRepresentation + if (auto t = rhs.getType().dyn_cast()) { + auto sparsity = t.getSparsity(); + if (sparsity < 0.25) + return mlir::failure(); + } + +#if 0 + // TODO Adapt PhyOperatorSelectionPass once this code is turned on again. + if(lhsTransposeOp) { + lhs = lhsTransposeOp.getArg(); + ta = !ta; + } +#endif + if (rhsTransposeOp) { + rhs = rhsTransposeOp.getArg(); + tb = !tb; + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), lhs, rhs, + static_cast(rewriter.create(transa.getLoc(), ta)), + static_cast(rewriter.create(transb.getLoc(), tb))); + return mlir::success(); +} + +/** + * @brief Replaces NumRowsOp by a constant, if the #rows of the input is known + * (e.g., due to shape inference). + */ +mlir::LogicalResult mlir::daphne::NumRowsOp::canonicalize(mlir::daphne::NumRowsOp op, PatternRewriter &rewriter) { + ssize_t numRows = -1; + + mlir::Type inTy = op.getArg().getType(); + if (auto t = inTy.dyn_cast()) + numRows = t.getNumRows(); + else if (auto t = inTy.dyn_cast()) + numRows = t.getNumRows(); + + if (numRows != -1) { + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), + rewriter.getIndexAttr(numRows)); + return mlir::success(); + } + return mlir::failure(); +} + +/** + * @brief Replaces NumColsOp by a constant, if the #cols of the input is known + * (e.g., due to shape inference). + */ +mlir::LogicalResult mlir::daphne::NumColsOp::canonicalize(mlir::daphne::NumColsOp op, PatternRewriter &rewriter) { + ssize_t numCols = -1; + + mlir::Type inTy = op.getArg().getType(); + if (auto t = inTy.dyn_cast()) + numCols = t.getNumCols(); + else if (auto t = inTy.dyn_cast()) + numCols = t.getNumCols(); + + if (numCols != -1) { + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), + rewriter.getIndexAttr(numCols)); + return mlir::success(); + } + return mlir::failure(); +} + +/** + * @brief Replaces NumCellsOp by a constant, if the #rows and #cols of the + * input is known (e.g., due to shape inference). + */ +mlir::LogicalResult mlir::daphne::NumCellsOp::canonicalize(mlir::daphne::NumCellsOp op, PatternRewriter &rewriter) { + ssize_t numRows = -1; + ssize_t numCols = -1; + + mlir::Type inTy = op.getArg().getType(); + if (auto t = inTy.dyn_cast()) { + numRows = t.getNumRows(); + numCols = t.getNumCols(); + } else if (auto t = inTy.dyn_cast()) { + numRows = t.getNumRows(); + numCols = t.getNumCols(); + } + + if (numRows != -1 && numCols != -1) { + rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), + rewriter.getIndexAttr(numRows * numCols)); + return mlir::success(); + } + return mlir::failure(); +} + +/** + * @brief Replaces SparsityOp by a constant, if the sparsity of the input is + * known (e.g., due to sparsity inference). + */ +mlir::LogicalResult mlir::daphne::SparsityOp::canonicalize(mlir::daphne::SparsityOp op, PatternRewriter &rewriter) { + double sparsity = -1.0; + + mlir::Type inTy = op.getArg().getType(); + if (auto t = inTy.dyn_cast()) + sparsity = t.getSparsity(); + + if (sparsity != -1) { + rewriter.replaceOpWithNewOp(op, sparsity); + return mlir::success(); + } + return mlir::failure(); +} + +/** + * @brief Replaces (1) `a + b` by `a concat b`, if `a` or `b` is a string, + * and (2) `a + X` by `X + a` (`a` scalar, `X` matrix/frame). + * + * (1) is important, since we use the `+`-operator for both addition and + * string concatenation in DaphneDSL, while the types of the operands might be + * known only after type inference. + * + * (2) is important, since our kernels for elementwise binary operations only + * support scalars as the right-hand-side operand so far (see #203). + * + * @param op + * @param rewriter + * @return + */ +mlir::LogicalResult mlir::daphne::EwAddOp::canonicalize(mlir::daphne::EwAddOp op, PatternRewriter &rewriter) { + mlir::Value lhs = op.getLhs(); + mlir::Value rhs = op.getRhs(); + + const bool lhsIsStr = llvm::isa(lhs.getType()); + const bool rhsIsStr = llvm::isa(rhs.getType()); + if (lhsIsStr || rhsIsStr) { + mlir::Type strTy = mlir::daphne::StringType::get(rewriter.getContext()); + if (!lhsIsStr) + lhs = rewriter.create(op.getLoc(), strTy, lhs); + if (!rhsIsStr) + rhs = rewriter.create(op.getLoc(), strTy, rhs); + rewriter.replaceOpWithNewOp(op, strTy, lhs, rhs); + return mlir::success(); + } else { + const bool lhsIsSca = !llvm::isa(lhs.getType()); + const bool rhsIsSca = !llvm::isa(rhs.getType()); + if (lhsIsSca && !rhsIsSca) { + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), rhs, lhs); + return mlir::success(); + } + return mlir::failure(); + } +} + +/** + * @brief Replaces `a - X` by `(X * -1) + a` (`a` scalar, `X` matrix/frame). + * + * This is important, since our kernels for elementwise binary operations only + * support scalars as the right-hand-side operand so far (see #203). + * + * As a downside, an additional operation and intermediate result is introduced. + * + * @param op + * @param rewriter + * @return + */ +mlir::LogicalResult mlir::daphne::EwSubOp::canonicalize(mlir::daphne::EwSubOp op, PatternRewriter &rewriter) { + mlir::Value lhs = op.getLhs(); + mlir::Value rhs = op.getRhs(); + const bool lhsIsSca = !llvm::isa(lhs.getType()); + const bool rhsIsSca = !llvm::isa(rhs.getType()); + if (lhsIsSca && !rhsIsSca) { + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), + rewriter.create( + op->getLoc(), + mlir::daphne::UnknownType::get(op->getContext()), // to be inferred + rhs, rewriter.create(op->getLoc(), int64_t(-1))), + lhs); + return mlir::success(); + } + return mlir::failure(); +} + +/** + * @brief Replaces `a * X` by `X * a` (`a` scalar, `X` matrix/frame). + * + * This is important, since our kernels for elementwise binary operations only + * support scalars as the right-hand-side operand so far (see #203). + * + * @param op + * @param rewriter + * @return + */ +mlir::LogicalResult mlir::daphne::EwMulOp::canonicalize(mlir::daphne::EwMulOp op, PatternRewriter &rewriter) { + mlir::Value lhs = op.getLhs(); + mlir::Value rhs = op.getRhs(); + const bool lhsIsSca = !llvm::isa(lhs.getType()); + const bool rhsIsSca = !llvm::isa(rhs.getType()); + if (lhsIsSca && !rhsIsSca) { + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), rhs, lhs); + return mlir::success(); + } + return mlir::failure(); +} + +/** + * @brief Replaces `a / X` by `(X ^ -1) * a` (`a` scalar, `X` matrix/frame), + * if `X` has a floating-point value type. + * + * This is important, since our kernels for elementwise binary operations only + * support scalars as the right-hand-side operand so far (see #203). + * + * As a downside, an additional operation and intermediate result is introduced. + * + * @param op + * @param rewriter + * @return + */ +mlir::LogicalResult mlir::daphne::EwDivOp::canonicalize(mlir::daphne::EwDivOp op, PatternRewriter &rewriter) { + mlir::Value lhs = op.getLhs(); + mlir::Value rhs = op.getRhs(); + const bool lhsIsSca = !llvm::isa(lhs.getType()); + const bool rhsIsSca = !llvm::isa(rhs.getType()); + const bool rhsIsFP = llvm::isa(CompilerUtils::getValueType(rhs.getType())); + if (lhsIsSca && !rhsIsSca && rhsIsFP) { + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), + rewriter.create(op->getLoc(), + mlir::daphne::UnknownType::get(op->getContext()), // to be inferred + rhs, + rewriter.create(op->getLoc(), double(-1))), + lhs); + return mlir::success(); + } + return mlir::failure(); +} + +/** + * @brief Replaces a `DistributeOp` by a `DistributedReadOp`, if its input + * value (a) is defined by a `ReadOp`, and (b) is not used elsewhere. + * @param context + */ +struct SimplifyDistributeRead : public mlir::OpRewritePattern { + SimplifyDistributeRead(mlir::MLIRContext *context) : OpRewritePattern(context, 1) { + // + } + + mlir::LogicalResult matchAndRewrite(mlir::daphne::DistributeOp op, mlir::PatternRewriter &rewriter) const override { + mlir::daphne::ReadOp readOp = op.getMat().getDefiningOp(); + if (!readOp || !readOp.getOperation()->hasOneUse()) + return mlir::failure(); + rewriter.replaceOp(op, {rewriter.create(readOp.getLoc(), op.getType(), + readOp.getFileName())}); + // TODO Instead of erasing the ReadOp here, the compiler should + // generally remove unused SSA values. Then, we might even drop the + // hasOneUse requirement above. + rewriter.eraseOp(readOp); + return mlir::success(); + } +}; + +void mlir::daphne::DistributeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + +mlir::LogicalResult mlir::daphne::CondOp::canonicalize(mlir::daphne::CondOp op, mlir::PatternRewriter &rewriter) { + mlir::Value cond = op.getCond(); + if (llvm::isa(cond.getType())) + // If the condition is not a scalar, we cannot rewrite the operation + // here. + return mlir::failure(); + else { + // If the condition is a scalar, we rewrite the operation to an + // if-then-else construct using the SCF dialect. + // TODO Check if it is really a scalar. + + mlir::Location loc = op.getLoc(); + + // Ensure that the condition is a boolean. + if (!cond.getType().isSignlessInteger(1)) + cond = rewriter.create(loc, rewriter.getI1Type(), cond); + + mlir::Block thenBlock; + mlir::Block elseBlock; + mlir::Value thenVal = op.getThenVal(); + mlir::Value elseVal = op.getElseVal(); + + // Get rid of frame column labels, since they interfere with the type + // comparison (see #485). + if (auto thenFrmTy = thenVal.getType().dyn_cast()) + if (thenFrmTy.getLabels() != nullptr) + thenVal = rewriter.create(loc, thenFrmTy.withLabels(nullptr), thenVal); + if (auto elseFrmTy = elseVal.getType().dyn_cast()) + if (elseFrmTy.getLabels() != nullptr) + elseVal = rewriter.create(loc, elseFrmTy.withLabels(nullptr), elseVal); + + // Check if the types of the then-value and the else-value are the same. + if (thenVal.getType() != elseVal.getType()) { + if (llvm::isa(thenVal.getType()) || llvm::isa(elseVal.getType())) + // If one of them is unknown, we abort the rewrite (but this is + // not an error). The type may become known later, this rewrite + // will be triggered again. + return mlir::failure(); + else + // If both types are known, but different, this is an error. + // TODO We could try to cast the types. + throw ErrorHandler::compilerError(op, "CanonicalizerPass (mlir::daphne::CondOp)", + "the then/else-values of CondOp must have the same value " + "type"); + } + + { + // Save the insertion point (automatically restored at the end of + // the block). + PatternRewriter::InsertionGuard insertGuard(rewriter); + + // TODO The current implementation only makes sure that the correct + // value is returned, but the operations calculating the + // then/else-values are still outside the if-then-else and will + // always both be executed (unless, e.g., the entire branching can + // be elimitated). This could be good (e.g., if the then/else-values + // have common subexpressions with other code) or bad (e.g., if they + // are expensive to compute). See #486. + + // Create yield-operations in both branches. + rewriter.setInsertionPointToEnd(&thenBlock); + rewriter.create(loc, thenVal); + rewriter.setInsertionPointToEnd(&elseBlock); + rewriter.create(loc, elseVal); + } + + // Helper functions to move the operations in the two blocks created + // above into the actual branches of the if-operation. + auto insertThenBlockDo = [&](mlir::OpBuilder &nested, mlir::Location loc) { + nested.getBlock()->getOperations().splice(nested.getBlock()->end(), thenBlock.getOperations()); + }; + auto insertElseBlockDo = [&](mlir::OpBuilder &nested, mlir::Location loc) { + nested.getBlock()->getOperations().splice(nested.getBlock()->end(), elseBlock.getOperations()); + }; + + // Replace the daphne::CondOp by an scf::IfOp. + rewriter.replaceOpWithNewOp(op, cond, insertThenBlockDo, insertElseBlockDo); + + return mlir::success(); + } +} + +mlir::LogicalResult mlir::daphne::ConvertDenseMatrixToMemRef::canonicalize(mlir::daphne::ConvertDenseMatrixToMemRef op, + mlir::PatternRewriter &rewriter) { + // removes unnecessary conversions of MemRef -> DM -> MemRef + mlir::Operation *dmNode = op->getOperand(0).getDefiningOp(); + + if (!llvm::isa(dmNode)) + return failure(); + + mlir::Operation *originalMemRefOp = dmNode->getPrevNode()->getOperand(0).getDefiningOp(); + op.replaceAllUsesWith(originalMemRefOp); + + rewriter.eraseOp(op); + if (dmNode->getUsers().empty()) + rewriter.eraseOp(dmNode); + + return mlir::success(); +} + +mlir::LogicalResult mlir::daphne::ConvertMemRefToDenseMatrix::canonicalize(mlir::daphne::ConvertMemRefToDenseMatrix op, + mlir::PatternRewriter &rewriter) { + mlir::Operation *extractPtr = op->getPrevNode(); + auto srcMemRef = extractPtr->getOperand(0).getDefiningOp(); + extractPtr->moveAfter(srcMemRef); + op->moveAfter(extractPtr); + + return mlir::success(); +} + +mlir::LogicalResult mlir::daphne::RenameOp::canonicalize(mlir::daphne::RenameOp op, mlir::PatternRewriter &rewriter) { + // Replace the RenameOp by its argument, since we only need + // this operation during DaphneDSL parsing. + rewriter.replaceOp(op, op.getArg()); + return mlir::success(); +} + +/** + * @brief Replaces `--a` by `a` (`a` scalar). + * + * @param op + * @param rewriter + * @return + */ +mlir::LogicalResult mlir::daphne::EwMinusOp::canonicalize(mlir::daphne::EwMinusOp op, PatternRewriter &rewriter) { + if (auto innerOp = op.getOperand().getDefiningOp()) { + rewriter.replaceOp(op, innerOp.getOperand()); + return mlir::success(); + } + return mlir::failure(); +} diff --git a/src/ir/daphneir/Daphne.h b/src/ir/daphneir/Daphne.h index 73a2e6b23..f960c4e3c 100644 --- a/src/ir/daphneir/Daphne.h +++ b/src/ir/daphneir/Daphne.h @@ -20,11 +20,13 @@ // The following includes are required by... #include "llvm/ADT/StringRef.h" -// TODO Get rid of this workaround by removing the pragmas and the include within +// TODO Get rid of this workaround by removing the pragmas and the include +// within // (note that this header is also included transitively by FuncOps.h), // once the problem is fixed in MLIR/LLVM. // As of MLIR llvm/llvm-project@20d454c79bbca7822eee88d188afb7a8747dac58, -// AttrTypeSubElements.h yields the following warnings, which are hereby ignored: +// AttrTypeSubElements.h yields the following warnings, which are hereby +// ignored: // - "... parameter 'derived' set but not used [-Wunused-but-set-parameter]" // - "... parameter 'walkAttrsFn' set but not used [-Wunused-but-set-parameter]" // - "... parameter 'walkTypesFn' set but not used [-Wunused-but-set-parameter]" @@ -33,10 +35,8 @@ #include "mlir/IR/AttrTypeSubElements.h" #pragma GCC diagnostic pop -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -44,16 +44,19 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" -#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" // TODO Get rid of this workaround by removing the pragmas, // once the problem is fixed in MLIR/LLVM. // As of MLIR llvm/llvm-project@20d454c79bbca7822eee88d188afb7a8747dac58, // PatternMatch.h yields the following warning, which is hereby ignored: -// - "... typedef 'using FnTraitsT = struct llvm::function_traits' locally defined but not used [-Wunused-local-typedefs]" +// - "... typedef 'using FnTraitsT = struct llvm::function_traits' +// locally defined but not used [-Wunused-local-typedefs]" #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-local-typedefs" #include "mlir/IR/PatternMatch.h" @@ -62,12 +65,12 @@ #include "mlir/Support/TypeID.h" #include -#include #include #include -#include #include +#include #include +#include #include #include @@ -75,29 +78,27 @@ #include namespace mlir::OpTrait { - template - class FPGAOPENCLSupport : public TraitBase { - }; -} +template class FPGAOPENCLSupport : public TraitBase {}; +} // namespace mlir::OpTrait namespace mlir::daphne { - enum class MatrixRepresentation { - Dense = 0, - // default is dense - Default = MatrixRepresentation::Dense, - Sparse = 1, - }; +enum class MatrixRepresentation { + Dense = 0, + // default is dense + Default = MatrixRepresentation::Dense, + Sparse = 1, +}; - std::string matrixRepresentationToString(MatrixRepresentation rep); +std::string matrixRepresentationToString(MatrixRepresentation rep); - MatrixRepresentation stringToMatrixRepresentation(const std::string &str); -} +MatrixRepresentation stringToMatrixRepresentation(const std::string &str); +} // namespace mlir::daphne // ... the following tablegen'erated headers. #define GET_TYPEDEF_CLASSES -#include #include "ir/daphneir/DaphneOpsDialect.h.inc" +#include #define GET_OP_CLASSES #include "ir/daphneir/DaphneOps.h.inc" -#endif //SRC_IR_DAPHNEIR_DAPHNE_H +#endif // SRC_IR_DAPHNEIR_DAPHNE_H diff --git a/src/ir/daphneir/DaphneAdaptTypesToKernelsTraits.h b/src/ir/daphneir/DaphneAdaptTypesToKernelsTraits.h index 6e30169ac..01db2d091 100644 --- a/src/ir/daphneir/DaphneAdaptTypesToKernelsTraits.h +++ b/src/ir/daphneir/DaphneAdaptTypesToKernelsTraits.h @@ -19,18 +19,17 @@ namespace mlir::OpTrait { -template -class CastArgsToResType : public TraitBase {}; +template class CastArgsToResType : public TraitBase {}; -template +template class CastFirstTwoArgsToResType : public TraitBase {}; -template +template class CastArgsToResTypeRandMatrixOp : public TraitBase {}; -template +template class CastArgsToMostGeneralArgType : public TraitBase {}; -} +} // namespace mlir::OpTrait -#endif //SRC_IR_DAPHNEIR_DAPHNEADAPTTYPESTOKERNELSTRAITS_H \ No newline at end of file +#endif // SRC_IR_DAPHNEIR_DAPHNEADAPTTYPESTOKERNELSTRAITS_H \ No newline at end of file diff --git a/src/ir/daphneir/DaphneDialect.cpp b/src/ir/daphneir/DaphneDialect.cpp index 542e66810..6c543d314 100644 --- a/src/ir/daphneir/DaphneDialect.cpp +++ b/src/ir/daphneir/DaphneDialect.cpp @@ -15,8 +15,8 @@ */ #include -#include #include +#include #include @@ -31,7 +31,6 @@ #include #include -#include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -52,80 +51,69 @@ #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/ArrayRef.h" -#include #include #include +#include #include #include #include struct DaphneInlinerInterface : public mlir::DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; + using DialectInlinerInterface::DialectInlinerInterface; - bool isLegalToInline(mlir::Operation *call, mlir::Operation *callable, - bool wouldBeCloned) const final { - return true; - } + bool isLegalToInline(mlir::Operation *call, mlir::Operation *callable, bool wouldBeCloned) const final { + return true; + } - bool isLegalToInline(mlir::Operation *, mlir::Region *, bool, mlir::IRMapping &) const final { - return true; - } + bool isLegalToInline(mlir::Operation *, mlir::Region *, bool, mlir::IRMapping &) const final { return true; } - bool isLegalToInline(mlir::Region *, mlir::Region *, bool, mlir::IRMapping &) const final { - return true; - } + bool isLegalToInline(mlir::Region *, mlir::Region *, bool, mlir::IRMapping &) const final { return true; } - void handleTerminator(mlir::Operation *op, - mlir::ArrayRef valuesToRepl) const final { - auto returnOp = mlir::dyn_cast(op); + void handleTerminator(mlir::Operation *op, mlir::ArrayRef valuesToRepl) const final { + auto returnOp = mlir::dyn_cast(op); - // Replace the values directly with the return operands. - if (returnOp.getNumOperands() != valuesToRepl.size()) { - throw ErrorHandler::compilerError(op, "DaphneInlinerInterface (handleTerminator)", - "number of operands " + std::to_string(returnOp.getNumOperands()) - + " from " + op->getName().getStringRef().str() - + " do not match size " + std::to_string(valuesToRepl.size()) - ); - } + // Replace the values directly with the return operands. + if (returnOp.getNumOperands() != valuesToRepl.size()) { + throw ErrorHandler::compilerError(op, "DaphneInlinerInterface (handleTerminator)", + "number of operands " + std::to_string(returnOp.getNumOperands()) + + " from " + op->getName().getStringRef().str() + + " do not match size " + std::to_string(valuesToRepl.size())); + } - for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()].replaceAllUsesWith(it.value()); - } + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()].replaceAllUsesWith(it.value()); + } - mlir::Operation *materializeCallConversion(mlir::OpBuilder &builder, mlir::Value input, - mlir::Type resultType, - mlir::Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); - } + mlir::Operation *materializeCallConversion(mlir::OpBuilder &builder, mlir::Value input, mlir::Type resultType, + mlir::Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } }; -void mlir::daphne::DaphneDialect::initialize() -{ +void mlir::daphne::DaphneDialect::initialize() { addOperations< - #define GET_OP_LIST - #include - >(); +#define GET_OP_LIST +#include + >(); addTypes< - #define GET_TYPEDEF_LIST - #include - >(); +#define GET_TYPEDEF_LIST +#include + >(); addInterfaces(); } -mlir::Operation *mlir::daphne::DaphneDialect::materializeConstant(OpBuilder &builder, - Attribute value, Type type, - mlir::Location loc) -{ +mlir::Operation *mlir::daphne::DaphneDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, + mlir::Location loc) { return builder.create(loc, type, value); } -mlir::Type mlir::daphne::DaphneDialect::parseType(mlir::DialectAsmParser &parser) const -{ +mlir::Type mlir::daphne::DaphneDialect::parseType(mlir::DialectAsmParser &parser) const { llvm::StringRef keyword; mlir::ParseResult pr = parser.parseKeyword(&keyword); - if(mlir::failed(pr)) + if (mlir::failed(pr)) throw std::runtime_error("parsing a DaphneIR type failed"); // `Matrix` `<` (`?` | \d+) `x` (`?` | \d+) `x` \type // (`:` ( @@ -156,12 +144,11 @@ mlir::Type mlir::daphne::DaphneDialect::parseType(mlir::DialectAsmParser &parser return nullptr; } } - if (parser.parseXInDimensionList() || - parser.parseType(elementType) - ) { + if (parser.parseXInDimensionList() || parser.parseType(elementType)) { return nullptr; } - // additional properties (only print/read them when present, as this will probably get more and more) + // additional properties (only print/read them when present, as this + // will probably get more and more) while (succeeded(parser.parseOptionalColon())) { if (succeeded(parser.parseOptionalKeyword("sp"))) { if (sparsity != -1.0) { @@ -171,42 +158,33 @@ mlir::Type mlir::daphne::DaphneDialect::parseType(mlir::DialectAsmParser &parser if (parser.parseLSquare() || parser.parseFloat(sparsity) || parser.parseRSquare()) { return nullptr; } - } - else if (succeeded(parser.parseOptionalKeyword("rep"))) { + } else if (succeeded(parser.parseOptionalKeyword("rep"))) { llvm::StringRef repName; if (parser.parseLSquare() || parser.parseKeyword(&repName) || parser.parseRSquare()) { return nullptr; } representation = stringToMatrixRepresentation(repName.str()); - } - else { + } else { return nullptr; } } - if(parser.parseGreater()) { + if (parser.parseGreater()) { return nullptr; } - return MatrixType::get( - parser.getBuilder().getContext(), elementType, numRows, numCols, sparsity, representation - ); - } - else if (keyword == "Frame") { + return MatrixType::get(parser.getBuilder().getContext(), elementType, numRows, numCols, sparsity, + representation); + } else if (keyword == "Frame") { ssize_t numRows = -1; ssize_t numCols = -1; - if ( - parser.parseLess() || - parser.parseOptionalQuestion() || + if (parser.parseLess() || parser.parseOptionalQuestion() || // TODO Parse #rows if there was no '?'. - //parser.parseInteger(numRows) || - parser.parseKeyword("x") || - parser.parseLSquare() || - parser.parseOptionalQuestion() || + // parser.parseInteger(numRows) || + parser.parseKeyword("x") || parser.parseLSquare() || parser.parseOptionalQuestion() || // TODO Parse #cols if there was no '?'. - //parser.parseInteger(numCols) || + // parser.parseInteger(numCols) || // TODO Parse sparsity - parser.parseColon() - ) { + parser.parseColon()) { return nullptr; } std::vector cts; @@ -215,52 +193,37 @@ mlir::Type mlir::daphne::DaphneDialect::parseType(mlir::DialectAsmParser &parser if (parser.parseType(type)) return nullptr; cts.push_back(type); - } - while (succeeded(parser.parseOptionalComma())); + } while (succeeded(parser.parseOptionalComma())); if (parser.parseRSquare() || parser.parseGreater()) { return nullptr; } - return FrameType::get( - parser.getBuilder().getContext(), cts, numRows, numCols, nullptr - ); - } - else if (keyword == "Handle") { + return FrameType::get(parser.getBuilder().getContext(), cts, numRows, numCols, nullptr); + } else if (keyword == "Handle") { mlir::Type dataType; if (parser.parseLess() || parser.parseType(dataType) || parser.parseGreater()) { return nullptr; } return mlir::daphne::HandleType::get(parser.getBuilder().getContext(), dataType); - } - else if (keyword == "String") { + } else if (keyword == "String") { return StringType::get(parser.getBuilder().getContext()); - } - else if (keyword == "DaphneContext") { + } else if (keyword == "DaphneContext") { return mlir::daphne::DaphneContextType::get(parser.getBuilder().getContext()); - } - else { + } else { parser.emitError(parser.getCurrentLocation()) << "Parsing failed, keyword `" << keyword << "` not recognized!"; return nullptr; } } -std::string unknownStrIf(ssize_t val) { - return (val == -1) ? "?" : std::to_string(val); -} +std::string unknownStrIf(ssize_t val) { return (val == -1) ? "?" : std::to_string(val); } -std::string unknownStrIf(double val) { - return (val == -1.0) ? "?" : std::to_string(val); -} +std::string unknownStrIf(double val) { return (val == -1.0) ? "?" : std::to_string(val); } -void mlir::daphne::DaphneDialect::printType(mlir::Type type, - mlir::DialectAsmPrinter &os) const -{ +void mlir::daphne::DaphneDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { if (type.isa()) os << "Structure"; else if (auto t = type.dyn_cast()) { - os << "Matrix<" - << unknownStrIf(t.getNumRows()) << 'x' - << unknownStrIf(t.getNumCols()) << 'x' - << t.getElementType(); + os << "Matrix<" << unknownStrIf(t.getNumRows()) << 'x' << unknownStrIf(t.getNumCols()) << 'x' + << t.getElementType(); auto sparsity = t.getSparsity(); auto representation = t.getRepresentation(); @@ -271,41 +234,34 @@ void mlir::daphne::DaphneDialect::printType(mlir::Type type, os << ":rep[" << matrixRepresentationToString(representation) << ']'; } os << '>'; - } - else if (auto t = type.dyn_cast()) { - os << "Frame<" - << unknownStrIf(t.getNumRows()) << "x[" - << unknownStrIf(t.getNumCols()) << ": "; + } else if (auto t = type.dyn_cast()) { + os << "Frame<" << unknownStrIf(t.getNumRows()) << "x[" << unknownStrIf(t.getNumCols()) << ": "; // Column types. std::vector cts = t.getColumnTypes(); for (size_t i = 0; i < cts.size(); i++) { os << cts[i]; - if(i < cts.size() - 1) + if (i < cts.size() - 1) os << ", "; } os << "], "; // Column labels. - std::vector * labels = t.getLabels(); - if(labels) { + std::vector *labels = t.getLabels(); + if (labels) { os << '['; for (size_t i = 0; i < labels->size(); i++) { os << '"' << (*labels)[i] << '"'; - if(i < labels->size() - 1) + if (i < labels->size() - 1) os << ", "; } os << ']'; - } - else + } else os << '?'; os << '>'; - } - else if (auto t = type.dyn_cast()) { + } else if (auto t = type.dyn_cast()) { os << "List<" << t.getElementType() << '>'; - } - else if (auto handle = type.dyn_cast()) { + } else if (auto handle = type.dyn_cast()) { os << "Handle<" << handle.getDataType() << ">"; - } - else if (isa(type)) + } else if (isa(type)) os << "String"; else if (auto t = type.dyn_cast()) os << "VariadicPack<" << t.getContainedType() << '>'; @@ -328,13 +284,12 @@ std::string mlir::daphne::matrixRepresentationToString(MatrixRepresentation rep) case MatrixRepresentation::Sparse: return "sparse"; default: - throw std::runtime_error("unknown mlir::daphne::MatrixRepresentation " + - std::to_string(static_cast(rep))); + throw std::runtime_error("unknown mlir::daphne::MatrixRepresentation " + std::to_string(static_cast(rep))); } } mlir::daphne::MatrixRepresentation mlir::daphne::stringToMatrixRepresentation(const std::string &str) { - if(str == "dense") + if (str == "dense") return MatrixRepresentation::Dense; else if (str == "sparse") return MatrixRepresentation::Sparse; @@ -343,1219 +298,110 @@ mlir::daphne::MatrixRepresentation mlir::daphne::stringToMatrixRepresentation(co } namespace mlir::daphne { - namespace detail { - struct MatrixTypeStorage : public ::mlir::TypeStorage { - // TODO: adapt epsilon for equality check (I think the only use is saving memory for the MLIR-IR representation of this type) - // the choosen epsilon directly defines how accurate our sparsity inference can be - constexpr static const double epsilon = 1e-6; - MatrixTypeStorage(::mlir::Type elementType, - ssize_t numRows, - ssize_t numCols, - double sparsity, - MatrixRepresentation representation) - : elementType(elementType), numRows(numRows), numCols(numCols), sparsity(sparsity), - representation(representation) {} - - /// The hash key is a tuple of the parameter types. - using KeyTy = std::tuple<::mlir::Type, ssize_t, ssize_t, double, MatrixRepresentation>; - bool operator==(const KeyTy &tblgenKey) const { - if(!(elementType == std::get<0>(tblgenKey))) - return false; - if(numRows != std::get<1>(tblgenKey)) - return false; - if(numCols != std::get<2>(tblgenKey)) - return false; - if(std::fabs(sparsity - std::get<3>(tblgenKey)) >= epsilon) - return false; - if(representation != std::get<4>(tblgenKey)) - return false; - return true; - } - static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { - auto float_hashable = static_cast(std::get<3>(tblgenKey) / epsilon); - return ::llvm::hash_combine(std::get<0>(tblgenKey), - std::get<1>(tblgenKey), - std::get<2>(tblgenKey), - float_hashable, - std::get<4>(tblgenKey)); - } - - /// Define a construction method for creating a new instance of this - /// storage. - static MatrixTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, - const KeyTy &tblgenKey) { - auto elementType = std::get<0>(tblgenKey); - auto numRows = std::get<1>(tblgenKey); - auto numCols = std::get<2>(tblgenKey); - auto sparsity = std::get<3>(tblgenKey); - auto representation = std::get<4>(tblgenKey); - - return new(allocator.allocate()) - MatrixTypeStorage(elementType, numRows, numCols, sparsity, representation); - } - ::mlir::Type elementType; - ssize_t numRows; - ssize_t numCols; - double sparsity; - MatrixRepresentation representation; - }; - } - ::mlir::Type MatrixType::getElementType() const { return getImpl()->elementType; } - ssize_t MatrixType::getNumRows() const { return getImpl()->numRows; } - ssize_t MatrixType::getNumCols() const { return getImpl()->numCols; } - double MatrixType::getSparsity() const { return getImpl()->sparsity; } - MatrixRepresentation MatrixType::getRepresentation() const { return getImpl()->representation; } -} - -mlir::OpFoldResult mlir::daphne::ConstantOp::fold(FoldAdaptor adaptor) -{ - if (!adaptor.getOperands().empty()) - throw ErrorHandler::compilerError( - this->getLoc(), "CanonicalizerPass (mlir::daphne::ConstantOp::fold)", - "constant has no operands but " + std::to_string(adaptor.getOperands().size()) + " were given"); - - return getValue(); -} - -::mlir::LogicalResult mlir::daphne::MatrixType::verify( - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - Type elementType, - ssize_t numRows, ssize_t numCols, double sparsity, MatrixRepresentation rep -) -{ - if ( - ( +namespace detail { +struct MatrixTypeStorage : public ::mlir::TypeStorage { + // TODO: adapt epsilon for equality check (I think the only use is saving + // memory for the MLIR-IR representation of this type) + // the choosen epsilon directly defines how accurate our sparsity inference + // can be + constexpr static const double epsilon = 1e-6; + MatrixTypeStorage(::mlir::Type elementType, ssize_t numRows, ssize_t numCols, double sparsity, + MatrixRepresentation representation) + : elementType(elementType), numRows(numRows), numCols(numCols), sparsity(sparsity), + representation(representation) {} + + /// The hash key is a tuple of the parameter types. + using KeyTy = std::tuple<::mlir::Type, ssize_t, ssize_t, double, MatrixRepresentation>; + bool operator==(const KeyTy &tblgenKey) const { + if (!(elementType == std::get<0>(tblgenKey))) + return false; + if (numRows != std::get<1>(tblgenKey)) + return false; + if (numCols != std::get<2>(tblgenKey)) + return false; + if (std::fabs(sparsity - std::get<3>(tblgenKey)) >= epsilon) + return false; + if (representation != std::get<4>(tblgenKey)) + return false; + return true; + } + static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { + auto float_hashable = static_cast(std::get<3>(tblgenKey) / epsilon); + return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey), std::get<2>(tblgenKey), + float_hashable, std::get<4>(tblgenKey)); + } + + /// Define a construction method for creating a new instance of this + /// storage. + static MatrixTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { + auto elementType = std::get<0>(tblgenKey); + auto numRows = std::get<1>(tblgenKey); + auto numCols = std::get<2>(tblgenKey); + auto sparsity = std::get<3>(tblgenKey); + auto representation = std::get<4>(tblgenKey); + + return new (allocator.allocate()) + MatrixTypeStorage(elementType, numRows, numCols, sparsity, representation); + } + ::mlir::Type elementType; + ssize_t numRows; + ssize_t numCols; + double sparsity; + MatrixRepresentation representation; +}; +} // namespace detail +::mlir::Type MatrixType::getElementType() const { return getImpl()->elementType; } +ssize_t MatrixType::getNumRows() const { return getImpl()->numRows; } +ssize_t MatrixType::getNumCols() const { return getImpl()->numCols; } +double MatrixType::getSparsity() const { return getImpl()->sparsity; } +MatrixRepresentation MatrixType::getRepresentation() const { return getImpl()->representation; } +} // namespace mlir::daphne + +::mlir::LogicalResult mlir::daphne::MatrixType::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + Type elementType, ssize_t numRows, ssize_t numCols, + double sparsity, MatrixRepresentation rep) { + if (( // Value type is unknown. llvm::isa(elementType) // Value type is known. - || elementType.isSignedInteger(64) - || elementType.isUnsignedInteger(8) - || elementType.isUnsignedInteger(64) - || elementType.isF32() - || elementType.isF64() - || elementType.isIndex() - || elementType.isInteger(1) - || llvm::isa(elementType) - || elementType.isUnsignedInteger(64) - || elementType.isUnsignedInteger(32) - || elementType.isSignedInteger(32) - || elementType.isSignedInteger(8) - ) && ( + || elementType.isSignedInteger(64) || elementType.isUnsignedInteger(8) || + elementType.isUnsignedInteger(64) || elementType.isF32() || elementType.isF64() || elementType.isIndex() || + elementType.isInteger(1) || llvm::isa(elementType) || + elementType.isUnsignedInteger(64) || elementType.isUnsignedInteger(32) || elementType.isSignedInteger(32) || + elementType.isSignedInteger(8)) && + ( // Number of rows and columns are valid (-1 for unknown). - numRows >= -1 && numCols >= -1 - ) && ( - sparsity == -1 || (sparsity >= 0.0 && sparsity <= 1.0) - ) - ) + numRows >= -1 && numCols >= -1) && + (sparsity == -1 || (sparsity >= 0.0 && sparsity <= 1.0))) return mlir::success(); else return emitError() << "invalid matrix element type: " << elementType; } -::mlir::LogicalResult mlir::daphne::FrameType::verify( - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - std::vector columnTypes, - ssize_t numRows, ssize_t numCols, - std::vector * labels -) -{ +::mlir::LogicalResult mlir::daphne::FrameType::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + std::vector columnTypes, ssize_t numRows, ssize_t numCols, + std::vector *labels) { // TODO Verify the individual column types. - if(numRows < -1 || numCols < -1) + if (numRows < -1 || numCols < -1) return mlir::failure(); - if(numCols != -1) { + if (numCols != -1) { // ToDo: ExtractColOp does not provide these columnTypes - if(!columnTypes.empty()) { + if (!columnTypes.empty()) { if (static_cast(columnTypes.size()) != numCols) return mlir::failure(); if (labels && static_cast(labels->size()) != numCols) return mlir::failure(); } } - if(labels && labels->size() != columnTypes.size()) + if (labels && labels->size() != columnTypes.size()) return mlir::failure(); return mlir::success(); } ::mlir::LogicalResult mlir::daphne::HandleType::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - Type dataType) -{ + Type dataType) { if (llvm::isa(dataType)) { return mlir::success(); - } - else + } else return emitError() << "only matrix type is supported for handle atm, got: " << dataType; } - -mlir::LogicalResult mlir::daphne::VectorizedPipelineOp::canonicalize(mlir::daphne::VectorizedPipelineOp op, - mlir::PatternRewriter &rewriter) -{ - // // Find duplicate inputs - std::vector vSplitsAttrs; - for (auto & split : op.getSplits()) - vSplitsAttrs.push_back(split); - auto currentSize = op.getInputs().size(); - - DenseMap inputMap; - - for (size_t i = 0; i < currentSize; i++) { - const auto& input = op.getInputs()[i]; - const auto& split = op.getSplits()[i].cast().getValue(); - - if (inputMap.count(input) == 0) { - inputMap[input] = i; - } else { - size_t j = inputMap[input]; - if (op.getSplits()[j].cast().getValue() == split) { - op.getBody().getArgument(i).replaceAllUsesWith(op.getBody().getArgument(j)); - op.getBody().eraseArgument(i); - op.getInputsMutable().erase(i); - vSplitsAttrs.erase(vSplitsAttrs.begin() + i); - currentSize--; - i--; - } - } - } - - std::vector resultsToReplace; - std::vector outRows; - std::vector outCols; - std::vector vCombineAttrs; - - llvm::BitVector eraseIxs; - eraseIxs.resize(op.getNumResults()); - for(auto result : op.getResults()) { - auto resultIx = result.getResultNumber(); - if(result.use_empty()) { - // remove - eraseIxs.set(resultIx); - } - else { - resultsToReplace.push_back(result); - outRows.push_back(op.getOutRows()[resultIx]); - outCols.push_back(op.getOutCols()[resultIx]); - vCombineAttrs.push_back(op.getCombines()[resultIx]); - } - } - op.getBody().front().getTerminator()->eraseOperands(eraseIxs); - if(!op.getCuda().getBlocks().empty()) - op.getCuda().front().getTerminator()->eraseOperands(eraseIxs); - - if(resultsToReplace.size() == op->getNumResults() && op.getSplits().size() == vSplitsAttrs.size()) { - return failure(); - } - auto pipelineOp = rewriter.create(op.getLoc(), - ValueRange(resultsToReplace).getTypes(), - op.getInputs(), - outRows, - outCols, - rewriter.getArrayAttr(vSplitsAttrs), - rewriter.getArrayAttr(vCombineAttrs), - op.getCtx()); - pipelineOp.getBody().takeBody(op.getBody()); - if(!op.getCuda().getBlocks().empty()) - pipelineOp.getCuda().takeBody(op.getCuda()); - for (auto e : llvm::enumerate(resultsToReplace)) { - auto resultToReplace = e.value(); - auto i = e.index(); - resultToReplace.replaceAllUsesWith(pipelineOp.getResult(i)); - } - op.erase(); - return success(); -} - -// **************************************************************************** -// Fold utility functions/macros -// **************************************************************************** -// For families of operations. - -// Adapted from "mlir/Dialect/CommonFolders.h" -mlir::Attribute performCast(mlir::Attribute attr, mlir::Type targetType, mlir::Location loc); - -template< - class ArgAttrElementT, - class ResAttrElementT = ArgAttrElementT, - class ArgElementValueT = typename ArgAttrElementT::ValueType, - class ResElementValueT = typename ResAttrElementT::ValueType, - class CalculationT = std::function -> -mlir::Attribute constFoldBinaryOp(mlir::Location loc, mlir::Type resultType, llvm::ArrayRef operands, - const CalculationT &calculate) { - if (operands.size() != 2) - throw ErrorHandler::compilerError(loc, - "CanonicalizerPass (constFoldBinaryOp)", - "binary op takes two operands but " + std::to_string(operands.size()) + " were given"); - - if(!operands[0] || !operands[1]) - return {}; - - if(llvm::isa(operands[0]) && llvm::isa(operands[1])) { - auto lhs = operands[0].cast(); - auto rhs = operands[1].cast(); - - // We need dedicated cases, as the parameters of ResAttrElementT::get() depend on ResAttrElementT. - if constexpr( - std::is_same::value || - std::is_same::value - ) { - mlir::Type l = lhs.getType(); - mlir::Type r = rhs.getType(); - if ((l.dyn_cast() || l.dyn_cast()) && - (r.dyn_cast() || r.dyn_cast())) { - auto lhsBitWidth = lhs.getType().getIntOrFloatBitWidth(); - auto rhsBitWidth = rhs.getType().getIntOrFloatBitWidth(); - - if (lhsBitWidth < rhsBitWidth) { - mlir::Attribute promotedLhs = performCast(lhs, rhs.getType(), loc); - lhs = promotedLhs.cast(); - } else if (rhsBitWidth < lhsBitWidth) { - mlir::Attribute promotedRhs = performCast(rhs, lhs.getType(), loc); - rhs = promotedRhs.cast(); - } - } - return ResAttrElementT::get(resultType, calculate(lhs.getValue(), rhs.getValue())); - } - else if constexpr(std::is_same::value) { - if(!resultType.isSignlessInteger(1)) - throw ErrorHandler::compilerError( - loc, "CanonicalizerPass (constFoldBinaryOp)", "expected boolean result type" - ); - return ResAttrElementT::get(lhs.getContext(), calculate(lhs.getValue(), rhs.getValue())); - } - else if constexpr(std::is_same::value) { - if(!resultType.isa()) - throw ErrorHandler::compilerError( - loc, "CanonicalizerPass (constFoldBinaryOp)", "expected string result type" - ); - return ResAttrElementT::get(calculate(lhs.getValue(), rhs.getValue()), resultType); - } - } - return {}; -} -template> -mlir::Attribute constFoldUnaryOp(mlir::Location loc, mlir::Type resultType, llvm::ArrayRef operands, - const CalculationT &calculate) { - if (operands.size() != 1) - throw ErrorHandler::compilerError(loc, - "CanonicalizerPass (constFoldUnaryOp)", - "unary op takes one operand but " + std::to_string(operands.size()) + " were given"); - - if (!operands[0]) - return {}; - - if (llvm::isa(operands[0])) { - auto operand = operands[0].cast(); - - return AttrElementT::get(resultType, calculate(operand.getValue())); - } - return {}; -} - -// **************************************************************************** -// Fold implementations -// **************************************************************************** -mlir::Attribute performCast(mlir::Attribute attr, mlir::Type targetType, mlir::Location loc) { - if (auto intAttr = attr.dyn_cast()) { - auto apInt = intAttr.getValue(); - - if (auto outTy = targetType.dyn_cast()) { - // Extend or truncate the integer value based on the target type - if (outTy.isUnsignedInteger()) { - apInt = apInt.zextOrTrunc(outTy.getWidth()); - } else if (outTy.isSignedInteger()) { - apInt = (intAttr.getType().isSignedInteger()) - ? apInt.sextOrTrunc(outTy.getWidth()) - : apInt.zextOrTrunc(outTy.getWidth()); - } - return mlir::IntegerAttr::getChecked(loc, outTy, apInt); - } - - if (auto outTy = targetType.dyn_cast()) { - return mlir::IntegerAttr::getChecked(loc, outTy, apInt); - } - - if (targetType.isF64()) { - if (intAttr.getType().isSignedInteger()) { - return mlir::FloatAttr::getChecked(loc, targetType, - llvm::APIntOps::RoundSignedAPIntToDouble(apInt)); - } - if (intAttr.getType().isUnsignedInteger() || intAttr.getType().isIndex()) { - return mlir::FloatAttr::getChecked(loc, targetType, - llvm::APIntOps::RoundAPIntToDouble(apInt)); - } - } - - if (targetType.isF32()) { - if (intAttr.getType().isSignedInteger()) { - return mlir::FloatAttr::getChecked(loc, targetType, - llvm::APIntOps::RoundSignedAPIntToFloat(apInt)); - } - if (intAttr.getType().isUnsignedInteger()) { - return mlir::FloatAttr::get(targetType, - llvm::APIntOps::RoundAPIntToFloat(apInt)); - } - } - } - else if (auto floatAttr = attr.dyn_cast()) { - auto val = floatAttr.getValueAsDouble(); - - if (targetType.isF64()) { - return mlir::FloatAttr::getChecked(loc, targetType, val); - } - if (targetType.isF32()) { - return mlir::FloatAttr::getChecked(loc, targetType, static_cast(val)); - } - if (targetType.isIntOrIndex()) { - auto num = static_cast(val); - return mlir::IntegerAttr::getChecked(loc, targetType, num); - } - } - - // If casting is not possible, return the original attribute - return {}; -} - -mlir::OpFoldResult mlir::daphne::CastOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - - if (isTrivialCast()) { - if (operands[0]) - return {operands[0]}; - else - return {getArg()}; - } - - if (operands[0]) { - if (auto castedAttr = performCast(operands[0], getType(), getLoc())) { - return castedAttr; - } - } - - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwAddOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a + b; }; - // TODO: we could check overflows - auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a + b; }; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, intOp)) - return res; - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwSubOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a - b; }; - auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a - b; }; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, intOp)) - return res; - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwMulOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a * b; }; - auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a * b; }; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, intOp)) - return res; - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwDivOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a / b; }; - auto sintOp = [&](const llvm::APInt &a, const llvm::APInt &b) { - if(b == 0) { - throw ErrorHandler::compilerError( - this->getLoc(), "CanonicalizerPass (mlir::daphne::EwDivOp::fold)", - "Can't divide by 0"); - } - return a.sdiv(b); - }; - auto uintOp = [&](const llvm::APInt &a, const llvm::APInt &b) { - if(b == 0) { - throw ErrorHandler::compilerError( - this->getLoc(), "CanonicalizerPass (mlir::daphne::EwDivOp::fold)", - "Can't divide by 0"); - } - return a.udiv(b); - }; - - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(getType().isSignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, sintOp)) - return res; - } - else if(getType().isUnsignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, uintOp)) - return res; - } - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwMinusOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto intOp = [](const llvm::APInt &a) { return -a; }; - auto floatOp = [](const llvm::APFloat &a) { return -a; }; - - if (auto res = constFoldUnaryOp(getLoc(), getType(), operands, intOp)) - return res; - if (auto res = constFoldUnaryOp(getLoc(), getType(), operands, floatOp)) - return res; - - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwPowOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - // TODO: EwPowOp integer constant folding - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { - return std::pow(a.convertToDouble(), b.convertToDouble()); - }; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwModOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto sintOp = [&](const llvm::APInt &a, const llvm::APInt &b) { - if(b == 0) { - throw ErrorHandler::compilerError( - this->getLoc(), "CanonicalizerPass (mlir::daphne::EwModOp::fold)", - "Can't compute mod 0"); - } - return a.srem(b); - }; - auto uintOp = [&](const llvm::APInt &a, const llvm::APInt &b) { - if(b == 0) { - throw ErrorHandler::compilerError( - this->getLoc(), "CanonicalizerPass (mlir::daphne::EwModOp::fold)", - "Can't compute mod 0"); - } - return a.urem(b); - }; - if(getType().isSignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, sintOp)) - return res; - } - else if(getType().isUnsignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, uintOp)) - return res; - } - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwLogOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { - // Compute the element-wise logarithm of a to the base b - // Equivalent to log_b(a) - return log(a.convertToDouble()) / log(b.convertToDouble()); - }; - if (auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwMinOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return llvm::minimum(a, b); }; - auto sintOp = [&](const llvm::APInt &a, const llvm::APInt &b) { - if(a.slt(b)) - return a; - else - return b; - }; - auto uintOp = [&](const llvm::APInt &a, const llvm::APInt &b) { - if(a.ult(b)) - return a; - else - return b; - }; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(getType().isSignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, sintOp)) - return res; - } - else if(getType().isUnsignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, uintOp)) - return res; - } - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwMaxOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return llvm::maximum(a, b); }; - auto sintOp = [&](const llvm::APInt &a, const llvm::APInt &b) { - if(a.sgt(b)) - return a; - else - return b; - }; - auto uintOp = [&](const llvm::APInt &a, const llvm::APInt &b) { - if(a.ugt(b)) - return a; - else - return b; - }; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(getType().isSignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, sintOp)) - return res; - } - else if(getType().isUnsignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, uintOp)) - return res; - } - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwAndOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto boolOp = [](const bool &a, const bool &b) { return a && b; }; - auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return (a != 0) && (b != 0); }; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, boolOp)) - return res; - // TODO: should output bool? - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, intOp)) - return res; - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwBitwiseAndOp::fold(FoldAdaptor adaptor) { - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwOrOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto boolOp = [](const bool &a, const bool &b) { return a || b; }; - auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return (a != 0) || (b != 0); }; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, boolOp)) - return res; - // TODO: should output bool - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, intOp)) - return res; - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwXorOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto boolOp = [](const bool &a, const bool &b) { return a ^ b; }; - auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return (a != 0) ^ (b != 0); }; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, boolOp)) - return res; - // TODO: should output bool - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, intOp)) - return res; - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwConcatOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - - if (operands.size() != 2) - throw ErrorHandler::compilerError( - this->getLoc(), "CanonicalizerPass (mlir::daphne::EwConcatOp::fold)", - "binary op takes two operands but " + std::to_string(operands.size()) + " were given"); - - if(!operands[0] || !operands[1]) - return {}; - - if(llvm::isa(operands[0]) && isa(operands[1])) { - auto lhs = operands[0].cast(); - auto rhs = operands[1].cast(); - - auto concated = lhs.getValue().str() + rhs.getValue().str(); - return StringAttr::get(concated, getType()); - } - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwEqOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a == b; }; - auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a == b; }; - auto strOp = [](const llvm::StringRef &a, const llvm::StringRef &b) { return a == b; }; - // TODO: fix bool return - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, intOp)) - return res; - if(auto res = constFoldBinaryOp(getLoc(), IntegerType::get(getContext(), 64, IntegerType::SignednessSemantics::Signed), operands, strOp)) - return res; - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwNeqOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a != b; }; - auto intOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a != b; }; - // TODO: fix bool return - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, intOp)) - return res; - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwLtOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a < b; }; - auto sintOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a.slt(b); }; - auto uintOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a.ult(b); }; - // TODO: fix bool return - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(getType().isSignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, sintOp)) - return res; - } - else if(getType().isUnsignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, uintOp)) - return res; - } - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwLeOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a <= b; }; - auto sintOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a.sle(b); }; - auto uintOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a.ule(b); }; - // TODO: fix bool return - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(getType().isSignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, sintOp)) - return res; - } - else if(getType().isUnsignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, uintOp)) - return res; - } - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwGtOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a > b; }; - auto sintOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a.sgt(b); }; - auto uintOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a.ugt(b); }; - // TODO: fix bool return - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(getType().isSignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, sintOp)) - return res; - } - else if(getType().isUnsignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, uintOp)) - return res; - } - return {}; -} - -mlir::OpFoldResult mlir::daphne::EwGeOp::fold(FoldAdaptor adaptor) { - ArrayRef operands = adaptor.getOperands(); - auto floatOp = [](const llvm::APFloat &a, const llvm::APFloat &b) { return a >= b; }; - auto sintOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a.sge(b); }; - auto uintOp = [](const llvm::APInt &a, const llvm::APInt &b) { return a.uge(b); }; - // TODO: fix bool return - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, floatOp)) - return res; - if(getType().isSignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, sintOp)) - return res; - } - else if(getType().isUnsignedInteger()) { - if(auto res = constFoldBinaryOp(getLoc(), getType(), operands, uintOp)) - return res; - } - return {}; -} - -/** - * @brief Transposition-aware matrix multiplication - * Identifies if an input to a MatMulOp is the result of a TransposeOp; Rewrites the Operation, - * passing transposition info as a flag, instead of transposing the matrix before multiplication - */ -mlir::LogicalResult mlir::daphne::MatMulOp::canonicalize( - mlir::daphne::MatMulOp op, PatternRewriter &rewriter -) { - mlir::Value lhs = op.getLhs(); - mlir::Value rhs = op.getRhs(); - mlir::Value transa = op.getTransa(); - mlir::Value transb = op.getTransb(); - - // TODO If transa or transb are not constant, we cannot continue on the respective side; - // we cannot just assume false then. - bool ta = CompilerUtils::constantOrDefault(transa, false); - bool tb = CompilerUtils::constantOrDefault(transb, false); - - // TODO Turn on the transposition-awareness for the left-hand-side argument again (see #447). - // mlir::daphne::TransposeOp lhsTransposeOp = lhs.getDefiningOp(); - mlir::daphne::TransposeOp rhsTransposeOp = rhs.getDefiningOp(); - - //if (!lhsTransposeOp && !rhsTransposeOp){ - if (!rhsTransposeOp){ - return mlir::failure(); - } - - // ToDo: This check prevents merging transpose into matrix multiplication because that is not yet supported by our - // sparse kernels. - // ToDo: bring user config here for sparsity threshold or properly use MatrixRepresentation - if(auto t = rhs.getType().dyn_cast()) { - auto sparsity = t.getSparsity(); - if(sparsity < 0.25) - return mlir::failure(); - } - -#if 0 - // TODO Adapt PhyOperatorSelectionPass once this code is turned on again. - if(lhsTransposeOp) { - lhs = lhsTransposeOp.getArg(); - ta = !ta; - } -#endif - if(rhsTransposeOp) { - rhs = rhsTransposeOp.getArg(); - tb = !tb; - } - - rewriter.replaceOpWithNewOp( - op, op.getType(), lhs, rhs, - static_cast(rewriter.create(transa.getLoc(), ta)), - static_cast(rewriter.create(transb.getLoc(), tb)) - ); - return mlir::success(); -} - -/** - * @brief Replaces NumRowsOp by a constant, if the #rows of the input is known - * (e.g., due to shape inference). - */ -mlir::LogicalResult mlir::daphne::NumRowsOp::canonicalize( - mlir::daphne::NumRowsOp op, PatternRewriter &rewriter -) { - ssize_t numRows = -1; - - mlir::Type inTy = op.getArg().getType(); - if(auto t = inTy.dyn_cast()) - numRows = t.getNumRows(); - else if(auto t = inTy.dyn_cast()) - numRows = t.getNumRows(); - - if(numRows != -1) { - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), rewriter.getIndexAttr(numRows) - ); - return mlir::success(); - } - return mlir::failure(); -} - -/** - * @brief Replaces NumColsOp by a constant, if the #cols of the input is known - * (e.g., due to shape inference). - */ -mlir::LogicalResult mlir::daphne::NumColsOp::canonicalize( - mlir::daphne::NumColsOp op, PatternRewriter &rewriter -) { - ssize_t numCols = -1; - - mlir::Type inTy = op.getArg().getType(); - if(auto t = inTy.dyn_cast()) - numCols = t.getNumCols(); - else if(auto t = inTy.dyn_cast()) - numCols = t.getNumCols(); - - if(numCols != -1) { - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), rewriter.getIndexAttr(numCols) - ); - return mlir::success(); - } - return mlir::failure(); -} - -/** - * @brief Replaces NumCellsOp by a constant, if the #rows and #cols of the - * input is known (e.g., due to shape inference). - */ -mlir::LogicalResult mlir::daphne::NumCellsOp::canonicalize( - mlir::daphne::NumCellsOp op, PatternRewriter &rewriter -) { - ssize_t numRows = -1; - ssize_t numCols = -1; - - mlir::Type inTy = op.getArg().getType(); - if(auto t = inTy.dyn_cast()) { - numRows = t.getNumRows(); - numCols = t.getNumCols(); - } - else if(auto t = inTy.dyn_cast()) { - numRows = t.getNumRows(); - numCols = t.getNumCols(); - } - - if(numRows != -1 && numCols != -1) { - rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), rewriter.getIndexAttr(numRows * numCols) - ); - return mlir::success(); - } - return mlir::failure(); -} - -/** - * @brief Replaces SparsityOp by a constant, if the sparsity of the input is known - * (e.g., due to sparsity inference). - */ -mlir::LogicalResult mlir::daphne::SparsityOp::canonicalize( - mlir::daphne::SparsityOp op, PatternRewriter &rewriter -) { - double sparsity = -1.0; - - mlir::Type inTy = op.getArg().getType(); - if(auto t = inTy.dyn_cast()) - sparsity = t.getSparsity(); - - if(sparsity != -1) { - rewriter.replaceOpWithNewOp( - op, sparsity - ); - return mlir::success(); - } - return mlir::failure(); -} - -/** - * @brief Replaces a `DistributeOp` by a `DistributedReadOp`, if its input - * value (a) is defined by a `ReadOp`, and (b) is not used elsewhere. - * @param context - */ -struct SimplifyDistributeRead : public mlir::OpRewritePattern { - SimplifyDistributeRead(mlir::MLIRContext *context) - : OpRewritePattern(context, 1) { - // - } - - mlir::LogicalResult - matchAndRewrite( - mlir::daphne::DistributeOp op, mlir::PatternRewriter &rewriter - ) const override { - mlir::daphne::ReadOp readOp = op.getMat().getDefiningOp(); - if(!readOp || !readOp.getOperation()->hasOneUse()) - return mlir::failure(); - rewriter.replaceOp( - op, {rewriter.create( - readOp.getLoc(), op.getType(), readOp.getFileName() - )} - ); - // TODO Instead of erasing the ReadOp here, the compiler should - // generally remove unused SSA values. Then, we might even drop the - // hasOneUse requirement above. - rewriter.eraseOp(readOp); - return mlir::success(); - } -}; - -/** - * @brief Replaces (1) `a + b` by `a concat b`, if `a` or `b` is a string, - * and (2) `a + X` by `X + a` (`a` scalar, `X` matrix/frame). - * - * (1) is important, since we use the `+`-operator for both addition and - * string concatenation in DaphneDSL, while the types of the operands might be - * known only after type inference. - * - * (2) is important, since our kernels for elementwise binary operations only support - * scalars as the right-hand-side operand so far (see #203). - * - * @param op - * @param rewriter - * @return - */ -mlir::LogicalResult mlir::daphne::EwAddOp::canonicalize( - mlir::daphne::EwAddOp op, PatternRewriter &rewriter -) { - mlir::Value lhs = op.getLhs(); - mlir::Value rhs = op.getRhs(); - - const bool lhsIsStr = llvm::isa(lhs.getType()); - const bool rhsIsStr = llvm::isa(rhs.getType()); - if(lhsIsStr || rhsIsStr) { - mlir::Type strTy = mlir::daphne::StringType::get(rewriter.getContext()); - if(!lhsIsStr) - lhs = rewriter.create(op.getLoc(), strTy, lhs); - if(!rhsIsStr) - rhs = rewriter.create(op.getLoc(), strTy, rhs); - rewriter.replaceOpWithNewOp(op, strTy, lhs, rhs); - return mlir::success(); - } - else { - const bool lhsIsSca = !llvm::isa(lhs.getType()); - const bool rhsIsSca = !llvm::isa(rhs.getType()); - if(lhsIsSca && !rhsIsSca) { - rewriter.replaceOpWithNewOp(op, op.getResult().getType(), rhs, lhs); - return mlir::success(); - } - return mlir::failure(); - } -} - -/** - * @brief Replaces `a - X` by `(X * -1) + a` (`a` scalar, `X` matrix/frame). - * - * This is important, since our kernels for elementwise binary operations only support - * scalars as the right-hand-side operand so far (see #203). - * - * As a downside, an additional operation and intermediate result is introduced. - * - * @param op - * @param rewriter - * @return - */ -mlir::LogicalResult mlir::daphne::EwSubOp::canonicalize( - mlir::daphne::EwSubOp op, PatternRewriter &rewriter -) { - mlir::Value lhs = op.getLhs(); - mlir::Value rhs = op.getRhs(); - const bool lhsIsSca = !llvm::isa(lhs.getType()); - const bool rhsIsSca = !llvm::isa(rhs.getType()); - if(lhsIsSca && !rhsIsSca) { - rewriter.replaceOpWithNewOp( - op, - op.getResult().getType(), - rewriter.create( - op->getLoc(), - mlir::daphne::UnknownType::get(op->getContext()), // to be inferred - rhs, - rewriter.create(op->getLoc(), int64_t(-1)) - ), - lhs - ); - return mlir::success(); - } - return mlir::failure(); -} - -/** - * @brief Replaces `a * X` by `X * a` (`a` scalar, `X` matrix/frame). - * - * This is important, since our kernels for elementwise binary operations only support - * scalars as the right-hand-side operand so far (see #203). - * - * @param op - * @param rewriter - * @return - */ -mlir::LogicalResult mlir::daphne::EwMulOp::canonicalize( - mlir::daphne::EwMulOp op, PatternRewriter &rewriter -) { - mlir::Value lhs = op.getLhs(); - mlir::Value rhs = op.getRhs(); - const bool lhsIsSca = !llvm::isa(lhs.getType()); - const bool rhsIsSca = !llvm::isa(rhs.getType()); - if(lhsIsSca && !rhsIsSca) { - rewriter.replaceOpWithNewOp(op, op.getResult().getType(), rhs, lhs); - return mlir::success(); - } - return mlir::failure(); -} - -/** - * @brief Replaces `a / X` by `(X ^ -1) * a` (`a` scalar, `X` matrix/frame), - * if `X` has a floating-point value type. - * - * This is important, since our kernels for elementwise binary operations only support - * scalars as the right-hand-side operand so far (see #203). - * - * As a downside, an additional operation and intermediate result is introduced. - * - * @param op - * @param rewriter - * @return - */ -mlir::LogicalResult mlir::daphne::EwDivOp::canonicalize( - mlir::daphne::EwDivOp op, PatternRewriter &rewriter -) { - mlir::Value lhs = op.getLhs(); - mlir::Value rhs = op.getRhs(); - const bool lhsIsSca = !llvm::isa(lhs.getType()); - const bool rhsIsSca = !llvm::isa(rhs.getType()); - const bool rhsIsFP = llvm::isa(CompilerUtils::getValueType(rhs.getType())); - if(lhsIsSca && !rhsIsSca && rhsIsFP) { - rewriter.replaceOpWithNewOp( - op, - op.getResult().getType(), - rewriter.create( - op->getLoc(), - mlir::daphne::UnknownType::get(op->getContext()), // to be inferred - rhs, - rewriter.create(op->getLoc(), double(-1)) - ), - lhs - ); - return mlir::success(); - } - return mlir::failure(); -} - -void mlir::daphne::DistributeOp::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context -) { - results.add(context); -} - -mlir::LogicalResult mlir::daphne::CondOp::canonicalize(mlir::daphne::CondOp op, - mlir::PatternRewriter &rewriter) -{ - mlir::Value cond = op.getCond(); - if(llvm::isa(cond.getType())) - // If the condition is not a scalar, we cannot rewrite the operation here. - return mlir::failure(); - else { - // If the condition is a scalar, we rewrite the operation to an if-then-else construct - // using the SCF dialect. - // TODO Check if it is really a scalar. - - mlir::Location loc = op.getLoc(); - - // Ensure that the condition is a boolean. - if(!cond.getType().isSignlessInteger(1)) - cond = rewriter.create(loc, rewriter.getI1Type(), cond); - - mlir::Block thenBlock; - mlir::Block elseBlock; - mlir::Value thenVal = op.getThenVal(); - mlir::Value elseVal = op.getElseVal(); - - // Get rid of frame column labels, since they interfere with the type comparison (see #485). - if(auto thenFrmTy = thenVal.getType().dyn_cast()) - if(thenFrmTy.getLabels() != nullptr) - thenVal = rewriter.create(loc, thenFrmTy.withLabels(nullptr), thenVal); - if(auto elseFrmTy = elseVal.getType().dyn_cast()) - if(elseFrmTy.getLabels() != nullptr) - elseVal = rewriter.create(loc, elseFrmTy.withLabels(nullptr), elseVal); - - // Check if the types of the then-value and the else-value are the same. - if(thenVal.getType() != elseVal.getType()) { - if(llvm::isa(thenVal.getType()) || llvm::isa(elseVal.getType())) - // If one of them is unknown, we abort the rewrite (but this is not an error). - // The type may become known later, this rewrite will be triggered again. - return mlir::failure(); - else - // If both types are known, but different, this is an error. - // TODO We could try to cast the types. - throw ErrorHandler::compilerError( - op, "CanonicalizerPass (mlir::daphne::CondOp)", - "the then/else-values of CondOp must have the same value " - "type"); - } - - { - // Save the insertion point (automatically restored at the end of the block). - PatternRewriter::InsertionGuard insertGuard(rewriter); - - // TODO The current implementation only makes sure that the correct value is - // returned, but the operations calculating the then/else-values are still - // outside the if-then-else and will always both be executed (unless, e.g., - // the entire branching can be elimitated). This could be good (e.g., if - // the then/else-values have common subexpressions with other code) or bad - // (e.g., if they are expensive to compute). See #486. - - // Create yield-operations in both branches. - rewriter.setInsertionPointToEnd(&thenBlock); - rewriter.create(loc, thenVal); - rewriter.setInsertionPointToEnd(&elseBlock); - rewriter.create(loc, elseVal); - } - - // Helper functions to move the operations in the two blocks created above - // into the actual branches of the if-operation. - auto insertThenBlockDo = [&](mlir::OpBuilder & nested, mlir::Location loc) { - nested.getBlock()->getOperations().splice(nested.getBlock()->end(), thenBlock.getOperations()); - }; - auto insertElseBlockDo = [&](mlir::OpBuilder & nested, mlir::Location loc) { - nested.getBlock()->getOperations().splice(nested.getBlock()->end(), elseBlock.getOperations()); - }; - - // Replace the daphne::CondOp by an scf::IfOp. - rewriter.replaceOpWithNewOp( - op, cond, insertThenBlockDo, insertElseBlockDo - ); - - return mlir::success(); - } -} - -mlir::LogicalResult mlir::daphne::ConvertDenseMatrixToMemRef::canonicalize( - mlir::daphne::ConvertDenseMatrixToMemRef op, - mlir::PatternRewriter &rewriter) { - // removes unnecessary conversions of MemRef -> DM -> MemRef - mlir::Operation *dmNode = op->getOperand(0).getDefiningOp(); - - if (!llvm::isa(dmNode)) - return failure(); - - mlir::Operation *originalMemRefOp = - dmNode->getPrevNode()->getOperand(0).getDefiningOp(); - op.replaceAllUsesWith(originalMemRefOp); - - rewriter.eraseOp(op); - if (dmNode->getUsers().empty()) rewriter.eraseOp(dmNode); - - return mlir::success(); -} - -mlir::LogicalResult mlir::daphne::ConvertMemRefToDenseMatrix::canonicalize( - mlir::daphne::ConvertMemRefToDenseMatrix op, - mlir::PatternRewriter &rewriter) { - mlir::Operation *extractPtr = op->getPrevNode(); - auto srcMemRef = extractPtr->getOperand(0).getDefiningOp(); - extractPtr->moveAfter(srcMemRef); - op->moveAfter(extractPtr); - - return mlir::success(); -} - -mlir::LogicalResult mlir::daphne::RenameOp::canonicalize( - mlir::daphne::RenameOp op, - mlir::PatternRewriter &rewriter -) { - // Replace the RenameOp by its argument, since we only need - // this operation during DaphneDSL parsing. - rewriter.replaceOp(op, op.getArg()); - return mlir::success(); -} - - -/** - * @brief Replaces `--a` by `a` (`a` scalar). - * - * @param op - * @param rewriter - * @return - */ -mlir::LogicalResult mlir::daphne::EwMinusOp::canonicalize( - mlir::daphne::EwMinusOp op, PatternRewriter &rewriter -) { - if (auto innerOp = op.getOperand().getDefiningOp()) { - rewriter.replaceOp(op, innerOp.getOperand()); - return mlir::success(); - } - return mlir::failure(); -} \ No newline at end of file diff --git a/src/ir/daphneir/DaphneDistributableOpInterface.cpp b/src/ir/daphneir/DaphneDistributableOpInterface.cpp index d12eb6e43..0575b33f7 100644 --- a/src/ir/daphneir/DaphneDistributableOpInterface.cpp +++ b/src/ir/daphneir/DaphneDistributableOpInterface.cpp @@ -17,12 +17,11 @@ #include #include +#include #include #include -#include -namespace mlir::daphne -{ +namespace mlir::daphne { #include } @@ -42,14 +41,12 @@ Type getWrappedType(Value v) { return wrappedType.dyn_cast().withSameElementTypeAndRepr(); } -template +template std::vector createEquivalentDistributedDAG_EwBinaryOp(EwBinaryOp *op, mlir::OpBuilder &builder, - mlir::ValueRange distributedInputs) -{ + mlir::ValueRange distributedInputs) { auto loc = op->getLoc(); - auto compute = builder.create(loc, - ArrayRef{daphne::HandleType::get(op->getContext(), op->getType())}, - distributedInputs); + auto compute = builder.create( + loc, ArrayRef{daphne::HandleType::get(op->getContext(), op->getType())}, distributedInputs); auto &block = compute.getBody().emplaceBlock(); auto argLhs = block.addArgument(getWrappedType(distributedInputs[0]), builder.getUnknownLoc()); auto argRhs = block.addArgument(getWrappedType(distributedInputs[1]), builder.getUnknownLoc()); @@ -68,50 +65,46 @@ std::vector createEquivalentDistributedDAG_EwBinaryOp(EwBinaryOp *o return ret; } -template -std::vector getOperandDistrPrimitives_EwBinaryOp(EwBinaryOp *op) { +template std::vector getOperandDistrPrimitives_EwBinaryOp(EwBinaryOp *op) { Type tL0 = op->getLhs().getType(); - auto tL = tL0.dyn_cast(); + auto tL = tL0.dyn_cast(); Type tR0 = op->getRhs().getType(); - auto tR = tR0.dyn_cast(); + auto tR = tR0.dyn_cast(); const ssize_t nrL = tL.getNumRows(); const ssize_t ncL = tL.getNumCols(); const ssize_t nrR = tR.getNumRows(); const ssize_t ncR = tR.getNumCols(); if (nrL == -1 || nrR == -1 || ncL == -1 || ncR == -1) - throw ErrorHandler::compilerError( - op->getLoc(), "DistributableOpInterface", - "unknown shapes of left and/or right operand to elementwise " - "binary operation are not supported while deciding " - "distribute/broadcast"); - - if(nrL == nrR && ncL == ncR) // matrix-matrix - return {false, false}; // distribute both inputs - else if(nrR == 1 && ncL == ncR) // matrix-row - return {false, true}; // distribute lhs, broadcast rhs - else if(nrL == nrR && ncR == 1) // matrix-col - return {false, true}; // distribute lhs, broadcast rhs + throw ErrorHandler::compilerError(op->getLoc(), "DistributableOpInterface", + "unknown shapes of left and/or right operand to elementwise " + "binary operation are not supported while deciding " + "distribute/broadcast"); + + if (nrL == nrR && ncL == ncR) // matrix-matrix + return {false, false}; // distribute both inputs + else if (nrR == 1 && ncL == ncR) // matrix-row + return {false, true}; // distribute lhs, broadcast rhs + else if (nrL == nrR && ncR == 1) // matrix-col + return {false, true}; // distribute lhs, broadcast rhs else - throw ErrorHandler::compilerError( - op->getLoc(), "DistributableOpInterface", - "mismatching shapes of left and right operand to elementwise " - "binary operation while deciding distribute/broadcast"); + throw ErrorHandler::compilerError(op->getLoc(), "DistributableOpInterface", + "mismatching shapes of left and right operand to elementwise " + "binary operation while deciding distribute/broadcast"); } // **************************************************************************** // DistributableOpInterface implementations // **************************************************************************** -#define IMPL_EWBINARYOP(OP) \ - std::vector mlir::daphne::OP::createEquivalentDistributedDAG(mlir::OpBuilder &builder, \ - mlir::ValueRange distributedInputs) \ - { \ - return createEquivalentDistributedDAG_EwBinaryOp(this, builder, distributedInputs); \ - } \ - \ - std::vector mlir::daphne::OP::getOperandDistrPrimitives() { \ - return getOperandDistrPrimitives_EwBinaryOp(this); \ +#define IMPL_EWBINARYOP(OP) \ + std::vector mlir::daphne::OP::createEquivalentDistributedDAG(mlir::OpBuilder &builder, \ + mlir::ValueRange distributedInputs) { \ + return createEquivalentDistributedDAG_EwBinaryOp(this, builder, distributedInputs); \ + } \ + \ + std::vector mlir::daphne::OP::getOperandDistrPrimitives() { \ + return getOperandDistrPrimitives_EwBinaryOp(this); \ } // TODO We should use traits (like for shape inference) so that we don't need @@ -149,13 +142,11 @@ IMPL_EWBINARYOP(EwLeOp) IMPL_EWBINARYOP(EwGtOp) IMPL_EWBINARYOP(EwGeOp) -std::vector daphne::RowAggMaxOp::createEquivalentDistributedDAG( - OpBuilder &builder, ValueRange distributedInputs -) { +std::vector daphne::RowAggMaxOp::createEquivalentDistributedDAG(OpBuilder &builder, + ValueRange distributedInputs) { auto loc = getLoc(); - auto compute = builder.create(loc, - ArrayRef{daphne::HandleType::get(getContext(), getType())}, - distributedInputs); + auto compute = builder.create( + loc, ArrayRef{daphne::HandleType::get(getContext(), getType())}, distributedInputs); auto &block = compute.getBody().emplaceBlock(); auto arg = block.addArgument(getWrappedType(distributedInputs[0]), builder.getUnknownLoc()); @@ -172,6 +163,4 @@ std::vector daphne::RowAggMaxOp::createEquivalentDistributedDAG( return ret; } -std::vector daphne::RowAggMaxOp::getOperandDistrPrimitives() { - return {false}; -} +std::vector daphne::RowAggMaxOp::getOperandDistrPrimitives() { return {false}; } diff --git a/src/ir/daphneir/DaphneDistributableOpInterface.h b/src/ir/daphneir/DaphneDistributableOpInterface.h index 0bcde3064..7d3c8a024 100644 --- a/src/ir/daphneir/DaphneDistributableOpInterface.h +++ b/src/ir/daphneir/DaphneDistributableOpInterface.h @@ -17,8 +17,8 @@ #ifndef SRC_IR_DAPHNEIR_DAPHNEDISTRIBUTABLEOPINTERFACE_H #define SRC_IR_DAPHNEIR_DAPHNEDISTRIBUTABLEOPINTERFACE_H -#include "mlir/IR/OpDefinition.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" namespace mlir { namespace daphne { diff --git a/src/ir/daphneir/DaphneInferFrameLabelsOpInterface.cpp b/src/ir/daphneir/DaphneInferFrameLabelsOpInterface.cpp index 0bfaa8c53..6746aebbe 100644 --- a/src/ir/daphneir/DaphneInferFrameLabelsOpInterface.cpp +++ b/src/ir/daphneir/DaphneInferFrameLabelsOpInterface.cpp @@ -15,15 +15,15 @@ */ #include -#include #include #include +#include #include +#include #include #include -#include namespace mlir::daphne { #include @@ -36,10 +36,9 @@ using namespace mlir; // **************************************************************************** // For families of operations. -template -void inferFrameLabels_ExtractOrFilterRowOp(ExtractOrFilterRowOp * op) { +template void inferFrameLabels_ExtractOrFilterRowOp(ExtractOrFilterRowOp *op) { Type t = op->getSource().getType(); - if(auto ft = t.dyn_cast()) { + if (auto ft = t.dyn_cast()) { Value res = op->getResult(); res.setType(res.getType().dyn_cast().withLabels(ft.getLabels())); } @@ -53,7 +52,7 @@ void daphne::ReadOp::inferFrameLabels() { auto p = CompilerUtils::isConstant(getFileName()); if (auto resType = getRes().getType().dyn_cast()) { if (p.first) { - std::vector * labels; + std::vector *labels; FileMetaData fmd = CompilerUtils::getFileMetaData(getFileName()); if (fmd.labels.empty()) { labels = nullptr; @@ -71,21 +70,19 @@ void daphne::ColBindOp::inferFrameLabels() { auto ftLhs = getLhs().getType().dyn_cast(); auto ftRhs = getRhs().getType().dyn_cast(); - if(!ftLhs || !ftRhs) - throw ErrorHandler::compilerError( - getLoc(), "daphne::ColBindOp::inferFrameLabels", - "currently ColBindOp can only infer its output labels if both " - "inputs are frames"); - if(!ftLhs.getLabels() || !ftRhs.getLabels()) - throw ErrorHandler::compilerError( - getLoc(), "daphne::ColBindOp::inferFrameLabels", - "currenly ColBindOp can only infer its output labels if the " - "labels of both input frames are known"); + if (!ftLhs || !ftRhs) + throw ErrorHandler::compilerError(getLoc(), "daphne::ColBindOp::inferFrameLabels", + "currently ColBindOp can only infer its output labels if both " + "inputs are frames"); + if (!ftLhs.getLabels() || !ftRhs.getLabels()) + throw ErrorHandler::compilerError(getLoc(), "daphne::ColBindOp::inferFrameLabels", + "currenly ColBindOp can only infer its output labels if the " + "labels of both input frames are known"); auto labelsRes = new std::vector(); - for(auto l : *(ftLhs.getLabels())) + for (auto l : *(ftLhs.getLabels())) labelsRes->push_back(l); - for(auto l : *(ftRhs.getLabels())) + for (auto l : *(ftRhs.getLabels())) labelsRes->push_back(l); Value res = getResult(); @@ -94,7 +91,7 @@ void daphne::ColBindOp::inferFrameLabels() { void daphne::CreateFrameOp::inferFrameLabels() { auto resLabels = new std::vector(); - for(Value label : getLabels()) + for (Value label : getLabels()) resLabels->push_back(CompilerUtils::constantOrThrow(label)); Value res = getResult(); res.setType(res.getType().dyn_cast().withLabels(resLabels)); @@ -103,8 +100,8 @@ void daphne::CreateFrameOp::inferFrameLabels() { void daphne::ExtractColOp::inferFrameLabels() { auto ft = getSource().getType().dyn_cast(); auto st = getSelectedCols().getType().dyn_cast(); - - if(ft && st) { + + if (ft && st) { std::string label = CompilerUtils::constantOrThrow(getSelectedCols()); std::string delimiter = "."; const std::string frameName = label.substr(0, label.find(delimiter)); @@ -116,7 +113,7 @@ void daphne::ExtractColOp::inferFrameLabels() { std::string labelFrameName = label.substr(0, label.find(delimiter)); if (labelFrameName.compare(frameName) == 0) { resultLabels->push_back(label); - } + } } Value res = getResult(); res.setType(res.getType().dyn_cast().withLabels(resultLabels)); @@ -126,23 +123,18 @@ void daphne::ExtractColOp::inferFrameLabels() { Value res = getResult(); res.setType(res.getType().dyn_cast().withLabels(resLabels)); } - } - } -void daphne::ExtractRowOp::inferFrameLabels() { - inferFrameLabels_ExtractOrFilterRowOp(this); -} +void daphne::ExtractRowOp::inferFrameLabels() { inferFrameLabels_ExtractOrFilterRowOp(this); } -void daphne::FilterRowOp::inferFrameLabels() { - inferFrameLabels_ExtractOrFilterRowOp(this); -} +void daphne::FilterRowOp::inferFrameLabels() { inferFrameLabels_ExtractOrFilterRowOp(this); } void daphne::GroupJoinOp::inferFrameLabels() { auto newLabels = new std::vector(); newLabels->push_back(CompilerUtils::constantOrThrow(getLhsOn())); - newLabels->push_back(std::string("SUM(") + CompilerUtils::constantOrThrow(getRhsAgg()) + std::string(")")); + newLabels->push_back(std::string("SUM(") + CompilerUtils::constantOrThrow(getRhsAgg()) + + std::string(")")); Value res = getResult(0); res.setType(res.getType().dyn_cast().withLabels(newLabels)); } @@ -158,14 +150,14 @@ void daphne::CartesianOp::inferFrameLabels() { auto newLabels = new std::vector(); auto ft1 = getLhs().getType().dyn_cast(); auto ft2 = getRhs().getType().dyn_cast(); - std::vector * labelsStr1 = ft1.getLabels(); - std::vector * labelsStr2 = ft2.getLabels(); + std::vector *labelsStr1 = ft1.getLabels(); + std::vector *labelsStr2 = ft2.getLabels(); - if(labelsStr1) - for(auto labelStr : *labelsStr1) + if (labelsStr1) + for (auto labelStr : *labelsStr1) newLabels->push_back(labelStr); - if(labelsStr2) - for(auto labelStr : *labelsStr2) + if (labelsStr2) + for (auto labelStr : *labelsStr2) newLabels->push_back(labelStr); getResult().setType(getRes().getType().dyn_cast().withLabels(newLabels)); @@ -173,7 +165,7 @@ void daphne::CartesianOp::inferFrameLabels() { void daphne::OrderOp::inferFrameLabels() { Type t = getArg().getType(); - if(auto ft = t.dyn_cast()) { + if (auto ft = t.dyn_cast()) { Value res = getResult(); res.setType(res.getType().dyn_cast().withLabels(ft.getLabels())); } @@ -183,32 +175,32 @@ void daphne::InnerJoinOp::inferFrameLabels() { auto newLabels = new std::vector(); auto ft1 = getLhs().getType().dyn_cast(); auto ft2 = getRhs().getType().dyn_cast(); - std::vector * labelsStr1 = ft1.getLabels(); - std::vector * labelsStr2 = ft2.getLabels(); + std::vector *labelsStr1 = ft1.getLabels(); + std::vector *labelsStr2 = ft2.getLabels(); - if(labelsStr1) - for(auto labelStr : *labelsStr1) + if (labelsStr1) + for (auto labelStr : *labelsStr1) newLabels->push_back(labelStr); - if(labelsStr2) - for(auto labelStr : *labelsStr2) + if (labelsStr2) + for (auto labelStr : *labelsStr2) newLabels->push_back(labelStr); getResult().setType(getRes().getType().dyn_cast().withLabels(newLabels)); } void daphne::ThetaJoinOp::inferFrameLabels() { - std::vector * newLabels = nullptr; - + std::vector *newLabels = nullptr; + auto ft1 = getLhs().getType().dyn_cast(); auto ft2 = getRhs().getType().dyn_cast(); - std::vector * labelsStr1 = ft1.getLabels(); - std::vector * labelsStr2 = ft2.getLabels(); + std::vector *labelsStr1 = ft1.getLabels(); + std::vector *labelsStr2 = ft2.getLabels(); - if(labelsStr1 && labelsStr2) { + if (labelsStr1 && labelsStr2) { newLabels = new std::vector(); - for(auto labelStr : *labelsStr1) + for (auto labelStr : *labelsStr1) newLabels->push_back(labelStr); - for(auto labelStr : *labelsStr2) + for (auto labelStr : *labelsStr2) newLabels->push_back(labelStr); } @@ -220,18 +212,18 @@ void daphne::GroupOp::inferFrameLabels() { std::vector aggColLabels; std::vector aggFuncNames; - for(Value t: getKeyCol()){ //Adopting keyCol Labels + for (Value t : getKeyCol()) { // Adopting keyCol Labels std::string keyLabel = CompilerUtils::constantOrThrow(t); std::string delimiter = "."; const std::string frameName = keyLabel.substr(0, keyLabel.find(delimiter)); const std::string colLabel = keyLabel.substr(keyLabel.find(delimiter) + delimiter.length(), keyLabel.length()); - - if(keyLabel == "*") { + + if (keyLabel == "*") { daphne::FrameType arg = getFrame().getType().dyn_cast(); for (std::string frameLabel : *arg.getLabels()) { newLabels->push_back(frameLabel); } - } else if(colLabel.compare("*") == 0) { + } else if (colLabel.compare("*") == 0) { daphne::FrameType arg = getFrame().getType().dyn_cast(); std::vector labels = *arg.getLabels(); for (std::string label : labels) { @@ -245,14 +237,14 @@ void daphne::GroupOp::inferFrameLabels() { } } - for(Value t: getAggCol()){ + for (Value t : getAggCol()) { aggColLabels.push_back(CompilerUtils::constantOrThrow(t)); } - for(Attribute t: getAggFuncs()){ + for (Attribute t : getAggFuncs()) { GroupEnum aggFuncValue = t.dyn_cast().getValue(); aggFuncNames.push_back(stringifyGroupEnum(aggFuncValue).str()); } - for(size_t i = 0; i < aggFuncNames.size() && i < aggColLabels.size(); i++){ + for (size_t i = 0; i < aggFuncNames.size() && i < aggColLabels.size(); i++) { newLabels->push_back(aggFuncNames.at(i) + "(" + aggColLabels.at(i) + ")"); } @@ -261,11 +253,10 @@ void daphne::GroupOp::inferFrameLabels() { void daphne::SetColLabelsOp::inferFrameLabels() { auto newLabels = new std::vector(); - for(Value label : getLabels()) { + for (Value label : getLabels()) { try { newLabels->push_back(CompilerUtils::constantOrThrow(label)); - } - catch(std::runtime_error&) { + } catch (std::runtime_error &) { // TODO This could be improved by supporting knowledge on only some // of the labels. // If we do not know the values of all label operands at @@ -281,9 +272,9 @@ void daphne::SetColLabelsPrefixOp::inferFrameLabels() { auto newLabels = new std::vector(); std::string prefixStr = CompilerUtils::constantOrThrow(getPrefix()); auto ft = getArg().getType().dyn_cast(); - std::vector * labelsStr = ft.getLabels(); - if(labelsStr) - for(auto labelStr : *labelsStr) + std::vector *labelsStr = ft.getLabels(); + if (labelsStr) + for (auto labelStr : *labelsStr) newLabels->push_back(LabelUtils::setPrefix(prefixStr, labelStr)); else { delete newLabels; diff --git a/src/ir/daphneir/DaphneInferShapeOpInterface.cpp b/src/ir/daphneir/DaphneInferShapeOpInterface.cpp index 91e05c648..d0c82c93e 100644 --- a/src/ir/daphneir/DaphneInferShapeOpInterface.cpp +++ b/src/ir/daphneir/DaphneInferShapeOpInterface.cpp @@ -15,15 +15,15 @@ */ #include -#include #include #include +#include #include -#include #include #include +#include namespace mlir::daphne { #include @@ -38,37 +38,37 @@ using namespace mlir::OpTrait; std::pair getShape(Value v) { Type t = v.getType(); - if(auto mt = t.dyn_cast()) + if (auto mt = t.dyn_cast()) return std::make_pair(mt.getNumRows(), mt.getNumCols()); - if(auto ft = t.dyn_cast()) + if (auto ft = t.dyn_cast()) return std::make_pair(ft.getNumRows(), ft.getNumCols()); // TODO Maybe check if it is really a scalar type. else // scalar return std::make_pair(1, 1); } -ssize_t inferNumRowsFromArgs(Operation* op, ValueRange vs) { +ssize_t inferNumRowsFromArgs(Operation *op, ValueRange vs) { // If the #rows of all arguments is known and matches, then this is the // inferred #rows. If the known #rows of any two arguments mismatch, an // exception is thrown. Otherwise, if the #rows of any argument is unknown, // the inferred #rows is unknown. ssize_t numRows = getShape(vs[0]).first; bool someUnknown = false; - if(numRows == -1) + if (numRows == -1) someUnknown = true; - for(size_t i = 1; i < vs.size(); i++) { + for (size_t i = 1; i < vs.size(); i++) { const ssize_t nextNumRows = getShape(vs[i]).first; - if(nextNumRows == -1) + if (nextNumRows == -1) someUnknown = true; - else if(numRows == -1) + else if (numRows == -1) numRows = nextNumRows; - else if(nextNumRows != numRows) + else if (nextNumRows != numRows) throw ErrorHandler::compilerError(op->getLoc(), "InferShapeOpInterface", - "shape inference: inferNumRowsFromArgs() requires that " - "arguments have the same number of rows, but there is " - "one with " + std::to_string(numRows) + " and one with " + - std::to_string(nextNumRows) + " rows" - ); + "shape inference: inferNumRowsFromArgs() requires that " + "arguments have the same number of rows, but there is " + "one with " + + std::to_string(numRows) + " and one with " + + std::to_string(nextNumRows) + " rows"); } return someUnknown ? -1 : numRows; } @@ -80,30 +80,30 @@ ssize_t inferNumColsFromArgs(Operation *op, ValueRange vs) { // the infered #cols is unknown. ssize_t numCols = getShape(vs[0]).second; bool someUnknown = false; - if(numCols == -1) + if (numCols == -1) someUnknown = true; - for(size_t i = 1; i < vs.size(); i++) { + for (size_t i = 1; i < vs.size(); i++) { const ssize_t nextNumCols = getShape(vs[i]).second; - if(nextNumCols == -1) + if (nextNumCols == -1) someUnknown = true; - else if(numCols == -1) + else if (numCols == -1) numCols = nextNumCols; - else if(nextNumCols != numCols) + else if (nextNumCols != numCols) throw ErrorHandler::compilerError(op->getLoc(), "InferShapeOpInterface", - "shape inference: inferNumColsFromArgs() requires that " - "arguments have the same number of columns, but there is " - "one with " + std::to_string(numCols) + " and one with " + - std::to_string(nextNumCols) + " columns" - ); + "shape inference: inferNumColsFromArgs() requires that " + "arguments have the same number of columns, but there is " + "one with " + + std::to_string(numCols) + " and one with " + + std::to_string(nextNumCols) + " columns"); } return someUnknown ? -1 : numCols; } ssize_t inferNumRowsFromSumOfArgs(ValueRange vs) { ssize_t sumNumRows = 0; - for(Value v : vs) { + for (Value v : vs) { const ssize_t numRows = getShape(v).first; - if(numRows == -1) + if (numRows == -1) return -1; sumNumRows += numRows; } @@ -112,9 +112,9 @@ ssize_t inferNumRowsFromSumOfArgs(ValueRange vs) { ssize_t inferNumColsFromSumOfArgs(ValueRange vs) { ssize_t sumNumCols = 0; - for(Value v : vs) { + for (Value v : vs) { const ssize_t numCols = getShape(v).second; - if(numCols == -1) + if (numCols == -1) return -1; sumNumCols += numCols; } @@ -133,43 +133,38 @@ ssize_t daphne::CartesianOp::inferNumRows() { ssize_t daphne::SeqOp::inferNumRows() { Type fromTy = getFrom().getType(); - if(fromTy.isF64()) { + if (fromTy.isF64()) { try { double vFrom = CompilerUtils::constantOrThrow(getFrom()); double vTo = CompilerUtils::constantOrThrow(getTo()); double vInc = CompilerUtils::constantOrThrow(getInc()); return floor(vTo / vInc - vFrom / vInc) + 1; - } - catch(const std::runtime_error & e) { + } catch (const std::runtime_error &e) { return -1; } } - if(fromTy.isF32()) { + if (fromTy.isF32()) { try { float vFrom = CompilerUtils::constantOrThrow(getFrom()); float vTo = CompilerUtils::constantOrThrow(getTo()); float vInc = CompilerUtils::constantOrThrow(getInc()); return floor(vTo / vInc - vFrom / vInc) + 1; - } - catch(const std::runtime_error & e) { + } catch (const std::runtime_error &e) { return -1; } - } - else if(fromTy.isSignedInteger(64)) { + } else if (fromTy.isSignedInteger(64)) { try { int64_t vFrom = CompilerUtils::constantOrThrow(getFrom()); int64_t vTo = CompilerUtils::constantOrThrow(getTo()); int64_t vInc = CompilerUtils::constantOrThrow(getInc()); return abs(vTo - vFrom) / abs(vInc) + 1; - } - catch(const std::runtime_error & e) { + } catch (const std::runtime_error &e) { return -1; } } - throw ErrorHandler::compilerError( - getLoc(), "InferShapeOpInterface (daphne::SeqOp::inferNumRows)", - "at the moment, shape inference for SeqOp supports only F64 and " - "SI64 value types"); + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::SeqOp::inferNumRows)", + "at the moment, shape inference for SeqOp supports only F64 and " + "SI64 value types"); } std::vector> daphne::CreateFrameOp::inferShape() { @@ -188,18 +183,18 @@ std::vector> daphne::GroupOp::inferShape() { std::vector newLabels; - for(Value t: getKeyCol()){ //Adopting keyCol Labels + for (Value t : getKeyCol()) { // Adopting keyCol Labels std::string keyLabel = CompilerUtils::constantOrThrow(t); std::string delimiter = "."; const std::string frameName = keyLabel.substr(0, keyLabel.find(delimiter)); const std::string colLabel = keyLabel.substr(keyLabel.find(delimiter) + delimiter.length(), keyLabel.length()); - - if(keyLabel == "*") { + + if (keyLabel == "*") { daphne::FrameType arg = getFrame().getType().dyn_cast(); for (std::string frameLabel : *arg.getLabels()) { newLabels.push_back(frameLabel); } - } else if(colLabel.compare("*") == 0) { + } else if (colLabel.compare("*") == 0) { daphne::FrameType arg = getFrame().getType().dyn_cast(); std::vector labels = *arg.getLabels(); for (std::string label : labels) { @@ -212,7 +207,7 @@ std::vector> daphne::GroupOp::inferShape() { newLabels.push_back(keyLabel); } } - + const size_t numCols = newLabels.size() + getAggCol().size(); return {{numRows, numCols}}; } @@ -223,12 +218,12 @@ std::vector> daphne::MatMulOp::inferShape() { ssize_t numRows = -1; std::pair pr = CompilerUtils::isConstant(getTransa()); - if(pr.first) + if (pr.first) numRows = pr.second ? shapeLhs.second : shapeLhs.first; - + ssize_t numCols = -1; std::pair pc = CompilerUtils::isConstant(getTransb()); - if(pc.first) + if (pc.first) numCols = pc.second ? shapeRhs.first : shapeRhs.second; return {{numRows, numCols}}; @@ -249,20 +244,19 @@ std::vector> daphne::OrderOp::inferShape() { size_t numCols = -1; Type t = getArg().getType(); - if(auto mt = t.dyn_cast()){ + if (auto mt = t.dyn_cast()) { numRows = mt.getNumRows(); numCols = mt.getNumCols(); } - if(auto ft = t.dyn_cast()){ + if (auto ft = t.dyn_cast()) { numRows = ft.getNumRows(); numCols = ft.getNumCols(); } std::pair p = CompilerUtils::isConstant(getReturnIdxs()); - if(p.first) { - if(p.second) + if (p.first) { + if (p.second) numCols = 1; - } - else + } else numCols = -1; return {{numRows, numCols}}; @@ -270,23 +264,22 @@ std::vector> daphne::OrderOp::inferShape() { std::vector> daphne::CondOp::inferShape() { Type condTy = getCond().getType(); - if(llvm::isa(condTy)) + if (llvm::isa(condTy)) // Actually, this should not happen, because if the type of the // condition is unknown, the type of the result should be unknown // too per type inference, such that shape inference should not // even get called. Nevertheless, returning unknown will probably // not hurt in case anyone ever calls this from somewhere else. return {{-1, -1}}; - if(auto condMatTy = condTy.dyn_cast()) + if (auto condMatTy = condTy.dyn_cast()) return {{condMatTy.getNumRows(), condMatTy.getNumCols()}}; - else if(auto condFrmTy = condTy.dyn_cast()) - throw ErrorHandler::compilerError( - getLoc(), "InferShapeOpInterface (daphne::CondOp::inferShape)", - "CondOp does not support frames for the condition yet"); + else if (auto condFrmTy = condTy.dyn_cast()) + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::CondOp::inferShape)", + "CondOp does not support frames for the condition yet"); else { // cond is a scalar // TODO check if it is really a scalar Type thenTy = getThenVal().getType(); Type elseTy = getElseVal().getType(); - + ssize_t thenNumRows = -1; ssize_t thenNumCols = -1; ssize_t elseNumRows = -1; @@ -295,28 +288,23 @@ std::vector> daphne::CondOp::inferShape() { auto thenFrmTy = thenTy.dyn_cast(); auto elseMatTy = elseTy.dyn_cast(); auto elseFrmTy = elseTy.dyn_cast(); - if(thenMatTy) { + if (thenMatTy) { thenNumRows = thenMatTy.getNumRows(); thenNumCols = thenMatTy.getNumCols(); - } - else if(thenFrmTy) { + } else if (thenFrmTy) { thenNumRows = thenFrmTy.getNumRows(); thenNumCols = thenFrmTy.getNumCols(); } - if(elseMatTy) { + if (elseMatTy) { elseNumRows = elseMatTy.getNumRows(); elseNumCols = elseMatTy.getNumCols(); - } - else if(elseFrmTy) { + } else if (elseFrmTy) { elseNumRows = elseFrmTy.getNumRows(); elseNumCols = elseFrmTy.getNumCols(); } - if((thenMatTy || thenFrmTy) && (elseMatTy || elseFrmTy)) - return {{ - (thenNumRows == elseNumRows) ? thenNumRows : -1, - (thenNumCols == elseNumCols) ? thenNumCols : -1 - }}; + if ((thenMatTy || thenFrmTy) && (elseMatTy || elseFrmTy)) + return {{(thenNumRows == elseNumRows) ? thenNumRows : -1, (thenNumCols == elseNumCols) ? thenNumCols : -1}}; else // Then-value or else-value is a scalar. return {{-1, -1}}; @@ -342,7 +330,8 @@ std::vector> daphne::Conv2DForwardOp::inferShape() { ssize_t numRows = shapeX.first; ssize_t numCols = F == -1 ? -1 : F * Hout * Wout; - // op output is [mat, scalar, scalar] for the convolved data and its dimensions + // op output is [mat, scalar, scalar] for the convolved data and its + // dimensions return {{numRows, numCols}, std::make_pair(1, 1), std::make_pair(1, 1)}; } @@ -362,7 +351,8 @@ std::vector> daphne::AvgPoolForwardOp::inferShape() size_t Wout = std::floor((Win + 2 * padw - Wf) / stridew + 1); auto numCols = C * Hout * Wout; - // op output is [mat, scalar, scalar] for the convolved data and its dimensions + // op output is [mat, scalar, scalar] for the convolved data and its + // dimensions return {{numRows, numCols}, std::make_pair(1, 1), std::make_pair(1, 1)}; } @@ -382,7 +372,8 @@ std::vector> daphne::MaxPoolForwardOp::inferShape() size_t Wout = std::floor((Win + 2 * padw - Wf) / stridew + 1); auto numCols = C * Hout * Wout; - // op output is [mat, scalar, scalar] for the convolved data and its dimensions + // op output is [mat, scalar, scalar] for the convolved data and its + // dimensions return {{numRows, numCols}, std::make_pair(1, 1), std::make_pair(1, 1)}; } @@ -394,14 +385,13 @@ std::vector> daphne::CTableOp::inferShape() { // the lhs and rhs input matrices) and the lhs/rhs input matrices // are compile-time constants, then we could determine the number // of rows/columns here. - return {{ - CompilerUtils::constantOrDefault(getResNumRows(), -1), - CompilerUtils::constantOrDefault(getResNumCols(), -1) - }}; + return {{CompilerUtils::constantOrDefault(getResNumRows(), -1), + CompilerUtils::constantOrDefault(getResNumCols(), -1)}}; } std::vector> daphne::MatrixConstantOp::inferShape() { - const Structure* mat = reinterpret_cast(CompilerUtils::constantOrThrow(getMatrixAddr())); + const Structure *mat = + reinterpret_cast(CompilerUtils::constantOrThrow(getMatrixAddr())); return {{mat->getNumRows(), mat->getNumCols()}}; } @@ -409,57 +399,48 @@ std::vector> daphne::SliceRowOp::inferShape() { Type srcTy = getSource().getType(); ssize_t srcNumRows; ssize_t srcNumCols; - if(llvm::isa(srcTy)) { + if (llvm::isa(srcTy)) { srcNumRows = -1; srcNumCols = -1; - } - else if(auto srcMatTy = srcTy.dyn_cast()) { + } else if (auto srcMatTy = srcTy.dyn_cast()) { srcNumRows = srcMatTy.getNumRows(); srcNumCols = srcMatTy.getNumCols(); - } - else if(auto srcFrmTy = srcTy.dyn_cast()) { + } else if (auto srcFrmTy = srcTy.dyn_cast()) { srcNumRows = srcFrmTy.getNumRows(); srcNumCols = srcFrmTy.getNumCols(); - } - else + } else // If this is the case, shape inference shouldn't have been called. - throw ErrorHandler::compilerError( - getLoc(), "InferShapeOpInterface (daphne::SliceRowOp::inferShape)", - "SliceRowOp shape inference does only support unknown, matrix, and " - "frame inputs"); + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::SliceRowOp::inferShape)", + "SliceRowOp shape inference does only support unknown, matrix, and " + "frame inputs"); auto loIn = CompilerUtils::isConstant(getLowerIncl()); auto upEx = CompilerUtils::isConstant(getUpperExcl()); ssize_t resNumRows = -1; - if(srcNumRows != -1 && loIn.first && upEx.first) { + if (srcNumRows != -1 && loIn.first && upEx.first) { ssize_t loInPos = loIn.second; ssize_t upExPos = upEx.second; - if(loInPos < 0 || loInPos >= srcNumRows) - throw ErrorHandler::compilerError( - getLoc(), - "InferShapeOpInterface (daphne::SliceRowOp::inferShape)", - "SliceRowOp shape inference: lowerIncl must be in [0, numRows), " - "but is " + - std::to_string(loInPos) + " with " + - std::to_string(srcNumRows) + " rows"); - if(upExPos < 0 || upExPos > srcNumRows) - throw ErrorHandler::compilerError( - getLoc(), - "InferShapeOpInterface (daphne::SliceRowOp::inferShape)", - "SliceRowOp shape inference: upperExcl must be in [0, numRows], " - "but is " + std::to_string(upExPos) + - " with " + std::to_string(srcNumRows) + " rows" - ); - if(loInPos > upExPos) - throw ErrorHandler::compilerError( - getLoc(), - "InferShapeOpInterface (daphne::SliceRowOp::inferShape)", - "SliceRowOp shape inference: lowerIncl must not be greater " - "than upperExcl" - " (found " + - std::to_string(loInPos) + " and " + - std::to_string(upExPos) + ")"); + if (loInPos < 0 || loInPos >= srcNumRows) + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::SliceRowOp::inferShape)", + "SliceRowOp shape inference: lowerIncl must be in [0, " + "numRows), " + "but is " + + std::to_string(loInPos) + " with " + std::to_string(srcNumRows) + + " rows"); + if (upExPos < 0 || upExPos > srcNumRows) + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::SliceRowOp::inferShape)", + "SliceRowOp shape inference: upperExcl must be in [0, " + "numRows], " + "but is " + + std::to_string(upExPos) + " with " + std::to_string(srcNumRows) + + " rows"); + if (loInPos > upExPos) + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::SliceRowOp::inferShape)", + "SliceRowOp shape inference: lowerIncl must not be greater " + "than upperExcl" + " (found " + + std::to_string(loInPos) + " and " + std::to_string(upExPos) + ")"); resNumRows = upExPos - loInPos; } @@ -470,52 +451,45 @@ std::vector> daphne::SliceColOp::inferShape() { Type srcTy = getSource().getType(); ssize_t srcNumRows; ssize_t srcNumCols; - if(auto srcMatTy = srcTy.dyn_cast()) { + if (auto srcMatTy = srcTy.dyn_cast()) { srcNumRows = srcMatTy.getNumRows(); srcNumCols = srcMatTy.getNumCols(); - } - else if(auto srcFrmTy = srcTy.dyn_cast()) { + } else if (auto srcFrmTy = srcTy.dyn_cast()) { srcNumRows = srcFrmTy.getNumRows(); srcNumCols = srcFrmTy.getNumCols(); - } - else + } else // If this is the case, shape inference shouldn't have been called. - throw ErrorHandler::compilerError( - getLoc(), "InferShapeOpInterface (daphne::SliceColOp::inferShape)", - "SliceColOp shape inference does only support matrix and frame " - "inputs"); + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::SliceColOp::inferShape)", + "SliceColOp shape inference does only support matrix and frame " + "inputs"); auto loIn = CompilerUtils::isConstant(getLowerIncl()); auto upEx = CompilerUtils::isConstant(getUpperExcl()); ssize_t resNumCols = -1; - if(srcNumCols != -1 && loIn.first && upEx.first) { + if (srcNumCols != -1 && loIn.first && upEx.first) { ssize_t loInPos = loIn.second; ssize_t upExPos = upEx.second; - if(loInPos < 0 || loInPos >= srcNumCols) - throw ErrorHandler::compilerError( - getLoc(), - "InferShapeOpInterface (daphne::SliceColOp::inferShape)", - "SliceColOp shape inference: lowerIncl must be in [0, " - "numCols), " - "but is " + - std::to_string(loInPos) + " with " + - std::to_string(srcNumCols) + " cols"); - if(upExPos < 0 || upExPos > srcNumCols) - throw ErrorHandler::compilerError( - getLoc(), - "InferShapeOpInterface (daphne::SliceColOp::inferShape)", - "SliceColOp shape inference: upperExcl must be in [0, numCols], " - "but is " + std::to_string(upExPos) + - " with " + std::to_string(srcNumCols) + " cols" - ); - if(loInPos > upExPos) - throw ErrorHandler::compilerError( - getLoc(), - "InferShapeOpInterface (daphne::SliceColOp::inferShape)", - "SliceColOp shape inference: lowerIncl must not be greater than upperExcl" - " (found " + std::to_string(loInPos) + " and " + std::to_string(upExPos) + ")" - ); + if (loInPos < 0 || loInPos >= srcNumCols) + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::SliceColOp::inferShape)", + "SliceColOp shape inference: lowerIncl must be in [0, " + "numCols), " + "but is " + + std::to_string(loInPos) + " with " + std::to_string(srcNumCols) + + " cols"); + if (upExPos < 0 || upExPos > srcNumCols) + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::SliceColOp::inferShape)", + "SliceColOp shape inference: upperExcl must be in [0, " + "numCols], " + "but is " + + std::to_string(upExPos) + " with " + std::to_string(srcNumCols) + + " cols"); + if (loInPos > upExPos) + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::SliceColOp::inferShape)", + "SliceColOp shape inference: lowerIncl must not be greater " + "than upperExcl" + " (found " + + std::to_string(loInPos) + " and " + std::to_string(upExPos) + ")"); resNumCols = upEx.second - loIn.second; } @@ -526,8 +500,8 @@ std::vector> daphne::ExtractColOp::inferShape() { auto ft = getSource().getType().dyn_cast(); auto srcNumRows = getShape(getOperand(0)).first; auto st = getSelectedCols().getType().dyn_cast(); - - if(ft && st) { + + if (ft && st) { std::string label = CompilerUtils::constantOrThrow(getSelectedCols()); std::string delimiter = "."; const std::string frameName = label.substr(0, label.find(delimiter)); @@ -545,7 +519,7 @@ std::vector> daphne::ExtractColOp::inferShape() { } } // Default case except when the selectedCols ends in a wildcard - return{{srcNumRows, getShape(getOperand(1)).second}}; + return {{srcNumRows, getShape(getOperand(1)).second}}; } std::vector> daphne::EigenOp::inferShape() { @@ -562,25 +536,22 @@ std::vector> daphne::RecodeOp::inferShape() { ssize_t resNumRows; ssize_t resNumCols; - if(auto argMatTy = llvm::dyn_cast(argTy)) { + if (auto argMatTy = llvm::dyn_cast(argTy)) { resNumRows = argMatTy.getNumRows(); resNumCols = argMatTy.getNumCols(); - } - else if(auto argFrmTy = llvm::dyn_cast(argTy)) { + } else if (auto argFrmTy = llvm::dyn_cast(argTy)) { resNumRows = argFrmTy.getNumRows(); resNumCols = argFrmTy.getNumCols(); - } - else if(llvm::isa(argTy)) { + } else if (llvm::isa(argTy)) { resNumRows = -1; resNumCols = -1; - } - else - throw ErrorHandler::compilerError( - getLoc(), "InferShapeOpInterface (daphne::RecodeOp::inferShape)", - "the argument to recode has an invalid type"); + } else + throw ErrorHandler::compilerError(getLoc(), "InferShapeOpInterface (daphne::RecodeOp::inferShape)", + "the argument to recode has an invalid type"); - // TODO We could infer (or estimate) the number of rows of the dictionary result - // if we knew the number of distinct values in the argument (or could estimate it). + // TODO We could infer (or estimate) the number of rows of the dictionary + // result if we knew the number of distinct values in the argument (or could + // estimate it). const ssize_t dictNumRows = -1; const ssize_t dictNumCols = 1; @@ -595,69 +566,60 @@ std::vector> daphne::RecodeOp::inferShape() { * @brief Utility for trying a parametric trait for all values of the parameter * from 0 to some upper bound. */ -template class tryParametricTrait> -struct tryParamTraitUntil { - static void apply(ssize_t& numRows, ssize_t& numCols, Operation * op) { +template class tryParametricTrait> struct tryParamTraitUntil { + static void apply(ssize_t &numRows, ssize_t &numCols, Operation *op) { tryParametricTrait::apply(numRows, numCols, op); tryParamTraitUntil::apply(numRows, numCols, op); } }; -template class tryParametricTrait> -struct tryParamTraitUntil<0, tryParametricTrait> { - static void apply(ssize_t& numRows, ssize_t& numCols, Operation * op) { +template