diff --git a/src/ZMQ.jl b/src/ZMQ.jl index 99ef452..1c260b0 100644 --- a/src/ZMQ.jl +++ b/src/ZMQ.jl @@ -29,9 +29,9 @@ export include("constants.jl") include("optutil.jl") include("error.jl") +include("context.jl") include("socket.jl") include("sockopts.jl") -include("context.jl") include("message.jl") include("comm.jl") diff --git a/src/context.jl b/src/context.jl index 88bec5b..9ba4aae 100644 --- a/src/context.jl +++ b/src/context.jl @@ -29,6 +29,15 @@ mutable struct Context end end +function Context(f::Function, args...) + ctx = Context(args...) + try + f(ctx) + finally + close(ctx) + end +end + Base.unsafe_convert(::Type{Ptr{Cvoid}}, c::Context) = getfield(c, :data) # define a global context that is initialized lazily diff --git a/src/socket.jl b/src/socket.jl index 2c914e5..2757d75 100644 --- a/src/socket.jl +++ b/src/socket.jl @@ -3,8 +3,7 @@ mutable struct Socket data::Ptr{Cvoid} pollfd::_FDWatcher - # ctx should be ::Context, but forward type references are not allowed - function Socket(ctx, typ::Integer) + function Socket(ctx::Context, typ::Integer) p = ccall((:zmq_socket, libzmq), Ptr{Cvoid}, (Ptr{Cvoid}, Cint), ctx, typ) if p == C_NULL throw(StateError(jl_zmq_error_str())) @@ -18,6 +17,15 @@ mutable struct Socket Socket(typ::Integer) = Socket(context(), typ) end +function Socket(f::Function, args...) + socket = Socket(args...) + try + f(socket) + finally + close(socket) + end +end + Base.unsafe_convert(::Type{Ptr{Cvoid}}, s::Socket) = getfield(s, :data) Base.isopen(socket::Socket) = getfield(socket, :data) != C_NULL @@ -56,4 +64,4 @@ function Sockets.connect(socket::Socket, endpoint::AbstractString) if rc != 0 throw(StateError(jl_zmq_error_str())) end -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index 14e0479..c6b91d5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -112,3 +112,33 @@ end # ZMQ.close(s1); ZMQ.close(s2) # should happen when context is closed ZMQ.close(ZMQ._context) # immediately close global context rather than waiting for exit end + +@testset "ZMQ resource management" begin + local leaked_req_socket, leaked_rep_socket + ZMQ.Socket(ZMQ.REQ) do req_socket + leaked_req_socket = req_socket + + ZMQ.Socket(ZMQ.REP) do rep_socket + leaked_rep_socket = rep_socket + + ZMQ.bind(rep_socket, "inproc://tester") + ZMQ.connect(req_socket, "inproc://tester") + + ZMQ.send(req_socket, "Mr. Watson, come here, I want to see you.") + @test unsafe_string(ZMQ.recv(rep_socket)) == "Mr. Watson, come here, I want to see you." + ZMQ.send(rep_socket, "Coming, Mr. Bell.") + @test unsafe_string(ZMQ.recv(req_socket)) == "Coming, Mr. Bell." + end + + @test !ZMQ.isopen(leaked_rep_socket) + end + @test !ZMQ.isopen(leaked_req_socket) + + local leaked_ctx + ZMQ.Context() do ctx + leaked_ctx = ctx + + @test isopen(ctx) + end + @test !isopen(leaked_ctx) +end