Skip to content
This repository has been archived by the owner on Sep 28, 2024. It is now read-only.

Implement multiwavelet operators #30

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
SpecialPolynomials = "a25cea48-d430-424a-8ee7-0d3ad3742e9e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
3 changes: 3 additions & 0 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ using Zygote
using ChainRulesCore
using GeometricFlux
using Statistics
using Polynomials
using SpecialPolynomials
using LinearAlgebra

include("abstracttypes.jl")

Expand Down
3 changes: 3 additions & 0 deletions src/Transform/Transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@ export
"""
abstract type AbstractTransform end

include("utils.jl")
include("polynomials.jl")
include("fourier_transform.jl")
include("chebyshev_transform.jl")
include("wavelet_transform.jl")
203 changes: 203 additions & 0 deletions src/Transform/polynomials.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
function get_filter(base::Symbol, k)
if base == :legendre
return legendre_filter(k)
elseif base == :chebyshev
return chebyshev_filter(k)
else
throw(ArgumentError("base must be one of :legendre or :chebyshev."))
end
end

function legendre_ϕ_ψ(k)
# TODO: row-major -> column major
ϕ_coefs = zeros(k, k)
ϕ_2x_coefs = zeros(k, k)

p = Polynomial([-1, 2]) # 2x-1
p2 = Polynomial([-1, 4]) # 4x-1

for ki in 0:(k - 1)
l = convert(Polynomial, gen_poly(Legendre, ki)) # Legendre of n=ki
ϕ_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 * ki + 1) .* coeffs(l(p))
ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 * (2 * ki + 1)) .* coeffs(l(p2))
end

ψ1_coefs = zeros(k, k)
ψ2_coefs = zeros(k, k)
for ki in 0:(k - 1)
ψ1_coefs[ki + 1, :] .= ϕ_2x_coefs[ki + 1, :]
for i in 0:(k - 1)
a = ϕ_2x_coefs[ki + 1, 1:(ki + 1)]
b = ϕ_coefs[i + 1, 1:(i + 1)]
proj_ = proj_factor(a, b)
ψ1_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :)
ψ2_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :)
end

for j in 0:(k - 1)
a = ϕ_2x_coefs[ki + 1, 1:(ki + 1)]
b = ψ1_coefs[j + 1, :]
proj_ = proj_factor(a, b)
ψ1_coefs[ki + 1, :] .-= proj_ .* view(ψ1_coefs, j + 1, :)
ψ2_coefs[ki + 1, :] .-= proj_ .* view(ψ2_coefs, j + 1, :)
end

a = ψ1_coefs[ki + 1, :]
norm1 = proj_factor(a, a)

a = ψ2_coefs[ki + 1, :]
norm2 = proj_factor(a, a, complement = true)
norm_ = sqrt(norm1 + norm2)
ψ1_coefs[ki + 1, :] ./= norm_
ψ2_coefs[ki + 1, :] ./= norm_
zero_out!(ψ1_coefs)
zero_out!(ψ2_coefs)
end

ϕ = [Polynomial(ϕ_coefs[i, :]) for i in 1:k]
ψ1 = [Polynomial(ψ1_coefs[i, :]) for i in 1:k]
ψ2 = [Polynomial(ψ2_coefs[i, :]) for i in 1:k]

return ϕ, ψ1, ψ2
end

function chebyshev_ϕ_ψ(k)
ϕ_coefs = zeros(k, k)
ϕ_2x_coefs = zeros(k, k)

p = Polynomial([-1, 2]) # 2x-1
p2 = Polynomial([-1, 4]) # 4x-1

for ki in 0:(k - 1)
if ki == 0
ϕ_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 / π)
ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(4 / π)
else
c = convert(Polynomial, gen_poly(Chebyshev, ki)) # Chebyshev of n=ki
ϕ_coefs[ki + 1, 1:(ki + 1)] .= 2 / sqrt(π) .* coeffs(c(p))
ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2) * 2 / sqrt(π) .* coeffs(c(p2))
end
end

ϕ = [ϕ_(ϕ_coefs[i, :]) for i in 1:k]

k_use = 2k
c = convert(Polynomial, gen_poly(Chebyshev, k_use))
x_m = roots(c(p))
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = π / k_use / 2

ψ1_coefs = zeros(k, k)
ψ2_coefs = zeros(k, k)

ψ1 = Array{Any, 1}(undef, k)
ψ2 = Array{Any, 1}(undef, k)

for ki in 0:(k - 1)
ψ1_coefs[ki + 1, :] .= ϕ_2x_coefs[ki + 1, :]
for i in 0:(k - 1)
proj_ = sum(wm .* ϕ[i + 1].(x_m) .* sqrt(2) .* ϕ[ki + 1].(2 * x_m))
ψ1_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :)
ψ2_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :)
end

for j in 0:(ki - 1)
proj_ = sum(wm .* ψ1[j + 1].(x_m) .* sqrt(2) .* ϕ[ki + 1].(2 * x_m))
ψ1_coefs[ki + 1, :] .-= proj_ .* view(ψ1_coefs, j + 1, :)
ψ2_coefs[ki + 1, :] .-= proj_ .* view(ψ2_coefs, j + 1, :)
end

ψ1[ki + 1] = ϕ_(ψ1_coefs[ki + 1, :]; lb = 0.0, ub = 0.5)
ψ2[ki + 1] = ϕ_(ψ2_coefs[ki + 1, :]; lb = 0.5, ub = 1.0)

norm1 = sum(wm .* ψ1[ki + 1].(x_m) .* ψ1[ki + 1].(x_m))
norm2 = sum(wm .* ψ2[ki + 1].(x_m) .* ψ2[ki + 1].(x_m))

norm_ = sqrt(norm1 + norm2)
ψ1_coefs[ki + 1, :] ./= norm_
ψ2_coefs[ki + 1, :] ./= norm_
zero_out!(ψ1_coefs)
zero_out!(ψ2_coefs)

ψ1[ki + 1] = ϕ_(ψ1_coefs[ki + 1, :]; lb = 0.0, ub = 0.5 + 1e-16)
ψ2[ki + 1] = ϕ_(ψ2_coefs[ki + 1, :]; lb = 0.5 + 1e-16, ub = 1.0)
end

return ϕ, ψ1, ψ2
end

function legendre_filter(k)
H0 = zeros(k, k)
H1 = zeros(k, k)
G0 = zeros(k, k)
G1 = zeros(k, k)
ϕ, ψ1, ψ2 = legendre_ϕ_ψ(k)

l = convert(Polynomial, gen_poly(Legendre, k))
x_m = roots(l(Polynomial([-1, 2]))) # 2x-1
m = 2 .* x_m .- 1
wm = 1 ./ k ./ legendre_der.(k, m) ./ gen_poly(Legendre, k - 1).(m)

for ki in 0:(k - 1)
for kpi in 0:(k - 1)
H0[ki + 1, kpi + 1] = 1 / sqrt(2) *
sum(wm .* ϕ[ki + 1].(x_m / 2) .* ϕ[kpi + 1].(x_m))
G0[ki + 1, kpi + 1] = 1 / sqrt(2) *
sum(wm .* ψ(ψ1, ψ2, ki, x_m / 2) .* ϕ[kpi + 1].(x_m))
H1[ki + 1, kpi + 1] = 1 / sqrt(2) *
sum(wm .* ϕ[ki + 1].((x_m .+ 1) / 2) .* ϕ[kpi + 1].(x_m))
G1[ki + 1, kpi + 1] = 1 / sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m .+ 1) / 2) .*
ϕ[kpi + 1].(x_m))
end
end

zero_out!(H0)
zero_out!(H1)
zero_out!(G0)
zero_out!(G1)

return H0, H1, G0, G1, I(k), I(k)
end

function chebyshev_filter(k)
H0 = zeros(k, k)
H1 = zeros(k, k)
G0 = zeros(k, k)
G1 = zeros(k, k)
Φ0 = zeros(k, k)
Φ1 = zeros(k, k)
ϕ, ψ1, ψ2 = chebyshev_ϕ_ψ(k)

k_use = 2k
c = convert(Polynomial, gen_poly(Chebyshev, k_use))
x_m = roots(c(Polynomial([-1, 2]))) # 2x-1
# x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1)
# not needed for our purpose here, we use even k always to avoid
wm = π / k_use / 2

for ki in 0:(k - 1)
for kpi in 0:(k - 1)
H0[ki + 1, kpi + 1] = 1 / sqrt(2) *
sum(wm .* ϕ[ki + 1].(x_m / 2) .* ϕ[kpi + 1].(x_m))
H1[ki + 1, kpi + 1] = 1 / sqrt(2) *
sum(wm .* ϕ[ki + 1].((x_m .+ 1) / 2) .* ϕ[kpi + 1].(x_m))
G0[ki + 1, kpi + 1] = 1 / sqrt(2) *
sum(wm .* ψ(ψ1, ψ2, ki, x_m / 2) .* ϕ[kpi + 1].(x_m))
G1[ki + 1, kpi + 1] = 1 / sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m .+ 1) / 2) .*
ϕ[kpi + 1].(x_m))
Φ0[ki + 1, kpi + 1] = 2 * sum(wm .* ϕ[ki + 1].(2x_m) .* ϕ[kpi + 1].(2x_m))
Φ1[ki + 1, kpi + 1] = 2 * sum(wm .* ϕ[ki + 1].(2 .* x_m .- 1) .*
ϕ[kpi + 1].(2 .* x_m .- 1))
end
end

zero_out!(H0)
zero_out!(H1)
zero_out!(G0)
zero_out!(G1)
zero_out!(Φ0)
zero_out!(Φ1)

return H0, H1, G0, G1, Φ0, Φ1
end
47 changes: 47 additions & 0 deletions src/Transform/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
function ϕ_(ϕ_coefs; lb::Real = 0.0, ub::Real = 1.0)
function partial(x)
mask = (lb ≤ x ≤ ub) * 1.0
return Polynomial(ϕ_coefs)(x) * mask
end
return partial
end

function ψ(ψ1, ψ2, i, inp)
mask = (inp .> 0.5) .* 1.0
return ψ1[i + 1].(inp) .* mask .+ ψ2[i + 1].(inp) .* mask
end

zero_out!(x; tol = 1e-8) = (x[abs.(x) .< tol] .= 0)

function gen_poly(poly, n)
x = zeros(n + 1)
x[end] = 1
return poly(x)
end

function convolve(a, b)
n = length(b)
y = similar(a, length(a) + n - 1)
for i in 1:length(a)
y[i:(i + n - 1)] .+= a[i] .* b
end
return y
end

function proj_factor(a, b; complement::Bool = false)
prod_ = convolve(a, b)
r = collect(1:length(prod_))
s = complement ? (1 .- 0.5 .^ r) : (0.5 .^ r)
proj_ = sum(prod_ ./ r .* s)
return proj_
end

_legendre(k, x) = (2k + 1) * gen_poly(Legendre, k)(x)

function legendre_der(k, x)
out = 0
for i in (k - 1):-2:-1
out += _legendre(i, x)
end
return out
end
31 changes: 31 additions & 0 deletions src/Transform/wavelet_transform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
export WaveletTransform

struct WaveletTransform{N, S} <: AbstractTransform
ec_d::Any
ec_s::Any
modes::NTuple{N, S} # N == ndims(x)
end

Base.ndims(::WaveletTransform{N}) where {N} = N

function transform(wt::WaveletTransform, 𝐱::AbstractArray)
N = size(X, ndims(wt) - 1)
# 1d
Xa = vcat(view(𝐱, :, :, 1:2:N, :), view(𝐱, :, :, 2:2:N, :))
# 2d
# Xa = vcat(
# view(𝐱, :, :, 1:2:N, 1:2:N, :),
# view(𝐱, :, :, 1:2:N, 2:2:N, :),
# view(𝐱, :, :, 2:2:N, 1:2:N, :),
# view(𝐱, :, :, 2:2:N, 2:2:N, :),
# )
d = NNlib.batched_mul(Xa, wt.ec_d)
s = NNlib.batched_mul(Xa, wt.ec_s)
return d, s
end

function inverse(wt::WaveletTransform, 𝐱_fwt::AbstractArray) end

# function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray)
# return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch]
# end
Loading