Skip to content

Commit

Permalink
Move dft_plan::execute to src
Browse files Browse the repository at this point in the history
  • Loading branch information
dancazarin committed Feb 12, 2024
1 parent a3abae2 commit 68b99bb
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 167 deletions.
194 changes: 27 additions & 167 deletions include/kfr/dft/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,21 @@ struct dft_stage;
template <typename T>
using dft_stage_ptr = std::unique_ptr<dft_stage<T>>;

namespace internal_generic
{
template <typename T>
void dft_initialize(dft_plan<T>& plan);
template <typename T>
void dft_real_initialize(dft_plan_real<T>& plan);
template <typename T, bool inverse>
void dft_execute(const dft_plan<T>& plan, cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp);

template <typename T>
using fn_transpose = void (*)(complex<T>*, const complex<T>*, shape<2>);
template <typename T>
void dft_initialize_transpose(fn_transpose<T>& transpose);

} // namespace internal_generic

/// @brief 1D DFT/FFT
template <typename T>
Expand Down Expand Up @@ -164,16 +175,10 @@ struct dft_plan
explicit dft_plan(size_t size, dft_order order = dft_order::normal)
: size(size), temp_size(0), data_size(0), arblen(false)
{
dft_initialize(*this);
internal_generic::dft_initialize(*this);
}

void dump() const
{
for (const std::unique_ptr<dft_stage<T>>& s : all_stages)
{
s->dump();
}
}
void dump() const;

KFR_MEM_INTRINSIC void execute(complex<T>* out, const complex<T>* in, u8* temp,
bool inverse = false) const
Expand Down Expand Up @@ -233,81 +238,10 @@ struct dft_plan
std::array<bitset, 2> disposition_inplace;
std::array<bitset, 2> disposition_outofplace;

void calc_disposition()
{
for (bool inverse : { false, true })
{
auto&& stages = this->stages[inverse];
bitset can_inplace_per_stage;
for (int i = 0; i < stages.size(); ++i)
{
can_inplace_per_stage[i] = stages[i]->can_inplace;
}
void calc_disposition();

disposition_inplace[static_cast<int>(inverse)] =
precompute_disposition(stages.size(), can_inplace_per_stage, true);
disposition_outofplace[static_cast<int>(inverse)] =
precompute_disposition(stages.size(), can_inplace_per_stage, false);
}
}

static bitset precompute_disposition(int num_stages, bitset can_inplace_per_stage, bool inplace_requested)
{
static bitset even{ 0x5555555555555555ull };
bitset mask = ~bitset() >> (DFT_MAX_STAGES - num_stages);
bitset result;
// disposition indicates where is input for corresponding stage
// first bit : 0 - input, 1 - scratch
// other bits: 0 - output, 1 - scratch

// build disposition that works always
if (num_stages % 2 == 0)
{ // even
result = ~even & mask;
}
else
{ // odd
result = even & mask;
}

int num_inplace = can_inplace_per_stage.count();

#ifdef KFR_DFT_ELIMINATE_MEMCPY
if (num_inplace > 0 && inplace_requested)
{
if (result.test(0)) // input is in scratch
{
// num_inplace must be odd
if (num_inplace % 2 == 0)
--num_inplace;
}
else
{
// num_inplace must be even
if (num_inplace % 2 != 0)
--num_inplace;
}
}
#endif
if (num_inplace > 0)
{
for (int i = num_stages - 1; i >= 0; --i)
{
if (can_inplace_per_stage.test(i))
{
result ^= ~bitset() >> (DFT_MAX_STAGES - (i + 1));

if (--num_inplace == 0)
break;
}
}
}

if (!inplace_requested) // out-of-place first stage; IN->OUT
result.reset(0);

return result;
}
static bitset precompute_disposition(int num_stages, bitset can_inplace_per_stage,
bool inplace_requested);

protected:
struct noinit
Expand All @@ -317,89 +251,11 @@ struct dft_plan
: size(size), temp_size(0), data_size(0), arblen(false)
{
}
const complex<T>* select_in(bitset disposition, size_t stage, const complex<T>* out, const complex<T>* in,
const complex<T>* scratch) const
{
return disposition.test(stage) ? scratch : stage == 0 ? in : out;
}
complex<T>* select_out(bitset disposition, size_t stage, size_t total_stages, complex<T>* out,
complex<T>* scratch) const
{
return stage == total_stages - 1 ? out : disposition.test(stage + 1) ? scratch : out;
}

