diff --git a/example/Burgers/Project.toml b/example/Burgers/Project.toml new file mode 100644 index 00000000..b09fae0c --- /dev/null +++ b/example/Burgers/Project.toml @@ -0,0 +1,16 @@ +name = "Burgers" +uuid = "5b053d85-f964-4905-ae31-99551cd8d3ad" + +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +MAT = "23992714-dd62-5051-b70f-ba57cb901cac" +NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" + +[extras] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Test"] diff --git a/example/Burgers/src/Burgers.jl b/example/Burgers/src/Burgers.jl new file mode 100644 index 00000000..6056d934 --- /dev/null +++ b/example/Burgers/src/Burgers.jl @@ -0,0 +1,49 @@ +module Burgers + +using NeuralOperators +using Flux +using CUDA + +include("data.jl") + +__init__() = register_burgers() + +function train() + if has_cuda() + @info "CUDA is on" + device = gpu + CUDA.allowscalar(false) + else + device = cpu + end + + modes = (16, ) + ch = 64 => 64 + σ = gelu + m = Chain( + Dense(2, 64), + FourierOperator(ch, modes, σ), + FourierOperator(ch, modes, σ), + FourierOperator(ch, modes, σ), + FourierOperator(ch, modes), + Dense(64, 128, σ), + Dense(128, 1), + flatten + ) |> device + + loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end] + + loader_train, loader_test = get_dataloader() + + function validate() + validation_losses = [loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test] + @info "loss: $(sum(validation_losses)/length(loader_test))" + end + + data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device + opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3)) + call_back = Flux.throttle(validate, 5, leading=false, trailing=true) + Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=call_back)) +end + +end diff --git a/example/Burgers/src/data.jl b/example/Burgers/src/data.jl new file mode 100644 index 00000000..67708336 --- /dev/null +++ b/example/Burgers/src/data.jl @@ -0,0 +1,44 @@ +using DataDeps +using Fetch +using MAT + +export get_burgers_data + +function register_burgers() + register(DataDep( + "Burgers", + """ + Burgers' equation dataset from + [fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator) + """, + "https://drive.google.com/file/d/17MYsKzxUQVaLMWodzPbffR8hhDHoadPp/view?usp=sharing", + "9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd", + fetch_method=gdownload, + post_fetch_method=unpack + )) +end + +function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32) + file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat")) + x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]')) + y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]')) + close(file) + + x_loc_data = Array{T, 3}(undef, 2, grid_size, n) + x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 1, grid_size), n), (grid_size, n)) + x_loc_data[2, :, :] .= x_data + + return x_loc_data, y_data +end + +function get_dataloader(; n_train=1800, n_test=200, batchsize=100) + 𝐱, 𝐲 = get_burgers_data(n=2048) + + 𝐱_train, 𝐲_train = 𝐱[:, :, 1:n_train], 𝐲[:, 1:n_train] + loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true) + + 𝐱_test, 𝐲_test = 𝐱[:, :, end-n_test+1:end], 𝐲[:, end-n_test+1:end] + loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false) + + return loader_train, loader_test +end diff --git a/example/Burgers/test/data.jl b/example/Burgers/test/data.jl new file mode 100644 index 00000000..41476f23 --- /dev/null +++ b/example/Burgers/test/data.jl @@ -0,0 +1,6 @@ +@testset "get burgers data" begin + xs, ys = get_burgers_data(n=1000) + + @test size(xs) == (2, 1024, 1000) + @test size(ys) == (1024, 1000) +end diff --git a/example/Burgers/test/runtests.jl b/example/Burgers/test/runtests.jl new file mode 100644 index 00000000..cf6eb4e0 --- /dev/null +++ b/example/Burgers/test/runtests.jl @@ -0,0 +1,6 @@ +using Burgers +using Test + +@testset "Burgers" begin + include("data.jl") +end diff --git a/example/burgers.jl b/example/burgers.jl deleted file mode 100644 index fe381286..00000000 --- a/example/burgers.jl +++ /dev/null @@ -1,35 +0,0 @@ -using NeuralOperators -using Flux -using CUDA - -if has_cuda() - @info "CUDA is on" - device = gpu - CUDA.allowscalar(false) -else - device = cpu -end - -m = FourierNeuralOperator() |> device -loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end] - -n_train = 1800 -n_test = 200 -batchsize = 100 -𝐱, 𝐲 = get_burgers_data(n=2048) - -𝐱_train, 𝐲_train = 𝐱[:, :, 1:n_train], 𝐲[:, 1:n_train] -loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true) - -𝐱_test, 𝐲_test = 𝐱[:, :, end-n_test+1:end], 𝐲[:, end-n_test+1:end] -loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false) - -function validate() - validation_losses = [loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test] - @info "loss: $(sum(validation_losses)/length(loader_test))" -end - -data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device -opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3)) -call_back = Flux.throttle(validate, 5, leading=false, trailing=true) -Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=call_back)) diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index f450634a..3df1d34f 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -1,9 +1,4 @@ module NeuralOperators - using DataDeps - using Fetch - using MAT - using StatsBase - using Flux using FFTW using Tullio @@ -13,11 +8,6 @@ module NeuralOperators using Zygote using ChainRulesCore - function __init__() - register_datasets() - end - - include("data.jl") include("fourier.jl") include("model.jl") end diff --git a/src/data.jl b/src/data.jl deleted file mode 100644 index 376baa1c..00000000 --- a/src/data.jl +++ /dev/null @@ -1,85 +0,0 @@ -export - UnitGaussianNormalizer, - encode, - decode, - get_burgers_data, - get_darcy_flow_data - -struct UnitGaussianNormalizer{T} - mean::Array{T} - std::Array{T} - ϵ::T -end - -function UnitGaussianNormalizer(𝐱; ϵ=1f-5) - dims = 1:ndims(𝐱)-1 - - return UnitGaussianNormalizer(mean(𝐱, dims=dims), StatsBase.std(𝐱, dims=dims), ϵ) -end - -encode(n::UnitGaussianNormalizer, 𝐱::AbstractArray) = @. (𝐱-n.mean) / (n.std+n.ϵ) -decode(n::UnitGaussianNormalizer, 𝐱::AbstractArray) = @. 𝐱 * (n.std+n.ϵ) + n.mean - - -function register_burgers() - register(DataDep( - "Burgers", - """ - Burgers' equation dataset from - [fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator) - """, - "https://drive.google.com/file/d/17MYsKzxUQVaLMWodzPbffR8hhDHoadPp/view?usp=sharing", - "9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd", - fetch_method=gdownload, - post_fetch_method=unpack - )) -end - -function register_darcy_flow() - register(DataDep( - "DarcyFlow", - """ - Darcy flow dataset from - [fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator) - """, - "https://drive.google.com/file/d/1zzVMuGhOG70EnR5L24LWqmX9-Wh_H5Wu/view?usp=sharing", - "802825de9da7398407296c99ca9ceb2371c752f6a3bdd1801172e02ce19edda4", - fetch_method=gdownload, - post_fetch_method=unpack - )) -end - -function register_datasets() - register_burgers() - register_darcy_flow() -end - -function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32) - file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat")) - x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]')) - y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]')) - close(file) - - x_loc_data = Array{T, 3}(undef, 2, grid_size, n) - x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 1, grid_size), n), (grid_size, n)) - x_loc_data[2, :, :] .= x_data - - return x_loc_data, y_data -end - -function get_darcy_flow_data(; n=1024, Δsamples=5, T=Float32, test_data=false) - # size(training_data) == size(testing_data) == (1024, 421, 421) - file = test_data ? "piececonst_r421_N1024_smooth2.mat" : "piececonst_r421_N1024_smooth1.mat" - file = matopen(joinpath(datadep"DarcyFlow", file)) - x_data = T.(permutedims(read(file, "coeff")[1:n, 1:Δsamples:end, 1:Δsamples:end], (3, 2, 1))) - y_data = T.(permutedims(read(file, "sol")[1:n, 1:Δsamples:end, 1:Δsamples:end], (3, 2, 1))) - close(file) - - x_dims = pushfirst!([size(x_data)...], 1) - y_dims = pushfirst!([size(y_data)...], 1) - x_data, y_data = reshape(x_data, x_dims...), reshape(y_data, y_dims...) - - x_normalizer, y_normalizer = UnitGaussianNormalizer(x_data), UnitGaussianNormalizer(y_data) - - return encode(x_normalizer, x_data), encode(y_normalizer, y_data), x_normalizer, y_normalizer -end diff --git a/src/fourier.jl b/src/fourier.jl index 8f258a49..2a5b9c92 100644 --- a/src/fourier.jl +++ b/src/fourier.jl @@ -3,6 +3,7 @@ export FourierOperator struct SpectralConv{P, N, T, S, F} + permuted::Bool weight::T in_channel::S out_channel::S @@ -10,6 +11,17 @@ struct SpectralConv{P, N, T, S, F} σ::F end +function SpectralConv( + permuted::Bool, + weight::T, + in_channel::S, + out_channel::S, + modes::NTuple{N, S}, + σ::F +) where {N, T, S, F} + return SpectralConv{permuted, N, T, S, F}(permuted, weight, in_channel, out_channel, modes, σ) +end + """ SpectralConv( ch, modes, σ=identity; @@ -50,18 +62,14 @@ function SpectralConv( in_chs, out_chs = ch scale = one(T) / (in_chs * out_chs) weights = scale * init(out_chs, in_chs, prod(modes)) - W = typeof(weights) - F = typeof(σ) - return SpectralConv{permuted,N,W,S,F}(weights, in_chs, out_chs, modes, σ) + return SpectralConv(permuted, weights, in_chs, out_chs, modes, σ) end Flux.@functor SpectralConv Base.ndims(::SpectralConv{P,N}) where {P,N} = N -permuted(::SpectralConv{P}) where {P} = P - function Base.show(io::IO, l::SpectralConv{P}) where {P} print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)), permuted=$P)") end @@ -142,7 +150,7 @@ end Flux.@functor FourierOperator function Base.show(io::IO, l::FourierOperator) - print(io, "FourierOperator($(l.conv.in_channel) => $(l.conv.out_channel), $(l.conv.modes), σ=$(string(l.σ)), permuted=$(permuted(l.conv)))") + print(io, "FourierOperator($(l.conv.in_channel) => $(l.conv.out_channel), $(l.conv.modes), σ=$(string(l.σ)), permuted=$(l.conv.permuted))") end function (m::FourierOperator)(𝐱) diff --git a/test/data.jl b/test/data.jl deleted file mode 100644 index f20cae63..00000000 --- a/test/data.jl +++ /dev/null @@ -1,24 +0,0 @@ -@testset "get burgers data" begin - xs, ys = get_burgers_data(n=1000) - - @test size(xs) == (2, 1024, 1000) - @test size(ys) == (1024, 1000) -end - -@testset "unit gaussian normalizer" begin - dims = (3, 3, 5, 6) - 𝐱 = rand(Float32, dims) - - n = UnitGaussianNormalizer(𝐱) - - @test size(n.mean) == size(n.std) - @test size(encode(n, 𝐱)) == dims - @test size(decode(n, encode(n, 𝐱))) == dims -end - -@testset "get darcy flow data" begin - xs, ys, _, _ = get_darcy_flow_data() - - @test size(xs) == (1, 85, 85, 1024) - @test size(ys) == (1, 85, 85, 1024) -end diff --git a/test/fourier.jl b/test/fourier.jl index 1d6f0d68..8851eb6c 100644 --- a/test/fourier.jl +++ b/test/fourier.jl @@ -9,7 +9,7 @@ @test ndims(SpectralConv(ch, modes)) == 1 @test repr(SpectralConv(ch, modes)) == "SpectralConv(64 => 128, (16,), σ=identity, permuted=false)" - 𝐱, _ = get_burgers_data(n=5) + 𝐱 = rand(Float32, 2, 1024, 5) @test size(m(𝐱)) == (128, 1024, 5) loss(x, y) = Flux.mse(m(x), y) @@ -28,7 +28,7 @@ end @test ndims(SpectralConv(ch, modes, permuted=true)) == 1 @test repr(SpectralConv(ch, modes, permuted=true)) == "SpectralConv(64 => 128, (16,), σ=identity, permuted=true)" - 𝐱, _ = get_burgers_data(n=5) + 𝐱 = rand(Float32, 2, 1024, 5) 𝐱 = permutedims(𝐱, (2, 1, 3)) @test size(m(𝐱)) == (1024, 128, 5) @@ -47,7 +47,7 @@ end ) @test repr(FourierOperator(ch, modes)) == "FourierOperator(64 => 128, (16,), σ=identity, permuted=false)" - 𝐱, _ = get_burgers_data(n=5) + 𝐱 = rand(Float32, 2, 1024, 5) @test size(m(𝐱)) == (128, 1024, 5) loss(x, y) = Flux.mse(m(x), y) @@ -65,7 +65,7 @@ end ) @test repr(FourierOperator(ch, modes, permuted=true)) == "FourierOperator(64 => 128, (16,), σ=identity, permuted=true)" - 𝐱, _ = get_burgers_data(n=5) + 𝐱 = rand(Float32, 2, 1024, 5) 𝐱 = permutedims(𝐱, (2, 1, 3)) @test size(m(𝐱)) == (1024, 128, 5) @@ -84,7 +84,7 @@ end ) @test ndims(SpectralConv(ch, modes)) == 2 - 𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20) + 𝐱 = rand(Float32, 1, 22, 22, 5) @test size(m(𝐱)) == (64, 22, 22, 5) loss(x, y) = Flux.mse(m(x), y) @@ -102,7 +102,7 @@ end ) @test ndims(SpectralConv(ch, modes, permuted=true)) == 2 - 𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20) + 𝐱 = rand(Float32, 1, 22, 22, 5) 𝐱 = permutedims(𝐱, (2, 3, 1, 4)) @test size(m(𝐱)) == (22, 22, 64, 5) @@ -120,7 +120,7 @@ end FourierOperator(ch, modes) ) - 𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20) + 𝐱 = rand(Float32, 1, 22, 22, 5) @test size(m(𝐱)) == (64, 22, 22, 5) loss(x, y) = Flux.mse(m(x), y) @@ -137,7 +137,7 @@ end FourierOperator(ch, modes, permuted=true) ) - 𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20) + 𝐱 = rand(Float32, 1, 22, 22, 5) 𝐱 = permutedims(𝐱, (2, 3, 1, 4)) @test size(m(𝐱)) == (22, 22, 64, 5) diff --git a/test/model.jl b/test/model.jl index f5198083..09ac6150 100644 --- a/test/model.jl +++ b/test/model.jl @@ -1,8 +1,7 @@ @testset "FourierNeuralOperator" begin m = FourierNeuralOperator() - 𝐱, 𝐲 = get_burgers_data() - 𝐱, 𝐲 = Float32.(𝐱), Float32.(𝐲) + 𝐱, 𝐲 = rand(Float32, 2, 1024, 5), rand(Float32, 1024, 5) @test size(m(𝐱)) == size(𝐲) loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end] diff --git a/test/runtests.jl b/test/runtests.jl index 93608045..b741a18e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,6 @@ using Test using Flux @testset "NeuralOperators.jl" begin - include("data.jl") include("fourier.jl") include("model.jl") end