From 30ebf9e1adfe645ea117cc7baa67daeabcaeb50a Mon Sep 17 00:00:00 2001 From: Andrea Ciprietti Date: Wed, 18 Dec 2024 22:29:52 +0100 Subject: [PATCH] sync: refactor `Semaphore::poll_acquire` This commit tries to make the logic in `poll_acquire` flow more naturally, while also using more descriptive names and improving commentary (including the addition of a missing "SAFETY:" annotation). There should be no functional changes. --- tokio/src/sync/batch_semaphore.rs | 145 +++++++++++++++--------------- 1 file changed, 74 insertions(+), 71 deletions(-) diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index aabee0f5c0e..ecf4b11b767 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -394,108 +394,109 @@ impl Semaphore { } } + /// Does the actual work of polling an `Acquire` waiting for permits. + /// Tries to acquire as many permits as needed, and no more than the + /// available ones. + /// + /// The returned status is `Ready` if all necessary permits have been + /// acquired, or otherwise `Pending`. fn poll_acquire( &self, cx: &mut Context<'_>, - num_permits: usize, + total_needed: usize, node: Pin<&mut Waiter>, queued: bool, ) -> Poll> { - let mut acquired = 0; - - let needed = if queued { - node.state.load(Acquire) << Self::PERMIT_SHIFT + let still_needed = if queued { + node.state.load(Acquire) } else { - num_permits << Self::PERMIT_SHIFT - }; + total_needed + } << Self::PERMIT_SHIFT; - let mut lock = None; - // First, try to take the requested number of permits from the - // semaphore. - let mut curr = self.permits.load(Acquire); - let mut waiters = loop { + let mut available = self.permits.load(Acquire); + let mut acquired: usize; + + // We could acquire the lock right now, but that's going to be + // expensive. To optimize, we try once without holding the mutex + // and only lock it if really necessary. + let mut queue_guard = None; + + // CAS loop to acquire _some_ permits from the semaphore. Either the + // waiter gets all needed permits, or it consumes all available ones. + loop { // Has the semaphore closed? - if curr & Self::CLOSED > 0 { + if available & Self::CLOSED > 0 { return Poll::Ready(Err(AcquireError::closed())); } - let mut remaining = 0; - let total = curr - .checked_add(acquired) - .expect("number of permits must not overflow"); - let (next, acq) = if total >= needed { - let next = curr - (needed - acquired); - (next, needed >> Self::PERMIT_SHIFT) - } else { - remaining = (needed - acquired) - curr; - (0, curr >> Self::PERMIT_SHIFT) - }; - - if remaining > 0 && lock.is_none() { - // No permits were immediately available, so this permit will - // (probably) need to wait. We'll need to acquire a lock on the - // wait queue before continuing. We need to do this _before_ the - // CAS that sets the new value of the semaphore's `permits` - // counter. Otherwise, if we subtract the permits and then - // acquire the lock, we might miss additional permits being - // added while waiting for the lock. - lock = Some(self.waiters.lock()); + acquired = cmp::min(available, still_needed); + let remaining = available - acquired; + + // Not all permits were immediately available, so this waiter will + // (probably) need to wait. We'll need to acquire a lock on the wait + // queue before continuing. We need to do this _before_ the CAS that + // sets the new value of the semaphore's `permits` counter. + // Otherwise, if we subtract the permits and then acquire the lock, + // we might miss additional permits being added while waiting for + // the lock. + if acquired < still_needed && queue_guard.is_none() { + queue_guard = Some(self.waiters.lock()); } - match self.permits.compare_exchange(curr, next, AcqRel, Acquire) { + match self + .permits + .compare_exchange(available, remaining, AcqRel, Acquire) + { Ok(_) => { - acquired += acq; - if remaining == 0 { - if !queued { - #[cfg(all(tokio_unstable, feature = "tracing"))] - self.resource_span.in_scope(|| { - tracing::trace!( - target: "runtime::resource::state_update", - permits = acquired, - permits.op = "sub", - ); - tracing::trace!( - target: "runtime::resource::async_op::state_update", - permits_obtained = acquired, - permits.op = "add", - ) - }); - - return Poll::Ready(Ok(())); - } else if lock.is_none() { - break self.waiters.lock(); - } - } - break lock.expect("lock must be acquired before waiting"); + break; } - Err(actual) => curr = actual, + Err(actual) => available = actual, } - }; - - if waiters.closed { - return Poll::Ready(Err(AcquireError::closed())); } #[cfg(all(tokio_unstable, feature = "tracing"))] self.resource_span.in_scope(|| { tracing::trace!( target: "runtime::resource::state_update", - permits = acquired, + permits = (acquired >> Self::PERMIT_SHIFT), permits.op = "sub", - ) + ); }); + // If the waiter gets all the permits at the first try, don't bother + // enqueuing it. + if acquired == still_needed && !queued { + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::async_op::state_update", + permits_obtained = (acquired >> Self::PERMIT_SHIFT), + permits.op = "add", + ) + }); + return Poll::Ready(Ok(())); + } + + let mut queue_guard = queue_guard.unwrap_or_else(|| self.waiters.lock()); + if queue_guard.closed { + return Poll::Ready(Err(AcquireError::closed())); + } + + acquired >>= Self::PERMIT_SHIFT; + if node.assign_permits(&mut acquired) { - self.add_permits_locked(acquired, waiters); + // If the waiter is happy with the acquired permits, return the + // leftover to the semaphore and bail out. + self.add_permits_locked(acquired, queue_guard); return Poll::Ready(Ok(())); } + assert_eq!(acquired, 0, "acquired more permits than necessary"); - assert_eq!(acquired, 0); let mut old_waker = None; // Otherwise, register the waker & enqueue the node. node.waker.with_mut(|waker| { - // Safety: the wait list is locked, so we may modify the waker. + // SAFETY: The wait list is locked, so we may modify the waker. let waker = unsafe { &mut *waker }; // Do we need to register the new waker? if waker @@ -508,14 +509,16 @@ impl Semaphore { // If the waiter is not already in the wait queue, enqueue it. if !queued { + // SAFETY: An `Acquire` never moves the `node` field until dropped. + // This also implies that the pointer is non-null. let node = unsafe { let node = Pin::into_inner_unchecked(node) as *mut _; NonNull::new_unchecked(node) }; - - waiters.queue.push_front(node); + queue_guard.queue.push_front(node); } - drop(waiters); + + drop(queue_guard); drop(old_waker); Poll::Pending