Skip to content

Commit

Permalink
sync: refactor Semaphore::poll_acquire
Browse files Browse the repository at this point in the history
This commit tries to make the logic in `poll_acquire` flow more naturally,
while also using more descriptive names and improving commentary (including
the addition of a missing "SAFETY:" annotation).

There should be no functional changes.
  • Loading branch information
cip999 committed Dec 19, 2024
1 parent 10e23d1 commit 74df9e7
Showing 1 changed file with 75 additions and 71 deletions.
146 changes: 75 additions & 71 deletions tokio/src/sync/batch_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,108 +394,109 @@ impl Semaphore {
}
}

/// Does the actual work of polling an `Acquire` waiting for permits.
/// Tries to acquire as many permits as needed, and no more than the
/// available ones.
///
/// The returned status is `Ready` if all necessary permits have been
/// acquired, or otherwise `Pending`.
fn poll_acquire(
&self,
cx: &mut Context<'_>,
num_permits: usize,
total_needed: usize,
node: Pin<&mut Waiter>,
queued: bool,
) -> Poll<Result<(), AcquireError>> {
let mut acquired = 0;

let needed = if queued {
node.state.load(Acquire) << Self::PERMIT_SHIFT
let still_needed = if queued {
node.state.load(Acquire)
} else {
num_permits << Self::PERMIT_SHIFT
};
total_needed
} << Self::PERMIT_SHIFT;

let mut lock = None;
// First, try to take the requested number of permits from the
// semaphore.
let mut curr = self.permits.load(Acquire);
let mut waiters = loop {
let mut available = self.permits.load(Acquire);
let mut acquired: usize;

// We could acquire the lock right now, but that's going to be
// expensive. To optimize, we try once without holding the mutex
// and only lock it if really necessary.
let mut queue_guard = None;

// CAS loop to acquire _some_ permits from the semaphore. Either the
// waiter gets all needed permits, or it consumes all available ones.
loop {
// Has the semaphore closed?
if curr & Self::CLOSED > 0 {
if available & Self::CLOSED > 0 {
return Poll::Ready(Err(AcquireError::closed()));
}

let mut remaining = 0;
let total = curr
.checked_add(acquired)
.expect("number of permits must not overflow");
let (next, acq) = if total >= needed {
let next = curr - (needed - acquired);
(next, needed >> Self::PERMIT_SHIFT)
} else {
remaining = (needed - acquired) - curr;
(0, curr >> Self::PERMIT_SHIFT)
};

if remaining > 0 && lock.is_none() {
// No permits were immediately available, so this permit will
// (probably) need to wait. We'll need to acquire a lock on the
// wait queue before continuing. We need to do this _before_ the
// CAS that sets the new value of the semaphore's `permits`
// counter. Otherwise, if we subtract the permits and then
// acquire the lock, we might miss additional permits being
// added while waiting for the lock.
lock = Some(self.waiters.lock());
acquired = cmp::min(available, still_needed);
let remaining = available - acquired;

// Not all permits were immediately available, so this waiter will
// (probably) need to wait. We'll need to acquire a lock on the wait
// queue before continuing. We need to do this _before_ the CAS that
// sets the new value of the semaphore's `permits` counter.
// Otherwise, if we subtract the permits and then acquire the lock,
// we might miss additional permits being added while waiting for
// the lock.
if acquired < still_needed && queue_guard.is_none() {
queue_guard = Some(self.waiters.lock());
}

match self.permits.compare_exchange(curr, next, AcqRel, Acquire) {
match self
.permits
.compare_exchange(available, remaining, AcqRel, Acquire)
{
Ok(_) => {
acquired += acq;
if remaining == 0 {
if !queued {
#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
permits = acquired,
permits.op = "sub",
);
tracing::trace!(
target: "runtime::resource::async_op::state_update",
permits_obtained = acquired,
permits.op = "add",
)
});

return Poll::Ready(Ok(()));
} else if lock.is_none() {
break self.waiters.lock();
}
}
break lock.expect("lock must be acquired before waiting");
break;
}
Err(actual) => curr = actual,
Err(actual) => available = actual,
}
};

if waiters.closed {
return Poll::Ready(Err(AcquireError::closed()));
}

#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
permits = acquired,
permits = (acquired >> Self::PERMIT_SHIFT),
permits.op = "sub",
)
);
});

// If the waiter gets all the permits at the first try, don't bother
// enqueuing it.
if acquired == still_needed && !queued {
#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::async_op::state_update",
permits_obtained = (acquired >> Self::PERMIT_SHIFT),
permits.op = "add",
)
});
return Poll::Ready(Ok(()));
}

let mut queue_guard = queue_guard.unwrap_or_else(|| self.waiters.lock());
if queue_guard.closed {
return Poll::Ready(Err(AcquireError::closed()));
}

acquired >>= Self::PERMIT_SHIFT;

if node.assign_permits(&mut acquired) {
self.add_permits_locked(acquired, waiters);
// If the waiter is happy with the acquired permits, return the
// leftover to the semaphore and bail out.
self.add_permits_locked(acquired, queue_guard);
return Poll::Ready(Ok(()));
}
assert_eq!(acquired, 0, "acquired more permits than necessary");

assert_eq!(acquired, 0);
let mut old_waker = None;

// Otherwise, register the waker & enqueue the node.
node.waker.with_mut(|waker| {
// Safety: the wait list is locked, so we may modify the waker.
// SAFETY: The wait list is locked, so we may modify the waker.
let waker = unsafe { &mut *waker };
// Do we need to register the new waker?
if waker
Expand All @@ -508,14 +509,17 @@ impl Semaphore {

// If the waiter is not already in the wait queue, enqueue it.
if !queued {
// SAFETY: Upholding the contract of the (unsafe) `Link` trait means
// that the pointee of `node` will not be moved by the linked list.
// Moreover, `node` cannot be null because it comes from a `Pin`.
let node = unsafe {
let node = Pin::into_inner_unchecked(node) as *mut _;
NonNull::new_unchecked(node)
};

waiters.queue.push_front(node);
queue_guard.queue.push_front(node);
}
drop(waiters);

drop(queue_guard);
drop(old_waker);

Poll::Pending
Expand Down

0 comments on commit 74df9e7

Please sign in to comment.