template <bool inverse>
void execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const
KFR_INTRINSIC void execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const
{
if (temp == nullptr && temp_size > 0)
{
return call_with_temp(temp_size, std::bind(&dft_plan<T>::execute_dft<inverse>, this,
cbool_t<inverse>{}, out, in, std::placeholders::_1));
}
auto&& stages = this->stages[inverse];
if (stages.size() == 1 && (stages[0]->can_inplace || in != out))
{
return stages[0]->execute(cbool<inverse>, out, in, temp);
}
size_t stack[DFT_MAX_STAGES] = { 0 };

bitset disposition =
in == out ? this->disposition_inplace[inverse] : this->disposition_outofplace[inverse];

complex<T>* scratch = ptr_cast<complex<T>>(
temp + this->temp_size -
align_up(sizeof(complex<T>) * (this->size + 1), platform<>::native_cache_alignment));

bool in_scratch = disposition.test(0);
if (in_scratch)
{
stages[0]->copy_input(inverse, scratch, in, this->size);
}

const size_t count = stages.size();

for (size_t depth = 0; depth < count;)
{
if (stages[depth]->recursion)
{
size_t offset = 0;
size_t rdepth = depth;
size_t maxdepth = depth;
do
{
if (stack[rdepth] == stages[rdepth]->repeats)
{
stack[rdepth] = 0;
rdepth--;
}
else
{
complex<T>* rout = select_out(disposition, rdepth, stages.size(), out, scratch);
const complex<T>* rin = select_in(disposition, rdepth, out, in, scratch);
stages[rdepth]->execute(cbool<inverse>, rout + offset, rin + offset, temp);
offset += stages[rdepth]->out_offset;
stack[rdepth]++;
if (rdepth < count - 1 && stages[rdepth + 1]->recursion)
rdepth++;
else
maxdepth = rdepth;
}
} while (rdepth != depth);
depth = maxdepth + 1;
}
else
{
size_t offset = 0;
while (offset < this->size)
{
stages[depth]->execute(
cbool<inverse>, select_out(disposition, depth, stages.size(), out, scratch) + offset,
select_in(disposition, depth, out, in, scratch) + offset, temp);
offset += stages[depth]->stage_size;
}
depth++;
}
}
internal_generic::dft_execute(*this, cbool<inverse>, out, in, temp);
}
};

Expand Down Expand Up @@ -434,7 +290,7 @@ struct dft_plan_real : dft_plan<T>
: dft_plan<T>(typename dft_plan<T>::noinit{}, size / 2), size(size), fmt(fmt)
{
KFR_LOGIC_CHECK(is_even(size), "dft_plan_real requires size to be even");
dft_real_initialize(*this);
internal_generic::dft_real_initialize(*this);
}

void execute(complex<T>*, const complex<T>*, u8*, bool = false) const = delete;
Expand Down Expand Up @@ -527,6 +383,7 @@ struct dft_plan_md
dfts[i] = dft_plan<T>(this->size[i]);
temp_size = std::max(temp_size, dfts[i].temp_size);
}
internal_generic::dft_initialize_transpose(transpose);
}

void execute(complex<T>* out, const complex<T>* in, u8* temp, bool inverse = false) const
Expand Down Expand Up @@ -590,7 +447,7 @@ struct dft_plan_md
builtin_memcpy(out, in, sizeof(complex<T>) * total);
}

matrix_transpose(out, out, shape{ sh.remove_back().product(), sh.back() });
transpose(out, out, shape{ sh.remove_back().product(), sh.back() });

if (axis == 0)
break;
Expand All @@ -607,7 +464,7 @@ struct dft_plan_md
index_t axis = 0;
for (;;)
{
matrix_transpose(out, in, shape{ sh.front(), sh.remove_front().product() });
transpose(out, in, shape{ sh.front(), sh.remove_front().product() });

if (size[axis] > 1)
{
Expand All @@ -626,6 +483,7 @@ struct dft_plan_md
using dft_list =
std::conditional_t<Dims == dynamic_shape, std::vector<dft_plan<T>>, std::array<dft_plan<T>, Dims>>;
dft_list dfts;
internal_generic::fn_transpose<T> transpose;
};

/// @brief Multidimensional DFT
Expand Down Expand Up @@ -687,6 +545,7 @@ struct dft_plan_md_real
{
temp_size += complex_size().product() * sizeof(complex<T>);
}
internal_generic::dft_initialize_transpose(transpose);
}

void execute(complex<T>* out, const T* in, u8* temp, cdirect_t = {}) const
Expand Down Expand Up @@ -803,7 +662,7 @@ struct dft_plan_md_real
dfts[axis].execute(out + o, out + o, temp, cfalse);
}

matrix_transpose(out, out, shape{ sh.remove_back().product(), sh.back() });
transpose(out, out, shape{ sh.remove_back().product(), sh.back() });

if (axis == 0)
break;
Expand All @@ -822,7 +681,7 @@ struct dft_plan_md_real
index_t axis = 0;
for (;;)
{
matrix_transpose(out, in, shape{ sh.front(), sh.remove_front().product() });
transpose(out, in, shape{ sh.front(), sh.remove_front().product() });

if (size[axis] > 1)
{
Expand All @@ -847,6 +706,7 @@ struct dft_plan_md_real
std::array<dft_plan<T>, const_max(Dims, 1) - 1>>;
dft_list dfts;
dft_plan_real<T> dft_real;
internal_generic::fn_transpose<T> transpose;
};

/// @brief DCT type 2 (unscaled)
Expand Down
Loading

0 comments on commit 68b99bb

Please sign in to comment.