diff --git a/bench/blast/math/dense/DynamicGemm.cpp b/bench/blast/math/dense/DynamicGemm.cpp index 9d702760..10d8fb32 100644 --- a/bench/blast/math/dense/DynamicGemm.cpp +++ b/bench/blast/math/dense/DynamicGemm.cpp @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -#include +#include #include #include diff --git a/bench/blast/math/dense/StaticGemm.cpp b/bench/blast/math/dense/StaticGemm.cpp index 960c8ae9..2565d36a 100644 --- a/bench/blast/math/dense/StaticGemm.cpp +++ b/bench/blast/math/dense/StaticGemm.cpp @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -#include +#include #include #include diff --git a/bench/blast/math/panel/DynamicGemm.cpp b/bench/blast/math/panel/DynamicGemm.cpp index 96ac4d62..3b8866d0 100644 --- a/bench/blast/math/panel/DynamicGemm.cpp +++ b/bench/blast/math/panel/DynamicGemm.cpp @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. #include -#include +#include #include diff --git a/bench/blast/math/panel/StaticGemm.cpp b/bench/blast/math/panel/StaticGemm.cpp index c515c35e..0c201346 100644 --- a/bench/blast/math/panel/StaticGemm.cpp +++ b/bench/blast/math/panel/StaticGemm.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include diff --git a/include/blast/math/StaticPanelMatrix.hpp b/include/blast/math/StaticPanelMatrix.hpp index b191d789..cfbcde3f 100644 --- a/include/blast/math/StaticPanelMatrix.hpp +++ b/include/blast/math/StaticPanelMatrix.hpp @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include diff --git a/include/blast/math/algorithm/Gemm.hpp b/include/blast/math/algorithm/Gemm.hpp index 36b14939..ae80d0d1 100644 --- a/include/blast/math/algorithm/Gemm.hpp +++ b/include/blast/math/algorithm/Gemm.hpp @@ -4,15 +4,11 @@ #pragma once -#include -#include +#include #include #include - -#include - -#include -#include +#include +#include namespace blast @@ -26,12 +22,12 @@ namespace blast * alpha and beta are scalars, and A, B and C are matrices, with A * an m by k matrix, B a k by n matrix and C an m by n matrix. * - * @tparam ST1 - * @tparam MPA - * @tparam MPB - * @tparam ST2 - * @tparam MPC - * @tparam MPD + * @tparam ST1 scalar type for @a alpha + * @tparam MPA matrix pointer type for @a A + * @tparam MPB matrix pointer type for @a B + * @tparam ST2 scalar type for @a beta + * @tparam MPC matrix pointer type for @a C + * @tparam MPD matrix pointer type for @a D * * @param M the number of rows of the matrices A, C, and D. * @param N the number of columns of the matrices B and C. @@ -44,23 +40,13 @@ namespace blast * @param D the output matrix D */ template < - typename ST1, typename MPA, typename MPB, - typename ST2, typename MPC, typename MPD + typename ST1, MatrixPointer MPA, MatrixPointer MPB, + typename ST2, MatrixPointer MPC, MatrixPointer MPD > - requires ( - MatrixPointer && StorageOrder_v == columnMajor && - MatrixPointer && - MatrixPointer && StorageOrder_v == columnMajor && - MatrixPointer && StorageOrder_v == columnMajor - ) - BLAST_ALWAYS_INLINE void gemm(size_t M, size_t N, size_t K, ST1 alpha, MPA A, MPB B, ST2 beta, MPC C, MPD D) + inline void gemm(size_t M, size_t N, size_t K, ST1 alpha, MPA A, MPB B, ST2 beta, MPC C, MPD D) { using ET = std::remove_cv_t>; - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t>, ET); - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t>, ET); - BLAZE_CONSTRAINT_MUST_BE_SAME_TYPE(std::remove_cv_t>, ET); - tile)>( xsimd::default_arch {}, D.cachePreferredTraversal, @@ -75,4 +61,60 @@ namespace blast } ); } + + + /** + * @brief Matrix-matrix multiplication for @a DenseMatrix arguments + * + * D := alpha*A*B + beta*C + * + * alpha and beta are scalars, and A, B and C are matrices, with A + * an m by k matrix, B a k by n matrix and C an m by n matrix. + * + * @param alpha the scalar alpha + * @param A the matrix A + * @param B the matrix B + * @param beta the scalar beta + * @param C the matrix C + * @param D the output matrix D + */ + template + inline void gemm(ST1 alpha, MT1 const& A, MT2 const& B, ST2 beta, MT3 const& C, MT4& D) + { + size_t const M = rows(A); + size_t const N = columns(B); + size_t const K = columns(A); + + if (rows(B) != K) + BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); + + if (rows(C) != M || columns(C) != N) + BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); + + if (rows(D) != M || columns(D) != N) + BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); + + gemm(M, N, K, alpha, ptr(*A), ptr(*B), beta, ptr(*C), ptr(*D)); + } + + + /** + * @brief Matrix-matrix multiplication for @a DenseMatrix arguments + * + * D := A*B + C + * + * A, B and C are matrices, with A + * an m by k matrix, B a k by n matrix and C an m by n matrix. + * + * @param A the matrix A + * @param B the matrix B + * @param C the matrix C + * @param D the output matrix D + */ + template + inline void gemm(MT1 const& A, MT2 const& B, MT3 const& C, MT4& D) + { + using ET = ElementType_t; + gemm(ET(1.), A, B, ET(1.), C, D); + } } diff --git a/include/blast/math/algorithm/Tile.hpp b/include/blast/math/algorithm/Tile.hpp index 280267fa..7f6ecc63 100644 --- a/include/blast/math/algorithm/Tile.hpp +++ b/include/blast/math/algorithm/Tile.hpp @@ -1,16 +1,6 @@ -// Copyright 2023 Mikhail Katliar -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Copyright 2024 Mikhail Katliar. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. #include @@ -19,7 +9,6 @@ #endif #include -#include #include @@ -54,7 +43,7 @@ namespace blast * @param f_partial functor to call on partial tiles */ template - BLAST_ALWAYS_INLINE void tile(Arch arch, StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial) + inline void tile(Arch arch, StorageOrder traversal_order, std::size_t m, std::size_t n, FF&& f_full, FP&& f_partial) { detail::tile(arch, traversal_order, m, n, f_full, f_partial); } diff --git a/include/blast/math/algorithm/arch/avx2/Tile.hpp b/include/blast/math/algorithm/arch/avx2/Tile.hpp index 01974521..aac90b8a 100644 --- a/include/blast/math/algorithm/arch/avx2/Tile.hpp +++ b/include/blast/math/algorithm/arch/avx2/Tile.hpp @@ -1,16 +1,6 @@ -// Copyright 2023 Mikhail Katliar -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Copyright 2024 Mikhail Katliar. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. #include #include diff --git a/include/blast/math/dense/Gemm.hpp b/include/blast/math/dense/Gemm.hpp deleted file mode 100644 index 7c6cc40a..00000000 --- a/include/blast/math/dense/Gemm.hpp +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) 2019-2020 Mikhail Katliar All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -#pragma once - -#include -#include -#include - - -namespace blast -{ - /** - * @brief Matrix-matrix multiplication for @a DenseMatrix arguments - * - * D := alpha*A*B + beta*C - * - * alpha and beta are scalars, and A, B and C are matrices, with A - * an m by k matrix, B a k by n matrix and C an m by n matrix. - * - * @param alpha the scalar alpha - * @param A the matrix A - * @param B the matrix B - * @param beta the scalar beta - * @param C the matrix C - * @param D the output matrix D - */ - template < - typename ST1, typename MT1, typename MT2, bool SO2, - typename ST2, typename MT3, typename MT4 - > - inline void gemm( - ST1 alpha, - DenseMatrix const& A, DenseMatrix const& B, - ST2 beta, DenseMatrix const& C, DenseMatrix& D) - { - size_t const M = rows(A); - size_t const N = columns(B); - size_t const K = columns(A); - - if (rows(B) != K) - BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); - - if (rows(C) != M || columns(C) != N) - BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); - - if (rows(D) != M || columns(D) != N) - BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); - - gemm(M, N, K, alpha, ptr(*A), ptr(*B), beta, ptr(*C), ptr(*D)); - } - - - /** - * @brief Matrix-matrix multiplication for @a DenseMatrix arguments - * - * D := A*B + C - * - * A, B and C are matrices, with A - * an m by k matrix, B a k by n matrix and C an m by n matrix. - * - * @param A the matrix A - * @param B the matrix B - * @param C the matrix C - * @param D the output matrix D - */ - template < - typename MT1, typename MT2, bool SO2, - typename MT3, typename MT4 - > - inline void gemm( - DenseMatrix const& A, DenseMatrix const& B, - DenseMatrix const& C, DenseMatrix& D) - { - using ET = ElementType_t; - gemm(ET(1.), A, B, ET(1.), C, D); - } -} diff --git a/include/blast/math/dense/Getrf.hpp b/include/blast/math/dense/Getrf.hpp index 543a3ae5..7b14d59a 100644 --- a/include/blast/math/dense/Getrf.hpp +++ b/include/blast/math/dense/Getrf.hpp @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/blast/math/panel/Gemm.hpp b/include/blast/math/panel/Gemm.hpp deleted file mode 100644 index 531c2801..00000000 --- a/include/blast/math/panel/Gemm.hpp +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2019-2020 Mikhail Katliar All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -#pragma once - -#include -#include -#include -#include - - -namespace blast -{ - /** - * @brief Matrix-matrix multiplication for @a PanelMatrix arguments - * - * D := alpha*A*B + beta*C - * - * alpha and beta are scalars, and A, B and C are matrices, with A - * an m by k matrix, B a k by n matrix and C an m by n matrix. - * - * @param alpha the scalar alpha - * @param A the matrix A - * @param B the matrix B - * @param beta the scalar beta - * @param C the matrix C - * @param D the output matrix D - */ - template < - typename ST1, typename MT1, typename MT2, bool SO2, - typename ST2, typename MT3, typename MT4 - > - inline void gemm( - ST1 alpha, - PanelMatrix const& A, PanelMatrix const& B, - ST2 beta, PanelMatrix const& C, PanelMatrix& D) - { - size_t const M = rows(A); - size_t const N = columns(B); - size_t const K = columns(A); - - if (rows(B) != K) - BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); - - if (rows(C) != M || columns(C) != N) - BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); - - if (rows(D) != M || columns(D) != N) - BLAST_THROW_EXCEPTION(std::invalid_argument {"Matrix sizes do not match"}); - - gemm(M, N, K, alpha, ptr(*A), ptr(*B), beta, ptr(*C), ptr(*D)); - } - - - /** - * @brief Matrix-matrix multiplication for @a PanelMatrix arguments - * - * D := A*B + C - * - * A, B and C are matrices, with A - * an m by k matrix, B a k by n matrix and C an m by n matrix. - * - * @param A the matrix A - * @param B the matrix B - * @param C the matrix C - * @param D the output matrix D - */ - template < - typename MT1, typename MT2, bool SO2, - typename MT3, typename MT4 - > - inline void gemm( - PanelMatrix const& A, PanelMatrix const& B, - PanelMatrix const& C, PanelMatrix& D) - { - using ET = ElementType_t; - gemm(ET(1.), A, B, ET(1.), C, D); - } -} diff --git a/include/blast/math/panel/Potrf.hpp b/include/blast/math/panel/Potrf.hpp index 879c46f0..e1763809 100644 --- a/include/blast/math/panel/Potrf.hpp +++ b/include/blast/math/panel/Potrf.hpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include @@ -97,4 +97,4 @@ namespace blast potrf_backend<1 * PANEL_SIZE, KN>(k, i, *A, *L); } } -} \ No newline at end of file +} diff --git a/include/blast/math/panel/StaticPanelMatrix.hpp b/include/blast/math/panel/StaticPanelMatrix.hpp index 82358689..0fa0a7b3 100644 --- a/include/blast/math/panel/StaticPanelMatrix.hpp +++ b/include/blast/math/panel/StaticPanelMatrix.hpp @@ -5,8 +5,8 @@ #pragma once #include -#include #include +#include #include #include #include @@ -16,6 +16,7 @@ #include #include #include +#include #include #include diff --git a/include/blast/math/views/submatrix/BaseTemplate.hpp b/include/blast/math/views/submatrix/BaseTemplate.hpp index ef3cc409..a19b7cfa 100644 --- a/include/blast/math/views/submatrix/BaseTemplate.hpp +++ b/include/blast/math/views/submatrix/BaseTemplate.hpp @@ -4,6 +4,8 @@ #pragma once +#include + namespace blast { @@ -13,4 +15,4 @@ namespace blast class PanelSubmatrix { }; -} \ No newline at end of file +} diff --git a/include/blast/system/Inline.hpp b/include/blast/system/Inline.hpp index ff88fb88..68897585 100644 --- a/include/blast/system/Inline.hpp +++ b/include/blast/system/Inline.hpp @@ -1,16 +1,6 @@ -// Copyright 2023 Mikhail Katliar -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +// Copyright 2024 Mikhail Katliar. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. #pragma once diff --git a/test/blast/math/dense/GemmTest.cpp b/test/blast/math/dense/GemmTest.cpp index 82a8b318..d56daaf5 100644 --- a/test/blast/math/dense/GemmTest.cpp +++ b/test/blast/math/dense/GemmTest.cpp @@ -4,7 +4,7 @@ #define BLAST_USER_ASSERTION 1 -#include +#include #include #include diff --git a/test/blast/math/dense/TrmmTest.cpp b/test/blast/math/dense/TrmmTest.cpp index a61e3c4c..04005df3 100644 --- a/test/blast/math/dense/TrmmTest.cpp +++ b/test/blast/math/dense/TrmmTest.cpp @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. #include -#include +#include #include #include diff --git a/test/blast/math/expressions/PMatTransExprTest.cpp b/test/blast/math/expressions/PMatTransExprTest.cpp index d268c250..bd850f35 100644 --- a/test/blast/math/expressions/PMatTransExprTest.cpp +++ b/test/blast/math/expressions/PMatTransExprTest.cpp @@ -4,12 +4,10 @@ #include #include -#include -#include #include #include -#include + namespace blast :: testing { diff --git a/test/blast/math/panel/GemmTest.cpp b/test/blast/math/panel/GemmTest.cpp index 2accdfcb..71e88bd1 100644 --- a/test/blast/math/panel/GemmTest.cpp +++ b/test/blast/math/panel/GemmTest.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include diff --git a/test/blast/math/panel/PotrfTest.cpp b/test/blast/math/panel/PotrfTest.cpp index d0c267cd..46502097 100644 --- a/test/blast/math/panel/PotrfTest.cpp +++ b/test/blast/math/panel/PotrfTest.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include