From 145f1aa5dfe6c4627aaa14d217954fa5e479f29e Mon Sep 17 00:00:00 2001 From: Mikhail Katliar Date: Thu, 12 Sep 2024 13:57:49 +0200 Subject: [PATCH] Made the code compile and tests pass on avx2 --- include/blast/math/Matrix.hpp | 55 ++++++++++++++------- include/blast/math/panel/Potrf.hpp | 2 +- include/blast/math/reference/Ger.hpp | 21 +++++++- test/blast/math/simd/RegisterMatrixTest.cpp | 5 +- 4 files changed, 60 insertions(+), 23 deletions(-) diff --git a/include/blast/math/Matrix.hpp b/include/blast/math/Matrix.hpp index cd01616..0e288f8 100644 --- a/include/blast/math/Matrix.hpp +++ b/include/blast/math/Matrix.hpp @@ -12,7 +12,10 @@ #include #include +#include #include +#include +#include namespace blast @@ -47,32 +50,48 @@ namespace blast { return ptr>(m, 0, 0); } - - + + template inline void randomize(M& m) noexcept { + std::mt19937 rng; + std::uniform_real_distribution> dist; + for (size_t i = 0; i < rows(m); ++i) + for (size_t j = 0; j < columns(m); ++j) + m(i, j) = dist(rng); } - template - inline bool operator==(MA const& a, MB const& b) noexcept - { - return false; - } + // template + // inline bool operator==(MA const& a, MB const& b) noexcept + // { + // size_t const M = columns(a); + // size_t const N = rows(a); + // if (M != columns(b) || N != rows(b)) + // throw std::invalid_argument {"Inconsistent matrix sizes"}; - template - inline std::ostream& operator<<(std::ostream& os, M const& m) - { - for (size_t i = 0; i < rows(m); ++i) - { - for (size_t j = 0; j < columns(m); ++j) - os << m(i, j) << "\t"; - os << std::endl; - } + // for (size_t i = 0; i < M; ++i) + // for (size_t j = 0; j < N; ++j) + // if (a(i, j) != b(i, j)) + // return false; - return os; - } + // return true; + // } + + + // template + // inline std::ostream& operator<<(std::ostream& os, M const& m) + // { + // for (size_t i = 0; i < rows(m); ++i) + // { + // for (size_t j = 0; j < columns(m); ++j) + // os << m(i, j) << "\t"; + // os << std::endl; + // } + + // return os; + // } } diff --git a/include/blast/math/panel/Potrf.hpp b/include/blast/math/panel/Potrf.hpp index e2a4196..e176380 100644 --- a/include/blast/math/panel/Potrf.hpp +++ b/include/blast/math/panel/Potrf.hpp @@ -77,7 +77,7 @@ namespace blast if (columns(L) != N) BLAZE_THROW_INVALID_ARGUMENT("Invalid matrix size"); - size_t constexpr KN = PANEL_SIZE; + size_t constexpr KN = 4; size_t k = 0; // This loop unroll gives some performance benefit for N >= 18, diff --git a/include/blast/math/reference/Ger.hpp b/include/blast/math/reference/Ger.hpp index 097856e..04774eb 100644 --- a/include/blast/math/reference/Ger.hpp +++ b/include/blast/math/reference/Ger.hpp @@ -10,12 +10,31 @@ namespace blast :: reference { + /** + * @brief Reference implementation of rank-1 update with multiplier + * + * a(i, j) += alpha * x(i) * y(j) + * for i=0...m-1, j=n-1 + * + * @tparam Real real number type + * @tparam VPX vector pointer type for the column vector @a x + * @tparam VPY vector pointer type for the row vector @a y + * @tparam MPA + * + * @param m number of rows in the matrix + * @param n number of columns in the matrix + * @param alpha scalar multiplier + * @param x column vector + * @param y row vector + * @param a matrix to perform update on + * + */ template requires VectorPointer && VectorPointer && MatrixPointer inline void ger(size_t m, size_t n, Real alpha, VPX x, VPY y, MPA a) { for (size_t i = 0; i < m; ++i) for (size_t j = 0; j < n; ++j) - *a(i, j) += alpha * *x(i) * *y(j); + *(~a)(i, j) += alpha * *(~x)(i) * *(~y)(j); } } diff --git a/test/blast/math/simd/RegisterMatrixTest.cpp b/test/blast/math/simd/RegisterMatrixTest.cpp index 606c97d..767c24b 100644 --- a/test/blast/math/simd/RegisterMatrixTest.cpp +++ b/test/blast/math/simd/RegisterMatrixTest.cpp @@ -4,11 +4,10 @@ #include #include -#include -#include -#include +#include #include #include +#include #include #include