Skip to content

Commit

Permalink
Merge pull request #662 from JuliaDiff/ox/donot_zero
Browse files Browse the repository at this point in the history
Define a bunch of zero_tangents that should just NoTangent
  • Loading branch information
oxinabox authored Feb 8, 2024
2 parents 0385ea8 + a105ba9 commit 6a8c3c2
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
19 changes: 16 additions & 3 deletions src/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ function zero_tangent end

zero_tangent(x::Number) = zero(x)

zero_tangent(::Type) = NoTangent()

function zero_tangent(x::MutableTangent{P}) where {P}
zb = backing(zero_tangent(backing(x)))
return MutableTangent{P}(zb)
Expand Down Expand Up @@ -171,10 +169,25 @@ function zero_tangent(x::Array{P,N}) where {P,N}
return y
end

function zero_tangent(::T) where {K,V,T<:AbstractDict{K,V}}
return Tangent{T}(Dict{K,guess_zero_tangent_type(V)}())
end

# Sad heauristic methods we need because of unassigned values
guess_zero_tangent_type(::Type{T}) where {T<:Number} = T
guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T)))
function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N}
return Array{guess_zero_tangent_type(T),N}
end
guess_zero_tangent_type(T::Type) = Any
guess_zero_tangent_type(T::Type) = Any

# Stuff that conceptually has its own identity regardless of structual implementation and doesn't have a tangent
zero_tangent(::Base.AbstractLogger) = NoTangent()

# Prevent zero_tangent going wild on the internals
zero_tangent(::Type) = NoTangent()
zero_tangent(::Expr) = NoTangent()
zero_tangent(::Core.Compiler.AbstractInterpreter) = NoTangent()
zero_tangent(::Core.Compiler.InstructionStream) = NoTangent()
zero_tangent(::Core.CodeInfo) = NoTangent()
zero_tangent(::Core.MethodInstance) = NoTangent()
2 changes: 1 addition & 1 deletion src/tangent_types/structural_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ function Tangent{P}() where {P<:Tuple}
return Tangent{P,typeof(backing)}(backing)
end

function Tangent{P}(d::Dict) where {P<:Dict}
function Tangent{P}(d::Dict) where {P<:AbstractDict}
return Tangent{P,typeof(d)}(d)
end

Expand Down
14 changes: 14 additions & 0 deletions test/tangent_types/abstract_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ end

@test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0)

@test ==(
zero_tangent(Dict{Int, Float64}(1 => 2.4)),
Tangent{Dict{Int,Float64}}(Dict{Int, Float64}())
)
if isdefined(Base, :PersistentDict)
@test ==(
zero_tangent(Base.PersistentDict(1 => 2.4)),
Tangent{Base.PersistentDict{Int,Float64}}(Dict{Int, Float64}())
)
end


# Higher order
# StructuralTangents are valid tangents for themselves (just like Numbers)
# and indeed we prefer that, otherwise higher order structural tangents are kinda
Expand All @@ -200,6 +212,8 @@ end
@test iszero(zero_tangent(:abc))
@test iszero(zero_tangent("abc"))
@test iszero(zero_tangent(sin))

@test iszero(zero_tangent(:(1 + 1)))
end

@testset "undef elements Vector" begin
Expand Down

0 comments on commit 6a8c3c2

Please sign in to comment.