diff --git a/examples/ConditionalVAE/main.jl b/examples/ConditionalVAE/main.jl index 380ae1d39..d9813e833 100644 --- a/examples/ConditionalVAE/main.jl +++ b/examples/ConditionalVAE/main.jl @@ -244,7 +244,8 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f start_time = time() for (i, X) in enumerate(train_dataloader) (_, loss, _, train_state) = Training.single_train_step!( - AutoEnzyme(), loss_function, X, train_state) + AutoEnzyme(), loss_function, X, train_state; return_gradients=Val(false) + ) loss_total += loss total_samples += size(X, ndims(X)) diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 93a1b1279..292132759 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -4,7 +4,7 @@ using Enzyme: Enzyme, Const, Duplicated, Active using Optimisers: Optimisers using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, TracedRNumber using Setfield: @set! -using Static: False +using Static: True, False using Lux: Lux, LuxOps, Training, Utils using Lux.Training: TrainingBackendCache, ReactantBackend diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index 2462bd252..d6c0c1c8d 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -55,7 +55,7 @@ function Lux.Training.compute_gradients_impl( end function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data, - ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} + ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F} grads, loss, stats, st = ts.cache.extras.compiled_gradient_function( obj_fn, ts.model, data, ts.parameters, ts.states) @set! ts.states = st @@ -70,7 +70,7 @@ for inplace in ("!", "") # Ideally users never hit this dispatch but it is still good to have as a fallback @eval function Lux.Training.$(apply_gradients_fn)( - ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}}, grads + ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}}, grads ) if hasfield(typeof(ts.cache.extras), :update_function) update_function = ts.cache.extras.update_function @@ -94,15 +94,15 @@ for inplace in ("!", "") @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} maybe_dump_to_mlir_file!($(internal_fn), objective_function, ts.model, data, - ts.parameters, ts.states, ts.optimizer_state) + ts.parameters, ts.states, ts.optimizer_state, backend.return_gradients) compiled_grad_and_step_function = @compile $(internal_fn)( objective_function, ts.model, data, ts.parameters, ts.states, - ts.optimizer_state) + ts.optimizer_state, backend.return_gradients) grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function( objective_function, ts.model, data, ts.parameters, ts.states, - ts.optimizer_state) + ts.optimizer_state, backend.return_gradients) cache = TrainingBackendCache( backend, False(), nothing, (; compiled_grad_and_step_function)) @@ -116,10 +116,11 @@ for inplace in ("!", "") return grads, loss, stats, ts end - @eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data, - ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} + @eval function Lux.Training.$(fname)(backend::ReactantBackend, obj_fn::F, data, + ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F} grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function( - obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) + obj_fn, ts.model, data, ts.parameters, ts.states, + ts.optimizer_state, backend.return_gradients) @set! ts.states = st @set! ts.parameters = ps @@ -131,7 +132,15 @@ for inplace in ("!", "") # XXX: Inplace version not actually inplace @eval function $(internal_fn)( - objective_function::F, model, data, ps, st, opt_state) where {F} + objective_function::F, model, data, ps, st, opt_state, ::False) where {F} + dps, loss, stats, stₙ = compute_gradients_internal( + objective_function, model, data, ps, st) + opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps) + return nothing, ps, loss, stats, stₙ, opt_state + end + + @eval function $(internal_fn)( + objective_function::F, model, data, ps, st, opt_state, ::True) where {F} dps, loss, stats, stₙ = compute_gradients_internal( objective_function, model, data, ps, st) opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps) diff --git a/src/helpers/training.jl b/src/helpers/training.jl index c11f74b93..45b93b49d 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -7,7 +7,7 @@ using FastClosures: @closure using Functors: Functors, fmap using Optimisers: Optimisers using Setfield: @set! -using Static: StaticBool, Static, False, True +using Static: StaticBool, Static, False, True, static using ..Lux: Lux, Utils, ReactantCompatibleOptimisers using LuxCore: LuxCore, AbstractLuxLayer @@ -104,7 +104,9 @@ function Base.show(io::IO, ::MIME"text/plain", ts::TrainState) print(io, "\n objective_function: ", nameof(typeof(ts.objective_function))) end -struct ReactantBackend end +@concrete struct ReactantBackend + return_gradients <: StaticBool +end const APPLY_GRAD_DOCSTRING = """ ## Arguments @@ -198,10 +200,13 @@ function compute_gradients(ad, obj_fn::F, data, ts::TrainState) where {F} return compute_gradients_impl(maybe_wrap_adtype(ad, dev_type), obj_fn, data, ts) end -maybe_wrap_adtype(backend::ReactantBackend, _) = backend -maybe_wrap_adtype(ad::AbstractADType, _) = ad -function maybe_wrap_adtype(ad::AbstractADType, ::Type{ReactantDevice}) - ad isa AutoEnzyme && return ReactantBackend() +maybe_wrap_adtype(backend::ReactantBackend, ::Any; kwargs...) = backend +maybe_wrap_adtype(ad::AbstractADType, ::Any; kwargs...) = ad +function maybe_wrap_adtype( + ad::AbstractADType, ::Type{ReactantDevice}; + return_gradients::Utils.BoolType=True() +) + ad isa AutoEnzyme && return ReactantBackend(static(return_gradients)) throw(ArgumentError("Computing gradients for models on XLA is supported only with \ Enzyme.jl (`AutoEnzyme`).")) end @@ -258,12 +263,17 @@ function wrap_objective_function( end """ - single_train_step!(backend, obj_fn::F, data, ts::TrainState) + single_train_step!(backend, obj_fn::F, data, ts::TrainState; return_gradients=True()) Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and updates the parameters using [`apply_gradients!`](@ref). All backends supported via [`compute_gradients`](@ref) are supported here. +## Keyword Arguments + + - `return_gradients`: If `True()`, the gradients are returned. If `False()`, the returned + gradients are `nothing`. Defaults to `True()`. This is only used for Reactant Backend. + ## Return Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`, @@ -271,13 +281,15 @@ only the parameters in `ts` are updated inplace. Users should be using the retur object for further training steps, else there is no caching and performance will be suboptimal (and absolutely terrible for backends like `AutoReactant`). """ -function single_train_step!(backend, obj_fn::F, data, ts::TrainState) where {F} - backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states))) +function single_train_step!(backend, obj_fn::F, data, ts::TrainState; + return_gradients::Utils.BoolType=True()) where {F} + backend = maybe_wrap_adtype( + backend, get_device_type((ts.parameters, ts.states)); return_gradients) return single_train_step_impl!(backend, obj_fn, data, ts) end """ - single_train_step(backend, obj_fn::F, data, ts::TrainState) + single_train_step(backend, obj_fn::F, data, ts::TrainState; return_gradients=True()) Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and updates the parameters using [`apply_gradients`](@ref). All backends supported via @@ -285,12 +297,19 @@ updates the parameters using [`apply_gradients`](@ref). All backends supported v In most cases you should use [`single_train_step!`](@ref) instead of this function. +## Keyword Arguments + + - `return_gradients`: If `True()`, the gradients are returned. If `False()`, the returned + gradients are `nothing`. Defaults to `True()`. This is only used for Reactant Backend. + ## Return -Returned values are the same as [`compute_gradients`](@ref). +Returned values are the same as [`single_train_step!`](@ref). """ -function single_train_step(backend, obj_fn::F, data, ts::TrainState) where {F} - backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states))) +function single_train_step(backend, obj_fn::F, data, ts::TrainState; + return_gradients::Utils.BoolType=True()) where {F} + backend = maybe_wrap_adtype( + backend, get_device_type((ts.parameters, ts.states)); return_gradients) return single_train_step_impl(backend, obj_fn, data, ts) end