Skip to content

Commit

Permalink
Dissallow access and removal of Head and Tail ChannelHandlers (#225)
Browse files Browse the repository at this point in the history
Motivation:

The Head and Tail ChannelHandlers should not be accessible as these are implementation details of the ChannelPipeline and removal of these will even make the whole ChannelPipeline not work anymore.

Modifications:

- Make it impossible to access / removal the Head and Tail ChannelHandlers
- Exclude the Head and Tail ChannelHandler from the debug String
- Add unit tests.

Result:

More robust ChannelPipeline.
  • Loading branch information
normanmaurer authored Mar 23, 2018
1 parent 61f03a7 commit eb60e89
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 8 deletions.
25 changes: 17 additions & 8 deletions Sources/NIO/ChannelPipeline.swift
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,16 @@ public final class ChannelPipeline: ChannelInvoker {
return promise.futureResult
}

/// Returns a `ChannelHandlerContext` which matches.
///
/// This skips head and tail (as these are internal and should not be accessible by the user).
///
/// - parameters:
/// - body: The predicate to execute per `ChannelHandlerContext` in the `ChannelPipeline`.
/// -returns: The first `ChannelHandlerContext` that matches or `nil` if none did.
private func contextForPredicate0(_ body: @escaping((ChannelHandlerContext) -> Bool)) -> ChannelHandlerContext? {
var curCtx: ChannelHandlerContext? = self.head
while let ctx = curCtx {
var curCtx: ChannelHandlerContext? = self.head?.next
while let ctx = curCtx, ctx !== self.tail {
if body(ctx) {
return ctx
}
Expand Down Expand Up @@ -835,8 +842,8 @@ public final class ChannelPipeline: ChannelInvoker {
self._channel = channel
self.eventLoop = channel.eventLoop

self.head = ChannelHandlerContext(name: "head", handler: HeadChannelHandler.sharedInstance, pipeline: self)
self.tail = ChannelHandlerContext(name: "tail", handler: TailChannelHandler.sharedInstance, pipeline: self)
self.head = ChannelHandlerContext(name: HeadChannelHandler.name, handler: HeadChannelHandler.sharedInstance, pipeline: self)
self.tail = ChannelHandlerContext(name: TailChannelHandler.name, handler: TailChannelHandler.sharedInstance, pipeline: self)
self.head?.next = self.tail
self.tail?.prev = self.head
}
Expand Down Expand Up @@ -876,8 +883,9 @@ extension ChannelPipeline {
}

/// Special `ChannelHandler` that forwards all events to the `Channel.Unsafe` implementation.
private final class HeadChannelHandler: _ChannelOutboundHandler {
/* private but tests */ final class HeadChannelHandler: _ChannelOutboundHandler {

static let name = "head"
static let sharedInstance = HeadChannelHandler()

private init() { }
Expand Down Expand Up @@ -930,8 +938,9 @@ private extension CloseMode {
}

/// Special `ChannelInboundHandler` which will consume all inbound events.
private final class TailChannelHandler: _ChannelInboundHandler, _ChannelOutboundHandler {
/* private but tests */ final class TailChannelHandler: _ChannelInboundHandler, _ChannelOutboundHandler {

static let name = "tail"
static let sharedInstance = TailChannelHandler()

private init() { }
Expand Down Expand Up @@ -1422,8 +1431,8 @@ public final class ChannelHandlerContext: ChannelInvoker {
extension ChannelPipeline: CustomDebugStringConvertible {
public var debugDescription: String {
var desc = "ChannelPipeline (\(ObjectIdentifier(self))):\n"
var node = self.head
while let ctx = node {
var node = self.head?.next
while let ctx = node, ctx !== self.tail {
let inboundStr = ctx.handler is _ChannelInboundHandler ? "I" : ""
let outboundStr = ctx.handler is _ChannelOutboundHandler ? "O" : ""
desc += " \(ctx.name) (\(type(of: ctx.handler))) [\(inboundStr)\(outboundStr)]\n"
Expand Down
2 changes: 2 additions & 0 deletions Tests/NIOTests/ChannelPipelineTest+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ extension ChannelPipelineTest {
("testAddBeforeWhileClosed", testAddBeforeWhileClosed),
("testFindHandlerByType", testFindHandlerByType),
("testFindHandlerByTypeReturnsTheFirstOfItsType", testFindHandlerByTypeReturnsTheFirstOfItsType),
("testContextForHeadOrTail", testContextForHeadOrTail),
("testRemoveHeadOrTail", testRemoveHeadOrTail),
]
}
}
Expand Down
58 changes: 58 additions & 0 deletions Tests/NIOTests/ChannelPipelineTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -620,4 +620,62 @@ class ChannelPipelineTest: XCTestCase {
XCTAssertTrue(try h1 === channel.pipeline.context(handlerType: TestHandler.self).wait().handler)
XCTAssertFalse(try h2 === channel.pipeline.context(handlerType: TestHandler.self).wait().handler)
}

func testContextForHeadOrTail() throws {
let channel = EmbeddedChannel()

defer {
XCTAssertFalse(try channel.finish())
}

do {
_ = try channel.pipeline.context(name: HeadChannelHandler.name).wait()
XCTFail()
} catch let err as ChannelPipelineError where err == .notFound {
/// expected
}

do {
_ = try channel.pipeline.context(handlerType: HeadChannelHandler.self).wait()
XCTFail()
} catch let err as ChannelPipelineError where err == .notFound {
/// expected
}

do {
_ = try channel.pipeline.context(name: TailChannelHandler.name).wait()
XCTFail()
} catch let err as ChannelPipelineError where err == .notFound {
/// expected
}

do {
_ = try channel.pipeline.context(handlerType: TailChannelHandler.self).wait()
XCTFail()
} catch let err as ChannelPipelineError where err == .notFound {
/// expected
}
}

func testRemoveHeadOrTail() throws {
let channel = EmbeddedChannel()

defer {
XCTAssertFalse(try channel.finish())
}

do {
_ = try channel.pipeline.remove(name: HeadChannelHandler.name).wait()
XCTFail()
} catch let err as ChannelPipelineError where err == .notFound {
/// expected
}

do {
_ = try channel.pipeline.remove(name: TailChannelHandler.name).wait()
XCTFail()
} catch let err as ChannelPipelineError where err == .notFound {
/// expected
}
}
}

0 comments on commit eb60e89

Please sign in to comment.