From 260bb331e06c0d39c8a9588ac9b9ef2836033dc4 Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Sat, 4 Jan 2025 23:30:16 -0800 Subject: [PATCH 01/11] fix websocket error handler missing masking key --- .../NIOWebSocketClientUpgrader.swift | 2 +- .../NIOWebSocketServerUpgrader.swift | 2 +- .../WebSocketProtocolErrorHandler.swift | 17 ++++- .../WebSocketFrameDecoderTest.swift | 64 ++++++++++++++++--- 4 files changed, 73 insertions(+), 12 deletions(-) diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index cc5fa16e33..e8f1b5ca47 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -199,7 +199,7 @@ private func _upgrade( ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize)) ) if enableAutomaticErrorHandling { - try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) + try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false)) } } .flatMap { diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index 4f888a3e34..b740e507d9 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -331,7 +331,7 @@ private func _upgrade( ) if automaticErrorHandling { - try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) + try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) } }.flatMap { upgradePipelineHandler(channel, upgradeRequest) diff --git a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift index e367184b62..a7b7119f32 100644 --- a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift +++ b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift @@ -21,8 +21,16 @@ import NIOCore public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { public typealias InboundIn = Never public typealias OutboundOut = WebSocketFrame + + private let isServer: Bool - public init() {} + /// Initialize the `WebSocketProtocolErrorHandler` + /// + /// - Parameters: + /// - isServer: indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is false. + public init(isServer: Bool = false) { + self.isServer = isServer + } public func errorCaught(context: ChannelHandlerContext, error: Error) { let loopBoundContext = context.loopBound @@ -32,6 +40,7 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { let frame = WebSocketFrame( fin: true, opcode: .connectionClose, + maskKey: self.makeMaskingKey(), data: data ) context.writeAndFlush(Self.wrapOutboundOut(frame)).whenComplete { (_: Result) in @@ -44,6 +53,12 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { // forward the error on to let others see it. context.fireErrorCaught(error) } + + private func makeMaskingKey() -> WebSocketMaskingKey? { + // According to RFC 6455 Section 5, a client *must* mask all frames that it sends to the server. + // A server *must not* mask any frames that it sends to the client + self.isServer ? nil : .random() + } } @available(*, unavailable) diff --git a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift index 1ae4c0bf9b..29f8ac6519 100644 --- a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift +++ b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift @@ -270,7 +270,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. @@ -288,7 +288,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) // A fake frame header that claims this is a fragmented ping frame. self.buffer.writeBytes([0x09, 0x00]) @@ -305,7 +305,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) // A fake frame header that claims this is a ping frame with 126 bytes of data. self.buffer.writeBytes([0x89, 0x7E, 0x00, 0x7E]) @@ -324,7 +324,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(swallower, position: .first)) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) // A fake frame header that claims this is a fragmented ping frame. self.buffer.writeBytes([0x09, 0x00]) @@ -483,7 +483,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. @@ -504,7 +504,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) // A fake frame header that claims this is a fragmented ping frame. self.buffer.writeBytes([0x09, 0x00]) @@ -524,7 +524,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) // A fake frame header that claims this is a ping frame with 126 bytes of data. self.buffer.writeBytes([0x89, 0x7E, 0x00, 0x7E]) @@ -545,7 +545,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(swallower, position: .first)) // A fake frame header that claims this is a fragmented ping frame. @@ -586,7 +586,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. @@ -606,6 +606,52 @@ public final class WebSocketFrameDecoderTest: XCTestCase { // We expect that an error frame will have been written out. XCTAssertNoThrow(XCTAssertEqual([0x88, 0x02, 0x03, 0xF1], try self.decoderChannel.readAllOutboundBytes())) } + + func testErrorHandlerMaskFrameForClient() throws { + // We need to insert a decoder that doesn't do error handling, and then a separate error + // handler. + self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false))) + + // A fake frame header that claims that the length of the frame is 16385 bytes, + // larger than the frame max. + self.buffer.writeBytes([0x81, 0xFE, 0x40, 0x01]) + + XCTAssertThrowsError(try self.decoderChannel.writeInbound(self.buffer)) { error in + XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError) + } + + let frame = try self.decoderChannel.readOutbound(as: WebSocketFrame.self) + guard let frame = frame else { + // We expect that an error frame will have been written out. + XCTFail("WebSocketFrame should have been written out.") + return + } + XCTAssertNotNil(frame.maskKey) + } + + func testErrorHandlerNotMaskFrameForServer() throws { + // We need to insert a decoder that doesn't do error handling, and then a separate error + // handler. + self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder())) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + + // A fake frame header that claims that the length of the frame is 16385 bytes, + // larger than the frame max. + self.buffer.writeBytes([0x81, 0xFE, 0x40, 0x01]) + + XCTAssertThrowsError(try self.decoderChannel.writeInbound(self.buffer)) { error in + XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError) + } + + let frame = try self.decoderChannel.readOutbound(as: WebSocketFrame.self) + guard let frame = frame else { + // We expect that an error frame will have been written out. + XCTFail("WebSocketFrame should have been written out.") + return + } + XCTAssertNil(frame.maskKey) + } func testWebSocketFrameDescription() { let byteBuffer = ByteBuffer() From 10745bc23295db9f7fd3aea071673bc2a961ed28 Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Sat, 4 Jan 2025 23:32:23 -0800 Subject: [PATCH 02/11] remove unnecessary @unchecked --- Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index b740e507d9..060687bb58 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -62,11 +62,10 @@ extension HTTPHeaders { /// /// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to /// remove the HTTP `ChannelHandler`s. -public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unchecked Sendable { +public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, Sendable { // This type *is* Sendable but we can't express that properly until Swift 5.7. In the meantime // the conformance is `@unchecked`. - // FIXME: remove @unchecked when 5.7 is the minimum supported version. private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture /// RFC 6455 specs this as the required entry in the Upgrade header. From 4fb6005362fe103f39b3af5890b2b34329e74a39 Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Sat, 4 Jan 2025 23:36:47 -0800 Subject: [PATCH 03/11] lint --- .../WebSocketProtocolErrorHandler.swift | 4 +- .../WebSocketFrameDecoderTest.swift | 52 +++++++++++++------ 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift index a7b7119f32..bf80405c26 100644 --- a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift +++ b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift @@ -21,7 +21,7 @@ import NIOCore public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { public typealias InboundIn = Never public typealias OutboundOut = WebSocketFrame - + private let isServer: Bool /// Initialize the `WebSocketProtocolErrorHandler` @@ -53,7 +53,7 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { // forward the error on to let others see it. context.fireErrorCaught(error) } - + private func makeMaskingKey() -> WebSocketMaskingKey? { // According to RFC 6455 Section 5, a client *must* mask all frames that it sends to the server. // A server *must not* mask any frames that it sends to the client diff --git a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift index 29f8ac6519..bfe63ae3a8 100644 --- a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift +++ b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift @@ -270,7 +270,9 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + ) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. @@ -288,7 +290,9 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + ) // A fake frame header that claims this is a fragmented ping frame. self.buffer.writeBytes([0x09, 0x00]) @@ -305,7 +309,9 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + ) // A fake frame header that claims this is a ping frame with 126 bytes of data. self.buffer.writeBytes([0x89, 0x7E, 0x00, 0x7E]) @@ -324,7 +330,9 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(swallower, position: .first)) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + ) // A fake frame header that claims this is a fragmented ping frame. self.buffer.writeBytes([0x09, 0x00]) @@ -483,7 +491,9 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + ) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. @@ -504,7 +514,9 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + ) // A fake frame header that claims this is a fragmented ping frame. self.buffer.writeBytes([0x09, 0x00]) @@ -524,7 +536,9 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + ) // A fake frame header that claims this is a ping frame with 126 bytes of data. self.buffer.writeBytes([0x89, 0x7E, 0x00, 0x7E]) @@ -545,7 +559,9 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + ) XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(swallower, position: .first)) // A fake frame header that claims this is a fragmented ping frame. @@ -586,7 +602,9 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + ) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. @@ -606,12 +624,14 @@ public final class WebSocketFrameDecoderTest: XCTestCase { // We expect that an error frame will have been written out. XCTAssertNoThrow(XCTAssertEqual([0x88, 0x02, 0x03, 0xF1], try self.decoderChannel.readAllOutboundBytes())) } - + func testErrorHandlerMaskFrameForClient() throws { // We need to insert a decoder that doesn't do error handling, and then a separate error // handler. self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder())) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false)) + ) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. @@ -620,7 +640,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertThrowsError(try self.decoderChannel.writeInbound(self.buffer)) { error in XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError) } - + let frame = try self.decoderChannel.readOutbound(as: WebSocketFrame.self) guard let frame = frame else { // We expect that an error frame will have been written out. @@ -629,12 +649,14 @@ public final class WebSocketFrameDecoderTest: XCTestCase { } XCTAssertNotNil(frame.maskKey) } - + func testErrorHandlerNotMaskFrameForServer() throws { // We need to insert a decoder that doesn't do error handling, and then a separate error // handler. self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder())) - XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))) + XCTAssertNoThrow( + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + ) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. @@ -643,7 +665,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertThrowsError(try self.decoderChannel.writeInbound(self.buffer)) { error in XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError) } - + let frame = try self.decoderChannel.readOutbound(as: WebSocketFrame.self) guard let frame = frame else { // We expect that an error frame will have been written out. From 13d2e1361b8d6594d5311e131e083a8a62b84747 Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Sat, 4 Jan 2025 23:50:43 -0800 Subject: [PATCH 04/11] remove comments --- Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift | 2 -- 1 file changed, 2 deletions(-) diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index 060687bb58..d7515894f0 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -63,8 +63,6 @@ extension HTTPHeaders { /// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to /// remove the HTTP `ChannelHandler`s. public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, Sendable { - // This type *is* Sendable but we can't express that properly until Swift 5.7. In the meantime - // the conformance is `@unchecked`. private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture From 8f7870229297b78d551207f242bc4b3888d28e0e Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Sun, 5 Jan 2025 17:51:56 -0800 Subject: [PATCH 05/11] address comments --- .../NIOWebSocketClientUpgrader.swift | 4 ++- .../NIOWebSocketServerUpgrader.swift | 2 +- .../WebSocketProtocolErrorHandler.swift | 13 ++++----- .../WebSocketClientEndToEndTests.swift | 28 +++++++++++++++++++ .../WebSocketFrameDecoderTest.swift | 25 +++++++++-------- 5 files changed, 51 insertions(+), 21 deletions(-) diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index e8f1b5ca47..305a78e204 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -199,7 +199,9 @@ private func _upgrade( ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize)) ) if enableAutomaticErrorHandling { - try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false)) + let errorHandler = WebSocketProtocolErrorHandler() + errorHandler.isServer = false + try channel.pipeline.syncOperations.addHandler(errorHandler) } } .flatMap { diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index d7515894f0..14e3c72612 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -328,7 +328,7 @@ private func _upgrade( ) if automaticErrorHandling { - try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) } }.flatMap { upgradePipelineHandler(channel, upgradeRequest) diff --git a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift index bf80405c26..59f29ad7c6 100644 --- a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift +++ b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift @@ -22,14 +22,11 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { public typealias InboundIn = Never public typealias OutboundOut = WebSocketFrame - private let isServer: Bool - - /// Initialize the `WebSocketProtocolErrorHandler` - /// - /// - Parameters: - /// - isServer: indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is false. - public init(isServer: Bool = false) { - self.isServer = isServer + /// Indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is true. + public var isServer: Bool + + public init() { + self.isServer = true } public func errorCaught(context: ChannelHandlerContext, error: Error) { diff --git a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift index 71a86bdc1e..7d94f18ea9 100644 --- a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift @@ -466,6 +466,34 @@ class WebSocketClientEndToEndTests: XCTestCase { // Close the pipeline. XCTAssertNoThrow(try clientChannel.close().wait()) } + + func testErrorHandlerMaskFrameForClient() throws { + + let (clientChannel, _) = try self.runSuccessfulUpgrade() + let maskBitMask: UInt8 = 0x80 + + var data = clientChannel.allocator.buffer(capacity: 4) + // A fake frame header that claims that the length of the frame is 16385 bytes, + // larger than the frame max. + data.writeBytes([0x81, 0xFE, 0x40, 0x01]) + + XCTAssertThrowsError(try clientChannel.writeInbound(data)) { error in + XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError) + } + + clientChannel.embeddedEventLoop.run() + var buffer = try clientChannel.readAllOutboundBuffers() + + guard let (_, secondByte) = buffer.readMultipleIntegers(as: (UInt8, UInt8).self) else { + XCTFail("Insufficient bytes from WebSocket frame") + return + } + + let maskedBit = (secondByte & maskBitMask) + XCTAssertEqual(0x80, maskedBit) + + XCTAssertNoThrow(!clientChannel.isActive) + } } #if !canImport(Darwin) || swift(>=5.10) diff --git a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift index bfe63ae3a8..9e322f8c15 100644 --- a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift +++ b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift @@ -271,7 +271,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) ) // A fake frame header that claims that the length of the frame is 16385 bytes, @@ -291,7 +291,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) ) // A fake frame header that claims this is a fragmented ping frame. @@ -310,7 +310,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) ) // A fake frame header that claims this is a ping frame with 126 bytes of data. @@ -331,7 +331,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { ) XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(swallower, position: .first)) XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) ) // A fake frame header that claims this is a fragmented ping frame. @@ -492,7 +492,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) ) // A fake frame header that claims that the length of the frame is 16385 bytes, @@ -515,7 +515,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) ) // A fake frame header that claims this is a fragmented ping frame. @@ -537,7 +537,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) ) // A fake frame header that claims this is a ping frame with 126 bytes of data. @@ -560,7 +560,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) ) XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(swallower, position: .first)) @@ -603,7 +603,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) ) // A fake frame header that claims that the length of the frame is 16385 bytes, @@ -629,8 +629,11 @@ public final class WebSocketFrameDecoderTest: XCTestCase { // We need to insert a decoder that doesn't do error handling, and then a separate error // handler. self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder())) + + let errorHandler = WebSocketProtocolErrorHandler() + errorHandler.isServer = false XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false)) + try self.decoderChannel.pipeline.syncOperations.addHandler(errorHandler) ) // A fake frame header that claims that the length of the frame is 16385 bytes, @@ -655,7 +658,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { // handler. self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder())) XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true)) + try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) ) // A fake frame header that claims that the length of the frame is 16385 bytes, From a25f4ec340466657f5a7db0c54d4f262880d718b Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Sun, 5 Jan 2025 17:54:29 -0800 Subject: [PATCH 06/11] lint --- Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift | 2 +- .../NIOWebSocketTests/WebSocketClientEndToEndTests.swift | 8 ++++---- Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift index 59f29ad7c6..02c3f4a909 100644 --- a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift +++ b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift @@ -24,7 +24,7 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { /// Indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is true. public var isServer: Bool - + public init() { self.isServer = true } diff --git a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift index 7d94f18ea9..cc7db9eedf 100644 --- a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift @@ -466,9 +466,9 @@ class WebSocketClientEndToEndTests: XCTestCase { // Close the pipeline. XCTAssertNoThrow(try clientChannel.close().wait()) } - + func testErrorHandlerMaskFrameForClient() throws { - + let (clientChannel, _) = try self.runSuccessfulUpgrade() let maskBitMask: UInt8 = 0x80 @@ -476,14 +476,14 @@ class WebSocketClientEndToEndTests: XCTestCase { // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. data.writeBytes([0x81, 0xFE, 0x40, 0x01]) - + XCTAssertThrowsError(try clientChannel.writeInbound(data)) { error in XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError) } clientChannel.embeddedEventLoop.run() var buffer = try clientChannel.readAllOutboundBuffers() - + guard let (_, secondByte) = buffer.readMultipleIntegers(as: (UInt8, UInt8).self) else { XCTFail("Insufficient bytes from WebSocket frame") return diff --git a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift index 9e322f8c15..3b85964fb4 100644 --- a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift +++ b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift @@ -629,7 +629,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { // We need to insert a decoder that doesn't do error handling, and then a separate error // handler. self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder())) - + let errorHandler = WebSocketProtocolErrorHandler() errorHandler.isServer = false XCTAssertNoThrow( From 0ce88cae3884c2317756338e114d617acb4b4060 Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Sun, 5 Jan 2025 18:10:39 -0800 Subject: [PATCH 07/11] lint --- .../NIOWebSocketClientUpgrader.swift | 2 +- .../WebSocketProtocolErrorHandler.swift | 10 +++- .../WebSocketFrameDecoderTest.swift | 53 ------------------- 3 files changed, 10 insertions(+), 55 deletions(-) diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index 305a78e204..95f18978e2 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -200,7 +200,7 @@ private func _upgrade( ) if enableAutomaticErrorHandling { let errorHandler = WebSocketProtocolErrorHandler() - errorHandler.isServer = false + errorHandler.setIsServer(false) try channel.pipeline.syncOperations.addHandler(errorHandler) } } diff --git a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift index 02c3f4a909..4f9a05d2fa 100644 --- a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift +++ b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift @@ -23,7 +23,7 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { public typealias OutboundOut = WebSocketFrame /// Indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is true. - public var isServer: Bool + public private(set) var isServer: Bool public init() { self.isServer = true @@ -56,6 +56,14 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { // A server *must not* mask any frames that it sends to the client self.isServer ? nil : .random() } + + /// Configure this `ChannelHandler` to be used by a WebSocket server or client. + /// + /// - Parameters: + /// - isServer: indicate whether this `ChannelHandler` is used by a WebSocket server or client. + public func setIsServer(_ isServer: Bool) { + self.isServer = isServer + } } @available(*, unavailable) diff --git a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift index 3b85964fb4..7bcae74483 100644 --- a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift +++ b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift @@ -625,59 +625,6 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow(XCTAssertEqual([0x88, 0x02, 0x03, 0xF1], try self.decoderChannel.readAllOutboundBytes())) } - func testErrorHandlerMaskFrameForClient() throws { - // We need to insert a decoder that doesn't do error handling, and then a separate error - // handler. - self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder())) - - let errorHandler = WebSocketProtocolErrorHandler() - errorHandler.isServer = false - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(errorHandler) - ) - - // A fake frame header that claims that the length of the frame is 16385 bytes, - // larger than the frame max. - self.buffer.writeBytes([0x81, 0xFE, 0x40, 0x01]) - - XCTAssertThrowsError(try self.decoderChannel.writeInbound(self.buffer)) { error in - XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError) - } - - let frame = try self.decoderChannel.readOutbound(as: WebSocketFrame.self) - guard let frame = frame else { - // We expect that an error frame will have been written out. - XCTFail("WebSocketFrame should have been written out.") - return - } - XCTAssertNotNil(frame.maskKey) - } - - func testErrorHandlerNotMaskFrameForServer() throws { - // We need to insert a decoder that doesn't do error handling, and then a separate error - // handler. - self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder())) - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) - ) - - // A fake frame header that claims that the length of the frame is 16385 bytes, - // larger than the frame max. - self.buffer.writeBytes([0x81, 0xFE, 0x40, 0x01]) - - XCTAssertThrowsError(try self.decoderChannel.writeInbound(self.buffer)) { error in - XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError) - } - - let frame = try self.decoderChannel.readOutbound(as: WebSocketFrame.self) - guard let frame = frame else { - // We expect that an error frame will have been written out. - XCTFail("WebSocketFrame should have been written out.") - return - } - XCTAssertNil(frame.maskKey) - } - func testWebSocketFrameDescription() { let byteBuffer = ByteBuffer() let webSocketFrame = WebSocketFrame( From 8ffb71d920ae37428c192486e599e84147dd411a Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Sun, 5 Jan 2025 18:15:19 -0800 Subject: [PATCH 08/11] revert unnecessary format changes --- .../WebSocketFrameDecoderTest.swift | 36 +++++-------------- 1 file changed, 9 insertions(+), 27 deletions(-) diff --git a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift index 7bcae74483..1ae4c0bf9b 100644 --- a/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift +++ b/Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift @@ -270,9 +270,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) - ) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. @@ -290,9 +288,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) - ) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) // A fake frame header that claims this is a fragmented ping frame. self.buffer.writeBytes([0x09, 0x00]) @@ -309,9 +305,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) - ) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) // A fake frame header that claims this is a ping frame with 126 bytes of data. self.buffer.writeBytes([0x89, 0x7E, 0x00, 0x7E]) @@ -330,9 +324,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(swallower, position: .first)) - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) - ) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) // A fake frame header that claims this is a fragmented ping frame. self.buffer.writeBytes([0x09, 0x00]) @@ -491,9 +483,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) - ) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. @@ -514,9 +504,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) - ) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) // A fake frame header that claims this is a fragmented ping frame. self.buffer.writeBytes([0x09, 0x00]) @@ -536,9 +524,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) - ) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) // A fake frame header that claims this is a ping frame with 126 bytes of data. self.buffer.writeBytes([0x89, 0x7E, 0x00, 0x7E]) @@ -559,9 +545,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) - ) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(swallower, position: .first)) // A fake frame header that claims this is a fragmented ping frame. @@ -602,9 +586,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase { XCTAssertNoThrow( try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first) ) - XCTAssertNoThrow( - try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) - ) + XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())) // A fake frame header that claims that the length of the frame is 16385 bytes, // larger than the frame max. From 6ffaa9157a6e3a477a554f6335bd5d41810c72cd Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Mon, 6 Jan 2025 23:03:28 -0800 Subject: [PATCH 09/11] add new initializer for WebSocketProtocolErrorHandler --- .../NIOWebSocketClientUpgrader.swift | 4 +--- .../WebSocketProtocolErrorHandler.swift | 18 +++++++++--------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index 95f18978e2..e8f1b5ca47 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -199,9 +199,7 @@ private func _upgrade( ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize)) ) if enableAutomaticErrorHandling { - let errorHandler = WebSocketProtocolErrorHandler() - errorHandler.setIsServer(false) - try channel.pipeline.syncOperations.addHandler(errorHandler) + try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false)) } } .flatMap { diff --git a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift index 4f9a05d2fa..459d993333 100644 --- a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift +++ b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift @@ -23,11 +23,19 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { public typealias OutboundOut = WebSocketFrame /// Indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is true. - public private(set) var isServer: Bool + private let isServer: Bool public init() { self.isServer = true } + + /// Initialize this `ChannelHandler` to be used by a WebSocket server or client. + /// + /// - Parameters: + /// - isServer: indicate whether this `ChannelHandler` is used by a WebSocket server or client. + public init (isServer: Bool) { + self.isServer = isServer + } public func errorCaught(context: ChannelHandlerContext, error: Error) { let loopBoundContext = context.loopBound @@ -56,14 +64,6 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { // A server *must not* mask any frames that it sends to the client self.isServer ? nil : .random() } - - /// Configure this `ChannelHandler` to be used by a WebSocket server or client. - /// - /// - Parameters: - /// - isServer: indicate whether this `ChannelHandler` is used by a WebSocket server or client. - public func setIsServer(_ isServer: Bool) { - self.isServer = isServer - } } @available(*, unavailable) From 78bb1d4613bc3ba0c73a98a79a64c8616acefc7f Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Mon, 6 Jan 2025 23:04:20 -0800 Subject: [PATCH 10/11] lint --- Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift index 459d993333..6c63126e9e 100644 --- a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift +++ b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift @@ -28,12 +28,12 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { public init() { self.isServer = true } - + /// Initialize this `ChannelHandler` to be used by a WebSocket server or client. /// /// - Parameters: /// - isServer: indicate whether this `ChannelHandler` is used by a WebSocket server or client. - public init (isServer: Bool) { + public init(isServer: Bool) { self.isServer = isServer } From 98d87d7ca8d3bb5ee1e920fdd54969ed9a988c39 Mon Sep 17 00:00:00 2001 From: Zhennan Zhou Date: Mon, 6 Jan 2025 23:08:00 -0800 Subject: [PATCH 11/11] add assert for a clean closed channel --- Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift index cc7db9eedf..11e8f848e0 100644 --- a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift @@ -493,6 +493,7 @@ class WebSocketClientEndToEndTests: XCTestCase { XCTAssertEqual(0x80, maskedBit) XCTAssertNoThrow(!clientChannel.isActive) + XCTAssertTrue(try clientChannel.finish(acceptAlreadyClosed: true).isClean) } }