Skip to content

Commit

Permalink
Moved avx2-specific implementation of tile() to blast/math/algorithm/…
Browse files Browse the repository at this point in the history
…arch/avx2
  • Loading branch information
mkatliar committed Aug 13, 2024
1 parent 3dee737 commit 49030d9
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 130 deletions.
3 changes: 2 additions & 1 deletion include/blast/math/algorithm/Gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ namespace blast
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPD>>, ET);

tile<ET, StorageOrder(StorageOrder_v<MPD>)>(
xsimd::default_arch {},
D.cachePreferredTraversal,
M, N,
[&] (auto& ker, size_t i, size_t j)
Expand All @@ -74,4 +75,4 @@ namespace blast
}
);
}
}
}
141 changes: 14 additions & 127 deletions include/blast/math/algorithm/Tile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <blast/system/Tile.hpp>
#include <blast/system/Inline.hpp>
#include <blast/math/simd/SimdSize.hpp>
#include <blast/math/Simd.hpp>

#if XSIMD_WITH_AVX2
# include <blast/math/algorithm/arch/avx2/Tile.hpp>
#endif

#include <blast/math/StorageOrder.hpp>
#include <blast/math/RegisterMatrix.hpp>
#include <blast/system/Inline.hpp>

#include <cstdlib>


namespace blast
{
template <typename ET, size_t KM, size_t KN, StorageOrder SO, typename FF, typename FP>
BLAZE_ALWAYS_INLINE void tile_backend(size_t m, size_t n, size_t i, FF&& f_full, FP&& f_partial)
{
RegisterMatrix<ET, KM, KN, SO> ker;

if (i + KM <= m)
{
size_t j = 0;

for (; j + KN <= n; j += KN)
f_full(ker, i, j);

if (j < n)
f_partial(ker, i, j, KM, n - j);
}
else
{
size_t j = 0;

for (; j + KN <= n; j += KN)
f_partial(ker, i, j, m - i, KN);

if (j < n)
f_partial(ker, i, j, m - i, n - j);
}
}


/**
* @brief Cover a matrix with tiles of different sizes in a performance-efficient way.
*
Expand All @@ -69,106 +44,18 @@ namespace blast
*
* @tparam ET type of matrix elements
* @tparam SO matrix storage order
* @tparam F functor type
* @tparam FF functor type for full tiles
* @tparam FP functor type for partial tiles
* @tparam Arch instruction set architecture
*
* @param m number of matrix rows
* @param n number of matrix columns
* @param f_full functor to call on full tiles
* @param f_partial functor to call on partial tiles
*/
template <typename ET, StorageOrder SO, typename FF, typename FP>
BLAST_ALWAYS_INLINE void tile(StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial)
template <typename ET, StorageOrder SO, typename FF, typename FP, typename Arch>
BLAST_ALWAYS_INLINE void tile(Arch arch, StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial)
{
size_t constexpr SS = SimdSize_v<ET>;
size_t constexpr TILE_STEP = 4; // TODO: this is almost arbitrary and needs to be ppoperly determined

static_assert(SO == columnMajor, "tile() for row-major matrices not implemented");

if (traversal_order == columnMajor)
{
size_t j = 0;

// Main part
for (; j + TILE_STEP <= n; j += TILE_STEP)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS)
{
RegisterMatrix<ET, 3 * SS, TILE_STEP, SO> ker;
f_full(ker, i, j);
}

for (; i + 2 * SS <= m; i += 2 * SS)
{
RegisterMatrix<ET, 2 * SS, TILE_STEP, SO> ker;
f_full(ker, i, j);
}

for (; i + 1 * SS <= m; i += 1 * SS)
{
RegisterMatrix<ET, 1 * SS, TILE_STEP, SO> ker;
f_full(ker, i, j);
}

// Bottom side
if (i < m)
{
RegisterMatrix<ET, SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, m - i, ker.columns());
}
}


// Right side
if (j < n)
{
size_t i = 0;

// i + 4 * TILE_STEP != M is to improve performance in case when the remaining number of rows is 4 * TILE_STEP:
// it is more efficient to apply 2 * TILE_STEP kernel 2 times than 3 * TILE_STEP + 1 * TILE_STEP kernel.
for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS)
{
RegisterMatrix<ET, 3 * SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 2 * SS <= m; i += 2 * SS)
{
RegisterMatrix<ET, 2 * SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 1 * SS <= m; i += 1 * SS)
{
RegisterMatrix<ET, 1 * SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

// Bottom-right corner
if (i < m)
{
RegisterMatrix<ET, SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, m - i, n - j);
}
}
}
else
{
size_t i = 0;

// i + 4 * SS != M is to improve performance in case when the remaining number of rows is 4 * SS:
// it is more efficient to apply 2 * SS kernel 2 times than 3 * SS + 1 * SS kernel.
for (; i + 2 * SS < m && i + 4 * SS != m; i += 3 * SS)
tile_backend<ET, 3 * SS, TILE_STEP, SO>(m, n, i, f_full, f_partial);

for (; i + 1 * SS < m; i += 2 * SS)
tile_backend<ET, 2 * SS, TILE_STEP, SO>(m, n, i, f_full, f_partial);

for (; i + 0 * SS < m; i += 1 * SS)
tile_backend<ET, 1 * SS, TILE_STEP, SO>(m, n, i, f_full, f_partial);
}
detail::tile<ET, SO>(arch, traversal_order, m, n, f_full, f_partial);
}
}
}
150 changes: 150 additions & 0 deletions include/blast/math/algorithm/arch/avx2/Tile.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright 2023 Mikhail Katliar
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <blast/system/Tile.hpp>
#include <blast/system/Inline.hpp>
#include <blast/math/StorageOrder.hpp>
#include <blast/math/RegisterMatrix.hpp>

