Skip to content

Commit

Permalink
Reordering files and getting rid of obsolete gemm() functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mkatliar committed Mar 8, 2024
1 parent e281988 commit 8ecbee4
Show file tree
Hide file tree
Showing 19 changed files with 26 additions and 193 deletions.
2 changes: 1 addition & 1 deletion bench/blast/math/simd/Ger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.

#include <blast/math/DynamicPanelMatrix.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>

#include <bench/Benchmark.hpp>

Expand Down
2 changes: 1 addition & 1 deletion bench/blast/math/simd/PartialGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.

#include <blast/math/DynamicPanelMatrix.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>
#include <blast/math/dense/StaticMatrixPointer.hpp>

#include <bench/Benchmark.hpp>
Expand Down
2 changes: 1 addition & 1 deletion bench/blast/math/simd/PartialLoad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.

#include <blast/math/DynamicPanelMatrix.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>
#include <blast/math/dense/StaticMatrixPointer.hpp>

#include <bench/Benchmark.hpp>
Expand Down
2 changes: 1 addition & 1 deletion bench/blast/math/simd/PartialStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.

#include <blast/math/DynamicPanelMatrix.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>
#include <blast/math/dense/StaticMatrixPointer.hpp>

#include <bench/Benchmark.hpp>
Expand Down
2 changes: 1 addition & 1 deletion bench/blast/math/simd/Potrf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include <blast/math/StaticPanelMatrix.hpp>
#include <blast/math/DynamicPanelMatrix.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>

#include <bench/Benchmark.hpp>
#include <bench/Complexity.hpp>
Expand Down
2 changes: 1 addition & 1 deletion bench/blast/math/simd/Trmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.

#include <blast/math/dense/StaticMatrixPointer.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>

#include <bench/Benchmark.hpp>

Expand Down
2 changes: 1 addition & 1 deletion bench/blast/math/simd/Trsm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.

#include <blast/math/StaticPanelMatrix.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>

#include <bench/Benchmark.hpp>
#include <bench/Trsm.hpp>
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion include/blast/math/algorithm/Tile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <blast/system/Tile.hpp>
#include <blast/system/Inline.hpp>
#include <blast/math/StorageOrder.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>

#include <cstdlib>

Expand Down
87 changes: 0 additions & 87 deletions include/blast/math/dense/GemmBackend.hpp

This file was deleted.

2 changes: 1 addition & 1 deletion include/blast/math/dense/Ger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <blast/util/Exception.hpp>

#include <blast/system/Tile.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>
#include <blast/math/typetraits/VectorPointer.hpp>
#include <blast/math/dense/VectorPointer.hpp>
#include <blast/math/dense/MatrixPointer.hpp>
Expand Down
2 changes: 1 addition & 1 deletion include/blast/math/dense/Getrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include <blast/math/dense/DynamicMatrixPointer.hpp>
#include <blast/math/dense/StaticMatrixPointer.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>
#include <blast/math/dense/Getf2.hpp>
#include <blast/math/dense/Trsm.hpp>
#include <blast/math/dense/Gemm.hpp>
Expand Down
2 changes: 1 addition & 1 deletion include/blast/math/dense/Potrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include <blast/math/dense/MatrixPointer.hpp>
#include <blast/math/RowColumnVectorPointer.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>
#include <blast/system/Tile.hpp>

#include <blaze/util/Exception.h>
Expand Down
3 changes: 1 addition & 2 deletions include/blast/math/dense/SyrkBackend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
#pragma once

#include <blast/system/Tile.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>
#include <blast/math/dense/DynamicMatrixPointer.hpp>
#include <blast/math/dense/StaticMatrixPointer.hpp>
#include <blast/math/dense/GemmBackend.hpp>

#include <blaze/util/Exception.h>
#include <blaze/util/constraints/SameType.h>
Expand Down
2 changes: 1 addition & 1 deletion include/blast/math/dense/TrmmBackend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>
#include <blast/math/dense/DynamicMatrixPointer.hpp>
#include <blast/math/dense/StaticMatrixPointer.hpp>
#include <blast/system/Tile.hpp>
Expand Down
2 changes: 1 addition & 1 deletion include/blast/math/expressions/PanelMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <blast/math/typetraits/IsPanelMatrix.hpp>
#include <blast/math/simd/Simd.hpp>
#include <blast/math/panel/PanelSize.hpp>
//#include <blast/math/simd/RegisterMatrix.hpp>
//#include <blast/math/register_matrix/RegisterMatrix.hpp>

#include <blaze/math/ReductionFlag.h>
#include <blaze/math/Matrix.h>
Expand Down
99 changes: 10 additions & 89 deletions include/blast/math/panel/Gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,96 +5,17 @@
#pragma once

#include <blast/math/PanelMatrix.hpp>
#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>
#include <blast/math/register_matrix/Gemm.hpp>
#include <blast/math/panel/PanelSize.hpp>
#include <blast/math/panel/MatrixPointer.hpp>

#include <blaze/math/AlignmentFlag.h>
#include <blaze/util/Exception.h>
#include <blaze/util/constraints/SameType.h>

