Skip to content

Commit

Permalink
Introduce a recursive cycle check when writing (#345)
Browse files Browse the repository at this point in the history
I noticed that writing can blow up when there are recusrive objects that reference each other.
(For example, in HTTP.jl, `Response` and `Request` reference each other)
This PR proposes a simple API for the `CompactContext` and `PrettyContext` where
`objectid` of objects will be tracked recursively when writing and it's configurable
what should be written out when a recursive cycle is detected.

Custom contexts can "hook in" to this behavior by subtyping `RecursiveCheckContext` and
including the required fields (see docs for new context). Otherwise, there shouldn't
be any functional change to APIs in any way.
  • Loading branch information
quinnj authored Jan 10, 2023
1 parent 4906286 commit 2219019
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 16 deletions.
66 changes: 50 additions & 16 deletions src/Writer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,43 @@ abstract type JSONContext <: StructuralContext end
"""
Internal implementation detail.
To handle recursive references in objects/arrays when writing, by default we want
to track references to objects seen and break recursion cycles to avoid stack overflows.
Subtypes of `RecursiveCheckContext` must include two fields in order to allow recursive
cycle checking to work properly when writing:
* `objectids::Set{UInt64}`: set of object ids in the current stack of objects being written
* `recursive_cycle_token::Any`: Any string, `nothing`, or object to be written when a cycle is detected
"""
abstract type RecursiveCheckContext <: JSONContext end

"""
Internal implementation detail.
Keeps track of the current location in the array or object, which winds and
unwinds during serialization.
"""
mutable struct PrettyContext{T<:IO} <: JSONContext
mutable struct PrettyContext{T<:IO} <: RecursiveCheckContext
io::T
step::Int # number of spaces to step
state::Int # number of steps at present
first::Bool # whether an object/array was just started
objectids::Set{UInt64}
recursive_cycle_token
end
PrettyContext(io::IO, step) = PrettyContext(io, step, 0, false)
PrettyContext(io::IO, step, recursive_cycle_token=nothing) = PrettyContext(io, step, 0, false, Set{UInt64}(), recursive_cycle_token)

"""
Internal implementation detail.
For compact printing, which in JSON is fully recursive.
"""
mutable struct CompactContext{T<:IO} <: JSONContext
mutable struct CompactContext{T<:IO} <: RecursiveCheckContext
io::T
first::Bool
objectids::Set{UInt64}
recursive_cycle_token
end
CompactContext(io::IO) = CompactContext(io, false)
CompactContext(io::IO, recursive_cycle_token=nothing) = CompactContext(io, false, Set{UInt64}(), recursive_cycle_token)

"""
Internal implementation detail.
Expand Down Expand Up @@ -265,12 +281,26 @@ end
show_json(io::SC, ::CS, ::Nothing) = show_null(io)
show_json(io::SC, ::CS, ::Missing) = show_null(io)

recursive_cycle_check(f, io, s, id) = f()

function recursive_cycle_check(f, io::RecursiveCheckContext, s, id)
if id in io.objectids
show_json(io, s, io.recursive_cycle_token)
else
push!(io.objectids, id)
f()
delete!(io.objectids, id)
end
end

function show_json(io::SC, s::CS, x::Union{AbstractDict, NamedTuple})
begin_object(io)
for kv in pairs(x)
show_pair(io, s, kv)
recursive_cycle_check(io, s, objectid(x)) do
begin_object(io)
for kv in pairs(x)
show_pair(io, s, kv)
end
end_object(io)
end
end_object(io)
end

function show_json(io::SC, s::CS, kv::Pair)
Expand All @@ -280,19 +310,23 @@ function show_json(io::SC, s::CS, kv::Pair)
end

function show_json(io::SC, s::CS, x::CompositeTypeWrapper)
begin_object(io)
for fn in x.fns
show_pair(io, s, fn, getproperty(x.wrapped, fn))
recursive_cycle_check(io, s, objectid(x.wrapped)) do
begin_object(io)
for fn in x.fns
show_pair(io, s, fn, getproperty(x.wrapped, fn))
end
end_object(io)
end
end_object(io)
end

function show_json(io::SC, s::CS, x::Union{AbstractVector, Tuple})
begin_array(io)
for elt in x
show_element(io, s, elt)
recursive_cycle_check(io, s, objectid(x)) do
begin_array(io)
for elt in x
show_element(io, s, elt)
end
end_array(io)
end
end_array(io)
end

"""
Expand Down
32 changes: 32 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,36 @@ end
end
end

mutable struct R1
id::Int
obj
end

struct MyCustomWriteContext <: JSON.Writer.RecursiveCheckContext
io
objectids::Set{UInt64}
recursive_cycle_token
end
MyCustomWriteContext(io) = MyCustomWriteContext(io, Set{UInt64}(), nothing)
Base.print(io::MyCustomWriteContext, x::UInt8) = Base.print(io.io, x)
for delegate in [:indent,
:delimit,
:separate,
:begin_array,
:end_array,
:begin_object,
:end_object]
@eval JSON.Writer.$delegate(io::MyCustomWriteContext) = JSON.Writer.$delegate(io.io)
end

@testset "RecursiveCheckContext" begin
x = R1(1, nothing)
x.obj = x
str = JSON.json(x)
@test str == "{\"id\":1,\"obj\":null}"
io = IOBuffer()
str = JSON.show_json(MyCustomWriteContext(JSON.Writer.CompactContext(io)), JSON.Serializations.StandardSerialization(), x)
@test String(take!(io)) == "{\"id\":1,\"obj\":null}"
end

# Check that printing to the default stdout doesn't fail

0 comments on commit 2219019

Please sign in to comment.