From f2e3ac50cc45468256c12aaee637915ca0b594ee Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 14 Jul 2022 13:34:04 -0400 Subject: [PATCH] Shorten printing of `Tangent` and thunks (#564) * don't print internal details of thunks * don't print super-long primal types for Tangent * add tests * version * fix Int32, comments * fix tests on 1.5 --- Project.toml | 2 +- src/tangent_types/tangent.jl | 13 ++++++++++++- src/tangent_types/thunks.jl | 16 ++++++++++++++-- test/tangent_types/tangent.jl | 8 ++++++++ test/tangent_types/thunks.jl | 12 ++++++++++++ 5 files changed, 47 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 782129559..bc1ec7b45 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.1" +version = "1.15.2" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/tangent.jl index 38c99eefa..f187cb3f2 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/tangent.jl @@ -75,7 +75,18 @@ Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) function Base.show(io::IO, tangent::Tangent{P}) where {P} print(io, "Tangent{") - show(io, P) + str = sprint(show, P, context = io) + i = findfirst('{', str) + if isnothing(i) + print(io, str) + else # for Tangent{T{A,B,C}}(stuff), print {A,B,C} in grey, and trim this part if longer than a line: + print(io, str[1:prevind(str, i)]) + if length(str) < 80 + printstyled(io, str[i:end], color=:light_black) + else + printstyled(io, str[i:prevind(str, 80)], "...", color=:light_black) + end + end print(io, "}") if isempty(backing(tangent)) print(io, "()") # so it doesn't show `NamedTuple()` diff --git a/src/tangent_types/thunks.jl b/src/tangent_types/thunks.jl index 6e5a3540d..3735307b0 100644 --- a/src/tangent_types/thunks.jl +++ b/src/tangent_types/thunks.jl @@ -196,7 +196,13 @@ end function Base.show(io::IO, x::Thunk) print(io, "Thunk(") - show(io, x.f) + str = sprint(show, x.f, context = io) # often this name is like "ChainRules.var"#1398#1403"{Matrix{Float64}, Matrix{Float64}}" + ind = findfirst("var\"#", str) + if isnothing(ind) || length(str) < 80 + printstyled(io, str, color=:light_black) + else + printstyled(io, str[1:ind[5]], "...", color=:light_black) + end print(io, ")") end @@ -223,7 +229,13 @@ unthunk(x::InplaceableThunk) = unthunk(x.val) function Base.show(io::IO, x::InplaceableThunk) print(io, "InplaceableThunk(") - show(io, x.add!) + str = sprint(show, x.add!, context = io) + ind = findfirst("var\"#", str) # look for auto-generated function names, often with huge types + if isnothing(ind) + printstyled(io, str, color=:light_black) + else + printstyled(io, str[1:ind[5]], "...", color=:light_black) + end print(io, ", ") show(io, x.val) print(io, ")") diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl index 1e6b27878..176dd4985 100644 --- a/test/tangent_types/tangent.jl +++ b/test/tangent_types/tangent.jl @@ -380,4 +380,12 @@ end c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) @test nt + c == (; a=1, b=2.1) end + + @testset "printing" begin + t5 = Tuple(rand(3)) + nt3 = (x=t5, y=t5, z=nothing) + tang = ProjectTo(nt3)(nt3) # moderately complicated Tangent + @test contains(sprint(show, tang), "...}(x = Tangent") # gets shortened + @test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole + end end diff --git a/test/tangent_types/thunks.jl b/test/tangent_types/thunks.jl index 0d29dd706..4b3ab4eb4 100644 --- a/test/tangent_types/thunks.jl +++ b/test/tangent_types/thunks.jl @@ -190,4 +190,16 @@ @test scal!(2, 2.0, v, 1) == scal!(2, @thunk(2.0), v, 1) @test_throws MutateThunkException LAPACK.trsyl!('C', 'C', m, m, @thunk(m)) end + + @testset "printing" begin + @test !contains(sprint(show, @thunk 1+1), "...") # short thunks not abbreviated + th = let x = rand(100) + @thunk x .+ x' + end + @test contains(sprint(show, th), "...") # but long ones are + + @test contains(sprint(show, InplaceableThunk(mul!, th)), "mul!") # named functions left in InplaceableThunk + str = sprint(show, InplaceableThunk(z -> z .+ ones(100), th)) + @test length(findall("...", str)) == 2 # now both halves shortened + end end