Skip to content

Commit cf6f391

Browse files
committed
Make the file download delegate sendable
Motivation: Delegates can be passed from any thread and are executed on an arbitrary event loop. That means they need to be Sendable. Rather than making them all Sendable in one go, we'll do the larger ones separately. Modifications: - Make FileDownloadDelegate sendable Result: Safe to pass FileDownloadDelegate across isolation domains
1 parent a4fcd70 commit cf6f391

File tree

2 files changed

+116
-72
lines changed

2 files changed

+116
-72
lines changed

Sources/AsyncHTTPClient/FileDownloadDelegate.swift

+115-71
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
import NIOConcurrencyHelpers
1516
import NIOCore
1617
import NIOHTTP1
1718
import NIOPosix
@@ -53,20 +54,26 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
5354
}
5455
}
5556

56-
private var progress = Progress(
57-
totalBytes: nil,
58-
receivedBytes: 0
59-
)
57+
private struct State {
58+
var progress = Progress(
59+
totalBytes: nil,
60+
receivedBytes: 0
61+
)
62+
var fileIOThreadPool: NIOThreadPool?
63+
var fileHandleFuture: EventLoopFuture<NIOFileHandle>?
64+
var writeFuture: EventLoopFuture<Void>?
65+
}
66+
private let state: NIOLockedValueBox<State>
67+
68+
var _fileIOThreadPool: NIOThreadPool? {
69+
self.state.withLockedValue { $0.fileIOThreadPool }
70+
}
6071

6172
public typealias Response = Progress
6273

6374
private let filePath: String
64-
private(set) var fileIOThreadPool: NIOThreadPool?
65-
private let reportHead: ((HTTPClient.Task<Progress>, HTTPResponseHead) -> Void)?
66-
private let reportProgress: ((HTTPClient.Task<Progress>, Progress) -> Void)?
67-
68-
private var fileHandleFuture: EventLoopFuture<NIOFileHandle>?
69-
private var writeFuture: EventLoopFuture<Void>?
75+
private let reportHead: (@Sendable (HTTPClient.Task<Progress>, HTTPResponseHead) -> Void)?
76+
private let reportProgress: (@Sendable (HTTPClient.Task<Progress>, Progress) -> Void)?
7077

