Skip to content

Commit 8f78702

Browse files
committed
address comments
1 parent 13d2e13 commit 8f78702

File tree

5 files changed

+51
-21
lines changed

5 files changed

+51
-21
lines changed

Sources/NIOWebSocket/NIOWebSocketClientUpgrader.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ private func _upgrade<UpgradeResult>(
199199
ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: maxFrameSize))
200200
)
201201
if enableAutomaticErrorHandling {
202-
try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false))
202+
let errorHandler = WebSocketProtocolErrorHandler()
203+
errorHandler.isServer = false
204+
try channel.pipeline.syncOperations.addHandler(errorHandler)
203205
}
204206
}
205207
.flatMap {

Sources/NIOWebSocket/NIOWebSocketServerUpgrader.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ private func _upgrade<UpgradeResult>(
328328
)
329329

330330
if automaticErrorHandling {
331-
try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
331+
try channel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
332332
}
333333
}.flatMap {
334334
upgradePipelineHandler(channel, upgradeRequest)

Sources/NIOWebSocket/WebSocketProtocolErrorHandler.swift

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,11 @@ public final class WebSocketProtocolErrorHandler: ChannelInboundHandler {
2222
public typealias InboundIn = Never
2323
public typealias OutboundOut = WebSocketFrame
2424

25-
private let isServer: Bool
26-
27-
/// Initialize the `WebSocketProtocolErrorHandler`
28-
///
29-
/// - Parameters:
30-
/// - isServer: indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is false.
31-
public init(isServer: Bool = false) {
32-
self.isServer = isServer
25+
/// Indicate that this `ChannelHandeler` is used by a WebSocket server or client. Default is true.
26+
public var isServer: Bool
27+
28+
public init() {
29+
self.isServer = true
3330
}
3431

3532
public func errorCaught(context: ChannelHandlerContext, error: Error) {

Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,34 @@ class WebSocketClientEndToEndTests: XCTestCase {
466466
// Close the pipeline.
467467
XCTAssertNoThrow(try clientChannel.close().wait())
468468
}
469+
470+
func testErrorHandlerMaskFrameForClient() throws {
471+
472+
let (clientChannel, _) = try self.runSuccessfulUpgrade()
473+
let maskBitMask: UInt8 = 0x80
474+
475+
var data = clientChannel.allocator.buffer(capacity: 4)
476+
// A fake frame header that claims that the length of the frame is 16385 bytes,
477+
// larger than the frame max.
478+
data.writeBytes([0x81, 0xFE, 0x40, 0x01])
479+
480+
XCTAssertThrowsError(try clientChannel.writeInbound(data)) { error in
481+
XCTAssertEqual(.invalidFrameLength, error as? NIOWebSocketError)
482+
}
483+
484+
clientChannel.embeddedEventLoop.run()
485+
var buffer = try clientChannel.readAllOutboundBuffers()
486+
487+
guard let (_, secondByte) = buffer.readMultipleIntegers(as: (UInt8, UInt8).self) else {
488+
XCTFail("Insufficient bytes from WebSocket frame")
489+
return
490+
}
491+
492+
let maskedBit = (secondByte & maskBitMask)
493+
XCTAssertEqual(0x80, maskedBit)
494+
495+
XCTAssertNoThrow(!clientChannel.isActive)
496+
}
469497
}
470498

471499
#if !canImport(Darwin) || swift(>=5.10)

Tests/NIOWebSocketTests/WebSocketFrameDecoderTest.swift

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
271271
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
272272
)
273273
XCTAssertNoThrow(
274-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
274+
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
275275
)
276276

277277
// A fake frame header that claims that the length of the frame is 16385 bytes,
@@ -291,7 +291,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
291291
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
292292
)
293293
XCTAssertNoThrow(
294-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
294+
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
295295
)
296296

297297
// A fake frame header that claims this is a fragmented ping frame.
@@ -310,7 +310,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
310310
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
311311
)
312312
XCTAssertNoThrow(
313-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
313+
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
314314
)
315315

316316
// A fake frame header that claims this is a ping frame with 126 bytes of data.
@@ -331,7 +331,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
331331
)
332332
XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(swallower, position: .first))
333333
XCTAssertNoThrow(
334-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
334+
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
335335
)
336336

337337
// A fake frame header that claims this is a fragmented ping frame.
@@ -492,7 +492,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
492492
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
493493
)
494494
XCTAssertNoThrow(
495-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
495+
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
496496
)
497497

498498
// A fake frame header that claims that the length of the frame is 16385 bytes,
@@ -515,7 +515,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
515515
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
516516
)
517517
XCTAssertNoThrow(
518-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
518+
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
519519
)
520520

521521
// A fake frame header that claims this is a fragmented ping frame.
@@ -537,7 +537,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
537537
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
538538
)
539539
XCTAssertNoThrow(
540-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
540+
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
541541
)
542542

543543
// A fake frame header that claims this is a ping frame with 126 bytes of data.
@@ -560,7 +560,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
560560
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
561561
)
562562
XCTAssertNoThrow(
563-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
563+
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
564564
)
565565
XCTAssertNoThrow(try self.decoderChannel.pipeline.syncOperations.addHandler(swallower, position: .first))
566566

@@ -603,7 +603,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
603603
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketFrameEncoder(), position: .first)
604604
)
605605
XCTAssertNoThrow(
606-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
606+
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
607607
)
608608

609609
// A fake frame header that claims that the length of the frame is 16385 bytes,
@@ -629,8 +629,11 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
629629
// We need to insert a decoder that doesn't do error handling, and then a separate error
630630
// handler.
631631
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
632+
633+
let errorHandler = WebSocketProtocolErrorHandler()
634+
errorHandler.isServer = false
632635
XCTAssertNoThrow(
633-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: false))
636+
try self.decoderChannel.pipeline.syncOperations.addHandler(errorHandler)
634637
)
635638

636639
// A fake frame header that claims that the length of the frame is 16385 bytes,
@@ -655,7 +658,7 @@ public final class WebSocketFrameDecoderTest: XCTestCase {
655658
// handler.
656659
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
657660
XCTAssertNoThrow(
658-
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler(isServer: true))
661+
try self.decoderChannel.pipeline.syncOperations.addHandler(WebSocketProtocolErrorHandler())
659662
)
660663

661664
// A fake frame header that claims that the length of the frame is 16385 bytes,

0 commit comments

Comments
 (0)