Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Channel: add ability to handle threads #30186

Merged
merged 1 commit into from
Jan 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 118 additions & 102 deletions base/channels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,30 @@ Other constructors:
* `Channel(sz)`: equivalent to `Channel{Any}(sz)`
"""
mutable struct Channel{T} <: AbstractChannel{T}
cond_take::Condition # waiting for data to become available
cond_put::Condition # waiting for a writeable slot
cond_take::Threads.Condition # waiting for data to become available
cond_wait::Threads.Condition # waiting for data to become maybe available
cond_put::Threads.Condition # waiting for a writeable slot
state::Symbol
excp::Union{Exception, Nothing} # exception to be thrown when state != :open
excp::Union{Exception, Nothing} # exception to be thrown when state != :open

data::Vector{T}
sz_max::Int # maximum size of channel

# Used when sz_max == 0, i.e., an unbuffered channel.
waiters::Int
takers::Vector{Task}
putters::Vector{Task}

function Channel{T}(sz::Float64) where T
if sz == Inf
Channel{T}(typemax(Int))
else
Channel{T}(convert(Int, sz))
end
end
function Channel{T}(sz::Integer) where T
if sz < 0
throw(ArgumentError("Channel size must be either 0, a positive integer or Inf"))
end
ch = new(Condition(), Condition(), :open, nothing, Vector{T}(), sz, 0)
if sz == 0
ch.takers = Vector{Task}()
ch.putters = Vector{Task}()
end
return ch
lock = ReentrantLock()
cond_put, cond_take = Threads.Condition(lock), Threads.Condition(lock)
cond_wait = (sz == 0 ? Threads.Condition(lock) : cond_take) # wait is distinct from take iff unbuffered
return new(cond_take, cond_wait, cond_put, :open, nothing, Vector{T}(), sz)
end
end

function Channel{T}(sz::Float64) where T
sz = (sz == Inf ? typemax(Int) : convert(Int, sz))
return Channel{T}(sz)
end
Channel(sz) = Channel{Any}(sz)

# special constructors
Expand Down Expand Up @@ -122,22 +113,30 @@ isbuffered(c::Channel) = c.sz_max==0 ? false : true

function check_channel_state(c::Channel)
if !isopen(c)
c.excp !== nothing && throw(c.excp)
excp = c.excp
excp !== nothing && throw(excp)
throw(closed_exception())
end
end
"""
close(c::Channel)
close(c::Channel[, excp::Exception])

Close a channel. An exception is thrown by:
Close a channel. An exception (optionally given by `excp`), is thrown by:

* [`put!`](@ref) on a closed channel.
* [`take!`](@ref) and [`fetch`](@ref) on an empty, closed channel.
"""
function close(c::Channel)
c.state = :closed
c.excp = closed_exception()
notify_error(c)
function close(c::Channel, excp::Exception=closed_exception())
lock(c)
try
c.state = :closed
c.excp = excp
notify_error(c.cond_take, excp)
notify_error(c.cond_wait, excp)
notify_error(c.cond_put, excp)
finally
unlock(c)
end
nothing
end
isopen(c::Channel) = (c.state == :open)
Expand Down Expand Up @@ -195,7 +194,7 @@ Stacktrace:
function bind(c::Channel, task::Task)
ref = WeakRef(c)
register_taskdone_hook(task, tsk->close_chnl_on_taskdone(tsk, ref))
c
return c
end

"""
Expand Down Expand Up @@ -225,17 +224,34 @@ function channeled_tasks(n::Int, funcs...; ctypes=fill(Any,n), csizes=fill(0,n))
end

function close_chnl_on_taskdone(t::Task, ref::WeakRef)
if ref.value !== nothing
c = ref.value
!isopen(c) && return
if istaskfailed(t)
c.state = :closed
c.excp = task_result(t)
notify_error(c)
c = ref.value
if c isa Channel
isopen(c) || return
cleanup = () -> try
isopen(c) || return
if istaskfailed(t)
excp = task_result(t)
if excp isa Exception
close(c, excp)
return
end
end
close(c)
return
finally
unlock(c)
end
if trylock(c)
# can't use `lock`, since attempts to task-switch to wait for it
# will just silently fail and leave us with broken state
cleanup()
else
close(c)
# so schedule this to happen once we are finished destroying our task
# (on a new Task)
@async (lock(c); cleanup())
end
end
nothing
end

struct InvalidStateException <: Exception
Expand All @@ -257,33 +273,39 @@ task.
function put!(c::Channel{T}, v) where T
check_channel_state(c)
v = convert(T, v)
isbuffered(c) ? put_buffered(c,v) : put_unbuffered(c,v)
return isbuffered(c) ? put_buffered(c, v) : put_unbuffered(c, v)
end

function put_buffered(c::Channel, v)
while length(c.data) == c.sz_max
wait(c.cond_put)
lock(c)
try
while length(c.data) == c.sz_max
check_channel_state(c)
wait(c.cond_put)
end
push!(c.data, v)
# notify all, since some of the waiters may be on a "fetch" call.
notify(c.cond_take, nothing, true, false)
finally
unlock(c)
end
push!(c.data, v)

# notify all, since some of the waiters may be on a "fetch" call.
notify(c.cond_take, nothing, true, false)
v
return v
end

function put_unbuffered(c::Channel, v)
if length(c.takers) == 0
push!(c.putters, current_task())
c.waiters > 0 && notify(c.cond_take, nothing, false, false)

