Skip to content

Commit

Permalink
gemm() implementation independent of matrix types
Browse files Browse the repository at this point in the history
  • Loading branch information
mkatliar committed Aug 13, 2024
1 parent ce09527 commit 093e9f9
Show file tree
Hide file tree
Showing 20 changed files with 96 additions and 242 deletions.
2 changes: 1 addition & 1 deletion bench/blast/math/dense/DynamicGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include <blast/math/dense/Gemm.hpp>
#include <blast/math/algorithm/Gemm.hpp>
#include <blast/math/Matrix.hpp>
#include <blast/blaze/Math.hpp>

Expand Down
2 changes: 1 addition & 1 deletion bench/blast/math/dense/StaticGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include <blast/math/dense/Gemm.hpp>
#include <blast/math/algorithm/Gemm.hpp>
#include <blast/math/Matrix.hpp>
#include <blast/blaze/Math.hpp>

Expand Down
2 changes: 1 addition & 1 deletion bench/blast/math/panel/DynamicGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.

#include <blast/math/DynamicPanelMatrix.hpp>
#include <blast/math/panel/Gemm.hpp>
#include <blast/math/algorithm/Gemm.hpp>

#include <bench/Gemm.hpp>

Expand Down
2 changes: 1 addition & 1 deletion bench/blast/math/panel/StaticGemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include <blast/math/StaticPanelMatrix.hpp>
#include <blast/math/Matrix.hpp>
#include <blast/math/panel/Gemm.hpp>
#include <blast/math/algorithm/Gemm.hpp>
#include <blast/blaze/Math.hpp>

#include <bench/Gemm.hpp>
Expand Down
1 change: 1 addition & 0 deletions include/blast/math/StaticPanelMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <blast/math/panel/StaticPanelMatrix.hpp>
#include <blast/math/expressions/PMatTransExpr.hpp>

#include <blaze/util/Random.h>
#include <blaze/util/constraints/Numeric.h>
Expand Down
94 changes: 68 additions & 26 deletions include/blast/math/algorithm/Gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@

#pragma once

#include <blast/system/Inline.hpp>
#include <blast/math/typetraits/MatrixPointer.hpp>
#include <blast/math/TypeTraits.hpp>
#include <blast/math/algorithm/Tile.hpp>
#include <blast/math/register_matrix/Gemm.hpp>

#include <blaze/util/constraints/SameType.h>

#include <cstddef>
#include <type_traits>
#include <blast/math/Matrix.hpp>
#include <blast/util/Exception.hpp>


namespace blast
Expand All @@ -26,12 +22,12 @@ namespace blast
* alpha and beta are scalars, and A, B and C are matrices, with A
* an m by k matrix, B a k by n matrix and C an m by n matrix.
*
* @tparam ST1
* @tparam MPA
* @tparam MPB
* @tparam ST2
* @tparam MPC
* @tparam MPD
* @tparam ST1 scalar type for @a alpha
* @tparam MPA matrix pointer type for @a A
* @tparam MPB matrix pointer type for @a B
* @tparam ST2 scalar type for @a beta
* @tparam MPC matrix pointer type for @a C
* @tparam MPD matrix pointer type for @a D
*
* @param M the number of rows of the matrices A, C, and D.
* @param N the number of columns of the matrices B and C.
Expand All @@ -44,23 +40,13 @@ namespace blast
* @param D the output matrix D
*/
template <
typename ST1, typename MPA, typename MPB,
typename ST2, typename MPC, typename MPD
typename ST1, MatrixPointer MPA, MatrixPointer MPB,
typename ST2, MatrixPointer MPC, MatrixPointer MPD
>
requires (
MatrixPointer<MPA> && StorageOrder_v<MPA> == columnMajor &&
MatrixPointer<MPB> &&
MatrixPointer<MPC> && StorageOrder_v<MPC> == columnMajor &&
MatrixPointer<MPD> && StorageOrder_v<MPD> == columnMajor
)
BLAST_ALWAYS_INLINE void gemm(size_t M, size_t N, size_t K, ST1 alpha, MPA A, MPB B, ST2 beta, MPC C, MPD D)
inline void gemm(size_t M, size_t N, size_t K, ST1 alpha, MPA A, MPB B, ST2 beta, MPC C, MPD D)
{
using ET = std::remove_cv_t<ElementType_t<MPD>>;

BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPB>>, ET);
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPC>>, ET);
BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t<ElementType_t<MPD>>, ET);

