Skip to content

Commit

Permalink
Fix DenseTrmmTest for ARM
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikhail Katliar authored and mkatliar committed Sep 13, 2024
1 parent 06f56c1 commit fc2375c
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 45 deletions.
74 changes: 33 additions & 41 deletions include/blast/math/dense/Trmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,73 +4,65 @@

#pragma once

#include <blast/math/Matrix.hpp>
#include <blast/math/dense/TrmmBackend.hpp>
#include <blast/system/Tile.hpp>

#include <blaze/util/Exception.h>
#include <blaze/util/constraints/SameType.h>
#include <blaze/math/DenseMatrix.h>

#include <algorithm>


namespace blast
{
/// @brief C = alpha * A * B + C; A upper-triangular
///
template <typename ST, typename MT1, typename MT2, bool SO2, typename MT3>
template <typename ST, typename MT1, typename MT2, typename MT3>
requires Matrix<MT1, ST> && Matrix<MT2, ST> && Matrix<MT3, ST>
&& (StorageOrder_v<MT1> == columnMajor) && (StorageOrder_v<MT3> == columnMajor)
inline void trmmLeftUpper(
ST alpha,
DenseMatrix<MT1, columnMajor> const& A, DenseMatrix<MT2, SO2> const& B,
DenseMatrix<MT3, columnMajor>& C)
MT1 const& A, MT2 const& B,
MT3& C)
{
using ET = ElementType_t<MT1>;
using ET = ST;
size_t constexpr TILE_SIZE = TileSize_v<ET>;

BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t<MT2>, ET);
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t<MT3>, ET);

size_t const M = rows(B);
size_t const N = columns(B);

if (rows(A) != M || columns(A) != M)
BLAZE_THROW_INVALID_ARGUMENT("Matrix sizes do not match");
throw std::invalid_argument {"Matrix sizes do not match"};

if (rows(C) != M || columns(C) != N)
BLAZE_THROW_INVALID_ARGUMENT("Matrix sizes do not match");
throw std::invalid_argument {"Matrix sizes do not match"};

size_t i = 0;

// i + 4 * TILE_SIZE != M is to improve performance in case when the remaining number of rows is 4 * TILE_SIZE:
// it is more efficient to apply 2 * TILE_SIZE kernel 2 times than 3 * TILE_SIZE + 1 * TILE_SIZE kernel.
for (; i + 2 * TILE_SIZE < M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE)
trmmLeftUpper_backend<3 * TILE_SIZE, TILE_SIZE>(
M - i, N, alpha, ptr<aligned>(*A, i, i), ptr<aligned>(*B, i, 0), ptr<aligned>(*C, i, 0));
M - i, N, alpha, ptr<aligned>(A, i, i), ptr<aligned>(B, i, 0), ptr<aligned>(C, i, 0));

for (; i + 1 * TILE_SIZE < M; i += 2 * TILE_SIZE)
trmmLeftUpper_backend<2 * TILE_SIZE, TILE_SIZE>(
M - i, N, alpha, ptr<aligned>(*A, i, i), ptr<aligned>(*B, i, 0), ptr<aligned>(*C, i, 0));
M - i, N, alpha, ptr<aligned>(A, i, i), ptr<aligned>(B, i, 0), ptr<aligned>(C, i, 0));

for (; i + 0 * TILE_SIZE < M; i += 1 * TILE_SIZE)
trmmLeftUpper_backend<1 * TILE_SIZE, TILE_SIZE>(
M - i, N, alpha, ptr<aligned>(*A, i, i), ptr<aligned>(*B, i, 0), ptr<aligned>(*C, i, 0));
M - i, N, alpha, ptr<aligned>(A, i, i), ptr<aligned>(B, i, 0), ptr<aligned>(C, i, 0));
}


/// @brief C = alpha * B * A + C; A lower-triangular
///
template <typename ET, typename MTB, typename MTA, bool SOA, typename MTC>
template <typename ET, typename MTB, typename MTA, typename MTC>
requires Matrix<MTB, ET> && Matrix<MTA, ET> && Matrix<MTC, ET>
&& (StorageOrder_v<MTB> == columnMajor) && (StorageOrder_v<MTC> == columnMajor)
inline void trmmRightLower(
ET alpha,
DenseMatrix<MTB, columnMajor> const& B, DenseMatrix<MTA, SOA> const& A,
DenseMatrix<MTC, columnMajor>& C)
MTB const& B, MTA const& A,
MTC& C)
{
size_t constexpr TILE_SIZE = TileSize_v<ET>;

BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t<MTB>, ET);
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t<MTA>, ET);
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(ElementType_t<MTC>, ET);

size_t const M = rows(B);
size_t const N = columns(B);

