From 9defa20e208332800aaa2c9e7c969c5fd98535ed Mon Sep 17 00:00:00 2001 From: Mikhail Katliar Date: Thu, 12 Sep 2024 15:46:26 +0200 Subject: [PATCH] Make the code compile for ARM again --- include/blast/math/algorithm/Tile.hpp | 4 + .../blast/math/algorithm/arch/avx2/Tile.hpp | 3 +- .../blast/math/algorithm/arch/neon64/Tile.hpp | 141 ++++++++++++++++++ include/blast/math/dense/Trmm.hpp | 1 + include/blast/math/panel/PanelSize.hpp | 1 - include/blast/system/Tile.hpp | 12 +- 6 files changed, 152 insertions(+), 10 deletions(-) create mode 100644 include/blast/math/algorithm/arch/neon64/Tile.hpp diff --git a/include/blast/math/algorithm/Tile.hpp b/include/blast/math/algorithm/Tile.hpp index 7f6ecc63..3b4efef6 100644 --- a/include/blast/math/algorithm/Tile.hpp +++ b/include/blast/math/algorithm/Tile.hpp @@ -8,6 +8,10 @@ # include #endif +#if XSIMD_WITH_NEON64 +# include +#endif + #include #include diff --git a/include/blast/math/algorithm/arch/avx2/Tile.hpp b/include/blast/math/algorithm/arch/avx2/Tile.hpp index aac90b8a..f7decfde 100644 --- a/include/blast/math/algorithm/arch/avx2/Tile.hpp +++ b/include/blast/math/algorithm/arch/avx2/Tile.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include @@ -46,7 +47,7 @@ namespace blast :: detail BLAST_ALWAYS_INLINE void tile(xsimd::avx2 const& arch, StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial) { size_t constexpr SS = SimdSize_v; - size_t constexpr TILE_STEP = 4; // TODO: this is almost arbitrary and needs to be ppoperly determined + size_t constexpr TILE_STEP = 4; // TODO: this is almost arbitrary and needs to be properly determined static_assert(SO == columnMajor, "tile() for row-major matrices not implemented"); diff --git a/include/blast/math/algorithm/arch/neon64/Tile.hpp b/include/blast/math/algorithm/arch/neon64/Tile.hpp new file mode 100644 index 00000000..652a7917 --- /dev/null +++ b/include/blast/math/algorithm/arch/neon64/Tile.hpp @@ -0,0 +1,141 @@ +// Copyright 2024 Mikhail Katliar. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#include +#include +#include +#include +#include + +#include + +#include + + +namespace blast :: detail +{ + template + BLAST_ALWAYS_INLINE void tile_backend(xsimd::neon64, size_t m, size_t n, size_t i, FF&& f_full, FP&& f_partial) + { + RegisterMatrix ker; + + if (i + KM <= m) + { + size_t j = 0; + + for (; j + KN <= n; j += KN) + f_full(ker, i, j); + + if (j < n) + f_partial(ker, i, j, KM, n - j); + } + else + { + size_t j = 0; + + for (; j + KN <= n; j += KN) + f_partial(ker, i, j, m - i, KN); + + if (j < n) + f_partial(ker, i, j, m - i, n - j); + } + } + + + template + BLAST_ALWAYS_INLINE void tile(xsimd::neon64 const& arch, StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial) + { + size_t constexpr SS = SimdSize_v; + size_t constexpr TILE_STEP = 4; // TODO: this is almost arbitrary and needs to be properly determined + + static_assert(SO == columnMajor, "tile() for row-major matrices not implemented"); + + if (traversal_order == columnMajor) + { + size_t j = 0; + + // Main part + for (; j + TILE_STEP <= n; j += TILE_STEP) + { + size_t i = 0; + + // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: + // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. + for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS) + { + RegisterMatrix ker; + f_full(ker, i, j); + } + + for (; i + 2 * SS <= m; i += 2 * SS) + { + RegisterMatrix ker; + f_full(ker, i, j); + } + + for (; i + 1 * SS <= m; i += 1 * SS) + { + RegisterMatrix ker; + f_full(ker, i, j); + } + + // Bottom side + if (i < m) + { + RegisterMatrix ker; + f_partial(ker, i, j, m - i, ker.columns()); + } + } + + + // Right side + if (j < n) + { + size_t i = 0; + + // i + 4 * TILE_STEP != M is to improve performance in case when the remaining number of rows is 4 * TILE_STEP: + // it is more efficient to apply 2 * TILE_STEP kernel 2 times than 3 * TILE_STEP + 1 * TILE_STEP kernel. + for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS) + { + RegisterMatrix ker; + f_partial(ker, i, j, ker.rows(), n - j); + } + + for (; i + 2 * SS <= m; i += 2 * SS) + { + RegisterMatrix ker; + f_partial(ker, i, j, ker.rows(), n - j); + } + + for (; i + 1 * SS <= m; i += 1 * SS) + { + RegisterMatrix ker; + f_partial(ker, i, j, ker.rows(), n - j); + } + + // Bottom-right corner + if (i < m) + { + RegisterMatrix ker; + f_partial(ker, i, j, m - i, n - j); + } + } + } + else + { + size_t i = 0; + + // i + 4 * SS != M is to improve performance in case when the remaining number of rows is 4 * SS: + // it is more efficient to apply 2 * SS kernel 2 times than 3 * SS + 1 * SS kernel. + for (; i + 2 * SS < m && i + 4 * SS != m; i += 3 * SS) + tile_backend(arch, m, n, i, f_full, f_partial); + + for (; i + 1 * SS < m; i += 2 * SS) + tile_backend(arch, m, n, i, f_full, f_partial); + + for (; i + 0 * SS < m; i += 1 * SS) + tile_backend(arch, m, n, i, f_full, f_partial); + } + } +} diff --git a/include/blast/math/dense/Trmm.hpp b/include/blast/math/dense/Trmm.hpp index 34dc7c4f..c82bd674 100644 --- a/include/blast/math/dense/Trmm.hpp +++ b/include/blast/math/dense/Trmm.hpp @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include diff --git a/include/blast/math/panel/PanelSize.hpp b/include/blast/math/panel/PanelSize.hpp index 309e4a7e..9f9b500e 100644 --- a/include/blast/math/panel/PanelSize.hpp +++ b/include/blast/math/panel/PanelSize.hpp @@ -19,7 +19,6 @@ namespace blast * * TODO: Is it always equal to SIMD size? Deprecate? * - * @tparam T data type * @tparam Arch architecture */ template diff --git a/include/blast/system/Tile.hpp b/include/blast/system/Tile.hpp index dd471a6f..7820eb13 100644 --- a/include/blast/system/Tile.hpp +++ b/include/blast/system/Tile.hpp @@ -4,18 +4,14 @@ #pragma once -//************************************************************************************************* -// Includes -//************************************************************************************************* - -#include +#include namespace blast { - using namespace blaze; - - + /** + * @brief TODO: deprecate? + */ template struct TileSize;