7178
/// Initializes a new file download delegate.
7279
///
@@ -78,20 +85,14 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
7885
/// the total byte count and download byte count passed to it as arguments. The callbacks
7986
/// will be invoked in the same threading context that the delegate itself is invoked,
8087
/// as controlled by `EventLoopPreference`.
88+
@preconcurrency
8189
public init(
8290
path: String,
8391
pool: NIOThreadPool? = nil,
84-
reportHead: ((HTTPClient.Task<Response>, HTTPResponseHead) -> Void)? = nil,
85-
reportProgress: ((HTTPClient.Task<Response>, Progress) -> Void)? = nil
92+
reportHead: (@Sendable (HTTPClient.Task<Response>, HTTPResponseHead) -> Void)? = nil,
93+
reportProgress: (@Sendable (HTTPClient.Task<Response>, Progress) -> Void)? = nil
8694
) throws {
87-
if let pool = pool {
88-
self.fileIOThreadPool = pool
89-
} else {
90-
// we should use the shared thread pool from the HTTPClient which
91-
// we will get from the `HTTPClient.Task`
92-
self.fileIOThreadPool = nil
93-
}
94-
95+
self.state = NIOLockedValueBox(State(fileIOThreadPool: pool))
9596
self.filePath = path
9697

9798
self.reportHead = reportHead
@@ -108,22 +109,23 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
108109
/// the total byte count and download byte count passed to it as arguments. The callbacks
109110
/// will be invoked in the same threading context that the delegate itself is invoked,
110111
/// as controlled by `EventLoopPreference`.
112+
@preconcurrency
111113
public convenience init(
112114
path: String,
113115
pool: NIOThreadPool,
114-
reportHead: ((HTTPResponseHead) -> Void)? = nil,
115-
reportProgress: ((Progress) -> Void)? = nil
116+
reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil,
117+
reportProgress: (@Sendable (Progress) -> Void)? = nil
116118
) throws {
117119
try self.init(
118120
path: path,
119121
pool: .some(pool),
120122
reportHead: reportHead.map { reportHead in
121-
{ _, head in
123+
{ @Sendable _, head in
122124
reportHead(head)
123125
}
124126
},
125127
reportProgress: reportProgress.map { reportProgress in
126-
{ _, head in
128+
{ @Sendable _, head in
127129
reportProgress(head)
128130
}
129131
}
@@ -139,99 +141,141 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
139141
/// the total byte count and download byte count passed to it as arguments. The callbacks
140142
/// will be invoked in the same threading context that the delegate itself is invoked,
141143
/// as controlled by `EventLoopPreference`.
144+
@preconcurrency
142145
public convenience init(
143146
path: String,
144-
reportHead: ((HTTPResponseHead) -> Void)? = nil,
145-
reportProgress: ((Progress) -> Void)? = nil
147+
reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil,
148+
reportProgress: (@Sendable (Progress) -> Void)? = nil
146149
) throws {
147150
try self.init(
148151
path: path,
149152
pool: nil,
150153
reportHead: reportHead.map { reportHead in
151-
{ _, head in
154+
{ @Sendable _, head in
152155
reportHead(head)
153156
}
154157
},
155158
reportProgress: reportProgress.map { reportProgress in
156-
{ _, head in
159+
{ @Sendable _, head in
157160
reportProgress(head)
158161
}
159162
}
160163
)
161164
}
162165

163166
public func didVisitURL(task: HTTPClient.Task<Progress>, _ request: HTTPClient.Request, _ head: HTTPResponseHead) {
164-
self.progress.history.append(.init(request: request, responseHead: head))
167+
self.state.withLockedValue {
168+
$0.progress.history.append(.init(request: request, responseHead: head))
169+
}
165170
}
166171

167172
public func didReceiveHead(
168173
task: HTTPClient.Task<Response>,
169174
_ head: HTTPResponseHead
170175
) -> EventLoopFuture<Void> {
171-
self.progress._head = head
176+
self.state.withLockedValue {
177+
$0.progress._head = head
172178

173-
self.reportHead?(task, head)
174-
175-
if let totalBytesString = head.headers.first(name: "Content-Length"),
176-
let totalBytes = Int(totalBytesString)
177-
{
178-
self.progress.totalBytes = totalBytes
179+
if let totalBytesString = head.headers.first(name: "Content-Length"),
180+
let totalBytes = Int(totalBytesString)
181+
{
182+
$0.progress.totalBytes = totalBytes
183+
}
179184
}
180185

186+
self.reportHead?(task, head)
187+
181188
return task.eventLoop.makeSucceededFuture(())
182189
}
183190

184191
public func didReceiveBodyPart(
185192
task: HTTPClient.Task<Response>,
186193
_ buffer: ByteBuffer
187194
) -> EventLoopFuture<Void> {
188-
let threadPool: NIOThreadPool = {
189-
guard let pool = self.fileIOThreadPool else {
190-
let pool = task.fileIOThreadPool
191-
self.fileIOThreadPool = pool
195+
let (progress, io) = self.state.withLockedValue { state in
196+
let threadPool: NIOThreadPool = {
197+
guard let pool = state.fileIOThreadPool else {
198+
let pool = task.fileIOThreadPool
199+
state.fileIOThreadPool = pool
200+
return pool
201+
}
192202
return pool
203+
}()
204+
205+
let io = NonBlockingFileIO(threadPool: threadPool)
206+
state.progress.receivedBytes += buffer.readableBytes
207+
return (state.progress, io)
208+
}
209+
self.reportProgress?(task, progress)
210+
211+
let writeFuture = self.state.withLockedValue { state in
212+
let writeFuture: EventLoopFuture<Void>
213+
if let fileHandleFuture = state.fileHandleFuture {
214+
writeFuture = fileHandleFuture.flatMap {
215+
io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
216+
}
217+
} else {
218+
let fileHandleFuture = io.openFile(
219+
_deprecatedPath: self.filePath,
220+
mode: .write,
221+
flags: .allowFileCreation(),
222+
eventLoop: task.eventLoop
223+
)
224+
state.fileHandleFuture = fileHandleFuture
225+
writeFuture = fileHandleFuture.flatMap {
226+
io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
227+
}
193228
}
194-
return pool
195-
}()
196-
let io = NonBlockingFileIO(threadPool: threadPool)
197-
self.progress.receivedBytes += buffer.readableBytes
198-
self.reportProgress?(task, self.progress)
199-
200-
let writeFuture: EventLoopFuture<Void>
201-
if let fileHandleFuture = self.fileHandleFuture {
202-
writeFuture = fileHandleFuture.flatMap {
203-
io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
204-
}
205-
} else {
206-
let fileHandleFuture = io.openFile(
207-
_deprecatedPath: self.filePath,
208-
mode: .write,
209-
flags: .allowFileCreation(),
210-
eventLoop: task.eventLoop
211-
)
212-
self.fileHandleFuture = fileHandleFuture
213-
writeFuture = fileHandleFuture.flatMap {
214-
io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop)
215-
}
229+
230+
state.writeFuture = writeFuture
231+
return writeFuture
216232
}
217233

218-
self.writeFuture = writeFuture
219234
return writeFuture
220235
}
221236

222237
private func close(fileHandle: NIOFileHandle) {
223238
try! fileHandle.close()
224-
self.fileHandleFuture = nil
239+
self.state.withLockedValue {
240+
$0.fileHandleFuture = nil
241+
}
225242
}
226243

227244
private func finalize() {
228-
if let writeFuture = self.writeFuture {
229-
writeFuture.whenComplete { _ in
230-
self.fileHandleFuture?.whenSuccess(self.close(fileHandle:))
231-
self.writeFuture = nil
245+
enum Finalize {
246+
case writeFuture(EventLoopFuture<Void>)
247+
case fileHandleFuture(EventLoopFuture<NIOFileHandle>)
248+
case none
249+
}
250+
251+
let finalize: Finalize = self.state.withLockedValue { state in
252+
if let writeFuture = state.writeFuture {
253+
return .writeFuture(writeFuture)
254+
} else if let fileHandleFuture = state.fileHandleFuture {
255+
return .fileHandleFuture(fileHandleFuture)
256+
} else {
257+
return .none
258+
}
259+
}
260+
261+
switch finalize {
262+
case .writeFuture(let future):
263+
future.whenComplete { _ in
264+
let fileHandleFuture = self.state.withLockedValue { state in
265+
let future = state.fileHandleFuture
266+
state.fileHandleFuture = nil
267+
state.writeFuture = nil
268+
return future
269+
}
270+
271+
fileHandleFuture?.whenSuccess {
272+
self.close(fileHandle: $0)
273+
}
232274
}
233-
} else {
234-
self.fileHandleFuture?.whenSuccess(self.close(fileHandle:))
275+
case .fileHandleFuture(let future):
276+
future.whenSuccess { self.close(fileHandle: $0) }
277+
case .none:
278+
()
235279
}
236280
}
237281

@@ -241,6 +285,6 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate {
241285

242286
public func didFinishRequest(task: HTTPClient.Task<Response>) throws -> Response {
243287
self.finalize()
244-
return self.progress
288+
return self.state.withLockedValue { $0.progress }
245289
}
246290
}

Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ class HTTPClientInternalTests: XCTestCase {
658658
).futureResult
659659
}
660660
_ = try EventLoopFuture.whenAllSucceed(resultFutures, on: self.clientGroup.next()).wait()
661-
let threadPools = delegates.map { $0.fileIOThreadPool }
661+
let threadPools = delegates.map { $0._fileIOThreadPool }
662662
let firstThreadPool = threadPools.first ?? nil
663663
XCTAssert(threadPools.dropFirst().allSatisfy { $0 === firstThreadPool })
664664
}

0 commit comments

Comments
 (0)