Skip to content

Commit

Permalink
Make the _Message API require using Ref{_Message}'s
Browse files Browse the repository at this point in the history
Using plain _Message's with most of these functions is likely incorrect.
  • Loading branch information
JamesWrigley committed May 21, 2024
1 parent 6124ac4 commit 56d3a9b
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/_message.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# zmq_msg_t in the header: char _[64];
const _Message = lib.zmq_msg_t

const _MessageOrRef = Union{_Message,Base.RefValue{_Message}}
const _MessageRef = Base.RefValue{_Message}

function msg_init()
zmsg = Ref{_Message}()
Expand All @@ -22,15 +22,14 @@ function msg_init(nbytes::Int)
end

# note: no finalizer for _Message, so we need to call close manually!
function Base.close(zmsg::_MessageOrRef)
function Base.close(zmsg::_MessageRef)
rc = lib.zmq_msg_close(zmsg)
rc != 0 && throw(StateError(jl_zmq_error_str()))
return nothing
end

Base.length(zmsg::_MessageOrRef) = lib.zmq_msg_size(zmsg) % Int
Base.unsafe_convert(::Type{Ptr{UInt8}}, zmsg::_Message) = Ptr{UInt8}(lib.zmq_msg_data(Ref(zmsg)))
Base.unsafe_convert(::Type{Ptr{UInt8}}, zmsg::Base.RefValue{_Message}) = Ptr{UInt8}(lib.zmq_msg_data(zmsg))
Base.length(zmsg::Base.RefValue{_Message}) = lib.zmq_msg_size(zmsg) % Int
Base.unsafe_convert(::Type{Ptr{UInt8}}, zmsg::_MessageRef) = Ptr{UInt8}(lib.zmq_msg_data(zmsg))

# isbits data, vectors thereof, and strings can be converted to/from _Message

Expand All @@ -57,7 +56,7 @@ function _MessageRef(x::String)
return zmsg
end

function unsafe_copy(::Type{Vector{T}}, zmsg::_MessageOrRef) where {T}
function unsafe_copy(::Type{Vector{T}}, zmsg::_MessageRef) where {T}
isbitstype(T) || throw(MethodError(unsafe_copy, (T, zmsg,)))
n = length(zmsg)
len, remainder = divrem(n, sizeof(T))
Expand All @@ -67,16 +66,16 @@ function unsafe_copy(::Type{Vector{T}}, zmsg::_MessageOrRef) where {T}
return a
end

function unsafe_copy(::Type{T}, zmsg::_MessageOrRef) where {T}
function unsafe_copy(::Type{T}, zmsg::_MessageRef) where {T}
isbitstype(T) || throw(MethodError(unsafe_copy, (T, zmsg,)))
n = length(zmsg)
n == sizeof(T) || error("message length $n ≠ sizeof($T)")
return @preserve zmsg unsafe_load(Ptr{T}(Base.unsafe_convert(Ptr{UInt8}, zmsg)))
end

function unsafe_copy(::Type{String}, zmsg::_MessageOrRef)
function unsafe_copy(::Type{String}, zmsg::_MessageRef)
n = length(zmsg)
return @preserve zmsg unsafe_string(Base.unsafe_convert(Ptr{UInt8}, zmsg), n)
end

unsafe_copy(::Type{IOBuffer}, zmsg::_MessageOrRef) = IOBuffer(unsafe_copy(Vector{UInt8}, zmsg))
unsafe_copy(::Type{IOBuffer}, zmsg::_MessageRef) = IOBuffer(unsafe_copy(Vector{UInt8}, zmsg))

0 comments on commit 56d3a9b

Please sign in to comment.