Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put tests of FFT backends into TestUtils submodule #78

Merged
merged 19 commits into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ version = "1.4.0"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
AbstractFFTsTestExt = "Test"

[compat]
ChainRulesCore = "1"
Expand Down
15 changes: 15 additions & 0 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,18 @@ To define a new FFT implementation in your own module, you should

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.

## Testing implementations

`AbstractFFTs.jl` provides an experimental `TestUtils` module to help with testing downstream implementations,
available as a [weak extension](https://pkgdocs.julialang.org/v1.9/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) of `Test`.
The following functions test that all FFT functionality has been correctly implemented:
```@docs
AbstractFFTs.TestUtils.test_complex_ffts
AbstractFFTs.TestUtils.test_real_ffts
```
`TestUtils` also exposes lower level functions for generically testing particular plans:
```@docs
AbstractFFTs.TestUtils.test_plan
AbstractFFTs.TestUtils.test_plan_adjoint
```
232 changes: 232 additions & 0 deletions ext/AbstractFFTsTestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license

module AbstractFFTsTestExt

using AbstractFFTs
using AbstractFFTs: TestUtils
using AbstractFFTs.LinearAlgebra
using Test

# Ground truth x_fft computed using FFTW library
const TEST_CASES = (
(; x = collect(1:7), dims = 1,
x_fft = [28.0 + 0.0im,
-3.5 + 7.267824888003178im,
-3.5 + 2.7911568610884143im,
-3.5 + 0.7988521603655248im,
-3.5 - 0.7988521603655248im,
-3.5 - 2.7911568610884143im,
-3.5 - 7.267824888003178im]),
(; x = collect(1:8), dims = 1,
x_fft = [36.0 + 0.0im,
-4.0 + 9.65685424949238im,
-4.0 + 4.0im,
-4.0 + 1.6568542494923806im,
-4.0 + 0.0im,
-4.0 - 1.6568542494923806im,
-4.0 - 4.0im,
-4.0 - 9.65685424949238im]),
(; x = collect(reshape(1:8, 2, 4)), dims = 2,
x_fft = [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im;
20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]),
(; x = collect(reshape(1:9, 3, 3)), dims = 2,
x_fft = [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
(; x = collect(reshape(1:8, 2, 2, 2)), dims = 1:2,
x_fft = cat([10.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
[26.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
dims=3)),
(; x = collect(1:7) + im * collect(8:14), dims = 1,
x_fft = [28.0 + 77.0im,
-10.76782488800318 + 3.767824888003175im,
-6.291156861088416 - 0.7088431389115883im,
-4.298852160365525 - 2.7011478396344746im,
-2.7011478396344764 - 4.298852160365524im,
-0.7088431389115866 - 6.291156861088417im,
3.767824888003177 - 10.76782488800318im]),
(; x = collect(reshape(1:8, 2, 2, 2)) + im * reshape(9:16, 2, 2, 2), dims = 1:2,
x_fft = cat([10.0 + 42.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
[26.0 + 58.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
dims=3)),
)


function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
_copy = copy_input ? copy : identity
if !inplace_plan
@test P * _copy(x) ≈ x_transformed
@test P \ (P * _copy(x)) ≈ x
_x_out = similar(P * _copy(x))
@test mul!(_x_out, P, _copy(x)) ≈ x_transformed
@test _x_out ≈ x_transformed
else
_x = copy(x)
@test P * _copy(_x) ≈ x_transformed
@test _x ≈ x_transformed
@test P \ _copy(_x) ≈ x
@test _x ≈ x
end
end

function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false)
_copy = copy_input ? copy : identity
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
# test basic properties
@test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
@test (P')' === P # test adjoint of adjoint
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
# test correctness of adjoint and its inverse via the dot test
if !real_plan
@test dot(y, P * _copy(x)) ≈ dot(P' * _copy(y), x)
@test dot(y, P \ _copy(x)) ≈ dot(P' \ _copy(y), x)
else
_component_dot(x, y) = dot(real.(x), real.(y)) + dot(imag.(x), imag.(y))
@test _component_dot(y, P * _copy(x)) ≈ _component_dot(P' * _copy(y), x)
@test _component_dot(x, P \ _copy(y)) ≈ _component_dot(P' \ _copy(x), y)
end
@test_throws MethodError mul!(x, P', y)
end

function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
@testset "correctness of fft, bfft, ifft" begin
for test_case in TEST_CASES
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
x = convert(ArrayType, _x) # dummy array that will be passed to plans
x_complexf = convert(ArrayType, complex.(float.(x))) # for testing mutating complex FFTs
x_fft = convert(ArrayType, _x_fft)

# FFT
@test fft(x, dims) ≈ x_fft
if test_inplace
_x_complexf = copy(x_complexf)
@test fft!(_x_complexf, dims) ≈ x_fft
@test _x_complexf ≈ x_fft
end
# test OOP plans, checking plan_fft and also inv and plan_inv of plan_ifft,
# which should give functionally identical plans
for P in (plan_fft(similar(x_complexf), dims),
(_inv(plan_ifft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_complexf, x_fft)
if test_adjoint
@test fftdims(P') == fftdims(P)
TestUtils.test_plan_adjoint(P, x_complexf)
end
end
if test_inplace
# test IIP plans
for P in (plan_fft!(similar(x_complexf), dims),
(_inv(plan_ifft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true)
end
end

# BFFT
x_scaled = prod(size(x, d) for d in dims) .* x
@test bfft(x_fft, dims) ≈ x_scaled
if test_inplace
_x_fft = copy(x_fft)
@test bfft!(_x_fft, dims) ≈ x_scaled
@test _x_fft ≈ x_scaled
end
# test OOP plans. Just 1 plan to test, but we use a for loop for consistent style
for P in (plan_bfft(similar(x_fft), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x_scaled)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_fft)
end
end
# test IIP plans
for P in (plan_bfft!(similar(x_fft), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true)
end

# IFFT
@test ifft(x_fft, dims) ≈ x
if test_inplace
_x_fft = copy(x_fft)
@test ifft!(_x_fft, dims) ≈ x
@test _x_fft ≈ x
end
# test OOP plans
for P in (plan_ifft(similar(x_complexf), dims),
(_inv(plan_fft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_fft)
end
end
# test IIP plans
if test_inplace
for P in (plan_ifft!(similar(x_complexf), dims),
(_inv(plan_fft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x; inplace_plan=true)
end
end
end
end
end

function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
@testset "correctness of rfft, brfft, irfft" begin
for test_case in TEST_CASES
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
x = convert(ArrayType, _x) # dummy array that will be passed to plans
x_real = float.(x) # for testing mutating real FFTs
x_fft = convert(ArrayType, _x_fft)
x_rfft = collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1)))

if !(eltype(x) <: Real)
continue
end

# RFFT
@test rfft(x, dims) ≈ x_rfft
for P in (plan_rfft(similar(x_real), dims),
(_inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Real
@test fftdims(P) == dims
TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input)
end
end

# BRFFT
x_scaled = prod(size(x, d) for d in dims) .* x
@test brfft(x_rfft, size(x, first(dims)), dims) ≈ x_scaled
for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
end
end

# IRFFT
@test irfft(x_rfft, size(x, first(dims)), dims) ≈ x
for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims),
(_inv(plan_rfft(similar(x_real), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
end
end
end
end
end

end
2 changes: 2 additions & 0 deletions src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ export fft, ifft, bfft, fft!, ifft!, bfft!,
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq

include("definitions.jl")
include("TestUtils.jl")

if !isdefined(Base, :get_extension)
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
include("../ext/AbstractFFTsTestExt.jl")
end

end # module
73 changes: 73 additions & 0 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
module TestUtils

"""
TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)

Run tests to verify correctness of FFT, BFFT, and IFFT functionality using a particular backend plan implementation.
The backend implementation is assumed to be loaded prior to calling this function.

# Arguments

- `ArrayType`: determines the `AbstractArray` implementation for
which the correctness tests are run. Arrays are constructed via
`convert(ArrayType, ...)`.
- `test_inplace=true`: whether to test in-place plans.
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
"""
function test_complex_ffts end

"""
TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)

Run tests to verify correctness of RFFT, BRFFT, and IRFFT functionality using a particular backend plan implementation.
The backend implementation is assumed to be loaded prior to calling this function.

# Arguments

- `ArrayType`: determines the `AbstractArray` implementation for
which the correctness tests are run. Arrays are constructed via
`convert(ArrayType, ...)`.
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
- `copy_input=false`: whether to copy the input before applying the plan in tests, to accomodate for
[input-mutating behaviour of real FFTW plans](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101).
"""
function test_real_ffts end

# Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101)
"""
TestUtils.test_plan(P::Plan, x::AbstractArray, x_transformed::AbstractArray;
inplace_plan=false, copy_input=false)

Test basic properties of a plan `P` given an input array `x` and expected output `x_transformed`.

Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
"""
function test_plan end

"""
TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false, copy_input=false)

Test basic properties of the [adjoint](api.md#Base.adjoint) `P'` of a particular plan given an input array `x`,
including its accuracy via the dot test.

Real-to-complex and complex-to-real plans require a slightly modified dot test, in which case `real_plan=true` should be provided.
The plan is assumed out-of-place, as adjoints are not yet supported for in-place plans.
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
"""
function test_plan_adjoint end

if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint)
function __init__()
# Better error message if users forget to load Test
Base.Experimental.register_error_hint(MethodError) do io, exc, _, _
if any(f -> (f === exc.f), (test_real_ffts, test_complex_ffts, test_plan, test_plan_adjoint)) &&
(Base.get_extension(AbstractFFTs, :AbstractFFTsTestExt) === nothing)
print(io, "\nDid you forget to load Test?")
end
end
end
end

end
3 changes: 2 additions & 1 deletion src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ end

Return a plan that performs the adjoint operation of the original plan.

!!! note
!!! warning
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
coverage of `Base.adjoint` in downstream implementations may be limited.
"""
Expand All @@ -676,6 +676,7 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)

size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)
fftdims(p::AdjointPlan) = fftdims(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x)

Expand Down
Loading