try
wait()
catch
filter!(x->x!=current_task(), c.putters)
rethrow()
lock(c)
taker = try
while isempty(c.cond_take.waitq)
check_channel_state(c)
notify(c.cond_wait)
wait(c.cond_put)
end
# unfair scheduled version of: notify(c.cond_take, v, false, false); yield()
popfirst!(c.cond_take.waitq)
finally
unlock(c)
end
taker = popfirst!(c.takers)
# unfair version of: schedule(taker, v); yield()
yield(taker, v) # immediately give taker a chance to run, but don't block the current task
return v
end
Expand All @@ -298,8 +320,16 @@ remove the item. `fetch` is unsupported on an unbuffered (0-size) channel.
"""
fetch(c::Channel) = isbuffered(c) ? fetch_buffered(c) : fetch_unbuffered(c)
function fetch_buffered(c::Channel)
wait(c)
c.data[1]
lock(c)
try
while isempty(c.data)
check_channel_state(c)
wait(c.cond_take)
end
return c.data[1]
finally
unlock(c)
end
end
fetch_unbuffered(c::Channel) = throw(ErrorException("`fetch` is not supported on an unbuffered Channel."))

Expand All @@ -314,32 +344,31 @@ task.
"""
take!(c::Channel) = isbuffered(c) ? take_buffered(c) : take_unbuffered(c)
function take_buffered(c::Channel)
wait(c)
v = popfirst!(c.data)
notify(c.cond_put, nothing, false, false) # notify only one, since only one slot has become available for a put!.
v
lock(c)
try
while isempty(c.data)
check_channel_state(c)
wait(c.cond_take)
end
v = popfirst!(c.data)
notify(c.cond_put, nothing, false, false) # notify only one, since only one slot has become available for a put!.
return v
finally
unlock(c)
end
end

popfirst!(c::Channel) = take!(c)

# 0-size channel
function take_unbuffered(c::Channel{T}) where T
check_channel_state(c)
push!(c.takers, current_task())
lock(c)
try
if length(c.putters) > 0
let refputter = Ref(popfirst!(c.putters))
return Base.try_yieldto(refputter) do putter
# if we fail to start putter, put it back in the queue
putter === current_task || pushfirst!(c.putters, putter)
end::T
end
else
return wait()::T
end
catch
filter!(x->x!=current_task(), c.takers)
rethrow()
check_channel_state(c)
notify(c.cond_put, nothing, false, false)
return wait(c.cond_take)::T
finally
unlock(c)
end
end

Expand All @@ -353,39 +382,26 @@ For unbuffered channels returns `true` if there are tasks waiting
on a [`put!`](@ref).
"""
isready(c::Channel) = n_avail(c) > 0
n_avail(c::Channel) = isbuffered(c) ? length(c.data) : length(c.putters)
n_avail(c::Channel) = isbuffered(c) ? length(c.data) : length(c.cond_put.waitq)

wait(c::Channel) = isbuffered(c) ? wait_impl(c) : wait_unbuffered(c)
function wait_impl(c::Channel)
while !isready(c)
check_channel_state(c)
wait(c.cond_take)
end
nothing
end
lock(c::Channel) = lock(c.cond_take)
unlock(c::Channel) = unlock(c.cond_take)
trylock(c::Channel) = trylock(c.cond_take)

function wait_unbuffered(c::Channel)
c.waiters += 1
function wait(c::Channel)
isready(c) && return
lock(c)
try
wait_impl(c)
while !isready(c)
check_channel_state(c)
wait(c.cond_wait)
end
finally
c.waiters -= 1
unlock(c)
end
nothing
end

function notify_error(c::Channel, err)
notify_error(c.cond_take, err)
notify_error(c.cond_put, err)

# release tasks on a `wait()/yieldto()` call (on unbuffered channels)
if !isbuffered(c)
waiters = filter!(t->(t.state == :runnable), vcat(c.takers, c.putters))
foreach(t->schedule(t, err; error=true), waiters)
end
end
notify_error(c::Channel) = notify_error(c, c.excp)

eltype(::Type{Channel{T}}) where {T} = T

show(io::IO, c::Channel) = print(io, "$(typeof(c))(sz_max:$(c.sz_max),sz_curr:$(n_avail(c)))")
Expand All @@ -394,7 +410,7 @@ function iterate(c::Channel, state=nothing)
try
return (take!(c), nothing)
catch e
if isa(e, InvalidStateException) && e.state==:closed
if isa(e, InvalidStateException) && e.state == :closed
return nothing
else
rethrow()
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Distributed/src/Distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Base: getindex, wait, put!, take!, fetch, isready, push!, length,

# imports for use
using Base: Process, Semaphore, JLOptions, AnyDict, buffer_writes, wait_connected,
VERSION_STRING, binding_module, notify_error, atexit, julia_exename,
VERSION_STRING, binding_module, atexit, julia_exename,
julia_cmd, AsyncGenerator, acquire, release, invokelatest,
shell_escape_posixly, uv_error, something, notnothing, isbuffered

Expand Down
12 changes: 7 additions & 5 deletions stdlib/Distributed/src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1083,24 +1083,25 @@ function deregister_worker(pg, pid)
ids = []
tonotify = []
lock(client_refs) do
for (id,rv) in pg.refs
if in(pid,rv.clientset)
for (id, rv) in pg.refs
if in(pid, rv.clientset)
push!(ids, id)
end
if rv.waitingfor == pid
push!(tonotify, (id,rv))
push!(tonotify, (id, rv))
end
end
for id in ids
del_client(pg, id, pid)
end

# throw exception to tasks waiting for this pid
for (id,rv) in tonotify
notify_error(rv.c, ProcessExitedException())
for (id, rv) in tonotify
close(rv.c, ProcessExitedException())
delete!(pg.refs, id)
end
end
return
end


Expand All @@ -1110,6 +1111,7 @@ function interrupt(pid::Integer)
if isa(w, Worker)
manage(w.manager, w.id, w.config, :interrupt)
end
return
end

"""
Expand Down
Loading