Skip to content

Commit

Permalink
fix: dedupe dead-lock on dropped requests (#2291)
Browse files Browse the repository at this point in the history
Co-authored-by: Tushar Mathur <[email protected]>
  • Loading branch information
meskill and tusharmath authored Jun 28, 2024
1 parent b84c8dc commit fadea96
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 36 deletions.
276 changes: 243 additions & 33 deletions src/core/data_loader/dedupe.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, Weak};

use futures_util::Future;
use tokio::sync::broadcast;
Expand All @@ -11,21 +11,39 @@ impl<A: Send + Sync + Eq + Hash + Clone> Key for A {}
pub trait Value: Send + Sync + Clone {}
impl<A: Send + Sync + Clone> Value for A {}

///
/// Allows deduplication of async operations based on a key.
pub struct Dedupe<Key, Value> {
/// Cache storage for the operations.
cache: Arc<Mutex<HashMap<Key, State<Value>>>>,
/// Initial size of the multi-producer, multi-consumer channel.
size: usize,
/// When enabled allows the operations to be cached forever.
persist: bool,
}

/// Represents the current state of the operation.
enum State<Value> {
Value(Value),
Send(broadcast::Sender<Value>),
/// Means that the operation has been executed and the result is stored.
Ready(Value),

/// Means that the operation is in progress and the result can be sent via
/// the stored sender whenever it's available in the future.
Pending(Weak<broadcast::Sender<Value>>),
}

/// Represents the next steps
enum Step<Value> {
Value(Value),
Recv(broadcast::Receiver<Value>),
Send(broadcast::Sender<Value>),
/// The operation has been executed and the result must be returned.
Return(Value),

/// The operation is in progress and the result must be awaited on the
/// receiver.
Await(broadcast::Receiver<Value>),

/// The operation needs to be executed and the result needs to be sent to
/// the provided sender.
Init(Arc<broadcast::Sender<Value>>),
}

