Skip to content

Commit

Permalink
customization: tpl hooks for sevaral BLAS1,2 fncs kokkos#96
Browse files Browse the repository at this point in the history
  • Loading branch information
fnrizzi committed Oct 27, 2021
1 parent 0cfe4a3 commit bb4b1ef
Show file tree
Hide file tree
Showing 9 changed files with 426 additions and 51 deletions.
99 changes: 95 additions & 4 deletions include/experimental/__p1673_bits/blas1_givens.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,33 @@ namespace linalg {
// DSQRT -> sqrt (Real input and return value)
// slapy2(real(fs), aimag(fs)) -> hypot(real(fs), imag(fs))


// begin anonymous namespace
namespace {
template <class Exec, class x_t, class y_t, class c_t, class s_t, class = void>
struct is_custom_givens_rotation_apply_avail : std::false_type {};

template <class Exec, class x_t, class y_t, class c_t, class s_t>
struct is_custom_givens_rotation_apply_avail<
Exec, x_t, y_t, c_t, s_t,
std::void_t<
decltype(givens_rotation_apply
(std::declval<Exec>(),
std::declval<x_t>(),
std::declval<y_t>(),
std::declval<const c_t>(),
std::declval<const s_t>()
)
)
>
>
{
static constexpr bool value = !std::is_same<Exec,std::experimental::linalg::impl::inline_exec_t>::value;
};
} // end anonymous namespace



template<std::floating_point Real>
void givens_rotation_setup(const Real f,
const Real g,
Expand Down Expand Up @@ -344,6 +371,7 @@ void givens_rotation_setup(const complex<Real>& f,
}
}

// c and s are std::floating_point
template<class ElementType1,
extents<>::size_type ext1,
class Layout1,
Expand All @@ -354,6 +382,7 @@ template<class ElementType1,
class Accessor2,
std::floating_point Real>
void givens_rotation_apply(
std::experimental::linalg::impl::inline_exec_t&& /* exec */,
std::experimental::mdspan<ElementType1, std::experimental::extents<ext1>, Layout1, Accessor1> x,
std::experimental::mdspan<ElementType2, std::experimental::extents<ext2>, Layout2, Accessor2> y,
const Real c,
Expand Down Expand Up @@ -381,15 +410,47 @@ template<class ExecutionPolicy,
class Accessor2,
std::floating_point Real>
void givens_rotation_apply(
ExecutionPolicy&& /* exec */,
ExecutionPolicy&& exec,
std::experimental::mdspan<ElementType1, std::experimental::extents<ext1>, Layout1, Accessor1> x,
std::experimental::mdspan<ElementType2, std::experimental::extents<ext2>, Layout2, Accessor2> y,
const Real c,
const Real s)
{

constexpr bool use_custom = is_custom_givens_rotation_apply_avail<
decltype(execpolicy_mapper(exec)), decltype(x), decltype(y), Real, Real
>::value;

if constexpr(use_custom){
givens_rotation_apply(execpolicy_mapper(exec), x, y, c, s);
}
else
{
givens_rotation_apply(std::experimental::linalg::impl::inline_exec_t(), x, y, c, s);
}
}

template<class ElementType1,
extents<>::size_type ext1,
class Layout1,
class Accessor1,
class ElementType2,
extents<>::size_type ext2,
class Layout2,
class Accessor2,
std::floating_point Real>
void givens_rotation_apply(
std::experimental::mdspan<ElementType1, std::experimental::extents<ext1>, Layout1, Accessor1> x,
std::experimental::mdspan<ElementType2, std::experimental::extents<ext2>, Layout2, Accessor2> y,
const Real c,
const Real s)
{
givens_rotation_apply(x, y, c, s);
givens_rotation_apply(std::experimental::linalg::impl::default_exec_t(), x, y, c, s);
}


// c is std::floating_point
// s is complex<std::floating_point>
template<class ElementType1,
extents<>::size_type ext1,
class Layout1,
Expand All @@ -400,6 +461,7 @@ template<class ElementType1,
class Accessor2,
std::floating_point Real>
void givens_rotation_apply(
std::experimental::linalg::impl::inline_exec_t&& /* exec */,
std::experimental::mdspan<ElementType1, std::experimental::extents<ext1>, Layout1, Accessor1> x,
std::experimental::mdspan<ElementType2, std::experimental::extents<ext2>, Layout2, Accessor2> y,
const Real c,
Expand Down Expand Up @@ -428,13 +490,42 @@ template<class ExecutionPolicy,
class Accessor2,
std::floating_point Real>
void givens_rotation_apply(
ExecutionPolicy&& /* exec */,
ExecutionPolicy&& exec,
std::experimental::mdspan<ElementType1, std::experimental::extents<ext1>, Layout1, Accessor1> x,
std::experimental::mdspan<ElementType2, std::experimental::extents<ext2>, Layout2, Accessor2> y,
const Real c,
const complex<Real> s)
{

constexpr bool use_custom = is_custom_givens_rotation_apply_avail<
decltype(execpolicy_mapper(exec)), decltype(x), decltype(y), Real, complex<Real>
>::value;

if constexpr(use_custom){
givens_rotation_apply(execpolicy_mapper(exec), x, y, c, s);
}
else
{
givens_rotation_apply(std::experimental::linalg::impl::inline_exec_t(), x, y, c, s);
}
}

template<class ElementType1,
extents<>::size_type ext1,
class Layout1,
class Accessor1,
class ElementType2,
extents<>::size_type ext2,
class Layout2,
class Accessor2,
std::floating_point Real>
void givens_rotation_apply(
std::experimental::mdspan<ElementType1, std::experimental::extents<ext1>, Layout1, Accessor1> x,
std::experimental::mdspan<ElementType2, std::experimental::extents<ext2>, Layout2, Accessor2> y,
const Real c,
const complex<Real> s)
{
givens_rotation_apply(x, y, c, s);
givens_rotation_apply(std::experimental::linalg::impl::default_exec_t(), x, y, c, s);
}

} // end namespace linalg
Expand Down
4 changes: 0 additions & 4 deletions include/experimental/__p1673_bits/blas1_linalg_add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,6 @@ struct is_custom_add_avail<
} // end anonymous namespace


// ------------
// PUBLIC API:
// ------------

template<class ElementType_x,
extents<>::size_type ... ext_x,
class Layout_x,
Expand Down
64 changes: 59 additions & 5 deletions include/experimental/__p1673_bits/blas1_matrix_frob_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,43 @@ namespace experimental {
inline namespace __p1673_version_0 {
namespace linalg {

// begin anonymous namespace
namespace {

template <class Exec, class A_t, class Scalar, class = void>
struct is_custom_matrix_frob_norm_avail : std::false_type {};

template <class Exec, class A_t, class Scalar>
struct is_custom_matrix_frob_norm_avail<
Exec, A_t, Scalar,
std::enable_if_t<
std::is_same<
decltype(matrix_frob_norm
(std::declval<Exec>(),
std::declval<A_t>(),
std::declval<Scalar>()
)
),
Scalar
>::value
>
>
{
static constexpr bool value =
!std::is_same<Exec, std::experimental::linalg::impl::inline_exec_t>::value;
};

} // end anonymous namespace

template<
class ElementType,
extents<>::size_type numRows,
extents<>::size_type numRows,
extents<>::size_type numCols,
class Layout,
class Accessor,
class Scalar>
Scalar matrix_frob_norm(
std::experimental::linalg::impl::inline_exec_t&& /* exec */,
std::experimental::mdspan<ElementType, std::experimental::extents<numRows, numCols>, Layout, Accessor> A,
Scalar init)
{
Expand Down Expand Up @@ -100,25 +129,50 @@ Scalar matrix_frob_norm(

template<class ExecutionPolicy,
class ElementType,
extents<>::size_type numRows,
extents<>::size_type numRows,
extents<>::size_type numCols,
class Layout,
class Accessor,
class Scalar>
Scalar matrix_frob_norm(
ExecutionPolicy&& /* exec */,
ExecutionPolicy&& exec,
std::experimental::mdspan<ElementType, std::experimental::extents<numRows, numCols>, Layout, Accessor> A,
Scalar init)
{
return matrix_frob_norm(A, init);

constexpr bool use_custom = is_custom_matrix_frob_norm_avail<
decltype(execpolicy_mapper(exec)), decltype(A), Scalar
>::value;

if constexpr(use_custom){
return matrix_frob_norm(execpolicy_mapper(exec), A, init);
}
else{
return matrix_frob_norm(std::experimental::linalg::impl::inline_exec_t(), A, init);
}
}

template<
class ElementType,
extents<>::size_type numRows,
extents<>::size_type numCols,
class Layout,
class Accessor,
class Scalar>
Scalar matrix_frob_norm(
std::experimental::mdspan<ElementType, std::experimental::extents<numRows, numCols>, Layout, Accessor> A,
Scalar init)
{
return matrix_frob_norm(std::experimental::linalg::impl::default_exec_t(), A, init);
}


// TODO: Implement auto functions
#if 0
template<class in_matrix_t>
auto matrix_frob_norm(in_matrix_t A)
{

}

template<class ExecutionPolicy,
Expand Down
71 changes: 62 additions & 9 deletions include/experimental/__p1673_bits/blas1_matrix_inf_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,43 @@ namespace experimental {
inline namespace __p1673_version_0 {
namespace linalg {

// begin anonymous namespace
namespace {

template <class Exec, class A_t, class Scalar, class = void>
struct is_custom_matrix_inf_norm_avail : std::false_type {};

template <class Exec, class A_t, class Scalar>
struct is_custom_matrix_inf_norm_avail<
Exec, A_t, Scalar,
std::enable_if_t<
std::is_same<
decltype(matrix_inf_norm
(std::declval<Exec>(),
std::declval<A_t>(),
std::declval<Scalar>()
)
),
Scalar
>::value
>
>
{
static constexpr bool value =
!std::is_same<Exec, std::experimental::linalg::impl::inline_exec_t>::value;
};

} // end anonymous namespace

template<
class ElementType,
extents<>::size_type numRows,
extents<>::size_type numRows,
extents<>::size_type numCols,
class Layout,
class Accessor,
class Scalar>
Scalar matrix_inf_norm(
std::experimental::linalg::impl::inline_exec_t&& /* exec */,
std::experimental::mdspan<ElementType, std::experimental::extents<numRows, numCols>, Layout, Accessor> A,
Scalar init)
{
Expand Down Expand Up @@ -89,28 +118,52 @@ Scalar matrix_inf_norm(
template<
class ExecutionPolicy,
class ElementType,
extents<>::size_type numRows,
extents<>::size_type numRows,
extents<>::size_type numCols,
class Layout,
class Accessor,
class Scalar>
Scalar matrix_inf_norm(
ExecutionPolicy&& /* exec */,
ExecutionPolicy&& exec,
std::experimental::mdspan<ElementType, std::experimental::extents<numRows, numCols>, Layout, Accessor> A,
Scalar init)
{

constexpr bool use_custom = is_custom_matrix_inf_norm_avail<
decltype(execpolicy_mapper(exec)), decltype(A), Scalar
>::value;

if constexpr(use_custom){
return matrix_inf_norm(execpolicy_mapper(exec), A, init);
}
else{
return matrix_inf_norm(std::experimental::linalg::impl::inline_exec_t(), A, init);
}
}

template<
class ElementType,
extents<>::size_type numRows,
extents<>::size_type numCols,
class Layout,
class Accessor,
class Scalar>
Scalar matrix_inf_norm(
std::experimental::mdspan<ElementType, std::experimental::extents<numRows, numCols>, Layout, Accessor> A,
Scalar init)
{
return matrix_inf_norm(A, init);
return matrix_inf_norm(std::experimental::linalg::impl::default_exec_t(), A, init);
}

namespace matrix_inf_norm_detail {

using std::abs;

// The point of this is to do correct ADL for abs,
// without exposing "using std::abs" in the outer namespace.
template<
class ElementType,
extents<>::size_type numRows,
extents<>::size_type numRows,
extents<>::size_type numCols,
class Layout,
class Accessor>
Expand All @@ -121,21 +174,21 @@ namespace matrix_inf_norm_detail {

template<
class ElementType,
extents<>::size_type numRows,
extents<>::size_type numRows,
extents<>::size_type numCols,
class Layout,
class Accessor>
auto matrix_inf_norm(
std::experimental::mdspan<ElementType, std::experimental::extents<numRows, numCols>, Layout, Accessor> A)
-> decltype(matrix_inf_norm_detail::matrix_inf_norm_return_type_deducer(A))
{
{
using return_t = decltype(matrix_inf_norm_detail::matrix_inf_norm_return_type_deducer(A));
return matrix_inf_norm(A, return_t{});
}

template<class ExecutionPolicy,
class ElementType,
extents<>::size_type numRows,
extents<>::size_type numRows,
extents<>::size_type numCols,
class Layout,
class Accessor>
Expand Down
Loading

0 comments on commit bb4b1ef

Please sign in to comment.