Skip to content

Commit

Permalink
Fix atomic put!/take! on an unbuffered RemoteChannel
Browse files Browse the repository at this point in the history
  • Loading branch information
amitmurthy authored and JeffBezanson committed Dec 10, 2018
1 parent 1f5d06b commit 5e72a49
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/Distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import Base: getindex, wait, put!, take!, fetch, isready, push!, length,
using Base: Process, Semaphore, JLOptions, AnyDict, buffer_writes, wait_connected,
VERSION_STRING, binding_module, notify_error, atexit, julia_exename,
julia_cmd, AsyncGenerator, acquire, release, invokelatest,
shell_escape_posixly, uv_error, something, notnothing
shell_escape_posixly, uv_error, something, notnothing, isbuffered

using Serialization, Sockets
import Serialization: serialize, deserialize
Expand Down
35 changes: 33 additions & 2 deletions src/process_messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,34 @@ mutable struct RemoteValue

waitingfor::Int # processor we need to hear from to fill this, or 0

RemoteValue(c) = new(c, BitSet(), 0)
synctake::Union{ReentrantLock, Nothing} # A lock used to synchronize the
# specific case of a local put! / remote take! on an
# unbuffered store. github issue #29932

function RemoteValue(c)
c_is_buffered = false
try
c_is_buffered = isbuffered(c)
catch
end

if c_is_buffered
return new(c, BitSet(), 0, nothing)
else
return new(c, BitSet(), 0, ReentrantLock())
end
end
end

wait(rv::RemoteValue) = wait(rv.c)

# A wrapper type to handle issue #29932 which requires locking / unlocking of
# RemoteValue.synctake outside of lexical scope.
struct SyncTake
v::Any
rv::RemoteValue
end

## core messages: do, call, fetch, wait, ref, put! ##
struct RemoteException <: Exception
pid::Int
Expand Down Expand Up @@ -267,7 +290,15 @@ end
function handle_msg(msg::CallMsg{:call_fetch}, header, r_stream, w_stream, version)
@async begin
v = run_work_thunk(()->msg.f(msg.args...; msg.kwargs...), false)
deliver_result(w_stream, :call_fetch, header.notify_oid, v)
if isa(v, SyncTake)
try
deliver_result(w_stream, :call_fetch, header.notify_oid, v.v)
finally
unlock(v.rv.synctake)
end
else
deliver_result(w_stream, :call_fetch, header.notify_oid, v)
end
end
end

Expand Down
31 changes: 27 additions & 4 deletions src/remotecall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,16 @@ end


put!(rv::RemoteValue, args...) = put!(rv.c, args...)
put_ref(rid, args...) = (put!(lookup_ref(rid), args...); nothing)
function put_ref(rid, caller, args...)
rv = lookup_ref(rid)
put!(rv, args...)
if myid() == caller && rv.synctake !== nothing
# Wait till a "taken" value is serialized out - github issue #29932
lock(rv.synctake)
unlock(rv.synctake)
end
nothing
end

"""
put!(rr::RemoteChannel, args...)
Expand All @@ -569,15 +578,29 @@ Store a set of values to the [`RemoteChannel`](@ref).
If the channel is full, blocks until space is available.
Return the first argument.
"""
put!(rr::RemoteChannel, args...) = (call_on_owner(put_ref, rr, args...); rr)
put!(rr::RemoteChannel, args...) = (call_on_owner(put_ref, rr, myid(), args...); rr)

# take! is not supported on Future

take!(rv::RemoteValue, args...) = take!(rv.c, args...)
function take_ref(rid, caller, args...)
v=take!(lookup_ref(rid), args...)
rv = lookup_ref(rid)
synctake = false
if myid() != caller && rv.synctake !== nothing
# special handling for local put! / remote take! on unbuffered channel
# github issue #29932
synctake = true
lock(rv.synctake)
end

v=take!(rv, args...)
isa(v, RemoteException) && (myid() == caller) && throw(v)
v

if synctake
return SyncTake(v, rv)
else
return v
end
end

"""
Expand Down
20 changes: 20 additions & 0 deletions test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,26 @@ f=Future(id_other)
remote_do(fut->put!(fut, myid()), id_other, f)
@test fetch(f) == id_other

# Github issue #29932
rc_unbuffered = RemoteChannel(()->Channel{Vector{Float64}}(0))

@async begin
# Trigger direct write (no buffering) of largish array
array_sz = Int(Base.SZ_UNBUFFERED_IO/8) + 1
largev = zeros(array_sz)
for i in 1:10
largev[1] = float(i)
put!(rc_unbuffered, largev)
end
end

@test remotecall_fetch(rc -> begin
for i in 1:10
take!(rc)[1] != float(i) && error("Failed")
end
return :OK
end, id_other, rc_unbuffered) == :OK

# github PR #14456
n = DoFullTest ? 6 : 5
for i = 1:10^n
Expand Down

0 comments on commit 5e72a49

Please sign in to comment.