#include <algorithm>


namespace blast
{
using namespace blaze;


template <
typename T, size_t M, size_t N,
typename MPA, typename MPB, typename MPC, typename MPD
>
requires
MatrixPointer<MPA> &&
MatrixPointer<MPB> &&
MatrixPointer<MPC> &&
MatrixPointer<MPD>
BLAZE_ALWAYS_INLINE void gemm_backend(
RegisterMatrix<T, M, N, columnMajor>& ker,
size_t K, T alpha, T beta,
MPA a, MPB b, MPC c, MPD d)
{
ker.load(beta, c);

for (size_t k = 0; k < K; ++k)
{
ker.ger(alpha, column(a), row(b));

a.hmove(1);
b.vmove(1);
}

ker.store(d);
}


template <
typename T, size_t M, size_t N,
typename MT1, bool SO1, typename MT2, bool SO2,
typename MT3, bool SO3, typename MT4, bool SO4
>
BLAZE_ALWAYS_INLINE void gemm_backend(RegisterMatrix<T, M, N, columnMajor>& ker, size_t K, T alpha, T beta,
Matrix<MT1, SO1> const& A, Matrix<MT2, SO2> const& B, Matrix<MT3, SO3> const& C, Matrix<MT4, SO4>& D)
{
ker.load(beta, *C);

for (size_t k = 0; k < K; ++k)
ker.ger(alpha, column(*A, k), row(*B, k));

ker.store(*D);
}


template <
typename T, size_t M, size_t N,
typename MPA, typename MPB, typename MPC, typename MPD
>
requires
MatrixPointer<MPA> &&
MatrixPointer<MPB> &&
MatrixPointer<MPC> &&
MatrixPointer<MPD>
BLAZE_ALWAYS_INLINE void gemm_backend(
RegisterMatrix<T, M, N, columnMajor>& ker,
size_t K, T alpha, T beta,
MPA a, MPB b, MPC c, MPD d,
size_t md, size_t nd)
{
ker.load(beta, c, md, nd);

for (size_t k = 0; k < K; ++k)
{
ker.ger(alpha, column(a), row(b), md, nd);

a.hmove(1);
b.vmove(1);
}

ker.store(d, md, nd);
}


template <typename MT1, typename MT2, typename MT3, typename MT4>
BLAZE_ALWAYS_INLINE void gemm_nt(
PanelMatrix<MT1, columnMajor> const& A, PanelMatrix<MT2, columnMajor> const& B,
Expand Down Expand Up @@ -190,29 +111,29 @@ namespace blast
auto a = ptr<aligned>(A, i, 0);

for (; j + KN <= N; j += KN)
gemm_backend(ker, K, alpha, beta,
gemm(ker, K, alpha,
a, trans(ptr<unaligned>(B, j, 0)),
ptr<aligned>(C, i, j), ptr<aligned>(D, i, j));
beta, ptr<aligned>(C, i, j), ptr<aligned>(D, i, j));

if (j < N)
gemm_backend(ker, K, alpha, beta,
gemm(ker, K, alpha,
a, trans(ptr<unaligned>(B, j, 0)),
ptr<aligned>(C, i, j), ptr<aligned>(D, i, j), KM, N - j);
beta, ptr<aligned>(C, i, j), ptr<aligned>(D, i, j), KM, N - j);
}
else
{
// Use partial save to calculate the bottom of the resulting matrix.
size_t j = 0;

for (; j + KN <= N; j += KN)
gemm_backend(ker, K, alpha, beta,
gemm(ker, K, alpha,
ptr<aligned>(A, i, 0), trans(ptr<unaligned>(B, j, 0)),
ptr<aligned>(C, i, j), ptr<aligned>(D, i, j), M - i, KN);
beta, ptr<aligned>(C, i, j), ptr<aligned>(D, i, j), M - i, KN);

if (j < N)
gemm_backend(ker, K, alpha, beta,
gemm(ker, K, alpha,
ptr<aligned>(A, i, 0), trans(ptr<unaligned>(B, j, 0)),
ptr<aligned>(C, i, j), ptr<aligned>(D, i, j), M - i, N - j);
beta, ptr<aligned>(C, i, j), ptr<aligned>(D, i, j), M - i, N - j);
}
}
}
2 changes: 1 addition & 1 deletion test/blast/math/simd/DynamicRegisterMatrixTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/DynamicRegisterMatrix.hpp>
#include <blast/math/StaticPanelMatrix.hpp>
#include <blast/math/dense/MatrixPointer.hpp>
#include <blast/math/views/submatrix/Panel.hpp>
Expand Down
2 changes: 1 addition & 1 deletion test/blast/math/simd/RegisterMatrixTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include <blast/math/simd/RegisterMatrix.hpp>
#include <blast/math/register_matrix/RegisterMatrix.hpp>
#include <blast/math/StaticPanelMatrix.hpp>
#include <blast/math/panel/MatrixPointer.hpp>
#include <blast/math/dense/MatrixPointer.hpp>
Expand Down

0 comments on commit 8ecbee4

Please sign in to comment.