-
-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Client-Decompression support #123
base: main
Are you sure you want to change the base?
Changes from 13 commits
81785a3
49bc530
7239cae
9fc1e19
8045aef
a43978a
ac9c4e2
90051db
d5cc050
07ec835
d0f66cf
da6fd79
0954614
7c87942
f1cd76e
e90c3d7
32b517b
8e3a62a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#ifndef C_ZLIB_H | ||
#define C_ZLIB_H | ||
|
||
#include <zlib.h> | ||
|
||
static inline int CZlib_inflateInit2(z_streamp strm, int windowBits) { | ||
return inflateInit2(strm, windowBits); | ||
} | ||
|
||
static inline Bytef *CZlib_voidPtr_to_BytefPtr(void *in) { | ||
return (Bytef *)in; | ||
} | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
|
||
public enum Compression { | ||
public struct Algorithm { | ||
enum Base { | ||
case gzip | ||
case deflate | ||
} | ||
|
||
private let base: Base | ||
|
||
var window: CInt { | ||
switch base { | ||
case .deflate: | ||
return 15 | ||
case .gzip: | ||
return 15 + 16 | ||
} | ||
} | ||
|
||
public static let gzip = Self(base: .gzip) | ||
public static let deflate = Self(base: .deflate) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import CZlib | ||
|
||
public enum Decompression { | ||
|
||
public struct Configuration { | ||
public var algorithm: Compression.Algorithm | ||
public var limit: Limit | ||
|
||
public init(algorithm: Compression.Algorithm, limit: Limit) { | ||
self.algorithm = algorithm | ||
self.limit = limit | ||
} | ||
} | ||
|
||
/// Specifies how to limit decompression inflation. | ||
public struct Limit: Sendable { | ||
private enum Base { | ||
case none | ||
case size(Int) | ||
case ratio(Int) | ||
} | ||
|
||
private var limit: Base | ||
|
||
/// No limit will be set. | ||
/// - warning: Setting `limit` to `.none` leaves you vulnerable to denial of service attacks. | ||
public static let none = Limit(limit: .none) | ||
/// Limit will be set on the request body size. | ||
public static func size(_ value: Int) -> Limit { | ||
return Limit(limit: .size(value)) | ||
} | ||
/// Limit will be set on a ratio between compressed body size and decompressed result. | ||
public static func ratio(_ value: Int) -> Limit { | ||
return Limit(limit: .ratio(value)) | ||
} | ||
|
||
func exceeded(compressed: Int, decompressed: Int) -> Bool { | ||
switch self.limit { | ||
case .none: | ||
return false | ||
case .size(let allowed): | ||
return decompressed > allowed | ||
case .ratio(let ratio): | ||
return decompressed > compressed * ratio | ||
} | ||
} | ||
} | ||
|
||
public struct DecompressionError: Error, Equatable, CustomStringConvertible { | ||
|
||
private enum Base: Error, Equatable { | ||
case limit | ||
case inflationError(Int) | ||
case initializationError(Int) | ||
case invalidTrailingData | ||
} | ||
|
||
private var base: Base | ||
|
||
/// The set ``DecompressionLimit`` has been exceeded | ||
public static let limit = Self(base: .limit) | ||
|
||
/// An error occurred when inflating. Error code is included to aid diagnosis. | ||
public static var inflationError: (Int) -> Self = { | ||
Self(base: .inflationError($0)) | ||
} | ||
|
||
/// Decoder could not be initialised. Error code is included to aid diagnosis. | ||
public static var initializationError: (Int) -> Self = { | ||
Self(base: .initializationError($0)) | ||
} | ||
|
||
/// Decompression completed but there was invalid trailing data behind the compressed data. | ||
public static var invalidTrailingData = Self(base: .invalidTrailingData) | ||
|
||
public var description: String { | ||
return String(describing: self.base) | ||
} | ||
} | ||
|
||
struct Decompressor { | ||
private let limit: Limit | ||
private var stream = z_stream() | ||
|
||
init(limit: Limit) { | ||
self.limit = limit | ||
} | ||
|
||
/// Assumes `buffer` is a new empty buffer. | ||
mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer) throws { | ||
let compressedLength = part.readableBytes | ||
var isComplete = false | ||
|
||
while part.readableBytes > 0 && !isComplete { | ||
try self.stream.inflatePart( | ||
input: &part, | ||
output: &buffer, | ||
isComplete: &isComplete | ||
) | ||
|
||
if self.limit.exceeded( | ||
compressed: compressedLength, | ||
decompressed: buffer.writerIndex + 1 | ||
) { | ||
throw DecompressionError.limit | ||
} | ||
} | ||
|
||
if part.readableBytes > 0 { | ||
throw DecompressionError.invalidTrailingData | ||
} | ||
} | ||
|
||
mutating func initializeDecoder(encoding: Compression.Algorithm) throws { | ||
self.stream.zalloc = nil | ||
self.stream.zfree = nil | ||
self.stream.opaque = nil | ||
|
||
let rc = CZlib_inflateInit2(&self.stream, encoding.window) | ||
guard rc == Z_OK else { | ||
throw DecompressionError.initializationError(Int(rc)) | ||
} | ||
} | ||
|
||
mutating func deinitializeDecoder() { | ||
inflateEnd(&self.stream) | ||
} | ||
} | ||
} | ||
|
||
//MARK: - +z_stream | ||
private extension z_stream { | ||
mutating func inflatePart( | ||
input: inout ByteBuffer, | ||
output: inout ByteBuffer, | ||
isComplete: inout Bool | ||
) throws { | ||
let minimumCapacity = input.readableBytes * 4 | ||
try input.readWithUnsafeMutableReadableBytes { pointer in | ||
self.avail_in = UInt32(pointer.count) | ||
self.next_in = CZlib_voidPtr_to_BytefPtr(pointer.baseAddress!) | ||
|
||
defer { | ||
self.avail_in = 0 | ||
self.next_in = nil | ||
self.avail_out = 0 | ||
self.next_out = nil | ||
} | ||
|
||
isComplete = try self.inflatePart(to: &output, minimumCapacity: minimumCapacity) | ||
|
||
return pointer.count - Int(self.avail_in) | ||
} | ||
} | ||
|
||
private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> Bool { | ||
var rc = Z_OK | ||
|
||
try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in | ||
self.avail_out = UInt32(pointer.count) | ||
self.next_out = CZlib_voidPtr_to_BytefPtr(pointer.baseAddress!) | ||
|
||
rc = inflate(&self, Z_SYNC_FLUSH) | ||
guard rc == Z_OK || rc == Z_STREAM_END else { | ||
throw Decompression.DecompressionError.inflationError(Int(rc)) | ||
} | ||
|
||
return pointer.count - Int(self.avail_out) | ||
} | ||
|
||
return rc == Z_STREAM_END | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,20 +25,31 @@ public final class WebSocket { | |
} | ||
|
||
private let channel: Channel | ||
private var onTextCallback: (WebSocket, String) -> () | ||
|
||
private var onTextCallback: ((WebSocket, String) -> ())? | ||
private var onTextBufferCallback: (WebSocket, ByteBuffer) -> () | ||
private var onBinaryCallback: (WebSocket, ByteBuffer) -> () | ||
private var onPongCallback: (WebSocket) -> () | ||
private var onPingCallback: (WebSocket) -> () | ||
|
||
private var frameSequence: WebSocketFrameSequence? | ||
private let type: PeerType | ||
|
||
private var decompressor: Decompression.Decompressor? | ||
|
||
private var waitingForPong: Bool | ||
private var waitingForClose: Bool | ||
private var scheduledTimeoutTask: Scheduled<Void>? | ||
|
||
init(channel: Channel, type: PeerType) { | ||
init(channel: Channel, type: PeerType, decompression: Decompression.Configuration?) throws { | ||
self.channel = channel | ||
self.type = type | ||
if let decompression = decompression { | ||
self.decompressor = Decompression.Decompressor(limit: decompression.limit) | ||
try self.decompressor?.initializeDecoder(encoding: decompression.algorithm) | ||
} | ||
self.onTextCallback = { _, _ in } | ||
self.onTextBufferCallback = { _, _ in } | ||
self.onBinaryCallback = { _, _ in } | ||
self.onPongCallback = { _ in } | ||
self.onPingCallback = { _ in } | ||
|
@@ -51,6 +62,11 @@ public final class WebSocket { | |
self.onTextCallback = callback | ||
} | ||
|
||
/// The same as `onText`, but with raw data instead of the decoded `String`. | ||
public func onTextBuffer(_ callback: @escaping (WebSocket, ByteBuffer) -> ()) { | ||
self.onTextBufferCallback = callback | ||
} | ||
|
||
public func onBinary(_ callback: @escaping (WebSocket, ByteBuffer) -> ()) { | ||
self.onBinaryCallback = callback | ||
} | ||
|
@@ -64,10 +80,10 @@ public final class WebSocket { | |
} | ||
|
||
/// If set, this will trigger automatic pings on the connection. If ping is not answered before | ||
/// the next ping is sent, then the WebSocket will be presumed innactive and will be closed | ||
/// the next ping is sent, then the WebSocket will be presumed inactive and will be closed | ||
/// automatically. | ||
/// These pings can also be used to keep the WebSocket alive if there is some other timeout | ||
/// mechanism shutting down innactive connections, such as a Load Balancer deployed in | ||
/// mechanism shutting down inactive connections, such as a Load Balancer deployed in | ||
/// front of the server. | ||
public var pingInterval: TimeAmount? { | ||
didSet { | ||
|
@@ -233,6 +249,7 @@ public final class WebSocket { | |
} else { | ||
frameSequence = WebSocketFrameSequence(type: frame.opcode) | ||
} | ||
|
||
// append this frame and update the sequence | ||
frameSequence.append(frame) | ||
self.frameSequence = frameSequence | ||
|
@@ -252,12 +269,27 @@ public final class WebSocket { | |
|
||
// if this frame was final and we have a non-nil frame sequence, | ||
// output it to the websocket and clear storage | ||
if let frameSequence = self.frameSequence, frame.fin { | ||
if var frameSequence = self.frameSequence, frame.fin { | ||
switch frameSequence.type { | ||
case .binary: | ||
self.onBinaryCallback(self, frameSequence.binaryBuffer) | ||
if decompressor != nil { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We still need to support swifts older than 5.7. Other than that, iirc we can't trigger a compiler copy of the decompressor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Oops, I forgot about that. 😆
Hmm interesting. Then should There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then, we will have ARC overhead, which we currently don't have. |
||
do { | ||
var buffer = ByteBuffer() | ||
try decompressor!.decompress(part: &frameSequence.buffer, buffer: &buffer) | ||
|
||
self.onBinaryCallback(self, buffer) | ||
} catch { | ||
self.close(code: .protocolError, promise: nil) | ||
return | ||
} | ||
} else { | ||
self.onBinaryCallback(self, frameSequence.buffer) | ||
} | ||
case .text: | ||
self.onTextCallback(self, frameSequence.textBuffer) | ||
if let callback = self.onTextCallback { | ||
callback(self, String(buffer: frameSequence.buffer)) | ||
} | ||
self.onTextBufferCallback(self, frameSequence.buffer) | ||
case .ping, .pong: | ||
assertionFailure("Control frames never have a frameSequence") | ||
default: break | ||
|
@@ -293,30 +325,25 @@ public final class WebSocket { | |
} | ||
|
||
deinit { | ||
self.decompressor?.deinitializeDecoder() | ||
assert(self.isClosed, "WebSocket was not closed before deinit.") | ||
} | ||
} | ||
|
||
private struct WebSocketFrameSequence { | ||
var binaryBuffer: ByteBuffer | ||
var textBuffer: String | ||
var buffer: ByteBuffer | ||
var type: WebSocketOpcode | ||
|
||
init(type: WebSocketOpcode) { | ||
self.binaryBuffer = ByteBufferAllocator().buffer(capacity: 0) | ||
self.textBuffer = .init() | ||
self.buffer = ByteBufferAllocator().buffer(capacity: 0) | ||
self.type = type | ||
} | ||
|
||
mutating func append(_ frame: WebSocketFrame) { | ||
var data = frame.unmaskedData | ||
switch type { | ||
case .binary: | ||
self.binaryBuffer.writeBuffer(&data) | ||
case .text: | ||
if let string = data.readString(length: data.readableBytes) { | ||
self.textBuffer += string | ||
} | ||
case .binary, .text: | ||
var data = frame.unmaskedData | ||
self.buffer.writeBuffer(&data) | ||
default: break | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What purpose does this callback serve? I don't see any usage of it anywhere, including in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we only have
onTextCallback
. What is sent to the ws is, of course, Data, not string. but when usingonTextCallback
, ws-kit turns the data into a string and passes the string to the users of the package. The problem is that if the text is for example in JSON format, ws-kit users need to turn the string into Data again to pass it to somewhere like JSONDecoder. so we haveData -> String -> Data
instead of justData
which is wasteful. this new callback solves that problem.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like it is mentioned in this issue: #79
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i should add some tests, never-the-less.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a test to assert both text callbacks have the same behavior.
I should also add that
onBinary
is not the same asonTextBuffer
becauseonBinary
only is activated if the ws frame is an actual binary frame.onTextBuffer
is for when the ws frame is a text frame, but users might still prefer to access the string's data directly. I did try to mix the two, but it could cause problems.