Skip to content

Commit

Permalink
docs: migrate most examples to Reactant (#1180)
Browse files Browse the repository at this point in the history
* docs: Basics run on CPU

* docs: Run Polynomial Fitting using Reactant

* feat: allow users to bump the HLO

* docs: update Optimization tutorial

* docs: use Reactant for CPU in SimpleChains

* docs: update PINN2DPDE

* docs: partially move HyperNet to reactant

* chore: run formatter [skip tests]

* docs: highlight Reactant more prominently

* docs: update SimpleRNN

* fix: incorrect check in Embedding

* fix: bump enzyme in project

* feat: handle weight initializers for reactant RNGs

* fix: workaround for #1186

* fix: simpleRNN works with reactant

* fix: failing tests and use overlay

* revert: Hypernet keep in CUDA for now
  • Loading branch information
avik-pal authored Jan 8, 2025
1 parent 476f3f4 commit 10ea255
Show file tree
Hide file tree
Showing 40 changed files with 362 additions and 169 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.4"
version = "1.5.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -88,7 +88,7 @@ Compat = "4.16"
ComponentArrays = "0.15.18"
ConcreteStructs = "0.2.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.16"
Enzyme = "0.13.28"
EnzymeCore = "0.8.8"
FastClosures = "0.3.2"
Flux = "0.15, 0.16"
Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,38 @@ gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss
(x, dev(rand(rng, Float32, 10, 2))), train_state)
```

## 🤸 Quickstart with Reactant

```julia
using Lux, Random, Optimisers, Reactant, Enzyme

rng = Random.default_rng()
Random.seed!(rng, 0)

model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))

dev = reactant_device()

ps, st = Lux.setup(rng, model) |> dev

x = rand(rng, Float32, 128, 2) |> dev

# We need to compile the model before we can use it.
model_forward = @compile model(x, ps, Lux.testmode(st))
model_forward(x, ps, Lux.testmode(st))

# Gradients can be computed using Enzyme
@jit Enzyme.gradient(Reverse, sum first Lux.apply, Const(model), x, ps, Const(st))

# All of this can be automated using the TrainState API
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

