diff --git a/include/blast/math/algorithm/Gemm.hpp b/include/blast/math/algorithm/Gemm.hpp index 211c3acb..36b14939 100644 --- a/include/blast/math/algorithm/Gemm.hpp +++ b/include/blast/math/algorithm/Gemm.hpp @@ -62,6 +62,7 @@ namespace blast BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t>, ET); tile)>( + xsimd::default_arch {}, D.cachePreferredTraversal, M, N, [&] (auto& ker, size_t i, size_t j) @@ -74,4 +75,4 @@ namespace blast } ); } -} \ No newline at end of file +} diff --git a/include/blast/math/algorithm/Tile.hpp b/include/blast/math/algorithm/Tile.hpp index 41d7aadb..280267fa 100644 --- a/include/blast/math/algorithm/Tile.hpp +++ b/include/blast/math/algorithm/Tile.hpp @@ -12,45 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include +#include + +#if XSIMD_WITH_AVX2 +# include +#endif + #include -#include +#include #include namespace blast { - template - BLAZE_ALWAYS_INLINE void tile_backend(size_t m, size_t n, size_t i, FF&& f_full, FP&& f_partial) - { - RegisterMatrix ker; - - if (i + KM <= m) - { - size_t j = 0; - - for (; j + KN <= n; j += KN) - f_full(ker, i, j); - - if (j < n) - f_partial(ker, i, j, KM, n - j); - } - else - { - size_t j = 0; - - for (; j + KN <= n; j += KN) - f_partial(ker, i, j, m - i, KN); - - if (j < n) - f_partial(ker, i, j, m - i, n - j); - } - } - - /** * @brief Cover a matrix with tiles of different sizes in a performance-efficient way. * @@ -69,106 +44,18 @@ namespace blast * * @tparam ET type of matrix elements * @tparam SO matrix storage order - * @tparam F functor type + * @tparam FF functor type for full tiles + * @tparam FP functor type for partial tiles + * @tparam Arch instruction set architecture * * @param m number of matrix rows * @param n number of matrix columns * @param f_full functor to call on full tiles * @param f_partial functor to call on partial tiles */ - template - BLAST_ALWAYS_INLINE void tile(StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial) + template + BLAST_ALWAYS_INLINE void tile(Arch arch, StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial) { - size_t constexpr SS = SimdSize_v; - size_t constexpr TILE_STEP = 4; // TODO: this is almost arbitrary and needs to be ppoperly determined - - static_assert(SO == columnMajor, "tile() for row-major matrices not implemented"); - - if (traversal_order == columnMajor) - { - size_t j = 0; - - // Main part - for (; j + TILE_STEP <= n; j += TILE_STEP) - { - size_t i = 0; - - // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: - // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. - for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS) - { - RegisterMatrix ker; - f_full(ker, i, j); - } - - for (; i + 2 * SS <= m; i += 2 * SS) - { - RegisterMatrix ker; - f_full(ker, i, j); - } - - for (; i + 1 * SS <= m; i += 1 * SS) - { - RegisterMatrix ker; - f_full(ker, i, j); - } - - // Bottom side - if (i < m) - { - RegisterMatrix ker; - f_partial(ker, i, j, m - i, ker.columns()); - } - } - - - // Right side - if (j < n) - { - size_t i = 0; - - // i + 4 * TILE_STEP != M is to improve performance in case when the remaining number of rows is 4 * TILE_STEP: - // it is more efficient to apply 2 * TILE_STEP kernel 2 times than 3 * TILE_STEP + 1 * TILE_STEP kernel. - for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS) - { - RegisterMatrix ker; - f_partial(ker, i, j, ker.rows(), n - j); - } - - for (; i + 2 * SS <= m; i += 2 * SS) - { - RegisterMatrix ker; - f_partial(ker, i, j, ker.rows(), n - j); - } - - for (; i + 1 * SS <= m; i += 1 * SS) - { - RegisterMatrix ker; - f_partial(ker, i, j, ker.rows(), n - j); - } - - // Bottom-right corner - if (i < m) - { - RegisterMatrix ker; - f_partial(ker, i, j, m - i, n - j); - } - } - } - else - { - size_t i = 0; - - // i + 4 * SS != M is to improve performance in case when the remaining number of rows is 4 * SS: - // it is more efficient to apply 2 * SS kernel 2 times than 3 * SS + 1 * SS kernel. - for (; i + 2 * SS < m && i + 4 * SS != m; i += 3 * SS) - tile_backend(m, n, i, f_full, f_partial); - - for (; i + 1 * SS < m; i += 2 * SS) - tile_backend(m, n, i, f_full, f_partial); - - for (; i + 0 * SS < m; i += 1 * SS) - tile_backend(m, n, i, f_full, f_partial); - } + detail::tile(arch, traversal_order, m, n, f_full, f_partial); } -} \ No newline at end of file +} diff --git a/include/blast/math/algorithm/arch/avx2/Tile.hpp b/include/blast/math/algorithm/arch/avx2/Tile.hpp new file mode 100644 index 00000000..01974521 --- /dev/null +++ b/include/blast/math/algorithm/arch/avx2/Tile.hpp @@ -0,0 +1,150 @@ +// Copyright 2023 Mikhail Katliar +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include + +#include + + +namespace blast :: detail +{ + template + BLAST_ALWAYS_INLINE void tile_backend(xsimd::avx2, size_t m, size_t n, size_t i, FF&& f_full, FP&& f_partial) + { + RegisterMatrix ker; + + if (i + KM <= m) + { + size_t j = 0; + + for (; j + KN <= n; j += KN) + f_full(ker, i, j); + + if (j < n) + f_partial(ker, i, j, KM, n - j); + } + else + { + size_t j = 0; + + for (; j + KN <= n; j += KN) + f_partial(ker, i, j, m - i, KN); + + if (j < n) + f_partial(ker, i, j, m - i, n - j); + } + } + + + template + BLAST_ALWAYS_INLINE void tile(xsimd::avx2 const& arch, StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial) + { + size_t constexpr SS = SimdSize_v; + size_t constexpr TILE_STEP = 4; // TODO: this is almost arbitrary and needs to be ppoperly determined + + static_assert(SO == columnMajor, "tile() for row-major matrices not implemented"); + + if (traversal_order == columnMajor) + { + size_t j = 0; + + // Main part + for (; j + TILE_STEP <= n; j += TILE_STEP) + { + size_t i = 0; + + // i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: + // it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. + for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS) + { + RegisterMatrix ker; + f_full(ker, i, j); + } + + for (; i + 2 * SS <= m; i += 2 * SS) + { + RegisterMatrix ker; + f_full(ker, i, j); + } + + for (; i + 1 * SS <= m; i += 1 * SS) + { + RegisterMatrix ker; + f_full(ker, i, j); + } + + // Bottom side + if (i < m) + { + RegisterMatrix ker; + f_partial(ker, i, j, m - i, ker.columns()); + } + } + + + // Right side + if (j < n) + { + size_t i = 0; + + // i + 4 * TILE_STEP != M is to improve performance in case when the remaining number of rows is 4 * TILE_STEP: + // it is more efficient to apply 2 * TILE_STEP kernel 2 times than 3 * TILE_STEP + 1 * TILE_STEP kernel. + for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS) + { + RegisterMatrix ker; + f_partial(ker, i, j, ker.rows(), n - j); + } + + for (; i + 2 * SS <= m; i += 2 * SS) + { + RegisterMatrix ker; + f_partial(ker, i, j, ker.rows(), n - j); + } + + for (; i + 1 * SS <= m; i += 1 * SS) + { + RegisterMatrix ker; + f_partial(ker, i, j, ker.rows(), n - j); + } + + // Bottom-right corner + if (i < m) + { + RegisterMatrix ker; + f_partial(ker, i, j, m - i, n - j); + } + } + } + else + { + size_t i = 0; + + // i + 4 * SS != M is to improve performance in case when the remaining number of rows is 4 * SS: + // it is more efficient to apply 2 * SS kernel 2 times than 3 * SS + 1 * SS kernel. + for (; i + 2 * SS < m && i + 4 * SS != m; i += 3 * SS) + tile_backend(arch, m, n, i, f_full, f_partial); + + for (; i + 1 * SS < m; i += 2 * SS) + tile_backend(arch, m, n, i, f_full, f_partial); + + for (; i + 0 * SS < m; i += 1 * SS) + tile_backend(arch, m, n, i, f_full, f_partial); + } + } +} diff --git a/include/blast/system/Inline.hpp b/include/blast/system/Inline.hpp index 8ec59251..ff88fb88 100644 --- a/include/blast/system/Inline.hpp +++ b/include/blast/system/Inline.hpp @@ -14,6 +14,14 @@ #pragma once -#include +#if defined(_MSC_VER) || defined(__INTEL_COMPILER) +# define BLAST_STRONG_INLINE __forceinline +#else +# define BLAST_STRONG_INLINE inline +#endif -#define BLAST_ALWAYS_INLINE BLAZE_ALWAYS_INLINE +#if defined(__GNUC__) +# define BLAST_ALWAYS_INLINE __attribute__((always_inline)) inline +#else +# define BLAST_ALWAYS_INLINE BLAST_STRONG_INLINE +#endif