Skip to content

Commit

Permalink
Merge pull request #1355 from marius311/better_fix_1352
Browse files Browse the repository at this point in the history
full fix to #1352
  • Loading branch information
ToucheSir authored Aug 17, 2023
2 parents 547be70 + 8418647 commit d4562e3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 41 deletions.
5 changes: 3 additions & 2 deletions src/lib/buffer.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
grad_mut(cx::Context, b::Buffer, ::Type=Union{}) =
_get!(() -> fill!(similar(b.data, Any), nothing), cache(cx), b)
# S is the eltype we are about to set into the buffer accumulator, so allocte wide enough
grad_mut(cx::Context, b::Buffer{T}, ::Type{S}=Union{}) where {T<:Number, S<:Number} =
_get!(() -> fill!(similar(b.data, float(promote_type(T, S))), 0), cache(cx), b)

@non_differentiable Buffer(::Any...)

@adjoint function getindex(b::Buffer, i...)
b[i...], function::S) where {S}
grad = grad_mut(__context__, b, S)
b[i...], function)
grad = grad_mut(__context__, b, eltype(Δ))
grad[i...] = accum(grad[i...], Δ)
return
end
Expand Down
45 changes: 6 additions & 39 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1569,46 +1569,13 @@ using Zygote: Buffer
@test ∇W2 == W2
@test ∇x == 6 .* x

@testset "incorrect promotion (#1352)" begin
u = [0.75, 0.5]
p = [-1.5, 0.05, 0.2, 0.01]

# in-place
function g1352!(du, u, p, t)
du[1, 1] = p[3] * u[1] + p[4] * u[2]
du[1, 2] = p[3] * u[1] + p[4] * u[2]
du[2, 1] = p[4] * u[1] + p[3] * u[2]
du[2, 2] = p[4] * u[1] + p[3] * u[2]
return nothing
end
du1_inplace, back_inplace = Zygote.pullback(u, p) do u, p
du = Zygote.Buffer(Matrix{Float64}(undef, 2, 2))
g1352!(du, u, p, 1.0)
return copy(du[:, 1])
end

# out-of-place
function g1352(u, p, t)
du11 = p[3] * u[1] + p[4] * u[2]
du12 = p[3] * u[1] + p[4] * u[2]
du21 = p[4] * u[1] + p[3] * u[2]
du22 = p[4] * u[1] + p[3] * u[2]
return [du11 du12
du21 du22]
end
du1, back = Zygote.pullback(u, p) do u, p
du = g1352(u, p, 1.0)
return du[:, 1]
end
# reduced mwe of #1352
@test Zygote.gradient([0,0]) do x
buf = Zygote.Buffer(similar(x))
buf[:] = x
sum(copy(buf[1:2]))
end == ([1,1],)

# comparison
@test du1_inplace du1
v = randn(2)
∇u_inplace, ∇p_inplace = back_inplace(v)
∇u, ∇p = back(v)
@test ∇u_inplace ∇u
@test ∇p_inplace ∇p
end
end

@testset "AbstractArray Addition / Subtraction / Negation" begin
Expand Down

0 comments on commit d4562e3

Please sign in to comment.