From 6dc4996461b49d79c7016e42ad0fabb7bd58de61 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Sun, 19 Dec 2021 23:24:38 +0800 Subject: [PATCH 01/10] draft for SparseKernel1d --- src/Transform/Transform.jl | 1 + src/Transform/wavelet_transform.jl | 103 +++++++++++++++++++++++++++++ test/wavelet.jl | 13 ++++ 3 files changed, 117 insertions(+) create mode 100644 src/Transform/wavelet_transform.jl create mode 100644 test/wavelet.jl diff --git a/src/Transform/Transform.jl b/src/Transform/Transform.jl index 2a02f1b7..c6f4c7e5 100644 --- a/src/Transform/Transform.jl +++ b/src/Transform/Transform.jl @@ -18,3 +18,4 @@ abstract type AbstractTransform end include("fourier_transform.jl") include("chebyshev_transform.jl") +include("wavelet_transform.jl") diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl new file mode 100644 index 00000000..6dec15b7 --- /dev/null +++ b/src/Transform/wavelet_transform.jl @@ -0,0 +1,103 @@ +struct SparseKernel1d{T,S} + k::Int + conv_blk::S + out_weight::T +end + +function SparseKernel1d(k::Int, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k + emb_dim = 128 + conv = Conv((3,), input_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel1d(k, conv, W_out) +end + +function (l::SparseKernel1d)(X::AbstractArray) + X_ = l.conv_blk(batched_transpose(X)) + Y = l.out_weight(batched_transpose(X_)) + return Y +end + + +# class MWT_CZ1d(nn.Module): +# def __init__(self, +# k = 3, alpha = 5, +# L = 0, c = 1, +# base = 'legendre', +# initializer = None, +# **kwargs): +# super(MWT_CZ1d, self).__init__() + +# self.k = k +# self.L = L +# H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) +# H0r = H0@PHI0 +# G0r = G0@PHI0 +# H1r = H1@PHI1 +# G1r = G1@PHI1 + +# H0r[np.abs(H0r)<1e-8]=0 +# H1r[np.abs(H1r)<1e-8]=0 +# G0r[np.abs(G0r)<1e-8]=0 +# G1r[np.abs(G1r)<1e-8]=0 + +# self.A = sparseKernelFT1d(k, alpha, c) +# self.B = sparseKernelFT1d(k, alpha, c) +# self.C = sparseKernelFT1d(k, alpha, c) + +# self.T0 = nn.Linear(k, k) + +# self.register_buffer('ec_s', torch.Tensor( +# np.concatenate((H0.T, H1.T), axis=0))) +# self.register_buffer('ec_d', torch.Tensor( +# np.concatenate((G0.T, G1.T), axis=0))) + +# self.register_buffer('rc_e', torch.Tensor( +# np.concatenate((H0r, G0r), axis=0))) +# self.register_buffer('rc_o', torch.Tensor( +# np.concatenate((H1r, G1r), axis=0))) + + +# def forward(self, x): + +# B, N, c, ich = x.shape # (B, N, k) +# ns = math.floor(np.log2(N)) + +# Ud = torch.jit.annotate(List[Tensor], []) +# Us = torch.jit.annotate(List[Tensor], []) +# # decompose +# for i in range(ns-self.L): +# d, x = self.wavelet_transform(x) +# Ud += [self.A(d) + self.B(x)] +# Us += [self.C(d)] +# x = self.T0(x) # coarsest scale transform + +# # reconstruct +# for i in range(ns-1-self.L,-1,-1): +# x = x + Us[i] +# x = torch.cat((x, Ud[i]), -1) +# x = self.evenOdd(x) +# return x + + +# def wavelet_transform(self, x): +# xa = torch.cat([x[:, ::2, :, :], +# x[:, 1::2, :, :], +# ], -1) +# d = torch.matmul(xa, self.ec_d) +# s = torch.matmul(xa, self.ec_s) +# return d, s + + +# def evenOdd(self, x): + +# B, N, c, ich = x.shape # (B, N, c, k) +# assert ich == 2*self.k +# x_e = torch.matmul(x, self.rc_e) +# x_o = torch.matmul(x, self.rc_o) + +# x = torch.zeros(B, N*2, c, self.k, +# device = x.device) +# x[..., ::2, :, :] = x_e +# x[..., 1::2, :, :] = x_o +# return x \ No newline at end of file diff --git a/test/wavelet.jl b/test/wavelet.jl new file mode 100644 index 00000000..48642c12 --- /dev/null +++ b/test/wavelet.jl @@ -0,0 +1,13 @@ +using NeuralOperators + +T = Float32 +k = 10 +c = 1 +in_chs = 20 +batch_size = 32 + + +l = NeuralOperators.SparseKernel1d(k, c) + +X = rand(T, c*k, in_chs, batch_size) +Y = l(X) From ec9051dda4a8ba45e4c2d67bcbef5fd738e1cc38 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Mon, 20 Dec 2021 13:47:12 +0800 Subject: [PATCH 02/10] complete SparseKernel1d/2d/3d --- Project.toml | 2 + src/NeuralOperators.jl | 2 + src/Transform/Transform.jl | 2 + src/Transform/polynomials.jl | 198 +++++++++++++++++++++++++++++ src/Transform/utils.jl | 39 ++++++ src/Transform/wavelet_transform.jl | 43 ++++++- test/wavelet.jl | 34 ++++- 7 files changed, 308 insertions(+), 12 deletions(-) create mode 100644 src/Transform/polynomials.jl create mode 100644 src/Transform/utils.jl diff --git a/Project.toml b/Project.toml index 56bd62f9..bda050e6 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,8 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" +SpecialPolynomials = "a25cea48-d430-424a-8ee7-0d3ad3742e9e" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 22c317f2..00c64175 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -10,6 +10,8 @@ using Zygote using ChainRulesCore using GeometricFlux using Statistics +using Polynomials +using SpecialPolynomials include("abstracttypes.jl") diff --git a/src/Transform/Transform.jl b/src/Transform/Transform.jl index c6f4c7e5..9e8e68cd 100644 --- a/src/Transform/Transform.jl +++ b/src/Transform/Transform.jl @@ -16,6 +16,8 @@ export """ abstract type AbstractTransform end +include("utils.jl") +include("polynomials.jl") include("fourier_transform.jl") include("chebyshev_transform.jl") include("wavelet_transform.jl") diff --git a/src/Transform/polynomials.jl b/src/Transform/polynomials.jl new file mode 100644 index 00000000..4670caa0 --- /dev/null +++ b/src/Transform/polynomials.jl @@ -0,0 +1,198 @@ +function legendre_ϕ_ψ(k) + # TODO: row-major -> column major + ϕ_coefs = zeros(k, k) + ϕ_2x_coefs = zeros(k, k) + + p = Polynomial([-1, 2]) # 2x-1 + p2 = Polynomial([-1, 4]) # 4x-1 + + for ki in 0:(k-1) + l = convert(Polynomial, gen_poly(Legendre, ki)) # Legendre of n=ki + ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2*ki+1) .* coeffs(l(p)) + ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2*(2*ki+1)) .* coeffs(l(p2)) + end + + ψ1_coefs .= ϕ_2x_coefs + ψ2_coefs = zeros(k, k) + for ki in 0:(k-1) + for i in 0:(k-1) + a = ϕ_2x_coefs[ki+1, 1:(ki+1)] + b = ϕ_coefs[i+1, 1:(i+1)] + proj_ = proj_factor(a, b) + view(ψ1_coefs, ki+1, :) .-= proj_ .* view(ϕ_coefs, i+1, :) + view(ψ2_coefs, ki+1, :) .-= proj_ .* view(ϕ_coefs, i+1, :) + end + + for j in 0:(k-1) + a = ϕ_2x_coefs[ki+1, 1:(ki+1)] + b = ψ1_coefs[j+1, :] + proj_ = proj_factor(a, b) + view(ψ1_coefs, ki+1, :) .-= proj_ .* view(ψ1_coefs, j+1, :) + view(ψ2_coefs, ki+1, :) .-= proj_ .* view(ψ2_coefs, j+1, :) + end + + a = ψ1_coefs[ki+1, :] + norm1 = proj_factor(a, a) + + a = ψ2_coefs[ki+1, :] + norm2 = proj_factor(a, a, complement=true) + norm_ = sqrt(norm1 + norm2) + ψ1_coefs[ki+1, :] ./= norm_ + ψ2_coefs[ki+1, :] ./= norm_ + zero_out!(ψ1_coefs) + zero_out!(ψ2_coefs) + end + + ϕ = [Polynomial(ϕ_coefs[i,:]) for i in 1:k] + ψ1 = [Polynomial(ψ1_coefs[i,:]) for i in 1:k] + ψ2 = [Polynomial(ψ2_coefs[i,:]) for i in 1:k] + + return ϕ, ψ1, ψ2 +end + +# function chebyshev_ϕ_ψ(k) +# ϕ_coefs = zeros(k, k) +# ϕ_2x_coefs = zeros(k, k) + +# p = Polynomial([-1, 2]) # 2x-1 +# p2 = Polynomial([-1, 4]) # 4x-1 + +# for ki in 0:(k-1) +# if ki == 0 +# ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2/π) +# ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(4/π) +# else +# c = convert(Polynomial, gen_poly(Chebyshev, ki)) # Chebyshev of n=ki +# ϕ_coefs[ki+1, 1:(ki+1)] .= 2/sqrt(π) .* coeffs(c(p)) +# ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2) * 2/sqrt(π) .* coeffs(c(p2)) +# end +# end + +# ϕ = [ϕ_(ϕ_coefs[i, :]) for i in 1:k] + +# k_use = 2k + +# # phi = [partial(phi_, phi_coeff[i,:]) for i in range(k)] + +# # x = Symbol('x') +# # kUse = 2*k +# # roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots() +# # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) +# # # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) +# # # not needed for our purpose here, we use even k always to avoid +# # wm = np.pi / kUse / 2 + +# # psi1_coeff = np.zeros((k, k)) +# # psi2_coeff = np.zeros((k, k)) + +# # psi1 = [[] for _ in range(k)] +# # psi2 = [[] for _ in range(k)] + +# # for ki in range(k): +# # psi1_coeff[ki,:] = phi_2x_coeff[ki,:] +# # for i in range(k): +# # proj_ = (wm * phi[i](x_m) * np.sqrt(2)* phi[ki](2*x_m)).sum() +# # psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:] +# # psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:] + +# # for j in range(ki): +# # proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum() +# # psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:] +# # psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:] + +# # psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5) +# # psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5, ub = 1) + +# # norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum() +# # norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum() + +# # norm_ = np.sqrt(norm1 + norm2) +# # psi1_coeff[ki,:] /= norm_ +# # psi2_coeff[ki,:] /= norm_ +# # psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0 +# # psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0 + +# # psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5+1e-16) +# # psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5+1e-16, ub = 1) + +# # return phi, psi1, psi2 +# end + +function legendre_filter(k) + # x = Symbol('x') + # H0 = np.zeros((k,k)) + # H1 = np.zeros((k,k)) + # G0 = np.zeros((k,k)) + # G1 = np.zeros((k,k)) + # PHI0 = np.zeros((k,k)) + # PHI1 = np.zeros((k,k)) + # phi, psi1, psi2 = get_phi_psi(k, base) + + # ---------------------------------------------------------- + + # roots = Poly(legendre(k, 2*x-1)).all_roots() + # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) + # wm = 1/k/legendreDer(k,2*x_m-1)/eval_legendre(k-1,2*x_m-1) + + # for ki in range(k): + # for kpi in range(k): + # H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum() + # G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum() + # H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum() + # G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum() + + # PHI0 = np.eye(k) + # PHI1 = np.eye(k) + + # ---------------------------------------------------------- + + # H0[np.abs(H0)<1e-8] = 0 + # H1[np.abs(H1)<1e-8] = 0 + # G0[np.abs(G0)<1e-8] = 0 + # G1[np.abs(G1)<1e-8] = 0 + + # return H0, H1, G0, G1, PHI0, PHI1 +end + +function chebyshev_filter(k) + # x = Symbol('x') + # H0 = np.zeros((k,k)) + # H1 = np.zeros((k,k)) + # G0 = np.zeros((k,k)) + # G1 = np.zeros((k,k)) + # PHI0 = np.zeros((k,k)) + # PHI1 = np.zeros((k,k)) + # phi, psi1, psi2 = get_phi_psi(k, base) + + # ---------------------------------------------------------- + + # x = Symbol('x') + # kUse = 2*k + # roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots() + # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) + # # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) + # # not needed for our purpose here, we use even k always to avoid + # wm = np.pi / kUse / 2 + + # for ki in range(k): + # for kpi in range(k): + # H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum() + # G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum() + # H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum() + # G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum() + + # PHI0[ki, kpi] = (wm * phi[ki](2*x_m) * phi[kpi](2*x_m)).sum() * 2 + # PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) * phi[kpi](2*x_m-1)).sum() * 2 + + # PHI0[np.abs(PHI0)<1e-8] = 0 + # PHI1[np.abs(PHI1)<1e-8] = 0 + + # ---------------------------------------------------------- + + # H0[np.abs(H0)<1e-8] = 0 + # H1[np.abs(H1)<1e-8] = 0 + # G0[np.abs(G0)<1e-8] = 0 + # G1[np.abs(G1)<1e-8] = 0 + + # return H0, H1, G0, G1, PHI0, PHI1 +end diff --git a/src/Transform/utils.jl b/src/Transform/utils.jl new file mode 100644 index 00000000..d9d5855a --- /dev/null +++ b/src/Transform/utils.jl @@ -0,0 +1,39 @@ +# function ϕ_(ϕ_coefs; lb::Real=0., ub::Real=1.) +# mask = +# return Polynomial(ϕ_coefs) +# end + +# def phi_(phi_c, x, lb = 0, ub = 1): +# mask = np.logical_or(xub) * 1.0 +# return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1-mask) + +function ψ(ψ1, ψ2, i, inp) + mask = (inp ≤ 0.5) * 1.0 + return ψ1[i](inp) * mask + ψ2[i](inp) * (1-mask) +end + +zero_out!(x; tol=1e-8) = (x[abs.(x) .< tol] .= 0) + +function gen_poly(poly, n) + x = zeros(n+1) + x[end] = 1 + return poly(x) +end + +function convolve(a, b) + n = length(b) + y = similar(a, length(a)+n-1) + for i in 1:length(a) + y[i:(i+n-1)] .+= a[i] .* b + end + return y +end + +function proj_factor(a, b; complement::Bool=false) + prod_ = convolve(a, b) + zero_out!(prod_) + r = collect(1:length(prod_)) + s = complement ? (1 .- 0.5 .^ r) : (0.5 .^ r) + proj_ = sum(prod_ ./ r .* s) + return proj_ +end diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl index 6dec15b7..52e4fe21 100644 --- a/src/Transform/wavelet_transform.jl +++ b/src/Transform/wavelet_transform.jl @@ -1,24 +1,53 @@ -struct SparseKernel1d{T,S} +struct SparseKernel{T,S} k::Int conv_blk::S out_weight::T end -function SparseKernel1d(k::Int, c::Int=1; init=Flux.glorot_uniform) +function SparseKernel1d(k::Int, α, c::Int=1; init=Flux.glorot_uniform) input_dim = c*k emb_dim = 128 conv = Conv((3,), input_dim=>emb_dim, relu; stride=1, pad=1, init=init) W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel1d(k, conv, W_out) + return SparseKernel(k, conv, W_out) end -function (l::SparseKernel1d)(X::AbstractArray) - X_ = l.conv_blk(batched_transpose(X)) - Y = l.out_weight(batched_transpose(X_)) - return Y +function SparseKernel2d(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k^2 + emb_dim = α*k^2 + conv = Conv((3, 3), input_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel(k, conv, W_out) +end + +function SparseKernel3d(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k^2 + emb_dim = α*k^2 + conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel(k, conv, W_out) +end + +function (l::SparseKernel)(X::AbstractArray) + bch_sz, _, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + + X_ = l.conv_blk(X) # (dims..., emb_dims, B) + X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B) + Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B) + Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B) + return collect(Y) end +# struct MWT_CZ1d + +# end + +# function MWT_CZ1d(k::Int=3, c::Int=1; init=Flux.glorot_uniform) + +# end + # class MWT_CZ1d(nn.Module): # def __init__(self, # k = 3, alpha = 5, diff --git a/test/wavelet.jl b/test/wavelet.jl index 48642c12..d4e920d9 100644 --- a/test/wavelet.jl +++ b/test/wavelet.jl @@ -1,13 +1,37 @@ using NeuralOperators +using CUDA +using Zygote + +CUDA.allowscalar(false) T = Float32 -k = 10 +k = 3 +batch_size = 32 + +α = 4 c = 1 in_chs = 20 -batch_size = 32 -l = NeuralOperators.SparseKernel1d(k, c) +l1 = NeuralOperators.SparseKernel1d(k, α, c) +X = rand(T, in_chs, c*k, batch_size) +Y = l1(X) +gradient(x->sum(l1(x)), X) + + +α = 4 +c = 3 +Nx = 5 +Ny = 7 + +l2 = NeuralOperators.SparseKernel2d(k, α, c) +X = rand(T, Nx, Ny, c*k^2, batch_size) +Y = l2(X) +gradient(x->sum(l2(x)), X) + +Nz = 13 -X = rand(T, c*k, in_chs, batch_size) -Y = l(X) +l3 = NeuralOperators.SparseKernel3d(k, α, c) +X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) +Y = l3(X) +gradient(x->sum(l3(x)), X) From 15c6b065b47975e874632cbdf0fc7abdcb069600 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Thu, 23 Dec 2021 15:14:59 +0800 Subject: [PATCH 03/10] complete SparseKernel{N} --- src/Transform/wavelet_transform.jl | 39 +++++++++++++-------- test/Transform/Transform.jl | 1 + test/Transform/wavelet_transform.jl | 53 +++++++++++++++++++++++++++++ test/wavelet.jl | 37 -------------------- 4 files changed, 79 insertions(+), 51 deletions(-) create mode 100644 test/Transform/wavelet_transform.jl delete mode 100644 test/wavelet.jl diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl index 52e4fe21..49596c48 100644 --- a/src/Transform/wavelet_transform.jl +++ b/src/Transform/wavelet_transform.jl @@ -1,33 +1,44 @@ -struct SparseKernel{T,S} - k::Int - conv_blk::S - out_weight::T +export + SparseKernel, + SparseKernel1D, + SparseKernel2D, + SparseKernel3D + + +struct SparseKernel{N,T,S} + conv_blk::T + out_weight::S +end + +function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S} + input_dim, emb_dim = ch + conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out) end -function SparseKernel1d(k::Int, α, c::Int=1; init=Flux.glorot_uniform) +function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) input_dim = c*k emb_dim = 128 - conv = Conv((3,), input_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel(k, conv, W_out) + return SparseKernel((3, ), input_dim=>emb_dim; init=init) end -function SparseKernel2d(k::Int, α, c::Int=1; init=Flux.glorot_uniform) +function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) input_dim = c*k^2 emb_dim = α*k^2 - conv = Conv((3, 3), input_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel(k, conv, W_out) + return SparseKernel((3, 3), input_dim=>emb_dim; init=init) end -function SparseKernel3d(k::Int, α, c::Int=1; init=Flux.glorot_uniform) +function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) input_dim = c*k^2 emb_dim = α*k^2 conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel(k, conv, W_out) + return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out) end +Flux.@functor SparseKernel + function (l::SparseKernel)(X::AbstractArray) bch_sz, _, dims_r... = reverse(size(X)) dims = reverse(dims_r) diff --git a/test/Transform/Transform.jl b/test/Transform/Transform.jl index d5ff9a67..188675e7 100644 --- a/test/Transform/Transform.jl +++ b/test/Transform/Transform.jl @@ -1,4 +1,5 @@ @testset "Transform" begin include("fourier_transform.jl") include("chebyshev_transform.jl") + include("wavelet_transform.jl") end diff --git a/test/Transform/wavelet_transform.jl b/test/Transform/wavelet_transform.jl new file mode 100644 index 00000000..726727eb --- /dev/null +++ b/test/Transform/wavelet_transform.jl @@ -0,0 +1,53 @@ +@testset "SparseKernel" begin + T = Float32 + k = 3 + batch_size = 32 + + @testset "1D SparseKernel" begin + α = 4 + c = 1 + in_chs = 20 + X = rand(T, in_chs, c*k, batch_size) + + l1 = SparseKernel1D(k, α, c) + Y = l1(X) + @test l1 isa SparseKernel{1} + @test size(Y) == size(X) + + gs = gradient(()->sum(l1(X)), Flux.params(l1)) + @test length(gs.grads) == 4 + end + + @testset "2D SparseKernel" begin + α = 4 + c = 3 + Nx = 5 + Ny = 7 + X = rand(T, Nx, Ny, c*k^2, batch_size) + + l2 = SparseKernel2D(k, α, c) + Y = l2(X) + @test l2 isa SparseKernel{2} + @test size(Y) == size(X) + + gs = gradient(()->sum(l2(X)), Flux.params(l2)) + @test length(gs.grads) == 4 + end + + @testset "3D SparseKernel" begin + α = 4 + c = 3 + Nx = 5 + Ny = 7 + Nz = 13 + X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) + + l3 = SparseKernel3D(k, α, c) + Y = l3(X) + @test l3 isa SparseKernel{3} + @test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size) + + gs = gradient(()->sum(l3(X)), Flux.params(l3)) + @test length(gs.grads) == 4 + end +end diff --git a/test/wavelet.jl b/test/wavelet.jl deleted file mode 100644 index d4e920d9..00000000 --- a/test/wavelet.jl +++ /dev/null @@ -1,37 +0,0 @@ -using NeuralOperators -using CUDA -using Zygote - -CUDA.allowscalar(false) - -T = Float32 -k = 3 -batch_size = 32 - -α = 4 -c = 1 -in_chs = 20 - - -l1 = NeuralOperators.SparseKernel1d(k, α, c) -X = rand(T, in_chs, c*k, batch_size) -Y = l1(X) -gradient(x->sum(l1(x)), X) - - -α = 4 -c = 3 -Nx = 5 -Ny = 7 - -l2 = NeuralOperators.SparseKernel2d(k, α, c) -X = rand(T, Nx, Ny, c*k^2, batch_size) -Y = l2(X) -gradient(x->sum(l2(x)), X) - -Nz = 13 - -l3 = NeuralOperators.SparseKernel3d(k, α, c) -X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) -Y = l3(X) -gradient(x->sum(l3(x)), X) From 37e5d48aafdb298a1f900a89fcc4b016520f9869 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Fri, 24 Dec 2021 00:09:38 +0800 Subject: [PATCH 04/10] draft for MWT_CZ1d --- src/Transform/polynomials.jl | 10 ++ src/Transform/wavelet_transform.jl | 159 +++++++++++++---------------- 2 files changed, 83 insertions(+), 86 deletions(-) diff --git a/src/Transform/polynomials.jl b/src/Transform/polynomials.jl index 4670caa0..fe8f9fb8 100644 --- a/src/Transform/polynomials.jl +++ b/src/Transform/polynomials.jl @@ -1,3 +1,13 @@ +function get_filter(base::Symbol, k) + if base == :legendre + return legendre_filter(k) + elseif base == :chebyshev + return chebyshev_filter(k) + else + throw(ArgumentError("base must be one of :legendre or :chebyshev.")) + end +end + function legendre_ϕ_ψ(k) # TODO: row-major -> column major ϕ_coefs = zeros(k, k) diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl index 49596c48..d32b16f4 100644 --- a/src/Transform/wavelet_transform.jl +++ b/src/Transform/wavelet_transform.jl @@ -51,93 +51,80 @@ function (l::SparseKernel)(X::AbstractArray) end -# struct MWT_CZ1d - -# end - -# function MWT_CZ1d(k::Int=3, c::Int=1; init=Flux.glorot_uniform) - -# end - -# class MWT_CZ1d(nn.Module): -# def __init__(self, -# k = 3, alpha = 5, -# L = 0, c = 1, -# base = 'legendre', -# initializer = None, -# **kwargs): -# super(MWT_CZ1d, self).__init__() - -# self.k = k -# self.L = L -# H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k) -# H0r = H0@PHI0 -# G0r = G0@PHI0 -# H1r = H1@PHI1 -# G1r = G1@PHI1 - -# H0r[np.abs(H0r)<1e-8]=0 -# H1r[np.abs(H1r)<1e-8]=0 -# G0r[np.abs(G0r)<1e-8]=0 -# G1r[np.abs(G1r)<1e-8]=0 - -# self.A = sparseKernelFT1d(k, alpha, c) -# self.B = sparseKernelFT1d(k, alpha, c) -# self.C = sparseKernelFT1d(k, alpha, c) - -# self.T0 = nn.Linear(k, k) - -# self.register_buffer('ec_s', torch.Tensor( -# np.concatenate((H0.T, H1.T), axis=0))) -# self.register_buffer('ec_d', torch.Tensor( -# np.concatenate((G0.T, G1.T), axis=0))) - -# self.register_buffer('rc_e', torch.Tensor( -# np.concatenate((H0r, G0r), axis=0))) -# self.register_buffer('rc_o', torch.Tensor( -# np.concatenate((H1r, G1r), axis=0))) - - -# def forward(self, x): - -# B, N, c, ich = x.shape # (B, N, k) -# ns = math.floor(np.log2(N)) - -# Ud = torch.jit.annotate(List[Tensor], []) -# Us = torch.jit.annotate(List[Tensor], []) -# # decompose -# for i in range(ns-self.L): -# d, x = self.wavelet_transform(x) -# Ud += [self.A(d) + self.B(x)] -# Us += [self.C(d)] -# x = self.T0(x) # coarsest scale transform - -# # reconstruct -# for i in range(ns-1-self.L,-1,-1): -# x = x + Us[i] -# x = torch.cat((x, Ud[i]), -1) -# x = self.evenOdd(x) -# return x - - -# def wavelet_transform(self, x): -# xa = torch.cat([x[:, ::2, :, :], -# x[:, 1::2, :, :], -# ], -1) -# d = torch.matmul(xa, self.ec_d) -# s = torch.matmul(xa, self.ec_s) -# return d, s - - -# def evenOdd(self, x): - -# B, N, c, ich = x.shape # (B, N, c, k) -# assert ich == 2*self.k -# x_e = torch.matmul(x, self.rc_e) -# x_o = torch.matmul(x, self.rc_o) - +struct MWT_CZ1d{T,S,R,Q,P} + k::Int + L::Int + A::T + B::S + C::R + T0::Q + ec_s::P + ec_d::P + rc_e::P + rc_o::P +end + +function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform) + H0, H1, G0, G1, ϕ0, ϕ1 = get_filter(base, k) + H0r = zero_out!(H0 * ϕ0) + G0r = zero_out!(G0 * ϕ0) + H1r = zero_out!(H1 * ϕ1) + G1r = zero_out!(G1 * ϕ1) + + dim = c*k + A = SpectralConv(dim=>dim, (α,); init=init) + B = SpectralConv(dim=>dim, (α,); init=init) + C = SpectralConv(dim=>dim, (α,); init=init) + T0 = Dense(k, k) + + ec_s = vcat(H0', H1') + ec_d = vcat(G0', G1') + rc_e = vcat(H0r, G0r) + rc_o = vcat(H1r, G1r) + return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o) +end + +function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} + N = size(X, 3) + Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :)) + d = NNlib.batched_mul(Xa, l.ec_d) + s = NNlib.batched_mul(Xa, l.ec_s) + return d, s +end + +function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} + bch_sz, N, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + @assert dims[1] == 2*l.k + Xₑ = NNlib.batched_mul(X, l.rc_e) + Xₒ = NNlib.batched_mul(X, l.rc_o) # x = torch.zeros(B, N*2, c, self.k, # device = x.device) # x[..., ::2, :, :] = x_e # x[..., 1::2, :, :] = x_o -# return x \ No newline at end of file + return X +end + +function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} + bch_sz, N, dims_r... = reverse(size(X)) + ns = floor(log2(N)) + stop = ns - l.L + + # decompose + Ud = T[] + Us = T[] + for i in 1:stop + d, X = wavelet_transform(l, X) + push!(Ud, l.A(d)+l.B(d)) + push!(Us, l.C(d)) + end + X = l.T0(X) + + # reconstruct + for i in stop:-1:1 + X += Us[i] + X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1) + X = even_odd(l, X) + end + return X +end From 1a716261b011908c0f65748c10957368bc081d55 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Tue, 1 Feb 2022 01:09:59 +0800 Subject: [PATCH 05/10] complete legendre polynomials --- src/Transform/polynomials.jl | 73 ++++++++++++------------------ src/Transform/utils.jl | 1 - src/Transform/wavelet_transform.jl | 10 ++-- test/Transform/Transform.jl | 1 + test/polynomials.jl | 34 ++++++++++++++ test/runtests.jl | 1 + 6 files changed, 71 insertions(+), 49 deletions(-) create mode 100644 test/polynomials.jl diff --git a/src/Transform/polynomials.jl b/src/Transform/polynomials.jl index fe8f9fb8..e4e0658f 100644 --- a/src/Transform/polynomials.jl +++ b/src/Transform/polynomials.jl @@ -22,23 +22,24 @@ function legendre_ϕ_ψ(k) ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2*(2*ki+1)) .* coeffs(l(p2)) end - ψ1_coefs .= ϕ_2x_coefs + ψ1_coefs = zeros(k, k) ψ2_coefs = zeros(k, k) for ki in 0:(k-1) + ψ1_coefs[ki+1, :] .= ϕ_2x_coefs[ki+1, :] for i in 0:(k-1) a = ϕ_2x_coefs[ki+1, 1:(ki+1)] b = ϕ_coefs[i+1, 1:(i+1)] proj_ = proj_factor(a, b) - view(ψ1_coefs, ki+1, :) .-= proj_ .* view(ϕ_coefs, i+1, :) - view(ψ2_coefs, ki+1, :) .-= proj_ .* view(ϕ_coefs, i+1, :) + ψ1_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) + ψ2_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) end for j in 0:(k-1) a = ϕ_2x_coefs[ki+1, 1:(ki+1)] b = ψ1_coefs[j+1, :] proj_ = proj_factor(a, b) - view(ψ1_coefs, ki+1, :) .-= proj_ .* view(ψ1_coefs, j+1, :) - view(ψ2_coefs, ki+1, :) .-= proj_ .* view(ψ2_coefs, j+1, :) + ψ1_coefs[ki+1, :] .-= proj_ .* view(ψ1_coefs, j+1, :) + ψ2_coefs[ki+1, :] .-= proj_ .* view(ψ2_coefs, j+1, :) end a = ψ1_coefs[ki+1, :] @@ -129,16 +130,11 @@ end # end function legendre_filter(k) - # x = Symbol('x') - # H0 = np.zeros((k,k)) - # H1 = np.zeros((k,k)) - # G0 = np.zeros((k,k)) - # G1 = np.zeros((k,k)) - # PHI0 = np.zeros((k,k)) - # PHI1 = np.zeros((k,k)) - # phi, psi1, psi2 = get_phi_psi(k, base) - - # ---------------------------------------------------------- + H0 = zeros(k, k)legendre + H1 = zeros(k, k) + G0 = zeros(k, k) + G1 = zeros(k, k) + ϕ, ψ1, ψ2 = legendre_ϕ_ψ(k) # roots = Poly(legendre(k, 2*x-1)).all_roots() # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) @@ -150,29 +146,23 @@ function legendre_filter(k) # G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum() # H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum() # G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum() - - # PHI0 = np.eye(k) - # PHI1 = np.eye(k) - - # ---------------------------------------------------------- - # H0[np.abs(H0)<1e-8] = 0 - # H1[np.abs(H1)<1e-8] = 0 - # G0[np.abs(G0)<1e-8] = 0 - # G1[np.abs(G1)<1e-8] = 0 + zero_out!(H0) + zero_out!(H1) + zero_out!(G0) + zero_out!(G1) - # return H0, H1, G0, G1, PHI0, PHI1 + return H0, H1, G0, G1, I(k), I(k) end function chebyshev_filter(k) - # x = Symbol('x') - # H0 = np.zeros((k,k)) - # H1 = np.zeros((k,k)) - # G0 = np.zeros((k,k)) - # G1 = np.zeros((k,k)) - # PHI0 = np.zeros((k,k)) - # PHI1 = np.zeros((k,k)) - # phi, psi1, psi2 = get_phi_psi(k, base) + H0 = zeros(k, k) + H1 = zeros(k, k) + G0 = zeros(k, k) + G1 = zeros(k, k) + Φ0 = zeros(k, k) + Φ1 = zeros(k, k) + ϕ, ψ1, ψ2 = chebyshev_ϕ_ψ(k) # ---------------------------------------------------------- @@ -193,16 +183,13 @@ function chebyshev_filter(k) # PHI0[ki, kpi] = (wm * phi[ki](2*x_m) * phi[kpi](2*x_m)).sum() * 2 # PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) * phi[kpi](2*x_m-1)).sum() * 2 - - # PHI0[np.abs(PHI0)<1e-8] = 0 - # PHI1[np.abs(PHI1)<1e-8] = 0 - - # ---------------------------------------------------------- - # H0[np.abs(H0)<1e-8] = 0 - # H1[np.abs(H1)<1e-8] = 0 - # G0[np.abs(G0)<1e-8] = 0 - # G1[np.abs(G1)<1e-8] = 0 + zero_out!(H0) + zero_out!(H1) + zero_out!(G0) + zero_out!(G1) + zero_out!(Φ0) + zero_out!(Φ1) - # return H0, H1, G0, G1, PHI0, PHI1 + return H0, H1, G0, G1, Φ0, Φ1 end diff --git a/src/Transform/utils.jl b/src/Transform/utils.jl index d9d5855a..f33f894f 100644 --- a/src/Transform/utils.jl +++ b/src/Transform/utils.jl @@ -31,7 +31,6 @@ end function proj_factor(a, b; complement::Bool=false) prod_ = convolve(a, b) - zero_out!(prod_) r = collect(1:length(prod_)) s = complement ? (1 .- 0.5 .^ r) : (0.5 .^ r) proj_ = sum(prod_ ./ r .* s) diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl index d32b16f4..26749c73 100644 --- a/src/Transform/wavelet_transform.jl +++ b/src/Transform/wavelet_transform.jl @@ -65,11 +65,11 @@ struct MWT_CZ1d{T,S,R,Q,P} end function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform) - H0, H1, G0, G1, ϕ0, ϕ1 = get_filter(base, k) - H0r = zero_out!(H0 * ϕ0) - G0r = zero_out!(G0 * ϕ0) - H1r = zero_out!(H1 * ϕ1) - G1r = zero_out!(G1 * ϕ1) + H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k) + H0r = zero_out!(H0 * Φ0) + G0r = zero_out!(G0 * Φ0) + H1r = zero_out!(H1 * Φ1) + G1r = zero_out!(G1 * Φ1) dim = c*k A = SpectralConv(dim=>dim, (α,); init=init) diff --git a/test/Transform/Transform.jl b/test/Transform/Transform.jl index 188675e7..abb5cac9 100644 --- a/test/Transform/Transform.jl +++ b/test/Transform/Transform.jl @@ -1,4 +1,5 @@ @testset "Transform" begin + include("polynomials.jl") include("fourier_transform.jl") include("chebyshev_transform.jl") include("wavelet_transform.jl") diff --git a/test/polynomials.jl b/test/polynomials.jl new file mode 100644 index 00000000..31fb81d6 --- /dev/null +++ b/test/polynomials.jl @@ -0,0 +1,34 @@ +@testset "polynomials" begin + @testset "legendre_ϕ_ψ" begin + ϕ, ψ1, ψ2 = NeuralOperators.legendre_ϕ_ψ(10) + + @test all(coeffs(ϕ[1]) .≈ [1.]) + @test all(coeffs(ϕ[2]) .≈ [-1.7320508075688772, 3.4641016151377544]) + @test all(coeffs(ϕ[3]) .≈ [2.23606797749979, -13.416407864998739, 13.416407864998739]) + @test all(coeffs(ϕ[4]) .≈ [-2.6457513110645907, 31.74901573277509, -79.37253933193772, 52.91502622129181]) + @test all(coeffs(ϕ[5]) .≈ [3.0, -60.0, 270.0, -420.0, 210.0]) + @test all(coeffs(ϕ[6]) .≈ [-3.3166247903554, 99.498743710662, -696.491205974634, 1857.309882599024, + -2089.4736179239017, 835.7894471695607]) + @test all(coeffs(ϕ[7]) .≈ [3.605551275463989, -151.43315356948753, 1514.3315356948754, -6057.326142779501, + 11357.486517711566, -9994.588135586178, 3331.529378528726]) + @test all(coeffs(ϕ[8]) .≈ [-3.872983346207417, 216.88706738761536, -2927.9754097328073, 16266.530054071152, + -44732.957648695665, 64415.45901412176, -46522.27595464349, 13292.078844183856]) + @test all(coeffs(ϕ[9]) .≈ [4.123105625617661, -296.86360504447157, 5195.113088278253, -38097.49598070719, + 142865.60992765194, -297160.46864951606, 346687.21342443535, -212257.47760679715, + 53064.36940169929]) + @test all(coeffs(ϕ[10]) .≈ [-4.358898943540674, 392.30090491866065, -8630.619908210534, 80552.45247663166, + -392693.20582357934, 1099540.9763060221, -1832568.2938433702, 1795168.9409077913, + -953683.4998572641, 211929.66663494756]) + + ϕ, ψ1, ψ2 = NeuralOperators.legendre_ϕ_ψ(3) + @test coeffs(ϕ[1]) ≈ [1.] + @test coeffs(ϕ[2]) ≈ [-1.7320508075688772, 3.4641016151377544] + @test coeffs(ϕ[3]) ≈ [2.23606797749979, -13.416407864998739, 13.416407864998739] + @test coeffs(ψ1[1]) ≈ [-1.0000000000000122, 6.000000000000073] + @test coeffs(ψ1[2]) ≈ [1.7320508075691732, -24.248711305967735, 51.96152422707261] + @test coeffs(ψ1[3]) ≈ [2.2360679774995615, -26.832815729994504, 53.665631459989214] + @test coeffs(ψ2[1]) ≈ [-5.000000000000066, 6.000000000000073] + @test coeffs(ψ2[2]) ≈ [29.44486372867492, -79.67433714817852, 51.96152422707261] + @test coeffs(ψ2[3]) ≈ [-29.068883707507286, 80.49844719001908, -53.665631460012115] + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 6ffe561a..90d4dcb2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using CUDA using Flux using GeometricFlux using Graphs +using Polynomials using Zygote using Test From e5afafbb02062f0811d420177e2de40dbb234ee5 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Tue, 1 Feb 2022 14:48:36 +0800 Subject: [PATCH 06/10] complete chebyshev polynomials --- src/Transform/polynomials.jl | 126 +++++++++++++++++------------------ src/Transform/utils.jl | 15 ++--- test/polynomials.jl | 33 +++++++++ 3 files changed, 102 insertions(+), 72 deletions(-) diff --git a/src/Transform/polynomials.jl b/src/Transform/polynomials.jl index e4e0658f..8bee986b 100644 --- a/src/Transform/polynomials.jl +++ b/src/Transform/polynomials.jl @@ -61,73 +61,71 @@ function legendre_ϕ_ψ(k) return ϕ, ψ1, ψ2 end -# function chebyshev_ϕ_ψ(k) -# ϕ_coefs = zeros(k, k) -# ϕ_2x_coefs = zeros(k, k) - -# p = Polynomial([-1, 2]) # 2x-1 -# p2 = Polynomial([-1, 4]) # 4x-1 - -# for ki in 0:(k-1) -# if ki == 0 -# ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2/π) -# ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(4/π) -# else -# c = convert(Polynomial, gen_poly(Chebyshev, ki)) # Chebyshev of n=ki -# ϕ_coefs[ki+1, 1:(ki+1)] .= 2/sqrt(π) .* coeffs(c(p)) -# ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2) * 2/sqrt(π) .* coeffs(c(p2)) -# end -# end - -# ϕ = [ϕ_(ϕ_coefs[i, :]) for i in 1:k] - -# k_use = 2k - -# # phi = [partial(phi_, phi_coeff[i,:]) for i in range(k)] +function chebyshev_ϕ_ψ(k) + ϕ_coefs = zeros(k, k) + ϕ_2x_coefs = zeros(k, k) + + p = Polynomial([-1, 2]) # 2x-1 + p2 = Polynomial([-1, 4]) # 4x-1 + + for ki in 0:(k-1) + if ki == 0 + ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2/π) + ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(4/π) + else + c = convert(Polynomial, gen_poly(Chebyshev, ki)) # Chebyshev of n=ki + ϕ_coefs[ki+1, 1:(ki+1)] .= 2/sqrt(π) .* coeffs(c(p)) + ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2) * 2/sqrt(π) .* coeffs(c(p2)) + end + end + + ϕ = [ϕ_(ϕ_coefs[i, :]) for i in 1:k] + + k_use = 2k + c = convert(Polynomial, gen_poly(Chebyshev, k_use)) + x_m = roots(c(p)) + # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) + # not needed for our purpose here, we use even k always to avoid + wm = π / k_use / 2 + + ψ1_coefs = zeros(k, k) + ψ2_coefs = zeros(k, k) + + ψ1 = Array{Any,1}(undef, k) + ψ2 = Array{Any,1}(undef, k) + + for ki in 0:(k-1) + ψ1_coefs[ki+1, :] .= ϕ_2x_coefs[ki+1, :] + for i in 0:(k-1) + proj_ = sum(wm .* ϕ[i+1].(x_m) .* sqrt(2) .* ϕ[ki+1].(2*x_m)) + ψ1_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) + ψ2_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) + end + + for j in 0:(ki-1) + proj_ = sum(wm .* ψ1[j+1].(x_m) .* sqrt(2) .* ϕ[ki+1].(2*x_m)) + ψ1_coefs[ki+1, :] .-= proj_ .* view(ψ1_coefs, j+1, :) + ψ2_coefs[ki+1, :] .-= proj_ .* view(ψ2_coefs, j+1, :) + end + + ψ1[ki+1] = ϕ_(ψ1_coefs[ki+1,:]; lb=0., ub=0.5) + ψ2[ki+1] = ϕ_(ψ2_coefs[ki+1,:]; lb=0.5, ub=1.) -# # x = Symbol('x') -# # kUse = 2*k -# # roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots() -# # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) -# # # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) -# # # not needed for our purpose here, we use even k always to avoid -# # wm = np.pi / kUse / 2 + norm1 = sum(wm .* ψ1[ki+1].(x_m) .* ψ1[ki+1].(x_m)) + norm2 = sum(wm .* ψ2[ki+1].(x_m) .* ψ2[ki+1].(x_m)) -# # psi1_coeff = np.zeros((k, k)) -# # psi2_coeff = np.zeros((k, k)) - -# # psi1 = [[] for _ in range(k)] -# # psi2 = [[] for _ in range(k)] - -# # for ki in range(k): -# # psi1_coeff[ki,:] = phi_2x_coeff[ki,:] -# # for i in range(k): -# # proj_ = (wm * phi[i](x_m) * np.sqrt(2)* phi[ki](2*x_m)).sum() -# # psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:] -# # psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:] - -# # for j in range(ki): -# # proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum() -# # psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:] -# # psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:] - -# # psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5) -# # psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5, ub = 1) - -# # norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum() -# # norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum() - -# # norm_ = np.sqrt(norm1 + norm2) -# # psi1_coeff[ki,:] /= norm_ -# # psi2_coeff[ki,:] /= norm_ -# # psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0 -# # psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0 - -# # psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5+1e-16) -# # psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5+1e-16, ub = 1) + norm_ = sqrt(norm1 + norm2) + ψ1_coefs[ki+1, :] ./= norm_ + ψ2_coefs[ki+1, :] ./= norm_ + zero_out!(ψ1_coefs) + zero_out!(ψ2_coefs) -# # return phi, psi1, psi2 -# end + ψ1[ki+1] = ϕ_(ψ1_coefs[ki+1,:]; lb=0., ub=0.5+1e-16) + ψ2[ki+1] = ϕ_(ψ2_coefs[ki+1,:]; lb=0.5+1e-16, ub=1.) + end + + return ϕ, ψ1, ψ2 +end function legendre_filter(k) H0 = zeros(k, k)legendre diff --git a/src/Transform/utils.jl b/src/Transform/utils.jl index f33f894f..ae64aa18 100644 --- a/src/Transform/utils.jl +++ b/src/Transform/utils.jl @@ -1,11 +1,10 @@ -# function ϕ_(ϕ_coefs; lb::Real=0., ub::Real=1.) -# mask = -# return Polynomial(ϕ_coefs) -# end - -# def phi_(phi_c, x, lb = 0, ub = 1): -# mask = np.logical_or(xub) * 1.0 -# return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1-mask) +function ϕ_(ϕ_coefs; lb::Real=0., ub::Real=1.) + function partial(x) + mask = (lb ≤ x ≤ ub) * 1. + return Polynomial(ϕ_coefs)(x) * mask + end + return partial +end function ψ(ψ1, ψ2, i, inp) mask = (inp ≤ 0.5) * 1.0 diff --git a/test/polynomials.jl b/test/polynomials.jl index 31fb81d6..09fb11d5 100644 --- a/test/polynomials.jl +++ b/test/polynomials.jl @@ -31,4 +31,37 @@ @test coeffs(ψ2[2]) ≈ [29.44486372867492, -79.67433714817852, 51.96152422707261] @test coeffs(ψ2[3]) ≈ [-29.068883707507286, 80.49844719001908, -53.665631460012115] end + + @testset "chebyshev_ϕ_ψ" begin + ϕ, ψ1, ψ2 = NeuralOperators.chebyshev_ϕ_ψ(3) + @test ϕ[1](0) ≈ 0.7978845608028654 + @test ϕ[1](1) ≈ 0.7978845608028654 + @test ϕ[1](2) ≈ 0. + @test ϕ[2](0) ≈ -1.1283791670955126 + @test ϕ[2](1) ≈ 1.1283791670955126 + @test ϕ[2](2) ≈ 0. + @test ϕ[3](0) ≈ 1.1283791670955126 + @test ϕ[3](1) ≈ 1.1283791670955126 + @test ϕ[3](2) ≈ 0. + + @test ψ1[1](0) ≈ -0.5560622352843183 + @test ψ1[1](1) ≈ 0. + @test ψ1[1](2) ≈ 0. + @test ψ1[2](0) ≈ 0.932609257876051 + @test ψ1[2](1) ≈ 0. + @test ψ1[2](2) ≈ 0. + @test ψ1[3](0) ≈ 1.0941547380212637 + @test ψ1[3](1) ≈ 0. + @test ψ1[3](2) ≈ 0. + + @test ψ2[1](0) ≈ -0. + @test ψ2[1](1) ≈ 0.5560622352843181 + @test ψ2[1](2) ≈ 0. + @test ψ2[2](0) ≈ 0. + @test ψ2[2](1) ≈ 0.9326092578760665 + @test ψ2[2](2) ≈ 0. + @test ψ2[3](0) ≈ 0. + @test ψ2[3](1) ≈ -1.0941547380212384 + @test ψ2[3](2) ≈ 0. + end end From a5d756f4aee199e2ae8741a0915516713ee5d077 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Tue, 1 Feb 2022 16:41:08 +0800 Subject: [PATCH 07/10] complete legendre_filter --- Project.toml | 1 + src/NeuralOperators.jl | 1 + src/Transform/polynomials.jl | 23 +++++++++++--------- src/Transform/utils.jl | 14 ++++++++++-- test/polynomials.jl | 42 ++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 6 files changed, 70 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index bda050e6..5dc7cc58 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" SpecialPolynomials = "a25cea48-d430-424a-8ee7-0d3ad3742e9e" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 00c64175..aafa0fb7 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -12,6 +12,7 @@ using GeometricFlux using Statistics using Polynomials using SpecialPolynomials +using LinearAlgebra include("abstracttypes.jl") diff --git a/src/Transform/polynomials.jl b/src/Transform/polynomials.jl index 8bee986b..a7d26704 100644 --- a/src/Transform/polynomials.jl +++ b/src/Transform/polynomials.jl @@ -128,22 +128,25 @@ function chebyshev_ϕ_ψ(k) end function legendre_filter(k) - H0 = zeros(k, k)legendre + H0 = zeros(k, k) H1 = zeros(k, k) G0 = zeros(k, k) G1 = zeros(k, k) ϕ, ψ1, ψ2 = legendre_ϕ_ψ(k) - # roots = Poly(legendre(k, 2*x-1)).all_roots() - # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) - # wm = 1/k/legendreDer(k,2*x_m-1)/eval_legendre(k-1,2*x_m-1) + l = convert(Polynomial, gen_poly(Legendre, k)) + x_m = roots(l(Polynomial([-1, 2]))) # 2x-1 + m = 2 .* x_m .- 1 + wm = 1 ./ k ./ legendre_der.(k, m) ./ gen_poly(Legendre, k-1).(m) - # for ki in range(k): - # for kpi in range(k): - # H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum() - # G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum() - # H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum() - # G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum() + for ki in 0:(k-1) + for kpi in 0:(k-1) + H0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].(x_m/2) .* ϕ[kpi+1].(x_m)) + G0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, x_m/2) .* ϕ[kpi+1].(x_m)) + H1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].((x_m.+1)/2) .* ϕ[kpi+1].(x_m)) + G1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m.+1)/2) .* ϕ[kpi+1].(x_m)) + end + end zero_out!(H0) zero_out!(H1) diff --git a/src/Transform/utils.jl b/src/Transform/utils.jl index ae64aa18..994fa75d 100644 --- a/src/Transform/utils.jl +++ b/src/Transform/utils.jl @@ -7,8 +7,8 @@ function ϕ_(ϕ_coefs; lb::Real=0., ub::Real=1.) end function ψ(ψ1, ψ2, i, inp) - mask = (inp ≤ 0.5) * 1.0 - return ψ1[i](inp) * mask + ψ2[i](inp) * (1-mask) + mask = (inp .> 0.5) .* 1.0 + return ψ1[i+1].(inp) .* mask .+ ψ2[i+1].(inp) .* mask end zero_out!(x; tol=1e-8) = (x[abs.(x) .< tol] .= 0) @@ -35,3 +35,13 @@ function proj_factor(a, b; complement::Bool=false) proj_ = sum(prod_ ./ r .* s) return proj_ end + +_legendre(k, x) = (2k+1) * gen_poly(Legendre, k)(x) + +function legendre_der(k, x) + out = 0 + for i in k-1:-2:-1 + out += _legendre(i, x) + end + return out +end diff --git a/test/polynomials.jl b/test/polynomials.jl index 09fb11d5..1163b7ef 100644 --- a/test/polynomials.jl +++ b/test/polynomials.jl @@ -64,4 +64,46 @@ @test ψ2[3](1) ≈ -1.0941547380212384 @test ψ2[3](2) ≈ 0. end + + @testset "legendre_filter" begin + H0, H1, G0, G1, Φ1, Φ2 = NeuralOperators.legendre_filter(3) + + @test H0 ≈ [0.70710678 0. 0. ; + -0.61237244 0.35355339 0. ; + 0. -0.6846532 0.1767767] + @test H1 ≈ [0.70710678 0. 0. ; + 0.61237244 0.35355339 0. ; + 0. 0.6846532 0.1767767] + @test G0 ≈ [0.35355339 0.61237244 0. ; + 0. 0.1767767 0.6846532 ; + 0. 0. 0.70710678] + @test G1 ≈ [-0.35355339 0.61237244 0. ; + 0. -0.1767767 0.6846532 ; + 0. 0. -0.70710678] + @test Φ1 == I(3) + @test Φ2 == I(3) + end + + @testset "chebyshev_filter" begin + # H0, H1, G0, G1, Φ1, Φ2 = NeuralOperators.chebyshev_filter(3) + + # @test H0 ≈ [0.70710678 0. 0. ; + # -0.5 0.35355339 0. ; + # -0.25 -0.70710678 0.1767767] + # @test H1 ≈ [0.70710678 0. 0. ; + # 0.5 0.35355339 0. ; + # -0.25 0.70710678 0.1767767] + # @test G0 ≈ [0.60944614 0.77940383 0. ; + # 0.66325172 1.02726613 1.14270252; + # 0.61723435 0.90708619 1.1562954 ] + # @test G1 ≈ [-0.60944614 0.77940383 0. ; + # 0.66325172 -1.02726613 1.14270252; + # -0.61723435 0.90708619 -1.1562954 ] + # @test Φ1 ≈ [1. -0.40715364 -0.21440101; + # -0.40715364 0.84839559 -0.44820615; + # -0.21440101 -0.44820615 0.84002127] + # @test Φ2 ≈ [1. 0.40715364 -0.21440101; + # 0.40715364 0.84839559 0.44820615; + # -0.21440101 0.44820615 0.84002127] + end end diff --git a/test/runtests.jl b/test/runtests.jl index 90d4dcb2..e2d65a17 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using CUDA using Flux using GeometricFlux using Graphs +using LinearAlgebra using Polynomials using Zygote using Test From 4e244b15312e679e1a4c96682227b6a92b8125b0 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Fri, 4 Feb 2022 05:37:00 +0800 Subject: [PATCH 08/10] complete chebyshev_filter --- src/Transform/polynomials.jl | 36 ++++++++++++++++------------------ test/polynomials.jl | 38 ++++++++++++++++++------------------ 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/src/Transform/polynomials.jl b/src/Transform/polynomials.jl index a7d26704..93ee257c 100644 --- a/src/Transform/polynomials.jl +++ b/src/Transform/polynomials.jl @@ -165,25 +165,23 @@ function chebyshev_filter(k) Φ1 = zeros(k, k) ϕ, ψ1, ψ2 = chebyshev_ϕ_ψ(k) - # ---------------------------------------------------------- - - # x = Symbol('x') - # kUse = 2*k - # roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots() - # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) - # # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) - # # not needed for our purpose here, we use even k always to avoid - # wm = np.pi / kUse / 2 - - # for ki in range(k): - # for kpi in range(k): - # H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum() - # G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum() - # H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum() - # G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum() - - # PHI0[ki, kpi] = (wm * phi[ki](2*x_m) * phi[kpi](2*x_m)).sum() * 2 - # PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) * phi[kpi](2*x_m-1)).sum() * 2 + k_use = 2k + c = convert(Polynomial, gen_poly(Chebyshev, k_use)) + x_m = roots(c(Polynomial([-1, 2]))) # 2x-1 + # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) + # not needed for our purpose here, we use even k always to avoid + wm = π / k_use / 2 + + for ki in 0:(k-1) + for kpi in 0:(k-1) + H0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].(x_m/2) .* ϕ[kpi+1].(x_m)) + H1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].((x_m.+1)/2) .* ϕ[kpi+1].(x_m)) + G0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, x_m/2) .* ϕ[kpi+1].(x_m)) + G1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m.+1)/2) .* ϕ[kpi+1].(x_m)) + Φ0[ki+1, kpi+1] = 2*sum(wm .* ϕ[ki+1].(2x_m) .* ϕ[kpi+1].(2x_m)) + Φ1[ki+1, kpi+1] = 2*sum(wm .* ϕ[ki+1].(2 .* x_m .- 1) .* ϕ[kpi+1].(2 .* x_m .- 1)) + end + end zero_out!(H0) zero_out!(H1) diff --git a/test/polynomials.jl b/test/polynomials.jl index 1163b7ef..fdb53df7 100644 --- a/test/polynomials.jl +++ b/test/polynomials.jl @@ -85,25 +85,25 @@ end @testset "chebyshev_filter" begin - # H0, H1, G0, G1, Φ1, Φ2 = NeuralOperators.chebyshev_filter(3) + H0, H1, G0, G1, Φ0, Φ1 = NeuralOperators.chebyshev_filter(3) - # @test H0 ≈ [0.70710678 0. 0. ; - # -0.5 0.35355339 0. ; - # -0.25 -0.70710678 0.1767767] - # @test H1 ≈ [0.70710678 0. 0. ; - # 0.5 0.35355339 0. ; - # -0.25 0.70710678 0.1767767] - # @test G0 ≈ [0.60944614 0.77940383 0. ; - # 0.66325172 1.02726613 1.14270252; - # 0.61723435 0.90708619 1.1562954 ] - # @test G1 ≈ [-0.60944614 0.77940383 0. ; - # 0.66325172 -1.02726613 1.14270252; - # -0.61723435 0.90708619 -1.1562954 ] - # @test Φ1 ≈ [1. -0.40715364 -0.21440101; - # -0.40715364 0.84839559 -0.44820615; - # -0.21440101 -0.44820615 0.84002127] - # @test Φ2 ≈ [1. 0.40715364 -0.21440101; - # 0.40715364 0.84839559 0.44820615; - # -0.21440101 0.44820615 0.84002127] + @test H0 ≈ [0.70710678 0. 0. ; + -0.5 0.35355339 0. ; + -0.25 -0.70710678 0.1767767] + @test H1 ≈ [0.70710678 0. 0. ; + 0.5 0.35355339 0. ; + -0.25 0.70710678 0.1767767] + @test G0 ≈ [0.60944614 0.77940383 0. ; + 0.66325172 1.02726613 1.14270252; + 0.61723435 0.90708619 1.1562954 ] + @test G1 ≈ [-0.60944614 0.77940383 0. ; + 0.66325172 -1.02726613 1.14270252; + -0.61723435 0.90708619 -1.1562954 ] + @test Φ0 ≈ [1. -0.40715364 -0.21440101; + -0.40715364 0.84839559 -0.44820615; + -0.21440101 -0.44820615 0.84002127] + @test Φ1 ≈ [1. 0.40715364 -0.21440101; + 0.40715364 0.84839559 0.44820615; + -0.21440101 0.44820615 0.84002127] end end From 09d085ccc809b63e3eb77f54cfd8ce7edc781999 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Sat, 26 Mar 2022 23:36:25 +0800 Subject: [PATCH 09/10] add WaveletTransform --- src/Transform/wavelet_transform.jl | 153 +++++----------------------- src/operator_kernel.jl | 148 ++++++++++++++++++++++++++- test/{ => Transform}/polynomials.jl | 0 test/Transform/wavelet_transform.jl | 63 ++++-------- test/operator_kernel.jl | 54 ++++++++++ 5 files changed, 247 insertions(+), 171 deletions(-) rename test/{ => Transform}/polynomials.jl (100%) diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl index 26749c73..41d54edb 100644 --- a/src/Transform/wavelet_transform.jl +++ b/src/Transform/wavelet_transform.jl @@ -1,130 +1,33 @@ -export - SparseKernel, - SparseKernel1D, - SparseKernel2D, - SparseKernel3D - - -struct SparseKernel{N,T,S} - conv_blk::T - out_weight::S -end - -function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S} - input_dim, emb_dim = ch - conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out) -end - -function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k - emb_dim = 128 - return SparseKernel((3, ), input_dim=>emb_dim; init=init) -end - -function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k^2 - emb_dim = α*k^2 - return SparseKernel((3, 3), input_dim=>emb_dim; init=init) -end - -function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k^2 - emb_dim = α*k^2 - conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out) -end - -Flux.@functor SparseKernel - -function (l::SparseKernel)(X::AbstractArray) - bch_sz, _, dims_r... = reverse(size(X)) - dims = reverse(dims_r) - - X_ = l.conv_blk(X) # (dims..., emb_dims, B) - X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B) - Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B) - Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B) - return collect(Y) -end - - -struct MWT_CZ1d{T,S,R,Q,P} - k::Int - L::Int - A::T - B::S - C::R - T0::Q - ec_s::P - ec_d::P - rc_e::P - rc_o::P -end - -function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform) - H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k) - H0r = zero_out!(H0 * Φ0) - G0r = zero_out!(G0 * Φ0) - H1r = zero_out!(H1 * Φ1) - G1r = zero_out!(G1 * Φ1) - - dim = c*k - A = SpectralConv(dim=>dim, (α,); init=init) - B = SpectralConv(dim=>dim, (α,); init=init) - C = SpectralConv(dim=>dim, (α,); init=init) - T0 = Dense(k, k) - - ec_s = vcat(H0', H1') - ec_d = vcat(G0', G1') - rc_e = vcat(H0r, G0r) - rc_o = vcat(H1r, G1r) - return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o) -end - -function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} - N = size(X, 3) - Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :)) - d = NNlib.batched_mul(Xa, l.ec_d) - s = NNlib.batched_mul(Xa, l.ec_s) +export WaveletTransform + +struct WaveletTransform{N, S}<:AbstractTransform + ec_d + ec_s + modes::NTuple{N, S} # N == ndims(x) +end + +Base.ndims(::WaveletTransform{N}) where {N} = N + +function transform(wt::WaveletTransform, 𝐱::AbstractArray) + N = size(X, ndims(wt)-1) + # 1d + Xa = vcat(view(𝐱, :, :, 1:2:N, :), view(𝐱, :, :, 2:2:N, :)) + # 2d + # Xa = vcat( + # view(𝐱, :, :, 1:2:N, 1:2:N, :), + # view(𝐱, :, :, 1:2:N, 2:2:N, :), + # view(𝐱, :, :, 2:2:N, 1:2:N, :), + # view(𝐱, :, :, 2:2:N, 2:2:N, :), + # ) + d = NNlib.batched_mul(Xa, wt.ec_d) + s = NNlib.batched_mul(Xa, wt.ec_s) return d, s end -function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} - bch_sz, N, dims_r... = reverse(size(X)) - dims = reverse(dims_r) - @assert dims[1] == 2*l.k - Xₑ = NNlib.batched_mul(X, l.rc_e) - Xₒ = NNlib.batched_mul(X, l.rc_o) -# x = torch.zeros(B, N*2, c, self.k, -# device = x.device) -# x[..., ::2, :, :] = x_e -# x[..., 1::2, :, :] = x_o - return X +function inverse(wt::WaveletTransform, 𝐱_fwt::AbstractArray) + end -function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} - bch_sz, N, dims_r... = reverse(size(X)) - ns = floor(log2(N)) - stop = ns - l.L - - # decompose - Ud = T[] - Us = T[] - for i in 1:stop - d, X = wavelet_transform(l, X) - push!(Ud, l.A(d)+l.B(d)) - push!(Us, l.C(d)) - end - X = l.T0(X) - - # reconstruct - for i in stop:-1:1 - X += Us[i] - X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1) - X = even_odd(l, X) - end - return X -end +# function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray) +# return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch] +# end diff --git a/src/operator_kernel.jl b/src/operator_kernel.jl index d131ad34..31c7af24 100644 --- a/src/operator_kernel.jl +++ b/src/operator_kernel.jl @@ -1,7 +1,12 @@ export - OperatorConv, - SpectralConv, - OperatorKernel + OperatorConv, + SpectralConv, + OperatorKernel, + SparseKernel, + SparseKernel1D, + SparseKernel2D, + SparseKernel3D, + MWT_CZ1d struct OperatorConv{P, T, S, TT} weight::T @@ -180,6 +185,143 @@ function (m::OperatorKernel)(𝐱) return m.σ.(m.linear(𝐱) + m.conv(𝐱)) end +""" + SparseKernel(κ, ch, σ=identity) + +Sparse kernel layer. + +## Arguments + +* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP. +* `ch`: Channel size for linear transform, e.g. `32`. +* `σ`: Activation function. +""" +struct SparseKernel{N,T,S} + conv_blk::T + out_weight::S +end + +function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S} + input_dim, emb_dim = ch + conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out) +end + +function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k + emb_dim = 128 + return SparseKernel((3, ), input_dim=>emb_dim; init=init) +end + +function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k^2 + emb_dim = α*k^2 + return SparseKernel((3, 3), input_dim=>emb_dim; init=init) +end + +function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k^2 + emb_dim = α*k^2 + conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out) +end + +Flux.@functor SparseKernel + +function (l::SparseKernel)(X::AbstractArray) + bch_sz, _, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + + X_ = l.conv_blk(X) # (dims..., emb_dims, B) + X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B) + Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B) + Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B) + return collect(Y) +end + + +struct MWT_CZ1d{T,S,R,Q,P} + k::Int + L::Int + A::T + B::S + C::R + T0::Q + ec_s::P + ec_d::P + rc_e::P + rc_o::P +end + +function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform) + H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k) + H0r = zero_out!(H0 * Φ0) + G0r = zero_out!(G0 * Φ0) + H1r = zero_out!(H1 * Φ1) + G1r = zero_out!(G1 * Φ1) + + dim = c*k + A = SpectralConv(dim=>dim, (α,); init=init) + B = SpectralConv(dim=>dim, (α,); init=init) + C = SpectralConv(dim=>dim, (α,); init=init) + T0 = Dense(k, k) + + ec_s = vcat(H0', H1') + ec_d = vcat(G0', G1') + rc_e = vcat(H0r, G0r) + rc_o = vcat(H1r, G1r) + return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o) +end + +function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} + N = size(X, 3) + Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :)) + d = NNlib.batched_mul(Xa, l.ec_d) + s = NNlib.batched_mul(Xa, l.ec_s) + return d, s +end + +function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} + bch_sz, N, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + @assert dims[1] == 2*l.k + Y = similar(X, bch_sz, 2N, l.c, l.k) + view(Y, :, :, 1:2:N, :) .= NNlib.batched_mul(X, l.rc_e) + view(Y, :, :, 2:2:N, :) .= NNlib.batched_mul(X, l.rc_o) + return Y +end + +function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} + bch_sz, N, dims_r... = reverse(size(X)) + ns = floor(log2(N)) + stop = ns - l.L + + # decompose + Ud = T[] + Us = T[] + for i in 1:stop + d, X = wavelet_transform(l, X) + push!(Ud, l.A(d)+l.B(d)) + push!(Us, l.C(d)) + end + X = l.T0(X) + + # reconstruct + for i in stop:-1:1 + X += Us[i] + X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1) + X = even_odd(l, X) + end + return X +end + +# function Base.show(io::IO, l::MWT_CZ1d) +# print(io, "MWT_CZ($(l.in_channel) => $(l.out_channel), $(l.transform.modes), $(nameof(typeof(l.transform))), permuted=$P)") +# end + + ######### # utils # ######### diff --git a/test/polynomials.jl b/test/Transform/polynomials.jl similarity index 100% rename from test/polynomials.jl rename to test/Transform/polynomials.jl diff --git a/test/Transform/wavelet_transform.jl b/test/Transform/wavelet_transform.jl index 726727eb..48705bf3 100644 --- a/test/Transform/wavelet_transform.jl +++ b/test/Transform/wavelet_transform.jl @@ -1,53 +1,30 @@ -@testset "SparseKernel" begin +@testset "wavelet transform" begin + 𝐱 = rand(30, 40, 50, 6, 7) # where ch == 6 and batch == 7 + + wt = WaveletTransform((3, 4, 5)) + + @test size(transform(wt, 𝐱)) == (30, 40, 50, 6, 7) + @test size(truncate_modes(wt, transform(wt, 𝐱))) == (3, 4, 5, 6, 7) + @test size(inverse(wt, truncate_modes(wt, transform(wt, 𝐱)))) == (3, 4, 5, 6, 7) +end + +@testset "MWT_CZ" begin T = Float32 k = 3 batch_size = 32 - @testset "1D SparseKernel" begin - α = 4 - c = 1 - in_chs = 20 - X = rand(T, in_chs, c*k, batch_size) + @testset "MWT_CZ1d" begin + mwt = MWT_CZ1d() - l1 = SparseKernel1D(k, α, c) - Y = l1(X) - @test l1 isa SparseKernel{1} - @test size(Y) == size(X) + # base functions + wavelet_transform(mwt, ) + even_odd(mwt, ) - gs = gradient(()->sum(l1(X)), Flux.params(l1)) - @test length(gs.grads) == 4 - end + # forward + Y = mwt(X) - @testset "2D SparseKernel" begin - α = 4 - c = 3 - Nx = 5 - Ny = 7 - X = rand(T, Nx, Ny, c*k^2, batch_size) - - l2 = SparseKernel2D(k, α, c) - Y = l2(X) - @test l2 isa SparseKernel{2} - @test size(Y) == size(X) - - gs = gradient(()->sum(l2(X)), Flux.params(l2)) - @test length(gs.grads) == 4 + # backward + g = gradient() end - @testset "3D SparseKernel" begin - α = 4 - c = 3 - Nx = 5 - Ny = 7 - Nz = 13 - X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) - - l3 = SparseKernel3D(k, α, c) - Y = l3(X) - @test l3 isa SparseKernel{3} - @test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size) - - gs = gradient(()->sum(l3(X)), Flux.params(l3)) - @test length(gs.grads) == 4 - end end diff --git a/test/operator_kernel.jl b/test/operator_kernel.jl index 2d00b4ff..b03e09b2 100644 --- a/test/operator_kernel.jl +++ b/test/operator_kernel.jl @@ -160,3 +160,57 @@ end data = [(𝐱, rand(Float32, 128, 1024, 5))] Flux.train!(loss, Flux.params(m), data, Flux.Adam()) end + +@testset "SparseKernel" begin + T = Float32 + k = 3 + batch_size = 32 + + @testset "1D SparseKernel" begin + α = 4 + c = 1 + in_chs = 20 + X = rand(T, in_chs, c*k, batch_size) + + l1 = SparseKernel1D(k, α, c) + Y = l1(X) + @test l1 isa SparseKernel{1} + @test size(Y) == size(X) + + gs = gradient(()->sum(l1(X)), Flux.params(l1)) + @test length(gs.grads) == 4 + end + + @testset "2D SparseKernel" begin + α = 4 + c = 3 + Nx = 5 + Ny = 7 + X = rand(T, Nx, Ny, c*k^2, batch_size) + + l2 = SparseKernel2D(k, α, c) + Y = l2(X) + @test l2 isa SparseKernel{2} + @test size(Y) == size(X) + + gs = gradient(()->sum(l2(X)), Flux.params(l2)) + @test length(gs.grads) == 4 + end + + @testset "3D SparseKernel" begin + α = 4 + c = 3 + Nx = 5 + Ny = 7 + Nz = 13 + X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) + + l3 = SparseKernel3D(k, α, c) + Y = l3(X) + @test l3 isa SparseKernel{3} + @test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size) + + gs = gradient(()->sum(l3(X)), Flux.params(l3)) + @test length(gs.grads) == 4 + end +end From 47ba0683a817e85dfbb6d9ee8fb61b487a81e4fe Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Mon, 11 Jul 2022 14:08:19 +0800 Subject: [PATCH 10/10] format code --- src/Transform/polynomials.jl | 153 +++++++++++++++------------- src/Transform/utils.jl | 20 ++-- src/Transform/wavelet_transform.jl | 12 +-- src/operator_kernel.jl | 78 +++++++------- test/Transform/polynomials.jl | 142 +++++++++++++++----------- test/Transform/wavelet_transform.jl | 5 +- test/operator_kernel.jl | 16 +-- 7 files changed, 225 insertions(+), 201 deletions(-) diff --git a/src/Transform/polynomials.jl b/src/Transform/polynomials.jl index 93ee257c..480af298 100644 --- a/src/Transform/polynomials.jl +++ b/src/Transform/polynomials.jl @@ -16,47 +16,47 @@ function legendre_ϕ_ψ(k) p = Polynomial([-1, 2]) # 2x-1 p2 = Polynomial([-1, 4]) # 4x-1 - for ki in 0:(k-1) + for ki in 0:(k - 1) l = convert(Polynomial, gen_poly(Legendre, ki)) # Legendre of n=ki - ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2*ki+1) .* coeffs(l(p)) - ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2*(2*ki+1)) .* coeffs(l(p2)) + ϕ_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 * ki + 1) .* coeffs(l(p)) + ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 * (2 * ki + 1)) .* coeffs(l(p2)) end - + ψ1_coefs = zeros(k, k) ψ2_coefs = zeros(k, k) - for ki in 0:(k-1) - ψ1_coefs[ki+1, :] .= ϕ_2x_coefs[ki+1, :] - for i in 0:(k-1) - a = ϕ_2x_coefs[ki+1, 1:(ki+1)] - b = ϕ_coefs[i+1, 1:(i+1)] + for ki in 0:(k - 1) + ψ1_coefs[ki + 1, :] .= ϕ_2x_coefs[ki + 1, :] + for i in 0:(k - 1) + a = ϕ_2x_coefs[ki + 1, 1:(ki + 1)] + b = ϕ_coefs[i + 1, 1:(i + 1)] proj_ = proj_factor(a, b) - ψ1_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) - ψ2_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) end - for j in 0:(k-1) - a = ϕ_2x_coefs[ki+1, 1:(ki+1)] - b = ψ1_coefs[j+1, :] + for j in 0:(k - 1) + a = ϕ_2x_coefs[ki + 1, 1:(ki + 1)] + b = ψ1_coefs[j + 1, :] proj_ = proj_factor(a, b) - ψ1_coefs[ki+1, :] .-= proj_ .* view(ψ1_coefs, j+1, :) - ψ2_coefs[ki+1, :] .-= proj_ .* view(ψ2_coefs, j+1, :) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ψ1_coefs, j + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ψ2_coefs, j + 1, :) end - a = ψ1_coefs[ki+1, :] + a = ψ1_coefs[ki + 1, :] norm1 = proj_factor(a, a) - a = ψ2_coefs[ki+1, :] - norm2 = proj_factor(a, a, complement=true) + a = ψ2_coefs[ki + 1, :] + norm2 = proj_factor(a, a, complement = true) norm_ = sqrt(norm1 + norm2) - ψ1_coefs[ki+1, :] ./= norm_ - ψ2_coefs[ki+1, :] ./= norm_ + ψ1_coefs[ki + 1, :] ./= norm_ + ψ2_coefs[ki + 1, :] ./= norm_ zero_out!(ψ1_coefs) zero_out!(ψ2_coefs) end - ϕ = [Polynomial(ϕ_coefs[i,:]) for i in 1:k] - ψ1 = [Polynomial(ψ1_coefs[i,:]) for i in 1:k] - ψ2 = [Polynomial(ψ2_coefs[i,:]) for i in 1:k] + ϕ = [Polynomial(ϕ_coefs[i, :]) for i in 1:k] + ψ1 = [Polynomial(ψ1_coefs[i, :]) for i in 1:k] + ψ2 = [Polynomial(ψ2_coefs[i, :]) for i in 1:k] return ϕ, ψ1, ψ2 end @@ -68,14 +68,14 @@ function chebyshev_ϕ_ψ(k) p = Polynomial([-1, 2]) # 2x-1 p2 = Polynomial([-1, 4]) # 4x-1 - for ki in 0:(k-1) + for ki in 0:(k - 1) if ki == 0 - ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2/π) - ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(4/π) + ϕ_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2 / π) + ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(4 / π) else c = convert(Polynomial, gen_poly(Chebyshev, ki)) # Chebyshev of n=ki - ϕ_coefs[ki+1, 1:(ki+1)] .= 2/sqrt(π) .* coeffs(c(p)) - ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2) * 2/sqrt(π) .* coeffs(c(p2)) + ϕ_coefs[ki + 1, 1:(ki + 1)] .= 2 / sqrt(π) .* coeffs(c(p)) + ϕ_2x_coefs[ki + 1, 1:(ki + 1)] .= sqrt(2) * 2 / sqrt(π) .* coeffs(c(p2)) end end @@ -87,43 +87,43 @@ function chebyshev_ϕ_ψ(k) # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) # not needed for our purpose here, we use even k always to avoid wm = π / k_use / 2 - + ψ1_coefs = zeros(k, k) ψ2_coefs = zeros(k, k) - ψ1 = Array{Any,1}(undef, k) - ψ2 = Array{Any,1}(undef, k) + ψ1 = Array{Any, 1}(undef, k) + ψ2 = Array{Any, 1}(undef, k) - for ki in 0:(k-1) - ψ1_coefs[ki+1, :] .= ϕ_2x_coefs[ki+1, :] - for i in 0:(k-1) - proj_ = sum(wm .* ϕ[i+1].(x_m) .* sqrt(2) .* ϕ[ki+1].(2*x_m)) - ψ1_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) - ψ2_coefs[ki+1, :] .-= proj_ .* view(ϕ_coefs, i+1, :) + for ki in 0:(k - 1) + ψ1_coefs[ki + 1, :] .= ϕ_2x_coefs[ki + 1, :] + for i in 0:(k - 1) + proj_ = sum(wm .* ϕ[i + 1].(x_m) .* sqrt(2) .* ϕ[ki + 1].(2 * x_m)) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ϕ_coefs, i + 1, :) end - for j in 0:(ki-1) - proj_ = sum(wm .* ψ1[j+1].(x_m) .* sqrt(2) .* ϕ[ki+1].(2*x_m)) - ψ1_coefs[ki+1, :] .-= proj_ .* view(ψ1_coefs, j+1, :) - ψ2_coefs[ki+1, :] .-= proj_ .* view(ψ2_coefs, j+1, :) + for j in 0:(ki - 1) + proj_ = sum(wm .* ψ1[j + 1].(x_m) .* sqrt(2) .* ϕ[ki + 1].(2 * x_m)) + ψ1_coefs[ki + 1, :] .-= proj_ .* view(ψ1_coefs, j + 1, :) + ψ2_coefs[ki + 1, :] .-= proj_ .* view(ψ2_coefs, j + 1, :) end - ψ1[ki+1] = ϕ_(ψ1_coefs[ki+1,:]; lb=0., ub=0.5) - ψ2[ki+1] = ϕ_(ψ2_coefs[ki+1,:]; lb=0.5, ub=1.) - - norm1 = sum(wm .* ψ1[ki+1].(x_m) .* ψ1[ki+1].(x_m)) - norm2 = sum(wm .* ψ2[ki+1].(x_m) .* ψ2[ki+1].(x_m)) - + ψ1[ki + 1] = ϕ_(ψ1_coefs[ki + 1, :]; lb = 0.0, ub = 0.5) + ψ2[ki + 1] = ϕ_(ψ2_coefs[ki + 1, :]; lb = 0.5, ub = 1.0) + + norm1 = sum(wm .* ψ1[ki + 1].(x_m) .* ψ1[ki + 1].(x_m)) + norm2 = sum(wm .* ψ2[ki + 1].(x_m) .* ψ2[ki + 1].(x_m)) + norm_ = sqrt(norm1 + norm2) - ψ1_coefs[ki+1, :] ./= norm_ - ψ2_coefs[ki+1, :] ./= norm_ + ψ1_coefs[ki + 1, :] ./= norm_ + ψ2_coefs[ki + 1, :] ./= norm_ zero_out!(ψ1_coefs) zero_out!(ψ2_coefs) - - ψ1[ki+1] = ϕ_(ψ1_coefs[ki+1,:]; lb=0., ub=0.5+1e-16) - ψ2[ki+1] = ϕ_(ψ2_coefs[ki+1,:]; lb=0.5+1e-16, ub=1.) + + ψ1[ki + 1] = ϕ_(ψ1_coefs[ki + 1, :]; lb = 0.0, ub = 0.5 + 1e-16) + ψ2[ki + 1] = ϕ_(ψ2_coefs[ki + 1, :]; lb = 0.5 + 1e-16, ub = 1.0) end - + return ϕ, ψ1, ψ2 end @@ -137,14 +137,18 @@ function legendre_filter(k) l = convert(Polynomial, gen_poly(Legendre, k)) x_m = roots(l(Polynomial([-1, 2]))) # 2x-1 m = 2 .* x_m .- 1 - wm = 1 ./ k ./ legendre_der.(k, m) ./ gen_poly(Legendre, k-1).(m) - - for ki in 0:(k-1) - for kpi in 0:(k-1) - H0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].(x_m/2) .* ϕ[kpi+1].(x_m)) - G0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, x_m/2) .* ϕ[kpi+1].(x_m)) - H1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].((x_m.+1)/2) .* ϕ[kpi+1].(x_m)) - G1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m.+1)/2) .* ϕ[kpi+1].(x_m)) + wm = 1 ./ k ./ legendre_der.(k, m) ./ gen_poly(Legendre, k - 1).(m) + + for ki in 0:(k - 1) + for kpi in 0:(k - 1) + H0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].(x_m / 2) .* ϕ[kpi + 1].(x_m)) + G0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ψ(ψ1, ψ2, ki, x_m / 2) .* ϕ[kpi + 1].(x_m)) + H1[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].((x_m .+ 1) / 2) .* ϕ[kpi + 1].(x_m)) + G1[ki + 1, kpi + 1] = 1 / sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m .+ 1) / 2) .* + ϕ[kpi + 1].(x_m)) end end @@ -152,7 +156,7 @@ function legendre_filter(k) zero_out!(H1) zero_out!(G0) zero_out!(G1) - + return H0, H1, G0, G1, I(k), I(k) end @@ -172,14 +176,19 @@ function chebyshev_filter(k) # not needed for our purpose here, we use even k always to avoid wm = π / k_use / 2 - for ki in 0:(k-1) - for kpi in 0:(k-1) - H0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].(x_m/2) .* ϕ[kpi+1].(x_m)) - H1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ϕ[ki+1].((x_m.+1)/2) .* ϕ[kpi+1].(x_m)) - G0[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, x_m/2) .* ϕ[kpi+1].(x_m)) - G1[ki+1, kpi+1] = 1/sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m.+1)/2) .* ϕ[kpi+1].(x_m)) - Φ0[ki+1, kpi+1] = 2*sum(wm .* ϕ[ki+1].(2x_m) .* ϕ[kpi+1].(2x_m)) - Φ1[ki+1, kpi+1] = 2*sum(wm .* ϕ[ki+1].(2 .* x_m .- 1) .* ϕ[kpi+1].(2 .* x_m .- 1)) + for ki in 0:(k - 1) + for kpi in 0:(k - 1) + H0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].(x_m / 2) .* ϕ[kpi + 1].(x_m)) + H1[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ϕ[ki + 1].((x_m .+ 1) / 2) .* ϕ[kpi + 1].(x_m)) + G0[ki + 1, kpi + 1] = 1 / sqrt(2) * + sum(wm .* ψ(ψ1, ψ2, ki, x_m / 2) .* ϕ[kpi + 1].(x_m)) + G1[ki + 1, kpi + 1] = 1 / sqrt(2) * sum(wm .* ψ(ψ1, ψ2, ki, (x_m .+ 1) / 2) .* + ϕ[kpi + 1].(x_m)) + Φ0[ki + 1, kpi + 1] = 2 * sum(wm .* ϕ[ki + 1].(2x_m) .* ϕ[kpi + 1].(2x_m)) + Φ1[ki + 1, kpi + 1] = 2 * sum(wm .* ϕ[ki + 1].(2 .* x_m .- 1) .* + ϕ[kpi + 1].(2 .* x_m .- 1)) end end @@ -189,6 +198,6 @@ function chebyshev_filter(k) zero_out!(G1) zero_out!(Φ0) zero_out!(Φ1) - + return H0, H1, G0, G1, Φ0, Φ1 end diff --git a/src/Transform/utils.jl b/src/Transform/utils.jl index 994fa75d..a76c13b6 100644 --- a/src/Transform/utils.jl +++ b/src/Transform/utils.jl @@ -1,6 +1,6 @@ -function ϕ_(ϕ_coefs; lb::Real=0., ub::Real=1.) +function ϕ_(ϕ_coefs; lb::Real = 0.0, ub::Real = 1.0) function partial(x) - mask = (lb ≤ x ≤ ub) * 1. + mask = (lb ≤ x ≤ ub) * 1.0 return Polynomial(ϕ_coefs)(x) * mask end return partial @@ -8,27 +8,27 @@ end function ψ(ψ1, ψ2, i, inp) mask = (inp .> 0.5) .* 1.0 - return ψ1[i+1].(inp) .* mask .+ ψ2[i+1].(inp) .* mask + return ψ1[i + 1].(inp) .* mask .+ ψ2[i + 1].(inp) .* mask end -zero_out!(x; tol=1e-8) = (x[abs.(x) .< tol] .= 0) +zero_out!(x; tol = 1e-8) = (x[abs.(x) .< tol] .= 0) function gen_poly(poly, n) - x = zeros(n+1) + x = zeros(n + 1) x[end] = 1 return poly(x) end function convolve(a, b) n = length(b) - y = similar(a, length(a)+n-1) + y = similar(a, length(a) + n - 1) for i in 1:length(a) - y[i:(i+n-1)] .+= a[i] .* b + y[i:(i + n - 1)] .+= a[i] .* b end return y end -function proj_factor(a, b; complement::Bool=false) +function proj_factor(a, b; complement::Bool = false) prod_ = convolve(a, b) r = collect(1:length(prod_)) s = complement ? (1 .- 0.5 .^ r) : (0.5 .^ r) @@ -36,11 +36,11 @@ function proj_factor(a, b; complement::Bool=false) return proj_ end -_legendre(k, x) = (2k+1) * gen_poly(Legendre, k)(x) +_legendre(k, x) = (2k + 1) * gen_poly(Legendre, k)(x) function legendre_der(k, x) out = 0 - for i in k-1:-2:-1 + for i in (k - 1):-2:-1 out += _legendre(i, x) end return out diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl index 41d54edb..d58970bb 100644 --- a/src/Transform/wavelet_transform.jl +++ b/src/Transform/wavelet_transform.jl @@ -1,15 +1,15 @@ export WaveletTransform -struct WaveletTransform{N, S}<:AbstractTransform - ec_d - ec_s +struct WaveletTransform{N, S} <: AbstractTransform + ec_d::Any + ec_s::Any modes::NTuple{N, S} # N == ndims(x) end Base.ndims(::WaveletTransform{N}) where {N} = N function transform(wt::WaveletTransform, 𝐱::AbstractArray) - N = size(X, ndims(wt)-1) + N = size(X, ndims(wt) - 1) # 1d Xa = vcat(view(𝐱, :, :, 1:2:N, :), view(𝐱, :, :, 2:2:N, :)) # 2d @@ -24,9 +24,7 @@ function transform(wt::WaveletTransform, 𝐱::AbstractArray) return d, s end -function inverse(wt::WaveletTransform, 𝐱_fwt::AbstractArray) - -end +function inverse(wt::WaveletTransform, 𝐱_fwt::AbstractArray) end # function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray) # return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch] diff --git a/src/operator_kernel.jl b/src/operator_kernel.jl index 31c7af24..2eab3647 100644 --- a/src/operator_kernel.jl +++ b/src/operator_kernel.jl @@ -1,12 +1,12 @@ export - OperatorConv, - SpectralConv, - OperatorKernel, - SparseKernel, - SparseKernel1D, - SparseKernel2D, - SparseKernel3D, - MWT_CZ1d + OperatorConv, + SpectralConv, + OperatorKernel, + SparseKernel, + SparseKernel1D, + SparseKernel2D, + SparseKernel3D, + MWT_CZ1d struct OperatorConv{P, T, S, TT} weight::T @@ -196,36 +196,37 @@ Sparse kernel layer. * `ch`: Channel size for linear transform, e.g. `32`. * `σ`: Activation function. """ -struct SparseKernel{N,T,S} +struct SparseKernel{N, T, S} conv_blk::T out_weight::S end -function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S} +function SparseKernel(filter::NTuple{N, T}, ch::Pair{S, S}; + init = Flux.glorot_uniform) where {N, T, S} input_dim, emb_dim = ch - conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out) + conv = Conv(filter, input_dim => emb_dim, relu; stride = 1, pad = 1, init = init) + W_out = Dense(emb_dim, input_dim; init = init) + return SparseKernel{N, typeof(conv), typeof(W_out)}(conv, W_out) end -function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k +function SparseKernel1D(k::Int, α, c::Int = 1; init = Flux.glorot_uniform) + input_dim = c * k emb_dim = 128 - return SparseKernel((3, ), input_dim=>emb_dim; init=init) + return SparseKernel((3,), input_dim => emb_dim; init = init) end -function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k^2 - emb_dim = α*k^2 - return SparseKernel((3, 3), input_dim=>emb_dim; init=init) +function SparseKernel2D(k::Int, α, c::Int = 1; init = Flux.glorot_uniform) + input_dim = c * k^2 + emb_dim = α * k^2 + return SparseKernel((3, 3), input_dim => emb_dim; init = init) end -function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k^2 - emb_dim = α*k^2 - conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out) +function SparseKernel3D(k::Int, α, c::Int = 1; init = Flux.glorot_uniform) + input_dim = c * k^2 + emb_dim = α * k^2 + conv = Conv((3, 3, 3), emb_dim => emb_dim, relu; stride = 1, pad = 1, init = init) + W_out = Dense(emb_dim, input_dim; init = init) + return SparseKernel{3, typeof(conv), typeof(W_out)}(conv, W_out) end Flux.@functor SparseKernel @@ -241,8 +242,7 @@ function (l::SparseKernel)(X::AbstractArray) return collect(Y) end - -struct MWT_CZ1d{T,S,R,Q,P} +struct MWT_CZ1d{T, S, R, Q, P} k::Int L::Int A::T @@ -255,17 +255,18 @@ struct MWT_CZ1d{T,S,R,Q,P} rc_o::P end -function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform) +function MWT_CZ1d(k::Int = 3, α::Int = 5, L::Int = 0, c::Int = 1; base::Symbol = :legendre, + init = Flux.glorot_uniform) H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k) H0r = zero_out!(H0 * Φ0) G0r = zero_out!(G0 * Φ0) H1r = zero_out!(H1 * Φ1) G1r = zero_out!(G1 * Φ1) - dim = c*k - A = SpectralConv(dim=>dim, (α,); init=init) - B = SpectralConv(dim=>dim, (α,); init=init) - C = SpectralConv(dim=>dim, (α,); init=init) + dim = c * k + A = SpectralConv(dim => dim, (α,); init = init) + B = SpectralConv(dim => dim, (α,); init = init) + C = SpectralConv(dim => dim, (α,); init = init) T0 = Dense(k, k) ec_s = vcat(H0', H1') @@ -275,7 +276,7 @@ function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendr return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o) end -function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} +function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T, 4}) where {T} N = size(X, 3) Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :)) d = NNlib.batched_mul(Xa, l.ec_d) @@ -283,17 +284,17 @@ function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} return d, s end -function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} +function even_odd(l::MWT_CZ1d, X::AbstractArray{T, 4}) where {T} bch_sz, N, dims_r... = reverse(size(X)) dims = reverse(dims_r) - @assert dims[1] == 2*l.k + @assert dims[1] == 2 * l.k Y = similar(X, bch_sz, 2N, l.c, l.k) view(Y, :, :, 1:2:N, :) .= NNlib.batched_mul(X, l.rc_e) view(Y, :, :, 2:2:N, :) .= NNlib.batched_mul(X, l.rc_o) return Y end -function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} +function (l::MWT_CZ1d)(X::T) where {T <: AbstractArray} bch_sz, N, dims_r... = reverse(size(X)) ns = floor(log2(N)) stop = ns - l.L @@ -303,7 +304,7 @@ function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} Us = T[] for i in 1:stop d, X = wavelet_transform(l, X) - push!(Ud, l.A(d)+l.B(d)) + push!(Ud, l.A(d) + l.B(d)) push!(Us, l.C(d)) end X = l.T0(X) @@ -321,7 +322,6 @@ end # print(io, "MWT_CZ($(l.in_channel) => $(l.out_channel), $(l.transform.modes), $(nameof(typeof(l.transform))), permuted=$P)") # end - ######### # utils # ######### diff --git a/test/Transform/polynomials.jl b/test/Transform/polynomials.jl index fdb53df7..2fb5a837 100644 --- a/test/Transform/polynomials.jl +++ b/test/Transform/polynomials.jl @@ -2,26 +2,44 @@ @testset "legendre_ϕ_ψ" begin ϕ, ψ1, ψ2 = NeuralOperators.legendre_ϕ_ψ(10) - @test all(coeffs(ϕ[1]) .≈ [1.]) + @test all(coeffs(ϕ[1]) .≈ [1.0]) @test all(coeffs(ϕ[2]) .≈ [-1.7320508075688772, 3.4641016151377544]) - @test all(coeffs(ϕ[3]) .≈ [2.23606797749979, -13.416407864998739, 13.416407864998739]) - @test all(coeffs(ϕ[4]) .≈ [-2.6457513110645907, 31.74901573277509, -79.37253933193772, 52.91502622129181]) + @test all(coeffs(ϕ[3]) .≈ + [2.23606797749979, -13.416407864998739, 13.416407864998739]) + @test all(coeffs(ϕ[4]) .≈ [ + -2.6457513110645907, + 31.74901573277509, + -79.37253933193772, + 52.91502622129181, + ]) @test all(coeffs(ϕ[5]) .≈ [3.0, -60.0, 270.0, -420.0, 210.0]) - @test all(coeffs(ϕ[6]) .≈ [-3.3166247903554, 99.498743710662, -696.491205974634, 1857.309882599024, - -2089.4736179239017, 835.7894471695607]) - @test all(coeffs(ϕ[7]) .≈ [3.605551275463989, -151.43315356948753, 1514.3315356948754, -6057.326142779501, - 11357.486517711566, -9994.588135586178, 3331.529378528726]) - @test all(coeffs(ϕ[8]) .≈ [-3.872983346207417, 216.88706738761536, -2927.9754097328073, 16266.530054071152, - -44732.957648695665, 64415.45901412176, -46522.27595464349, 13292.078844183856]) - @test all(coeffs(ϕ[9]) .≈ [4.123105625617661, -296.86360504447157, 5195.113088278253, -38097.49598070719, - 142865.60992765194, -297160.46864951606, 346687.21342443535, -212257.47760679715, - 53064.36940169929]) - @test all(coeffs(ϕ[10]) .≈ [-4.358898943540674, 392.30090491866065, -8630.619908210534, 80552.45247663166, - -392693.20582357934, 1099540.9763060221, -1832568.2938433702, 1795168.9409077913, - -953683.4998572641, 211929.66663494756]) - + @test all(coeffs(ϕ[6]) .≈ + [-3.3166247903554, 99.498743710662, -696.491205974634, 1857.309882599024, + -2089.4736179239017, 835.7894471695607]) + @test all(coeffs(ϕ[7]) .≈ + [3.605551275463989, -151.43315356948753, 1514.3315356948754, + -6057.326142779501, + 11357.486517711566, -9994.588135586178, 3331.529378528726]) + @test all(coeffs(ϕ[8]) .≈ + [-3.872983346207417, 216.88706738761536, -2927.9754097328073, + 16266.530054071152, + -44732.957648695665, 64415.45901412176, -46522.27595464349, + 13292.078844183856]) + @test all(coeffs(ϕ[9]) .≈ + [4.123105625617661, -296.86360504447157, 5195.113088278253, + -38097.49598070719, + 142865.60992765194, -297160.46864951606, 346687.21342443535, + -212257.47760679715, + 53064.36940169929]) + @test all(coeffs(ϕ[10]) .≈ + [-4.358898943540674, 392.30090491866065, -8630.619908210534, + 80552.45247663166, + -392693.20582357934, 1099540.9763060221, -1832568.2938433702, + 1795168.9409077913, + -953683.4998572641, 211929.66663494756]) + ϕ, ψ1, ψ2 = NeuralOperators.legendre_ϕ_ψ(3) - @test coeffs(ϕ[1]) ≈ [1.] + @test coeffs(ϕ[1]) ≈ [1.0] @test coeffs(ϕ[2]) ≈ [-1.7320508075688772, 3.4641016151377544] @test coeffs(ϕ[3]) ≈ [2.23606797749979, -13.416407864998739, 13.416407864998739] @test coeffs(ψ1[1]) ≈ [-1.0000000000000122, 6.000000000000073] @@ -36,50 +54,50 @@ ϕ, ψ1, ψ2 = NeuralOperators.chebyshev_ϕ_ψ(3) @test ϕ[1](0) ≈ 0.7978845608028654 @test ϕ[1](1) ≈ 0.7978845608028654 - @test ϕ[1](2) ≈ 0. + @test ϕ[1](2) ≈ 0.0 @test ϕ[2](0) ≈ -1.1283791670955126 @test ϕ[2](1) ≈ 1.1283791670955126 - @test ϕ[2](2) ≈ 0. + @test ϕ[2](2) ≈ 0.0 @test ϕ[3](0) ≈ 1.1283791670955126 @test ϕ[3](1) ≈ 1.1283791670955126 - @test ϕ[3](2) ≈ 0. + @test ϕ[3](2) ≈ 0.0 @test ψ1[1](0) ≈ -0.5560622352843183 - @test ψ1[1](1) ≈ 0. - @test ψ1[1](2) ≈ 0. + @test ψ1[1](1) ≈ 0.0 + @test ψ1[1](2) ≈ 0.0 @test ψ1[2](0) ≈ 0.932609257876051 - @test ψ1[2](1) ≈ 0. - @test ψ1[2](2) ≈ 0. + @test ψ1[2](1) ≈ 0.0 + @test ψ1[2](2) ≈ 0.0 @test ψ1[3](0) ≈ 1.0941547380212637 - @test ψ1[3](1) ≈ 0. - @test ψ1[3](2) ≈ 0. + @test ψ1[3](1) ≈ 0.0 + @test ψ1[3](2) ≈ 0.0 - @test ψ2[1](0) ≈ -0. + @test ψ2[1](0) ≈ -0.0 @test ψ2[1](1) ≈ 0.5560622352843181 - @test ψ2[1](2) ≈ 0. - @test ψ2[2](0) ≈ 0. + @test ψ2[1](2) ≈ 0.0 + @test ψ2[2](0) ≈ 0.0 @test ψ2[2](1) ≈ 0.9326092578760665 - @test ψ2[2](2) ≈ 0. - @test ψ2[3](0) ≈ 0. + @test ψ2[2](2) ≈ 0.0 + @test ψ2[3](0) ≈ 0.0 @test ψ2[3](1) ≈ -1.0941547380212384 - @test ψ2[3](2) ≈ 0. + @test ψ2[3](2) ≈ 0.0 end @testset "legendre_filter" begin H0, H1, G0, G1, Φ1, Φ2 = NeuralOperators.legendre_filter(3) - @test H0 ≈ [0.70710678 0. 0. ; - -0.61237244 0.35355339 0. ; - 0. -0.6846532 0.1767767] - @test H1 ≈ [0.70710678 0. 0. ; - 0.61237244 0.35355339 0. ; - 0. 0.6846532 0.1767767] - @test G0 ≈ [0.35355339 0.61237244 0. ; - 0. 0.1767767 0.6846532 ; - 0. 0. 0.70710678] - @test G1 ≈ [-0.35355339 0.61237244 0. ; - 0. -0.1767767 0.6846532 ; - 0. 0. -0.70710678] + @test H0 ≈ [0.70710678 0.0 0.0; + -0.61237244 0.35355339 0.0; + 0.0 -0.6846532 0.1767767] + @test H1 ≈ [0.70710678 0.0 0.0; + 0.61237244 0.35355339 0.0; + 0.0 0.6846532 0.1767767] + @test G0 ≈ [0.35355339 0.61237244 0.0; + 0.0 0.1767767 0.6846532; + 0.0 0.0 0.70710678] + @test G1 ≈ [-0.35355339 0.61237244 0.0; + 0.0 -0.1767767 0.6846532; + 0.0 0.0 -0.70710678] @test Φ1 == I(3) @test Φ2 == I(3) end @@ -87,23 +105,23 @@ @testset "chebyshev_filter" begin H0, H1, G0, G1, Φ0, Φ1 = NeuralOperators.chebyshev_filter(3) - @test H0 ≈ [0.70710678 0. 0. ; - -0.5 0.35355339 0. ; - -0.25 -0.70710678 0.1767767] - @test H1 ≈ [0.70710678 0. 0. ; - 0.5 0.35355339 0. ; - -0.25 0.70710678 0.1767767] - @test G0 ≈ [0.60944614 0.77940383 0. ; - 0.66325172 1.02726613 1.14270252; - 0.61723435 0.90708619 1.1562954 ] - @test G1 ≈ [-0.60944614 0.77940383 0. ; - 0.66325172 -1.02726613 1.14270252; - -0.61723435 0.90708619 -1.1562954 ] - @test Φ0 ≈ [1. -0.40715364 -0.21440101; - -0.40715364 0.84839559 -0.44820615; - -0.21440101 -0.44820615 0.84002127] - @test Φ1 ≈ [1. 0.40715364 -0.21440101; - 0.40715364 0.84839559 0.44820615; - -0.21440101 0.44820615 0.84002127] + @test H0 ≈ [0.70710678 0.0 0.0; + -0.5 0.35355339 0.0; + -0.25 -0.70710678 0.1767767] + @test H1 ≈ [0.70710678 0.0 0.0; + 0.5 0.35355339 0.0; + -0.25 0.70710678 0.1767767] + @test G0 ≈ [0.60944614 0.77940383 0.0; + 0.66325172 1.02726613 1.14270252; + 0.61723435 0.90708619 1.1562954] + @test G1 ≈ [-0.60944614 0.77940383 0.0; + 0.66325172 -1.02726613 1.14270252; + -0.61723435 0.90708619 -1.1562954] + @test Φ0 ≈ [1.0 -0.40715364 -0.21440101; + -0.40715364 0.84839559 -0.44820615; + -0.21440101 -0.44820615 0.84002127] + @test Φ1 ≈ [1.0 0.40715364 -0.21440101; + 0.40715364 0.84839559 0.44820615; + -0.21440101 0.44820615 0.84002127] end end diff --git a/test/Transform/wavelet_transform.jl b/test/Transform/wavelet_transform.jl index 48705bf3..16211220 100644 --- a/test/Transform/wavelet_transform.jl +++ b/test/Transform/wavelet_transform.jl @@ -17,8 +17,8 @@ end mwt = MWT_CZ1d() # base functions - wavelet_transform(mwt, ) - even_odd(mwt, ) + wavelet_transform(mwt) + even_odd(mwt) # forward Y = mwt(X) @@ -26,5 +26,4 @@ end # backward g = gradient() end - end diff --git a/test/operator_kernel.jl b/test/operator_kernel.jl index b03e09b2..e210a478 100644 --- a/test/operator_kernel.jl +++ b/test/operator_kernel.jl @@ -170,14 +170,14 @@ end α = 4 c = 1 in_chs = 20 - X = rand(T, in_chs, c*k, batch_size) + X = rand(T, in_chs, c * k, batch_size) l1 = SparseKernel1D(k, α, c) Y = l1(X) @test l1 isa SparseKernel{1} @test size(Y) == size(X) - gs = gradient(()->sum(l1(X)), Flux.params(l1)) + gs = gradient(() -> sum(l1(X)), Flux.params(l1)) @test length(gs.grads) == 4 end @@ -186,14 +186,14 @@ end c = 3 Nx = 5 Ny = 7 - X = rand(T, Nx, Ny, c*k^2, batch_size) - + X = rand(T, Nx, Ny, c * k^2, batch_size) + l2 = SparseKernel2D(k, α, c) Y = l2(X) @test l2 isa SparseKernel{2} @test size(Y) == size(X) - gs = gradient(()->sum(l2(X)), Flux.params(l2)) + gs = gradient(() -> sum(l2(X)), Flux.params(l2)) @test length(gs.grads) == 4 end @@ -203,14 +203,14 @@ end Nx = 5 Ny = 7 Nz = 13 - X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) + X = rand(T, Nx, Ny, Nz, α * k^2, batch_size) l3 = SparseKernel3D(k, α, c) Y = l3(X) @test l3 isa SparseKernel{3} - @test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size) + @test size(Y) == (Nx, Ny, Nz, c * k^2, batch_size) - gs = gradient(()->sum(l3(X)), Flux.params(l3)) + gs = gradient(() -> sum(l3(X)), Flux.params(l3)) @test length(gs.grads) == 4 end end