diff --git a/src/Writer.jl b/src/Writer.jl index 6ebea04..b3ca285 100644 --- a/src/Writer.jl +++ b/src/Writer.jl @@ -83,27 +83,43 @@ abstract type JSONContext <: StructuralContext end """ Internal implementation detail. +To handle recursive references in objects/arrays when writing, by default we want +to track references to objects seen and break recursion cycles to avoid stack overflows. +Subtypes of `RecursiveCheckContext` must include two fields in order to allow recursive +cycle checking to work properly when writing: + * `objectids::Set{UInt64}`: set of object ids in the current stack of objects being written + * `recursive_cycle_token::Any`: Any string, `nothing`, or object to be written when a cycle is detected +""" +abstract type RecursiveCheckContext <: JSONContext end + +""" +Internal implementation detail. + Keeps track of the current location in the array or object, which winds and unwinds during serialization. """ -mutable struct PrettyContext{T<:IO} <: JSONContext +mutable struct PrettyContext{T<:IO} <: RecursiveCheckContext io::T step::Int # number of spaces to step state::Int # number of steps at present first::Bool # whether an object/array was just started + objectids::Set{UInt64} + recursive_cycle_token end -PrettyContext(io::IO, step) = PrettyContext(io, step, 0, false) +PrettyContext(io::IO, step, recursive_cycle_token=nothing) = PrettyContext(io, step, 0, false, Set{UInt64}(), recursive_cycle_token) """ Internal implementation detail. For compact printing, which in JSON is fully recursive. """ -mutable struct CompactContext{T<:IO} <: JSONContext +mutable struct CompactContext{T<:IO} <: RecursiveCheckContext io::T first::Bool + objectids::Set{UInt64} + recursive_cycle_token end -CompactContext(io::IO) = CompactContext(io, false) +CompactContext(io::IO, recursive_cycle_token=nothing) = CompactContext(io, false, Set{UInt64}(), recursive_cycle_token) """ Internal implementation detail. @@ -265,12 +281,26 @@ end show_json(io::SC, ::CS, ::Nothing) = show_null(io) show_json(io::SC, ::CS, ::Missing) = show_null(io) +recursive_cycle_check(f, io, s, id) = f() + +function recursive_cycle_check(f, io::RecursiveCheckContext, s, id) + if id in io.objectids + show_json(io, s, io.recursive_cycle_token) + else + push!(io.objectids, id) + f() + delete!(io.objectids, id) + end +end + function show_json(io::SC, s::CS, x::Union{AbstractDict, NamedTuple}) - begin_object(io) - for kv in pairs(x) - show_pair(io, s, kv) + recursive_cycle_check(io, s, objectid(x)) do + begin_object(io) + for kv in pairs(x) + show_pair(io, s, kv) + end + end_object(io) end - end_object(io) end function show_json(io::SC, s::CS, kv::Pair) @@ -280,19 +310,23 @@ function show_json(io::SC, s::CS, kv::Pair) end function show_json(io::SC, s::CS, x::CompositeTypeWrapper) - begin_object(io) - for fn in x.fns - show_pair(io, s, fn, getproperty(x.wrapped, fn)) + recursive_cycle_check(io, s, objectid(x.wrapped)) do + begin_object(io) + for fn in x.fns + show_pair(io, s, fn, getproperty(x.wrapped, fn)) + end + end_object(io) end - end_object(io) end function show_json(io::SC, s::CS, x::Union{AbstractVector, Tuple}) - begin_array(io) - for elt in x - show_element(io, s, elt) + recursive_cycle_check(io, s, objectid(x)) do + begin_array(io) + for elt in x + show_element(io, s, elt) + end + end_array(io) end - end_array(io) end """ diff --git a/test/runtests.jl b/test/runtests.jl index d585154..37d398d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -85,4 +85,36 @@ end end end +mutable struct R1 + id::Int + obj +end + +struct MyCustomWriteContext <: JSON.Writer.RecursiveCheckContext + io + objectids::Set{UInt64} + recursive_cycle_token +end +MyCustomWriteContext(io) = MyCustomWriteContext(io, Set{UInt64}(), nothing) +Base.print(io::MyCustomWriteContext, x::UInt8) = Base.print(io.io, x) +for delegate in [:indent, + :delimit, + :separate, + :begin_array, + :end_array, + :begin_object, + :end_object] +@eval JSON.Writer.$delegate(io::MyCustomWriteContext) = JSON.Writer.$delegate(io.io) +end + +@testset "RecursiveCheckContext" begin + x = R1(1, nothing) + x.obj = x + str = JSON.json(x) + @test str == "{\"id\":1,\"obj\":null}" + io = IOBuffer() + str = JSON.show_json(MyCustomWriteContext(JSON.Writer.CompactContext(io)), JSON.Serializations.StandardSerialization(), x) + @test String(take!(io)) == "{\"id\":1,\"obj\":null}" +end + # Check that printing to the default stdout doesn't fail