Skip to content

Commit

Permalink
feat: update to Functors v0.5 (#1069)
Browse files Browse the repository at this point in the history
* feat: update LuxCore to latest Functors

* fix: lock in LuxCore versions

* feat: update MLDataDevices to support latest Functors

* test: disable Flux tests for now

* ci: use MLDataDevices in LuxTestUtils

* refactor: use fmap to implement recursive ops

* chore: bump ADTypes to 1.10

* fix: don't preserve structure in checks

* chore: bump minimum versions
  • Loading branch information
avik-pal authored Nov 14, 2024
1 parent 8f3d749 commit fc75290
Show file tree
Hide file tree
Showing 49 changed files with 240 additions and 182 deletions.
14 changes: 12 additions & 2 deletions .buildkite/testing_luxtestutils.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,22 @@ steps:
codecov: true
dirs:
- lib/LuxTestUtils/src
- lib/MLDataDevices/src
- lib/MLDataDevices/ext
command: |
julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils -e '
import Pkg;
Pkg.Registry.update();
Pkg.instantiate();
Pkg.test(; coverage="user")'
dev_pkgs = Pkg.PackageSpec[];
for pkg in ("lib/MLDataDevices",)
push!(dev_pkgs, Pkg.PackageSpec(path=pkg));
end;
Pkg.develop(dev_pkgs);
Pkg.instantiate()'
julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils/test -e '
import Pkg, LuxTestUtils
dir = dirname(pathof(LuxTestUtils))
include(joinpath(dir, "../test/runtests.jl"))'
agents:
queue: "juliagpu"
cuda: "*"
Expand Down
12 changes: 11 additions & 1 deletion .github/workflows/CI_LuxTestUtils.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,18 @@ jobs:
- name: "Install Dependencies and Run Tests"
run: |
import Pkg
dev_pkgs = Pkg.PackageSpec[]
for pkg in ("lib/MLDataDevices",)
push!(dev_pkgs, Pkg.PackageSpec(path=pkg))
end
Pkg.develop(dev_pkgs)
Pkg.Registry.update()
Pkg.instantiate()
Pkg.test(; coverage="user")
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils {0}
- uses: julia-actions/julia-processcoverage@v1
with:
directories: lib/LuxTestUtils/src
directories: lib/LuxTestUtils/src,lib/MLDataDevices/src,lib/MLDataDevices/ext
- uses: codecov/codecov-action@v4
with:
files: lcov.info
Expand All @@ -75,6 +80,11 @@ jobs:
- name: "Install Dependencies and Run Tests"
run: |
import Pkg
dev_pkgs = Pkg.PackageSpec[]
for pkg in ("lib/MLDataDevices",)
push!(dev_pkgs, Pkg.PackageSpec(path=pkg))
end
Pkg.develop(dev_pkgs)
Pkg.Registry.update()
Pkg.instantiate()
Pkg.test(; coverage="user")
Expand Down
18 changes: 9 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.3"
version = "1.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -66,14 +66,14 @@ LuxTrackerExt = "Tracker"
LuxZygoteExt = "Zygote"

[compat]
ADTypes = "1.8.1"
ADTypes = "1.10"
Adapt = "4.1"
ArgCheck = "2.3"
ArrayInterface = "7.10"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
Compat = "4.15"
ComponentArrays = "0.15.16"
Compat = "4.16"
ComponentArrays = "0.15.18"
ConcreteStructs = "0.2.3"
DispatchDoctor = "0.4.12"
Enzyme = "0.13.13"
Expand All @@ -82,20 +82,20 @@ FastClosures = "0.3.2"
Flux = "0.14.25"
ForwardDiff = "0.10.36"
FunctionWrappers = "1.1.3"
Functors = "0.4.12"
Functors = "0.5"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
LossFunctions = "0.11.1"
LuxCore = "1"
LuxCore = "1.2"
LuxLib = "1.3.7"
MLDataDevices = "1.5"
MLDataDevices = "1.6"
MLUtils = "0.4.4"
MPI = "0.20.19"
MacroTools = "0.5.13"
Markdown = "1.10"
NCCL = "0.1.1"
NNlib = "0.9.24"
Optimisers = "0.3.3, 0.4"
Optimisers = "0.3.4, 0.4"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.4"
Expand All @@ -107,7 +107,7 @@ SimpleChains = "0.4.7"
Static = "1.1.1"
StaticArraysCore = "1.4.3"
Statistics = "1.10"
Tracker = "0.2.34"
Tracker = "0.2.36"
WeightInitializers = "1"
Zygote = "0.6.70"
julia = "1.10"
12 changes: 6 additions & 6 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,27 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1.3"
ADTypes = "1.10"
Adapt = "4"
ChainRulesCore = "1.24"
ComponentArrays = "0.15"
ComponentArrays = "0.15.18"
Documenter = "1.4"
DocumenterVitepress = "0.1.3"
Enzyme = "0.13.13"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
Functors = "0.4.12"
Functors = "0.5"
GPUArraysCore = "0.1, 0.2"
KernelAbstractions = "0.9"
LinearAlgebra = "1.10"
Literate = "2.18.0"
Lux = "1"
LuxCUDA = "0.3.2"
LuxCore = "1"
LuxCore = "1.2"
LuxLib = "1.3.4"
LuxTestUtils = "1.5"
MLDataDevices = "1.4"
Optimisers = "0.3.3, 0.4"
MLDataDevices = "1.6"
Optimisers = "0.3.4, 0.4"
Pkg = "1.10"
Printf = "1.10"
Random = "1.10"
Expand Down
6 changes: 3 additions & 3 deletions docs/src/manual/nn_inside_gpu_kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ making it compatible with multiple GPU backends.
input data and let Lux handle the batching internally.

```@example nn_in_gpu_kernels
using Lux, LuxCUDA, Random
using Lux, LuxCUDA, Random, Functors
using KernelAbstractions, StaticArrays
```

Expand All @@ -45,8 +45,8 @@ nn = Chain(Dense(4, 4, relu), Dense(4, 4))
ps, st = Lux.setup(Xoshiro(123), nn)
to_sarray(x) = SArray{Tuple{size(x)...}}(x)
ps_static = Lux.recursive_map(to_sarray, ps)
st_static = Lux.recursive_map(to_sarray, st)
ps_static = fmap(to_sarray, ps)
st_static = fmap(to_sarray, st)
```

First we will run it on CPU.
Expand Down
4 changes: 2 additions & 2 deletions examples/Basics/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.15"
ComponentArrays = "0.15.18"
ForwardDiff = "0.10"
Lux = "1"
LuxCUDA = "0.3"
Optimisers = "0.3.3, 0.4"
Optimisers = "0.3.4, 0.4"
Zygote = "0.6"
4 changes: 2 additions & 2 deletions examples/BayesianNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CairoMakie = "0.12"
Functors = "0.4"
Functors = "0.5"
LinearAlgebra = "1"
Lux = "1"
Random = "1"
Tracker = "0.2"
Tracker = "0.2.36"
Turing = "0.34, 0.35"
Zygote = "0.6.69"
2 changes: 1 addition & 1 deletion examples/GravitationalWaveForm/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"

[compat]
CairoMakie = "0.12"
ComponentArrays = "0.15"
ComponentArrays = "0.15.18"
LineSearches = "7"
Lux = "1"
Optimization = "4"
Expand Down
6 changes: 3 additions & 3 deletions examples/HyperNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
ComponentArrays = "0.15"
ADTypes = "1.10"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.3.3, 0.4"
Optimisers = "0.3.4, 0.4"
Setfield = "1"
Statistics = "1"
Zygote = "0.6"
4 changes: 2 additions & 2 deletions examples/ImageNet/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ ImageMagick = "1"
JLD2 = "0.5.1"
Lux = "1"
LuxCUDA = "0.3.3"
MLDataDevices = "1.3"
MLDataDevices = "1.6"
MLUtils = "0.4.4"
MPI = "0.20.21"
NCCL = "0.1.1"
OneHotArrays = "0.2.5"
Optimisers = "0.3.3, 0.4"
Optimisers = "0.3.4, 0.4"
ParameterSchedulers = "0.4.2"
Random = "1.10"
Setfield = "1.1.1"
Expand Down
4 changes: 2 additions & 2 deletions examples/NeuralODE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ComponentArrays = "0.15"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.3.3, 0.4"
Optimisers = "0.3.4, 0.4"
OrdinaryDiffEqTsit5 = "1"
SciMLSensitivity = "7.63"
Statistics = "1"
Expand Down
2 changes: 1 addition & 1 deletion examples/OptimizationIntegration/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"

[compat]
CairoMakie = "0.12.10"
ComponentArrays = "0.15.17"
ComponentArrays = "0.15.18"
Lux = "1"
LuxCUDA = "0.3.3"
MLUtils = "0.4.4"
Expand Down
4 changes: 2 additions & 2 deletions examples/PINN2DPDE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1.8.1"
ADTypes = "1.10"
CairoMakie = "0.12.10"
Lux = "1"
LuxCUDA = "0.3.3"
MLUtils = "0.4.4"
OnlineStats = "1.7.1"
Optimisers = "0.3.3, 0.4"
Optimisers = "0.3.4, 0.4"
Printf = "1.10"
Random = "1.10"
Statistics = "1.10"
Expand Down
4 changes: 2 additions & 2 deletions examples/PolynomialFitting/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
ADTypes = "1.10"
CairoMakie = "0.12"
Lux = "1"
LuxCUDA = "0.3"
Optimisers = "0.3.3, 0.4"
Optimisers = "0.3.4, 0.4"
Statistics = "1"
Zygote = "0.6"
4 changes: 2 additions & 2 deletions examples/SimpleChains/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
ADTypes = "1.10"
Lux = "1"
MLDatasets = "0.7.14"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.3.3, 0.4"
Optimisers = "0.3.4, 0.4"
Random = "1"
SimpleChains = "0.4.6"
Zygote = "0.6.69"
4 changes: 2 additions & 2 deletions examples/SimpleRNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
ADTypes = "1.10"
JLD2 = "0.5"
Lux = "1"
LuxCUDA = "0.3"
MLUtils = "0.4"
Optimisers = "0.3.3, 0.4"
Optimisers = "0.3.4, 0.4"
Statistics = "1"
Zygote = "0.6"
4 changes: 3 additions & 1 deletion ext/LuxEnzymeExt/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ module LuxEnzymeExt
using ADTypes: AutoEnzyme
using Enzyme: Enzyme, Active, Const, Duplicated
using EnzymeCore: EnzymeCore
using Functors: fmap
using Setfield: @set!
using Static: False, True

using Lux: Lux
using Lux: Lux, Utils
using Lux.Training: TrainingBackendCache, TrainState
using MLDataDevices: isleaf

include("training.jl")

Expand Down
11 changes: 6 additions & 5 deletions ext/LuxEnzymeExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function Lux.Training.compute_gradients_impl(
ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F}
dps = Lux.recursive_make_zero(ts.parameters)
dps = Lux.Training.dparameters(ts.cache)

obj_fn_wrap, st_wrap, stats_wrap = Lux.Training.wrap_objective_function(
obj_fn, ts.model, ts.parameters, ts.states, data, True())
Expand All @@ -22,8 +22,7 @@ const AUTODIFF_CACHE_TYPE = TrainingBackendCache{

function Lux.Training.compute_gradients_impl(
::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_CACHE_TYPE, F}) where {F}
Enzyme.make_zero!(ts.cache.dparameters)
dps = ts.cache.dparameters
dps = Lux.Training.dparameters(ts.cache)

_, loss = Enzyme.autodiff(
EnzymeCore.ReverseWithPrimal, Const(ts.cache.extras.obj_fn), Active,
Expand Down Expand Up @@ -57,14 +56,16 @@ const AUTODIFF_THUNK_CACHE_TYPE = TrainingBackendCache{

function Lux.Training.compute_gradients_impl(::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:AUTODIFF_THUNK_CACHE_TYPE, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
dps = Lux.Training.dparameters(ts.cache)
params = Duplicated(ts.parameters, dps)

tape, (loss, st_, stats), _ = ts.cache.extras.forward(
Const(obj_fn), Const(ts.model), params, Const(ts.states), Const(data))
ts.cache.extras.reverse(
Const(obj_fn), Const(ts.model), params, Const(ts.states), Const(data),
(one(loss), Lux.recursive_make_zero(st_), Lux.recursive_make_zero(stats)), tape)
(one(loss), fmap(Utils.zero, st_; exclude=isleaf),
fmap(Utils.zero, stats; exclude=isleaf)), tape
)

@set! ts.objective_function = obj_fn
@set! ts.states = st_
Expand Down
3 changes: 2 additions & 1 deletion ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module LuxReverseDiffExt
using ADTypes: ADTypes, AbstractADType, AutoReverseDiff
using ArrayInterface: ArrayInterface
using FunctionWrappers: FunctionWrapper
using Functors: fmap
using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal,
@grad_from_chainrules
using Setfield: @set!
Expand All @@ -11,7 +12,7 @@ using Static: False, True
using Lux: Lux, Utils
using Lux.Training: Training, TrainingBackendCache, TrainState
using LuxCore: LuxCore
using MLDataDevices: CPUDevice
using MLDataDevices: CPUDevice, isleaf

include("utils.jl")
include("rules.jl")
Expand Down
Loading

0 comments on commit fc75290

Please sign in to comment.