Skip to content

Commit

Permalink
Channel: drop explicit API change, make always threadsafe internally
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash committed Dec 11, 2018
1 parent a08bb2b commit 385afe8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 27 deletions.
39 changes: 21 additions & 18 deletions base/channels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
abstract type AbstractChannel{T} end

"""
Channel{T}(sz::Int, threadsafe::Bool)
Channel{T}(sz::Int)
Constructs a `Channel` with an internal buffer that can hold a maximum of `sz` objects
of type `T`.
Expand All @@ -12,40 +12,37 @@ of type `T`.
`Channel(0)` constructs an unbuffered channel. `put!` blocks until a matching `take!` is called.
And vice-versa.
If `threadsafe` is true, some API operations (specifically `wait`) require first acquiring
the lock on the Channel object.
Other constructors:
* `Channel(Inf)`: equivalent to `Channel{Any}(typemax(Int))`
* `Channel(sz)`: equivalent to `Channel{Any}(sz)`
"""
mutable struct Channel{T} <: AbstractChannel{T}
cond_take::Condition # waiting for data to become available
cond_wait::Condition # waiting for data to become maybe available
cond_put::Condition # waiting for a writeable slot
cond_take::Threads.Condition # waiting for data to become available
cond_wait::Threads.Condition # waiting for data to become maybe available
cond_put::Threads.Condition # waiting for a writeable slot
state::Symbol
excp::Union{Exception, Nothing} # exception to be thrown when state != :open

data::Vector{T}
sz_max::Int # maximum size of channel

function Channel{T}(sz::Integer, threadsafe::Bool=false) where T
function Channel{T}(sz::Integer) where T
if sz < 0
throw(ArgumentError("Channel size must be either 0, a positive integer or Inf"))
end
lock = threadsafe ? ReentrantLock() : AlwaysLockedST()
cond_put, cond_take = Condition(lock), Condition(lock)
cond_wait = (sz == 0 ? Condition(lock) : cond_take) # wait is distinct from take iff unbuffered
lock = ReentrantLock()
cond_put, cond_take = Threads.Condition(lock), Threads.Condition(lock)
cond_wait = (sz == 0 ? Threads.Condition(lock) : cond_take) # wait is distinct from take iff unbuffered
return new(cond_take, cond_wait, cond_put, :open, nothing, Vector{T}(), sz)
end
end

function Channel{T}(sz::Float64, threadsafe::Bool=false) where T
function Channel{T}(sz::Float64) where T
sz = (sz == Inf ? typemax(Int) : convert(Int, sz))
return Channel{T}(sz)
end
Channel(sz, threadsafe::Bool=false) = Channel{Any}(sz, threadsafe)
Channel(sz) = Channel{Any}(sz)

# special constructors
"""
Expand Down Expand Up @@ -94,8 +91,8 @@ julia> istaskdone(taskref[])
true
```
"""
function Channel(func::Function, threadsafe::Bool=false; ctype=Any, csize=0, taskref=nothing)
chnl = Channel{ctype}(csize, threadsafe)
function Channel(func::Function; ctype=Any, csize=0, taskref=nothing)
chnl = Channel{ctype}(csize)
task = Task(() -> func(chnl))
bind(chnl, task)
yield(task) # immediately start it
Expand Down Expand Up @@ -380,9 +377,15 @@ unlock(c::Channel) = unlock(c.cond_take)
trylock(c::Channel) = trylock(c.cond_take)

function wait(c::Channel)
while !isready(c)
check_channel_state(c)
wait(c.cond_wait)
isready(c) && return
lock(c)
try
while !isready(c)
check_channel_state(c)
wait(c.cond_wait)
end
finally
unlock(c)
end
nothing
end
Expand Down
9 changes: 0 additions & 9 deletions test/channels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@ end

@testset "type conversion in put!" begin
c = Channel{Int64}(0)
lock(c)
@async put!(c, Int32(1))
wait(c)
@test isa(take!(c), Int64)
@test_throws MethodError put!(c, "")
unlock(c)
@assert !islocked(c.cond_take)
end

Expand Down Expand Up @@ -122,7 +120,6 @@ using Distributed
while isopen(cs[i])
yield()
end
i < 3 && foreach(lock, cs)
@test_throws ErrorException wait(cs[i])
@test_throws ErrorException take!(cs[i])
@test_throws ErrorException put!(cs[i], 1)
Expand All @@ -145,11 +142,9 @@ using Distributed
c = Channel(N)
foreach(t -> bind(c, t), tasks)
foreach(schedule, tasks)
lock(c)
@test_throws InvalidStateException wait(c)
@test !isopen(c)
@test ref[] == nth
unlock(c)
@assert !islocked(c.cond_take)

# channeled_tasks
Expand All @@ -158,10 +153,8 @@ using Distributed
chnls, tasks = Base.channeled_tasks(2, tf_chnls1; ctypes=[T,T], csizes=[N,N])
put!(chnls[1], 1)
@test take!(chnls[2]) === 2
foreach(lock, chnls)
@test_throws InvalidStateException wait(chnls[1])
@test_throws InvalidStateException wait(chnls[2])
foreach(unlock, chnls)
@test istaskdone(tasks[1])
@test !isopen(chnls[1])
@test !isopen(chnls[2])
Expand All @@ -183,10 +176,8 @@ using Distributed
yield()
put!(f, 1) # allow tf4 and tf5 to exit after now, eventually closing the channel

foreach(lock, chnls)
@test_throws InvalidStateException wait(chnls[1])
@test_throws InvalidStateException wait(chnls[2])
foreach(unlock, chnls)
@test istaskdone(tasks[1])
@test istaskdone(tasks[2])
@test !isopen(chnls[1])
Expand Down

0 comments on commit 385afe8

Please sign in to comment.