tile<ET, StorageOrder(StorageOrder_v<MPD>)>(
xsimd::default_arch {},
D.cachePreferredTraversal,
Expand All @@ -75,4 +61,60 @@ namespace blast
}
);
}


/**
* @brief Matrix-matrix multiplication for @a DenseMatrix arguments
*
* D := alpha*A*B + beta*C
*
* alpha and beta are scalars, and A, B and C are matrices, with A
* an m by k matrix, B a k by n matrix and C an m by n matrix.
*
* @param alpha the scalar alpha
* @param A the matrix A
* @param B the matrix B
* @param beta the scalar beta
* @param C the matrix C
* @param D the output matrix D
*/
template <typename ST1, Matrix MT1, Matrix MT2, typename ST2, Matrix MT3, Matrix MT4>
inline void gemm(ST1 alpha, MT1 const& A, MT2 const& B, ST2 beta, MT3 const& C, MT4& D)
{
size_t const M = rows(A);
size_t const N = columns(B);
size_t const K = columns(A);

if (rows(B) != K)
BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"});

if (rows(C) != M || columns(C) != N)
BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"});

if (rows(D) != M || columns(D) != N)
BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"});

gemm(M, N, K, alpha, ptr(*A), ptr(*B), beta, ptr(*C), ptr(*D));
}


/**
* @brief Matrix-matrix multiplication for @a DenseMatrix arguments
*
* D := A*B + C
*
* A, B and C are matrices, with A
* an m by k matrix, B a k by n matrix and C an m by n matrix.
*
* @param A the matrix A
* @param B the matrix B
* @param C the matrix C
* @param D the output matrix D
*/
template <Matrix MT1, Matrix MT2, Matrix MT3, Matrix MT4>
inline void gemm(MT1 const& A, MT2 const& B, MT3 const& C, MT4& D)
{
using ET = ElementType_t<MT4>;
gemm(ET(1.), A, B, ET(1.), C, D);
}
}
19 changes: 4 additions & 15 deletions include/blast/math/algorithm/Tile.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
// Copyright 2023 Mikhail Katliar
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright 2024 Mikhail Katliar. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include <blast/math/Simd.hpp>

Expand All @@ -19,7 +9,6 @@
#endif

#include <blast/math/StorageOrder.hpp>
#include <blast/system/Inline.hpp>

#include <cstdlib>

Expand Down Expand Up @@ -54,7 +43,7 @@ namespace blast
* @param f_partial functor to call on partial tiles
*/
template <typename ET, StorageOrder SO, typename FF, typename FP, typename Arch>
BLAST_ALWAYS_INLINE void tile(Arch arch, StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial)
inline void tile(Arch arch, StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial)
{
detail::tile<ET, SO>(arch, traversal_order, m, n, f_full, f_partial);
}
Expand Down
16 changes: 3 additions & 13 deletions include/blast/math/algorithm/arch/avx2/Tile.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
// Copyright 2023 Mikhail Katliar
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright 2024 Mikhail Katliar. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

#include <blast/system/Tile.hpp>
#include <blast/system/Inline.hpp>
Expand Down
79 changes: 0 additions & 79 deletions include/blast/math/dense/Gemm.hpp

This file was deleted.

2 changes: 1 addition & 1 deletion include/blast/math/dense/Getrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <blast/math/RegisterMatrix.hpp>
#include <blast/math/dense/Getf2.hpp>
#include <blast/math/dense/Trsm.hpp>
#include <blast/math/dense/Gemm.hpp>
#include <blast/math/algorithm/Gemm.hpp>
#include <blast/system/Tile.hpp>

#include <blaze/util/Exception.h>
Expand Down
80 changes: 0 additions & 80 deletions include/blast/math/panel/Gemm.hpp

This file was deleted.

Loading

0 comments on commit 093e9f9

Please sign in to comment.