-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tile() function part 0.5 #8
Changes from all commits
d75b59c
3116abf
e33d7b4
c788324
6fd9c95
0b79d8b
391e92a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
// Copyright 2023 Mikhail Katliar | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <blast/system/Tile.hpp> | ||
#include <blast/system/Inline.hpp> | ||
#include <blast/math/StorageOrder.hpp> | ||
#include <blast/math/simd/RegisterMatrix.hpp> | ||
|
||
#include <cstdlib> | ||
|
||
|
||
namespace blast | ||
{ | ||
/** | ||
* @brief Cover a matrix with tiles of different sizes in a performance-efficient way. | ||
* | ||
* The tile sizes and positions are chosen based on matrix element type, storage order, and current system architecture. | ||
* Positions and sizes of the tiles are such that the entire matrix is covered and tiles do not overlap. | ||
* | ||
* This function is helpful in implementing register-blocked matrix algorithms. | ||
* | ||
* For each tile one of the two specified functors @a f_full, @a f_partial is called: | ||
* | ||
* 1) @a f_full (ker, i, j); // if tile size equals ker.columns() by ker.rows() | ||
* 2) @a f_partial (ker, i, j, km, kn); // if tile size is smaller than ker.columns() by ker.rows() | ||
* | ||
* where ker is a RegisterMatrix object, (i, j) are indices of top left corner of the tile, | ||
* and (km, kn) are dimensions of the tile. | ||
* | ||
* @tparam ET type of matrix elements | ||
* @tparam SO matrix storage order | ||
* @tparam F functor type | ||
* | ||
* @param m number of matrix rows | ||
* @param n number of matrix columns | ||
* @param f_full functor to call on full tiles | ||
* @param f_partial functor to call on partial tiles | ||
*/ | ||
template <typename ET, StorageOrder SO, typename FF, typename FP> | ||
BLAST_ALWAYS_INLINE void tile(std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial) | ||
{ | ||
size_t constexpr TILE_SIZE = TileSize_v<ET>; | ||
|
||
size_t j = 0; | ||
|
||
// Main part | ||
for (; j + TILE_SIZE <= n; j += TILE_SIZE) | ||
{ | ||
size_t i = 0; | ||
|
||
// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit magic for me, can't really understand why is it more efficient :/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where magic 3 and 4 come from? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The magic nombers come from a specific case when you have 16 AVX registers, each storing 4 This code is tied to a specific architecture and is actually very old. It should be re-written in more general way. |
||
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. | ||
for (; i + 3 * TILE_SIZE <= m && i + 4 * TILE_SIZE != m; i += 3 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why |
||
f_full(ker, i, j); | ||
} | ||
|
||
for (; i + 2 * TILE_SIZE <= m; i += 2 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
f_full(ker, i, j); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Мне тут пояснительная бригада нужна, зачем мы применяем функтор к пустой матрице? А, я понял, это чтобы этими черипичками покрыть в случае четного и нечетного количества строк? Или нет? |
||
} | ||
|
||
for (; i + 1 * TILE_SIZE <= m; i += 1 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
f_full(ker, i, j); | ||
} | ||
|
||
// Bottom side | ||
if (i < m) | ||
{ | ||
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
f_partial(ker, i, j, m - i, ker.columns()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
} | ||
|
||
|
||
// Right side | ||
if (j < n) | ||
{ | ||
size_t i = 0; | ||
|
||
// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: | ||
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. | ||
for (; i + 3 * TILE_SIZE <= m && i + 4 * TILE_SIZE != m; i += 3 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
f_partial(ker, i, j, ker.rows(), n - j); | ||
} | ||
|
||
for (; i + 2 * TILE_SIZE <= m; i += 2 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
f_partial(ker, i, j, ker.rows(), n - j); | ||
} | ||
|
||
for (; i + 1 * TILE_SIZE <= m; i += 1 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
f_partial(ker, i, j, ker.rows(), n - j); | ||
} | ||
|
||
// Bottom-right corner | ||
if (i < m) | ||
{ | ||
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
f_partial(ker, i, j, m - i, n - j); | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,18 +4,21 @@ | |
|
||
#pragma once | ||
|
||
#include <blaze/math/StorageOrder.h> | ||
#include <blaze/math/typetraits/StorageOrder.h> | ||
#include <blaze/math/views/Forward.h> | ||
#include <blast/math/dense/GemmBackend.hpp> | ||
#include <blast/system/Inline.hpp> | ||
#include <blast/math/typetraits/StorageOrder.hpp> | ||
#include <blast/math/typetraits/MatrixPointer.hpp> | ||
#include <blast/math/algorithm/Tile.hpp> | ||
#include <blast/math/dense/MatrixPointer.hpp> | ||
|
||
#include <algorithm> | ||
#include <blaze/util/constraints/SameType.h> | ||
|
||
#include <cstddef> | ||
#include <type_traits> | ||
|
||
|
||
namespace blast | ||
{ | ||
|
||
/** | ||
* @brief Performs the matrix-matrix operation | ||
* | ||
|
@@ -51,83 +54,25 @@ namespace blast | |
MatrixPointer<MPC> && StorageOrder_v<MPC> == columnMajor && | ||
MatrixPointer<MPD> && StorageOrder_v<MPD> == columnMajor | ||
) | ||
BLAZE_ALWAYS_INLINE void gemm(size_t M, size_t N, size_t K, ST1 alpha, MPA A, MPB B, ST2 beta, MPC C, MPD D) | ||
BLAST_ALWAYS_INLINE void gemm(size_t M, size_t N, size_t K, ST1 alpha, MPA A, MPB B, ST2 beta, MPC C, MPD D) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question, do we need to specify |
||
{ | ||
using ET = std::remove_cv_t<ElementType_t<MPA>>; | ||
using ET = std::remove_cv_t<ElementType_t<MPD>>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make a difference? |
||
size_t constexpr TILE_SIZE = TileSize_v<ET>; | ||
|
||
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPB>>, ET); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to put in |
||
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPC>>, ET); | ||
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPD>>, ET); | ||
|
||
size_t j = 0; | ||
|
||
// Main part | ||
for (; j + TILE_SIZE <= N; j += TILE_SIZE) | ||
{ | ||
size_t i = 0; | ||
|
||
// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: | ||
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. | ||
for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j)); | ||
} | ||
|
||
for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE) | ||
tile<ET, StorageOrder(StorageOrder_v<MPD>)>(M, N, | ||
[&] (auto& ker, size_t i, size_t j) | ||
{ | ||
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j)); | ||
} | ||
|
||
for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j)); | ||
} | ||
|
||
// Bottom edge | ||
if (i < M) | ||
{ | ||
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), M - i, ker.columns()); | ||
} | ||
} | ||
|
||
|
||
// Right edge | ||
if (j < N) | ||
{ | ||
size_t i = 0; | ||
|
||
// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE: | ||
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel. | ||
for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), ker.rows(), N - j); | ||
} | ||
|
||
for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), ker.rows(), N - j); | ||
} | ||
|
||
for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE) | ||
{ | ||
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), ker.rows(), N - j); | ||
} | ||
|
||
// Bottom-right corner | ||
if (i < M) | ||
}, | ||
[&] (auto& ker, size_t i, size_t j, size_t m, size_t n) | ||
{ | ||
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker; | ||
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), M - i, N - j); | ||
gemm(ker, K, alpha, A(i, 0), B(0, j), beta, C(i, j), D(i, j), m, n); | ||
} | ||
} | ||
); | ||
} | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// Copyright 2023 Mikhail Katliar | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <blaze/system/Inline.h> | ||
|
||
#define BLAST_ALWAYS_INLINE BLAZE_ALWAYS_INLINE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if we have matrix of integers, booleans or complex numbers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I think it will not compile, because we don't have
TileSize_v<ET>
defined for those types.