This repository has been archived by the owner on Sep 28, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from foldfelis/dp
add double pendulum example and the docs
- Loading branch information
Showing
14 changed files
with
421 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Burgers' equation | ||
|
||
In this example, a [Burgers' equation](https://en.wikipedia.org/wiki/Burgers%27_equation) | ||
is learned by a one-dimensional Fourier neural operator network. | ||
Change directory to `example/Burgers` and use following commend to train model: | ||
|
||
```julia | ||
$ julia --proj | ||
|
||
julia> using Burgers; Burgers.train() | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
name = "DoublePendulum" | ||
uuid = "0c23c1c1-5f41-4617-a685-ac46aae913c3" | ||
|
||
[deps] | ||
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" | ||
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" | ||
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" | ||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" | ||
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" | ||
|
||
[extras] | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[targets] | ||
test = ["Test"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Double pendulum | ||
|
||
A [double pendulum](https://www.wikiwand.com/en/Double_pendulum) is a pendulum with another pendulum attached to its end. | ||
In is example, instead of learning from a well-described equation of motion, | ||
we train the model with the famous dataset provided by IBM. | ||
The equation of motion to the real experiments of double pendulum is learned by a two-dimensional Fourier neural operator. | ||
It learns to inference the next 30 steps with the given first 30 steps. | ||
The result of this example can be found [here](https://foldfelis.github.io/NeuralOperators.jl/dev/assets/notebook/double_pendulum.jl.html). | ||
|
||
![](gallery/result.gif) | ||
|
||
By inference the result recurrently, we can generate up to 150 steps with the first 30 initial steps. | ||
|
||
Change directory to `example/DoublePendulum` and use following commend to train model: | ||
|
||
```julia | ||
$ julia --proj | ||
|
||
julia> using DoublePendulum; DoublePendulum.train() | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
### A Pluto.jl notebook ### | ||
# v0.16.0 | ||
|
||
using Markdown | ||
using InteractiveUtils | ||
|
||
# ╔═╡ 194baef2-0417-11ec-05ab-4527ef614024 | ||
using Pkg; Pkg.develop(path=".."); Pkg.activate("..") | ||
|
||
# ╔═╡ 38c9ced5-dcf8-4e03-ac07-7c435687861b | ||
begin | ||
using DoublePendulum | ||
using Plots | ||
end | ||
|
||
# ╔═╡ 396b5d7a-a7a4-4f22-a87e-39b405e8d62a | ||
md" | ||
# Double Pendulum | ||
JingYu Ning | ||
" | ||
|
||
# ╔═╡ 2a606ecf-acf0-41ad-9290-7569dbb22b5a | ||
md" | ||
The data is provided by [IBM](https://developer.ibm.com/exchanges/data/all/double-pendulum-chaotic/) | ||
> In this dataset, videos of the double pendulum were taken using a high-speed Phantom Miro EX2 camera. To make the extraction of the arm positions easier, a matte black background was used, and the three datums were marked with red, green and blue fiducial markers. The camera was placed at 2 meters from the pendulum, with the axis of the objective aligned with the first pendulum datum. The pendulum was launched by hand, and the camera was motion triggered. The dataset was generated on the basis of 21 individual runs of the pendulum. Each of the recorded sequences lasted around 40s and consisted of around 17500 frames. | ||
" | ||
|
||
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7 | ||
data_x, data_y, _, _ = DoublePendulum.preprocess( | ||
DoublePendulum.get_data(i=20), | ||
ratio=1 | ||
); | ||
|
||
# ╔═╡ 4d0b08a4-8a54-41fd-997f-ad54d4c984cd | ||
m = DoublePendulum.get_model(); | ||
|
||
# ╔═╡ ad6302b2-3d62-4a3f-b8bf-f69bab80c7a4 | ||
ground_truth_data = cat( | ||
[data_x[:, :, :, i] | ||
for i in 1:size(data_x, 3):size(data_x)[end]]..., dims=3 | ||
)[1, :, :]; | ||
|
||
# ╔═╡ 794374ce-6674-481d-8a3b-04db0f32d233 | ||
begin | ||
n = 5 | ||
inferenced_data = data_x[:, :, :, 1:1] | ||
for i in 1:n | ||
inferenced_data = cat( | ||
inferenced_data, | ||
m(inferenced_data[:, :, :, i:i]), | ||
dims=4 | ||
) | ||
end | ||
|
||
inferenced_data = cat( | ||
[inferenced_data[:, :, :, i] for i in 1:n]..., dims=3 | ||
)[1, :, :] | ||
end; | ||
|
||
# ╔═╡ 9c8b3f8a-1b85-4c32-a416-ead51b244b94 | ||
begin | ||
c = [ | ||
RGB([239, 71, 111]/255...), | ||
RGB([6, 214, 160]/255...), | ||
RGB([17, 138, 178]/255...) | ||
] | ||
xi, yi = [2, 4, 6], [1, 3, 5] | ||
|
||
anim = @animate for i in 1:size(inferenced_data)[end] | ||
i_data = [0, 0, inferenced_data[:, i]...] | ||
g_data = [0, 0, ground_truth_data[:, i]...] | ||
|
||
scatter( | ||
legend=false, ticks=false, | ||
xlim=(-1500, 1500), ylim=(-1500, 1500), size=(400, 350) | ||
) | ||
plot!(i_data[xi], i_data[yi], color=:black) | ||
scatter!(i_data[xi], i_data[yi], color=c, markersize=8) | ||
plot!(g_data[xi], g_data[yi], color=:gray) | ||
scatter!(g_data[xi], g_data[yi], color=c, markersize=4) | ||
|
||
if i ≤ 30 | ||
annotate!(-1400, -1400, text("t=$i", :left, color=:black)) | ||
else | ||
annotate!(-1400, -1400, text("t=$i", :left, color=:red)) | ||
end | ||
end | ||
|
||
gif(anim) | ||
end | ||
|
||
# ╔═╡ Cell order: | ||
# ╟─396b5d7a-a7a4-4f22-a87e-39b405e8d62a | ||
# ╟─2a606ecf-acf0-41ad-9290-7569dbb22b5a | ||
# ╟─194baef2-0417-11ec-05ab-4527ef614024 | ||
# ╠═38c9ced5-dcf8-4e03-ac07-7c435687861b | ||
# ╠═5268feee-bda2-4612-9d4c-a1db424a11c7 | ||
# ╠═4d0b08a4-8a54-41fd-997f-ad54d4c984cd | ||
# ╠═ad6302b2-3d62-4a3f-b8bf-f69bab80c7a4 | ||
# ╠═794374ce-6674-481d-8a3b-04db0f32d233 | ||
# ╟─9c8b3f8a-1b85-4c32-a416-ead51b244b94 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
module DoublePendulum | ||
|
||
using NeuralOperators | ||
using Flux | ||
using CUDA | ||
using JLD2 | ||
|
||
include("data.jl") | ||
|
||
__init__() = register_double_pendulum_chaotic() | ||
|
||
function update_model!(model_file_path, model) | ||
model = cpu(model) | ||
jldsave(model_file_path; model) | ||
@warn "model updated!" | ||
end | ||
|
||
function train(; Δt=1) | ||
if has_cuda() | ||
@info "CUDA is on" | ||
device = gpu | ||
CUDA.allowscalar(false) | ||
else | ||
device = cpu | ||
end | ||
|
||
m = Chain( | ||
Dense(2, 64), | ||
FourierOperator(64=>64, (4, 16), gelu), | ||
FourierOperator(64=>64, (4, 16), gelu), | ||
FourierOperator(64=>64, (4, 16), gelu), | ||
FourierOperator(64=>64, (4, 16)), | ||
Dense(64, 128, gelu), | ||
Dense(128, 2), | ||
) |> device | ||
|
||
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end] | ||
|
||
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3)) | ||
|
||
loader_train, loader_test = get_dataloader(Δt=Δt) | ||
|
||
losses = Float32[] | ||
function validate() | ||
validation_loss = sum(loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test)/length(loader_test) | ||
@info "loss: $validation_loss" | ||
|
||
push!(losses, validation_loss) | ||
(losses[end] == minimum(losses)) && update_model!(joinpath(@__DIR__, "../model/model.jld2"), m) | ||
end | ||
call_back = Flux.throttle(validate, 10, leading=false, trailing=true) | ||
|
||
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device | ||
for e in 1:20 | ||
@info "Epoch $e\n η: $(opt.os[2].eta)" | ||
@time Flux.train!(loss, params(m), data, opt, cb=call_back) | ||
(e%3 == 0) && (opt.os[2].eta /= 2) | ||
end | ||
end | ||
|
||
function get_model() | ||
f = jldopen(joinpath(@__DIR__, "../model/model.jld2")) | ||
model = f["model"] | ||
close(f) | ||
|
||
return model | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
using DataDeps | ||
using CSV | ||
using DataFrames | ||
|
||
function register_double_pendulum_chaotic() | ||
register(DataDep( | ||
"DoublePendulumChaotic", | ||
""" | ||
Dataset was generated on the basis of 21 individual runs of a double pendulum. | ||
Each of the recorded sequences lasted around 40s and consisted of around 17500 frames. | ||
* `x_red`: Horizontal pixel coordinate of the red point (the central pivot to the first pendulum) | ||
* `y_red`: Vertical pixel coordinate of the red point (the central pivot to the first pendulum) | ||
* `x_green`: Horizontal pixel coordinate of the green point (the first pendulum) | ||
* `y_green`: Vertical pixel coordinate of the green point (the first pendulum) | ||
* `x_blue`: Horizontal pixel coordinate of the blue point (the second pendulum) | ||
* `y_blue`: Vertical pixel coordinate of the blue point (the second pendulum) | ||
Page: https://developer.ibm.com/exchanges/data/all/double-pendulum-chaotic/ | ||
""", | ||
"https://dax-cdn.cdn.appdomain.cloud/dax-double-pendulum-chaotic/2.0.1/double-pendulum-chaotic.tar.gz", | ||
"4ca743b4b783094693d313ebedc2e8e53cf29821ee8b20abd99f8fb4c0866f8d", | ||
post_fetch_method=unpack | ||
)) | ||
end | ||
|
||
function get_data(; i=0, n=-1) | ||
data_path = joinpath(datadep"DoublePendulumChaotic", "original", "dpc_dataset_csv") | ||
df = CSV.read( | ||
joinpath(data_path, "$i.csv"), | ||
DataFrame, | ||
header=[:x_red, :y_red, :x_green, :y_green, :x_blue, :y_blue] | ||
) | ||
data = (n < 0) ? collect(Matrix(df)') : collect(Matrix(df)')[:, 1:n] | ||
|
||
return Float32.(data) | ||
end | ||
|
||
function preprocess(𝐱; Δt=1, nx=30, ny=30, ratio=0.9) | ||
# move red point to (0, 0) | ||
xs_red, ys_red = 𝐱[1, :], 𝐱[2, :] | ||
𝐱[3, :] -= xs_red; 𝐱[5, :] -= xs_red | ||
𝐱[4, :] -= ys_red; 𝐱[6, :] -= ys_red | ||
|
||
# needs only green and blue points | ||
𝐱 = reshape(𝐱[3:6, 1:Δt:end], 1, 4, :) | ||
# velocity of green and blue points | ||
∇𝐱 = 𝐱[:, :, 2:end] - 𝐱[:, :, 1:(end-1)] | ||
# merge info of pos and velocity | ||
𝐱 = cat(𝐱[:, :, 1:(end-1)], ∇𝐱, dims=1) | ||
|
||
# with info of first nx steps to inference next ny steps | ||
n = size(𝐱)[end] - (nx + ny) + 1 | ||
𝐱s = Array{Float32}(undef, size(𝐱)[1:2]..., nx, n) | ||
𝐲s = Array{Float32}(undef, size(𝐱)[1:2]..., ny, n) | ||
for i in 1:n | ||
𝐱s[:, :, :, i] .= 𝐱[:, :, i:(i+nx-1)] | ||
𝐲s[:, :, :, i] .= 𝐱[:, :, (i+nx):(i+nx+ny-1)] | ||
end | ||
|
||
n_train = floor(Int, ratio*n) | ||
𝐱_train, 𝐲_train = 𝐱s[:, :, :, 1:n_train], 𝐲s[:, :, :, 1:n_train] | ||
𝐱_test, 𝐲_test = 𝐱s[:, :, :, (n_train+1):end], 𝐲s[:, :, :, (n_train+1):end] | ||
|
||
return 𝐱_train, 𝐲_train, 𝐱_test, 𝐲_test | ||
end | ||
|
||
function get_dataloader(; n_file=20, Δt=1, nx=30, ny=30, ratio=0.9, batchsize=100) | ||
𝐱_train, 𝐲_train = Array{Float32}(undef, 2, 4, nx, 0), Array{Float32}(undef, 2, 4, ny, 0) | ||
𝐱_test, 𝐲_test = Array{Float32}(undef, 2, 4, nx, 0), Array{Float32}(undef, 2, 4, ny, 0) | ||
for i in 0:(n_file-1) | ||
𝐱_train_i, 𝐲_train_i, 𝐱_test_i, 𝐲_test_i = preprocess(get_data(i=i), Δt=Δt, nx=nx, ny=ny, ratio=ratio) | ||
|
||
𝐱_train, 𝐲_train = cat(𝐱_train, 𝐱_train_i, dims=4), cat(𝐲_train, 𝐲_train_i, dims=4) | ||
𝐱_test, 𝐲_test = cat(𝐱_test, 𝐱_test_i, dims=4), cat(𝐲_test, 𝐲_test_i, dims=4) | ||
end | ||
|
||
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true) | ||
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false) | ||
|
||
return loader_train, loader_test | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
@testset "double pendulum" begin | ||
xs = DoublePendulum.get_data(i=0, n=100) | ||
|
||
@test size(xs) == (6, 100) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
using DoublePendulum | ||
using Test | ||
|
||
@testset "DoublePendulum" begin | ||
include("data.jl") | ||
end |
Oops, something went wrong.