gs, loss, stats, train_state = Training.single_train_step!(
AutoEnzyme(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state
)
```

## 📚 Examples

Look in the [examples](/examples/) directory for self-contained usage examples. The [documentation](https://lux.csail.mit.edu) has examples sorted into proper categories.
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ julia = "1.10"
[sources]
Lux = { path = "../" }
LuxLib = { path = "../lib/LuxLib" }
LuxCUDA = { path = "../lib/LuxCUDA" }
LuxCore = { path = "../lib/LuxCore" }
MLDataDevices = { path = "../lib/MLDataDevices" }
LuxTestUtils = { path = "../lib/LuxTestUtils" }
Expand Down
10 changes: 6 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Documenter, DocumenterVitepress, Pkg
using Lux, LuxCore, LuxLib, WeightInitializers, NNlib
using LuxTestUtils, MLDataDevices
using LuxCUDA

using Optimisers # for some docstrings

Expand Down Expand Up @@ -78,8 +77,10 @@ pages = [
#! format: on

deploy_config = Documenter.auto_detect_deploy_system()
deploy_decision = Documenter.deploy_folder(deploy_config; repo="github.com/LuxDL/Lux.jl",
devbranch="main", devurl="dev", push_preview=true)
deploy_decision = Documenter.deploy_folder(
deploy_config; repo="github.com/LuxDL/Lux.jl",
devbranch="main", devurl="dev", push_preview=true
)

makedocs(;
sitename="Lux.jl Docs",
Expand All @@ -96,7 +97,8 @@ makedocs(;
repo="https://github.com/LuxDL/Lux.jl/blob/{commit}{path}#{line}",
format=DocumenterVitepress.MarkdownVitepress(;
repo="github.com/LuxDL/Lux.jl", devbranch="main", devurl="dev",
deploy_url="https://lux.csail.mit.edu", deploy_decision),
deploy_url="https://lux.csail.mit.edu", deploy_decision
),
draft=false,
pages
)
Expand Down
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ hero:
features:
- icon: 🚀
title: Fast & Extendible
details: Lux.jl is written in Julia itself, making it extremely extendible. CUDA and AMDGPU are supported first-class, with experimental support for Metal and Intel GPUs.
title: Fast & Extendable
details: Lux.jl is written in Julia itself, making it extremely extendable. CUDA and AMDGPU are supported first-class, with experimental support for Metal and Intel GPUs.
link: /introduction
- icon: 🐎
Expand Down
50 changes: 33 additions & 17 deletions docs/src/introduction/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ Pkg.add("Lux")

```@example quickstart
using Lux, Random, Optimisers, Zygote
using LuxCUDA # For CUDA support
# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
```

We take randomness very seriously
Expand Down Expand Up @@ -66,26 +65,33 @@ y, st = Lux.apply(model, x, ps, st)
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state)
gs, loss, stats, train_state = Lux.Training.compute_gradients(
AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state
)
## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)
# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state)
gs, loss, stats, train_state = Training.single_train_step!(
AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state
)
```

## Defining Custom Layers

We can train our model using the above code, but let's go ahead and see how to use Reactant.
Reactant is a julia frontend that generates MLIR and then compiles it using XLA (after
running fancy optimizations). It is the current recommended way to train large models in
Lux. For more details on using Reactant, see the [manual](@ref reactant-compilation).

```@example custom_compact
using Lux, Random, Optimisers, Zygote
using LuxCUDA # For CUDA support
# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
using Lux, Random, Optimisers, Reactant, Enzyme
using Printf # For pretty printing
dev = gpu_device()
dev = reactant_device()
```

We will define a custom MLP using the `@compact` macro. The macro takes in a list of
Expand All @@ -97,10 +103,12 @@ n_in = 1
n_out = 1
nlayers = 3
model = @compact(w1=Dense(n_in => 32),
model = @compact(
w1=Dense(n_in => 32),
w2=[Dense(32 => 32) for i in 1:nlayers],
w3=Dense(32 => n_out),
act=relu) do x
act=relu
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
Expand All @@ -116,21 +124,24 @@ We can initialize the model and train it with the same code as before!
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model) |> dev
ps, st = Lux.setup(rng, model) |> dev
x = rand(rng, Float32, n_in, 32) |> dev
model(x, ps, st) # 1×32 Matrix and updated state as output.
@jit model(x, ps, st) # 1×32 Matrix and updated state as output.
x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :) |> dev
x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :)
y_data = 2 .* x_data .- x_data .^ 3
x_data, y_data = dev(x_data), dev(y_data)
function train_model!(model, ps, st, x_data, y_data)
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))
for iter in 1:1000
_, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), MSELoss(),
(x_data, y_data), train_state)
_, loss, _, train_state = Lux.Training.single_train_step!(
AutoEnzyme(), MSELoss(),
(x_data, y_data), train_state
)
if iter % 100 == 1 || iter == 1000
@printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
end
Expand All @@ -155,6 +166,11 @@ packages mentioned in this documentation are available via the Julia General Reg

You can install all those packages via `import Pkg; Pkg.add(<package name>)`.

## XLA (CPU/GPU/TPU) Support

