Skip to content

Commit

Permalink
Merge pull request #34 from superfly/watch-to-sub
Browse files Browse the repository at this point in the history
Rename watch -> subscription
  • Loading branch information
jeromegn authored Aug 24, 2023
2 parents 7635917 + 961f009 commit f9c83ad
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 107 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

81 changes: 51 additions & 30 deletions crates/corro-agent/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
api::{
client::{api_v1_db_schema, api_v1_queries, api_v1_transactions},
peer::{bidirectional_sync, gossip_client_endpoint, gossip_server_endpoint, SyncError},
pubsub::{api_v1_watch_by_id, api_v1_watches, MatcherCache},
pubsub::{api_v1_sub_by_id, api_v1_subs, MatcherCache},
},
broadcast::runtime_loop,
transport::{ConnectError, Transport},
Expand Down Expand Up @@ -418,6 +418,13 @@ pub async fn run(agent: Agent, opts: AgentOptions) -> eyre::Result<()> {
}
}
}

if buf.capacity() >= 64 * 1024 {
info!(
"big buf from processing unidirectional stream: {}",
buf.capacity()
);
}
}
});
}
Expand Down Expand Up @@ -459,35 +466,48 @@ pub async fn run(agent: Agent, opts: AgentOptions) -> eyre::Result<()> {
return;
}
Some(res) => match res {
Ok(b) => match BiPayload::read_from_buffer(&b) {
Ok(payload) => {
match payload {
BiPayload::V1(
BiPayloadV1::SyncState(state),
) => {
// println!("got sync state: {state:?}");
if let Err(e) = bidirectional_sync(
&agent,
generate_sync(
agent.bookie(),
agent.actor_id(),
),
Some(state),
framed.into_inner(),
tx,
)
.await
{
warn!("could not complete bidirectional sync: {e}");
Ok(b) => {
if b.capacity() >= 64 * 1024 {
info!(
"big buf from processing bidirectional stream: {}",
b.capacity()
);
}
match BiPayload::read_from_buffer(&b) {
Ok(payload) => {
match payload {
BiPayload::V1(
BiPayloadV1::SyncState(state),
) => {
// println!("got sync state: {state:?}");
if let Err(e) =
bidirectional_sync(
&agent,
generate_sync(
agent.bookie(),
agent.actor_id(),
),
Some(state),
framed.into_inner(),
tx,
)
.await
{
warn!("could not complete bidirectional sync: {e}");
}
break;
}
break;
}
}

Err(e) => {
warn!(
"could not decode BiPayload: {e}"
);
}
}
Err(e) => {
warn!("could not decode BiPayload: {e}");
}
},
}

Err(e) => {
error!("could not read framed payload from bidirectional stream: {e}");
}
Expand Down Expand Up @@ -722,8 +742,8 @@ pub async fn run(agent: Agent, opts: AgentOptions) -> eyre::Result<()> {
),
)
.route(
"/v1/watches",
post(api_v1_watches).route_layer(
"/v1/subscriptions",
post(api_v1_subs).route_layer(
tower::ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_error: BoxError| async {
Ok::<_, Infallible>((
Expand All @@ -736,8 +756,8 @@ pub async fn run(agent: Agent, opts: AgentOptions) -> eyre::Result<()> {
),
)
.route(
"/v1/watches/:id",
get(api_v1_watch_by_id).route_layer(
"/v1/subscriptions/:id",
get(api_v1_sub_by_id).route_layer(
tower::ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_error: BoxError| async {
Ok::<_, Infallible>((
Expand Down Expand Up @@ -999,6 +1019,7 @@ async fn handle_gossip_to_send(transport: Transport, mut to_send_rx: Receiver<(A
});

increment_counter!("corro.gossip.send.count", "actor_id" => actor.id().to_string());
gauge!("corro.gossip.send.buffer.capacity", buf.capacity() as f64);
}
}

Expand Down
7 changes: 7 additions & 0 deletions crates/corro-agent/src/api/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,13 @@ pub async fn bidirectional_sync(

debug!(actor_id = %agent.actor_id(), "done writing sync messages (count: {count})");

if buf.capacity() >= 64 * 1024 {
info!(
"big buffer from bidirectional sync sender: {}",
buf.capacity()
);
}

Ok::<_, SyncError>(count)
},
async move {
Expand Down
48 changes: 25 additions & 23 deletions crates/corro-agent/src/api/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use uuid::Uuid;

pub async fn api_v1_watch_by_id(
pub async fn api_v1_sub_by_id(
Extension(agent): Extension<Agent>,
axum::extract::Path(id): axum::extract::Path<Uuid>,
) -> impl IntoResponse {
watch_by_id(agent, id).await
sub_by_id(agent, id).await
}

async fn watch_by_id(agent: Agent, id: Uuid) -> hyper::Response<hyper::Body> {
async fn sub_by_id(agent: Agent, id: Uuid) -> hyper::Response<hyper::Body> {
let matcher = match { agent.matchers().read().get(&id).cloned() } {
Some(matcher) => matcher,
None => {
return hyper::Response::builder()
.status(StatusCode::NOT_FOUND)
.body(
serde_json::to_vec(&QueryEvent::Error(format_compact!(
"could not find watcher with id {id}"
"could not find subscription with id {id}"
)))
.expect("could not serialize queries stream error")
.into(),
Expand All @@ -56,7 +56,7 @@ async fn watch_by_id(agent: Agent, id: Uuid) -> hyper::Response<hyper::Body> {
let change_rx = matcher.subscribe();
let cancel = matcher.cancel();

tokio::spawn(process_watch_channel(
tokio::spawn(process_sub_channel(
agent.clone(),
id,
tx,
Expand Down Expand Up @@ -143,7 +143,7 @@ async fn watch_by_id(agent: Agent, id: Uuid) -> hyper::Response<hyper::Body> {
.expect("could not build query response body")
}

async fn process_watch_channel(
async fn process_sub_channel(
agent: Agent,
matcher_id: Uuid,
mut tx: hyper::body::Sender,
Expand Down Expand Up @@ -229,15 +229,17 @@ async fn process_watch_channel(

match recv.try_recv() {
Ok(new) => query_evt = new,
Err(e) => match e {
TryRecvError::Empty => break,
TryRecvError::Closed => break,

TryRecvError::Lagged(lagged) => {
error!("change recv lagged by {lagged}, stopping watch processing");
return;
Err(e) => {
match e {
TryRecvError::Empty => break,
TryRecvError::Closed => break,

TryRecvError::Lagged(lagged) => {
error!("change recv lagged by {lagged}, stopping subscription processing");
return;
}
}
},
}
}
}
}
Expand Down Expand Up @@ -361,9 +363,9 @@ async fn expand_sql(

pub type MatcherCache = Arc<TokioRwLock<HashMap<String, Uuid>>>;

pub async fn api_v1_watches(
pub async fn api_v1_subs(
Extension(agent): Extension<Agent>,
Extension(watch_cache): Extension<MatcherCache>,
Extension(subscription_cache): Extension<MatcherCache>,
axum::extract::Json(stmt): axum::extract::Json<Statement>,
) -> impl IntoResponse {
let stmt = match expand_sql(&agent, &stmt).await {
Expand All @@ -380,15 +382,15 @@ pub async fn api_v1_watches(
}
};

let matcher_id = { watch_cache.read().await.get(&stmt).cloned() };
let matcher_id = { subscription_cache.read().await.get(&stmt).cloned() };

if let Some(matcher_id) = matcher_id {
let contains = { agent.matchers().read().contains_key(&matcher_id) };
if contains {
info!("reusing matcher id {matcher_id}");
return watch_by_id(agent, matcher_id).await;
return sub_by_id(agent, matcher_id).await;
} else {
watch_cache.write().await.remove(&stmt);
subscription_cache.write().await.remove(&stmt);
}
}

Expand Down Expand Up @@ -441,10 +443,10 @@ pub async fn api_v1_watches(

{
agent.matchers().write().insert(matcher_id, matcher.clone());
watch_cache.write().await.insert(stmt, matcher_id);
subscription_cache.write().await.insert(stmt, matcher_id);
}

tokio::spawn(process_watch_channel(
tokio::spawn(process_sub_channel(
agent.clone(),
matcher_id,
tx,
Expand Down Expand Up @@ -477,7 +479,7 @@ mod tests {
use super::*;

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_api_v1_watches() -> eyre::Result<()> {
async fn test_api_v1_subs() -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();

let (tripwire, _tripwire_worker, _tripwire_tx) = Tripwire::new_simple();
Expand Down Expand Up @@ -548,7 +550,7 @@ mod tests {

assert!(body.0.results.len() == 2);

let res = api_v1_watches(
let res = api_v1_subs(
Extension(agent.clone()),
Extension(Default::default()),
axum::Json(Statement::Simple("select * from tests".into())),
Expand Down
8 changes: 8 additions & 0 deletions crates/corro-agent/src/broadcast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,14 @@ pub fn runtime_loop(
bcast_tx.capacity() as f64
);
gauge!("corro.broadcast.pending.count", idle_pendings.len() as f64);
gauge!(
"corro.broadcast.buffer.capacity",
bcast_buf.capacity() as f64
);
gauge!(
"corro.broadcast.serialization.buffer.capacity",
ser_buf.capacity() as f64
);
}
}

Expand Down
16 changes: 8 additions & 8 deletions crates/corro-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl CorrosionApiClient {
Ok(res.into_body())
}

pub async fn watch(
pub async fn subscribe(
&self,
statement: &Statement,
) -> Result<
Expand All @@ -79,7 +79,7 @@ impl CorrosionApiClient {
> {
let req = hyper::Request::builder()
.method(hyper::Method::POST)
.uri(format!("http://{}/v1/watches", self.api_addr))
.uri(format!("http://{}/v1/subscriptions", self.api_addr))
.header(hyper::header::CONTENT_TYPE, "application/json")
.header(hyper::header::ACCEPT, "application/json")
.body(Body::from(serde_json::to_vec(statement)?))?;
Expand All @@ -97,16 +97,16 @@ impl CorrosionApiClient {
.and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()))
.ok_or(Error::ExpectedQueryId)?;

Ok((id, watch_stream(res.into_body())))
Ok((id, subscription_stream(res.into_body())))
}

pub async fn watched_query(
pub async fn subscription(
&self,
id: Uuid,
) -> Result<impl Stream<Item = io::Result<QueryEvent>> + Send + Sync + 'static, Error> {
let req = hyper::Request::builder()
.method(hyper::Method::GET)
.uri(format!("http://{}/v1/watches/{}", self.api_addr, id))
.uri(format!("http://{}/v1/subscriptions/{}", self.api_addr, id))
.header(hyper::header::ACCEPT, "application/json")
.body(hyper::Body::empty())?;

Expand All @@ -116,7 +116,7 @@ impl CorrosionApiClient {
return Err(Error::UnexpectedStatusCode(res.status()));
}

Ok(watch_stream(res.into_body()))
Ok(subscription_stream(res.into_body()))
}

pub async fn execute(&self, statements: &[Statement]) -> Result<RqliteResponse, Error> {
Expand Down Expand Up @@ -240,7 +240,7 @@ impl CorrosionApiClient {
}
}

fn watch_stream(body: hyper::Body) -> impl Stream<Item = io::Result<QueryEvent>> {
fn subscription_stream(body: hyper::Body) -> impl Stream<Item = io::Result<QueryEvent>> {
let body = StreamReader::new(body.map_err(|e| {
if let Some(io_error) = e
.source()
Expand Down Expand Up @@ -426,6 +426,6 @@ pub enum Error {
#[error("unexpected result: {0:?}")]
UnexpectedResult(RqliteResult),

#[error("could not retrieve watch id from headers")]
#[error("could not retrieve subscription id from headers")]
ExpectedQueryId,
}
Loading

0 comments on commit f9c83ad

Please sign in to comment.