Skip to content

Commit

Permalink
Pass views by const& in fftshift (#192)
Browse files Browse the repository at this point in the history
* Pass views by const& in fftshift

* Review from yasahi-hpc

Co-authored-by: yasahi-hpc <[email protected]>

---------

Co-authored-by: yasahi-hpc <[email protected]>
  • Loading branch information
tpadioleau and yasahi-hpc authored Nov 5, 2024
1 parent 29389f2 commit 367e7cc
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions common/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ auto get_shift(const ViewType& inout, axis_type<DIM> axes, int direction = 1) {
}

template <typename ExecutionSpace, typename ViewType>
void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<1> shift,
axis_type<1>) {
void roll(const ExecutionSpace& exec_space, const ViewType& inout,
axis_type<1> shift, axis_type<1> /* axes */) {
// Last parameter is ignored but present for keeping the interface consistent
static_assert(ViewType::rank() == 1, "roll: Rank of View must be 1.");
int n0 = inout.extent_int(0);
Expand All @@ -67,13 +67,13 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<1> shift,
}
});

inout = tmp;
Kokkos::deep_copy(inout, tmp);
}
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM1 = 1>
void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift,
axis_type<DIM1> axes) {
void roll(const ExecutionSpace& exec_space, const ViewType& inout,
axis_type<2> shift, axis_type<DIM1> axes) {
constexpr int DIM0 = 2;
static_assert(ViewType::rank() == DIM0, "roll: Rank of View must be 2.");
int n0 = inout.extent_int(0), n1 = inout.extent_int(1);
Expand Down Expand Up @@ -128,18 +128,18 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift,
}
});

inout = tmp;
Kokkos::deep_copy(inout, tmp);
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
void fftshift_impl(const ExecutionSpace& exec_space, const ViewType& inout,
axis_type<DIM> axes) {
auto shift = get_shift(inout, axes);
roll(exec_space, inout, shift, axes);
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void ifftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
void ifftshift_impl(const ExecutionSpace& exec_space, const ViewType& inout,
axis_type<DIM> axes) {
auto shift = get_shift(inout, axes, -1);
roll(exec_space, inout, shift, axes);
Expand Down Expand Up @@ -219,7 +219,7 @@ auto rfftfreq(const ExecutionSpace&, const std::size_t n,
/// \param axes [in] Axes over which to shift (default: nullopt, shifting over
/// all axes)
template <typename ExecutionSpace, typename ViewType>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
void fftshift(const ExecutionSpace& exec_space, const ViewType& inout,
std::optional<int> axes = std::nullopt) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"fftshift: View value type must be float, double, "
Expand All @@ -246,7 +246,7 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
/// \param inout [in,out] Spectrum
/// \param axes [in] Axes over which to shift
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
void fftshift(const ExecutionSpace& exec_space, const ViewType& inout,
axis_type<DIM> axes) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"fftshift: View value type must be float, double, "
Expand All @@ -269,7 +269,7 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
/// \param axes [in] Axes over which to shift (default: nullopt, shifting over
/// all axes)
template <typename ExecutionSpace, typename ViewType>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
void ifftshift(const ExecutionSpace& exec_space, const ViewType& inout,
std::optional<int> axes = std::nullopt) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"ifftshift: View value type must be float, double, "
Expand All @@ -295,7 +295,7 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
/// \param inout [in,out] Spectrum
/// \param axes [in] Axes over which to shift
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
void ifftshift(const ExecutionSpace& exec_space, const ViewType& inout,
axis_type<DIM> axes) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"ifftshift: View value type must be float, double, "
Expand Down

0 comments on commit 367e7cc

Please sign in to comment.