#include <blast/math/Simd.hpp>

#include <cstdlib>


namespace blast :: detail
{
template <typename ET, size_t KM, size_t KN, StorageOrder SO, typename FF, typename FP>
BLAST_ALWAYS_INLINE void tile_backend(xsimd::avx2, size_t m, size_t n, size_t i, FF&& f_full, FP&& f_partial)
{
RegisterMatrix<ET, KM, KN, SO> ker;

if (i + KM <= m)
{
size_t j = 0;

for (; j + KN <= n; j += KN)
f_full(ker, i, j);

if (j < n)
f_partial(ker, i, j, KM, n - j);
}
else
{
size_t j = 0;

for (; j + KN <= n; j += KN)
f_partial(ker, i, j, m - i, KN);

if (j < n)
f_partial(ker, i, j, m - i, n - j);
}
}


template <typename ET, StorageOrder SO, typename FF, typename FP>
BLAST_ALWAYS_INLINE void tile(xsimd::avx2 const& arch, StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial)
{
size_t constexpr SS = SimdSize_v<ET>;
size_t constexpr TILE_STEP = 4; // TODO: this is almost arbitrary and needs to be ppoperly determined

static_assert(SO == columnMajor, "tile() for row-major matrices not implemented");

if (traversal_order == columnMajor)
{
size_t j = 0;

// Main part
for (; j + TILE_STEP <= n; j += TILE_STEP)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS)
{
RegisterMatrix<ET, 3 * SS, TILE_STEP, SO> ker;
f_full(ker, i, j);
}

for (; i + 2 * SS <= m; i += 2 * SS)
{
RegisterMatrix<ET, 2 * SS, TILE_STEP, SO> ker;
f_full(ker, i, j);
}

for (; i + 1 * SS <= m; i += 1 * SS)
{
RegisterMatrix<ET, 1 * SS, TILE_STEP, SO> ker;
f_full(ker, i, j);
}

// Bottom side
if (i < m)
{
RegisterMatrix<ET, SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, m - i, ker.columns());
}
}


// Right side
if (j < n)
{
size_t i = 0;

// i + 4 * TILE_STEP != M is to improve performance in case when the remaining number of rows is 4 * TILE_STEP:
// it is more efficient to apply 2 * TILE_STEP kernel 2 times than 3 * TILE_STEP + 1 * TILE_STEP kernel.
for (; i + 3 * SS <= m && i + 4 * SS != m; i += 3 * SS)
{
RegisterMatrix<ET, 3 * SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 2 * SS <= m; i += 2 * SS)
{
RegisterMatrix<ET, 2 * SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 1 * SS <= m; i += 1 * SS)
{
RegisterMatrix<ET, 1 * SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

// Bottom-right corner
if (i < m)
{
RegisterMatrix<ET, SS, TILE_STEP, SO> ker;
f_partial(ker, i, j, m - i, n - j);
}
}
}
else
{
size_t i = 0;

// i + 4 * SS != M is to improve performance in case when the remaining number of rows is 4 * SS:
// it is more efficient to apply 2 * SS kernel 2 times than 3 * SS + 1 * SS kernel.
for (; i + 2 * SS < m && i + 4 * SS != m; i += 3 * SS)
tile_backend<ET, 3 * SS, TILE_STEP, SO>(arch, m, n, i, f_full, f_partial);

for (; i + 1 * SS < m; i += 2 * SS)
tile_backend<ET, 2 * SS, TILE_STEP, SO>(arch, m, n, i, f_full, f_partial);

for (; i + 0 * SS < m; i += 1 * SS)
tile_backend<ET, 1 * SS, TILE_STEP, SO>(arch, m, n, i, f_full, f_partial);
}
}
}
12 changes: 10 additions & 2 deletions include/blast/system/Inline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@

#pragma once

#include <blaze/system/Inline.h>
#if defined(_MSC_VER) || defined(__INTEL_COMPILER)
# define BLAST_STRONG_INLINE __forceinline
#else
# define BLAST_STRONG_INLINE inline
#endif

#define BLAST_ALWAYS_INLINE BLAZE_ALWAYS_INLINE
#if defined(__GNUC__)
# define BLAST_ALWAYS_INLINE __attribute__((always_inline)) inline
#else
# define BLAST_ALWAYS_INLINE BLAST_STRONG_INLINE
#endif

0 comments on commit 49030d9

Please sign in to comment.