Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tile() function part 0.5 #8

Merged
merged 7 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# BLAST
*BLAST* is a high-performance linear algebra library that combines a BLAS-like nterface with modern C++ template metaprogramming.
*BLAST* (**BLAS** **T**emplates) is a high-performance linear algebra library that combines a BLAS-like nterface with modern C++ template metaprogramming.
mkatliar marked this conversation as resolved.
Show resolved Hide resolved
*BLAST* implementation is single-threaded and intended for matrices of small and medium size (a few hundred rows/columns), which is common for embedded control applications.

The name stands for **BLAS** **T**emplates.

The figures below shows the performance of BLAS *dgemm* routine for different LA implementations on an
*Intel(R) Core(TM) i7-9850H CPU @ 2.60GHz*:
![dgemm_performance](doc/dgemm_performance.png)
Expand Down Expand Up @@ -104,4 +102,4 @@ cd blast
docker build . --tag blast_bench .
docker run -v `pwd`/bench_result/docker:/root/blast/bench_result blast_bench
```
The benchmark results will be put in `/bench_result/docker`.
The benchmark results will be put in `/bench_result/docker`.
6 changes: 6 additions & 0 deletions bench/blast/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ add_executable(bench-blast
math/panel/DynamicPotrf.cpp
)

target_compile_definitions(bench-blast
# Use Blaze without linking to a BLAS library.
# Blaze is used to prepare data in some of the benchmarks.
PRIVATE BLAZE_BLAS_MODE=0
)

target_link_libraries(bench-blast
blast
bench-blast-common
Expand Down
124 changes: 124 additions & 0 deletions include/blast/math/algorithm/Tile.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright 2023 Mikhail Katliar
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <blast/system/Tile.hpp>
#include <blast/system/Inline.hpp>
#include <blast/math/StorageOrder.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>

#include <cstdlib>


namespace blast
{
/**
* @brief Cover a matrix with tiles of different sizes in a performance-efficient way.
*
* The tile sizes and positions are chosen based on matrix element type, storage order, and current system architecture.
* Positions and sizes of the tiles are such that the entire matrix is coveren and tiles do not overlap.
mkatliar marked this conversation as resolved.
Show resolved Hide resolved
*
* This function is helpful in implementing register-blocked matrix algorithms.
*
* For each tile one of the two specified functors @a f_full, @a f_partial is called:
*
* 1) @a f_full (ker, i, j); // if tile size equals ker.columns() by ker.rows()
* 2) @a f_partial (ker, i, j, km, kn); // if tile size is smaller than ker.columns() by ker.rows()
*
* where ker is a RegisterMatrix object, (i, j) are indices of top left corner of the tile,
* and (km, kn) are dimensions of the tile.
*
* @tparam ET type of matrix elements
* @tparam SO matrix storage order
* @tparam F functor type
*
* @param m number of matrix rows
* @param n number of matrix columns
* @param f_full functor to call on full tiles
* @param f_partial functor to call on partial tiles
*/
template <typename ET, StorageOrder SO, typename FF, typename FP>
BLAST_ALWAYS_INLINE void tile(std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial)
{
size_t constexpr TILE_SIZE = TileSize_v<ET>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if we have matrix of integers, booleans or complex numbers?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I think it will not compile, because we don't have TileSize_v<ET> defined for those types.


size_t j = 0;

// Main part
for (; j + TILE_SIZE <= n; j += TILE_SIZE)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit magic for me, can't really understand why is it more efficient :/

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where magic 3 and 4 come from?

Copy link
Owner Author

@mkatliar mkatliar Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic nombers come from a specific case when you have 16 AVX registers, each storing 4 doubles, and TILE_SIZE == 4. When you have 16 rows left, it is more efficient (based on performance test) to apply a 8-row kernel 2 times than a 12-row kernel and then a 4-row kernel.

This code is tied to a specific architecture and is actually very old. It should be re-written in more general way.

// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * TILE_SIZE <= m && i + 4 * TILE_SIZE != m; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why columnMajor here and not SO?

f_full(ker, i, j);
}

for (; i + 2 * TILE_SIZE <= m; i += 2 * TILE_SIZE)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_full(ker, i, j);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Мне тут пояснительная бригада нужна, зачем мы применяем функтор к пустой матрице? А, я понял, это чтобы этими черипичками покрыть в случае четного и нечетного количества строк? Или нет?

}

for (; i + 1 * TILE_SIZE <= m; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_full(ker, i, j);
}

// Bottom side
if (i < m)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, m - i, ker.columns());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TILE_SIZE != ker.columns() ?

}
}


// Right side
if (j < n)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * TILE_SIZE <= m && i + 4 * TILE_SIZE != m; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 2 * TILE_SIZE <= m; i += 2 * TILE_SIZE)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

for (; i + 1 * TILE_SIZE <= m; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, ker.rows(), n - j);
}

// Bottom-right corner
if (i < m)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
f_partial(ker, i, j, m - i, n - j);
}
}
}
}
87 changes: 16 additions & 71 deletions include/blast/math/dense/Gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@

#pragma once

#include <blaze/math/StorageOrder.h>
#include <blaze/math/typetraits/StorageOrder.h>
#include <blaze/math/views/Forward.h>
#include <blast/math/dense/GemmBackend.hpp>
#include <blast/system/Inline.hpp>
#include <blast/math/typetraits/StorageOrder.hpp>
#include <blast/math/typetraits/MatrixPointer.hpp>
#include <blast/math/algorithm/Tile.hpp>
#include <blast/math/dense/MatrixPointer.hpp>

#include <algorithm>
#include <blaze/util/constraints/SameType.h>

#include <cstddef>
#include <type_traits>


namespace blast
{

/**
* @brief Performs the matrix-matrix operation
*
Expand Down Expand Up @@ -51,83 +54,25 @@ namespace blast
MatrixPointer<MPC> && StorageOrder_v<MPC> == columnMajor &&
MatrixPointer<MPD> && StorageOrder_v<MPD> == columnMajor
)
BLAZE_ALWAYS_INLINE void gemm(size_t M, size_t N, size_t K, ST1 alpha, MPA A, MPB B, ST2 beta, MPC C, MPD D)
BLAST_ALWAYS_INLINE void gemm(size_t M, size_t N, size_t K, ST1 alpha, MPA A, MPB B, ST2 beta, MPC C, MPD D)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question, do we need to specify M, N, K? Are not they part of MPA, MPB, MPC?

{
using ET = std::remove_cv_t<ElementType_t<MPA>>;
using ET = std::remove_cv_t<ElementType_t<MPD>>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make a difference?

size_t constexpr TILE_SIZE = TileSize_v<ET>;

BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPB>>, ET);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to put in require above?

BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPC>>, ET);
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPD>>, ET);

size_t j = 0;

// Main part
for (; j + TILE_SIZE <= N; j += TILE_SIZE)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j));
}

for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE)
tile<ET, StorageOrder(StorageOrder_v<MPD>)>(M, N,
[&] (auto& ker, size_t i, size_t j)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j));
}

for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j));
}

// Bottom edge
if (i < M)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), M - i, ker.columns());
}
}


// Right edge
if (j < N)
{
size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), ker.rows(), N - j);
}

for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), ker.rows(), N - j);
}

for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), ker.rows(), N - j);
}

// Bottom-right corner
if (i < M)
},
[&] (auto& ker, size_t i, size_t j, size_t m, size_t n)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), M - i, N - j);
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), m, n);
}
}
);
}


Expand Down
4 changes: 2 additions & 2 deletions include/blast/math/dense/Ger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ namespace blast
ker.store(B(i, j));
}

// Bottom edge
// Bottom side
if (i < M)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
Expand All @@ -108,7 +108,7 @@ namespace blast
}


// Right edge
// Right side
if (j < N)
{
size_t i = 0;
Expand Down
4 changes: 2 additions & 2 deletions include/blast/math/dense/Syrk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace blast
ker.store(ptr<aligned>(D, i, j));
}

// Bottom edge
// Bottom side
if (i < M)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
Expand All @@ -91,7 +91,7 @@ namespace blast
}


// Right edge
// Right side
if (j < M)
{
size_t i = j;
Expand Down
4 changes: 2 additions & 2 deletions include/blast/math/dense/Trmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ namespace blast
ker.store(ptr<aligned>(C, i, j));
}

// Bottom edge
// Bottom side
if (i < M)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
Expand All @@ -136,7 +136,7 @@ namespace blast
}


// Right edge
// Right side
if (j < N)
{
size_t i = 0;
Expand Down
19 changes: 19 additions & 0 deletions include/blast/system/Inline.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright 2023 Mikhail Katliar
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <blaze/system/Inline.h>

#define BLAST_ALWAYS_INLINE BLAZE_ALWAYS_INLINE
Loading