Lux.jl supports XLA compilation for CPU, GPU, and TPU using
[Reactant.jl](https://github.com/EnzymeAD/Reactant.jl).

## GPU Support

GPU Support for Lux.jl requires loading additional packages:
Expand Down
10 changes: 10 additions & 0 deletions docs/src/manual/compiling_lux_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ fmap(Broadcast.BroadcastFunction(-), ∂ps_zyg, ∂ps_enzyme |> cpu_device())

## [Using the `TrainState` API](@id compile_lux_model_trainstate)

!!! tip "Debugging TrainState API Failures"

If the code fails to compile with Reactant, it is useful to dump the HLO. Starting the
Julia session with `LUX_DUMP_REACTANT_HLO_OPTIMIZE` environment variable set to
`no_enzyme`, `false`, or `true` will dump the HLO to a file (filename will be
displayed). This is an useful information to provide when opening an issue.

Alternatively, you can set theglobal reference `Lux.DUMP_REACTANT_HLO_OPT_MODE` to a
symbol corresponding to the `optimize` keyword argument to `@code_hlo`.

Now that we saw the low-level API let's see how to train the model without any of this
boilerplate. Simply follow the following steps:

Expand Down
13 changes: 6 additions & 7 deletions docs/src/manual/gpu_management.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
# GPU Management

!!! info

Starting from `v0.5`, Lux has transitioned to a new GPU management system. The old
system using `cpu` and `gpu` functions is still in place but will be removed in `v1`.
Using the old functions might lead to performance regressions if used inside
performance critical code.

`Lux.jl` can handle multiple GPU backends. Currently, the following backends are supported:

```@example gpu_management
Expand All @@ -16,6 +9,12 @@ using Lux, LuxCUDA #, AMDGPU, Metal, oneAPI
supported_gpu_backends()
```

!!! tip "GPU Support via Reactant"

If you are using Reactant, you can use the [`reactant_device`](@ref) function to
automatically select Reactant backend if available. Additionally to force Reactant to
use `gpu`, you can run `Reactant.set_default_backend("gpu")` (this is automatic).

!!! danger "Metal Support"

Support for Metal GPUs should be considered extremely experimental at this point.
Expand Down
5 changes: 3 additions & 2 deletions docs/tutorials.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#! format: off
const BEGINNER_TUTORIALS = [
"Basics/main.jl" => "CUDA",
"Basics/main.jl" => "CPU",
"PolynomialFitting/main.jl" => "CUDA",
"SimpleRNN/main.jl" => "CUDA",
# Technically this is run on CPU but we need a better machine to run it
"SimpleChains/main.jl" => "CUDA",
"OptimizationIntegration/main.jl" => "CUDA",
"OptimizationIntegration/main.jl" => "CPU",
]
const INTERMEDIATE_TUTORIALS = [
"NeuralODE/main.jl" => "CUDA",
Expand Down
2 changes: 0 additions & 2 deletions examples/Basics/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -12,6 +11,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ComponentArrays = "0.15.18"
ForwardDiff = "0.10"
Lux = "1"
LuxCUDA = "0.3"
Optimisers = "0.4.1"
Zygote = "0.6"
14 changes: 8 additions & 6 deletions examples/Basics/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,14 @@ W * x
# the `cu` function (or the `gpu` function exported by `Lux``), and it supports all of the
# above operations with the same syntax.

using LuxCUDA

if LuxCUDA.functional()
x_cu = cu(rand(5, 3))
@show x_cu
end
# ```julia
# using LuxCUDA
#
# if LuxCUDA.functional()
# x_cu = cu(rand(5, 3))
# @show x_cu
# end
# ```

# ## (Im)mutability

Expand Down
11 changes: 5 additions & 6 deletions examples/ConditionalVAE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ end
function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int)
total_images = grid_rows * grid_cols
imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img
cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) : colorview(RGB, img)
cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) :
colorview(RGB, permutedims(img, (3, 1, 2)))
return cimg'
end
return create_image_grid(imgs, grid_rows, grid_cols)
Expand Down Expand Up @@ -239,23 +240,21 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f
for epoch in 1:epochs
loss_total = 0.0f0
total_samples = 0
total_time = 0.0

start_time = time()
for (i, X) in enumerate(train_dataloader)
throughput_tic = time()
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state)
throughput_toc = time()

loss_total += loss
total_samples += size(X, ndims(X))
total_time += throughput_toc - throughput_tic

if i % 250 == 0 || i == length(train_dataloader)
throughput = total_samples / total_time
throughput = total_samples / (time() - start_time)
@printf "Epoch %d, Iter %d, Loss: %.7f, Throughput: %.6f im/s\n" epoch i loss throughput
end
end
total_time = time() - start_time

train_loss = loss_total / length(train_dataloader)
throughput = total_samples / total_time
Expand Down
6 changes: 0 additions & 6 deletions examples/HyperNet/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Expand All @@ -9,19 +8,14 @@ OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1.10"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Setfield = "1"
Statistics = "1"
Zygote = "0.6"
4 changes: 2 additions & 2 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

# ## Package Imports

using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Setfield, Statistics, Zygote
using Lux, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Zygote

CUDA.allowscalar(false)

Expand Down
2 changes: 0 additions & 2 deletions examples/OptimizationIntegration/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand All @@ -16,7 +15,6 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
CairoMakie = "0.12.10"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3.3"
MLUtils = "0.4.4"
Optimization = "4"
OptimizationOptimJL = "0.4"
Expand Down
Loading

6 comments on commit 10ea255

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/LuxCore

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/LuxLib

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122618

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a LuxCore-v1.2.2 -m "<description of version>" 10ea255face7fdd6b198cf9e2e641a8414afe037
git push origin LuxCore-v1.2.2

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/WeightInitializers

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122619

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a LuxLib-v1.4.1 -m "<description of version>" 10ea255face7fdd6b198cf9e2e641a8414afe037
git push origin LuxLib-v1.4.1

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/122620

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a WeightInitializers-v1.1.0 -m "<description of version>" 10ea255face7fdd6b198cf9e2e641a8414afe037
git push origin WeightInitializers-v1.1.0

Please sign in to comment.