Skip to content

Commit

Permalink
Modify implementation to match proposal
Browse files Browse the repository at this point in the history
Specifically dealing with constraints and mandates.
See kokkos#66.
  • Loading branch information
amklinv-nnl committed Jun 18, 2021
1 parent fbfa0a0 commit 539e6df
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 2 deletions.
4 changes: 4 additions & 0 deletions include/experimental/__p1673_bits/blas1_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ Scalar dot(std::experimental::mdspan<ElementType1, std::experimental::extents<ex
std::experimental::mdspan<ElementType2, std::experimental::extents<ext2>, Layout2, Accessor2> v2,
Scalar init)
{
static_assert(v1.static_extent(0) == dynamic_extent ||
v2.static_extent(0) == dynamic_extent ||
v1.static_extent(0) == v2.static_extent(0));

for (size_t k = 0; k < v1.extent(0); ++k) {
init += v1(k) * v2(k);
}
Expand Down
8 changes: 8 additions & 0 deletions include/experimental/__p1673_bits/blas1_givens.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ void givens_rotation_apply(
const Real c,
const Real s)
{
static_assert(x.static_extent(0) == dynamic_extent ||
y.static_extent(0) == dynamic_extent ||
x.static_extent(0) == y.static_extent(0));

for (extents<>::size_type i = 0; i < x.extent(0); ++i) {
const auto dtemp = c * x(i) + s * y(i);
y(i) = c * y(i) - s * x(i);
Expand Down Expand Up @@ -401,6 +405,10 @@ void givens_rotation_apply(
const Real c,
const complex<Real> s)
{
static_assert(x.static_extent(0) == dynamic_extent ||
y.static_extent(0) == dynamic_extent ||
x.static_extent(0) == y.static_extent(0));

using std::conj;
for (extents<>::size_type i = 0; i < x.extent(0); ++i) {
const auto dtemp = c * x(i) + s * y(i);
Expand Down
30 changes: 30 additions & 0 deletions include/experimental/__p1673_bits/blas1_linalg_add.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ void add_rank_1(
std::experimental::mdspan<ElementType_y, std::experimental::extents<ext_y>, Layout_y, Accessor_y> y,
std::experimental::mdspan<ElementType_z, std::experimental::extents<ext_z>, Layout_z, Accessor_z> z)
{
static_assert(x.static_extent(0) == dynamic_extent ||
z.static_extent(0) == dynamic_extent ||
x.static_extent(0) == z.static_extent(0));
static_assert(y.static_extent(0) == dynamic_extent ||
z.static_extent(0) == dynamic_extent ||
y.static_extent(0) == z.static_extent(0));
static_assert(x.static_extent(0) == dynamic_extent ||
y.static_extent(0) == dynamic_extent ||
x.static_extent(0) == y.static_extent(0));

for (extents<>::size_type i = 0; i < z.extent(0); ++i) {
z(i) = x(i) + y(i);
}
Expand All @@ -92,6 +102,26 @@ void add_rank_2(
std::experimental::mdspan<ElementType_y, std::experimental::extents<numRows_y, numCols_y>, Layout_y, Accessor_y> y,
std::experimental::mdspan<ElementType_z, std::experimental::extents<numRows_z, numCols_z>, Layout_z, Accessor_z> z)
{
static_assert(x.static_extent(0) == dynamic_extent ||
z.static_extent(0) == dynamic_extent ||
x.static_extent(0) == z.static_extent(0));
static_assert(y.static_extent(0) == dynamic_extent ||
z.static_extent(0) == dynamic_extent ||
y.static_extent(0) == z.static_extent(0));
static_assert(x.static_extent(0) == dynamic_extent ||
y.static_extent(0) == dynamic_extent ||
x.static_extent(0) == y.static_extent(0));

static_assert(x.static_extent(1) == dynamic_extent ||
z.static_extent(1) == dynamic_extent ||
x.static_extent(1) == z.static_extent(1));
static_assert(y.static_extent(1) == dynamic_extent ||
z.static_extent(1) == dynamic_extent ||
y.static_extent(1) == z.static_extent(1));
static_assert(x.static_extent(1) == dynamic_extent ||
y.static_extent(1) == dynamic_extent ||
x.static_extent(1) == y.static_extent(1));

using size_type = typename extents<>::size_type;

for (size_type j = 0; j < x.extent(1); ++j) {
Expand Down
11 changes: 11 additions & 0 deletions include/experimental/__p1673_bits/blas1_linalg_copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ void copy_rank_1(
std::experimental::mdspan<ElementType_x, std::experimental::extents<ext_x>, Layout_x, Accessor_x> x,
std::experimental::mdspan<ElementType_y, std::experimental::extents<ext_y>, Layout_y, Accessor_y> y)
{
static_assert(x.static_extent(0) == dynamic_extent ||
y.static_extent(0) == dynamic_extent ||
x.static_extent(0) == y.static_extent(0));

for (extents<>::size_type i = 0; i < y.extent(0); ++i) {
y(i) = x(i);
}
Expand All @@ -81,6 +85,13 @@ void copy_rank_2(
std::experimental::mdspan<ElementType_x, std::experimental::extents<numRows_x, numCols_x>, Layout_x, Accessor_x> x,
std::experimental::mdspan<ElementType_y, std::experimental::extents<numRows_y, numCols_y>, Layout_y, Accessor_y> y)
{
static_assert(x.static_extent(0) == dynamic_extent ||
y.static_extent(0) == dynamic_extent ||
x.static_extent(0) == y.static_extent(0));
static_assert(x.static_extent(1) == dynamic_extent ||
y.static_extent(1) == dynamic_extent ||
x.static_extent(1) == y.static_extent(1));

using size_type = typename extents<>::size_type;

for (size_type j = 0; j < y.extent(1); ++j) {
Expand Down
15 changes: 13 additions & 2 deletions include/experimental/__p1673_bits/blas1_linalg_swap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ void swap_rank_1(
std::experimental::mdspan<ElementType_x, std::experimental::extents<ext_x>, Layout_x, Accessor_x> x,
std::experimental::mdspan<ElementType_y, std::experimental::extents<ext_y>, Layout_y, Accessor_y> y)
{
static_assert(x.static_extent(0) == dynamic_extent ||
y.static_extent(0) == dynamic_extent ||
x.static_extent(0) == y.static_extent(0));

using std::swap;
using size_type = typename extents<>::size_type;

Expand All @@ -86,6 +90,13 @@ void swap_rank_2(
std::experimental::mdspan<ElementType_x, std::experimental::extents<numRows_x, numCols_x>, Layout_x, Accessor_x> x,
std::experimental::mdspan<ElementType_y, std::experimental::extents<numRows_y, numCols_y>, Layout_y, Accessor_y> y)
{
static_assert(x.static_extent(0) == dynamic_extent ||
y.static_extent(0) == dynamic_extent ||
x.static_extent(0) == y.static_extent(0));
static_assert(x.static_extent(1) == dynamic_extent ||
y.static_extent(1) == dynamic_extent ||
x.static_extent(1) == y.static_extent(1));

using std::swap;
using size_type = typename extents<>::size_type;

Expand All @@ -111,10 +122,10 @@ void swap_elements(
std::experimental::mdspan<ElementType_x, std::experimental::extents<ext_x ...>, Layout_x, Accessor_x> x,
std::experimental::mdspan<ElementType_y, std::experimental::extents<ext_y ...>, Layout_y, Accessor_y> y)
{
if constexpr (x.rank() == 1) {
if constexpr (x.rank() == 1 && y.rank() == 1) {
swap_rank_1(x, y);
}
else if constexpr (x.rank() == 2) {
else if constexpr (x.rank() == 2 && y.rank() == 2) {
swap_rank_2(x, y);
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ template<class ElementType_A,
extents<>::size_type ext_y,
class Layout_y,
class Accessor_y>
requires (Layout_A::is_always_unique())
void matrix_vector_product(
std::experimental::mdspan<ElementType_A, std::experimental::extents<numRows_A, numCols_A>, Layout_A, Accessor_A> A,
std::experimental::mdspan<ElementType_x, std::experimental::extents<ext_x>, Layout_x, Accessor_x> x,
Expand Down Expand Up @@ -123,6 +124,7 @@ template<class ElementType_A,
extents<>::size_type ext_z,
class Layout_z,
class Accessor_z>
requires (Layout_A::is_always_unique())
void matrix_vector_product(
std::experimental::mdspan<ElementType_A, std::experimental::extents<numRows_A, numCols_A>, Layout_A, Accessor_A> A,
std::experimental::mdspan<ElementType_x, std::experimental::extents<ext_x>, Layout_x, Accessor_x> x,
Expand Down Expand Up @@ -183,6 +185,7 @@ template<class ElementType_A,
extents<>::size_type ext_y,
class Layout_y,
class Accessor_y>
requires (Layout_A::is_always_unique())
void symmetric_matrix_vector_product(
std::experimental::mdspan<ElementType_A, std::experimental::extents<numRows_A, numCols_A>, Layout_A, Accessor_A> A,
Triangle t,
Expand Down Expand Up @@ -260,6 +263,7 @@ template<class ElementType_A,
extents<>::size_type ext_z,
class Layout_z,
class Accessor_z>
requires (Layout_A::is_always_unique())
void symmetric_matrix_vector_product(
std::experimental::mdspan<ElementType_A, std::experimental::extents<numRows_A, numCols_A>, Layout_A, Accessor_A> A,
Triangle t,
Expand Down Expand Up @@ -339,6 +343,7 @@ template<class ElementType_A,
extents<>::size_type ext_y,
class Layout_y,
class Accessor_y>
requires (Layout_A::is_always_unique())
void hermitian_matrix_vector_product(
std::experimental::mdspan<ElementType_A, std::experimental::extents<numRows_A, numCols_A>, Layout_A, Accessor_A> A,
Triangle t,
Expand Down Expand Up @@ -387,6 +392,7 @@ template<class ExecutionPolicy,
extents<>::size_type ext_y,
class Layout_y,
class Accessor_y>
requires (Layout_A::is_always_unique())
void hermitian_matrix_vector_product(
ExecutionPolicy&& /* exec */,
std::experimental::mdspan<ElementType_A, std::experimental::extents<numRows_A, numCols_A>, Layout_A, Accessor_A> A,
Expand Down Expand Up @@ -417,6 +423,7 @@ template<class ElementType_A,
extents<>::size_type ext_z,
class Layout_z,
class Accessor_z>
requires (Layout_A::is_always_unique())
void hermitian_matrix_vector_product(
std::experimental::mdspan<ElementType_A, std::experimental::extents<numRows_A, numCols_A>, Layout_A, Accessor_A> A,
Triangle t,
Expand Down Expand Up @@ -498,6 +505,7 @@ template<class ElementType_A,
extents<>::size_type ext_y,
class Layout_y,
class Accessor_y>
requires (Layout_A::is_always_unique())
void triangular_matrix_vector_product(
std::experimental::mdspan<ElementType_A, std::experimental::extents<numRows_A, numCols_A>, Layout_A, Accessor_A> A,
Triangle t,
Expand Down Expand Up @@ -585,6 +593,7 @@ template<class ElementType_A,
extents<>::size_type ext_z,
class Layout_z,
class Accessor_z>
requires (Layout_A::is_always_unique())
void triangular_matrix_vector_product(
std::experimental::mdspan<ElementType_A, std::experimental::extents<numRows_A, numCols_A>, Layout_A, Accessor_A> A,
Triangle t,
Expand Down

0 comments on commit 539e6df

Please sign in to comment.