Skip to content
This repository has been archived by the owner on Sep 28, 2024. It is now read-only.

Commit

Permalink
migrate
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed May 21, 2022
1 parent d259ce3 commit 827ce6f
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 178 deletions.
137 changes: 13 additions & 124 deletions src/Transform/wavelet_transform.jl
Original file line number Diff line number Diff line change
@@ -1,130 +1,19 @@
export
SparseKernel,
SparseKernel1D,
SparseKernel2D,
SparseKernel3D
export WaveletTransform


struct SparseKernel{N,T,S}
conv_blk::T
out_weight::S
end

function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S}
input_dim, emb_dim = ch
conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
W_out = Dense(emb_dim, input_dim; init=init)
return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out)
end

function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k
emb_dim = 128
return SparseKernel((3, ), input_dim=>emb_dim; init=init)
end

function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k^2
emb_dim = α*k^2
return SparseKernel((3, 3), input_dim=>emb_dim; init=init)
end

function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k^2
emb_dim = α*k^2
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
W_out = Dense(emb_dim, input_dim; init=init)
return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out)
end

Flux.@functor SparseKernel

function (l::SparseKernel)(X::AbstractArray)
bch_sz, _, dims_r... = reverse(size(X))
dims = reverse(dims_r)

X_ = l.conv_blk(X) # (dims..., emb_dims, B)
X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B)
Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B)
Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B)
return collect(Y)
end


struct MWT_CZ1d{T,S,R,Q,P}
k::Int
L::Int
A::T
B::S
C::R
T0::Q
ec_s::P
ec_d::P
rc_e::P
rc_o::P
end

function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform)
H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k)
H0r = zero_out!(H0 * Φ0)
G0r = zero_out!(G0 * Φ0)
H1r = zero_out!(H1 * Φ1)
G1r = zero_out!(G1 * Φ1)

dim = c*k
A = SpectralConv(dim=>dim, (α,); init=init)
B = SpectralConv(dim=>dim, (α,); init=init)
C = SpectralConv(dim=>dim, (α,); init=init)
T0 = Dense(k, k)

ec_s = vcat(H0', H1')
ec_d = vcat(G0', G1')
rc_e = vcat(H0r, G0r)
rc_o = vcat(H1r, G1r)
return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
end

function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
N = size(X, 3)
Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :))
d = NNlib.batched_mul(Xa, l.ec_d)
s = NNlib.batched_mul(Xa, l.ec_s)
return d, s
struct WaveletTransform{N, S}<:AbstractTransform
modes::NTuple{N, S} # N == ndims(x)
end

function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
bch_sz, N, dims_r... = reverse(size(X))
dims = reverse(dims_r)
@assert dims[1] == 2*l.k
Xₑ = NNlib.batched_mul(X, l.rc_e)
Xₒ = NNlib.batched_mul(X, l.rc_o)
# x = torch.zeros(B, N*2, c, self.k,
# device = x.device)
# x[..., ::2, :, :] = x_e
# x[..., 1::2, :, :] = x_o
return X
end
Base.ndims(::WaveletTransform{N}) where {N} = N

function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray}
bch_sz, N, dims_r... = reverse(size(X))
ns = floor(log2(N))
stop = ns - l.L
# function transform(wt::WaveletTransform, 𝐱::AbstractArray)
# return fft(Zygote.hook(real, 𝐱), 1:ndims(wt)) # [size(x)..., in_chs, batch]
# end

# decompose
Ud = T[]
Us = T[]
for i in 1:stop
d, X = wavelet_transform(l, X)
push!(Ud, l.A(d)+l.B(d))
push!(Us, l.C(d))
end
X = l.T0(X)
# function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray)
# return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch]
# end

# reconstruct
for i in stop:-1:1
X += Us[i]
X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
X = even_odd(l, X)
end
return X
end
# function inverse(wt::WaveletTransform, 𝐱_fft::AbstractArray)
# return real(ifft(𝐱_fft, 1:ndims(wt))) # [size(x_fft)..., out_chs, batch]
# end
141 changes: 140 additions & 1 deletion src/operator_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ export
OperatorConv,
SpectralConv,
OperatorKernel,
GraphKernel
GraphKernel,
SparseKernel,
SparseKernel1D,
SparseKernel2D,
SparseKernel3D

struct OperatorConv{P, T, S, TT}
weight::T
Expand Down Expand Up @@ -216,6 +220,141 @@ function Base.show(io::IO, l::GraphKernel)
print(io, ")")
end

