diff --git a/Project.toml b/Project.toml index 56bd62f9..5dc7cc58 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,9 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" +SpecialPolynomials = "a25cea48-d430-424a-8ee7-0d3ad3742e9e" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 22c317f2..aafa0fb7 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -10,6 +10,9 @@ using Zygote using ChainRulesCore using GeometricFlux using Statistics +using Polynomials +using SpecialPolynomials +using LinearAlgebra include("abstracttypes.jl") diff --git a/src/Transform/Transform.jl b/src/Transform/Transform.jl index 2a02f1b7..9e8e68cd 100644 --- a/src/Transform/Transform.jl +++ b/src/Transform/Transform.jl @@ -16,5 +16,8 @@ export """ abstract type AbstractTransform end +include("utils.jl") +include("polynomials.jl") include("fourier_transform.jl") include("chebyshev_transform.jl") +include("wavelet_transform.jl") diff --git a/src/Transform/polynomials.jl b/src/Transform/polynomials.jl new file mode 100644 index 00000000..480af298 --- /dev/null +++ b/src/Transform/polynomials.jl @@ -0,0 +1,203 @@ +function get_filter(base::Symbol, k) + if base == :legendre + return legendre_filter(k) + elseif base == :chebyshev + return chebyshev_filter(k) + else + throw(ArgumentError("base must be one of :legendre or :chebyshev.")) + end +end + +function legendre_ϕ_ψ(k) + # TODO: row-major -> column major + ϕ_coefs = zeros(k, k) + ϕ_2x_coefs = zeros(k, k) + + p = Polynomial([-1, 2]) # 2x-1 + p2 = Polynomial([-1, 4]) # 4x-1 + + for ki in 0:(k - 1) + l = convert(Polynomial, gen_poly(Legendre, ki)) # Legendre of n=ki + ϕ_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 * ki + 1) .* coeffs(l(p)) + ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 * (2 * ki + 1)) .* coeffs(l(p2)) + end + + ψ1_coefs = zeros(k, k) + ψ2_coefs = zeros(k, k) + for ki in 0:(k - 1) + ψ1_coefs[ki + 1, :] .= ϕ_2x_coefs[ki + 1, :] + for i in 0:(k - 1) + a = ϕ_2x_coefs[ki + 1, 1:(ki + 1)] + b = ϕ_coefs[i + 1, 1:(i + 1)] + proj_ = proj_factor(a, b) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) + end + + for j in 0:(k - 1) + a = ϕ_2x_coefs[ki + 1, 1:(ki + 1)] + b = ψ1_coefs[j + 1, :] + proj_ = proj_factor(a, b) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ψ1_coefs, j + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ψ2_coefs, j + 1, :) + end + + a = ψ1_coefs[ki + 1, :] + norm1 = proj_factor(a, a) + + a = ψ2_coefs[ki + 1, :] + norm2 = proj_factor(a, a, complement = true) + norm_ = sqrt(norm1 + norm2) + ψ1_coefs[ki + 1, :] ./= norm_ + ψ2_coefs[ki + 1, :] ./= norm_ + zero_out!(ψ1_coefs) + zero_out!(ψ2_coefs) + end + + ϕ = [Polynomial(ϕ_coefs[i, :]) for i in 1:k] + ψ1 = [Polynomial(ψ1_coefs[i, :]) for i in 1:k] + ψ2 = [Polynomial(ψ2_coefs[i, :]) for i in 1:k] + + return ϕ, ψ1, ψ2 +end + +function chebyshev_ϕ_ψ(k) + ϕ_coefs = zeros(k, k) + ϕ_2x_coefs = zeros(k, k) + + p = Polynomial([-1, 2]) # 2x-1 + p2 = Polynomial([-1, 4]) # 4x-1 + + for ki in 0:(k - 1) + if ki == 0 + ϕ_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 / π) + ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(4 / π) + else + c = convert(Polynomial, gen_poly(Chebyshev, ki)) # Chebyshev of n=ki + ϕ_coefs[ki + 1, 1:(ki + 1)] .= 2 / sqrt(π) .* coeffs(c(p)) + ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2) * 2 / sqrt(π) .* coeffs(c(p2)) + end + end + + ϕ = [ϕ_(ϕ_coefs[i, :]) for i in 1:k] + + k_use = 2k + c = convert(Polynomial, gen_poly(Chebyshev, k_use)) + x_m = roots(c(p)) + # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) + # not needed for our purpose here, we use even k always to avoid + wm = π / k_use / 2 + + ψ1_coefs = zeros(k, k) + ψ2_coefs = zeros(k, k) + + ψ1 = Array{Any, 1}(undef, k) + ψ2 = Array{Any, 1}(undef, k) + + for ki in 0:(k - 1) + ψ1_coefs[ki + 1, :] .= ϕ_2x_coefs[ki + 1, :] + for i in 0:(k - 1) + proj_ = sum(wm .* ϕ[i + 1].(x_m) .* sqrt(2) .* ϕ[ki + 1].(2 * x_m)) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) + end + + for j in 0:(ki - 1) + proj_ = sum(wm .* ψ1[j + 1].(x_m) .* sqrt(2) .* ϕ[ki + 1].(2 * x_m)) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ψ1_coefs, j + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ψ2_coefs, j + 1, :) + end + + ψ1[ki + 1] = ϕ_(ψ1_coefs[ki + 1, :]; lb = 0.0, ub = 0.5) + ψ2[ki + 1] = ϕ_(ψ2_coefs[ki + 1, :]; lb = 0.5, ub = 1.0) + + norm1 = sum(wm .* ψ1[ki + 1].(x_m) .* ψ1[ki + 1].(x_m)) + norm2 = sum(wm .* ψ2[ki + 1].(x_m) .* ψ2[ki + 1].(x_m)) + + norm_ = sqrt(norm1 + norm2) + ψ1_coefs[ki + 1, :] ./= norm_ + ψ2_coefs[ki + 1, :] ./= norm_ + zero_out!(ψ1_coefs) + zero_out!(ψ2_coefs) + + ψ1[ki + 1] = ϕ_(ψ1_coefs[ki + 1, :]; lb = 0.0, ub = 0.5 + 1e-16) + ψ2[ki + 1] = ϕ_(ψ2_coefs[ki + 1, :]; lb = 0.5 + 1e-16, ub = 1.0) + end + + return ϕ, ψ1, ψ2 +end + +function legendre_filter(k) + H0 = zeros(k, k) + H1 = zeros(k, k) + G0 = zeros(k, k) + G1 = zeros(k, k) + ϕ, ψ1, ψ2 = legendre_ϕ_ψ(k) + + l = convert(Polynomial, gen_poly(Legendre, k)) + x_m = roots(l(Polynomial([-1, 2]))) # 2x-1 + m = 2 .* x_m .- 1 + wm = 1 ./ k ./ legendre_der.(k, m) ./ gen_poly(Legendre, k - 1).(m) + + for ki in 0:(k - 1) + for kpi in 0:(k - 1) + H0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].(x_m / 2) .* ϕ[kpi + 1].(x_m)) + G0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ψ(ψ1, ψ2, ki, x_m / 2) .* ϕ[kpi + 1].(x_m)) + H1[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].((x_m .+ 1) / 2) .* ϕ[kpi + 1].(x_m)) + G1[ki + 1, kpi + 1] = 1 / sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m .+ 1) / 2) .* + ϕ[kpi + 1].(x_m)) + end + end + + zero_out!(H0) + zero_out!(H1) + zero_out!(G0) + zero_out!(G1) + + return H0, H1, G0, G1, I(k), I(k) +end + +function chebyshev_filter(k) + H0 = zeros(k, k) + H1 = zeros(k, k) + G0 = zeros(k, k) + G1 = zeros(k, k) + Φ0 = zeros(k, k) + Φ1 = zeros(k, k) + ϕ, ψ1, ψ2 = chebyshev_ϕ_ψ(k) + + k_use = 2k + c = convert(Polynomial, gen_poly(Chebyshev, k_use)) + x_m = roots(c(Polynomial([-1, 2]))) # 2x-1 + # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) + # not needed for our purpose here, we use even k always to avoid + wm = π / k_use / 2 + + for ki in 0:(k - 1) + for kpi in 0:(k - 1) + H0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].(x_m / 2) .* ϕ[kpi + 1].(x_m)) + H1[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].((x_m .+ 1) / 2) .* ϕ[kpi + 1].(x_m)) + G0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ψ(ψ1, ψ2, ki, x_m / 2) .* ϕ[kpi + 1].(x_m)) + G1[ki + 1, kpi + 1] = 1 / sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m .+ 1) / 2) .* + ϕ[kpi + 1].(x_m)) + Φ0[ki + 1, kpi + 1] = 2 * sum(wm .* ϕ[ki + 1].(2x_m) .* ϕ[kpi + 1].(2x_m)) + Φ1[ki + 1, kpi + 1] = 2 * sum(wm .* ϕ[ki + 1].(2 .* x_m .- 1) .* + ϕ[kpi + 1].(2 .* x_m .- 1)) + end + end + + zero_out!(H0) + zero_out!(H1) + zero_out!(G0) + zero_out!(G1) + zero_out!(Φ0) + zero_out!(Φ1) + + return H0, H1, G0, G1, Φ0, Φ1 +end diff --git a/src/Transform/utils.jl b/src/Transform/utils.jl new file mode 100644 index 00000000..a76c13b6 --- /dev/null +++ b/src/Transform/utils.jl @@ -0,0 +1,47 @@ +function ϕ_(ϕ_coefs; lb::Real = 0.0, ub::Real = 1.0) + function partial(x) + mask = (lb ≤ x ≤ ub) * 1.0 + return Polynomial(ϕ_coefs)(x) * mask + end + return partial +end + +function ψ(ψ1, ψ2, i, inp) + mask = (inp .> 0.5) .* 1.0 + return ψ1[i + 1].(inp) .* mask .+ ψ2[i + 1].(inp) .* mask +end + +zero_out!(x; tol = 1e-8) = (x[abs.(x) .< tol] .= 0) + +function gen_poly(poly, n) + x = zeros(n + 1) + x[end] = 1 + return poly(x) +end + +function convolve(a, b) + n = length(b) + y = similar(a, length(a) + n - 1) + for i in 1:length(a) + y[i:(i + n - 1)] .+= a[i] .* b + end + return y +end + +function proj_factor(a, b; complement::Bool = false) + prod_ = convolve(a, b) + r = collect(1:length(prod_)) + s = complement ? (1 .- 0.5 .^ r) : (0.5 .^ r) + proj_ = sum(prod_ ./ r .* s) + return proj_ +end + +_legendre(k, x) = (2k + 1) * gen_poly(Legendre, k)(x) + +function legendre_der(k, x) + out = 0 + for i in (k - 1):-2:-1 + out += _legendre(i, x) + end + return out +end diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl new file mode 100644 index 00000000..d58970bb --- /dev/null +++ b/src/Transform/wavelet_transform.jl @@ -0,0 +1,31 @@ +export WaveletTransform + +struct WaveletTransform{N, S} <: AbstractTransform + ec_d::Any + ec_s::Any + modes::NTuple{N, S} # N == ndims(x) +end + +Base.ndims(::WaveletTransform{N}) where {N} = N + +function transform(wt::WaveletTransform, 𝐱::AbstractArray) + N = size(X, ndims(wt) - 1) + # 1d + Xa = vcat(view(𝐱, :, :, 1:2:N, :), view(𝐱, :, :, 2:2:N, :)) + # 2d + # Xa = vcat( + # view(𝐱, :, :, 1:2:N, 1:2:N, :), + # view(𝐱, :, :, 1:2:N, 2:2:N, :), + # view(𝐱, :, :, 2:2:N, 1:2:N, :), + # view(𝐱, :, :, 2:2:N, 2:2:N, :), + # ) + d = NNlib.batched_mul(Xa, wt.ec_d) + s = NNlib.batched_mul(Xa, wt.ec_s) + return d, s +end + +function inverse(wt::WaveletTransform, 𝐱_fwt::AbstractArray) end + +# function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray) +# return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch] +# end diff --git a/src/operator_kernel.jl b/src/operator_kernel.jl index d131ad34..2eab3647 100644 --- a/src/operator_kernel.jl +++ b/src/operator_kernel.jl @@ -1,7 +1,12 @@ export OperatorConv, SpectralConv, - OperatorKernel + OperatorKernel, + SparseKernel, + SparseKernel1D, + SparseKernel2D, + SparseKernel3D, + MWT_CZ1d struct OperatorConv{P, T, S, TT} weight::T @@ -180,6 +185,143 @@ function (m::OperatorKernel)(𝐱) return m.σ.(m.linear(𝐱) + m.conv(𝐱)) end +""" + SparseKernel(κ, ch, σ=identity) + +Sparse kernel layer. + +## Arguments + +* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP. +* `ch`: Channel size for linear transform, e.g. `32`. +* `σ`: Activation function. +""" +struct SparseKernel{N, T, S} + conv_blk::T + out_weight::S +end + +function SparseKernel(filter::NTuple{N, T}, ch::Pair{S, S}; + init = Flux.glorot_uniform) where {N, T, S} + input_dim, emb_dim = ch + conv = Conv(filter, input_dim => emb_dim, relu; stride = 1, pad = 1, init = init) + W_out = Dense(emb_dim, input_dim; init = init) + return SparseKernel{N, typeof(conv), typeof(W_out)}(conv, W_out) +end + +function SparseKernel1D(k::Int, α, c::Int = 1; init = Flux.glorot_uniform) + input_dim = c * k + emb_dim = 128 + return SparseKernel((3,), input_dim => emb_dim; init = init) +end + +function SparseKernel2D(k::Int, α, c::Int = 1; init = Flux.glorot_uniform) + input_dim = c * k^2 + emb_dim = α * k^2 + return SparseKernel((3, 3), input_dim => emb_dim; init = init) +end + +function SparseKernel3D(k::Int, α, c::Int = 1; init = Flux.glorot_uniform) + input_dim = c * k^2 + emb_dim = α * k^2 + conv = Conv((3, 3, 3), emb_dim => emb_dim, relu; stride = 1, pad = 1, init = init) + W_out = Dense(emb_dim, input_dim; init = init) + return SparseKernel{3, typeof(conv), typeof(W_out)}(conv, W_out) +end + +Flux.@functor SparseKernel + +function (l::SparseKernel)(X::AbstractArray) + bch_sz, _, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + + X_ = l.conv_blk(X) # (dims..., emb_dims, B) + X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B) + Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B) + Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B) + return collect(Y) +end + +struct MWT_CZ1d{T, S, R, Q, P} + k::Int + L::Int + A::T + B::S + C::R + T0::Q + ec_s::P + ec_d::P + rc_e::P + rc_o::P +end + +function MWT_CZ1d(k::Int = 3, α::Int = 5, L::Int = 0, c::Int = 1; base::Symbol = :legendre, + init = Flux.glorot_uniform) + H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k) + H0r = zero_out!(H0 * Φ0) + G0r = zero_out!(G0 * Φ0) + H1r = zero_out!(H1 * Φ1) + G1r = zero_out!(G1 * Φ1) + + dim = c * k + A = SpectralConv(dim => dim, (α,); init = init) + B = SpectralConv(dim => dim, (α,); init = init) + C = SpectralConv(dim => dim, (α,); init = init) + T0 = Dense(k, k) + + ec_s = vcat(H0', H1') + ec_d = vcat(G0', G1') + rc_e = vcat(H0r, G0r) + rc_o = vcat(H1r, G1r) + return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o) +end + +function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T, 4}) where {T} + N = size(X, 3) + Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :)) + d = NNlib.batched_mul(Xa, l.ec_d) + s = NNlib.batched_mul(Xa, l.ec_s) + return d, s +end + +function even_odd(l::MWT_CZ1d, X::AbstractArray{T, 4}) where {T} + bch_sz, N, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + @assert dims[1] == 2 * l.k + Y = similar(X, bch_sz, 2N, l.c, l.k) + view(Y, :, :, 1:2:N, :) .= NNlib.batched_mul(X, l.rc_e) + view(Y, :, :, 2:2:N, :) .= NNlib.batched_mul(X, l.rc_o) + return Y +end + +function (l::MWT_CZ1d)(X::T) where {T <: AbstractArray} + bch_sz, N, dims_r... = reverse(size(X)) + ns = floor(log2(N)) + stop = ns - l.L + + # decompose + Ud = T[] + Us = T[] + for i in 1:stop + d, X = wavelet_transform(l, X) + push!(Ud, l.A(d) + l.B(d)) + push!(Us, l.C(d)) + end + X = l.T0(X) + + # reconstruct + for i in stop:-1:1 + X += Us[i] + X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1) + X = even_odd(l, X) + end + return X +end + +# function Base.show(io::IO, l::MWT_CZ1d) +# print(io, "MWT_CZ($(l.in_channel) => $(l.out_channel), $(l.transform.modes), $(nameof(typeof(l.transform))), permuted=$P)") +# end + ######### # utils # ######### diff --git a/test/Transform/Transform.jl b/test/Transform/Transform.jl index d5ff9a67..abb5cac9 100644 --- a/test/Transform/Transform.jl +++ b/test/Transform/Transform.jl @@ -1,4 +1,6 @@ @testset "Transform" begin + include("polynomials.jl") include("fourier_transform.jl") include("chebyshev_transform.jl") + include("wavelet_transform.jl") end diff --git a/test/Transform/polynomials.jl b/test/Transform/polynomials.jl new file mode 100644 index 00000000..2fb5a837 --- /dev/null +++ b/test/Transform/polynomials.jl @@ -0,0 +1,127 @@ +@testset "polynomials" begin + @testset "legendre_ϕ_ψ" begin + ϕ, ψ1, ψ2 = NeuralOperators.legendre_ϕ_ψ(10) + + @test all(coeffs(ϕ[1]) .≈ [1.0]) + @test all(coeffs(ϕ[2]) .≈ [-1.7320508075688772, 3.4641016151377544]) + @test all(coeffs(ϕ[3]) .≈ + [2.23606797749979, -13.416407864998739, 13.416407864998739]) + @test all(coeffs(ϕ[4]) .≈ [ + -2.6457513110645907, + 31.74901573277509, + -79.37253933193772, + 52.91502622129181, + ]) + @test all(coeffs(ϕ[5]) .≈ [3.0, -60.0, 270.0, -420.0, 210.0]) + @test all(coeffs(ϕ[6]) .≈ + [-3.3166247903554, 99.498743710662, -696.491205974634, 1857.309882599024, + -2089.4736179239017, 835.7894471695607]) + @test all(coeffs(ϕ[7]) .≈ + [3.605551275463989, -151.43315356948753, 1514.3315356948754, + -6057.326142779501, + 11357.486517711566, -9994.588135586178, 3331.529378528726]) + @test all(coeffs(ϕ[8]) .≈ + [-3.872983346207417, 216.88706738761536, -2927.9754097328073, + 16266.530054071152, + -44732.957648695665, 64415.45901412176, -46522.27595464349, + 13292.078844183856]) + @test all(coeffs(ϕ[9]) .≈ + [4.123105625617661, -296.86360504447157, 5195.113088278253, + -38097.49598070719, + 142865.60992765194, -297160.46864951606, 346687.21342443535, + -212257.47760679715, + 53064.36940169929]) + @test all(coeffs(ϕ[10]) .≈ + [-4.358898943540674, 392.30090491866065, -8630.619908210534, + 80552.45247663166, + -392693.20582357934, 1099540.9763060221, -1832568.2938433702, + 1795168.9409077913, + -953683.4998572641, 211929.66663494756]) + + ϕ, ψ1, ψ2 = NeuralOperators.legendre_ϕ_ψ(3) + @test coeffs(ϕ[1]) ≈ [1.0] + @test coeffs(ϕ[2]) ≈ [-1.7320508075688772, 3.4641016151377544] + @test coeffs(ϕ[3]) ≈ [2.23606797749979, -13.416407864998739, 13.416407864998739] + @test coeffs(ψ1[1]) ≈ [-1.0000000000000122, 6.000000000000073] + @test coeffs(ψ1[2]) ≈ [1.7320508075691732, -24.248711305967735, 51.96152422707261] + @test coeffs(ψ1[3]) ≈ [2.2360679774995615, -26.832815729994504, 53.665631459989214] + @test coeffs(ψ2[1]) ≈ [-5.000000000000066, 6.000000000000073] + @test coeffs(ψ2[2]) ≈ [29.44486372867492, -79.67433714817852, 51.96152422707261] + @test coeffs(ψ2[3]) ≈ [-29.068883707507286, 80.49844719001908, -53.665631460012115] + end + + @testset "chebyshev_ϕ_ψ" begin + ϕ, ψ1, ψ2 = NeuralOperators.chebyshev_ϕ_ψ(3) + @test ϕ[1](0) ≈ 0.7978845608028654 + @test ϕ[1](1) ≈ 0.7978845608028654 + @test ϕ[1](2) ≈ 0.0 + @test ϕ[2](0) ≈ -1.1283791670955126 + @test ϕ[2](1) ≈ 1.1283791670955126 + @test ϕ[2](2) ≈ 0.0 + @test ϕ[3](0) ≈ 1.1283791670955126 + @test ϕ[3](1) ≈ 1.1283791670955126 + @test ϕ[3](2) ≈ 0.0 + + @test ψ1[1](0) ≈ -0.5560622352843183 + @test ψ1[1](1) ≈ 0.0 + @test ψ1[1](2) ≈ 0.0 + @test ψ1[2](0) ≈ 0.932609257876051 + @test ψ1[2](1) ≈ 0.0 + @test ψ1[2](2) ≈ 0.0 + @test ψ1[3](0) ≈ 1.0941547380212637 + @test ψ1[3](1) ≈ 0.0 + @test ψ1[3](2) ≈ 0.0 + + @test ψ2[1](0) ≈ -0.0 + @test ψ2[1](1) ≈ 0.5560622352843181 + @test ψ2[1](2) ≈ 0.0 + @test ψ2[2](0) ≈ 0.0 + @test ψ2[2](1) ≈ 0.9326092578760665 + @test ψ2[2](2) ≈ 0.0 + @test ψ2[3](0) ≈ 0.0 + @test ψ2[3](1) ≈ -1.0941547380212384 + @test ψ2[3](2) ≈ 0.0 + end + + @testset "legendre_filter" begin + H0, H1, G0, G1, Φ1, Φ2 = NeuralOperators.legendre_filter(3) + + @test H0 ≈ [0.70710678 0.0 0.0; + -0.61237244 0.35355339 0.0; + 0.0 -0.6846532 0.1767767] + @test H1 ≈ [0.70710678 0.0 0.0; + 0.61237244 0.35355339 0.0; + 0.0 0.6846532 0.1767767] + @test G0 ≈ [0.35355339 0.61237244 0.0; + 0.0 0.1767767 0.6846532; + 0.0 0.0 0.70710678] + @test G1 ≈ [-0.35355339 0.61237244 0.0; + 0.0 -0.1767767 0.6846532; + 0.0 0.0 -0.70710678] + @test Φ1 == I(3) + @test Φ2 == I(3) + end + + @testset "chebyshev_filter" begin + H0, H1, G0, G1, Φ0, Φ1 = NeuralOperators.chebyshev_filter(3) + + @test H0 ≈ [0.70710678 0.0 0.0; + -0.5 0.35355339 0.0; + -0.25 -0.70710678 0.1767767] + @test H1 ≈ [0.70710678 0.0 0.0; + 0.5 0.35355339 0.0; + -0.25 0.70710678 0.1767767] + @test G0 ≈ [0.60944614 0.77940383 0.0; + 0.66325172 1.02726613 1.14270252; + 0.61723435 0.90708619 1.1562954] + @test G1 ≈ [-0.60944614 0.77940383 0.0; + 0.66325172 -1.02726613 1.14270252; + -0.61723435 0.90708619 -1.1562954] + @test Φ0 ≈ [1.0 -0.40715364 -0.21440101; + -0.40715364 0.84839559 -0.44820615; + -0.21440101 -0.44820615 0.84002127] + @test Φ1 ≈ [1.0 0.40715364 -0.21440101; + 0.40715364 0.84839559 0.44820615; + -0.21440101 0.44820615 0.84002127] + end +end diff --git a/test/Transform/wavelet_transform.jl b/test/Transform/wavelet_transform.jl new file mode 100644 index 00000000..16211220 --- /dev/null +++ b/test/Transform/wavelet_transform.jl @@ -0,0 +1,29 @@ +@testset "wavelet transform" begin + 𝐱 = rand(30, 40, 50, 6, 7) # where ch == 6 and batch == 7 + + wt = WaveletTransform((3, 4, 5)) + + @test size(transform(wt, 𝐱)) == (30, 40, 50, 6, 7) + @test size(truncate_modes(wt, transform(wt, 𝐱))) == (3, 4, 5, 6, 7) + @test size(inverse(wt, truncate_modes(wt, transform(wt, 𝐱)))) == (3, 4, 5, 6, 7) +end + +@testset "MWT_CZ" begin + T = Float32 + k = 3 + batch_size = 32 + + @testset "MWT_CZ1d" begin + mwt = MWT_CZ1d() + + # base functions + wavelet_transform(mwt) + even_odd(mwt) + + # forward + Y = mwt(X) + + # backward + g = gradient() + end +end diff --git a/test/operator_kernel.jl b/test/operator_kernel.jl index 2d00b4ff..e210a478 100644 --- a/test/operator_kernel.jl +++ b/test/operator_kernel.jl @@ -160,3 +160,57 @@ end data = [(𝐱, rand(Float32, 128, 1024, 5))] Flux.train!(loss, Flux.params(m), data, Flux.Adam()) end + +@testset "SparseKernel" begin + T = Float32 + k = 3 + batch_size = 32 + + @testset "1D SparseKernel" begin + α = 4 + c = 1 + in_chs = 20 + X = rand(T, in_chs, c * k, batch_size) + + l1 = SparseKernel1D(k, α, c) + Y = l1(X) + @test l1 isa SparseKernel{1} + @test size(Y) == size(X) + + gs = gradient(() -> sum(l1(X)), Flux.params(l1)) + @test length(gs.grads) == 4 + end + + @testset "2D SparseKernel" begin + α = 4 + c = 3 + Nx = 5 + Ny = 7 + X = rand(T, Nx, Ny, c * k^2, batch_size) + + l2 = SparseKernel2D(k, α, c) + Y = l2(X) + @test l2 isa SparseKernel{2} + @test size(Y) == size(X) + + gs = gradient(() -> sum(l2(X)), Flux.params(l2)) + @test length(gs.grads) == 4 + end + + @testset "3D SparseKernel" begin + α = 4 + c = 3 + Nx = 5 + Ny = 7 + Nz = 13 + X = rand(T, Nx, Ny, Nz, α * k^2, batch_size) + + l3 = SparseKernel3D(k, α, c) + Y = l3(X) + @test l3 isa SparseKernel{3} + @test size(Y) == (Nx, Ny, Nz, c * k^2, batch_size) + + gs = gradient(() -> sum(l3(X)), Flux.params(l3)) + @test length(gs.grads) == 4 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 6ffe561a..e2d65a17 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,8 @@ using CUDA using Flux using GeometricFlux using Graphs +using LinearAlgebra +using Polynomials using Zygote using Test