diff --git a/stdlib/public/Concurrency/DiscardingTaskGroup.swift b/stdlib/public/Concurrency/DiscardingTaskGroup.swift index 1de2e5311abcf..0fcec55446e16 100644 --- a/stdlib/public/Concurrency/DiscardingTaskGroup.swift +++ b/stdlib/public/Concurrency/DiscardingTaskGroup.swift @@ -80,7 +80,7 @@ public func withDiscardingTaskGroup( discardResults: true ) - let _group = Builtin.createTaskGroupWithFlags(flags, GroupResult.self) + let _group = Builtin.createTaskGroupWithFlags(flags, Void.self) var group = DiscardingTaskGroup(group: _group) defer { Builtin.destroyTaskGroup(_group) } @@ -108,7 +108,7 @@ public func _unsafeInheritExecutor_withDiscardingTaskGroup( discardResults: true ) - let _group = Builtin.createTaskGroupWithFlags(flags, GroupResult.self) + let _group = Builtin.createTaskGroupWithFlags(flags, Void.self) var group = DiscardingTaskGroup(group: _group) defer { Builtin.destroyTaskGroup(_group) } @@ -347,7 +347,7 @@ public func withThrowingDiscardingTaskGroup( discardResults: true ) - let _group = Builtin.createTaskGroupWithFlags(flags, GroupResult.self) + let _group = Builtin.createTaskGroupWithFlags(flags, Void.self) var group = ThrowingDiscardingTaskGroup(group: _group) defer { Builtin.destroyTaskGroup(_group) } @@ -378,7 +378,7 @@ public func _unsafeInheritExecutor_withThrowingDiscardingTaskGroup( discardResults: true ) - let _group = Builtin.createTaskGroupWithFlags(flags, GroupResult.self) + let _group = Builtin.createTaskGroupWithFlags(flags, Void.self) var group = ThrowingDiscardingTaskGroup(group: _group) defer { Builtin.destroyTaskGroup(_group) } diff --git a/stdlib/public/Concurrency/TaskGroup.cpp b/stdlib/public/Concurrency/TaskGroup.cpp index 9cc8f4ea0712b..0d1d2be13d156 100644 --- a/stdlib/public/Concurrency/TaskGroup.cpp +++ b/stdlib/public/Concurrency/TaskGroup.cpp @@ -439,12 +439,12 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord { /// by simultaneously decrementing one Pending and one Waiting tasks. /// /// This is used to atomically perform a waiting task completion. - /// The change is made 'relaxed' and may have to be retried. + /// The change is made with relaxed memory ordering. /// /// This can be safely used in a discarding task group as well, /// where the "ready" change will simply be ignored, since there /// are no ready bits to change. - bool statusCompletePendingReadyWaiting(TaskGroupStatus &old); + void statusCompletePendingReadyWaiting(TaskGroupStatus &old); /// Cancel the task group and all tasks within it. /// @@ -568,7 +568,11 @@ struct TaskGroupStatus { // so if we're in "discard results" mode, we must not decrement the ready count, // as there is no ready count in the status. change += group->isAccumulatingResults() ? oneReadyTask : 0; - return TaskGroupStatus{status - change}; + + TaskGroupStatus newStatus{status - change}; + SWIFT_TASK_GROUP_DEBUG_LOG(group, "completingPendingReadyWaiting %s", + newStatus.to_string(group).c_str()); + return newStatus; } TaskGroupStatus completingPendingReady(const TaskGroupBase* _Nonnull group) { @@ -669,11 +673,12 @@ struct TaskGroupStatus { }; }; -bool TaskGroupBase::statusCompletePendingReadyWaiting(TaskGroupStatus &old) { - return status.compare_exchange_strong( +void TaskGroupBase::statusCompletePendingReadyWaiting(TaskGroupStatus &old) { + while (!status.compare_exchange_weak( old.status, old.completingPendingReadyWaiting(this).status, /*success*/ std::memory_order_relaxed, - /*failure*/ std::memory_order_relaxed); + /*failure*/ std::memory_order_relaxed)) { + } // Loop until the compare_exchange succeeds } AsyncTask *TaskGroupBase::claimWaitingTask() { @@ -681,12 +686,11 @@ AsyncTask *TaskGroupBase::claimWaitingTask() { "attempted to claim waiting task but status indicates no waiting " "task is present!"); - auto waitingTask = waitQueue.load(std::memory_order_acquire); - if (!waitQueue.compare_exchange_strong(waitingTask, nullptr, - std::memory_order_release, - std::memory_order_relaxed)) { - swift_Concurrency_fatalError(0, "Failed to claim waitingTask!"); - } + auto waitingTask = waitQueue.exchange(nullptr, std::memory_order_acquire); + SWIFT_TASK_GROUP_DEBUG_LOG(this, "claimed waiting task %p", waitingTask); + if (!waitingTask) + swift_Concurrency_fatalError(0, "Claimed NULL waitingTask!"); + return waitingTask; } void TaskGroupBase::runWaitingTask(PreparedWaitingTask prepared) { @@ -737,13 +741,19 @@ uint64_t TaskGroupBase::pendingTasks() const { TaskGroupStatus TaskGroupBase::statusMarkWaitingAssumeAcquire() { auto old = status.fetch_or(TaskGroupStatus::waiting, std::memory_order_acquire); - return TaskGroupStatus{old | TaskGroupStatus::waiting}; + TaskGroupStatus newStatus{old | TaskGroupStatus::waiting}; + SWIFT_TASK_GROUP_DEBUG_LOG(this, "statusMarkWaitingAssumeAcquire %s", + newStatus.to_string(this).c_str()); + return newStatus; } TaskGroupStatus TaskGroupBase::statusMarkWaitingAssumeRelease() { auto old = status.fetch_or(TaskGroupStatus::waiting, std::memory_order_release); - return TaskGroupStatus{old | TaskGroupStatus::waiting}; + TaskGroupStatus newStatus{old | TaskGroupStatus::waiting}; + SWIFT_TASK_GROUP_DEBUG_LOG(this, "statusMarkWaitingAssumeRelease %s", + newStatus.to_string(this).c_str()); + return newStatus; } /// Add a single pending task to the status counter. @@ -786,6 +796,8 @@ TaskGroupStatus TaskGroupBase::statusAddPendingTaskAssumeRelaxed(bool unconditio TaskGroupStatus TaskGroupBase::statusRemoveWaitingRelease() { auto old = status.fetch_and(~TaskGroupStatus::waiting, std::memory_order_release); + SWIFT_TASK_GROUP_DEBUG_LOG(this, "statusRemoveWaitingRelease %s", + old.to_string(this).c_str()); return TaskGroupStatus{old}; } @@ -793,6 +805,9 @@ bool TaskGroupBase::statusCancel() { /// The cancelled bit is always the same, the first one, between all task group implementations: const uint64_t cancelled = TaskGroupStatus::cancelled; auto old = status.fetch_or(cancelled, std::memory_order_relaxed); + SWIFT_TASK_GROUP_DEBUG_LOG( + this, "statusCancel %s", + TaskGroupStatus{old | cancelled}.to_string(this).c_str()); // return if the status was already cancelled before we flipped it or not return old & cancelled; @@ -827,6 +842,8 @@ class AccumulatingTaskGroup: public TaskGroupBase { auto old = status.fetch_add(TaskGroupStatus::oneReadyTask, std::memory_order_acquire); auto s = TaskGroupStatus{old + TaskGroupStatus::oneReadyTask}; + SWIFT_TASK_GROUP_DEBUG_LOG(this, "statusMarkWaitingAssumeRelease %s", + s.to_string(this).c_str()); assert(s.readyTasks(this) <= s.pendingTasks(this)); return s; } @@ -880,23 +897,17 @@ class DiscardingTaskGroup: public TaskGroupBase { return TaskGroupStatus{status.load(std::memory_order_acquire)}; } - /// Compare-and-set old status to a status derived from the old one, - /// by simultaneously decrementing one Pending and one Waiting tasks. - /// - /// This is used to atomically perform a waiting task completion. - bool statusCompletePendingReadyWaiting(TaskGroupStatus &old) { - return status.compare_exchange_strong( - old.status, old.completingPendingReadyWaiting(this).status, - /*success*/ std::memory_order_relaxed, - /*failure*/ std::memory_order_relaxed); - } - /// Decrement the pending status count. /// Returns the *assumed* new status, including the just performed -1. TaskGroupStatus statusCompletePendingAssumeRelease() { auto old = status.fetch_sub(TaskGroupStatus::onePendingTask, std::memory_order_release); assert(TaskGroupStatus{old}.pendingTasks(this) > 0 && "attempted to decrement pending count when it was 0 already"); + SWIFT_TASK_GROUP_DEBUG_LOG( + this, "statusComplete = %s", + TaskGroupStatus{status.load(std::memory_order_relaxed)} + .to_string(this) + .c_str()); return TaskGroupStatus{old - TaskGroupStatus::onePendingTask}; } @@ -1323,6 +1334,8 @@ void AccumulatingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *contex // ==== a) has waiting task, so let us complete it right away if (assumed.hasWaitingTask()) { auto waitingTask = claimWaitingTask(); + SWIFT_TASK_GROUP_DEBUG_LOG(this, "offer, waitingTask = %p", waitingTask); + assert(waitingTask); auto prepared = prepareWaitingTaskWithTask( /*complete=*/waitingTask, /*with=*/completedTask, assumed, hadErrorResult); @@ -1480,14 +1493,7 @@ void DiscardingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *context) // We grab the waiting task while holding the group lock, because this // allows a single task to get the waiting task and attempt to complete it. // As another offer gets to run, it will have either a different waiting task, or no waiting task at all. - auto waitingTask = waitQueue.load(std::memory_order_acquire); - if (!waitQueue.compare_exchange_strong(waitingTask, nullptr, - std::memory_order_release, - std::memory_order_relaxed)) { - swift_Concurrency_fatalError(0, "Failed to claim waitingTask!"); - } - assert(waitingTask && "status claimed to have waitingTask but waitQueue was empty!"); - + auto waitingTask = claimWaitingTask(); SWIFT_TASK_GROUP_DEBUG_LOG(this, "offer, last pending task completed successfully, resume waitingTask:%p with completedTask:%p", waitingTask, completedTask); @@ -1558,8 +1564,11 @@ TaskGroupBase::PreparedWaitingTask TaskGroupBase::prepareWaitingTaskWithTask( bool hadErrorResult, bool alreadyDecremented, bool taskWasRetained) { - SWIFT_TASK_GROUP_DEBUG_LOG(this, "resume, waitingTask = %p, completedTask = %p, alreadyDecremented:%d, error:%d", - waitingTask, alreadyDecremented, hadErrorResult, completedTask); + SWIFT_TASK_GROUP_DEBUG_LOG(this, + "resume, waitingTask = %p, completedTask = %p, " + "alreadyDecremented:%d, error:%d", + waitingTask, completedTask, alreadyDecremented, + hadErrorResult); assert(waitingTask && "waitingTask must not be null when attempting to resume it"); assert(assumed.hasWaitingTask()); #if SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL @@ -1579,9 +1588,8 @@ TaskGroupBase::PreparedWaitingTask TaskGroupBase::prepareWaitingTaskWithTask( enqueueCompletedTask(completedTask, hadErrorResult); return {nullptr}; #else /* SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL */ - if (!alreadyDecremented) { - (void) statusCompletePendingReadyWaiting(assumed); - } + if (!alreadyDecremented) + statusCompletePendingReadyWaiting(assumed); // Populate the waiting task with value from completedTask. auto result = PollResult::get(completedTask, hadErrorResult); @@ -1643,9 +1651,8 @@ DiscardingTaskGroup::prepareWaitingTaskWithError(AsyncTask *waitingTask, _enqueueRawError(this, &readyQueue, error); return {nullptr}; #else /* SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL */ - if (!alreadyDecremented) { + if (!alreadyDecremented) statusCompletePendingReadyWaiting(assumed); - } // Run the task. auto result = PollResult::getError(error); @@ -1796,77 +1803,81 @@ reevaluate_if_taskgroup_has_results:; auto waitHead = waitQueue.load(std::memory_order_acquire); // ==== 2) Ready task was polled, return with it immediately ----------------- - if (assumed.readyTasks(this)) { + while (assumed.readyTasks(this)) { + // We loop when the compare_exchange fails. SWIFT_TASK_DEBUG_LOG("poll group = %p, tasks .ready = %d, .pending = %llu", this, assumed.readyTasks(this), assumed.pendingTasks(this)); auto assumedStatus = assumed.status; auto newStatus = TaskGroupStatus{assumedStatus}; - if (status.compare_exchange_strong( - assumedStatus, newStatus.completingPendingReadyWaiting(this).status, - /*success*/ std::memory_order_release, - /*failure*/ std::memory_order_acquire)) { - - // We're going back to running the task, so if we suspended before, - // we need to flag it as running again. - if (hasSuspended) { - waitingTask->flagAsRunning(); - } + if (!status.compare_exchange_weak( + assumedStatus, newStatus.completingPendingReadyWaiting(this).status, + /*success*/ std::memory_order_release, + /*failure*/ std::memory_order_acquire)) { + assumed = TaskGroupStatus{assumedStatus}; + continue; // We raced with something, try again. + } + SWIFT_TASK_DEBUG_LOG("poll, after CAS: %s", status.to_string().c_str()); - // Success! We are allowed to poll. - ReadyQueueItem item; - bool taskDequeued = readyQueue.dequeue(item); - assert(taskDequeued); (void) taskDequeued; - - auto futureFragment = - item.getStatus() == ReadyStatus::RawError ? - nullptr : - item.getTask()->futureFragment(); - - // Store the task in the result, so after we're done processing it may - // be swift_release'd; we kept it alive while it was in the readyQueue by - // an additional retain issued as we enqueued it there. - - // Note that the task was detached from the task group when it - // completed, so we don't need to do that bit of record-keeping here. - - switch (item.getStatus()) { - case ReadyStatus::Success: - // Immediately return the polled value - result.status = PollStatus::Success; - result.storage = futureFragment->getStoragePtr(); - result.successType = futureFragment->getResultType(); - result.retainedTask = item.getTask(); - assert(result.retainedTask && "polled a task, it must be not null"); - _swift_tsan_acquire(static_cast(result.retainedTask)); - unlock(); - return result; + // We're going back to running the task, so if we suspended before, + // we need to flag it as running again. + if (hasSuspended) { + waitingTask->flagAsRunning(); + } - case ReadyStatus::Error: - // Immediately return the polled value - result.status = PollStatus::Error; - result.storage = - reinterpret_cast(futureFragment->getError()); - result.successType = ResultTypeInfo(); - result.retainedTask = item.getTask(); - assert(result.retainedTask && "polled a task, it must be not null"); - _swift_tsan_acquire(static_cast(result.retainedTask)); - unlock(); - return result; + // Success! We are allowed to poll. + ReadyQueueItem item; + bool taskDequeued = readyQueue.dequeue(item); + assert(taskDequeued); (void) taskDequeued; + + auto futureFragment = + item.getStatus() == ReadyStatus::RawError ? + nullptr : + item.getTask()->futureFragment(); + + // Store the task in the result, so after we're done processing it may + // be swift_release'd; we kept it alive while it was in the readyQueue by + // an additional retain issued as we enqueued it there. + + // Note that the task was detached from the task group when it + // completed, so we don't need to do that bit of record-keeping here. + + switch (item.getStatus()) { + case ReadyStatus::Success: + // Immediately return the polled value + result.status = PollStatus::Success; + result.storage = futureFragment->getStoragePtr(); + result.successType = futureFragment->getResultType(); + result.retainedTask = item.getTask(); + assert(result.retainedTask && "polled a task, it must be not null"); + _swift_tsan_acquire(static_cast(result.retainedTask)); + unlock(); + return result; + + case ReadyStatus::Error: + // Immediately return the polled value + result.status = PollStatus::Error; + result.storage = + reinterpret_cast(futureFragment->getError()); + result.successType = ResultTypeInfo(); + result.retainedTask = item.getTask(); + assert(result.retainedTask && "polled a task, it must be not null"); + _swift_tsan_acquire(static_cast(result.retainedTask)); + unlock(); + return result; - case ReadyStatus::Empty: - result.status = PollStatus::Empty; - result.storage = nullptr; - result.retainedTask = nullptr; - result.successType = this->successType; - unlock(); - return result; + case ReadyStatus::Empty: + result.status = PollStatus::Empty; + result.storage = nullptr; + result.retainedTask = nullptr; + result.successType = this->successType; + unlock(); + return result; - case ReadyStatus::RawError: - swift_Concurrency_fatalError(0, "accumulating task group should never use raw-errors!"); - } - swift_Concurrency_fatalError(0, "must return result when status compare-and-swap was successful"); - } // else, we failed status-cas (some other waiter claimed a ready pending task, try again) + case ReadyStatus::RawError: + swift_Concurrency_fatalError(0, "accumulating task group should never use raw-errors!"); + } + swift_Concurrency_fatalError(0, "must return result when status compare-and-swap was successful"); } // ==== 3) Add to wait queue ------------------------------------------------- @@ -1878,7 +1889,9 @@ reevaluate_if_taskgroup_has_results:; } while (true) { // Put the waiting task at the beginning of the wait queue. - SWIFT_TASK_GROUP_DEBUG_LOG(this, "WATCH OUT, SET WAITER ONTO waitQueue.head = %p", waitQueue.load(std::memory_order_relaxed)); + SWIFT_TASK_GROUP_DEBUG_LOG( + this, "WATCH OUT, SET WAITER %p ONTO waitQueue.head = %p", waitingTask, + waitQueue.load(std::memory_order_relaxed)); if (waitQueue.compare_exchange_weak( waitHead, waitingTask, /*success*/ std::memory_order_release, @@ -2029,6 +2042,7 @@ void TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask, if (bodyError && isDiscardingResults() && readyQueue.isEmpty()) { auto discardingGroup = asDiscardingImpl(this); auto readyItem = ReadyQueueItem::getRawError(discardingGroup, bodyError); + SWIFT_TASK_GROUP_DEBUG_LOG(this, "enqueue %#" PRIxPTR, readyItem.storage); readyQueue.enqueue(readyItem); } diff --git a/test/Concurrency/Runtime/async_taskgroup_cancellation_race.swift b/test/Concurrency/Runtime/async_taskgroup_cancellation_race.swift new file mode 100644 index 0000000000000..96268d37c0914 --- /dev/null +++ b/test/Concurrency/Runtime/async_taskgroup_cancellation_race.swift @@ -0,0 +1,59 @@ +// RUN: %target-run-simple-swift + +// REQUIRES: executable_test +// REQUIRES: concurrency +// REQUIRES: libdispatch +// REQUIRES: concurrency_runtime +// UNSUPPORTED: use_os_stdlib +// UNSUPPORTED: back_deployment_runtime +// UNSUPPORTED: back_deploy_concurrency +// UNSUPPORTED: freestanding + +func unorderedResults( + _ fns: [@Sendable () async -> R]) -> (Task<(), Never>, AsyncStream) { + var capturedContinuation: AsyncStream.Continuation? = nil + let stream = AsyncStream { continuation in + capturedContinuation = continuation + } + + guard let capturedContinuation = capturedContinuation else { + fatalError("failed to capture continuation") + } + + let task = Task.detached { + await withTaskGroup(of: Void.self) { group in + for fn in fns { + group.addTask { + let _ = capturedContinuation.yield(await fn()) + } + } + await group.waitForAll() + } + capturedContinuation.finish() + } + + let result = (task, stream) + + return result + } + +var fns: [@Sendable () async -> String] = [ + { + try? await Task.sleep(nanoseconds: .random(in: 0..<50000)) + return "hello" + } +] + +fns.append(fns[0]) +fns.append(fns[0]) + +// This is a race that will crash or trigger an assertion failure if there's an +// issue. If we get to the end then we pass. +for _ in 0..<1000 { + let (t, s) = unorderedResults(fns) + + for try await x in s { + _ = x + if Bool.random() { t.cancel() } + } +} diff --git a/test/Concurrency/Runtime/async_taskgroup_discarding_neverConsumingTasks.swift b/test/Concurrency/Runtime/async_taskgroup_discarding_neverConsumingTasks.swift index cacde94631888..bd5f54e9e7ff0 100644 --- a/test/Concurrency/Runtime/async_taskgroup_discarding_neverConsumingTasks.swift +++ b/test/Concurrency/Runtime/async_taskgroup_discarding_neverConsumingTasks.swift @@ -85,9 +85,24 @@ func test_discardingTaskGroup_neverConsume(sleepBeforeGroupWaitAll: Duration) as print("all tasks: \(allTasks)") } +func test_discardingTaskGroup_bigReturn() async { + print(">>> \(#function)") + + // Test returning a very large value to ensure we don't overflow memory. + let array = await withDiscardingTaskGroup { group in + group.addTask {} + try? await Task.sleep(until: .now + .milliseconds(100), clock: .continuous) + return InlineArray<32768, Int>(repeating: 12345) + } + + // CHECK: Huge return value produced: 12345 12345 + print("Huge return value produced:", array[0], array[32767]) +} + @main struct Main { static func main() async { await test_discardingTaskGroup_neverConsume() await test_discardingTaskGroup_neverConsume(sleepBeforeGroupWaitAll: .milliseconds(500)) + await test_discardingTaskGroup_bigReturn() } }