From 10ea255face7fdd6b198cf9e2e641a8414afe037 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 8 Jan 2025 13:20:22 -0500 Subject: [PATCH] docs: migrate most examples to Reactant (#1180) * docs: Basics run on CPU * docs: Run Polynomial Fitting using Reactant * feat: allow users to bump the HLO * docs: update Optimization tutorial * docs: use Reactant for CPU in SimpleChains * docs: update PINN2DPDE * docs: partially move HyperNet to reactant * chore: run formatter [skip tests] * docs: highlight Reactant more prominently * docs: update SimpleRNN * fix: incorrect check in Embedding * fix: bump enzyme in project * feat: handle weight initializers for reactant RNGs * fix: workaround for #1186 * fix: simpleRNN works with reactant * fix: failing tests and use overlay * revert: Hypernet keep in CUDA for now --- Project.toml | 4 +- README.md | 32 ++++++++++ docs/Project.toml | 1 + docs/make.jl | 10 +-- docs/src/index.md | 4 +- docs/src/introduction/index.md | 50 ++++++++++----- docs/src/manual/compiling_lux_models.md | 10 +++ docs/src/manual/gpu_management.md | 13 ++-- docs/tutorials.jl | 5 +- examples/Basics/Project.toml | 2 - examples/Basics/main.jl | 14 ++-- examples/ConditionalVAE/main.jl | 11 ++-- examples/HyperNet/Project.toml | 6 -- examples/HyperNet/main.jl | 4 +- examples/OptimizationIntegration/Project.toml | 2 - examples/OptimizationIntegration/main.jl | 1 - examples/PINN2DPDE/Project.toml | 6 +- examples/PINN2DPDE/main.jl | 64 +++++++++++-------- examples/PolynomialFitting/Project.toml | 10 +-- examples/PolynomialFitting/main.jl | 26 +++++--- examples/SimpleChains/Project.toml | 1 + examples/SimpleChains/main.jl | 59 +++++++++++------ examples/SimpleRNN/Project.toml | 7 +- examples/SimpleRNN/main.jl | 49 +++++++++----- ext/LuxReactantExt/LuxReactantExt.jl | 7 +- ext/LuxReactantExt/patches.jl | 5 ++ ext/LuxReactantExt/training.jl | 24 ++++++- lib/LuxCore/Project.toml | 2 +- lib/LuxCore/ext/LuxCoreReactantExt.jl | 5 +- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl | 2 +- lib/LuxLib/src/LuxLib.jl | 2 +- lib/LuxLib/src/traits.jl | 1 + lib/WeightInitializers/Project.toml | 5 +- .../ext/WeightInitializersReactantExt.jl | 29 +++++++++ src/Lux.jl | 15 +++++ src/extended_ops.jl | 12 ++-- src/layers/basic.jl | 20 ++++-- test/qa_tests.jl | 3 +- test/reactant/layer_tests.jl | 6 +- 40 files changed, 362 insertions(+), 169 deletions(-) create mode 100644 lib/WeightInitializers/ext/WeightInitializersReactantExt.jl diff --git a/Project.toml b/Project.toml index 77788a3803..5caf1ff9a7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.4.4" +version = "1.5.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -88,7 +88,7 @@ Compat = "4.16" ComponentArrays = "0.15.18" ConcreteStructs = "0.2.3" DispatchDoctor = "0.4.12" -Enzyme = "0.13.16" +Enzyme = "0.13.28" EnzymeCore = "0.8.8" FastClosures = "0.3.2" Flux = "0.15, 0.16" diff --git a/README.md b/README.md index 56daadc150..56dbb93919 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,38 @@ gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss (x, dev(rand(rng, Float32, 10, 2))), train_state) ``` +## 🤸 Quickstart with Reactant + +```julia +using Lux, Random, Optimisers, Reactant, Enzyme + +rng = Random.default_rng() +Random.seed!(rng, 0) + +model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10))) + +dev = reactant_device() + +ps, st = Lux.setup(rng, model) |> dev + +x = rand(rng, Float32, 128, 2) |> dev + +# We need to compile the model before we can use it. +model_forward = @compile model(x, ps, Lux.testmode(st)) +model_forward(x, ps, Lux.testmode(st)) + +# Gradients can be computed using Enzyme +@jit Enzyme.gradient(Reverse, sum ∘ first ∘ Lux.apply, Const(model), x, ps, Const(st)) + +# All of this can be automated using the TrainState API +train_state = Training.TrainState(model, ps, st, Adam(0.001f0)) + +gs, loss, stats, train_state = Training.single_train_step!( + AutoEnzyme(), MSELoss(), + (x, dev(rand(rng, Float32, 10, 2))), train_state +) +``` + ## 📚 Examples Look in the [examples](/examples/) directory for self-contained usage examples. The [documentation](https://lux.csail.mit.edu) has examples sorted into proper categories. diff --git a/docs/Project.toml b/docs/Project.toml index e0d9b02476..e252fcf933 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -66,6 +66,7 @@ julia = "1.10" [sources] Lux = { path = "../" } LuxLib = { path = "../lib/LuxLib" } +LuxCUDA = { path = "../lib/LuxCUDA" } LuxCore = { path = "../lib/LuxCore" } MLDataDevices = { path = "../lib/MLDataDevices" } LuxTestUtils = { path = "../lib/LuxTestUtils" } diff --git a/docs/make.jl b/docs/make.jl index 8d407f3d2d..59a871bf91 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,7 +1,6 @@ using Documenter, DocumenterVitepress, Pkg using Lux, LuxCore, LuxLib, WeightInitializers, NNlib using LuxTestUtils, MLDataDevices -using LuxCUDA using Optimisers # for some docstrings @@ -78,8 +77,10 @@ pages = [ #! format: on deploy_config = Documenter.auto_detect_deploy_system() -deploy_decision = Documenter.deploy_folder(deploy_config; repo="github.com/LuxDL/Lux.jl", - devbranch="main", devurl="dev", push_preview=true) +deploy_decision = Documenter.deploy_folder( + deploy_config; repo="github.com/LuxDL/Lux.jl", + devbranch="main", devurl="dev", push_preview=true +) makedocs(; sitename="Lux.jl Docs", @@ -96,7 +97,8 @@ makedocs(; repo="https://github.com/LuxDL/Lux.jl/blob/{commit}{path}#{line}", format=DocumenterVitepress.MarkdownVitepress(; repo="github.com/LuxDL/Lux.jl", devbranch="main", devurl="dev", - deploy_url="https://lux.csail.mit.edu", deploy_decision), + deploy_url="https://lux.csail.mit.edu", deploy_decision + ), draft=false, pages ) diff --git a/docs/src/index.md b/docs/src/index.md index f30a7a476d..1d19139f89 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -23,8 +23,8 @@ hero: features: - icon: 🚀 - title: Fast & Extendible - details: Lux.jl is written in Julia itself, making it extremely extendible. CUDA and AMDGPU are supported first-class, with experimental support for Metal and Intel GPUs. + title: Fast & Extendable + details: Lux.jl is written in Julia itself, making it extremely extendable. CUDA and AMDGPU are supported first-class, with experimental support for Metal and Intel GPUs. link: /introduction - icon: 🐎 diff --git a/docs/src/introduction/index.md b/docs/src/introduction/index.md index b67ce45d62..4be0de857e 100644 --- a/docs/src/introduction/index.md +++ b/docs/src/introduction/index.md @@ -25,8 +25,7 @@ Pkg.add("Lux") ```@example quickstart using Lux, Random, Optimisers, Zygote -using LuxCUDA # For CUDA support -# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support +# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support ``` We take randomness very seriously @@ -66,26 +65,33 @@ y, st = Lux.apply(model, x, ps, st) train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0)) ## We can compute the gradients using Training.compute_gradients -gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(), - (x, dev(rand(rng, Float32, 10, 2))), train_state) +gs, loss, stats, train_state = Lux.Training.compute_gradients( + AutoZygote(), MSELoss(), + (x, dev(rand(rng, Float32, 10, 2))), train_state +) ## Optimization train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end) # Both these steps can be combined into a single call -gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(), - (x, dev(rand(rng, Float32, 10, 2))), train_state) +gs, loss, stats, train_state = Training.single_train_step!( + AutoZygote(), MSELoss(), + (x, dev(rand(rng, Float32, 10, 2))), train_state +) ``` ## Defining Custom Layers +We can train our model using the above code, but let's go ahead and see how to use Reactant. +Reactant is a julia frontend that generates MLIR and then compiles it using XLA (after +running fancy optimizations). It is the current recommended way to train large models in +Lux. For more details on using Reactant, see the [manual](@ref reactant-compilation). + ```@example custom_compact -using Lux, Random, Optimisers, Zygote -using LuxCUDA # For CUDA support -# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support +using Lux, Random, Optimisers, Reactant, Enzyme using Printf # For pretty printing -dev = gpu_device() +dev = reactant_device() ``` We will define a custom MLP using the `@compact` macro. The macro takes in a list of @@ -97,10 +103,12 @@ n_in = 1 n_out = 1 nlayers = 3 -model = @compact(w1=Dense(n_in => 32), +model = @compact( + w1=Dense(n_in => 32), w2=[Dense(32 => 32) for i in 1:nlayers], w3=Dense(32 => n_out), - act=relu) do x + act=relu +) do x embed = act(w1(x)) for w in w2 embed = act(w(embed)) @@ -116,21 +124,24 @@ We can initialize the model and train it with the same code as before! rng = Random.default_rng() Random.seed!(rng, 0) -ps, st = Lux.setup(Xoshiro(0), model) |> dev +ps, st = Lux.setup(rng, model) |> dev x = rand(rng, Float32, n_in, 32) |> dev -model(x, ps, st) # 1×32 Matrix and updated state as output. +@jit model(x, ps, st) # 1×32 Matrix and updated state as output. -x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :) |> dev +x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :) y_data = 2 .* x_data .- x_data .^ 3 +x_data, y_data = dev(x_data), dev(y_data) function train_model!(model, ps, st, x_data, y_data) train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0)) for iter in 1:1000 - _, loss, _, train_state = Lux.Training.single_train_step!(AutoZygote(), MSELoss(), - (x_data, y_data), train_state) + _, loss, _, train_state = Lux.Training.single_train_step!( + AutoEnzyme(), MSELoss(), + (x_data, y_data), train_state + ) if iter % 100 == 1 || iter == 1000 @printf "Iteration: %04d \t Loss: %10.9g\n" iter loss end @@ -155,6 +166,11 @@ packages mentioned in this documentation are available via the Julia General Reg You can install all those packages via `import Pkg; Pkg.add()`. +## XLA (CPU/GPU/TPU) Support + +Lux.jl supports XLA compilation for CPU, GPU, and TPU using +[Reactant.jl](https://github.com/EnzymeAD/Reactant.jl). + ## GPU Support GPU Support for Lux.jl requires loading additional packages: diff --git a/docs/src/manual/compiling_lux_models.md b/docs/src/manual/compiling_lux_models.md index 0264f44268..cfbd407486 100644 --- a/docs/src/manual/compiling_lux_models.md +++ b/docs/src/manual/compiling_lux_models.md @@ -124,6 +124,16 @@ fmap(Broadcast.BroadcastFunction(-), ∂ps_zyg, ∂ps_enzyme |> cpu_device()) ## [Using the `TrainState` API](@id compile_lux_model_trainstate) +!!! tip "Debugging TrainState API Failures" + + If the code fails to compile with Reactant, it is useful to dump the HLO. Starting the + Julia session with `LUX_DUMP_REACTANT_HLO_OPTIMIZE` environment variable set to + `no_enzyme`, `false`, or `true` will dump the HLO to a file (filename will be + displayed). This is an useful information to provide when opening an issue. + + Alternatively, you can set theglobal reference `Lux.DUMP_REACTANT_HLO_OPT_MODE` to a + symbol corresponding to the `optimize` keyword argument to `@code_hlo`. + Now that we saw the low-level API let's see how to train the model without any of this boilerplate. Simply follow the following steps: diff --git a/docs/src/manual/gpu_management.md b/docs/src/manual/gpu_management.md index dcb15ce3b0..50cd3cc954 100644 --- a/docs/src/manual/gpu_management.md +++ b/docs/src/manual/gpu_management.md @@ -1,12 +1,5 @@ # GPU Management -!!! info - - Starting from `v0.5`, Lux has transitioned to a new GPU management system. The old - system using `cpu` and `gpu` functions is still in place but will be removed in `v1`. - Using the old functions might lead to performance regressions if used inside - performance critical code. - `Lux.jl` can handle multiple GPU backends. Currently, the following backends are supported: ```@example gpu_management @@ -16,6 +9,12 @@ using Lux, LuxCUDA #, AMDGPU, Metal, oneAPI supported_gpu_backends() ``` +!!! tip "GPU Support via Reactant" + + If you are using Reactant, you can use the [`reactant_device`](@ref) function to + automatically select Reactant backend if available. Additionally to force Reactant to + use `gpu`, you can run `Reactant.set_default_backend("gpu")` (this is automatic). + !!! danger "Metal Support" Support for Metal GPUs should be considered extremely experimental at this point. diff --git a/docs/tutorials.jl b/docs/tutorials.jl index 5bdcf3056f..abbfd6581f 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -1,10 +1,11 @@ #! format: off const BEGINNER_TUTORIALS = [ - "Basics/main.jl" => "CUDA", + "Basics/main.jl" => "CPU", "PolynomialFitting/main.jl" => "CUDA", "SimpleRNN/main.jl" => "CUDA", + # Technically this is run on CPU but we need a better machine to run it "SimpleChains/main.jl" => "CUDA", - "OptimizationIntegration/main.jl" => "CUDA", + "OptimizationIntegration/main.jl" => "CPU", ] const INTERMEDIATE_TUTORIALS = [ "NeuralODE/main.jl" => "CUDA", diff --git a/examples/Basics/Project.toml b/examples/Basics/Project.toml index c7ea884bf5..2abe510aa1 100644 --- a/examples/Basics/Project.toml +++ b/examples/Basics/Project.toml @@ -2,7 +2,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -12,6 +11,5 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ComponentArrays = "0.15.18" ForwardDiff = "0.10" Lux = "1" -LuxCUDA = "0.3" Optimisers = "0.4.1" Zygote = "0.6" diff --git a/examples/Basics/main.jl b/examples/Basics/main.jl index ba00fae84e..401fc8a961 100644 --- a/examples/Basics/main.jl +++ b/examples/Basics/main.jl @@ -109,12 +109,14 @@ W * x # the `cu` function (or the `gpu` function exported by `Lux``), and it supports all of the # above operations with the same syntax. -using LuxCUDA - -if LuxCUDA.functional() - x_cu = cu(rand(5, 3)) - @show x_cu -end +# ```julia +# using LuxCUDA +# +# if LuxCUDA.functional() +# x_cu = cu(rand(5, 3)) +# @show x_cu +# end +# ``` # ## (Im)mutability diff --git a/examples/ConditionalVAE/main.jl b/examples/ConditionalVAE/main.jl index 99f90d321d..380ae1d399 100644 --- a/examples/ConditionalVAE/main.jl +++ b/examples/ConditionalVAE/main.jl @@ -151,7 +151,8 @@ end function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int) total_images = grid_rows * grid_cols imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img - cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) : colorview(RGB, img) + cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) : + colorview(RGB, permutedims(img, (3, 1, 2))) return cimg' end return create_image_grid(imgs, grid_rows, grid_cols) @@ -239,23 +240,21 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f for epoch in 1:epochs loss_total = 0.0f0 total_samples = 0 - total_time = 0.0 + start_time = time() for (i, X) in enumerate(train_dataloader) - throughput_tic = time() (_, loss, _, train_state) = Training.single_train_step!( AutoEnzyme(), loss_function, X, train_state) - throughput_toc = time() loss_total += loss total_samples += size(X, ndims(X)) - total_time += throughput_toc - throughput_tic if i % 250 == 0 || i == length(train_dataloader) - throughput = total_samples / total_time + throughput = total_samples / (time() - start_time) @printf "Epoch %d, Iter %d, Loss: %.7f, Throughput: %.6f im/s\n" epoch i loss throughput end end + total_time = time() - start_time train_loss = loss_total / length(train_dataloader) throughput = total_samples / total_time diff --git a/examples/HyperNet/Project.toml b/examples/HyperNet/Project.toml index 9654279710..0e56fe62db 100644 --- a/examples/HyperNet/Project.toml +++ b/examples/HyperNet/Project.toml @@ -1,5 +1,4 @@ [deps] -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" @@ -9,12 +8,9 @@ OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "1.10" ComponentArrays = "0.15.18" Lux = "1" LuxCUDA = "0.3" @@ -22,6 +18,4 @@ MLDatasets = "0.7" MLUtils = "0.4" OneHotArrays = "0.2.5" Optimisers = "0.4.1" -Setfield = "1" -Statistics = "1" Zygote = "0.6" diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index 9afad5e384..b38be28073 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -2,8 +2,8 @@ # ## Package Imports -using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers, - Printf, Random, Setfield, Statistics, Zygote +using Lux, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers, + Printf, Random, Zygote CUDA.allowscalar(false) diff --git a/examples/OptimizationIntegration/Project.toml b/examples/OptimizationIntegration/Project.toml index e43832d037..35f00f5f73 100644 --- a/examples/OptimizationIntegration/Project.toml +++ b/examples/OptimizationIntegration/Project.toml @@ -2,7 +2,6 @@ CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" @@ -16,7 +15,6 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" CairoMakie = "0.12.10" ComponentArrays = "0.15.18" Lux = "1" -LuxCUDA = "0.3.3" MLUtils = "0.4.4" Optimization = "4" OptimizationOptimJL = "0.4" diff --git a/examples/OptimizationIntegration/main.jl b/examples/OptimizationIntegration/main.jl index 84ed27e6ee..481bdbd280 100644 --- a/examples/OptimizationIntegration/main.jl +++ b/examples/OptimizationIntegration/main.jl @@ -18,7 +18,6 @@ using Lux, Optimization, OptimizationOptimisers, OptimizationOptimJL, OrdinaryDiffEqTsit5, SciMLSensitivity, Random, MLUtils, CairoMakie, ComponentArrays, Printf -using LuxCUDA const gdev = gpu_device() const cdev = cpu_device() diff --git a/examples/PINN2DPDE/Project.toml b/examples/PINN2DPDE/Project.toml index 4b2b24bf0c..a89163bede 100644 --- a/examples/PINN2DPDE/Project.toml +++ b/examples/PINN2DPDE/Project.toml @@ -1,25 +1,23 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1.10" CairoMakie = "0.12.10" Lux = "1" -LuxCUDA = "0.3.3" MLUtils = "0.4.4" OnlineStats = "1.7.1" Optimisers = "0.4.1" Printf = "1.10" Random = "1.10" Statistics = "1.10" -Zygote = "0.6.70" diff --git a/examples/PINN2DPDE/main.jl b/examples/PINN2DPDE/main.jl index f2921ec63b..c2ba37a4d2 100644 --- a/examples/PINN2DPDE/main.jl +++ b/examples/PINN2DPDE/main.jl @@ -10,13 +10,10 @@ # ## Package Imports -using ADTypes, Lux, Optimisers, Zygote, Random, Printf, Statistics, MLUtils, OnlineStats, - CairoMakie -using LuxCUDA +using Lux, Optimisers, Random, Printf, Statistics, MLUtils, OnlineStats, CairoMakie, + Reactant, Enzyme -CUDA.allowscalar(false) - -const gdev = gpu_device() +const xdev = reactant_device(; force=true) const cdev = cpu_device() # ## Problem Definition @@ -60,12 +57,13 @@ end # will use the following loss function @views function physics_informed_loss_function( - u::StatefulLuxLayer, v::StatefulLuxLayer, w::StatefulLuxLayer, xyt::AbstractArray) - ∂u_∂xyt = only(Zygote.gradient(sum ∘ u, xyt)) + u::StatefulLuxLayer, v::StatefulLuxLayer, w::StatefulLuxLayer, xyt::AbstractArray +) + ∂u_∂xyt = Enzyme.gradient(Enzyme.Reverse, sum ∘ u, xyt)[1] ∂u_∂x, ∂u_∂y, ∂u_∂t = ∂u_∂xyt[1:1, :], ∂u_∂xyt[2:2, :], ∂u_∂xyt[3:3, :] - ∂v_∂x = only(Zygote.gradient(sum ∘ v, xyt))[1:1, :] + ∂v_∂x = Enzyme.gradient(Enzyme.Reverse, sum ∘ v, xyt)[1][1:1, :] v_xyt = v(xyt) - ∂w_∂y = only(Zygote.gradient(sum ∘ w, xyt))[2:2, :] + ∂w_∂y = Enzyme.gradient(Enzyme.Reverse, sum ∘ w, xyt)[1][2:2, :] w_xyt = w(xyt) return ( mean(abs2, ∂u_∂t .- ∂v_∂x .- ∂w_∂y) + @@ -141,37 +139,45 @@ nothing #hide # ## Training -function train_model(xyt, target_data, xyt_bc, target_bc; seed::Int=0, - maxiters::Int=50000, hidden_dims::Int=32) +function train_model( + xyt, target_data, xyt_bc, target_bc; seed::Int=0, + maxiters::Int=50000, hidden_dims::Int=32 +) rng = Random.default_rng() Random.seed!(rng, seed) pinn = PINN(; hidden_dims) - ps, st = Lux.setup(rng, pinn) |> gdev + ps, st = Lux.setup(rng, pinn) |> xdev - bc_dataloader = DataLoader((xyt_bc, target_bc); batchsize=32, shuffle=true) |> gdev - pde_dataloader = DataLoader((xyt, target_data); batchsize=32, shuffle=true) |> gdev + bc_dataloader = DataLoader( + (xyt_bc, target_bc); batchsize=32, shuffle=true, partial=false + ) |> xdev + pde_dataloader = DataLoader( + (xyt, target_data); batchsize=32, shuffle=true, partial=false + ) |> xdev train_state = Training.TrainState(pinn, ps, st, Adam(0.05f0)) lr = i -> i < 5000 ? 0.05f0 : (i < 10000 ? 0.005f0 : 0.0005f0) total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker = ntuple( - _ -> Lag(Float32, 32), 4) + _ -> OnlineStats.CircBuff(Float32, 32; rev=true), 4) iter = 1 for ((xyt_batch, target_data_batch), (xyt_bc_batch, target_bc_batch)) in zip( - Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader)) + Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader) + ) Optimisers.adjust!(train_state, lr(iter)) _, loss, stats, train_state = Training.single_train_step!( - AutoZygote(), loss_function, ( - xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch), - train_state) + AutoEnzyme(), loss_function, + (xyt_batch, target_data_batch, xyt_bc_batch, target_bc_batch), + train_state + ) - fit!(total_loss_tracker, loss) - fit!(physics_loss_tracker, stats.physics_loss) - fit!(data_loss_tracker, stats.data_loss) - fit!(bc_loss_tracker, stats.bc_loss) + fit!(total_loss_tracker, Float32(loss)) + fit!(physics_loss_tracker, Float32(stats.physics_loss)) + fit!(data_loss_tracker, Float32(stats.data_loss)) + fit!(bc_loss_tracker, Float32(stats.bc_loss)) mean_loss = mean(OnlineStats.value(total_loss_tracker)) mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker)) @@ -181,7 +187,7 @@ function train_model(xyt, target_data, xyt_bc, target_bc; seed::Int=0, isnan(loss) && throw(ArgumentError("NaN Loss Detected")) if iter % 1000 == 1 || iter == maxiters - @printf "Iteration: [%5d / %5d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \ + @printf "Iteration: [%6d/%6d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f \ (%.9f) \t Data Loss: %.9f (%.9f) \t BC \ Loss: %.9f (%.9f)\n" iter maxiters loss mean_loss stats.physics_loss mean_physics_loss stats.data_loss mean_data_loss stats.bc_loss mean_bc_loss end @@ -191,12 +197,14 @@ function train_model(xyt, target_data, xyt_bc, target_bc; seed::Int=0, end return StatefulLuxLayer{true}( - pinn, cdev(train_state.parameters), cdev(train_state.states)) + pinn, cdev(train_state.parameters), cdev(train_state.states) + ) end trained_model = train_model(xyt, target_data, xyt_bc, target_bc) -trained_u = Lux.testmode(StatefulLuxLayer{true}( - trained_model.model.u, trained_model.ps.u, trained_model.st.u)) +trained_u = Lux.testmode( + StatefulLuxLayer{true}(trained_model.model.u, trained_model.ps.u, trained_model.st.u) +) nothing #hide # ## Visualizing the Results diff --git a/examples/PolynomialFitting/Project.toml b/examples/PolynomialFitting/Project.toml index 168865fdd6..66e1f5fda8 100644 --- a/examples/PolynomialFitting/Project.toml +++ b/examples/PolynomialFitting/Project.toml @@ -2,18 +2,18 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1.10" CairoMakie = "0.12" Lux = "1" -LuxCUDA = "0.3" Optimisers = "0.4.1" -Statistics = "1" -Zygote = "0.6" +Printf = "1.10" +Random = "1.10" +Reactant = "0.2.14" +Statistics = "1.10" diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index 50f32b447f..bf2d7e1b90 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -5,8 +5,7 @@ # ## Package Imports -using Lux, ADTypes, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote -using CairoMakie +using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie # ## Dataset @@ -55,10 +54,10 @@ opt = Adam(0.03f0) # functions provided by Lux. const loss_function = MSELoss() -const dev_cpu = cpu_device() -const dev_gpu = gpu_device() +const cdev = cpu_device() +const xdev = reactant_device() -ps, st = Lux.setup(rng, model) |> dev_gpu +ps, st = Lux.setup(rng, model) |> xdev # ## Training @@ -67,14 +66,14 @@ ps, st = Lux.setup(rng, model) |> dev_gpu tstate = Training.TrainState(model, ps, st, opt) -# Now we will use Zygote for our AD requirements. +# Now we will use Enzyme (Reactant) for our AD requirements. -vjp_rule = AutoZygote() +vjp_rule = AutoEnzyme() # Finally the training loop. function main(tstate::Training.TrainState, vjp, data, epochs) - data = data .|> gpu_device() + data = data |> xdev for epoch in 1:epochs _, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate) if epoch % 50 == 1 || epoch == epochs @@ -85,7 +84,16 @@ function main(tstate::Training.TrainState, vjp, data, epochs) end tstate = main(tstate, vjp_rule, (x, y), 250) -y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1]) + +# Since we are using Reactant, we need to compile the model before we can use it. + +forward_pass = @compile Lux.apply( + tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states) +) + +y_pred = cdev(first(forward_pass( + tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states) +))) nothing #hide # Let's plot the results diff --git a/examples/SimpleChains/Project.toml b/examples/SimpleChains/Project.toml index 3f1b9b2a41..0045a3d0c8 100644 --- a/examples/SimpleChains/Project.toml +++ b/examples/SimpleChains/Project.toml @@ -7,6 +7,7 @@ OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index 1ff12bc23f..7a356f99f0 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -7,10 +7,12 @@ # reference. # ## Package Imports -using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf +using Lux, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf, Reactant using MLDatasets: MNIST using SimpleChains: SimpleChains +Reactant.set_default_backend("cpu") + # ## Loading MNIST function loadmnist(batchsize, train_split) ## Load MNIST @@ -31,16 +33,26 @@ function loadmnist(batchsize, train_split) return ( ## Use DataLoader to automatically minibatch and shuffle the data - DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true), + DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true, partial=false), ## Don't shuffle the test data - DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false)) + DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false, partial=false) + ) end # ## Define the Model -lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), - Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3), - Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10))) +lux_model = Chain( + Conv((5, 5), 1 => 6, relu), + MaxPool((2, 2)), + Conv((5, 5), 6 => 16, relu), + MaxPool((2, 2)), + FlattenLayer(3), + Chain( + Dense(256 => 128, relu), + Dense(128 => 84, relu), + Dense(84 => 10) + ) +) # We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining # the [`ToSimpleChainsAdaptor`](@ref) and providing the input dimensions. @@ -49,7 +61,7 @@ adaptor = ToSimpleChainsAdaptor((28, 28, 1)) simple_chains_model = adaptor(lux_model) # ## Helper Functions -const loss = CrossEntropyLoss(; logits=Val(true)) +const lossfn = CrossEntropyLoss(; logits=Val(true)) function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 @@ -64,16 +76,20 @@ function accuracy(model, ps, st, dataloader) end # ## Define the Training Loop -function train(model; rng=Xoshiro(0), kwargs...) - train_dataloader, test_dataloader = loadmnist(128, 0.9) - ps, st = Lux.setup(rng, model) +function train(model, dev=cpu_device(); rng=Random.default_rng(), kwargs...) + train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev + ps, st = Lux.setup(rng, model) |> dev + + vjp = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote() train_state = Training.TrainState(model, ps, st, Adam(3.0f-4)) - ### Warmup the model - x_proto = randn(rng, Float32, 28, 28, 1, 1) - y_proto = onehotbatch([1], 0:9) - Training.compute_gradients(AutoZygote(), loss, (x_proto, y_proto), train_state) + if dev isa ReactantDevice + x_ra = first(test_dataloader)[1] + model_compiled = @compile model(x_ra, ps, Lux.testmode(st)) + else + model_compiled = model + end ### Lets train the model nepochs = 10 @@ -81,15 +97,18 @@ function train(model; rng=Xoshiro(0), kwargs...) for epoch in 1:nepochs stime = time() for (x, y) in train_dataloader - gs, _, _, train_state = Training.single_train_step!( - AutoZygote(), loss, (x, y), train_state) + _, _, _, train_state = Training.single_train_step!( + vjp, lossfn, (x, y), train_state + ) end ttime = time() - stime tr_acc = accuracy( - model, train_state.parameters, train_state.states, train_dataloader) * 100 + model_compiled, train_state.parameters, train_state.states, train_dataloader) * + 100 te_acc = accuracy( - model, train_state.parameters, train_state.states, test_dataloader) * 100 + model_compiled, train_state.parameters, train_state.states, test_dataloader) * + 100 @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \ %.2f%%\n" epoch nepochs ttime tr_acc te_acc @@ -101,12 +120,12 @@ end # ## Finally Training the Model # First we will train the Lux model -tr_acc, te_acc = train(lux_model) +tr_acc, te_acc = train(lux_model, reactant_device()) @assert tr_acc > 0.75 && te_acc > 0.75 #hide nothing #hide # Now we will train the SimpleChains model -train(simple_chains_model) +tr_acc, te_acc = train(simple_chains_model) @assert tr_acc > 0.75 && te_acc > 0.75 #hide nothing #hide diff --git a/examples/SimpleRNN/Project.toml b/examples/SimpleRNN/Project.toml index 81e54f61e5..4eba4ce69e 100644 --- a/examples/SimpleRNN/Project.toml +++ b/examples/SimpleRNN/Project.toml @@ -2,20 +2,15 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" [compat] ADTypes = "1.10" JLD2 = "0.5" Lux = "1" -LuxCUDA = "0.3" MLUtils = "0.4" Optimisers = "0.4.1" -Statistics = "1" -Zygote = "0.6" diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index a11a2c5cbc..bc642d2f01 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -9,7 +9,7 @@ # ## Package Imports -using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics +using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random # ## Dataset @@ -34,9 +34,11 @@ function get_dataloaders(; dataset_size=1000, sequence_length=50) ## Create DataLoaders return ( ## Use DataLoader to automatically minibatch and shuffle the data - DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true), + DataLoader( + collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false), ## Don't shuffle the validation data - DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false)) + DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false) + ) end # ## Creating a Classifier @@ -51,7 +53,7 @@ end # To understand more about container layers, please look at # [Container Layer](@ref Container-Layer). -struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)} +struct SpiralClassifier{L, C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)} lstm_cell::L classifier::C end @@ -128,35 +130,52 @@ accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred) # ## Training the Model function main(model_type) - dev = gpu_device() + dev = reactant_device() + cdev = cpu_device() ## Get the dataloaders - train_loader, val_loader = get_dataloaders() .|> dev + train_loader, val_loader = get_dataloaders() |> dev ## Create the model model = model_type(2, 8, 1) - rng = Xoshiro(0) - ps, st = Lux.setup(rng, model) |> dev + ps, st = Lux.setup(Random.default_rng(), model) |> dev train_state = Training.TrainState(model, ps, st, Adam(0.01f0)) + model_compiled = if dev isa ReactantDevice + @compile model(first(train_loader)[1], ps, Lux.testmode(st)) + else + model + end + ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote() for epoch in 1:25 ## Train the model + total_loss = 0.0f0 + total_samples = 0 for (x, y) in train_loader (_, loss, _, train_state) = Training.single_train_step!( - AutoZygote(), lossfn, (x, y), train_state) - - @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss + ad, lossfn, (x, y), train_state + ) + total_loss += loss * length(y) + total_samples += length(y) end + @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss/total_samples) ## Validate the model + total_acc = 0.0f0 + total_loss = 0.0f0 + total_samples = 0 + st_ = Lux.testmode(train_state.states) for (x, y) in val_loader - ŷ, st_ = model(x, train_state.parameters, st_) - loss = lossfn(ŷ, y) - acc = accuracy(ŷ, y) - @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc + ŷ, st_ = model_compiled(x, train_state.parameters, st_) + ŷ, y = cdev(ŷ), cdev(y) + total_acc += accuracy(ŷ, y) * length(y) + total_loss += lossfn(ŷ, y) * length(y) + total_samples += length(y) end + + @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss/total_samples) (total_acc/total_samples) end return (train_state.parameters, train_state.states) |> cpu_device() diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 6f49b076d0..93a1b1279a 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -2,7 +2,7 @@ module LuxReactantExt using Enzyme: Enzyme, Const, Duplicated, Active using Optimisers: Optimisers -using Reactant: Reactant, @compile, AnyTracedRArray, TracedRArray, TracedRNumber +using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, TracedRNumber using Setfield: @set! using Static: False @@ -13,6 +13,11 @@ Lux.is_extension_loaded(::Val{:Reactant}) = true Utils.to_rarray(x; kwargs...) = Reactant.to_rarray(x; kwargs...) +Utils.contiguous(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(x) + +Utils.eltype(::Type{<:TracedRArray{T, N}}) where {T, N} = T +Utils.eltype(::Type{<:TracedRNumber{T}}) where {T} = T + function Utils.promote_to(::Type{T}, x::Number) where {T <: Number} x isa Reactant.TracedType && return x return Reactant.ConcreteRNumber{T}(x) diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index f9f4519e0a..6d79f2b60f 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -2,3 +2,8 @@ Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(ve # XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g + +# Embedding +function (e::Lux.Embedding)(x::TracedRNumber{<:Reactant.ReactantInt}, ps, st::NamedTuple) + return ps.weight[:, x], st +end diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index c35d5cb054..2462bd252b 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -23,9 +23,24 @@ function compute_gradients_internal(objective_function::F, model, data, ps, st) return dps, loss, stats_wrapper.stats, stats_wrapper.st end +function maybe_dump_to_mlir_file!(f::F, args...) where {F} + if Lux.DUMP_REACTANT_HLO_OPT_MODE[] !== nothing + hlo = @code_hlo optimize=Lux.DUMP_REACTANT_HLO_OPT_MODE[] f(args...) + fname = tempname() * ".mlir" + io = open(fname, "w") + write(io, string(hlo)) + close(io) + @info "HLO dumped to $fname" + end + return +end + function Lux.Training.compute_gradients_impl( backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} + maybe_dump_to_mlir_file!(compute_gradients_internal, objective_function, ts.model, + data, ts.parameters, ts.states) + compiled_gradient_function = @compile compute_gradients_internal( objective_function, ts.model, data, ts.parameters, ts.states) @@ -60,6 +75,9 @@ for inplace in ("!", "") if hasfield(typeof(ts.cache.extras), :update_function) update_function = ts.cache.extras.update_function else + maybe_dump_to_mlir_file!(update_function, ts.optimizer_state, ts.parameters, + grads) + update_function = @compile Optimisers.$(update_fn)( ts.optimizer_state, ts.parameters, grads) @set! ts.cache.extras = merge(ts.cache.extras, (; update_function)) @@ -72,10 +90,12 @@ for inplace in ("!", "") return ts end - # XXX: Should we add a check to ensure the inputs to this function is same as the one - # used in the compiled function? We can re-trigger the compilation with a warning + # XXX: recompile with a warning if new input types are used @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} + maybe_dump_to_mlir_file!($(internal_fn), objective_function, ts.model, data, + ts.parameters, ts.states, ts.optimizer_state) + compiled_grad_and_step_function = @compile $(internal_fn)( objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml index acb9f2ec12..5b095d97d3 100644 --- a/lib/LuxCore/Project.toml +++ b/lib/LuxCore/Project.toml @@ -1,7 +1,7 @@ name = "LuxCore" uuid = "bb33d45b-7691-41d6-9220-0943567d0623" authors = ["Avik Pal and contributors"] -version = "1.2.1" +version = "1.2.2" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/lib/LuxCore/ext/LuxCoreReactantExt.jl b/lib/LuxCore/ext/LuxCoreReactantExt.jl index 3ad0c0dc21..f6e7770964 100644 --- a/lib/LuxCore/ext/LuxCoreReactantExt.jl +++ b/lib/LuxCore/ext/LuxCoreReactantExt.jl @@ -1,6 +1,6 @@ module LuxCoreReactantExt -using LuxCore: AbstractLuxLayer +using LuxCore: AbstractLuxLayer, LuxCore using Reactant: Reactant # Avoid tracing though models since it won't contain anything useful @@ -10,4 +10,7 @@ function Reactant.make_tracer( return model end +LuxCore.replicate(rng::Reactant.TracedRNG) = copy(rng) +LuxCore.replicate(rng::Reactant.ConcreteRNG) = copy(rng) + end diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index b1cfb2d23f..a9ec915cb1 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.4.0" +version = "1.4.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl b/lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl index 87a912bec9..9faff0e511 100644 --- a/lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl +++ b/lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl @@ -15,7 +15,7 @@ for serial in (true, false) opname = serial ? :serial_matmul_loopvec! : :matmul_loopvec! @eval @inline function LuxLib.Impl.$(opname)( C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) - if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN + if !iszero(β) # Special case this because Base.FastMath.mul_fast(NaN, false) = NaN @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) Cⱼₖ = zero(eltype(C)) for I in indices((A, B), (2, 1)) diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl index f0e5ca707c..8eef4df1d3 100644 --- a/lib/LuxLib/src/LuxLib.jl +++ b/lib/LuxLib/src/LuxLib.jl @@ -8,7 +8,7 @@ using Static: Static, known using ChainRulesCore: ChainRulesCore, NoTangent using LuxCore: LuxCore -using MLDataDevices: get_device_type, AbstractGPUDevice +using MLDataDevices: get_device_type, AbstractGPUDevice, ReactantDevice using NNlib: NNlib const Optional{T} = Union{Nothing, T} diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl index 6df0fc8f7a..092ef3bd71 100644 --- a/lib/LuxLib/src/traits.jl +++ b/lib/LuxLib/src/traits.jl @@ -216,6 +216,7 @@ function internal_operation_mode(xs::Tuple) dev = get_device_type(xs) dev <: AbstractGPUDevice && return GPUBroadcastOp{dev}() + dev <: ReactantDevice && return GenericBroadcastOp() # This check needs to be done after the GPU Check known(Utils.unrolled_any(!Traits.fast_scalar_indexing, xs)) && diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml index 4f22301a32..fbe612a00f 100644 --- a/lib/WeightInitializers/Project.toml +++ b/lib/WeightInitializers/Project.toml @@ -1,7 +1,7 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "1.0.5" +version = "1.1.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -18,6 +18,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] @@ -27,6 +28,7 @@ WeightInitializersChainRulesCoreExt = "ChainRulesCore" WeightInitializersGPUArraysExt = "GPUArrays" WeightInitializersMetalExt = ["Metal", "GPUArrays"] WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"] +WeightInitializersReactantExt = "Reactant" [compat] AMDGPU = "0.9.6, 1" @@ -39,6 +41,7 @@ GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" Metal = "1.3.0" Random = "1.10" +Reactant = "0.2.16" SpecialFunctions = "2.4" Statistics = "1.10" julia = "1.10" diff --git a/lib/WeightInitializers/ext/WeightInitializersReactantExt.jl b/lib/WeightInitializers/ext/WeightInitializersReactantExt.jl new file mode 100644 index 0000000000..2abc129b63 --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersReactantExt.jl @@ -0,0 +1,29 @@ +module WeightInitializersReactantExt + +using Random: AbstractRNG +using Reactant: Reactant, TracedUtils, TracedRNG, ConcreteRNG, TracedRArray, + @reactant_overlay +using WeightInitializers: DeviceAgnostic + +# random numbers are automatically handled +for op in (:zeros, :ones) + @eval begin + function DeviceAgnostic.$(op)( + ::ConcreteRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return Reactant.to_rarray($(op)(T, dims...)) + end + + function DeviceAgnostic.$(op)( + ::TracedRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return TracedUtils.promote_to(TracedRArray{T, length(dims)}, $(op)(T, dims...)) + end + + @reactant_overlay @noinline function DeviceAgnostic.$(op)( + ::AbstractRNG, ::Type{T}, dims::Integer... + ) where {T} + return TracedUtils.promote_to(TracedRArray{T, length(dims)}, $(op)(T, dims...)) + end + end +end + +end diff --git a/src/Lux.jl b/src/Lux.jl index 64f0af07f1..ceda3df255 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -33,6 +33,21 @@ const CRC = ChainRulesCore const NAME_TYPE = Union{Nothing, String, Symbol} const Optional{T} = Union{T, Nothing} +const DUMP_REACTANT_HLO_OPT_MODE = Ref{Union{Symbol, Nothing, Bool}}(nothing) + +function __init__() + HLO_DUMP = get(ENV, "LUX_DUMP_REACTANT_HLO_OPTIMIZE", nothing) + if HLO_DUMP !== nothing + if HLO_DUMP == "true" || HLO_DUMP == "1" + DUMP_REACTANT_HLO_OPT_MODE[] = true + elseif HLO_DUMP == "false" || HLO_DUMP == "0" + DUMP_REACTANT_HLO_OPT_MODE[] = false + else + DUMP_REACTANT_HLO_OPT_MODE[] = Symbol(HLO_DUMP) + end + end +end + is_extension_loaded(::Val) = false # Preferences diff --git a/src/extended_ops.jl b/src/extended_ops.jl index 0223d775c3..15ebebfa4f 100644 --- a/src/extended_ops.jl +++ b/src/extended_ops.jl @@ -13,7 +13,7 @@ using EnzymeCore: EnzymeCore using FastClosures: @closure using Static: StaticBool, StaticSymbol, known -using MLDataDevices: get_device_type, AbstractGPUDevice, AbstractDevice +using MLDataDevices: get_device_type, AbstractGPUDevice, ReactantDevice, AbstractDevice using ..Utils: Utils @@ -103,10 +103,12 @@ GPUArray. Additional dispatches for RNN helpers are also provided for `TimeLastIndex` and `BatchLastIndex`. """ -function eachslice(x::AbstractArray, dims::Val) - return eachslice(get_device_type(x), x, dims) -end -function eachslice(::Type{<:AbstractGPUDevice}, x::AbstractArray, ::Val{dims}) where {dims} +eachslice(x::AbstractArray, dims::Val) = eachslice(get_device_type(x), x, dims) + +function eachslice( + ::Type{<:Union{<:ReactantDevice, <:AbstractGPUDevice}}, + x::AbstractArray, ::Val{dims} +) where {dims} return [Utils.contiguous(selectdim(x, dims, i)) for i in axes(x, dims)] end function eachslice(::Type{<:AbstractDevice}, x::AbstractArray, ::Val{dims}) where {dims} diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 305af824c0..0e6d8f8ac0 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -611,22 +611,30 @@ end outputsize(e::Embedding, _, ::AbstractRNG) = (e.out_dims,) -(e::Embedding)(x::Integer, ps, st::NamedTuple) = view(ps.weight, :, x), st -function (e::Embedding)(x::AbstractVector{<:Integer}, ps, st::NamedTuple) +function (e::Embedding)(x::Number, ps, st::NamedTuple) + @assert Utils.eltype(x) <: Integer + return view(ps.weight, :, x), st +end +function (e::Embedding)(x::AbstractVector, ps, st::NamedTuple) + @assert Utils.eltype(x) <: Integer return NNlib.gather(ps.weight, x), st end -function (e::Embedding)(x::AbstractArray{<:Integer}, ps, st::NamedTuple) +function (e::Embedding)(x::AbstractArray, ps, st::NamedTuple) + @assert Utils.eltype(x) <: Integer y, stₙ = e(vec(x), ps, st) return reshape(y, :, size(x)...), stₙ end -function (e::Embedding)(x::NTuple{<:Any, <:Integer}, ps, st::NamedTuple) +function (e::Embedding)(x::NTuple{N, T}, ps, st::NamedTuple) where {N, T} + @assert Utils.eltype(T) <: Integer return view(ps.weight, :, x...), st end -function (e::Embedding)(x::NTuple{<:Any, <:AbstractVector{<:Integer}}, ps, st::NamedTuple) +function (e::Embedding)(x::NTuple{N, <:AbstractVector{T}}, ps, st::NamedTuple) where {N, T} + @assert Utils.eltype(T) <: Integer @argcheck allequal(size, x) DimensionMismatch("Input vectors must have the same shape") return NNlib.gather(ps.weight, x...), st end -function (e::Embedding)(x::NTuple{<:Any, <:AbstractArray{<:Integer}}, ps, st::NamedTuple) +function (e::Embedding)(x::NTuple{N, <:AbstractArray{T}}, ps, st::NamedTuple) where {N, T} + @assert Utils.eltype(T) <: Integer @argcheck allequal(size, x) DimensionMismatch("Input arrays must have the same shape") y, stₙ = e(vec.(x), ps, st) return reshape(y, :, size(first(x))...), stₙ diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 173af3dcd1..37b05155e4 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -1,13 +1,14 @@ @testitem "Aqua: Quality Assurance" tags=[:misc] begin using Aqua, ChainRulesCore, ForwardDiff - Aqua.test_all(Lux; ambiguities=false, piracies=false) + Aqua.test_all(Lux; ambiguities=false, piracies=false, unbound_args=false) Aqua.test_ambiguities(Lux; exclude=[ForwardDiff.jacobian, ForwardDiff.gradient, Lux.AutoDiffInternalImpl.batched_jacobian, Lux.AutoDiffInternalImpl.jacobian_vector_product, Lux.AutoDiffInternalImpl.jacobian_vector_product_impl]) Aqua.test_piracies(Lux; treat_as_own=[Lux.outputsize]) + Aqua.test_unbound_args(Lux; broken=true) end @testitem "Explicit Imports: Quality Assurance" tags=[:misc] begin diff --git a/test/reactant/layer_tests.jl b/test/reactant/layer_tests.jl index e0e0fb5266..b2b5d8021b 100644 --- a/test/reactant/layer_tests.jl +++ b/test/reactant/layer_tests.jl @@ -42,12 +42,14 @@ end Reactant.set_default_backend("cpu") end + dev = reactant_device(; force=true) + @testset for cell in (RNNCell, LSTMCell, GRUCell) model = Recurrence(cell(4 => 4)) ps, st = Lux.setup(rng, model) - ps_ra, st_ra = (ps, st) |> Reactant.to_rarray + ps_ra, st_ra = (ps, st) |> dev x = rand(Float32, 4, 16, 12) - x_ra = x |> Reactant.to_rarray + x_ra = x |> dev y_ra, _ = @jit model(x_ra, ps_ra, st_ra) y, _ = model(x, ps, st)