diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl index 26749c73..5c5d4dd1 100644 --- a/src/Transform/wavelet_transform.jl +++ b/src/Transform/wavelet_transform.jl @@ -1,130 +1,19 @@ -export - SparseKernel, - SparseKernel1D, - SparseKernel2D, - SparseKernel3D +export WaveletTransform - -struct SparseKernel{N,T,S} - conv_blk::T - out_weight::S -end - -function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S} - input_dim, emb_dim = ch - conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out) -end - -function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k - emb_dim = 128 - return SparseKernel((3, ), input_dim=>emb_dim; init=init) -end - -function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k^2 - emb_dim = α*k^2 - return SparseKernel((3, 3), input_dim=>emb_dim; init=init) -end - -function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k^2 - emb_dim = α*k^2 - conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out) -end - -Flux.@functor SparseKernel - -function (l::SparseKernel)(X::AbstractArray) - bch_sz, _, dims_r... = reverse(size(X)) - dims = reverse(dims_r) - - X_ = l.conv_blk(X) # (dims..., emb_dims, B) - X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B) - Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B) - Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B) - return collect(Y) -end - - -struct MWT_CZ1d{T,S,R,Q,P} - k::Int - L::Int - A::T - B::S - C::R - T0::Q - ec_s::P - ec_d::P - rc_e::P - rc_o::P -end - -function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform) - H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k) - H0r = zero_out!(H0 * Φ0) - G0r = zero_out!(G0 * Φ0) - H1r = zero_out!(H1 * Φ1) - G1r = zero_out!(G1 * Φ1) - - dim = c*k - A = SpectralConv(dim=>dim, (α,); init=init) - B = SpectralConv(dim=>dim, (α,); init=init) - C = SpectralConv(dim=>dim, (α,); init=init) - T0 = Dense(k, k) - - ec_s = vcat(H0', H1') - ec_d = vcat(G0', G1') - rc_e = vcat(H0r, G0r) - rc_o = vcat(H1r, G1r) - return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o) -end - -function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} - N = size(X, 3) - Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :)) - d = NNlib.batched_mul(Xa, l.ec_d) - s = NNlib.batched_mul(Xa, l.ec_s) - return d, s +struct WaveletTransform{N, S}<:AbstractTransform + modes::NTuple{N, S} # N == ndims(x) end -function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} - bch_sz, N, dims_r... = reverse(size(X)) - dims = reverse(dims_r) - @assert dims[1] == 2*l.k - Xₑ = NNlib.batched_mul(X, l.rc_e) - Xₒ = NNlib.batched_mul(X, l.rc_o) -# x = torch.zeros(B, N*2, c, self.k, -# device = x.device) -# x[..., ::2, :, :] = x_e -# x[..., 1::2, :, :] = x_o - return X -end +Base.ndims(::WaveletTransform{N}) where {N} = N -function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} - bch_sz, N, dims_r... = reverse(size(X)) - ns = floor(log2(N)) - stop = ns - l.L +# function transform(wt::WaveletTransform, 𝐱::AbstractArray) +# return fft(Zygote.hook(real, 𝐱), 1:ndims(wt)) # [size(x)..., in_chs, batch] +# end - # decompose - Ud = T[] - Us = T[] - for i in 1:stop - d, X = wavelet_transform(l, X) - push!(Ud, l.A(d)+l.B(d)) - push!(Us, l.C(d)) - end - X = l.T0(X) +# function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray) +# return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch] +# end - # reconstruct - for i in stop:-1:1 - X += Us[i] - X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1) - X = even_odd(l, X) - end - return X -end +# function inverse(wt::WaveletTransform, 𝐱_fft::AbstractArray) +# return real(ifft(𝐱_fft, 1:ndims(wt))) # [size(x_fft)..., out_chs, batch] +# end diff --git a/src/operator_kernel.jl b/src/operator_kernel.jl index 9f332da1..8325bacf 100644 --- a/src/operator_kernel.jl +++ b/src/operator_kernel.jl @@ -2,7 +2,11 @@ export OperatorConv, SpectralConv, OperatorKernel, - GraphKernel + GraphKernel, + SparseKernel, + SparseKernel1D, + SparseKernel2D, + SparseKernel3D struct OperatorConv{P, T, S, TT} weight::T @@ -216,6 +220,141 @@ function Base.show(io::IO, l::GraphKernel) print(io, ")") end +""" + SparseKernel(κ, ch, σ=identity) + +Sparse kernel layer. + +## Arguments + +* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP. +* `ch`: Channel size for linear transform, e.g. `32`. +* `σ`: Activation function. +""" +struct SparseKernel{N,T,S} + conv_blk::T + out_weight::S +end + +function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S} + input_dim, emb_dim = ch + conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out) +end + +function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k + emb_dim = 128 + return SparseKernel((3, ), input_dim=>emb_dim; init=init) +end + +function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k^2 + emb_dim = α*k^2 + return SparseKernel((3, 3), input_dim=>emb_dim; init=init) +end + +function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k^2 + emb_dim = α*k^2 + conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out) +end + +Flux.@functor SparseKernel + +function (l::SparseKernel)(X::AbstractArray) + bch_sz, _, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + + X_ = l.conv_blk(X) # (dims..., emb_dims, B) + X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B) + Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B) + Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B) + return collect(Y) +end + + +struct MWT_CZ1d{T,S,R,Q,P} + k::Int + L::Int + A::T + B::S + C::R + T0::Q + ec_s::P + ec_d::P + rc_e::P + rc_o::P +end + +function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform) + H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k) + H0r = zero_out!(H0 * Φ0) + G0r = zero_out!(G0 * Φ0) + H1r = zero_out!(H1 * Φ1) + G1r = zero_out!(G1 * Φ1) + + dim = c*k + A = SpectralConv(dim=>dim, (α,); init=init) + B = SpectralConv(dim=>dim, (α,); init=init) + C = SpectralConv(dim=>dim, (α,); init=init) + T0 = Dense(k, k) + + ec_s = vcat(H0', H1') + ec_d = vcat(G0', G1') + rc_e = vcat(H0r, G0r) + rc_o = vcat(H1r, G1r) + return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o) +end + +function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} + N = size(X, 3) + Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :)) + d = NNlib.batched_mul(Xa, l.ec_d) + s = NNlib.batched_mul(Xa, l.ec_s) + return d, s +end + +function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} + bch_sz, N, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + @assert dims[1] == 2*l.k + Xₑ = NNlib.batched_mul(X, l.rc_e) + Xₒ = NNlib.batched_mul(X, l.rc_o) +# x = torch.zeros(B, N*2, c, self.k, +# device = x.device) +# x[..., ::2, :, :] = x_e +# x[..., 1::2, :, :] = x_o + return X +end + +function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} + bch_sz, N, dims_r... = reverse(size(X)) + ns = floor(log2(N)) + stop = ns - l.L + + # decompose + Ud = T[] + Us = T[] + for i in 1:stop + d, X = wavelet_transform(l, X) + push!(Ud, l.A(d)+l.B(d)) + push!(Us, l.C(d)) + end + X = l.T0(X) + + # reconstruct + for i in stop:-1:1 + X += Us[i] + X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1) + X = even_odd(l, X) + end + return X +end + ######### # utils # diff --git a/test/polynomials.jl b/test/Transform/polynomials.jl similarity index 100% rename from test/polynomials.jl rename to test/Transform/polynomials.jl diff --git a/test/Transform/wavelet_transform.jl b/test/Transform/wavelet_transform.jl index 726727eb..e69de29b 100644 --- a/test/Transform/wavelet_transform.jl +++ b/test/Transform/wavelet_transform.jl @@ -1,53 +0,0 @@ -@testset "SparseKernel" begin - T = Float32 - k = 3 - batch_size = 32 - - @testset "1D SparseKernel" begin - α = 4 - c = 1 - in_chs = 20 - X = rand(T, in_chs, c*k, batch_size) - - l1 = SparseKernel1D(k, α, c) - Y = l1(X) - @test l1 isa SparseKernel{1} - @test size(Y) == size(X) - - gs = gradient(()->sum(l1(X)), Flux.params(l1)) - @test length(gs.grads) == 4 - end - - @testset "2D SparseKernel" begin - α = 4 - c = 3 - Nx = 5 - Ny = 7 - X = rand(T, Nx, Ny, c*k^2, batch_size) - - l2 = SparseKernel2D(k, α, c) - Y = l2(X) - @test l2 isa SparseKernel{2} - @test size(Y) == size(X) - - gs = gradient(()->sum(l2(X)), Flux.params(l2)) - @test length(gs.grads) == 4 - end - - @testset "3D SparseKernel" begin - α = 4 - c = 3 - Nx = 5 - Ny = 7 - Nz = 13 - X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) - - l3 = SparseKernel3D(k, α, c) - Y = l3(X) - @test l3 isa SparseKernel{3} - @test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size) - - gs = gradient(()->sum(l3(X)), Flux.params(l3)) - @test length(gs.grads) == 4 - end -end diff --git a/test/operator_kernel.jl b/test/operator_kernel.jl index e6c9d186..730468ae 100644 --- a/test/operator_kernel.jl +++ b/test/operator_kernel.jl @@ -171,3 +171,57 @@ end g = Zygote.gradient(() -> sum(l(𝐱)), Flux.params(l)) @test length(g.grads) == 3 end + +@testset "SparseKernel" begin + T = Float32 + k = 3 + batch_size = 32 + + @testset "1D SparseKernel" begin + α = 4 + c = 1 + in_chs = 20 + X = rand(T, in_chs, c*k, batch_size) + + l1 = SparseKernel1D(k, α, c) + Y = l1(X) + @test l1 isa SparseKernel{1} + @test size(Y) == size(X) + + gs = gradient(()->sum(l1(X)), Flux.params(l1)) + @test length(gs.grads) == 4 + end + + @testset "2D SparseKernel" begin + α = 4 + c = 3 + Nx = 5 + Ny = 7 + X = rand(T, Nx, Ny, c*k^2, batch_size) + + l2 = SparseKernel2D(k, α, c) + Y = l2(X) + @test l2 isa SparseKernel{2} + @test size(Y) == size(X) + + gs = gradient(()->sum(l2(X)), Flux.params(l2)) + @test length(gs.grads) == 4 + end + + @testset "3D SparseKernel" begin + α = 4 + c = 3 + Nx = 5 + Ny = 7 + Nz = 13 + X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) + + l3 = SparseKernel3D(k, α, c) + Y = l3(X) + @test l3 isa SparseKernel{3} + @test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size) + + gs = gradient(()->sum(l3(X)), Flux.params(l3)) + @test length(gs.grads) == 4 + end +end