Skip to content

Commit

Permalink
fix codec initializing
Browse files Browse the repository at this point in the history
  • Loading branch information
Mixfair committed Aug 6, 2022
1 parent 544e283 commit c3a3b7c
Showing 1 changed file with 42 additions and 27 deletions.
69 changes: 42 additions & 27 deletions src/WebSockets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,27 +90,18 @@ function mask!(bytes::Vector{UInt8}, mask)
end
return
end

function compress(data::T) where T <: AbstractVector{UInt8}
compressed = transcode(DeflateCompressor, data)
push!(compressed, 0x00)
return compressed
end

function compress(data::String)
compressed = transcode(DeflateCompressor, data)
push!(compressed, 0x00)
return String(compressed)
function final_deflate_codecs(t::Tuple)
CodecZlib.TranscodingStreams.finalize(t[1])
CodecZlib.TranscodingStreams.finalize(t[2])
end

function decompress(data::T) where T <: AbstractVector{UInt8}
decompressed = transcode(DeflateDecompressor, append!(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00]))
return decompressed
end
function init_deflate_codecs()
codecco = DeflateCompressor()
CodecZlib.TranscodingStreams.initialize(codecco)
codecde = DeflateDecompressor()
CodecZlib.TranscodingStreams.initialize(codecde)

function decompress(data::String)
decompressed = transcode(DeflateDecompressor, append!(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00]))
return String(decompressed)
return (codecco, codecde)
end


Expand Down Expand Up @@ -316,20 +307,21 @@ mutable struct WebSocket
writebuffer::Vector{UInt8}
readclosed::Bool
writeclosed::Bool
isdeflate::Bool
deflate::Union{Nothing, Tuple{CodecZlib.CompressorCodec, CodecZlib.DecompressorCodec}}
end

const DEFAULT_MAX_FRAG = 1024

WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate::Bool=false) =
WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, isdeflate)
WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, isdeflate ? init_deflate_codecs() : nothing)

"""
WebSockets.isclosed(ws) -> Bool
Check whether a `WebSocket` has sent and received CLOSE frames.
"""
isclosed(ws::WebSocket) = ws.readclosed && ws.writeclosed
isdeflate(ws::WebSocket) = !isnothing(ws.deflate)

# Handshake
"Check whether a HTTP.Request or HTTP.Response is a websocket upgrade request/response"
Expand Down Expand Up @@ -534,7 +526,7 @@ function Sockets.send(ws::WebSocket, x)
# so we can appropriately set the FIN bit for the last fragmented frame
nextstate = iterate(x, st)
while true
n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? ws.isdeflate : false))
n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? isdeflate(ws) : false))
first = false
nextstate === nothing && break
item, st = nextstate
Expand All @@ -543,8 +535,8 @@ function Sockets.send(ws::WebSocket, x)
else
# single binary or text frame for message
@label write_single_frame
pl = ws.isdeflate ? compress(payload(ws, x)) : payload(ws, x)
return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=ws.isdeflate))
pl = isdeflate(ws) ? compress(ws, payload(ws, x)) : payload(ws, x)
return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=isdeflate(ws)))
end
end

Expand All @@ -559,7 +551,7 @@ to when a PING message is received by a websocket connection.
function ping(ws::WebSocket, data=UInt8[])
@require !ws.writeclosed
@debugv 2 "$(ws.id): sending ping"
return writeframe(ws.io, Frame(true, PING, ws.client, payload(ws, data)))
return writeframe(ws.io, Frame(true, PING, ws.client, payload(ws, isdeflate(ws) ? compress(ws, data) : data)))
end

"""
Expand Down Expand Up @@ -620,11 +612,34 @@ function Base.close(ws::WebSocket, body::CloseFrameBody=CloseFrameBody(1000, "")
@assert ws.readclosed
# if we're the server, it's our job to close the underlying socket
!ws.client && isopen(ws.io) && close(ws.io)
final_deflate_codecs(ws.deflate)
return
end

# Receiving messages

function compress(ws::WebSocket, data::T) where T <: AbstractVector{UInt8}
compressed = transcode(ws.deflate[1], data)
push!(compressed, 0x00)
return compressed
end

function compress(ws::WebSocket, data::String)
compressed = transcode(ws.deflate[1], data)
push!(compressed, 0x00)
return String(compressed)
end

function decompress(ws::WebSocket, data::T) where T <: AbstractVector{UInt8}
decompressed = transcode(ws.deflate[2], append!(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00]))
return decompressed
end

function decompress(ws::WebSocket, data::String)
decompressed = transcode(ws.deflate[2], append!(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00]))
return String(decompressed)
end

# returns whether additional frames should be read
# true if fragmented message or a ping/pong frame was handled
@noinline control_len_check(len) = len > 125 && throw(WebSocketError(CloseFrameBody(1002, "Invalid length for control frame")))
Expand All @@ -644,7 +659,7 @@ function checkreadframe!(ws::WebSocket, frame::Frame)
if !ws.writeclosed
close(ws)
end
throw(WebSocketError(frame.payload))
throw(WebSocketError(isdeflate(ws) ? decompress(ws, frame.payload) : frame.payload))
elseif opcode == PING
control_len_check(frame.flags.len)
pong(ws, frame.payload)
Expand Down Expand Up @@ -686,7 +701,7 @@ function receive(ws::WebSocket)
done = checkreadframe!(ws, frame)
# common case of reading single non-control frame
if done
payload = ws.isdeflate ? decompress(frame.payload) : frame.payload
payload = isdeflate(ws) ? decompress(ws, frame.payload) : frame.payload
payload isa String && utf8check(payload)
return payload
end
Expand All @@ -704,7 +719,7 @@ function receive(ws::WebSocket)
end
done && break
end
payload = ws.isdeflate ? decompress(payload) : payload
payload = isdeflate(ws) ? decompress(ws, payload) : payload
payload isa String && utf8check(payload)
@debugv 2 "Read message: $(payload[1:min(1024, sizeof(payload))])"
return payload
Expand Down

0 comments on commit c3a3b7c

Please sign in to comment.