Skip to content

Commit

Permalink
Made the code compile and tests pass on avx2
Browse files Browse the repository at this point in the history
  • Loading branch information
mkatliar committed Sep 12, 2024
1 parent 9a418ea commit 145f1aa
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 23 deletions.
55 changes: 37 additions & 18 deletions include/blast/math/Matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
#include <blast/util/Types.hpp>
#include <blast/system/Inline.hpp>

#include <blaze/math/Aliases.h>
#include <iostream>
#include <stdexcept>
#include <random>


namespace blast
Expand Down Expand Up @@ -47,32 +50,48 @@ namespace blast
{
return ptr<IsAligned_v<MT>>(m, 0, 0);
}


template <Matrix M>
inline void randomize(M& m) noexcept
{
std::mt19937 rng;
std::uniform_real_distribution<ElementType_t<M>> dist;

for (size_t i = 0; i < rows(m); ++i)
for (size_t j = 0; j < columns(m); ++j)
m(i, j) = dist(rng);
}


template <Matrix MA, Matrix MB>
inline bool operator==(MA const& a, MB const& b) noexcept
{
return false;
}
// template <Matrix MA, Matrix MB>
// inline bool operator==(MA const& a, MB const& b) noexcept
// {
// size_t const M = columns(a);
// size_t const N = rows(a);

// if (M != columns(b) || N != rows(b))
// throw std::invalid_argument {"Inconsistent matrix sizes"};

template <Matrix M>
inline std::ostream& operator<<(std::ostream& os, M const& m)
{
for (size_t i = 0; i < rows(m); ++i)
{
for (size_t j = 0; j < columns(m); ++j)
os << m(i, j) << "\t";
os << std::endl;
}
// for (size_t i = 0; i < M; ++i)
// for (size_t j = 0; j < N; ++j)
// if (a(i, j) != b(i, j))
// return false;

return os;
}
// return true;
// }


// template <Matrix M>
// inline std::ostream& operator<<(std::ostream& os, M const& m)
// {
// for (size_t i = 0; i < rows(m); ++i)
// {
// for (size_t j = 0; j < columns(m); ++j)
// os << m(i, j) << "\t";
// os << std::endl;
// }

// return os;
// }
}
2 changes: 1 addition & 1 deletion include/blast/math/panel/Potrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace blast
if (columns(L) != N)
BLAZE_THROW_INVALID_ARGUMENT("Invalid matrix size");

size_t constexpr KN = PANEL_SIZE;
size_t constexpr KN = 4;
size_t k = 0;

// This loop unroll gives some performance benefit for N >= 18,
Expand Down
21 changes: 20 additions & 1 deletion include/blast/math/reference/Ger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,31 @@

namespace blast :: reference
{
/**
* @brief Reference implementation of rank-1 update with multiplier
*
* a(i, j) += alpha * x(i) * y(j)
* for i=0...m-1, j=n-1
*
* @tparam Real real number type
* @tparam VPX vector pointer type for the column vector @a x
* @tparam VPY vector pointer type for the row vector @a y
* @tparam MPA
*
* @param m number of rows in the matrix
* @param n number of columns in the matrix
* @param alpha scalar multiplier
* @param x column vector
* @param y row vector
* @param a matrix to perform update on
*
*/
template <typename Real, typename VPX, typename VPY, typename MPA>
requires VectorPointer<VPX, Real> && VectorPointer<VPY, Real> && MatrixPointer<MPA, Real>
inline void ger(size_t m, size_t n, Real alpha, VPX x, VPY y, MPA a)
{
for (size_t i = 0; i < m; ++i)
for (size_t j = 0; j < n; ++j)
*a(i, j) += alpha * *x(i) * *y(j);
*(~a)(i, j) += alpha * *(~x)(i) * *(~y)(j);
}
}
5 changes: 2 additions & 3 deletions test/blast/math/simd/RegisterMatrixTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

#include <blast/math/RegisterMatrix.hpp>
#include <blast/math/Matrix.hpp>
#include <blast/math/panel/MatrixPointer.hpp>
#include <blast/math/dense/MatrixPointer.hpp>
#include <blast/math/dense/VectorPointer.hpp>
#include <blast/math/Vector.hpp>
#include <blast/math/views/submatrix/Panel.hpp>
#include <blast/math/reference/Ger.hpp>
#include <blast/math/panel/StaticPanelMatrix.hpp>

#include <test/Testing.hpp>
#include <test/Randomize.hpp>
Expand Down

0 comments on commit 145f1aa

Please sign in to comment.