"""
SparseKernel(κ, ch, σ=identity)
Sparse kernel layer.
## Arguments
* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP.
* `ch`: Channel size for linear transform, e.g. `32`.
* `σ`: Activation function.
"""
struct SparseKernel{N,T,S}
conv_blk::T
out_weight::S
end

function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S}
input_dim, emb_dim = ch
conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
W_out = Dense(emb_dim, input_dim; init=init)
return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out)
end

function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k
emb_dim = 128
return SparseKernel((3, ), input_dim=>emb_dim; init=init)
end

function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k^2
emb_dim = α*k^2
return SparseKernel((3, 3), input_dim=>emb_dim; init=init)
end

function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k^2
emb_dim = α*k^2
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
W_out = Dense(emb_dim, input_dim; init=init)
return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out)
end

Flux.@functor SparseKernel

function (l::SparseKernel)(X::AbstractArray)
bch_sz, _, dims_r... = reverse(size(X))
dims = reverse(dims_r)

X_ = l.conv_blk(X) # (dims..., emb_dims, B)
X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B)
Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B)
Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B)
return collect(Y)
end


struct MWT_CZ1d{T,S,R,Q,P}
k::Int
L::Int
A::T
B::S
C::R
T0::Q
ec_s::P
ec_d::P
rc_e::P
rc_o::P
end

function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform)
H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k)
H0r = zero_out!(H0 * Φ0)
G0r = zero_out!(G0 * Φ0)
H1r = zero_out!(H1 * Φ1)
G1r = zero_out!(G1 * Φ1)

dim = c*k
A = SpectralConv(dim=>dim, (α,); init=init)
B = SpectralConv(dim=>dim, (α,); init=init)
C = SpectralConv(dim=>dim, (α,); init=init)
T0 = Dense(k, k)

ec_s = vcat(H0', H1')
ec_d = vcat(G0', G1')
rc_e = vcat(H0r, G0r)
rc_o = vcat(H1r, G1r)
return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
end

function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
N = size(X, 3)
Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :))
d = NNlib.batched_mul(Xa, l.ec_d)
s = NNlib.batched_mul(Xa, l.ec_s)
return d, s
end

function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
bch_sz, N, dims_r... = reverse(size(X))
dims = reverse(dims_r)
@assert dims[1] == 2*l.k
Xₑ = NNlib.batched_mul(X, l.rc_e)
Xₒ = NNlib.batched_mul(X, l.rc_o)
# x = torch.zeros(B, N*2, c, self.k,
# device = x.device)
# x[..., ::2, :, :] = x_e
# x[..., 1::2, :, :] = x_o
return X
end

function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray}
bch_sz, N, dims_r... = reverse(size(X))
ns = floor(log2(N))
stop = ns - l.L

# decompose
Ud = T[]
Us = T[]
for i in 1:stop
d, X = wavelet_transform(l, X)
push!(Ud, l.A(d)+l.B(d))
push!(Us, l.C(d))
end
X = l.T0(X)

# reconstruct
for i in stop:-1:1
X += Us[i]
X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
X = even_odd(l, X)
end
return X
end


#########
# utils #
Expand Down
File renamed without changes.
53 changes: 0 additions & 53 deletions test/Transform/wavelet_transform.jl
Original file line number Diff line number Diff line change
@@ -1,53 +0,0 @@
@testset "SparseKernel" begin
T = Float32
k = 3
batch_size = 32

@testset "1D SparseKernel" begin
α = 4
c = 1
in_chs = 20
X = rand(T, in_chs, c*k, batch_size)

l1 = SparseKernel1D(k, α, c)
Y = l1(X)
@test l1 isa SparseKernel{1}
@test size(Y) == size(X)

gs = gradient(()->sum(l1(X)), Flux.params(l1))
@test length(gs.grads) == 4
end

@testset "2D SparseKernel" begin
α = 4
c = 3
Nx = 5
Ny = 7
X = rand(T, Nx, Ny, c*k^2, batch_size)

l2 = SparseKernel2D(k, α, c)
Y = l2(X)
@test l2 isa SparseKernel{2}
@test size(Y) == size(X)

gs = gradient(()->sum(l2(X)), Flux.params(l2))
@test length(gs.grads) == 4
end

@testset "3D SparseKernel" begin
α = 4
c = 3
Nx = 5
Ny = 7
Nz = 13
X = rand(T, Nx, Ny, Nz, α*k^2, batch_size)

l3 = SparseKernel3D(k, α, c)
Y = l3(X)
@test l3 isa SparseKernel{3}
@test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size)

gs = gradient(()->sum(l3(X)), Flux.params(l3))
@test length(gs.grads) == 4
end
end
Loading

0 comments on commit 827ce6f

Please sign in to comment.