Skip to content
This repository has been archived by the owner on Sep 28, 2024. It is now read-only.

Commit

Permalink
Merge pull request #16 from foldfelis/n-d_example
Browse files Browse the repository at this point in the history
Fix functor bug and build project for Burgers' equation
  • Loading branch information
foldfelis authored Aug 22, 2021
2 parents fa998db + 29d55da commit cf8eddd
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 171 deletions.
16 changes: 16 additions & 0 deletions example/Burgers/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name = "Burgers"
uuid = "5b053d85-f964-4905-ae31-99551cd8d3ad"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
49 changes: 49 additions & 0 deletions example/Burgers/src/Burgers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module Burgers

using NeuralOperators
using Flux
using CUDA

include("data.jl")

__init__() = register_burgers()

function train()
if has_cuda()
@info "CUDA is on"
device = gpu
CUDA.allowscalar(false)
else
device = cpu
end

modes = (16, )
ch = 64 => 64
σ = gelu
m = Chain(
Dense(2, 64),
FourierOperator(ch, modes, σ),
FourierOperator(ch, modes, σ),
FourierOperator(ch, modes, σ),
FourierOperator(ch, modes),
Dense(64, 128, σ),
Dense(128, 1),
flatten
) |> device

loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]

loader_train, loader_test = get_dataloader()

function validate()
validation_losses = [loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test]
@info "loss: $(sum(validation_losses)/length(loader_test))"
end

data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
call_back = Flux.throttle(validate, 5, leading=false, trailing=true)
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
end

end
44 changes: 44 additions & 0 deletions example/Burgers/src/data.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using DataDeps
using Fetch
using MAT

export get_burgers_data

function register_burgers()
register(DataDep(
"Burgers",
"""
Burgers' equation dataset from
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
""",
"https://drive.google.com/file/d/17MYsKzxUQVaLMWodzPbffR8hhDHoadPp/view?usp=sharing",
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd",
fetch_method=gdownload,
post_fetch_method=unpack
))
end

function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32)
file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat"))
x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]'))
y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]'))
close(file)

x_loc_data = Array{T, 3}(undef, 2, grid_size, n)
x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 1, grid_size), n), (grid_size, n))
x_loc_data[2, :, :] .= x_data

return x_loc_data, y_data
end

function get_dataloader(; n_train=1800, n_test=200, batchsize=100)
𝐱, 𝐲 = get_burgers_data(n=2048)

𝐱_train, 𝐲_train = 𝐱[:, :, 1:n_train], 𝐲[:, 1:n_train]
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)

𝐱_test, 𝐲_test = 𝐱[:, :, end-n_test+1:end], 𝐲[:, end-n_test+1:end]
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)

return loader_train, loader_test
end
6 changes: 6 additions & 0 deletions example/Burgers/test/data.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@testset "get burgers data" begin
xs, ys = get_burgers_data(n=1000)

@test size(xs) == (2, 1024, 1000)
@test size(ys) == (1024, 1000)
end
6 changes: 6 additions & 0 deletions example/Burgers/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using Burgers
using Test

@testset "Burgers" begin
include("data.jl")
end
35 changes: 0 additions & 35 deletions example/burgers.jl

This file was deleted.

10 changes: 0 additions & 10 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
module NeuralOperators
using DataDeps
using Fetch
using MAT
using StatsBase

using Flux
using FFTW
using Tullio
Expand All @@ -13,11 +8,6 @@ module NeuralOperators
using Zygote
using ChainRulesCore

function __init__()
register_datasets()
end

include("data.jl")
include("fourier.jl")
include("model.jl")
end
85 changes: 0 additions & 85 deletions src/data.jl

This file was deleted.

20 changes: 14 additions & 6 deletions src/fourier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,25 @@ export
FourierOperator

struct SpectralConv{P, N, T, S, F}
permuted::Bool
weight::T
in_channel::S
out_channel::S
modes::NTuple{N, S}
σ::F
end

function SpectralConv(
permuted::Bool,
weight::T,
in_channel::S,
out_channel::S,
modes::NTuple{N, S},
σ::F
) where {N, T, S, F}
return SpectralConv{permuted, N, T, S, F}(permuted, weight, in_channel, out_channel, modes, σ)
end

"""
SpectralConv(
ch, modes, σ=identity;
Expand Down Expand Up @@ -50,18 +62,14 @@ function SpectralConv(
in_chs, out_chs = ch
scale = one(T) / (in_chs * out_chs)
weights = scale * init(out_chs, in_chs, prod(modes))
W = typeof(weights)
F = typeof(σ)

return SpectralConv{permuted,N,W,S,F}(weights, in_chs, out_chs, modes, σ)
return SpectralConv(permuted, weights, in_chs, out_chs, modes, σ)
end

Flux.@functor SpectralConv

Base.ndims(::SpectralConv{P,N}) where {P,N} = N

permuted(::SpectralConv{P}) where {P} = P

function Base.show(io::IO, l::SpectralConv{P}) where {P}
print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)), permuted=$P)")
end
Expand Down Expand Up @@ -142,7 +150,7 @@ end
Flux.@functor FourierOperator

function Base.show(io::IO, l::FourierOperator)
print(io, "FourierOperator($(l.conv.in_channel) => $(l.conv.out_channel), $(l.conv.modes), σ=$(string(l.σ)), permuted=$(permuted(l.conv)))")
print(io, "FourierOperator($(l.conv.in_channel) => $(l.conv.out_channel), $(l.conv.modes), σ=$(string(l.σ)), permuted=$(l.conv.permuted))")
end

function (m::FourierOperator)(𝐱)
Expand Down
24 changes: 0 additions & 24 deletions test/data.jl

This file was deleted.

Loading

2 comments on commit cf8eddd

@foldfelis
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/43279

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" cf8eddd7befd6a90de85a8ca1321780dc6bd447a
git push origin v0.1.0

Please sign in to comment.