-
Notifications
You must be signed in to change notification settings - Fork 63
/
extended_ops.jl
244 lines (204 loc) · 7.98 KB
/
extended_ops.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""
LuxOps
This module is a part of `Lux.jl`. It contains operations that are useful in DL context.
Additionally certain operations here alias Base functions to behave more sensibly with
GPUArrays.
"""
module LuxOps
using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, @thunk, @non_differentiable
using Compat: @compat
using EnzymeCore: EnzymeCore
using FastClosures: @closure
using Static: StaticBool, StaticSymbol, known
using MLDataDevices: get_device_type, AbstractGPUDevice, AbstractDevice
using ..Utils: Utils
const CRC = ChainRulesCore
const KnownSymbolType{v} = Union{Val{v}, StaticSymbol{v}}
# `xlogx` and `xlogy`
## We don't use `LogExpFunctions` since they don't support GPU broadcasting. See
## https://github.com/LuxDL/Lux.jl/pull/796. Additionally we have special broadcast rrules.
"""
xlogx(x::Number)
Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get
zero.
"""
function xlogx(x::Number)
result = x * log(x)
return ifelse(iszero(x), zero(result), result)
end
∇xlogx(Δ::Number, logx::Number) = Δ * (logx + true)
function CRC.rrule(::typeof(xlogx), x::Number)
iszero(x) && return x, Returns((NoTangent(), ZeroTangent()))
logx = log(x)
∇xlogx_internal = @closure Δ -> (NoTangent(), @thunk(∇xlogx(Δ, logx)))
return x * logx, ∇xlogx_internal
end
function CRC.rrule(
::typeof(Broadcast.broadcasted), ::typeof(xlogx), x::AbstractArray{<:Number})
logx = log.(x)
∇xlogx_internal = @closure Δ -> (NoTangent(), NoTangent(), @thunk(∇xlogx.(Δ, logx)))
return .*(x, logx), ∇xlogx_internal
end
"""
xlogy(x::Number, y::Number)
Return `x * log(y)` for `y > 0`, and zero when `x == 0`.
"""
function xlogy(x::Number, y::Number)
result = x * log(y)
return ifelse(iszero(x), zero(result), result)
end
∇₁xlogy(Δ::Number, logy::Number) = Δ * logy
∇₂xlogy(Δ::Number, x::Number, y::Number) = Δ * x / y
function CRC.rrule(::typeof(xlogy), x::Number, y::Number)
iszero(x) && return x, Returns((NoTangent(), ZeroTangent()))
logy = log(y)
∇xlogy_internal = @closure Δ -> (
NoTangent(), @thunk(∇₁xlogy(Δ, logy)), @thunk(∇₂xlogy(Δ, x, y)))
return x * logy, ∇xlogy_internal
end
function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(xlogy),
x::AbstractArray{<:Number}, y::AbstractArray{<:Number})
logy = log.(y)
∇xlogy_internal = @closure Δ -> (
NoTangent(), NoTangent(), @thunk(∇₁xlogy.(Δ, logy)), @thunk(∇₂xlogy.(Δ, x, y)))
return .*(x, logy), ∇xlogy_internal
end
"""
getproperty(x, ::Val{v})
getproperty(x, ::StaticSymbol{v})
Similar to `Base.getproperty` but requires a `Val` (or `Static.StaticSymbol`). Additionally,
if `v` is not present in `x`, then `nothing` is returned.
"""
function getproperty(x, ::KnownSymbolType{v}) where {v}
return v ∈ Base.propertynames(x) ? Base.getproperty(x, v) : nothing
end
@generated function getproperty(x::NamedTuple{names}, ::KnownSymbolType{v}) where {names, v}
return v ∈ names ? :(x.$v) : :(nothing)
end
"""
eachslice(x, dims::Val)
Same as `Base.eachslice` but doesn't produce a `SubArray` for the slices if `x` is a
GPUArray.
Additional dispatches for RNN helpers are also provided for `TimeLastIndex` and
`BatchLastIndex`.
"""
function eachslice(x::AbstractArray, dims::Val)
return eachslice(get_device_type(x), x, dims)
end
function eachslice(::Type{<:AbstractGPUDevice}, x::AbstractArray, ::Val{dims}) where {dims}
return [Utils.contiguous(selectdim(x, dims, i)) for i in axes(x, dims)]
end
function eachslice(::Type{<:AbstractDevice}, x::AbstractArray, ::Val{dims}) where {dims}
return [selectdim(x, dims, i) for i in axes(x, dims)]
end
function ∇eachslice(Δ′, x::AbstractArray, ::Val{dims}) where {dims}
Δs = CRC.unthunk(Δ′)
idx = findfirst(Base.Fix2(isa, AbstractArray), Δs)
idx === nothing && return zero.(x)
Δ = similar(x)
fill!(Δ, false)
for i in axes(x, dims)
Δᵢ = selectdim(Δ, dims, i)
copyto!(Δᵢ, Δs[i])
end
return CRC.ProjectTo(x)(Δ)
end
function CRC.rrule(::typeof(eachslice), x::AbstractArray, d::Val{dims}) where {dims}
∇eachslice_internal = @closure Δ -> (NoTangent(), ∇eachslice(Δ, x, d), NoTangent())
return eachslice(x, d), ∇eachslice_internal
end
"""
foldl_init(op, x)
foldl_init(op, x, init)
Exactly same as `foldl(op, x; init)` in the forward pass. But, gives gradients wrt `init`
in the backward pass.
"""
foldl_init(op, x) = foldl_init(op, x, nothing)
foldl_init(op, x, init) = foldl(op, x; init)
function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(foldl_init), op::G, x::Tuple, init) where {G}
x_arr = [x...]
y, ∇foldl_init_internal = CRC.rrule_via_ad(cfg, foldl_init, op, x_arr, init)
∇foldl_init = @closure Δ -> begin
∂foldl_init, ∂op, ∂x, ∂init = ∇foldl_init_internal(Δ)
return ∂foldl_init, ∂op, Tuple(∂x), ∂init
end
return y, ∇foldl_init
end
function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(foldl_init), op::G, x::AbstractArray, init) where {G}
list, start = x, init
accum_func = @closure (a, b) -> CRC.rrule_via_ad(cfg, op, first(a), b)
accum_func_inner = @closure (x1, x2) -> begin
(_d1, dc, _d3) = x1
(_val, back) = x2
return back(dc)
end
hobbits = Vector{Any}(undef, length(list)) # Unfornately Zygote needs this for CUDA
accumulate!(accum_func, hobbits, list; init=(start, nothing))
y = first(last(hobbits))
ax = axes(x)
project = CRC.ProjectTo.(x)
∇foldl_init = Δ -> begin
trio = accumulate(accum_func_inner, reverse(hobbits); init=(0, Δ, 0))
∂op = sum(first, trio)
∂x = reshape(map(last, reverse(trio)), ax)
return (NoTangent(), ∂op,
[proj(∂xᵢ) for (proj, ∂xᵢ) in zip(project, ∂x)], last(trio)[2])
end
return y, ∇foldl_init
end
"""
multigate(x::AbstractArray, ::Val{N})
Split up `x` into `N` equally sized chunks (along dimension `1`).
"""
function multigate(x::AbstractArray, ::Val{N}) where {N}
return ntuple(i -> Utils.gate(x, size(x, 1) ÷ N, i), N)
end
function ∇multigate(Δ, x::AbstractArray, ::Val{N}) where {N}
∂x = similar(x, eltype(x), axes(x))
foreach(multigate(∂x, Val(N)), Δ) do ∂xᵢ, Δᵢ
if Δᵢ isa CRC.AbstractZero
fill!(∂xᵢ, false)
else
∂xᵢ .= Δᵢ
end
end
return CRC.ProjectTo(x)(∂x)
end
function CRC.rrule(::typeof(multigate), x::AbstractArray, c::Val{N}) where {N}
∇multigate_internal = @closure Δ -> (
NoTangent(), @thunk(∇multigate(CRC.unthunk(Δ), x, c)), NoTangent())
return multigate(x, c), ∇multigate_internal
end
"""
istraining(::Val{training})
istraining(::StaticBool)
istraining(::Bool)
istraining(st::NamedTuple)
Returns `true` if `training` is `true` or if `st` contains a `training` field with value
`true`. Else returns `false`.
"""
istraining(::Val{training}) where {training} = training
istraining(training::StaticBool) = known(training)
istraining(training::Bool) = training
istraining(st::NamedTuple) = hasproperty(st, :training) && istraining(st.training)
CRC.@non_differentiable istraining(::Any)
# Public API
@compat(public, (xlogx, xlogy, getproperty, eachslice, foldl_init, multigate, istraining))
end
using .LuxOps: LuxOps, multigate
const safe_getproperty = LuxOps.getproperty
const safe_eachslice = LuxOps.eachslice
# TODO: directly import them from LuxOps from 1.0
const private_xlogx = LuxOps.xlogx
const private_xlogy = LuxOps.xlogy
const private_foldl_init = LuxOps.foldl_init
# These are defined here to avoid a circular dependency among modules
for (op, field) in (:bias => :use_bias, :affine => :affine,
:track_stats => :track_stats, :train_state => :train_state)
@eval function $(Symbol(:has_, op))(l::AbstractLuxLayer)
res = known(safe_getproperty(l, Val($(Meta.quot(field)))))
return ifelse(res === nothing, false, res)
end
end