diff --git a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift index cc5fa16e33..e8f1b5ca47 100644 --- a/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift @@ -199,7 +199,7 @@ private func _upgrade( ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize)) ) if enableAutomaticErrorHandling { - try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler()) + try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false)) } } .flatMap { diff --git a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift index 4f888a3e34..14e3c72612 100644 --- a/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift +++ b/Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift @@ -62,11 +62,8 @@ extension HTTPHeaders { /// /// This upgrader assumes that the `HTTPServerUpgradeHandler` will appropriately mutate the pipeline to /// remove the HTTP `ChannelHandler`s. -public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, @unchecked Sendable { - // This type *is* Sendable but we can't express that properly until Swift 5.7. In the meantime - // the conformance is `@unchecked`. +public final class NIOWebSocketServerUpgrader: HTTPServerProtocolUpgrader, Sendable { - // FIXME: remove @unchecked when 5.7 is the minimum supported version. private typealias ShouldUpgrade = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture private typealias UpgradePipelineHandler = @Sendable (Channel, HTTPRequestHead) -> EventLoopFuture /// RFC 6455 specs this as the required entry in the Upgrade header. diff --git a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift index e367184b62..6c63126e9e 100644 --- a/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift +++ b/Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift @@ -22,7 +22,20 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { public typealias InboundIn = Never public typealias OutboundOut = WebSocketFrame - public init() {} + /// Indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is true. + private let isServer: Bool + + public init() { + self.isServer = true + } + + /// Initialize this `ChannelHandler` to be used by a WebSocket server or client. + /// + /// - Parameters: + /// - isServer: indicate whether this `ChannelHandler` is used by a WebSocket server or client. + public init(isServer: Bool) { + self.isServer = isServer + } public func errorCaught(context: ChannelHandlerContext, error: Error) { let loopBoundContext = context.loopBound @@ -32,6 +45,7 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { let frame = WebSocketFrame( fin: true, opcode: .connectionClose, + maskKey: self.makeMaskingKey(), data: data ) context.writeAndFlush(Self.wrapOutboundOut(frame)).whenComplete { (_: Result) in @@ -44,6 +58,12 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler { // forward the error on to let others see it. context.fireErrorCaught(error) } + + private func makeMaskingKey() -> WebSocketMaskingKey? { + // According to RFC 6455 Section 5, a client *must* mask all frames that it sends to the server. + // A server *must not* mask any frames that it sends to the client + self.isServer ? nil : .random() + } } @available(*, unavailable) diff --git a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift index 71a86bdc1e..11e8f848e0 100644 --- a/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift +++ b/Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift @@ -466,6 +466,35 @@ class WebSocketClientEndToEndTests: XCTestCase { // Close the pipeline. XCTAssertNoThrow(try clientChannel.close().wait()) } + + func testErrorHandlerMaskFrameForClient() throws { + + let (clientChannel, _) = try self.runSuccessfulUpgrade() + let maskBitMask: UInt8 = 0x80 + + var data = clientChannel.allocator.buffer(capacity: 4) + // A fake frame header that claims that the length of the frame is 16385 bytes, + // larger than the frame max. + data.writeBytes([0x81, 0xFE, 0x40, 0x01]) + + XCTAssertThrowsError(try clientChannel.writeInbound(data)) { error in + XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError) + } + + clientChannel.embeddedEventLoop.run() + var buffer = try clientChannel.readAllOutboundBuffers() + + guard let (_, secondByte) = buffer.readMultipleIntegers(as: (UInt8, UInt8).self) else { + XCTFail("Insufficient bytes from WebSocket frame") + return + } + + let maskedBit = (secondByte & maskBitMask) + XCTAssertEqual(0x80, maskedBit) + + XCTAssertNoThrow(!clientChannel.isActive) + XCTAssertTrue(try clientChannel.finish(acceptAlreadyClosed: true).isClean) + } } #if !canImport(Darwin) || swift(>=5.10)