Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor sync::batch_semaphore::Semaphore::poll_acquire #7046

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 75 additions & 71 deletions tokio/src/sync/batch_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,108 +394,109 @@ impl Semaphore {
}
}

/// Does the actual work of polling an `Acquire` waiting for permits.
/// Tries to acquire as many permits as needed, and no more than the
/// available ones.
///
/// The returned status is `Ready` if all necessary permits have been
/// acquired, or otherwise `Pending`.
fn poll_acquire(
&self,
cx: &mut Context<'_>,
num_permits: usize,
total_needed: usize,
node: Pin<&mut Waiter>,
queued: bool,
) -> Poll<Result<(), AcquireError>> {
let mut acquired = 0;

let needed = if queued {
node.state.load(Acquire) << Self::PERMIT_SHIFT
let still_needed = if queued {
node.state.load(Acquire)
} else {
num_permits << Self::PERMIT_SHIFT
};
total_needed
} << Self::PERMIT_SHIFT;

let mut lock = None;
// First, try to take the requested number of permits from the
// semaphore.
let mut curr = self.permits.load(Acquire);
let mut waiters = loop {
let mut available = self.permits.load(Acquire);
let mut acquired: usize;

// We could acquire the lock right now, but that's going to be
// expensive. To optimize, we try once without holding the mutex
// and only lock it if really necessary.
let mut queue_guard = None;

// CAS loop to acquire _some_ permits from the semaphore. Either the
// waiter gets all needed permits, or it consumes all available ones.
loop {
// Has the semaphore closed?
if curr & Self::CLOSED > 0 {
if available & Self::CLOSED > 0 {
return Poll::Ready(Err(AcquireError::closed()));
}

let mut remaining = 0;
let total = curr
.checked_add(acquired)
.expect("number of permits must not overflow");
let (next, acq) = if total >= needed {
let next = curr - (needed - acquired);
(next, needed >> Self::PERMIT_SHIFT)
} else {
remaining = (needed - acquired) - curr;
(0, curr >> Self::PERMIT_SHIFT)
};

if remaining > 0 && lock.is_none() {
// No permits were immediately available, so this permit will
// (probably) need to wait. We'll need to acquire a lock on the
// wait queue before continuing. We need to do this _before_ the
// CAS that sets the new value of the semaphore's `permits`
// counter. Otherwise, if we subtract the permits and then
// acquire the lock, we might miss additional permits being
// added while waiting for the lock.
lock = Some(self.waiters.lock());
acquired = cmp::min(available, still_needed);
let remaining = available - acquired;

// Not all permits were immediately available, so this waiter will
// (probably) need to wait. We'll need to acquire a lock on the wait
// queue before continuing. We need to do this _before_ the CAS that
// sets the new value of the semaphore's `permits` counter.
// Otherwise, if we subtract the permits and then acquire the lock,
// we might miss additional permits being added while waiting for
// the lock.
if acquired < still_needed && queue_guard.is_none() {
queue_guard = Some(self.waiters.lock());
}

match self.permits.compare_exchange(curr, next, AcqRel, Acquire) {
match self
.permits
.compare_exchange(available, remaining, AcqRel, Acquire)
{
Ok(_) => {
acquired += acq;
if remaining == 0 {
if !queued {
#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
permits = acquired,
permits.op = "sub",
);
tracing::trace!(
target: "runtime::resource::async_op::state_update",
permits_obtained = acquired,
permits.op = "add",
)
});

return Poll::Ready(Ok(()));
} else if lock.is_none() {
break self.waiters.lock();
}
}
break lock.expect("lock must be acquired before waiting");
break;
}
Err(actual) => curr = actual,
Err(actual) => available = actual,
}
};

if waiters.closed {
return Poll::Ready(Err(AcquireError::closed()));
}

#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::state_update",
permits = acquired,
permits = (acquired >> Self::PERMIT_SHIFT),
permits.op = "sub",
)
);
});

// If the waiter gets all the permits at the first try, don't bother
// enqueuing it.
if acquired == still_needed && !queued {
#[cfg(all(tokio_unstable, feature = "tracing"))]
self.resource_span.in_scope(|| {
tracing::trace!(
target: "runtime::resource::async_op::state_update",
permits_obtained = (acquired >> Self::PERMIT_SHIFT),
permits.op = "add",
)
});
return Poll::Ready(Ok(()));
}

let mut queue_guard = queue_guard.unwrap_or_else(|| self.waiters.lock());
if queue_guard.closed {
return Poll::Ready(Err(AcquireError::closed()));
}

acquired >>= Self::PERMIT_SHIFT;

if node.assign_permits(&mut acquired) {
self.add_permits_locked(acquired, waiters);
// If the waiter is happy with the acquired permits, return the
// leftover to the semaphore and bail out.
self.add_permits_locked(acquired, queue_guard);
return Poll::Ready(Ok(()));
}
assert_eq!(acquired, 0, "acquired more permits than necessary");

assert_eq!(acquired, 0);
let mut old_waker = None;

// Otherwise, register the waker & enqueue the node.
node.waker.with_mut(|waker| {
// Safety: the wait list is locked, so we may modify the waker.
// SAFETY: The wait list is locked, so we may modify the waker.
let waker = unsafe { &mut *waker };
// Do we need to register the new waker?
if waker
Expand All @@ -508,14 +509,17 @@ impl Semaphore {

// If the waiter is not already in the wait queue, enqueue it.
if !queued {
// SAFETY: Upholding the contract of the (unsafe) `Link` trait means
// that the pointee of `node` will not be moved by the linked list.
// Moreover, `node` cannot be null because it comes from a `Pin`.
let node = unsafe {
let node = Pin::into_inner_unchecked(node) as *mut _;
NonNull::new_unchecked(node)
};

waiters.queue.push_front(node);
queue_guard.queue.push_front(node);
}
drop(waiters);

drop(queue_guard);
drop(old_waker);

Poll::Pending
Expand Down
Loading