diff --git a/src/WebSockets.jl b/src/WebSockets.jl index 14342f49d..7631c9eae 100644 --- a/src/WebSockets.jl +++ b/src/WebSockets.jl @@ -1,6 +1,6 @@ module WebSockets -using Base64, LoggingExtras, UUIDs, Sockets, Random +using Base64, LoggingExtras, UUIDs, Sockets, Random, CodecZlib using MbedTLS: digest, MD_SHA1, SSLContext using ..IOExtras, ..Streams, ..ConnectionPool, ..Messages, ..Conditions, ..Servers import ..open @@ -55,7 +55,7 @@ FrameFlags(final::Bool, opcode::OpCode, masked::Bool, len::Integer; rsv1::Bool=f ) Base.show(io::IO, x::FrameFlags) = - print(io, "FrameFlags(", "final=", x.final, ", ", "opcode=", x.opcode, ", ", "masked=", x.masked, ", ", "len=", x.len, ")") + print(io, "FrameFlags(", "final=", x.final, ", isdeflate=", x.rsv1, ", ", "opcode=", x.opcode, ", ", "masked=", x.masked, ", ", "len=", x.len, ")") primitive type Mask 32 end Base.UInt32(x::Mask) = Base.bitcast(UInt32, x) @@ -90,6 +90,20 @@ function mask!(bytes::Vector{UInt8}, mask) end return end +function final_deflate_codecs(t::Tuple) + CodecZlib.TranscodingStreams.finalize(t[1]) + CodecZlib.TranscodingStreams.finalize(t[2]) +end + +function init_deflate_codecs() + codecco = DeflateCompressor() + CodecZlib.TranscodingStreams.initialize(codecco) + codecde = DeflateDecompressor() + CodecZlib.TranscodingStreams.initialize(codecde) + + return (codecco, codecde) +end + # send method Frame constructor function Frame(final::Bool, opcode::OpCode, client::Bool, payload::AbstractVector{UInt8}; rsv1::Bool=false, rsv2::Bool=false, rsv3::Bool=false) @@ -293,12 +307,13 @@ mutable struct WebSocket writebuffer::Vector{UInt8} readclosed::Bool writeclosed::Bool + deflate::Union{Nothing, Tuple{CodecZlib.CompressorCodec, CodecZlib.DecompressorCodec}} end const DEFAULT_MAX_FRAG = 1024 -WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG) = - WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false) +WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate::Bool=false) = + WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, isdeflate ? init_deflate_codecs() : nothing) """ WebSockets.isclosed(ws) -> Bool @@ -306,6 +321,7 @@ WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, max Check whether a `WebSocket` has sent and received CLOSE frames. """ isclosed(ws::WebSocket) = ws.readclosed && ws.writeclosed +isdeflate(ws::WebSocket) = !isnothing(ws.deflate) # Handshake "Check whether a HTTP.Request or HTTP.Response is a websocket upgrade request/response" @@ -347,7 +363,7 @@ WebSockets.open(url) do ws end ``` """ -function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...) +function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate=false, kw...) key = base64encode(rand(Random.RandomDevice(), UInt8, 16)) headers = [ "Upgrade" => "websocket", @@ -363,13 +379,14 @@ function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, if header(http, "Sec-WebSocket-Accept") != hashedkey(key) throw(WebSocketError("Invalid Sec-WebSocket-Accept\n" * "$(http.message)")) end + isdeflate = occursin("permessage-deflate", header(http, "Sec-Websocket-Extensions")) # later stream logic checks to see if the HTTP message is "complete" # by seeing if ntoread is 0, which is typemax(Int) for websockets by default # so set it to 0 so it's correctly viewed as "complete" once we're done # doing websocket things http.ntoread = 0 io = http.stream - ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation) + ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation, isdeflate) @debugv 2 "$(ws.id): WebSocket opened" try f(ws) @@ -416,7 +433,8 @@ function listen end listen(f, args...; kw...) = Servers.listen(http -> upgrade(f, http; kw...), args...; kw...) listen!(f, args...; kw...) = Servers.listen!(http -> upgrade(f, http; kw...), args...; kw...) -function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...) +function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), + maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate=false, kw...) @debugv 2 "Server websocket upgrade requested" isupgrade(http.message) || handshakeerror() if !hasheader(http, "Sec-WebSocket-Version", "13") @@ -430,10 +448,11 @@ function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=f setheader(http, "Connection" => "Upgrade") key = header(http, "Sec-WebSocket-Key") setheader(http, "Sec-WebSocket-Accept" => hashedkey(key)) + isdeflate && setheader(http, "Sec-Websocket-Extensions" => "permessage-deflate; client_no_context_takeover") startwrite(http) io = http.stream req = http.message - ws = WebSocket(io, req, req.response; client=false, maxframesize, maxfragmentation) + ws = WebSocket(io, req, req.response; client=false, maxframesize, maxfragmentation, isdeflate) @debugv 2 "$(ws.id): WebSocket upgraded; connection established" try f(ws) @@ -507,7 +526,7 @@ function Sockets.send(ws::WebSocket, x) # so we can appropriately set the FIN bit for the last fragmented frame nextstate = iterate(x, st) while true - n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item))) + n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? isdeflate(ws) : false)) first = false nextstate === nothing && break item, st = nextstate @@ -516,7 +535,8 @@ function Sockets.send(ws::WebSocket, x) else # single binary or text frame for message @label write_single_frame - return writeframe(ws.io, Frame(true, opcode(x), ws.client, payload(ws, x))) + pl = isdeflate(ws) ? compress(ws, payload(ws, x)) : payload(ws, x) + return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=isdeflate(ws))) end end @@ -531,7 +551,7 @@ to when a PING message is received by a websocket connection. function ping(ws::WebSocket, data=UInt8[]) @require !ws.writeclosed @debugv 2 "$(ws.id): sending ping" - return writeframe(ws.io, Frame(true, PING, ws.client, payload(ws, data))) + return writeframe(ws.io, Frame(true, PING, ws.client, payload(ws, isdeflate(ws) ? compress(ws, data) : data))) end """ @@ -592,18 +612,41 @@ function Base.close(ws::WebSocket, body::CloseFrameBody=CloseFrameBody(1000, "") @assert ws.readclosed # if we're the server, it's our job to close the underlying socket !ws.client && isopen(ws.io) && close(ws.io) + final_deflate_codecs(ws.deflate) return end # Receiving messages +function compress(ws::WebSocket, data::T) where T <: AbstractVector{UInt8} + compressed = transcode(ws.deflate[1], data) + push!(compressed, 0x00) + return compressed +end + +function compress(ws::WebSocket, data::String) + compressed = transcode(ws.deflate[1], data) + push!(compressed, 0x00) + return String(compressed) +end + +function decompress(ws::WebSocket, data::T) where T <: AbstractVector{UInt8} + decompressed = transcode(ws.deflate[2], append!(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + return decompressed +end + +function decompress(ws::WebSocket, data::String) + decompressed = transcode(ws.deflate[2], append!(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + return String(decompressed) +end + # returns whether additional frames should be read # true if fragmented message or a ping/pong frame was handled @noinline control_len_check(len) = len > 125 && throw(WebSocketError(CloseFrameBody(1002, "Invalid length for control frame"))) @noinline utf8check(x) = isvalid(x) || throw(WebSocketError(CloseFrameBody(1007, "Invalid UTF-8"))) function checkreadframe!(ws::WebSocket, frame::Frame) - if frame.flags.rsv1 || frame.flags.rsv2 || frame.flags.rsv3 + if frame.flags.rsv2 || frame.flags.rsv3 throw(WebSocketError(CloseFrameBody(1002, "Reserved bits set in control frame"))) end opcode = frame.flags.opcode @@ -616,7 +659,7 @@ function checkreadframe!(ws::WebSocket, frame::Frame) if !ws.writeclosed close(ws) end - throw(WebSocketError(frame.payload)) + throw(WebSocketError(isdeflate(ws) ? decompress(ws, frame.payload) : frame.payload)) elseif opcode == PING control_len_check(frame.flags.len) pong(ws, frame.payload) @@ -624,8 +667,6 @@ function checkreadframe!(ws::WebSocket, frame::Frame) elseif opcode == PONG control_len_check(frame.flags.len) return false - elseif frame.flags.final && frame.flags.opcode == TEXT && frame.payload isa String - utf8check(frame.payload) end return frame.flags.final end @@ -659,7 +700,11 @@ function receive(ws::WebSocket) @debugv 2 "$(ws.id): Received frame: $frame" done = checkreadframe!(ws, frame) # common case of reading single non-control frame - done && return frame.payload + if done + payload = isdeflate(ws) ? decompress(ws, frame.payload) : frame.payload + payload isa String && utf8check(payload) + return payload + end opcode = frame.flags.opcode iscontrol(opcode) && return receive(ws) # if we're here, we're reading a fragmented message @@ -674,6 +719,7 @@ function receive(ws::WebSocket) end done && break end + payload = isdeflate(ws) ? decompress(ws, payload) : payload payload isa String && utf8check(payload) @debugv 2 "Read message: $(payload[1:min(1024, sizeof(payload))])" return payload