Expand All @@ -93,46 +85,46 @@ namespace blast
for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, N - j, alpha, ptr<aligned>(*B, i, j), ptr<aligned>(*A, j, j));
gemm(ker, N - j, alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j));
/*
ker.trmmRightLower(alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j));
ker.gemm(K, alpha, ptr<aligned>(B, i, j + TILE_SIZE), ptr<aligned>(A, j + TILE_SIZE, j));
*/
ker.store(ptr<aligned>(*C, i, j));
ker.store(ptr<aligned>(C, i, j));
}

for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, N - j, alpha, ptr<aligned>(*B, i, j), ptr<aligned>(*A, j, j));
gemm(ker, N - j, alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j));
/*
ker.trmmRightLower(alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j));
ker.gemm(K, alpha, ptr<aligned>(B, i, j + TILE_SIZE), ptr<aligned>(A, j + TILE_SIZE, j));
*/
ker.store(ptr<aligned>(*C, i, j));
ker.store(ptr<aligned>(C, i, j));
}

for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, N - j, alpha, ptr<aligned>(*B, i, j), ptr<aligned>(*A, j, j));
gemm(ker, N - j, alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j));
/*
ker.trmmRightLower(alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j));
ker.gemm(K, alpha, ptr<aligned>(B, i, j + TILE_SIZE), ptr<aligned>(A, j + TILE_SIZE, j));
*/
ker.store(ptr<aligned>(*C, i, j));
ker.store(ptr<aligned>(C, i, j));
}

// Bottom side
if (i < M)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, N - j, alpha, ptr<aligned>(*B, i, j), ptr<aligned>(*A, j, j), M - i, ker.columns());
gemm(ker, N - j, alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j), M - i, ker.columns());
/*
ker.trmmRightLower(alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j));
ker.gemm(K, alpha, ptr<aligned>(B, i, j + TILE_SIZE), ptr<aligned>(A, j + TILE_SIZE, j), M - i, ker.columns());
*/
ker.store(ptr<aligned>(*C, i, j), M - i, ker.columns());
ker.store(ptr<aligned>(C, i, j), M - i, ker.columns());
}
}

Expand All @@ -147,30 +139,30 @@ namespace blast
for (; i + 3 * TILE_SIZE <= M && i + 4 * TILE_SIZE != M; i += 3 * TILE_SIZE)
{
RegisterMatrix<ET, 3 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, N - j, alpha, ptr<aligned>(*B, i, j), ptr<aligned>(*A, j, j), ker.rows(), N - j);
ker.store(ptr<aligned>(*C, i, j), ker.rows(), N - j);
gemm(ker, N - j, alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j), ker.rows(), N - j);
ker.store(ptr<aligned>(C, i, j), ker.rows(), N - j);
}

for (; i + 2 * TILE_SIZE <= M; i += 2 * TILE_SIZE)
{
RegisterMatrix<ET, 2 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, N - j, alpha, ptr<aligned>(*B, i, j), ptr<aligned>(*A, j, j), ker.rows(), N - j);
ker.store(ptr<aligned>(*C, i, j), ker.rows(), N - j);
gemm(ker, N - j, alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j), ker.rows(), N - j);
ker.store(ptr<aligned>(C, i, j), ker.rows(), N - j);
}

for (; i + 1 * TILE_SIZE <= M; i += 1 * TILE_SIZE)
{
RegisterMatrix<ET, 1 * TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, N - j, alpha, ptr<aligned>(*B, i, j), ptr<aligned>(*A, j, j), ker.rows(), N - j);
ker.store(ptr<aligned>(*C, i, j), ker.rows(), N - j);
gemm(ker, N - j, alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j), ker.rows(), N - j);
ker.store(ptr<aligned>(C, i, j), ker.rows(), N - j);
}

// Bottom-right corner
if (i < M)
{
RegisterMatrix<ET, TILE_SIZE, TILE_SIZE, columnMajor> ker;
gemm(ker, N - j, alpha, ptr<aligned>(*B, i, j), ptr<aligned>(*A, j, j), M - i, N - j);
ker.store(ptr<aligned>(*C, i, j), M - i, N - j);
gemm(ker, N - j, alpha, ptr<aligned>(B, i, j), ptr<aligned>(A, j, j), M - i, N - j);
ker.store(ptr<aligned>(C, i, j), M - i, N - j);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion include/blast/math/dense/TrmmBackend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace blast
BLAZE_ALWAYS_INLINE void trmmLeftUpper_backend(size_t M, size_t N, T alpha, P1 a, P2 b, P3 c)
{
size_t constexpr TILE_SIZE = TileSize_v<T>;
BLAZE_STATIC_ASSERT(KM % TILE_SIZE == 0);
static_assert(KM % TILE_SIZE == 0);

RegisterMatrix<T, KM, KN, columnMajor> ker;

Expand Down
Loading

0 comments on commit fc2375c

Please sign in to comment.