Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert to using quiescing helper #453

Merged
merged 5 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 39 additions & 59 deletions Sources/HummingbirdCore/Server/HTTP/HTTPChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,75 +32,55 @@ enum HTTPChannelError: Error {
case unexpectedHTTPPart(HTTPRequestPart)
}

enum HTTPState: Int, Sendable {
case idle
case processing
case cancelled
}

extension HTTPChannelHandler {
public func handleHTTP(asyncChannel: NIOAsyncChannel<HTTPRequestPart, HTTPResponsePart>, logger: Logger) async {
let processingRequest = NIOLockedValueBox(HTTPState.idle)
do {
try await withTaskCancellationHandler {
try await withGracefulShutdownHandler {
try await asyncChannel.executeThenClose { inbound, outbound in
let responseWriter = HTTPServerBodyWriter(outbound: outbound)
var iterator = inbound.makeAsyncIterator()
try await asyncChannel.executeThenClose { inbound, outbound in
let responseWriter = HTTPServerBodyWriter(outbound: outbound)
var iterator = inbound.makeAsyncIterator()

// read first part, verify it is a head
guard let part = try await iterator.next() else { return }
guard case .head(var head) = part else {
throw HTTPChannelError.unexpectedHTTPPart(part)
}

// read first part, verify it is a head
guard let part = try await iterator.next() else { return }
guard case .head(var head) = part else {
throw HTTPChannelError.unexpectedHTTPPart(part)
while true {
let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator)
let request = Request(head: head, body: .init(asyncSequence: bodyStream))
let response = await self.responder(request, asyncChannel.channel)
do {
try await outbound.write(.head(response.head))
let tailHeaders = try await response.body.write(responseWriter)
try await outbound.write(.end(tailHeaders))
} catch {
throw error
}
if request.headers[.connection] == "close" {
return
}

// Flush current request
// read until we don't have a body part
var part: HTTPRequestPart?
while true {
// set to processing unless it is cancelled then exit
guard processingRequest.exchange(.processing) == .idle else { break }

let bodyStream = NIOAsyncChannelRequestBody(iterator: iterator)
let request = Request(head: head, body: .init(asyncSequence: bodyStream))
let response: Response = await self.responder(request, asyncChannel.channel)
do {
try await outbound.write(.head(response.head))
let tailHeaders = try await response.body.write(responseWriter)
try await outbound.write(.end(tailHeaders))
} catch {
throw error
}
if request.headers[.connection] == "close" {
return
}
// set to idle unless it is cancelled then exit
guard processingRequest.exchange(.idle) == .processing else { break }

// Flush current request
// read until we don't have a body part
var part: HTTPRequestPart?
while true {
part = try await iterator.next()
guard case .body = part else { break }
}
// if we have an end then read the next part
if case .end = part {
part = try await iterator.next()
}

// if part is nil break out of loop
guard let part else {
break
}
part = try await iterator.next()
guard case .body = part else { break }
}
// if we have an end then read the next part
if case .end = part {
part = try await iterator.next()
}

// part should be a head, if not throw error
guard case .head(let newHead) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) }
head = newHead
// if part is nil break out of loop
guard let part else {
break
}
}
} onGracefulShutdown: {
// set to cancelled
if processingRequest.exchange(.cancelled) == .idle {
// only close the channel input if it is idle
asyncChannel.channel.close(mode: .input, promise: nil)

// part should be a head, if not throw error
guard case .head(let newHead) = part else { throw HTTPChannelError.unexpectedHTTPPart(part) }
head = newHead
}
}
} onCancel: {
Expand Down
22 changes: 12 additions & 10 deletions Sources/HummingbirdCore/Server/HTTPUserEventHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public final class HTTPUserEventHandler: ChannelDuplexHandler, RemovableChannelH
let part = unwrapOutboundIn(data)
if case .end = part {
self.requestsInProgress -= 1
context.write(data, promise: promise)
context.writeAndFlush(data, promise: promise)
if self.closeAfterResponseWritten {
context.close(promise: nil)
self.closeAfterResponseWritten = false
Expand All @@ -61,21 +61,23 @@ public final class HTTPUserEventHandler: ChannelDuplexHandler, RemovableChannelH

public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch event {
case is ChannelShouldQuiesceEvent:
// we received a quiesce event. If we have any requests in progress we should
// wait for them to finish
if self.requestsInProgress > 0 {
self.closeAfterResponseWritten = true
} else {
context.close(promise: nil)
}

case IdleStateHandler.IdleStateEvent.read:
// if we get an idle read event and we haven't completed reading the request
// close the connection
if self.requestsBeingRead > 0 {
// close the connection, or a request hasnt been initiated
if self.requestsBeingRead > 0 || self.requestsInProgress == 0 {
self.logger.trace("Idle read timeout, so close channel")
context.close(promise: nil)
}

case IdleStateHandler.IdleStateEvent.write:
// if we get an idle write event and are not currently processing a request
if self.requestsInProgress == 0 {
self.logger.trace("Idle write timeout, so close channel")
context.close(mode: .input, promise: nil)
}

default:
context.fireUserInboundEventTriggered(event)
}
Expand Down
50 changes: 40 additions & 10 deletions Sources/HummingbirdCore/Server/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
onServerRunning: (@Sendable (Channel) async -> Void)?
)
case starting
case running(asyncChannel: AsyncServerChannel)
case running(
asyncChannel: AsyncServerChannel,
quiescingHelper: ServerQuiescingHelper
)
case shuttingDown(shutdownPromise: EventLoopPromise<Void>)
case shutdown

Expand Down Expand Up @@ -96,7 +99,7 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
self.state = .starting

do {
let asyncChannel = try await self.makeServer(
let (asyncChannel, quiescingHelper) = try await self.makeServer(
childChannelSetup: childChannelSetup,
configuration: configuration
)
Expand All @@ -107,7 +110,7 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
fatalError("We should only be running once")

case .starting:
self.state = .running(asyncChannel: asyncChannel)
self.state = .running(asyncChannel: asyncChannel, quiescingHelper: quiescingHelper)

await withGracefulShutdownHandler {
await onServerRunning?(asyncChannel.channel)
Expand Down Expand Up @@ -138,13 +141,14 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
}

case .shuttingDown, .shutdown:
self.logger.info("Shutting down")
try await asyncChannel.channel.close()
}
} catch {
self.state = .shutdown
throw error
}
self.state = .shutdown

case .starting, .running:
fatalError("Run should only be called once")

Expand All @@ -162,10 +166,20 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
case .initial, .starting:
self.state = .shutdown

case .running(let channel):
case .running(let channel, let quiescingHelper):
let shutdownPromise = channel.channel.eventLoop.makePromise(of: Void.self)
channel.channel.close(promise: shutdownPromise)
self.state = .shuttingDown(shutdownPromise: shutdownPromise)
quiescingHelper.initiateShutdown(promise: shutdownPromise)
try await shutdownPromise.futureResult.get()

// We need to check the state here again since we just awaited above
switch self.state {
case .initial, .starting, .running, .shutdown:
fatalError("Unexpected state \(self.state)")

case .shuttingDown:
self.state = .shutdown
}

case .shuttingDown(let shutdownPromise):
// We are just going to queue up behind the current graceful shutdown
Expand All @@ -179,8 +193,8 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
/// Start server
/// - Parameter responder: Object that provides responses to requests sent to the server
/// - Returns: EventLoopFuture that is fulfilled when server has started
nonisolated func makeServer(childChannelSetup: ChildChannel, configuration: ServerConfiguration) async throws -> AsyncServerChannel {
let bootstrap: ServerBootstrapProtocol
nonisolated func makeServer(childChannelSetup: ChildChannel, configuration: ServerConfiguration) async throws -> (AsyncServerChannel, ServerQuiescingHelper) {
var bootstrap: ServerBootstrapProtocol
#if canImport(Network)
if let tsBootstrap = self.createTSBootstrap(configuration: configuration) {
bootstrap = tsBootstrap
Expand All @@ -199,6 +213,11 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
)
#endif

let quiescingHelper = ServerQuiescingHelper(group: self.eventLoopGroup)
bootstrap = bootstrap.serverChannelInitializer { channel in
channel.pipeline.addHandler(quiescingHelper.makeServerChannelHandler(channel: channel))
}

do {
switch configuration.address.value {
case .hostname(let host, let port):
Expand All @@ -213,7 +232,7 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
)
}
self.logger.info("Server started and listening on \(host):\(asyncChannel.channel.localAddress?.port ?? port)")
return asyncChannel
return (asyncChannel, quiescingHelper)

case .unixDomainSocket(let path):
let asyncChannel = try await bootstrap.bind(
Expand All @@ -227,7 +246,7 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {
)
}
self.logger.info("Server started and listening on socket path \(path)")
return asyncChannel
return (asyncChannel, quiescingHelper)
}
} catch {
// should we close the channel here
Expand Down Expand Up @@ -271,6 +290,17 @@ public actor Server<ChildChannel: ServerChildChannel>: Service {

/// Protocol for bootstrap.
protocol ServerBootstrapProtocol {
/// Initialize the `ServerSocketChannel` with `initializer`. The most common task in initializer is to add
/// `ChannelHandler`s to the `ChannelPipeline`.
///
/// The `ServerSocketChannel` uses the accepted `Channel`s as inbound messages.
///
/// - note: To set the initializer for the accepted `SocketChannel`s, look at `ServerBootstrap.childChannelInitializer`.
///
/// - parameters:
/// - initializer: A closure that initializes the provided `Channel`.
func serverChannelInitializer(_ initializer: @escaping @Sendable (Channel) -> EventLoopFuture<Void>) -> Self
Joannis marked this conversation as resolved.
Show resolved Hide resolved

func bind<Output: Sendable>(
host: String,
port: Int,
Expand Down
4 changes: 2 additions & 2 deletions Sources/HummingbirdTesting/LiveTestFramework.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ final class LiveTestFramework<App: ApplicationProtocol>: ApplicationTestFramewor
client.connect()
do {
let value = try await test(Client(client: client))
await serviceGroup.triggerGracefulShutdown()
try await client.shutdown()
await serviceGroup.triggerGracefulShutdown()
return value
} catch {
await serviceGroup.triggerGracefulShutdown()
try await client.shutdown()
await serviceGroup.triggerGracefulShutdown()
throw error
}
}
Expand Down
55 changes: 52 additions & 3 deletions Tests/HummingbirdCoreTests/CoreTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class HummingBirdCoreTests: XCTestCase {
}
}

func testReadIdleHandler() async throws {
func testUnfinishedReadIdleHandler() async throws {
/// Channel Handler for serializing request header and data
final class HTTPServerIncompleteRequest: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPRequestPart
Expand Down Expand Up @@ -304,7 +304,56 @@ class HummingBirdCoreTests: XCTestCase {
}
}

func testWriteIdleTimeout() async throws {
func testUninitiatedReadIdleHandler() async throws {
/// Channel Handler for serializing request header and data
final class HTTPServerIncompleteRequest: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPRequestPart
typealias InboundOut = HTTPRequestPart

func channelRead(context: ChannelHandlerContext, data: NIOAny) {}
}
try await testServer(
responder: { request, _ in
do {
_ = try await request.body.collect(upTo: .max)
} catch {
return Response(status: .contentTooLarge)
}
return .init(status: .ok)
},
httpChannelSetup: .http1(additionalChannelHandlers: [HTTPServerIncompleteRequest(), IdleStateHandler(readTimeout: .seconds(1))]),
configuration: .init(address: .hostname(port: 0)),
eventLoopGroup: Self.eventLoopGroup,
logger: Logger(label: "Hummingbird")
) { client in
try await withTimeout(.seconds(5)) {
do {
_ = try await client.get("/", headers: [.connection: "keep-alive"])
XCTFail("Should not get here")
} catch TestClient.Error.connectionClosing {
} catch {
XCTFail("Unexpected error: \(error)")
}
}
}
}

func testLeftOpenReadIdleHandler() async throws {
/// Channel Handler for serializing request header and data
final class HTTPServerIncompleteRequest: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPRequestPart
typealias InboundOut = HTTPRequestPart
var readOneRequest = false
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let part = self.unwrapInboundIn(data)
if !self.readOneRequest {
context.fireChannelRead(data)
}
if case .end = part {
self.readOneRequest = true
}
}
}
try await testServer(
responder: { request, _ in
do {
Expand All @@ -314,7 +363,7 @@ class HummingBirdCoreTests: XCTestCase {
}
return .init(status: .ok)
},
httpChannelSetup: .http1(additionalChannelHandlers: [IdleStateHandler(writeTimeout: .seconds(1))]),
httpChannelSetup: .http1(additionalChannelHandlers: [HTTPServerIncompleteRequest(), IdleStateHandler(readTimeout: .seconds(1))]),
configuration: .init(address: .hostname(port: 0)),
eventLoopGroup: Self.eventLoopGroup,
logger: Logger(label: "Hummingbird")
Expand Down
4 changes: 3 additions & 1 deletion Tests/HummingbirdCoreTests/HTTP2Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ class HummingBirdHTTP2Tests: XCTestCase {
func testConnect() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
var logger = Logger(label: "Hummingbird")
logger.logLevel = .trace
try await testServer(
responder: { _, _ in
.init(status: .ok)
},
httpChannelSetup: .http2Upgrade(tlsConfiguration: getServerTLSConfiguration()),
configuration: .init(address: .hostname(port: 0), serverName: testServerName),
eventLoopGroup: eventLoopGroup,
logger: Logger(label: "Hummingbird")
logger: logger
) { port in
var tlsConfiguration = try getClientTLSConfiguration()
// no way to override the SSL server name with AsyncHTTPClient so need to set
Expand Down
Loading
Loading