diff --git a/src/Transform/polynomials.jl b/src/Transform/polynomials.jl index 93ee257c..480af298 100644 --- a/src/Transform/polynomials.jl +++ b/src/Transform/polynomials.jl @@ -16,47 +16,47 @@ function legendre_ϕ_ψ(k) p = Polynomial([-1, 2]) # 2x-1 p2 = Polynomial([-1, 4]) # 4x-1 - for ki in 0:(k-1) + for ki in 0:(k - 1) l = convert(Polynomial, gen_poly(Legendre, ki)) # Legendre of n=ki - ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2*ki+1) .* coeffs(l(p)) - ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2*(2*ki+1)) .* coeffs(l(p2)) + ϕ_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 * ki + 1) .* coeffs(l(p)) + ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 * (2 * ki + 1)) .* coeffs(l(p2)) end - + ψ1_coefs = zeros(k, k) ψ2_coefs = zeros(k, k) - for ki in 0:(k-1) - ψ1_coefs[ki+1, :] .= ϕ_2x_coefs[ki+1, :] - for i in 0:(k-1) - a = ϕ_2x_coefs[ki+1, 1:(ki+1)] - b = ϕ_coefs[i+1, 1:(i+1)] + for ki in 0:(k - 1) + ψ1_coefs[ki + 1, :] .= ϕ_2x_coefs[ki + 1, :] + for i in 0:(k - 1) + a = ϕ_2x_coefs[ki + 1, 1:(ki + 1)] + b = ϕ_coefs[i + 1, 1:(i + 1)] proj_ = proj_factor(a, b) - ψ1_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) - ψ2_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) end - for j in 0:(k-1) - a = ϕ_2x_coefs[ki+1, 1:(ki+1)] - b = ψ1_coefs[j+1, :] + for j in 0:(k - 1) + a = ϕ_2x_coefs[ki + 1, 1:(ki + 1)] + b = ψ1_coefs[j + 1, :] proj_ = proj_factor(a, b) - ψ1_coefs[ki+1, :] .-= proj_ .* view(ψ1_coefs, j+1, :) - ψ2_coefs[ki+1, :] .-= proj_ .* view(ψ2_coefs, j+1, :) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ψ1_coefs, j + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ψ2_coefs, j + 1, :) end - a = ψ1_coefs[ki+1, :] + a = ψ1_coefs[ki + 1, :] norm1 = proj_factor(a, a) - a = ψ2_coefs[ki+1, :] - norm2 = proj_factor(a, a, complement=true) + a = ψ2_coefs[ki + 1, :] + norm2 = proj_factor(a, a, complement = true) norm_ = sqrt(norm1 + norm2) - ψ1_coefs[ki+1, :] ./= norm_ - ψ2_coefs[ki+1, :] ./= norm_ + ψ1_coefs[ki + 1, :] ./= norm_ + ψ2_coefs[ki + 1, :] ./= norm_ zero_out!(ψ1_coefs) zero_out!(ψ2_coefs) end - ϕ = [Polynomial(ϕ_coefs[i,:]) for i in 1:k] - ψ1 = [Polynomial(ψ1_coefs[i,:]) for i in 1:k] - ψ2 = [Polynomial(ψ2_coefs[i,:]) for i in 1:k] + ϕ = [Polynomial(ϕ_coefs[i, :]) for i in 1:k] + ψ1 = [Polynomial(ψ1_coefs[i, :]) for i in 1:k] + ψ2 = [Polynomial(ψ2_coefs[i, :]) for i in 1:k] return ϕ, ψ1, ψ2 end @@ -68,14 +68,14 @@ function chebyshev_ϕ_ψ(k) p = Polynomial([-1, 2]) # 2x-1 p2 = Polynomial([-1, 4]) # 4x-1 - for ki in 0:(k-1) + for ki in 0:(k - 1) if ki == 0 - ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2/π) - ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(4/π) + ϕ_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 / π) + ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(4 / π) else c = convert(Polynomial, gen_poly(Chebyshev, ki)) # Chebyshev of n=ki - ϕ_coefs[ki+1, 1:(ki+1)] .= 2/sqrt(π) .* coeffs(c(p)) - ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2) * 2/sqrt(π) .* coeffs(c(p2)) + ϕ_coefs[ki + 1, 1:(ki + 1)] .= 2 / sqrt(π) .* coeffs(c(p)) + ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2) * 2 / sqrt(π) .* coeffs(c(p2)) end end @@ -87,43 +87,43 @@ function chebyshev_ϕ_ψ(k) # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) # not needed for our purpose here, we use even k always to avoid wm = π / k_use / 2 - + ψ1_coefs = zeros(k, k) ψ2_coefs = zeros(k, k) - ψ1 = Array{Any,1}(undef, k) - ψ2 = Array{Any,1}(undef, k) + ψ1 = Array{Any, 1}(undef, k) + ψ2 = Array{Any, 1}(undef, k) - for ki in 0:(k-1) - ψ1_coefs[ki+1, :] .= ϕ_2x_coefs[ki+1, :] - for i in 0:(k-1) - proj_ = sum(wm .* ϕ[i+1].(x_m) .* sqrt(2) .* ϕ[ki+1].(2*x_m)) - ψ1_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) - ψ2_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) + for ki in 0:(k - 1) + ψ1_coefs[ki + 1, :] .= ϕ_2x_coefs[ki + 1, :] + for i in 0:(k - 1) + proj_ = sum(wm .* ϕ[i + 1].(x_m) .* sqrt(2) .* ϕ[ki + 1].(2 * x_m)) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) end - for j in 0:(ki-1) - proj_ = sum(wm .* ψ1[j+1].(x_m) .* sqrt(2) .* ϕ[ki+1].(2*x_m)) - ψ1_coefs[ki+1, :] .-= proj_ .* view(ψ1_coefs, j+1, :) - ψ2_coefs[ki+1, :] .-= proj_ .* view(ψ2_coefs, j+1, :) + for j in 0:(ki - 1) + proj_ = sum(wm .* ψ1[j + 1].(x_m) .* sqrt(2) .* ϕ[ki + 1].(2 * x_m)) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ψ1_coefs, j + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ψ2_coefs, j + 1, :) end - ψ1[ki+1] = ϕ_(ψ1_coefs[ki+1,:]; lb=0., ub=0.5) - ψ2[ki+1] = ϕ_(ψ2_coefs[ki+1,:]; lb=0.5, ub=1.) - - norm1 = sum(wm .* ψ1[ki+1].(x_m) .* ψ1[ki+1].(x_m)) - norm2 = sum(wm .* ψ2[ki+1].(x_m) .* ψ2[ki+1].(x_m)) - + ψ1[ki + 1] = ϕ_(ψ1_coefs[ki + 1, :]; lb = 0.0, ub = 0.5) + ψ2[ki + 1] = ϕ_(ψ2_coefs[ki + 1, :]; lb = 0.5, ub = 1.0) + + norm1 = sum(wm .* ψ1[ki + 1].(x_m) .* ψ1[ki + 1].(x_m)) + norm2 = sum(wm .* ψ2[ki + 1].(x_m) .* ψ2[ki + 1].(x_m)) + norm_ = sqrt(norm1 + norm2) - ψ1_coefs[ki+1, :] ./= norm_ - ψ2_coefs[ki+1, :] ./= norm_ + ψ1_coefs[ki + 1, :] ./= norm_ + ψ2_coefs[ki + 1, :] ./= norm_ zero_out!(ψ1_coefs) zero_out!(ψ2_coefs) - - ψ1[ki+1] = ϕ_(ψ1_coefs[ki+1,:]; lb=0., ub=0.5+1e-16) - ψ2[ki+1] = ϕ_(ψ2_coefs[ki+1,:]; lb=0.5+1e-16, ub=1.) + + ψ1[ki + 1] = ϕ_(ψ1_coefs[ki + 1, :]; lb = 0.0, ub = 0.5 + 1e-16) + ψ2[ki + 1] = ϕ_(ψ2_coefs[ki + 1, :]; lb = 0.5 + 1e-16, ub = 1.0) end - + return ϕ, ψ1, ψ2 end @@ -137,14 +137,18 @@ function legendre_filter(k) l = convert(Polynomial, gen_poly(Legendre, k)) x_m = roots(l(Polynomial([-1, 2]))) # 2x-1 m = 2 .* x_m .- 1 - wm = 1 ./ k ./ legendre_der.(k, m) ./ gen_poly(Legendre, k-1).(m) - - for ki in 0:(k-1) - for kpi in 0:(k-1) - H0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].(x_m/2) .* ϕ[kpi+1].(x_m)) - G0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, x_m/2) .* ϕ[kpi+1].(x_m)) - H1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].((x_m.+1)/2) .* ϕ[kpi+1].(x_m)) - G1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m.+1)/2) .* ϕ[kpi+1].(x_m)) + wm = 1 ./ k ./ legendre_der.(k, m) ./ gen_poly(Legendre, k - 1).(m) + + for ki in 0:(k - 1) + for kpi in 0:(k - 1) + H0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].(x_m / 2) .* ϕ[kpi + 1].(x_m)) + G0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ψ(ψ1, ψ2, ki, x_m / 2) .* ϕ[kpi + 1].(x_m)) + H1[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].((x_m .+ 1) / 2) .* ϕ[kpi + 1].(x_m)) + G1[ki + 1, kpi + 1] = 1 / sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m .+ 1) / 2) .* + ϕ[kpi + 1].(x_m)) end end @@ -152,7 +156,7 @@ function legendre_filter(k) zero_out!(H1) zero_out!(G0) zero_out!(G1) - + return H0, H1, G0, G1, I(k), I(k) end @@ -172,14 +176,19 @@ function chebyshev_filter(k) # not needed for our purpose here, we use even k always to avoid wm = π / k_use / 2 - for ki in 0:(k-1) - for kpi in 0:(k-1) - H0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].(x_m/2) .* ϕ[kpi+1].(x_m)) - H1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].((x_m.+1)/2) .* ϕ[kpi+1].(x_m)) - G0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, x_m/2) .* ϕ[kpi+1].(x_m)) - G1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m.+1)/2) .* ϕ[kpi+1].(x_m)) - Φ0[ki+1, kpi+1] = 2*sum(wm .* ϕ[ki+1].(2x_m) .* ϕ[kpi+1].(2x_m)) - Φ1[ki+1, kpi+1] = 2*sum(wm .* ϕ[ki+1].(2 .* x_m .- 1) .* ϕ[kpi+1].(2 .* x_m .- 1)) + for ki in 0:(k - 1) + for kpi in 0:(k - 1) + H0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].(x_m / 2) .* ϕ[kpi + 1].(x_m)) + H1[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].((x_m .+ 1) / 2) .* ϕ[kpi + 1].(x_m)) + G0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ψ(ψ1, ψ2, ki, x_m / 2) .* ϕ[kpi + 1].(x_m)) + G1[ki + 1, kpi + 1] = 1 / sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m .+ 1) / 2) .* + ϕ[kpi + 1].(x_m)) + Φ0[ki + 1, kpi + 1] = 2 * sum(wm .* ϕ[ki + 1].(2x_m) .* ϕ[kpi + 1].(2x_m)) + Φ1[ki + 1, kpi + 1] = 2 * sum(wm .* ϕ[ki + 1].(2 .* x_m .- 1) .* + ϕ[kpi + 1].(2 .* x_m .- 1)) end end @@ -189,6 +198,6 @@ function chebyshev_filter(k) zero_out!(G1) zero_out!(Φ0) zero_out!(Φ1) - + return H0, H1, G0, G1, Φ0, Φ1 end diff --git a/src/Transform/utils.jl b/src/Transform/utils.jl index 994fa75d..a76c13b6 100644 --- a/src/Transform/utils.jl +++ b/src/Transform/utils.jl @@ -1,6 +1,6 @@ -function ϕ_(ϕ_coefs; lb::Real=0., ub::Real=1.) +function ϕ_(ϕ_coefs; lb::Real = 0.0, ub::Real = 1.0) function partial(x) - mask = (lb ≤ x ≤ ub) * 1. + mask = (lb ≤ x ≤ ub) * 1.0 return Polynomial(ϕ_coefs)(x) * mask end return partial @@ -8,27 +8,27 @@ end function ψ(ψ1, ψ2, i, inp) mask = (inp .> 0.5) .* 1.0 - return ψ1[i+1].(inp) .* mask .+ ψ2[i+1].(inp) .* mask + return ψ1[i + 1].(inp) .* mask .+ ψ2[i + 1].(inp) .* mask end -zero_out!(x; tol=1e-8) = (x[abs.(x) .< tol] .= 0) +zero_out!(x; tol = 1e-8) = (x[abs.(x) .< tol] .= 0) function gen_poly(poly, n) - x = zeros(n+1) + x = zeros(n + 1) x[end] = 1 return poly(x) end function convolve(a, b) n = length(b) - y = similar(a, length(a)+n-1) + y = similar(a, length(a) + n - 1) for i in 1:length(a) - y[i:(i+n-1)] .+= a[i] .* b + y[i:(i + n - 1)] .+= a[i] .* b end return y end -function proj_factor(a, b; complement::Bool=false) +function proj_factor(a, b; complement::Bool = false) prod_ = convolve(a, b) r = collect(1:length(prod_)) s = complement ? (1 .- 0.5 .^ r) : (0.5 .^ r) @@ -36,11 +36,11 @@ function proj_factor(a, b; complement::Bool=false) return proj_ end -_legendre(k, x) = (2k+1) * gen_poly(Legendre, k)(x) +_legendre(k, x) = (2k + 1) * gen_poly(Legendre, k)(x) function legendre_der(k, x) out = 0 - for i in k-1:-2:-1 + for i in (k - 1):-2:-1 out += _legendre(i, x) end return out diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl index 41d54edb..d58970bb 100644 --- a/src/Transform/wavelet_transform.jl +++ b/src/Transform/wavelet_transform.jl @@ -1,15 +1,15 @@ export WaveletTransform -struct WaveletTransform{N, S}<:AbstractTransform - ec_d - ec_s +struct WaveletTransform{N, S} <: AbstractTransform + ec_d::Any + ec_s::Any modes::NTuple{N, S} # N == ndims(x) end Base.ndims(::WaveletTransform{N}) where {N} = N function transform(wt::WaveletTransform, 𝐱::AbstractArray) - N = size(X, ndims(wt)-1) + N = size(X, ndims(wt) - 1) # 1d Xa = vcat(view(𝐱, :, :, 1:2:N, :), view(𝐱, :, :, 2:2:N, :)) # 2d @@ -24,9 +24,7 @@ function transform(wt::WaveletTransform, 𝐱::AbstractArray) return d, s end -function inverse(wt::WaveletTransform, 𝐱_fwt::AbstractArray) - -end +function inverse(wt::WaveletTransform, 𝐱_fwt::AbstractArray) end # function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray) # return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch] diff --git a/src/operator_kernel.jl b/src/operator_kernel.jl index 82dcb6a2..fa261930 100644 --- a/src/operator_kernel.jl +++ b/src/operator_kernel.jl @@ -1,12 +1,12 @@ export - OperatorConv, - SpectralConv, - OperatorKernel, - SparseKernel, - SparseKernel1D, - SparseKernel2D, - SparseKernel3D, - MWT_CZ1d + OperatorConv, + SpectralConv, + OperatorKernel, + SparseKernel, + SparseKernel1D, + SparseKernel2D, + SparseKernel3D, + MWT_CZ1d struct OperatorConv{P, T, S, TT} weight::T @@ -196,36 +196,37 @@ Sparse kernel layer. * `ch`: Channel size for linear transform, e.g. `32`. * `σ`: Activation function. """ -struct SparseKernel{N,T,S} +struct SparseKernel{N, T, S} conv_blk::T out_weight::S end -function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S} +function SparseKernel(filter::NTuple{N, T}, ch::Pair{S, S}; + init = Flux.glorot_uniform) where {N, T, S} input_dim, emb_dim = ch - conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out) + conv = Conv(filter, input_dim => emb_dim, relu; stride = 1, pad = 1, init = init) + W_out = Dense(emb_dim, input_dim; init = init) + return SparseKernel{N, typeof(conv), typeof(W_out)}(conv, W_out) end -function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k +function SparseKernel1D(k::Int, α, c::Int = 1; init = Flux.glorot_uniform) + input_dim = c * k emb_dim = 128 - return SparseKernel((3, ), input_dim=>emb_dim; init=init) + return SparseKernel((3,), input_dim => emb_dim; init = init) end -function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k^2 - emb_dim = α*k^2 - return SparseKernel((3, 3), input_dim=>emb_dim; init=init) +function SparseKernel2D(k::Int, α, c::Int = 1; init = Flux.glorot_uniform) + input_dim = c * k^2 + emb_dim = α * k^2 + return SparseKernel((3, 3), input_dim => emb_dim; init = init) end -function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k^2 - emb_dim = α*k^2 - conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out) +function SparseKernel3D(k::Int, α, c::Int = 1; init = Flux.glorot_uniform) + input_dim = c * k^2 + emb_dim = α * k^2 + conv = Conv((3, 3, 3), emb_dim => emb_dim, relu; stride = 1, pad = 1, init = init) + W_out = Dense(emb_dim, input_dim; init = init) + return SparseKernel{3, typeof(conv), typeof(W_out)}(conv, W_out) end Flux.@functor SparseKernel @@ -241,8 +242,7 @@ function (l::SparseKernel)(X::AbstractArray) return collect(Y) end - -struct MWT_CZ1d{T,S,R,Q,P} +struct MWT_CZ1d{T, S, R, Q, P} k::Int L::Int A::T @@ -255,17 +255,18 @@ struct MWT_CZ1d{T,S,R,Q,P} rc_o::P end -function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform) +function MWT_CZ1d(k::Int = 3, α::Int = 5, L::Int = 0, c::Int = 1; base::Symbol = :legendre, + init = Flux.glorot_uniform) H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k) H0r = zero_out!(H0 * Φ0) G0r = zero_out!(G0 * Φ0) H1r = zero_out!(H1 * Φ1) G1r = zero_out!(G1 * Φ1) - dim = c*k - A = SpectralConv(dim=>dim, (α,); init=init) - B = SpectralConv(dim=>dim, (α,); init=init) - C = SpectralConv(dim=>dim, (α,); init=init) + dim = c * k + A = SpectralConv(dim => dim, (α,); init = init) + B = SpectralConv(dim => dim, (α,); init = init) + C = SpectralConv(dim => dim, (α,); init = init) T0 = Dense(k, k) ec_s = vcat(H0', H1') @@ -275,7 +276,7 @@ function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendr return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o) end -function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} +function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T, 4}) where {T} N = size(X, 3) Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :)) d = NNlib.batched_mul(Xa, l.ec_d) @@ -283,17 +284,17 @@ function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} return d, s end -function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} +function even_odd(l::MWT_CZ1d, X::AbstractArray{T, 4}) where {T} bch_sz, N, dims_r... = reverse(size(X)) dims = reverse(dims_r) - @assert dims[1] == 2*l.k + @assert dims[1] == 2 * l.k Y = similar(X, bch_sz, 2N, l.c, l.k) view(Y, :, :, 1:2:N, :) .= NNlib.batched_mul(X, l.rc_e) view(Y, :, :, 2:2:N, :) .= NNlib.batched_mul(X, l.rc_o) return Y end -function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} +function (l::MWT_CZ1d)(X::T) where {T <: AbstractArray} bch_sz, N, dims_r... = reverse(size(X)) ns = floor(log2(N)) stop = ns - l.L @@ -303,7 +304,7 @@ function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} Us = T[] for i in 1:stop d, X = wavelet_transform(l, X) - push!(Ud, l.A(d)+l.B(d)) + push!(Ud, l.A(d) + l.B(d)) push!(Us, l.C(d)) end X = l.T0(X) @@ -321,7 +322,6 @@ end # print(io, "MWT_CZ($(l.in_channel) => $(l.out_channel), $(l.transform.modes), $(nameof(typeof(l.transform))), permuted=$P)") # end - ######### # utils # ######### diff --git a/test/Transform/polynomials.jl b/test/Transform/polynomials.jl index fdb53df7..2fb5a837 100644 --- a/test/Transform/polynomials.jl +++ b/test/Transform/polynomials.jl @@ -2,26 +2,44 @@ @testset "legendre_ϕ_ψ" begin ϕ, ψ1, ψ2 = NeuralOperators.legendre_ϕ_ψ(10) - @test all(coeffs(ϕ[1]) .≈ [1.]) + @test all(coeffs(ϕ[1]) .≈ [1.0]) @test all(coeffs(ϕ[2]) .≈ [-1.7320508075688772, 3.4641016151377544]) - @test all(coeffs(ϕ[3]) .≈ [2.23606797749979, -13.416407864998739, 13.416407864998739]) - @test all(coeffs(ϕ[4]) .≈ [-2.6457513110645907, 31.74901573277509, -79.37253933193772, 52.91502622129181]) + @test all(coeffs(ϕ[3]) .≈ + [2.23606797749979, -13.416407864998739, 13.416407864998739]) + @test all(coeffs(ϕ[4]) .≈ [ + -2.6457513110645907, + 31.74901573277509, + -79.37253933193772, + 52.91502622129181, + ]) @test all(coeffs(ϕ[5]) .≈ [3.0, -60.0, 270.0, -420.0, 210.0]) - @test all(coeffs(ϕ[6]) .≈ [-3.3166247903554, 99.498743710662, -696.491205974634, 1857.309882599024, - -2089.4736179239017, 835.7894471695607]) - @test all(coeffs(ϕ[7]) .≈ [3.605551275463989, -151.43315356948753, 1514.3315356948754, -6057.326142779501, - 11357.486517711566, -9994.588135586178, 3331.529378528726]) - @test all(coeffs(ϕ[8]) .≈ [-3.872983346207417, 216.88706738761536, -2927.9754097328073, 16266.530054071152, - -44732.957648695665, 64415.45901412176, -46522.27595464349, 13292.078844183856]) - @test all(coeffs(ϕ[9]) .≈ [4.123105625617661, -296.86360504447157, 5195.113088278253, -38097.49598070719, - 142865.60992765194, -297160.46864951606, 346687.21342443535, -212257.47760679715, - 53064.36940169929]) - @test all(coeffs(ϕ[10]) .≈ [-4.358898943540674, 392.30090491866065, -8630.619908210534, 80552.45247663166, - -392693.20582357934, 1099540.9763060221, -1832568.2938433702, 1795168.9409077913, - -953683.4998572641, 211929.66663494756]) - + @test all(coeffs(ϕ[6]) .≈ + [-3.3166247903554, 99.498743710662, -696.491205974634, 1857.309882599024, + -2089.4736179239017, 835.7894471695607]) + @test all(coeffs(ϕ[7]) .≈ + [3.605551275463989, -151.43315356948753, 1514.3315356948754, + -6057.326142779501, + 11357.486517711566, -9994.588135586178, 3331.529378528726]) + @test all(coeffs(ϕ[8]) .≈ + [-3.872983346207417, 216.88706738761536, -2927.9754097328073, + 16266.530054071152, + -44732.957648695665, 64415.45901412176, -46522.27595464349, + 13292.078844183856]) + @test all(coeffs(ϕ[9]) .≈ + [4.123105625617661, -296.86360504447157, 5195.113088278253, + -38097.49598070719, + 142865.60992765194, -297160.46864951606, 346687.21342443535, + -212257.47760679715, + 53064.36940169929]) + @test all(coeffs(ϕ[10]) .≈ + [-4.358898943540674, 392.30090491866065, -8630.619908210534, + 80552.45247663166, + -392693.20582357934, 1099540.9763060221, -1832568.2938433702, + 1795168.9409077913, + -953683.4998572641, 211929.66663494756]) + ϕ, ψ1, ψ2 = NeuralOperators.legendre_ϕ_ψ(3) - @test coeffs(ϕ[1]) ≈ [1.] + @test coeffs(ϕ[1]) ≈ [1.0] @test coeffs(ϕ[2]) ≈ [-1.7320508075688772, 3.4641016151377544] @test coeffs(ϕ[3]) ≈ [2.23606797749979, -13.416407864998739, 13.416407864998739] @test coeffs(ψ1[1]) ≈ [-1.0000000000000122, 6.000000000000073] @@ -36,50 +54,50 @@ ϕ, ψ1, ψ2 = NeuralOperators.chebyshev_ϕ_ψ(3) @test ϕ[1](0) ≈ 0.7978845608028654 @test ϕ[1](1) ≈ 0.7978845608028654 - @test ϕ[1](2) ≈ 0. + @test ϕ[1](2) ≈ 0.0 @test ϕ[2](0) ≈ -1.1283791670955126 @test ϕ[2](1) ≈ 1.1283791670955126 - @test ϕ[2](2) ≈ 0. + @test ϕ[2](2) ≈ 0.0 @test ϕ[3](0) ≈ 1.1283791670955126 @test ϕ[3](1) ≈ 1.1283791670955126 - @test ϕ[3](2) ≈ 0. + @test ϕ[3](2) ≈ 0.0 @test ψ1[1](0) ≈ -0.5560622352843183 - @test ψ1[1](1) ≈ 0. - @test ψ1[1](2) ≈ 0. + @test ψ1[1](1) ≈ 0.0 + @test ψ1[1](2) ≈ 0.0 @test ψ1[2](0) ≈ 0.932609257876051 - @test ψ1[2](1) ≈ 0. - @test ψ1[2](2) ≈ 0. + @test ψ1[2](1) ≈ 0.0 + @test ψ1[2](2) ≈ 0.0 @test ψ1[3](0) ≈ 1.0941547380212637 - @test ψ1[3](1) ≈ 0. - @test ψ1[3](2) ≈ 0. + @test ψ1[3](1) ≈ 0.0 + @test ψ1[3](2) ≈ 0.0 - @test ψ2[1](0) ≈ -0. + @test ψ2[1](0) ≈ -0.0 @test ψ2[1](1) ≈ 0.5560622352843181 - @test ψ2[1](2) ≈ 0. - @test ψ2[2](0) ≈ 0. + @test ψ2[1](2) ≈ 0.0 + @test ψ2[2](0) ≈ 0.0 @test ψ2[2](1) ≈ 0.9326092578760665 - @test ψ2[2](2) ≈ 0. - @test ψ2[3](0) ≈ 0. + @test ψ2[2](2) ≈ 0.0 + @test ψ2[3](0) ≈ 0.0 @test ψ2[3](1) ≈ -1.0941547380212384 - @test ψ2[3](2) ≈ 0. + @test ψ2[3](2) ≈ 0.0 end @testset "legendre_filter" begin H0, H1, G0, G1, Φ1, Φ2 = NeuralOperators.legendre_filter(3) - @test H0 ≈ [0.70710678 0. 0. ; - -0.61237244 0.35355339 0. ; - 0. -0.6846532 0.1767767] - @test H1 ≈ [0.70710678 0. 0. ; - 0.61237244 0.35355339 0. ; - 0. 0.6846532 0.1767767] - @test G0 ≈ [0.35355339 0.61237244 0. ; - 0. 0.1767767 0.6846532 ; - 0. 0. 0.70710678] - @test G1 ≈ [-0.35355339 0.61237244 0. ; - 0. -0.1767767 0.6846532 ; - 0. 0. -0.70710678] + @test H0 ≈ [0.70710678 0.0 0.0; + -0.61237244 0.35355339 0.0; + 0.0 -0.6846532 0.1767767] + @test H1 ≈ [0.70710678 0.0 0.0; + 0.61237244 0.35355339 0.0; + 0.0 0.6846532 0.1767767] + @test G0 ≈ [0.35355339 0.61237244 0.0; + 0.0 0.1767767 0.6846532; + 0.0 0.0 0.70710678] + @test G1 ≈ [-0.35355339 0.61237244 0.0; + 0.0 -0.1767767 0.6846532; + 0.0 0.0 -0.70710678] @test Φ1 == I(3) @test Φ2 == I(3) end @@ -87,23 +105,23 @@ @testset "chebyshev_filter" begin H0, H1, G0, G1, Φ0, Φ1 = NeuralOperators.chebyshev_filter(3) - @test H0 ≈ [0.70710678 0. 0. ; - -0.5 0.35355339 0. ; - -0.25 -0.70710678 0.1767767] - @test H1 ≈ [0.70710678 0. 0. ; - 0.5 0.35355339 0. ; - -0.25 0.70710678 0.1767767] - @test G0 ≈ [0.60944614 0.77940383 0. ; - 0.66325172 1.02726613 1.14270252; - 0.61723435 0.90708619 1.1562954 ] - @test G1 ≈ [-0.60944614 0.77940383 0. ; - 0.66325172 -1.02726613 1.14270252; - -0.61723435 0.90708619 -1.1562954 ] - @test Φ0 ≈ [1. -0.40715364 -0.21440101; - -0.40715364 0.84839559 -0.44820615; - -0.21440101 -0.44820615 0.84002127] - @test Φ1 ≈ [1. 0.40715364 -0.21440101; - 0.40715364 0.84839559 0.44820615; - -0.21440101 0.44820615 0.84002127] + @test H0 ≈ [0.70710678 0.0 0.0; + -0.5 0.35355339 0.0; + -0.25 -0.70710678 0.1767767] + @test H1 ≈ [0.70710678 0.0 0.0; + 0.5 0.35355339 0.0; + -0.25 0.70710678 0.1767767] + @test G0 ≈ [0.60944614 0.77940383 0.0; + 0.66325172 1.02726613 1.14270252; + 0.61723435 0.90708619 1.1562954] + @test G1 ≈ [-0.60944614 0.77940383 0.0; + 0.66325172 -1.02726613 1.14270252; + -0.61723435 0.90708619 -1.1562954] + @test Φ0 ≈ [1.0 -0.40715364 -0.21440101; + -0.40715364 0.84839559 -0.44820615; + -0.21440101 -0.44820615 0.84002127] + @test Φ1 ≈ [1.0 0.40715364 -0.21440101; + 0.40715364 0.84839559 0.44820615; + -0.21440101 0.44820615 0.84002127] end end diff --git a/test/Transform/wavelet_transform.jl b/test/Transform/wavelet_transform.jl index 48705bf3..16211220 100644 --- a/test/Transform/wavelet_transform.jl +++ b/test/Transform/wavelet_transform.jl @@ -17,8 +17,8 @@ end mwt = MWT_CZ1d() # base functions - wavelet_transform(mwt, ) - even_odd(mwt, ) + wavelet_transform(mwt) + even_odd(mwt) # forward Y = mwt(X) @@ -26,5 +26,4 @@ end # backward g = gradient() end - end diff --git a/test/operator_kernel.jl b/test/operator_kernel.jl index be20b01b..fb3fee3f 100644 --- a/test/operator_kernel.jl +++ b/test/operator_kernel.jl @@ -152,14 +152,14 @@ end α = 4 c = 1 in_chs = 20 - X = rand(T, in_chs, c*k, batch_size) + X = rand(T, in_chs, c * k, batch_size) l1 = SparseKernel1D(k, α, c) Y = l1(X) @test l1 isa SparseKernel{1} @test size(Y) == size(X) - gs = gradient(()->sum(l1(X)), Flux.params(l1)) + gs = gradient(() -> sum(l1(X)), Flux.params(l1)) @test length(gs.grads) == 4 end @@ -168,14 +168,14 @@ end c = 3 Nx = 5 Ny = 7 - X = rand(T, Nx, Ny, c*k^2, batch_size) - + X = rand(T, Nx, Ny, c * k^2, batch_size) + l2 = SparseKernel2D(k, α, c) Y = l2(X) @test l2 isa SparseKernel{2} @test size(Y) == size(X) - gs = gradient(()->sum(l2(X)), Flux.params(l2)) + gs = gradient(() -> sum(l2(X)), Flux.params(l2)) @test length(gs.grads) == 4 end @@ -185,14 +185,14 @@ end Nx = 5 Ny = 7 Nz = 13 - X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) + X = rand(T, Nx, Ny, Nz, α * k^2, batch_size) l3 = SparseKernel3D(k, α, c) Y = l3(X) @test l3 isa SparseKernel{3} - @test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size) + @test size(Y) == (Nx, Ny, Nz, c * k^2, batch_size) - gs = gradient(()->sum(l3(X)), Flux.params(l3)) + gs = gradient(() -> sum(l3(X)), Flux.params(l3)) @test length(gs.grads) == 4 end end