diff --git a/src/aggregator/mod.rs b/src/aggregator/mod.rs index f8ffbfc..55523f0 100644 --- a/src/aggregator/mod.rs +++ b/src/aggregator/mod.rs @@ -7,10 +7,7 @@ mod spot; use crate::aggregator::spot::check_spot_termination_status; use crate::epoch::EpochConfig; -use crate::models::{ - begin_db_transaction, commit_db_transaction, DBConnectionType, DBPool, DBStorageConnections, - PgStoreError, -}; +use crate::models::{DBConnectionType, DBPool, DBStorageConnections, PgStoreError}; use crate::profiler::{Profiler, ProfilerStat}; use crate::record_stream::{ get_data_channel_topic_from_env, KafkaRecordStream, KafkaRecordStreamConfig, RecordStream, @@ -214,21 +211,15 @@ pub async fn start_aggregation( let mut out_stream = create_output_stream(output_measurements_to_stdout, channel_name)?; if let Some(out_stream) = out_stream.as_mut() { out_stream.init_producer_queues().await; - out_stream.begin_producer_transaction()?; } let db_conn = Arc::new(Mutex::new(db_pool.get().await?)); - begin_db_transaction(db_conn.clone())?; process_expired_epochs( db_conn.clone(), &epoch_config, - out_stream.as_ref().map(|v| v.as_ref()), + out_stream.as_ref(), profiler.clone(), ) .await?; - if let Some(out_stream) = out_stream.as_ref() { - wait_and_commit_producer(out_stream).await?; - } - commit_db_transaction(db_conn)?; info!("Profiler summary:\n{}", profiler.summary().await); info!("Finished aggregation"); diff --git a/src/aggregator/processing.rs b/src/aggregator/processing.rs index 1bcc32c..34b7eb2 100644 --- a/src/aggregator/processing.rs +++ b/src/aggregator/processing.rs @@ -2,13 +2,14 @@ use super::group::{GroupedMessages, MessageChunk}; use super::recovered::RecoveredMessages; use super::report::report_measurements; use super::AggregatorError; +use crate::aggregator::wait_and_commit_producer; use crate::epoch::EpochConfig; use crate::models::{ - DBConnection, DBPool, DBStorageConnections, MessageWithThreshold, PendingMessage, - RecoveredMessage, + begin_db_transaction, commit_db_transaction, DBConnection, DBPool, DBStorageConnections, + MessageWithThreshold, PendingMessage, RecoveredMessage, }; use crate::profiler::{Profiler, ProfilerStat}; -use crate::record_stream::{DynRecordStream, RecordStreamArc}; +use crate::record_stream::RecordStreamArc; use crate::star::{recover_key, recover_msgs, AppSTARError, MsgRecoveryInfo}; use star_constellation::api::NestedMessage; use star_constellation::Error as ConstellationError; @@ -20,7 +21,7 @@ use tokio::task::JoinHandle; pub async fn process_expired_epochs( conn: Arc>, epoch_config: &EpochConfig, - out_stream: Option<&DynRecordStream>, + out_stream: Option<&RecordStreamArc>, profiler: Arc, ) -> Result<(), AggregatorError> { let epochs = RecoveredMessage::list_distinct_epochs(conn.clone()).await?; @@ -29,6 +30,11 @@ pub async fn process_expired_epochs( continue; } info!("Detected expired epoch '{}', processing...", epoch); + if let Some(out_stream) = out_stream.as_ref() { + out_stream.begin_producer_transaction()?; + } + begin_db_transaction(conn.clone())?; + let mut rec_msgs = RecoveredMessages::default(); rec_msgs .fetch_all_recovered_with_nonzero_count(conn.clone(), epoch as u8, profiler.clone()) @@ -39,12 +45,17 @@ pub async fn process_expired_epochs( epoch_config, epoch as u8, true, - out_stream, + out_stream.as_ref().map(|v| v.as_ref()), profiler.clone(), ) .await?; RecoveredMessage::delete_epoch(conn.clone(), epoch, profiler.clone()).await?; PendingMessage::delete_epoch(conn.clone(), epoch, profiler.clone()).await?; + + if let Some(out_stream) = out_stream.as_ref() { + wait_and_commit_producer(out_stream).await?; + } + commit_db_transaction(conn.clone())?; } Ok(()) }