Skip to content

Make the ResponseAccumulator Sendable #838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 130 additions & 110 deletions Sources/AsyncHTTPClient/HTTPHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,12 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
}
}

var history = [HTTPClient.RequestResponse]()
var state = State.idle
private struct MutableState: Sendable {
var history = [HTTPClient.RequestResponse]()
var state = State.idle
}

private let state: NIOLockedValueBox<MutableState>
let requestMethod: HTTPMethod
let requestHost: String

Expand Down Expand Up @@ -573,107 +577,126 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
self.requestMethod = request.method
self.requestHost = request.host
self.maxBodySize = maxBodySize
self.state = NIOLockedValueBox(MutableState())
}

public func didVisitURL(
task: HTTPClient.Task<HTTPClient.Response>,
_ request: HTTPClient.Request,
_ head: HTTPResponseHead
) {
self.history.append(.init(request: request, responseHead: head))
self.state.withLockedValue {
$0.history.append(.init(request: request, responseHead: head))
}
}

public func didReceiveHead(task: HTTPClient.Task<Response>, _ head: HTTPResponseHead) -> EventLoopFuture<Void> {
switch self.state {
case .idle:
if self.requestMethod != .HEAD,
let contentLength = head.headers.first(name: "Content-Length"),
let announcedBodySize = Int(contentLength),
announcedBodySize > self.maxBodySize
{
let error = ResponseTooBigError(maxBodySize: maxBodySize)
self.state = .error(error)
return task.eventLoop.makeFailedFuture(error)
}
let responseTooBig: Bool

self.state = .head(head)
case .head:
preconditionFailure("head already set")
case .body:
preconditionFailure("no head received before body")
case .end:
preconditionFailure("request already processed")
case .error:
break
if self.requestMethod != .HEAD,
let contentLength = head.headers.first(name: "Content-Length"),
let announcedBodySize = Int(contentLength),
announcedBodySize > self.maxBodySize
{
responseTooBig = true
} else {
responseTooBig = false
}

return self.state.withLockedValue {
switch $0.state {
case .idle:
if responseTooBig {
let error = ResponseTooBigError(maxBodySize: self.maxBodySize)
$0.state = .error(error)
return task.eventLoop.makeFailedFuture(error)
}

$0.state = .head(head)
case .head:
preconditionFailure("head already set")
case .body:
preconditionFailure("no head received before body")
case .end:
preconditionFailure("request already processed")
case .error:
break
}
return task.eventLoop.makeSucceededFuture(())
}
return task.eventLoop.makeSucceededFuture(())
}

public func didReceiveBodyPart(task: HTTPClient.Task<Response>, _ part: ByteBuffer) -> EventLoopFuture<Void> {
switch self.state {
case .idle:
preconditionFailure("no head received before body")
case .head(let head):
guard part.readableBytes <= self.maxBodySize else {
let error = ResponseTooBigError(maxBodySize: self.maxBodySize)
self.state = .error(error)
return task.eventLoop.makeFailedFuture(error)
}
self.state = .body(head, part)
case .body(let head, var body):
let newBufferSize = body.writerIndex + part.readableBytes
guard newBufferSize <= self.maxBodySize else {
let error = ResponseTooBigError(maxBodySize: self.maxBodySize)
self.state = .error(error)
return task.eventLoop.makeFailedFuture(error)
}
self.state.withLockedValue {
switch $0.state {
case .idle:
preconditionFailure("no head received before body")
case .head(let head):
guard part.readableBytes <= self.maxBodySize else {
let error = ResponseTooBigError(maxBodySize: self.maxBodySize)
$0.state = .error(error)
return task.eventLoop.makeFailedFuture(error)
}
$0.state = .body(head, part)
case .body(let head, var body):
let newBufferSize = body.writerIndex + part.readableBytes
guard newBufferSize <= self.maxBodySize else {
let error = ResponseTooBigError(maxBodySize: self.maxBodySize)
$0.state = .error(error)
return task.eventLoop.makeFailedFuture(error)
}

// The compiler can't prove that `self.state` is dead here (and it kinda isn't, there's
// a cross-module call in the way) so we need to drop the original reference to `body` in
// `self.state` or we'll get a CoW. To fix that we temporarily set the state to `.end` (which
// has no associated data). We'll fix it at the bottom of this block.
self.state = .end
var part = part
body.writeBuffer(&part)
self.state = .body(head, body)
case .end:
preconditionFailure("request already processed")
case .error:
break
// The compiler can't prove that `self.state` is dead here (and it kinda isn't, there's
// a cross-module call in the way) so we need to drop the original reference to `body` in
// `self.state` or we'll get a CoW. To fix that we temporarily set the state to `.end` (which
// has no associated data). We'll fix it at the bottom of this block.
$0.state = .end
var part = part
body.writeBuffer(&part)
$0.state = .body(head, body)
case .end:
preconditionFailure("request already processed")
case .error:
break
}
return task.eventLoop.makeSucceededFuture(())
}
return task.eventLoop.makeSucceededFuture(())
}

public func didReceiveError(task: HTTPClient.Task<Response>, _ error: Error) {
self.state = .error(error)
self.state.withLockedValue {
$0.state = .error(error)
}
}

public func didFinishRequest(task: HTTPClient.Task<Response>) throws -> Response {
switch self.state {
case .idle:
preconditionFailure("no head received before end")
case .head(let head):
return Response(
host: self.requestHost,
status: head.status,
version: head.version,
headers: head.headers,
body: nil,
history: self.history
)
case .body(let head, let body):
return Response(
host: self.requestHost,
status: head.status,
version: head.version,
headers: head.headers,
body: body,
history: self.history
)
case .end:
preconditionFailure("request already processed")
case .error(let error):
throw error
try self.state.withLockedValue {
switch $0.state {
case .idle:
preconditionFailure("no head received before end")
case .head(let head):
return Response(
host: self.requestHost,
status: head.status,
version: head.version,
headers: head.headers,
body: nil,
history: $0.history
)
case .body(let head, let body):
return Response(
host: self.requestHost,
status: head.status,
version: head.version,
headers: head.headers,
body: body,
history: $0.history
)
case .end:
preconditionFailure("request already processed")
case .error(let error):
throw error
}
}
}
}
Expand Down Expand Up @@ -709,8 +732,9 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate {
/// released together with the `HTTPTaskHandler` when channel is closed.
/// Users of the library are not required to keep a reference to the
/// object that implements this protocol, but may do so if needed.
public protocol HTTPClientResponseDelegate: AnyObject {
associatedtype Response
@preconcurrency
public protocol HTTPClientResponseDelegate: AnyObject, Sendable {
associatedtype Response: Sendable

/// Called when the request head is sent. Will be called once.
///
Expand Down Expand Up @@ -885,7 +909,7 @@ extension URL {
}
}

protocol HTTPClientTaskDelegate {
protocol HTTPClientTaskDelegate: Sendable {
func fail(_ error: Error)
}

Expand All @@ -894,49 +918,54 @@ extension HTTPClient {
///
/// Will be created by the library and could be used for obtaining
/// `EventLoopFuture<Response>` of the execution or cancellation of the execution.
public final class Task<Response> {
public final class Task<Response>: Sendable {
/// The `EventLoop` the delegate will be executed on.
public let eventLoop: EventLoop
/// The `Logger` used by the `Task` for logging.
public let logger: Logger // We are okay to store the logger here because a Task is for only one request.

let promise: EventLoopPromise<Response>

struct State: Sendable {
var isCancelled: Bool
var taskDelegate: HTTPClientTaskDelegate?
}

private let state: NIOLockedValueBox<State>

var isCancelled: Bool {
self.lock.withLock { self._isCancelled }
self.state.withLockedValue { $0.isCancelled }
}

var taskDelegate: HTTPClientTaskDelegate? {
get {
self.lock.withLock { self._taskDelegate }
self.state.withLockedValue { $0.taskDelegate }
}
set {
self.lock.withLock { self._taskDelegate = newValue }
self.state.withLockedValue { $0.taskDelegate = newValue }
}
}

private var _isCancelled: Bool = false
private var _taskDelegate: HTTPClientTaskDelegate?
private let lock = NIOLock()
private let makeOrGetFileIOThreadPool: () -> NIOThreadPool
private let makeOrGetFileIOThreadPool: @Sendable () -> NIOThreadPool

/// The shared thread pool of a ``HTTPClient`` used for file IO. It is lazily created on first access.
internal var fileIOThreadPool: NIOThreadPool {
self.makeOrGetFileIOThreadPool()
}

init(eventLoop: EventLoop, logger: Logger, makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool) {
init(eventLoop: EventLoop, logger: Logger, makeOrGetFileIOThreadPool: @escaping @Sendable () -> NIOThreadPool) {
self.eventLoop = eventLoop
self.promise = eventLoop.makePromise()
self.logger = logger
self.makeOrGetFileIOThreadPool = makeOrGetFileIOThreadPool
self.state = NIOLockedValueBox(State(isCancelled: false, taskDelegate: nil))
}

static func failedTask(
eventLoop: EventLoop,
error: Error,
logger: Logger,
makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool
makeOrGetFileIOThreadPool: @escaping @Sendable () -> NIOThreadPool
) -> Task<Response> {
let task = self.init(
eventLoop: eventLoop,
Expand All @@ -957,7 +986,8 @@ extension HTTPClient {
/// - returns: The value of ``futureResult`` when it completes.
/// - throws: The error value of ``futureResult`` if it errors.
@available(*, noasync, message: "wait() can block indefinitely, prefer get()", renamed: "get()")
public func wait() throws -> Response {
@preconcurrency
public func wait() throws -> Response where Response: Sendable {
try self.promise.futureResult.wait()
}

Expand All @@ -968,7 +998,8 @@ extension HTTPClient {
/// - returns: The value of ``futureResult`` when it completes.
/// - throws: The error value of ``futureResult`` if it errors.
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
public func get() async throws -> Response {
@preconcurrency
public func get() async throws -> Response where Response: Sendable {
try await self.promise.futureResult.get()
}

Expand All @@ -985,23 +1016,14 @@ extension HTTPClient {
///
/// - Parameter error: the error that is used to fail the promise
public func fail(reason error: Error) {
let taskDelegate = self.lock.withLock { () -> HTTPClientTaskDelegate? in
self._isCancelled = true
return self._taskDelegate
let taskDelegate = self.state.withLockedValue { state in
state.isCancelled = true
return state.taskDelegate
}

taskDelegate?.fail(error)
}

func succeed<Delegate: HTTPClientResponseDelegate>(
promise: EventLoopPromise<Response>?,
with value: Response,
delegateType: Delegate.Type,
closing: Bool
) {
promise?.succeed(value)
}

func fail<Delegate: HTTPClientResponseDelegate>(
with error: Error,
delegateType: Delegate.Type
Expand All @@ -1011,13 +1033,11 @@ extension HTTPClient {
}
}

extension HTTPClient.Task: @unchecked Sendable {}

internal struct TaskCancelEvent {}

// MARK: - RedirectHandler

internal struct RedirectHandler<ResponseType> {
internal struct RedirectHandler<ResponseType: Sendable> {
let request: HTTPClient.Request
let redirectState: RedirectState
let execute: (HTTPClient.Request, RedirectState) -> HTTPClient.Task<ResponseType>
Expand Down
Loading