impl<K: Key, V: Value> Dedupe<K, V> {
Expand All @@ -38,36 +56,61 @@ impl<K: Key, V: Value> Dedupe<K, V> {
Fn: FnOnce() -> Fut,
Fut: Future<Output = V>,
{
match self.step(key) {
Step::Value(value) => value,
Step::Recv(mut rx) => rx.recv().await.unwrap(),
Step::Send(tx) => {
let value = or_else().await;
let mut guard = self.cache.lock().unwrap();
if self.persist {
guard.insert(key.to_owned(), State::Value(value.clone()));
} else {
guard.remove(key);
loop {
let value = match self.step(key) {
Step::Return(value) => value,
Step::Await(mut rx) => match rx.recv().await {
Ok(value) => value,
Err(_) => {
// If we get an error that means the task with
// owned tx (sender) was dropped.i.e. there is no result in cache
// and we can try another attempt because probably another
// task will do the execution
continue;
}
},
Step::Init(tx) => {
let value = or_else().await;
let mut guard = self.cache.lock().unwrap();
if self.persist {
guard.insert(key.to_owned(), State::Ready(value.clone()));
} else {
guard.remove(key);
}
let _ = tx.send(value.clone());
value
}
let _ = tx.send(value.clone());
value
}
};

return value;
}
}

fn step(&self, key: &K) -> Step<V> {
let mut this = self.cache.lock().unwrap();
match this.get(key) {
Some(state) => match state {
State::Value(value) => Step::Value(value.clone()),
State::Send(tx) => Step::Recv(tx.subscribe()),
},
None => {
let (tx, _) = broadcast::channel(self.size);
this.insert(key.to_owned(), State::Send(tx.clone()));
Step::Send(tx.clone())

if let Some(state) = this.get(key) {
match state {
State::Ready(value) => return Step::Return(value.clone()),
State::Pending(tx) => {
// We can upgrade from Weak to Arc only in case when
// original tx is still alive
// otherwise we will create in the code below
if let Some(tx) = tx.upgrade() {
return Step::Await(tx.subscribe());
}
}
}
}

let (tx, _) = broadcast::channel(self.size);
let tx = Arc::new(tx);
// Store a Weak version of tx and pass actual tx to further handling
// to control if tx is still alive and will be able to handle the request.
// Only single `strong` reference to tx should exist so we can
// understand when the execution is still alive and we'll get the response
this.insert(key.to_owned(), State::Pending(Arc::downgrade(&tx)));
Step::Init(tx)
}
}

Expand All @@ -91,19 +134,21 @@ impl<K: Key, V: Value, E: Value> DedupeResult<K, V, E> {

#[cfg(test)]
mod tests {
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;

use assert_eq;
use tokio::join;
use tokio::time::sleep;
use tokio::time::{sleep, timeout_at, Instant};

use super::*;

#[tokio::test]
async fn test_no_key() {
let cache = Arc::new(Dedupe::<u64, u64>::new(1000, true));
let actual = cache.dedupe(&1, || Box::pin(async { 1 })).await;
pretty_assertions::assert_eq!(actual, 1);
assert_eq!(actual, 1);
}

#[tokio::test]
Expand All @@ -112,7 +157,7 @@ mod tests {
cache.dedupe(&1, || Box::pin(async { 1 })).await;

let actual = cache.dedupe(&1, || Box::pin(async { 2 })).await;
pretty_assertions::assert_eq!(actual, 1);
assert_eq!(actual, 1);
}

#[tokio::test]
Expand All @@ -124,7 +169,7 @@ mod tests {
}

let actual = cache.dedupe(&1, || Box::pin(async { 2 })).await;
pretty_assertions::assert_eq!(actual, 0);
assert_eq!(actual, 0);
}

#[tokio::test]
Expand All @@ -145,7 +190,7 @@ mod tests {
});
let (a, b) = join!(a, b);

pretty_assertions::assert_eq!(a, b);
assert_eq!(a, b);
}

async fn compute_value(counter: Arc<AtomicUsize>) -> String {
Expand Down Expand Up @@ -185,4 +230,169 @@ mod tests {
"compute_value was called more than once"
);
}

#[tokio::test]
async fn test_hanging_after_dropped() {
let cache = Arc::new(Dedupe::<u64, ()>::new(100, true));

let task = cache.dedupe(&1, move || async move {
sleep(Duration::from_millis(100)).await;
});

// drops the task since the underlying sleep timeout is higher than the
// timeout here

timeout_at(Instant::now() + Duration::from_millis(10), task)
.await
.expect_err("Should throw timeout error");

cache
.dedupe(&1, move || async move {
sleep(Duration::from_millis(100)).await;
})
.await;
}

#[tokio::test]
async fn test_hanging_dropped_while_in_use() {
let cache = Arc::new(Dedupe::<u64, u64>::new(100, true));
let cache_1 = cache.clone();
let cache_2 = cache.clone();

let task_1 = tokio::spawn(async move {
cache_1
.dedupe(&1, move || async move {
sleep(Duration::from_millis(100)).await;
100
})
.await
});

let task_2 = tokio::spawn(async move {
cache_2
.dedupe(&1, move || async move {
sleep(Duration::from_millis(100)).await;
200
})
.await
});

sleep(Duration::from_millis(10)).await;

// drop the first task
task_1.abort();

let actual = task_2.await.unwrap();
assert_eq!(actual, 200)
}

// TODO: This is a failing test
#[tokio::test]
#[ignore]
async fn test_should_not_abort_call_1() {
#[derive(Debug, PartialEq, Clone)]
struct Status {
// Set this in the first call
call_1: bool,

// Set this in the second call
call_2: bool,
}

let status = Arc::new(Mutex::new(Status { call_1: false, call_2: false }));

let cache = Arc::new(Dedupe::<u64, ()>::new(100, true));
let cache_1 = cache.clone();
let cache_2 = cache.clone();
let status_1 = status.clone();
let status_2 = status.clone();

// Task 1 completed in 100ms
let task_1 = tokio::spawn(async move {
cache_1
.dedupe(&1, move || async move {
sleep(Duration::from_millis(100)).await;
status_1.lock().unwrap().call_1 = true;
})
.await
});

// Wait for 10ms
sleep(Duration::from_millis(10)).await;

// Task 2 completed in 200ms
tokio::spawn(async move {
cache_2
.dedupe(&1, move || async move {
sleep(Duration::from_millis(120)).await;
status_2.lock().unwrap().call_2 = true;
})
.await
});

// Wait for 10ms
sleep(Duration::from_millis(10)).await;

// Abort the task_1
task_1.abort();

sleep(Duration::from_millis(300)).await;

// Task 1 should still have completed because others are dependent on it.
let actual = status.lock().unwrap().deref().to_owned();
assert_eq!(actual, Status { call_1: true, call_2: false })
}

#[tokio::test]
async fn test_should_abort_all() {
#[derive(Debug, PartialEq, Clone)]
struct Status {
// Set this in the first call
call_1: bool,

// Set this in the second call
call_2: bool,
}

let status = Arc::new(Mutex::new(Status { call_1: false, call_2: false }));

let cache = Arc::new(Dedupe::<u64, ()>::new(100, true));
let cache_1 = cache.clone();
let cache_2 = cache.clone();
let status_1 = status.clone();
let status_2 = status.clone();

// Task 1 completed in 100ms
let task_1 = tokio::spawn(async move {
cache_1
.dedupe(&1, move || async move {
sleep(Duration::from_millis(100)).await;
status_1.lock().unwrap().call_1 = true;
})
.await
});

// Task 2 completed in 150ms
let task_2 = tokio::spawn(async move {
cache_2
.dedupe(&1, move || async move {
sleep(Duration::from_millis(150)).await;
status_2.lock().unwrap().call_2 = true;
})
.await
});

// Wait for 10ms
sleep(Duration::from_millis(50)).await;

// Abort the task_1 & task_2
task_1.abort();
task_2.abort();

sleep(Duration::from_millis(300)).await;

// No task should have completed
let actual = status.lock().unwrap().deref().to_owned();
assert_eq!(actual, Status { call_1: false, call_2: false })
}
}
4 changes: 1 addition & 3 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ pub trait HttpIO: Sync + Send + 'static {
async fn execute(
&self,
request: reqwest::Request,
) -> anyhow::Result<Response<hyper::body::Bytes>> {
self.execute(request).await
}
) -> anyhow::Result<Response<hyper::body::Bytes>>;
}

#[async_trait::async_trait]
Expand Down

0 comments on commit fadea96

Please sign in to comment.