From eb60e8978fb32de827046c9dcac1a5a60fa3aa5d Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Fri, 23 Mar 2018 15:02:16 +0100 Subject: [PATCH] Dissallow access and removal of Head and Tail ChannelHandlers (#225) Motivation: The Head and Tail ChannelHandlers should not be accessible as these are implementation details of the ChannelPipeline and removal of these will even make the whole ChannelPipeline not work anymore. Modifications: - Make it impossible to access / removal the Head and Tail ChannelHandlers - Exclude the Head and Tail ChannelHandler from the debug String - Add unit tests. Result: More robust ChannelPipeline. --- Sources/NIO/ChannelPipeline.swift | 25 +++++--- .../NIOTests/ChannelPipelineTest+XCTest.swift | 2 + Tests/NIOTests/ChannelPipelineTest.swift | 58 +++++++++++++++++++ 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/Sources/NIO/ChannelPipeline.swift b/Sources/NIO/ChannelPipeline.swift index 6d533b419a..7504ce3c48 100644 --- a/Sources/NIO/ChannelPipeline.swift +++ b/Sources/NIO/ChannelPipeline.swift @@ -446,9 +446,16 @@ public final class ChannelPipeline: ChannelInvoker { return promise.futureResult } + /// Returns a `ChannelHandlerContext` which matches. + /// + /// This skips head and tail (as these are internal and should not be accessible by the user). + /// + /// - parameters: + /// - body: The predicate to execute per `ChannelHandlerContext` in the `ChannelPipeline`. + /// -returns: The first `ChannelHandlerContext` that matches or `nil` if none did. private func contextForPredicate0(_ body: @escaping((ChannelHandlerContext) -> Bool)) -> ChannelHandlerContext? { - var curCtx: ChannelHandlerContext? = self.head - while let ctx = curCtx { + var curCtx: ChannelHandlerContext? = self.head?.next + while let ctx = curCtx, ctx !== self.tail { if body(ctx) { return ctx } @@ -835,8 +842,8 @@ public final class ChannelPipeline: ChannelInvoker { self._channel = channel self.eventLoop = channel.eventLoop - self.head = ChannelHandlerContext(name: "head", handler: HeadChannelHandler.sharedInstance, pipeline: self) - self.tail = ChannelHandlerContext(name: "tail", handler: TailChannelHandler.sharedInstance, pipeline: self) + self.head = ChannelHandlerContext(name: HeadChannelHandler.name, handler: HeadChannelHandler.sharedInstance, pipeline: self) + self.tail = ChannelHandlerContext(name: TailChannelHandler.name, handler: TailChannelHandler.sharedInstance, pipeline: self) self.head?.next = self.tail self.tail?.prev = self.head } @@ -876,8 +883,9 @@ extension ChannelPipeline { } /// Special `ChannelHandler` that forwards all events to the `Channel.Unsafe` implementation. -private final class HeadChannelHandler: _ChannelOutboundHandler { +/* private but tests */ final class HeadChannelHandler: _ChannelOutboundHandler { + static let name = "head" static let sharedInstance = HeadChannelHandler() private init() { } @@ -930,8 +938,9 @@ private extension CloseMode { } /// Special `ChannelInboundHandler` which will consume all inbound events. -private final class TailChannelHandler: _ChannelInboundHandler, _ChannelOutboundHandler { +/* private but tests */ final class TailChannelHandler: _ChannelInboundHandler, _ChannelOutboundHandler { + static let name = "tail" static let sharedInstance = TailChannelHandler() private init() { } @@ -1422,8 +1431,8 @@ public final class ChannelHandlerContext: ChannelInvoker { extension ChannelPipeline: CustomDebugStringConvertible { public var debugDescription: String { var desc = "ChannelPipeline (\(ObjectIdentifier(self))):\n" - var node = self.head - while let ctx = node { + var node = self.head?.next + while let ctx = node, ctx !== self.tail { let inboundStr = ctx.handler is _ChannelInboundHandler ? "I" : "" let outboundStr = ctx.handler is _ChannelOutboundHandler ? "O" : "" desc += " \(ctx.name) (\(type(of: ctx.handler))) [\(inboundStr)\(outboundStr)]\n" diff --git a/Tests/NIOTests/ChannelPipelineTest+XCTest.swift b/Tests/NIOTests/ChannelPipelineTest+XCTest.swift index c2bbc69cb5..219dfc2cc6 100644 --- a/Tests/NIOTests/ChannelPipelineTest+XCTest.swift +++ b/Tests/NIOTests/ChannelPipelineTest+XCTest.swift @@ -43,6 +43,8 @@ extension ChannelPipelineTest { ("testAddBeforeWhileClosed", testAddBeforeWhileClosed), ("testFindHandlerByType", testFindHandlerByType), ("testFindHandlerByTypeReturnsTheFirstOfItsType", testFindHandlerByTypeReturnsTheFirstOfItsType), + ("testContextForHeadOrTail", testContextForHeadOrTail), + ("testRemoveHeadOrTail", testRemoveHeadOrTail), ] } } diff --git a/Tests/NIOTests/ChannelPipelineTest.swift b/Tests/NIOTests/ChannelPipelineTest.swift index d2a1a292f7..eedb7b7a70 100644 --- a/Tests/NIOTests/ChannelPipelineTest.swift +++ b/Tests/NIOTests/ChannelPipelineTest.swift @@ -620,4 +620,62 @@ class ChannelPipelineTest: XCTestCase { XCTAssertTrue(try h1 === channel.pipeline.context(handlerType: TestHandler.self).wait().handler) XCTAssertFalse(try h2 === channel.pipeline.context(handlerType: TestHandler.self).wait().handler) } + + func testContextForHeadOrTail() throws { + let channel = EmbeddedChannel() + + defer { + XCTAssertFalse(try channel.finish()) + } + + do { + _ = try channel.pipeline.context(name: HeadChannelHandler.name).wait() + XCTFail() + } catch let err as ChannelPipelineError where err == .notFound { + /// expected + } + + do { + _ = try channel.pipeline.context(handlerType: HeadChannelHandler.self).wait() + XCTFail() + } catch let err as ChannelPipelineError where err == .notFound { + /// expected + } + + do { + _ = try channel.pipeline.context(name: TailChannelHandler.name).wait() + XCTFail() + } catch let err as ChannelPipelineError where err == .notFound { + /// expected + } + + do { + _ = try channel.pipeline.context(handlerType: TailChannelHandler.self).wait() + XCTFail() + } catch let err as ChannelPipelineError where err == .notFound { + /// expected + } + } + + func testRemoveHeadOrTail() throws { + let channel = EmbeddedChannel() + + defer { + XCTAssertFalse(try channel.finish()) + } + + do { + _ = try channel.pipeline.remove(name: HeadChannelHandler.name).wait() + XCTFail() + } catch let err as ChannelPipelineError where err == .notFound { + /// expected + } + + do { + _ = try channel.pipeline.remove(name: TailChannelHandler.name).wait() + XCTFail() + } catch let err as ChannelPipelineError where err == .notFound { + /// expected + } + } }