From 7d9f8f7763631d3c6bbcae2cbbc3902903bd221e Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Wed, 28 Nov 2018 16:15:06 -0500 Subject: [PATCH] Channel: upgrade to make threadsafe internally This drops the internal `notify_error` API, use `close` instead. --- base/channels.jl | 220 ++++++++++++++------------ stdlib/Distributed/src/Distributed.jl | 2 +- stdlib/Distributed/src/cluster.jl | 12 +- test/channels.jl | 43 ++--- 4 files changed, 150 insertions(+), 127 deletions(-) diff --git a/base/channels.jl b/base/channels.jl index e8cf5a977dfbe..27f1719341375 100644 --- a/base/channels.jl +++ b/base/channels.jl @@ -23,39 +23,30 @@ Other constructors: * `Channel(sz)`: equivalent to `Channel{Any}(sz)` """ mutable struct Channel{T} <: AbstractChannel{T} - cond_take::Condition # waiting for data to become available - cond_put::Condition # waiting for a writeable slot + cond_take::Threads.Condition # waiting for data to become available + cond_wait::Threads.Condition # waiting for data to become maybe available + cond_put::Threads.Condition # waiting for a writeable slot state::Symbol - excp::Union{Exception, Nothing} # exception to be thrown when state != :open + excp::Union{Exception, Nothing} # exception to be thrown when state != :open data::Vector{T} sz_max::Int # maximum size of channel - # Used when sz_max == 0, i.e., an unbuffered channel. - waiters::Int - takers::Vector{Task} - putters::Vector{Task} - - function Channel{T}(sz::Float64) where T - if sz == Inf - Channel{T}(typemax(Int)) - else - Channel{T}(convert(Int, sz)) - end - end function Channel{T}(sz::Integer) where T if sz < 0 throw(ArgumentError("Channel size must be either 0, a positive integer or Inf")) end - ch = new(Condition(), Condition(), :open, nothing, Vector{T}(), sz, 0) - if sz == 0 - ch.takers = Vector{Task}() - ch.putters = Vector{Task}() - end - return ch + lock = ReentrantLock() + cond_put, cond_take = Threads.Condition(lock), Threads.Condition(lock) + cond_wait = (sz == 0 ? Threads.Condition(lock) : cond_take) # wait is distinct from take iff unbuffered + return new(cond_take, cond_wait, cond_put, :open, nothing, Vector{T}(), sz) end end +function Channel{T}(sz::Float64) where T + sz = (sz == Inf ? typemax(Int) : convert(Int, sz)) + return Channel{T}(sz) +end Channel(sz) = Channel{Any}(sz) # special constructors @@ -122,22 +113,30 @@ isbuffered(c::Channel) = c.sz_max==0 ? false : true function check_channel_state(c::Channel) if !isopen(c) - c.excp !== nothing && throw(c.excp) + excp = c.excp + excp !== nothing && throw(excp) throw(closed_exception()) end end """ - close(c::Channel) + close(c::Channel[, excp::Exception]) -Close a channel. An exception is thrown by: +Close a channel. An exception (optionally given by `excp`), is thrown by: * [`put!`](@ref) on a closed channel. * [`take!`](@ref) and [`fetch`](@ref) on an empty, closed channel. """ -function close(c::Channel) - c.state = :closed - c.excp = closed_exception() - notify_error(c) +function close(c::Channel, excp::Exception=closed_exception()) + lock(c) + try + c.state = :closed + c.excp = excp + notify_error(c.cond_take, excp) + notify_error(c.cond_wait, excp) + notify_error(c.cond_put, excp) + finally + unlock(c) + end nothing end isopen(c::Channel) = (c.state == :open) @@ -195,7 +194,7 @@ Stacktrace: function bind(c::Channel, task::Task) ref = WeakRef(c) register_taskdone_hook(task, tsk->close_chnl_on_taskdone(tsk, ref)) - c + return c end """ @@ -225,17 +224,34 @@ function channeled_tasks(n::Int, funcs...; ctypes=fill(Any,n), csizes=fill(0,n)) end function close_chnl_on_taskdone(t::Task, ref::WeakRef) - if ref.value !== nothing - c = ref.value - !isopen(c) && return - if istaskfailed(t) - c.state = :closed - c.excp = task_result(t) - notify_error(c) + c = ref.value + if c isa Channel + isopen(c) || return + cleanup = () -> try + isopen(c) || return + if istaskfailed(t) + excp = task_result(t) + if excp isa Exception + close(c, excp) + return + end + end + close(c) + return + finally + unlock(c) + end + if trylock(c) + # can't use `lock`, since attempts to task-switch to wait for it + # will just silently fail and leave us with broken state + cleanup() else - close(c) + # so schedule this to happen once we are finished destroying our task + # (on a new Task) + @async (lock(c); cleanup()) end end + nothing end struct InvalidStateException <: Exception @@ -257,33 +273,39 @@ task. function put!(c::Channel{T}, v) where T check_channel_state(c) v = convert(T, v) - isbuffered(c) ? put_buffered(c,v) : put_unbuffered(c,v) + return isbuffered(c) ? put_buffered(c, v) : put_unbuffered(c, v) end function put_buffered(c::Channel, v) - while length(c.data) == c.sz_max - wait(c.cond_put) + lock(c) + try + while length(c.data) == c.sz_max + check_channel_state(c) + wait(c.cond_put) + end + push!(c.data, v) + # notify all, since some of the waiters may be on a "fetch" call. + notify(c.cond_take, nothing, true, false) + finally + unlock(c) end - push!(c.data, v) - - # notify all, since some of the waiters may be on a "fetch" call. - notify(c.cond_take, nothing, true, false) - v + return v end function put_unbuffered(c::Channel, v) - if length(c.takers) == 0 - push!(c.putters, current_task()) - c.waiters > 0 && notify(c.cond_take, nothing, false, false) - - try - wait() - catch - filter!(x->x!=current_task(), c.putters) - rethrow() + lock(c) + taker = try + while isempty(c.cond_take.waitq) + check_channel_state(c) + notify(c.cond_wait) + wait(c.cond_put) end + # unfair scheduled version of: notify(c.cond_take, v, false, false); yield() + popfirst!(c.cond_take.waitq) + finally + unlock(c) end - taker = popfirst!(c.takers) + # unfair version of: schedule(taker, v); yield() yield(taker, v) # immediately give taker a chance to run, but don't block the current task return v end @@ -298,8 +320,16 @@ remove the item. `fetch` is unsupported on an unbuffered (0-size) channel. """ fetch(c::Channel) = isbuffered(c) ? fetch_buffered(c) : fetch_unbuffered(c) function fetch_buffered(c::Channel) - wait(c) - c.data[1] + lock(c) + try + while isempty(c.data) + check_channel_state(c) + wait(c.cond_take) + end + return c.data[1] + finally + unlock(c) + end end fetch_unbuffered(c::Channel) = throw(ErrorException("`fetch` is not supported on an unbuffered Channel.")) @@ -314,32 +344,31 @@ task. """ take!(c::Channel) = isbuffered(c) ? take_buffered(c) : take_unbuffered(c) function take_buffered(c::Channel) - wait(c) - v = popfirst!(c.data) - notify(c.cond_put, nothing, false, false) # notify only one, since only one slot has become available for a put!. - v + lock(c) + try + while isempty(c.data) + check_channel_state(c) + wait(c.cond_take) + end + v = popfirst!(c.data) + notify(c.cond_put, nothing, false, false) # notify only one, since only one slot has become available for a put!. + return v + finally + unlock(c) + end end popfirst!(c::Channel) = take!(c) # 0-size channel function take_unbuffered(c::Channel{T}) where T - check_channel_state(c) - push!(c.takers, current_task()) + lock(c) try - if length(c.putters) > 0 - let refputter = Ref(popfirst!(c.putters)) - return Base.try_yieldto(refputter) do putter - # if we fail to start putter, put it back in the queue - putter === current_task || pushfirst!(c.putters, putter) - end::T - end - else - return wait()::T - end - catch - filter!(x->x!=current_task(), c.takers) - rethrow() + check_channel_state(c) + notify(c.cond_put, nothing, false, false) + return wait(c.cond_take)::T + finally + unlock(c) end end @@ -353,39 +382,26 @@ For unbuffered channels returns `true` if there are tasks waiting on a [`put!`](@ref). """ isready(c::Channel) = n_avail(c) > 0 -n_avail(c::Channel) = isbuffered(c) ? length(c.data) : length(c.putters) +n_avail(c::Channel) = isbuffered(c) ? length(c.data) : length(c.cond_put.waitq) -wait(c::Channel) = isbuffered(c) ? wait_impl(c) : wait_unbuffered(c) -function wait_impl(c::Channel) - while !isready(c) - check_channel_state(c) - wait(c.cond_take) - end - nothing -end +lock(c::Channel) = lock(c.cond_take) +unlock(c::Channel) = unlock(c.cond_take) +trylock(c::Channel) = trylock(c.cond_take) -function wait_unbuffered(c::Channel) - c.waiters += 1 +function wait(c::Channel) + isready(c) && return + lock(c) try - wait_impl(c) + while !isready(c) + check_channel_state(c) + wait(c.cond_wait) + end finally - c.waiters -= 1 + unlock(c) end nothing end -function notify_error(c::Channel, err) - notify_error(c.cond_take, err) - notify_error(c.cond_put, err) - - # release tasks on a `wait()/yieldto()` call (on unbuffered channels) - if !isbuffered(c) - waiters = filter!(t->(t.state == :runnable), vcat(c.takers, c.putters)) - foreach(t->schedule(t, err; error=true), waiters) - end -end -notify_error(c::Channel) = notify_error(c, c.excp) - eltype(::Type{Channel{T}}) where {T} = T show(io::IO, c::Channel) = print(io, "$(typeof(c))(sz_max:$(c.sz_max),sz_curr:$(n_avail(c)))") @@ -394,7 +410,7 @@ function iterate(c::Channel, state=nothing) try return (take!(c), nothing) catch e - if isa(e, InvalidStateException) && e.state==:closed + if isa(e, InvalidStateException) && e.state == :closed return nothing else rethrow() diff --git a/stdlib/Distributed/src/Distributed.jl b/stdlib/Distributed/src/Distributed.jl index e0afb433a6f3f..1ab793f44d8ee 100644 --- a/stdlib/Distributed/src/Distributed.jl +++ b/stdlib/Distributed/src/Distributed.jl @@ -11,7 +11,7 @@ import Base: getindex, wait, put!, take!, fetch, isready, push!, length, # imports for use using Base: Process, Semaphore, JLOptions, AnyDict, buffer_writes, wait_connected, - VERSION_STRING, binding_module, notify_error, atexit, julia_exename, + VERSION_STRING, binding_module, atexit, julia_exename, julia_cmd, AsyncGenerator, acquire, release, invokelatest, shell_escape_posixly, uv_error, something, notnothing, isbuffered diff --git a/stdlib/Distributed/src/cluster.jl b/stdlib/Distributed/src/cluster.jl index 747919e56b07e..7d1f47c2dd763 100644 --- a/stdlib/Distributed/src/cluster.jl +++ b/stdlib/Distributed/src/cluster.jl @@ -1046,12 +1046,12 @@ function deregister_worker(pg, pid) ids = [] tonotify = [] lock(client_refs) do - for (id,rv) in pg.refs - if in(pid,rv.clientset) + for (id, rv) in pg.refs + if in(pid, rv.clientset) push!(ids, id) end if rv.waitingfor == pid - push!(tonotify, (id,rv)) + push!(tonotify, (id, rv)) end end for id in ids @@ -1059,11 +1059,12 @@ function deregister_worker(pg, pid) end # throw exception to tasks waiting for this pid - for (id,rv) in tonotify - notify_error(rv.c, ProcessExitedException()) + for (id, rv) in tonotify + close(rv.c, ProcessExitedException()) delete!(pg.refs, id) end end + return end @@ -1073,6 +1074,7 @@ function interrupt(pid::Integer) if isa(w, Worker) manage(w.manager, w.id, w.config, :interrupt) end + return end """ diff --git a/test/channels.jl b/test/channels.jl index 91e420e0519d3..9e49badb8ae66 100644 --- a/test/channels.jl +++ b/test/channels.jl @@ -60,6 +60,7 @@ end wait(c) @test isa(take!(c), Int64) @test_throws MethodError put!(c, "") + @assert !islocked(c.cond_take) end @testset "multiple for loops waiting on the same channel" begin @@ -86,14 +87,14 @@ end using Distributed @testset "channels bound to tasks" for N in [0, 10] # Normal exit of task - c=Channel(N) - bind(c, @async (yield();nothing)) + c = Channel(N) + bind(c, @async (yield(); nothing)) @test_throws InvalidStateException take!(c) @test !isopen(c) # Error exception in task - c=Channel(N) - bind(c, @async (yield();error("foo"))) + c = Channel(N) + bind(c, @async (yield(); error("foo"))) @test_throws ErrorException take!(c) @test !isopen(c) @@ -101,22 +102,24 @@ using Distributed cs = [Channel(N) for i in 1:5] tf2 = () -> begin if N > 0 - foreach(c->(@assert take!(c)==2), cs) + foreach(c -> (@assert take!(c) === 2), cs) end yield() error("foo") end task = Task(tf2) - foreach(c->bind(c, task), cs) + foreach(c -> bind(c, task), cs) schedule(task) if N > 0 for i in 1:5 - @test put!(cs[i], 2) == 2 + @test put!(cs[i], 2) === 2 end end for i in 1:5 - while (isopen(cs[i])); yield(); end + while isopen(cs[i]) + yield() + end @test_throws ErrorException wait(cs[i]) @test_throws ErrorException take!(cs[i]) @test_throws ErrorException put!(cs[i], 1) @@ -137,39 +140,41 @@ using Distributed tasks = [Task(()->tf3(i)) for i in 1:5] c = Channel(N) - foreach(t->bind(c,t), tasks) + foreach(t -> bind(c, t), tasks) foreach(schedule, tasks) @test_throws InvalidStateException wait(c) @test !isopen(c) @test ref[] == nth + @assert !islocked(c.cond_take) # channeled_tasks for T in [Any, Int] - chnls, tasks = Base.channeled_tasks(2, (c1,c2)->(@assert take!(c1)==1; put!(c2,2)); ctypes=[T,T], csizes=[N,N]) + tf_chnls1 = (c1, c2) -> (@assert take!(c1) == 1; put!(c2, 2)) + chnls, tasks = Base.channeled_tasks(2, tf_chnls1; ctypes=[T,T], csizes=[N,N]) put!(chnls[1], 1) - @test take!(chnls[2]) == 2 + @test take!(chnls[2]) === 2 @test_throws InvalidStateException wait(chnls[1]) @test_throws InvalidStateException wait(chnls[2]) @test istaskdone(tasks[1]) @test !isopen(chnls[1]) @test !isopen(chnls[2]) - f=Future() - tf4 = (c1,c2) -> begin - @assert take!(c1)==1 + f = Future() + tf4 = (c1, c2) -> begin + @assert take!(c1) === 1 wait(f) end - tf5 = (c1,c2) -> begin - put!(c2,2) + tf5 = (c1, c2) -> begin + put!(c2, 2) wait(f) end chnls, tasks = Base.channeled_tasks(2, tf4, tf5; ctypes=[T,T], csizes=[N,N]) put!(chnls[1], 1) - @test take!(chnls[2]) == 2 + @test take!(chnls[2]) === 2 yield() - put!(f, 1) + put!(f, 1) # allow tf4 and tf5 to exit after now, eventually closing the channel @test_throws InvalidStateException wait(chnls[1]) @test_throws InvalidStateException wait(chnls[2]) @@ -181,7 +186,7 @@ using Distributed # channel tf6 = c -> begin - @assert take!(c)==2 + @assert take!(c) === 2 error("foo") end