Skip to content

Commit

Permalink
Introduced tile() function (#8)
Browse files Browse the repository at this point in the history
* Update README.md

* Added tile() function

* gemm() using tile()

* Add BLAZE_BLAS_MODE=0 in bench-blast

* edge -> side

* Fixed spelling errors

* Added -ftemplate-backtrace-limit=0 in cmake.yml workflow for better diagnostics
  • Loading branch information
mkatliar authored Mar 8, 2024
1 parent 9dbbee4 commit 153025e
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 82 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cmake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
cmake -B ${{github.workspace}}/build \
-DCMAKE_BUILD_TYPE=${{env.BUILD_TYPE}} \
-DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_CXX_FLAGS="-march=native -mfma -mavx -mavx2 -msse4 -fno-math-errno" \
-DCMAKE_CXX_FLAGS="-march=native -mfma -mavx -mavx2 -msse4 -fno-math-errno -ftemplate-backtrace-limit=0" \
-DCMAKE_CXX_FLAGS_RELEASE="-O3 -g -DNDEBUG -ffast-math" \
-DBLAST_WITH_BENCHMARK=OFF \
-DBLAST_WITH_TEST=ON
Expand Down
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# BLAST
*BLAST* is a high-performance linear algebra library that combines a BLAS-like nterface with modern C++ template metaprogramming.
*BLAST* (**BLAS** **T**emplates) is a high-performance linear algebra library that combines a BLAS-like interface with modern C++ template metaprogramming.
*BLAST* implementation is single-threaded and intended for matrices of small and medium size (a few hundred rows/columns), which is common for embedded control applications.

The name stands for **BLAS** **T**emplates.

The figures below shows the performance of BLAS *dgemm* routine for different LA implementations on an
*Intel(R) Core(TM) i7-9850H CPU @ 2.60GHz*:
![dgemm_performance](doc/dgemm_performance.png)
Expand Down Expand Up @@ -104,4 +102,4 @@ cd blast
docker build . --tag blast_bench .
docker run -v `pwd`/bench_result/docker:/root/blast/bench_result blast_bench
```
The benchmark results will be put in `/bench_result/docker`.
The benchmark results will be put in `/bench_result/docker`.
6 changes: 6 additions & 0 deletions bench/blast/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ add_executable(bench-blast
math/panel/DynamicPotrf.cpp
)

target_compile_definitions(bench-blast
# Use Blaze without linking to a BLAS library.
# Blaze is used to prepare data in some of the benchmarks.
PRIVATE BLAZE_BLAS_MODE=0
)

target_link_libraries(bench-blast
blast
bench-blast-common
Expand Down
124 changes: 124 additions & 0 deletions include/blast/math/algorithm/Tile.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright 2023 Mikhail Katliar
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <blast/system/Tile.hpp>
#include <blast/system/Inline.hpp>
#include <blast/math/StorageOrder.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>

#include <cstdlib>


namespace blast
{
/**
* @brief Cover a matrix with tiles of different sizes in a performance-efficient way.
*
* The tile sizes and positions are chosen based on matrix element type, storage order, and current system architecture.
* Positions and sizes of the tiles are such that the entire matrix is covered and tiles do not overlap.
*
* This function is helpful in implementing register-blocked matrix algorithms.
*
* For each tile one of the two specified functors @a f_full, @a f_partial is called:
*
* 1) @a f_full (ker, i, j); // if tile size equals ker.columns() by ker.rows()
* 2) @a f_partial (ker, i, j, km, kn); // if tile size is smaller than ker.columns() by ker.rows()
*
* where ker is a RegisterMatrix object, (i, j) are indices of top left corner of the tile,
* and (km, kn) are dimensions of the tile.
*
* @tparam ET type of matrix elements
* @tparam SO matrix storage order
* @tparam F functor type
*
* @param m number of matrix rows
* @param n number of matrix columns
* @param f_full functor to call on full tiles
* @param f_partial functor to call on partial tiles
*/
template <typename ET, StorageOrder SO, typename FF, typename FP>
BLAST_ALWAYS_INLINE void tile(std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial)
{
size_t constexpr TILE_SIZE = TileSize_v<ET>;

size_t j = 0;

// Main part
for (; j + TILE_SIZE <= n; j += TILE_SIZE)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * TILE_SIZE <= m && i + 4 * TILE_SIZE != m; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_full(ker, i, j);
}

for (; i + 2 * TILE_SIZE <= m; i += 2 * TILE_SIZE)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_full(ker, i, j);
}

for (; i + 1 * TILE_SIZE <= m; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_full(ker, i, j);
}

// Bottom side
if (i < m)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, m - i, ker.columns());
}
}


// Right side
if (j < n)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * TILE_SIZE <= m && i + 4 * TILE_SIZE != m; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 2 * TILE_SIZE <= m; i += 2 * TILE_SIZE)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 1 * TILE_SIZE <= m; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

// Bottom-right corner
if (i < m)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, m - i, n - j);
}
}
}
}
87 changes: 16 additions & 71 deletions include/blast/math/dense/Gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@

#pragma once

#include <blaze/math/StorageOrder.h>
#include <blaze/math/typetraits/StorageOrder.h>
#include <blaze/math/views/Forward.h>
#include <blast/math/dense/GemmBackend.hpp>
#include <blast/system/Inline.hpp>
#include <blast/math/typetraits/StorageOrder.hpp>
#include <blast/math/typetraits/MatrixPointer.hpp>
#include <blast/math/algorithm/Tile.hpp>
#include <blast/math/dense/MatrixPointer.hpp>

#include <algorithm>
#include <blaze/util/constraints/SameType.h>

#include <cstddef>
#include <type_traits>


namespace blast
{

/**
* @brief Performs the matrix-matrix operation
*
Expand Down Expand Up @@ -51,83 +54,25 @@ namespace blast
MatrixPointer<MPC> && StorageOrder_v<MPC> == columnMajor &&
MatrixPointer<MPD> && StorageOrder_v<MPD> == columnMajor
)
BLAZE_ALWAYS_INLINE void gemm(size_t M, size_t N, size_t K, ST1 alpha, MPA A, MPB B, ST2 beta, MPC C, MPD D)
BLAST_ALWAYS_INLINE void gemm(size_t M, size_t N, size_t K, ST1 alpha, MPA A, MPB B, ST2 beta, MPC C, MPD D)
{
using ET = std::remove_cv_t<ElementType_t<MPA>>;
using ET = std::remove_cv_t<ElementType_t<MPD>>;
size_t constexpr TILE_SIZE = TileSize_v<ET>;

BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPB>>, ET);
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPC>>, ET);
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPD>>, ET);

size_t j = 0;

// Main part
for (; j + TILE_SIZE <= N; j += TILE_SIZE)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j));
}

for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE)
tile<ET, StorageOrder(StorageOrder_v<MPD>)>(M, N,
[&] (auto& ker, size_t i, size_t j)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j));
}

for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j));
}

// Bottom edge
if (i < M)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), M - i, ker.columns());
}
}


// Right edge
if (j < N)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), ker.rows(), N - j);
}

for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), ker.rows(), N - j);
}

for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), ker.rows(), N - j);
}

// Bottom-right corner
if (i < M)
},
[&] (auto& ker, size_t i, size_t j, size_t m, size_t n)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), M - i, N - j);
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), m, n);
}
}
);
}


Expand Down
4 changes: 2 additions & 2 deletions include/blast/math/dense/Ger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ namespace blast
ker.store(B(i, j));
}

// Bottom edge
// Bottom side
if (i < M)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
Expand All @@ -108,7 +108,7 @@ namespace blast
}


// Right edge
// Right side
if (j < N)
{
size_t i = 0;
Expand Down
4 changes: 2 additions & 2 deletions include/blast/math/dense/Syrk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace blast
ker.store(ptr<aligned>(D, i, j));
}

// Bottom edge
// Bottom side
if (i < M)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
Expand All @@ -91,7 +91,7 @@ namespace blast
}


// Right edge
// Right side
if (j < M)
{
size_t i = j;
Expand Down
4 changes: 2 additions & 2 deletions include/blast/math/dense/Trmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ namespace blast
ker.store(ptr<aligned>(C, i, j));
}

// Bottom edge
// Bottom side
if (i < M)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
Expand All @@ -136,7 +136,7 @@ namespace blast
}


// Right edge
// Right side
if (j < N)
{
size_t i = 0;
Expand Down
19 changes: 19 additions & 0 deletions include/blast/system/Inline.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright 2023 Mikhail Katliar
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <blaze/system/Inline.h>

#define BLAST_ALWAYS_INLINE BLAZE_ALWAYS_INLINE

0 comments on commit 153025e

Please sign in to comment.