From 7b550aab4fe0bb7d4fc7ef740d41445c812e5a7e Mon Sep 17 00:00:00 2001 From: WarrenZhu050413 Date: Mon, 19 May 2025 11:33:56 +0800 Subject: [PATCH 1/2] Added proactive heartbeat timeout failure propagation (#164) (#188) --- Cargo.toml | 2 + README.md | 32 +++ proto/torchft.proto | 14 ++ src/lib.rs | 101 +++++++- src/lighthouse.rs | 464 +++++++++++++++++++++++++++++++++++-- src/manager.rs | 15 +- torchft/_torchft.pyi | 16 +- torchft/data.py | 10 +- torchft/data_test.py | 2 +- torchft/lighthouse_test.py | 69 ++++++ torchft/manager.py | 209 ++++++++++++++++- torchft/manager_test.py | 148 +++++++++++- train_ddp.py | 2 +- train_ddp_proactive.py | 218 +++++++++++++++++ 14 files changed, 1254 insertions(+), 48 deletions(-) create mode 100644 train_ddp_proactive.py diff --git a/Cargo.toml b/Cargo.toml index 0c6ae6e9..ec90c111 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,9 @@ slog-stdlog = "4.1.1" stderrlog = "0.6.0" structopt = "0.3.26" tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] } +tokio-stream = {version = "0.1.14", features = ["sync"]} tonic = "0.12.2" +futures-core = "0.3" [build-dependencies] tonic-build = "0.12.2" diff --git a/README.md b/README.md index cb07b47c..ff3cfeab 100644 --- a/README.md +++ b/README.md @@ -246,6 +246,38 @@ CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --mast By observing the outputs from both shells, you should observe process group reconfiguration and live checkpoint recovery. +### Proactive Failure Recovery Mode (Experimental) + +You can experiment with proactive failure recovery mode by: + +```sh +export TORCHFT_PROACTIVE_RECOVERY=1 +``` + +With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce. + +You can test this out by running `train_ddp_proactive.py` + +On shell 1 (one replica groups starts initial training): +```sh +export REPLICA_GROUP_ID=0 +export NUM_REPLICA_GROUPS=2 +export TORCHFT_PROACTIVE_RECOVERY=1 + +CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py +``` + +On shell 2 (a second replica group joins): +```sh +export REPLICA_GROUP_ID=1 +export NUM_REPLICA_GROUPS=2 +export TORCHFT_PROACTIVE_RECOVERY=1 + +CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py +``` + +You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting `export TORCHFT_PROACTIVE_RECOVERY=0`, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing. + ### Example Parameter Server torchft has a fault tolerant parameter server implementation built on it's diff --git a/proto/torchft.proto b/proto/torchft.proto index 7c086eb9..cf7c403d 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -67,9 +67,17 @@ message LighthouseHeartbeatRequest { message LighthouseHeartbeatResponse {} +message SubscribeFailuresRequest {} + +message FailureNotification { + string replica_id = 1; + string error_message = 2; +} + service LighthouseService { rpc Quorum (LighthouseQuorumRequest) returns (LighthouseQuorumResponse); rpc Heartbeat (LighthouseHeartbeatRequest) returns (LighthouseHeartbeatResponse); + rpc SubscribeFailures (SubscribeFailuresRequest) returns (stream FailureNotification); } message ManagerQuorumRequest { @@ -126,3 +134,9 @@ service ManagerService { rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse); rpc Kill(KillRequest) returns (KillResponse); } + +message LighthouseClientRequest { + string replica_id = 1; +} + + diff --git a/src/lib.rs b/src/lib.rs index 32a7a37e..5e9c53a7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ mod timeout; use anyhow::Result; use atty::Stream; use core::time::Duration; -use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; +use pyo3::exceptions::{PyRuntimeError, PyStopIteration, PyTimeoutError}; use std::cmp; use std::env; use std::sync::Arc; @@ -21,6 +21,7 @@ use std::thread::available_parallelism; use structopt::StructOpt; use tokio::runtime::Runtime; use tokio::task::JoinHandle; +use tokio_stream::StreamExt; use tonic::transport::Channel; use tonic::Status; @@ -35,11 +36,13 @@ pub mod torchftpb { use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; use crate::torchftpb::{ - CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest, - ManagerQuorumRequest, ShouldCommitRequest, + CheckpointMetadataRequest, FailureNotification as ProtoFailureNotification, + LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ShouldCommitRequest, + SubscribeFailuresRequest, }; use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; +use pyo3::{PyRef, PyRefMut}; // Get the number of threads to use for the tokio runtime fn num_threads() -> usize { @@ -290,6 +293,45 @@ struct QuorumResult { heal: bool, } +#[pyclass(unsendable)] +struct FailureStream { + runtime: Arc, + stream: tonic::Streaming, + timeout: Duration, +} + +#[pymethods] +impl FailureStream { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult { + let runtime = slf.runtime.clone(); + let timeout = slf.timeout; + // borrow stream mutably for the whole async block + let fut = async { tokio::time::timeout(timeout, slf.stream.next()).await }; + + match runtime.block_on(fut) { + Ok(Some(Ok(note))) => Ok(FailureNotification { + replica_id: note.replica_id, + error_message: note.error_message, + }), + Ok(Some(Err(status))) => Err(StatusError(status).into()), + Ok(None) => Err(PyStopIteration::new_err(())), + Err(_) => Err(PyTimeoutError::new_err( + "Timeout waiting for failure notification", + )), + } + } +} + +#[pyclass(get_all, set_all)] +#[derive(Clone)] +struct FailureNotification { + replica_id: String, + error_message: String, +} + #[pymethods] impl QuorumResult { #[new] @@ -478,7 +520,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult { #[pyclass] struct LighthouseClient { client: LighthouseServiceClient, - runtime: Runtime, + runtime: Arc, } #[pymethods] @@ -487,11 +529,13 @@ impl LighthouseClient { #[new] fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult { py.allow_threads(move || { - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(num_threads()) - .thread_name("torchft-lhclnt") - .enable_all() - .build()?; + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_threads()) + .thread_name("torchft-lhclnt") + .enable_all() + .build()?, + ); let client = runtime .block_on(manager::lighthouse_client_new(addr, connect_timeout)) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -586,6 +630,22 @@ impl LighthouseClient { Ok(()) }) } + + #[pyo3(signature = (timeout = Duration::from_secs(5)))] + fn subscribe_failures(&self, py: Python<'_>, timeout: Duration) -> PyResult { + py.allow_threads(move || { + let req = tonic::Request::new(SubscribeFailuresRequest {}); + let response = self + .runtime + .block_on(self.client.clone().subscribe_failures(req)) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(FailureStream { + runtime: self.runtime.clone(), + stream: response.into_inner(), + timeout: timeout, + }) + }) + } } /// LighthouseServer is a GRPC server for the lighthouse service. @@ -610,7 +670,7 @@ struct LighthouseServer { #[pymethods] impl LighthouseServer { - #[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))] + #[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None, failure_tick_ms=None))] #[new] fn new( py: Python<'_>, @@ -619,10 +679,12 @@ impl LighthouseServer { join_timeout_ms: Option, quorum_tick_ms: Option, heartbeat_timeout_ms: Option, + failure_tick_ms: Option, ) -> PyResult { let join_timeout_ms = join_timeout_ms.unwrap_or(100); let quorum_tick_ms = quorum_tick_ms.unwrap_or(100); let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000); + let failure_tick_ms = failure_tick_ms.unwrap_or(1000); py.allow_threads(move || { let rt = tokio::runtime::Builder::new_multi_thread() @@ -638,6 +700,7 @@ impl LighthouseServer { join_timeout_ms: join_timeout_ms, quorum_tick_ms: quorum_tick_ms, heartbeat_timeout_ms: heartbeat_timeout_ms, + failure_tick_ms: failure_tick_ms, })) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -663,6 +726,22 @@ impl LighthouseServer { self.handle.abort(); }) } + + /// inject_failure broadcasts a failure notification for the given replica. + /// + /// This helper is intended for testing `subscribe_failures` from Python. + #[pyo3(signature = (replica_id))] + fn inject_failure(&self, py: Python<'_>, replica_id: String) { + let lighthouse = self.lighthouse.clone(); + let runtime = &self._runtime; + py.allow_threads(move || { + let _ = runtime.block_on(async { + if let Err(e) = lighthouse.inject_failure(replica_id).await { + eprintln!("Failed to inject failure: {}", e); + } + }); + }); + } } struct StatusError(Status); @@ -750,6 +829,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?; Ok(()) diff --git a/src/lighthouse.rs b/src/lighthouse.rs index a5760032..063fcd29 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -12,6 +12,7 @@ use std::sync::Arc; use std::time::Duration; use std::time::{Instant, SystemTime}; +use crate::torchftpb::FailureNotification; use anyhow::{anyhow, Result}; use askama::Template; use axum::{ @@ -28,16 +29,23 @@ use tokio::sync::broadcast; use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio::time::interval; +use tokio_stream::wrappers::{ + errors::BroadcastStreamRecvError as TokioStreamBroadcastStreamRecvError, BroadcastStream, +}; +use tokio_stream::StreamExt; use tonic::service::Routes; use tonic::transport::server::TcpIncoming; use tonic::transport::Server; use tonic::{Request, Response, Status}; +use futures_core::Stream; +use std::pin::Pin; + use crate::manager::manager_client_new; use crate::torchftpb::{ lighthouse_service_server::{LighthouseService, LighthouseServiceServer}, KillRequest, LighthouseHeartbeatRequest, LighthouseHeartbeatResponse, LighthouseQuorumRequest, - LighthouseQuorumResponse, Quorum, QuorumMember, + LighthouseQuorumResponse, Quorum, QuorumMember, SubscribeFailuresRequest, }; #[derive(Clone)] @@ -47,14 +55,28 @@ struct QuorumMemberDetails { } struct State { - channel: broadcast::Sender, + quorum_channel: broadcast::Sender, + // Tracks currently active participants in the process of forming a quorum. + // Replicas are added upon receiving a `LighthouseQuorumRequest`. + // Replicas are cleared after a quorum is successfully formed OR + // removed by `_failure_tick` if their heartbeat expires. participants: HashMap, prev_quorum: Option, quorum_id: i64, - // heartbeat information - // replica_id -> last heartbeat + // Stores the last heartbeat time for each replica ID. + // Replicas are added/updated upon receiving `LighthouseHeartbeatRequest` or `LighthouseQuorumRequest`. + // Replicas are removed by `_failure_tick` if their heartbeat expires and a failure notification is sent. heartbeats: HashMap, + + // Stores the timestamp of when a replica was first detected as failed (heartbeat expired). + // This is used to ensure only one `FailureNotification` is sent per failure event. + // Replicas are added by `_failure_tick` upon detecting a new failure. + // Replicas are removed by `_failure_tick` if a subsequent heartbeat is received (signifying recovery). + failures: HashMap, + + // Broadcast channel for sending failure notifications to subscribers. + pub failure_channel: broadcast::Sender, } pub struct Lighthouse { @@ -83,7 +105,7 @@ impl ChangeLogger { } } -#[derive(StructOpt, Debug)] +#[derive(StructOpt, Debug, Clone)] #[structopt()] pub struct LighthouseOpt { // bind is the address to bind the server to. @@ -120,6 +142,13 @@ pub struct LighthouseOpt { help = "How long to wait for a heartbeat before considering a replica dead." )] pub heartbeat_timeout_ms: u64, + + #[structopt( + long = "failure_tick_ms", + default_value = "1000", + help = "How frequently to check for failures." + )] + pub failure_tick_ms: u64, } fn quorum_changed(a: &Vec, b: &Vec) -> bool { @@ -265,14 +294,45 @@ impl Lighthouse { let listener = tokio::net::TcpListener::bind(&opt.bind).await?; let (tx, _) = broadcast::channel(16); + let (failure_tx, failure_rx) = broadcast::channel::(16); + + // Create a task to monitor the failure channel + let mut failure_rx_cloned: broadcast::Receiver = + failure_rx.resubscribe(); + tokio::spawn(async move { + use tokio::time::{sleep, Duration}; + info!("Starting permanent failure channel subscriber"); + loop { + match failure_rx_cloned.recv().await { + Ok(note) => { + info!( + "Healthy replicas received failure notification for {} with error message: {}", + note.replica_id, + note.error_message + ); + } + Err(e) => { + error!("Healthy replicas error: {}", e); + // If the channel is closed, break the loop + if matches!(e, tokio::sync::broadcast::error::RecvError::Closed) { + break; + } + } + } + sleep(Duration::from_millis(100)).await; // Prevent thrashing if there are continuous errors + } + info!("Permanent failure channel subscriber exiting"); + }); Ok(Arc::new(Self { state: Mutex::new(State { participants: HashMap::new(), - channel: tx, + quorum_channel: tx, prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: failure_tx, }), opt: opt, local_addr: listener.local_addr()?, @@ -326,7 +386,7 @@ impl Lighthouse { state.prev_quorum = Some(quorum.clone()); state.participants.clear(); - match state.channel.send(quorum) { + match state.quorum_channel.send(quorum) { Ok(_) => (), Err(e) => error!("failed to send quorum {}", e), } @@ -391,6 +451,76 @@ impl Lighthouse { .map_err(|e| e.into()) } + async fn _run_failure_tick(self: Arc) -> Result<()> { + let mut interval = interval(Duration::from_millis(self.opt.failure_tick_ms)); + loop { + interval.tick().await; // Wait for the next tick + let mut state = self.state.lock().await; + self.clone()._failure_tick(&mut state)?; + } + } + + fn _failure_tick(self: Arc, state: &mut State) -> Result<()> { + let now = Instant::now(); + let timeout = Duration::from_millis(self.opt.heartbeat_timeout_ms); + + // Use a temporary list to collect replica IDs to remove from heartbeats + // to avoid modifying the map while iterating over it. + let mut failed_replica_ids_to_remove_from_heartbeats = Vec::new(); + let mut failure_detected = false; + + for (replica_id, last_heartbeat) in state.heartbeats.iter() { + if now.duration_since(*last_heartbeat) > timeout { + if !state.failures.contains_key(replica_id) { + info!( + "Replica {} timed out (last heartbeat: {:?}), sending failure notification.", + replica_id, + last_heartbeat + ); + if let Err(e) = state.failure_channel.send(FailureNotification { + replica_id: replica_id.clone(), + error_message: "heartbeat timeout".to_string(), + }) { + error!( + "Failed to send failure notification for {}: {} (receiver count: {})", + replica_id, + e, + state.failure_channel.receiver_count() + ); + } else { + failure_detected = true; // Set flag if notification sent successfully + } + // Record failure information + state.failures.insert(replica_id.clone(), now); + state.participants.remove(replica_id); + failed_replica_ids_to_remove_from_heartbeats.push(replica_id.clone()); + } + } else { + // If the participant sends heartbeat again, remove it from failures. + if state.failures.remove(replica_id).is_some() { + info!("Replica {} recovered from failure.", replica_id); + } + } + } + + // Remove failed replicas from heartbeats + for replica_id in failed_replica_ids_to_remove_from_heartbeats { + state.heartbeats.remove(&replica_id); + info!( + "Removed replica {} from heartbeats and participants due to timeout.", + replica_id + ); + } + + // If a new failure was detected and broadcasted, reset participants to restart quorum formation + if failure_detected { + info!("New failure detected, resetting all participants for quorum formation."); + state.participants.clear(); + } + + Ok(()) + } + pub async fn run(self: Arc) -> Result<()> { let mut set = JoinSet::new(); @@ -398,6 +528,8 @@ impl Lighthouse { set.spawn(self.clone()._run_grpc()); + set.spawn(self.clone()._run_failure_tick()); + while let Some(res) = set.join_next().await { res??; } @@ -469,6 +601,18 @@ impl Lighthouse { Ok(()) } + + pub async fn inject_failure(self: Arc, replica_id: String) -> Result<()> { + let state = self.state.lock().await; + state + .failure_channel + .send(FailureNotification { + replica_id, + error_message: "injected failure".to_string(), + }) + .map_err(|e| anyhow!("Failed to send failure notification: {}", e))?; + Ok(()) + } } #[tonic::async_trait] @@ -502,7 +646,7 @@ impl LighthouseService for Arc { member: requester.clone(), }, ); - let rx = state.channel.subscribe(); + let rx = state.quorum_channel.subscribe(); // proactively run quorum tick self.clone() @@ -556,6 +700,35 @@ impl LighthouseService for Arc { let reply = LighthouseHeartbeatResponse {}; Ok(Response::new(reply)) } + + type SubscribeFailuresStream = + Pin> + Send + 'static>>; + + async fn subscribe_failures( + &self, + _req: Request, + ) -> Result, Status> { + // clone a receiver + let rx = { + let state = self.state.lock().await; + let receiver_count = state.failure_channel.receiver_count(); + info!( + "subscribe_failures: Creating new subscriber (current count: {})", + receiver_count + ); + state.failure_channel.subscribe() + }; + + // Wrap the receiver; map its *internal* error into `tonic::Status` + let stream = BroadcastStream::new(rx).filter_map(|res| match res { + Ok(note) => Some(Ok(note)), + Err(TokioStreamBroadcastStreamRecvError::Lagged(n)) => Some(Err( + Status::resource_exhausted(format!("client lagged {n} messages")), + )), + }); + + Ok(Response::new(Box::pin(stream))) + } } #[derive(Template)] @@ -605,6 +778,8 @@ where mod tests { use super::*; use std::ops::Sub; + use tokio::sync::broadcast::error::RecvError as TokioBroadcastRecvError; + use tokio::time::timeout as tokio_timeout; use tonic::transport::Channel; @@ -624,14 +799,17 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, }; let now = Instant::now(); @@ -703,14 +881,17 @@ mod tests { join_timeout_ms: 0, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, }; let now = Instant::now(); @@ -789,14 +970,17 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, }; let now = Instant::now(); @@ -879,14 +1063,17 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, }; let now = Instant::now(); @@ -974,6 +1161,7 @@ mod tests { join_timeout_ms: 1, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let lighthouse = Lighthouse::new(opt).await?; @@ -1020,14 +1208,17 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, }; let now = Instant::now(); @@ -1103,6 +1294,185 @@ mod tests { assert!(quorum_changed(&a, &c)); } + // Helper to create a default QuorumMember for tests + fn test_quorum_member(replica_id: &str) -> QuorumMember { + QuorumMember { + replica_id: replica_id.to_string(), + address: format!("addr_{}", replica_id), + store_address: format!("store_{}", replica_id), + step: 1, + world_size: 2, // Assuming 2 for this test context + shrink_only: false, + data: String::new(), + commit_failures: 0, + } + } + + /// Test that `_failure_tick` correctly identifies timed-out replicas, + /// broadcasts a failure notification exactly once per failure, and + /// cleans up the replica from `heartbeats` and `participants` while + /// adding it to `failures`. Subsequent ticks should not re-notify + /// or change the state for an already failed replica. + #[tokio::test] + async fn test_failure_tick_single_notification_and_cleanup() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 0, // Not relevant for this test + quorum_tick_ms: 10, // Not directly relevant but keep it small + heartbeat_timeout_ms: 100, // Reasonably short for testing + failure_tick_ms: 50, // How often _failure_tick would be called + }; + let lighthouse = Lighthouse::new(opt.clone()).await?; + + let mut failure_rx = { + let state_guard = lighthouse.state.lock().await; + state_guard.failure_channel.subscribe() + }; + + let replica_id_failing = "failing_one"; + + let now = Instant::now(); + // Ensure expired_time is definitively older than heartbeat_timeout_ms + let expired_time = now - Duration::from_millis(opt.heartbeat_timeout_ms * 2); + + // Setup initial state: one about to fail + { + let mut state_guard = lighthouse.state.lock().await; + let state = &mut *state_guard; + + // Failing replica + state.participants.insert( + replica_id_failing.to_string(), + QuorumMemberDetails { + joined: now, // Joined time doesn't prevent failure due to heartbeat + member: test_quorum_member(replica_id_failing), + }, + ); + state + .heartbeats + .insert(replica_id_failing.to_string(), expired_time); + } + + // --- First call to _failure_tick --- + // This call should detect the failure, send a notification, and update state. + { + let mut state_guard = lighthouse.state.lock().await; + lighthouse.clone()._failure_tick(&mut *state_guard)?; + } + + // Assertions after first tick + // 1. Check notification for failing_replica + match tokio_timeout( + Duration::from_millis(opt.failure_tick_ms * 2), + failure_rx.recv(), + ) + .await + { + Ok(Ok(notification)) => { + assert_eq!( + notification.replica_id, replica_id_failing, + "Notification should be for the failing replica" + ); + } + Ok(Err(TokioBroadcastRecvError::Lagged(n))) => { + panic!( + "Broadcast channel lagged by {} messages, missed the failure notification", + n + ); + } + Ok(Err(TokioBroadcastRecvError::Closed)) => { + panic!("Broadcast channel closed unexpectedly after first tick"); + } + Err(_) => panic!( + "Did not receive failure notification for {} in time", + replica_id_failing + ), + } + + // 2. Verify state changes + { + let state_guard = lighthouse.state.lock().await; + let state = &*state_guard; + + // Failing replica assertions + assert!( + state.failures.contains_key(replica_id_failing), + "{} should be in failures map", + replica_id_failing + ); + assert!( + !state.heartbeats.contains_key(replica_id_failing), + "{} should be removed from heartbeats", + replica_id_failing + ); + assert!( + !state.participants.contains_key(replica_id_failing), + "{} should be removed from participants", + replica_id_failing + ); + } + + // --- Second call to _failure_tick --- + // This call should *not* detect a *new* failure for the same replica + // and should not send another notification. + { + let mut state_guard = lighthouse.state.lock().await; + lighthouse.clone()._failure_tick(&mut *state_guard)?; + } + + // Assertions after second tick + // 1. No new notification for failing_replica + match tokio_timeout( + Duration::from_millis(opt.failure_tick_ms * 2), + failure_rx.recv(), + ) + .await + { + Ok(Ok(notification)) => { + panic!( + "Received unexpected second failure notification for {}", + notification.replica_id + ); + } + Ok(Err(TokioBroadcastRecvError::Lagged(n))) => { + // This might happen if the test environment is slow and ticks are processed faster than receives. + // For this specific assertion (no *new* message), lagging is an acceptable outcome. + info!("Broadcast channel lagged by {} messages on second check, implies no new distinct message.", n); + } + Ok(Err(TokioBroadcastRecvError::Closed)) => { + // Channel might close if sender is dropped, implies no new message. + info!("Broadcast channel closed on second check, implies no new distinct message."); + } + Err(_) => { + // Expected: Timeout, meaning no new message was received for failing_replica. + } + } + + // 2. Verify state remains consistent for failing_replica + { + let state_guard = lighthouse.state.lock().await; + let state = &*state_guard; + + assert!( + state.failures.contains_key(replica_id_failing), + "{} should remain in failures map", + replica_id_failing + ); + assert!( + !state.heartbeats.contains_key(replica_id_failing), + "{} should remain removed from heartbeats", + replica_id_failing + ); + assert!( + !state.participants.contains_key(replica_id_failing), + "{} should remain removed from participants", + replica_id_failing + ); + } + Ok(()) + } + #[tokio::test] async fn test_lighthouse_join_during_shrink() -> Result<()> { fn create_member(id: &str, addr_num: &str, step: i64, shrink_only: bool) -> QuorumMember { @@ -1130,6 +1500,7 @@ mod tests { join_timeout_ms: 1000, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; // Start the lighthouse service @@ -1237,6 +1608,7 @@ mod tests { join_timeout_ms: 1000, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }; // Start the lighthouse service @@ -1281,4 +1653,70 @@ mod tests { lighthouse_task.abort(); Ok(()) } + + #[tokio::test] + async fn test_lighthouse_subscribe_failures_basic() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60 * 60 * 1000, // 1hr + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + }; + + let lighthouse = Lighthouse::new(opt).await?; + let lighthouse_task = tokio::spawn(lighthouse.clone().run()); + + let mut client = lighthouse_client_new(lighthouse.address()).await?; + let request = tonic::Request::new(SubscribeFailuresRequest {}); + client.subscribe_failures(request).await?; + + lighthouse_task.abort(); + Ok(()) + } + + #[tokio::test] + async fn test_subscribe_failures_delivers_notifications() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60 * 60 * 1000, + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + }; + let lighthouse = Lighthouse::new(opt).await?; + let mut client = lighthouse_client_new(lighthouse.address()).await?; + let lighthouse_task = tokio::spawn(lighthouse.clone().run()); + + // 1. Subscribe with a deadline + let mut req = tonic::Request::new(SubscribeFailuresRequest {}); + req.set_timeout(Duration::from_secs(5)); + let mut stream = client.subscribe_failures(req).await?.into_inner(); + + // 2. Trigger a failure notification + { + let state = lighthouse.state.lock().await; + state + .failure_channel + .send(FailureNotification { + replica_id: "replica_id_X".into(), + error_message: "injected failure".to_string(), + }) + .unwrap(); + } + + // 3. Ensure we receive it + match stream.next().await { + Some(Ok(note)) => { + assert_eq!(note.replica_id, "replica_id_X"); + assert_eq!(note.error_message, "injected failure"); + } + other => panic!("Expected notification, got {:?}", other), + } + + lighthouse_task.abort(); + Ok(()) + } } diff --git a/src/manager.rs b/src/manager.rs index e28cbeb5..affda55d 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -15,6 +15,7 @@ use tokio::sync::broadcast; use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio::time::sleep; +use tokio::time::timeout as tokio_timeout; use tonic::transport::server::TcpIncoming; use tonic::transport::Channel; use tonic::transport::Server; @@ -54,7 +55,7 @@ macro_rules! info_with_replica { struct ManagerState { checkpoint_metadata: HashMap, - channel: broadcast::Sender, + quorum_channel: broadcast::Sender, participants: HashMap, should_commit_channel: broadcast::Sender, @@ -126,7 +127,7 @@ impl Manager { heartbeat_interval: heartbeat_interval, state: Mutex::new(ManagerState { checkpoint_metadata: HashMap::new(), - channel: tx, + quorum_channel: tx, participants: HashMap::new(), should_commit_channel: should_commit_tx, @@ -204,7 +205,7 @@ impl Manager { }); lighthouse_request.set_timeout(timeout); - let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request)) + let response = tokio_timeout(timeout, client.quorum(lighthouse_request)) .await .unwrap_or_else(|e| { Err(Status::cancelled(format!( @@ -217,7 +218,7 @@ impl Manager { info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp); state - .channel + .quorum_channel .send( resp.quorum .ok_or_else(|| Status::internal("missing quorum"))?, @@ -273,7 +274,7 @@ impl ManagerService for Arc { }; // TODO check step state.participants.insert(group_rank, member.clone()); - let rx = state.channel.subscribe(); + let rx = state.quorum_channel.subscribe(); self._run_quorum(&mut state, member, timeout).await?; @@ -550,6 +551,7 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -597,6 +599,7 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -652,6 +655,7 @@ mod tests { min_replicas: 2, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -724,6 +728,7 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 9614d1b0..01d1f9e7 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -11,8 +11,8 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, - commit_failures: int, init_sync: bool = True, + commit_failures: int = 0, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( @@ -60,9 +60,11 @@ class LighthouseServer: join_timeout_ms: Optional[int] = None, quorum_tick_ms: Optional[int] = None, heartbeat_timeout_ms: Optional[int] = None, + failure_tick_ms: Optional[int] = None, ) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ... + def inject_failure(self, replica_id: str) -> None: ... @dataclass class QuorumMember: @@ -85,6 +87,14 @@ class Quorum: participants: List[QuorumMember] created: Timestamp +@dataclass +class FailureNotification: + replica_id: str + +class FailureStream: + def __iter__(self) -> "FailureStream": ... + def __next__(self) -> FailureNotification: ... + @dataclass class LighthouseClient: addr: str @@ -106,3 +116,7 @@ class LighthouseClient: replica_id: str, timeout: timedelta = timedelta(seconds=5), ) -> None: ... + def subscribe_failures( + self, + timeout: timedelta = timedelta(seconds=5), + ) -> FailureStream: ... diff --git a/torchft/data.py b/torchft/data.py index 02e5b3be..77ec1de7 100644 --- a/torchft/data.py +++ b/torchft/data.py @@ -38,15 +38,15 @@ class DistributedSampler(data.distributed.DistributedSampler): This will shard the input dataset into ``num_replicas*num_replica_group`` number of shards. - Each shard rank is calculated via: ``rank + num_replicas*replica_rank`` + Each shard rank is calculated via: ``rank + num_replicas*replica_group_id`` - num_replicas and replica_rank must be the same on all workers. + num_replicas and replica_group_id must be the same on all workers. """ def __init__( self, dataset: data.Dataset, - replica_rank: int, + replica_group_id: int, num_replica_groups: int, group_rank: Optional[int] = None, num_replicas: Optional[int] = None, @@ -55,7 +55,7 @@ def __init__( """ Args: data: the dataset to use - replica_rank: the group ID (0-num_replica_groups) to use for this shard of data. + replica_group_id: the group ID (0-num_replica_groups) to use for this shard of data. num_replica_groups: the max number of global replica groups rank: the local group rank num_replicas: the local group world size @@ -65,7 +65,7 @@ def __init__( if num_replicas is None: num_replicas = dist.get_world_size() - self.global_rank: int = group_rank + num_replicas * replica_rank + self.global_rank: int = group_rank + num_replicas * replica_group_id self.global_world_size: int = num_replicas * num_replica_groups super().__init__( diff --git a/torchft/data_test.py b/torchft/data_test.py index 8dae190e..5b7c6b6e 100644 --- a/torchft/data_test.py +++ b/torchft/data_test.py @@ -27,7 +27,7 @@ def test_distributed_sampler(self) -> None: dataset = DummyDataset(1000) sampler = DistributedSampler( dataset, - replica_rank=1, + replica_group_id=1, num_replica_groups=2, group_rank=3, num_replicas=4, diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py index 067a6222..bbe3a974 100644 --- a/torchft/lighthouse_test.py +++ b/torchft/lighthouse_test.py @@ -155,3 +155,72 @@ def test_heartbeat_round_trip(self) -> None: finally: lighthouse.shutdown() + + def test_subscribe_failures(self) -> None: + """Test that subscribe_failures can be called without raising an exception.""" + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + ) + try: + client = LighthouseClient( + addr=lighthouse.address(), + connect_timeout=timedelta(seconds=1), + ) + stream = client.subscribe_failures(timeout=timedelta(milliseconds=100)) + finally: + lighthouse.shutdown() + + def test_subscribe_failures_notification(self) -> None: + """Test that failure notifications are delivered to subscribers.""" + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + ) + try: + client = LighthouseClient( + addr=lighthouse.address(), + connect_timeout=timedelta(seconds=1), + ) + stream = client.subscribe_failures(timeout=timedelta(seconds=1)) + lighthouse.inject_failure("nodeX") + note = next(stream) + assert note.replica_id == "nodeX" + finally: + lighthouse.shutdown() + + def test_inject_failure(self) -> None: + """Test that inject failure delivers a failure notification to subscribers""" + # Start a lighthouse server + server = LighthouseServer( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=100, + ) + print(f"Server address: {server.address()}") + + # Create a client to subscribe to failures + client = LighthouseClient(server.address(), timedelta(seconds=5)) + failure_stream = client.subscribe_failures(timedelta(seconds=5)) + + # Inject a failure + replica_id = "test_replica" + print(f"Injecting failure for replica: {replica_id}") + server.inject_failure(replica_id) + + # Wait a bit for the notification to be processed + time.sleep(1) + + # Try to get the failure notification + try: + notification = next(failure_stream) + print( + f"Received failure notification for replica: {notification.replica_id}" + ) + assert notification.replica_id == replica_id, "Received wrong replica_id" + print("Test passed!") + except Exception as e: + print(f"Error: {e}") + + # Clean up + server.shutdown() diff --git a/torchft/manager.py b/torchft/manager.py index 2c1c6406..ae48cfd4 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -24,25 +24,29 @@ and Hybrid FSDP. """ - import concurrent.futures import logging +import multiprocessing import os import socket +import threading +import time import traceback import uuid from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext from datetime import timedelta from enum import Enum +from multiprocessing.connection import Connection from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast import torch from torch.distributed import ReduceOp, TCPStore -from torchft._torchft import ManagerClient, ManagerServer +from torchft._torchft import LighthouseClient, ManagerClient, ManagerServer from torchft.checkpointing import CheckpointTransport, HTTPTransport from torchft.futures import future_timeout +from torchft.multiprocessing import _MonitoredPipe if TYPE_CHECKING: from torchft.process_group import ProcessGroup @@ -103,6 +107,7 @@ def __init__( timeout: timedelta = timedelta(seconds=60), quorum_timeout: timedelta = timedelta(seconds=60), connect_timeout: timedelta = timedelta(seconds=60), + proactive_recovery_subscribe_timeout: timedelta = timedelta(milliseconds=100), rank: Optional[int] = None, world_size: Optional[int] = None, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, @@ -116,6 +121,7 @@ def __init__( checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, init_sync: bool = True, max_retries: Optional[int] = None, + proactive_recovery: bool = False, ) -> None: """ Args: @@ -166,6 +172,9 @@ def __init__( self._timeout = timeout self._quorum_timeout = quorum_timeout self._connect_timeout = connect_timeout + self._proactive_recovery_subscribe_timeout = ( + proactive_recovery_subscribe_timeout + ) self._replica_world_size_mode = world_size_mode self._init_sync = init_sync self._max_retries = max_retries @@ -187,9 +196,7 @@ def __init__( self._checkpoint_transport: CheckpointTransport[Dict[str, T]] = ( checkpoint_transport ) - self._executor = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="async_quorum" - ) + self._executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="") self._quorum_future: Optional[concurrent.futures.Future] = None self._store = TCPStore( @@ -205,12 +212,57 @@ def __init__( torch.cuda.Stream() if torch.cuda.is_available() else None ) + lighthouse_addr: Optional[str] = lighthouse_addr + if os.environ.get("TORCHFT_LIGHTHOUSE") is not None: + lighthouse_addr = ( + lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"] + ) # Else error in tests, since TORCHFT_LIGHTHOUSE may not be set + + self._proactive_recovery = proactive_recovery or int( + os.environ.get("TORCHFT_PROACTIVE_RECOVERY", 0) + ) + + if lighthouse_addr is not None and self._proactive_recovery: + ctx = multiprocessing.get_context("spawn") + error_local, error_remote = ctx.Pipe() + self._error_pipe = _MonitoredPipe(error_local) + self._error_remote = _MonitoredPipe(error_remote) + self._failure_listener_stop_event = ctx.Event() + + self._failure_listener_process = ctx.Process( + target=_failure_listener_process_main, + args=( + lighthouse_addr, + self._connect_timeout, + self._failure_listener_stop_event, + error_remote, + self._proactive_recovery_subscribe_timeout, + ), + daemon=True, + ) + self._failure_listener_process.start() + else: + self._failure_listener_process = None + self._error_pipe = None + self._failure_listener_stop_event = None + + # Initialize and start the error processing thread if the listener process is active + self._error_processor_thread: Optional[threading.Thread] = None + self._error_processor_stop_event: Optional[threading.Event] = None + if self._failure_listener_process is not None: + self._error_processor_stop_event = threading.Event() + self._error_processor_thread = threading.Thread( + target=self._error_processor_loop, + name="TorchFTErrorProcessor", + daemon=True, + ) + self._error_processor_thread.start() + if self._group_rank == 0: if port is None: port = int(os.environ.get(MANAGER_PORT_ENV, 0)) bind = f"[::]:{port}" - lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"] # We need a unique identifier in the case that a worker restarts quickly and # replaces the previous worker with the same ID. @@ -219,6 +271,7 @@ def __init__( replica_id = new_uuid else: replica_id = f"{replica_id}:{new_uuid}" + self._manager = ManagerServer( replica_id=replica_id, lighthouse_addr=lighthouse_addr, @@ -229,13 +282,11 @@ def __init__( heartbeat_interval=heartbeat_interval, connect_timeout=connect_timeout, ) - self._store.set(MANAGER_ADDR_KEY, self._manager.address()) self._store.set(REPLICA_ID_KEY, replica_id) addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8") self._client = ManagerClient(addr, connect_timeout=connect_timeout) - replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8") self._logger = _ManagerLogger( manager=self, replica_id=replica_id or "", group_rank=group_rank @@ -258,13 +309,96 @@ def set_state_dict_fns( self._load_state_dict = load_state_dict self._user_state_dict = state_dict + def _error_handler(self, err): + self._logger.info(f"Received error: {err}") + self.report_error(err) + self._pg.abort() + + def _error_processor_loop(self) -> None: + """Continuously checks the error pipe from the listener process and reports errors.""" + assert ( + self._error_pipe is not None + ), "Error pipe must be initialized for error processor loop." + assert ( + self._error_processor_stop_event is not None + ), "Stop event must be initialized for error processor loop." + + try: + while not self._error_processor_stop_event.is_set(): + try: + item = self._error_pipe.recv(0.1) + except TimeoutError: + continue + except OSError: + break + except Exception as e: + self._error_handler(e) + finally: + pass + def shutdown(self, wait: bool = True) -> None: """ Shutdown the manager and checkpoint server. """ - self._checkpoint_transport.shutdown(wait=wait) if self._manager is not None: self._manager.shutdown() + + # Stop the error processor thread first + if ( + self._error_processor_thread is not None + and self._error_processor_stop_event is not None + ): + self._logger.info("Setting error processor thread stop event") + self._error_processor_stop_event.set() + if wait: + self._logger.info("Waiting for error processor thread to complete") + try: + self._error_processor_thread.join(timeout=5) # Short timeout + if self._error_processor_thread.is_alive(): + self._logger.warn( + "Error processor thread did not terminate in time." + ) + else: + self._logger.info("Error processor thread shutdown completed.") + except Exception as e: + self._logger.warn(f"Error waiting for error processor thread: {e}") + + # Stop the failure listener process if it exists + if ( + hasattr(self, "_failure_listener_process") + and self._failure_listener_process is not None + ): + self._logger.info("Setting failure listener stop event for process") + if ( + hasattr(self, "_failure_listener_stop_event") + and self._failure_listener_stop_event is not None + ): + self._failure_listener_stop_event.set() + + if wait: + self._logger.info("Waiting for failure listener process to complete") + try: + self._failure_listener_process.join(timeout=10) # Process join + if self._failure_listener_process.is_alive(): + self._logger.warn( + "Failure listener process did not terminate, attempting to terminate." + ) + self._failure_listener_process.terminate() # Force terminate if join times out + self._failure_listener_process.join( + timeout=1 + ) # Wait for terminate + else: + self._logger.info("Failure listener process shutdown completed") + except Exception as e: + self._logger.warn( + f"Error waiting for/terminating failure listener process: {e}" + ) + + # Clean up pipe + if hasattr(self, "_error_pipe") and self._error_pipe is not None: + self._error_pipe.close() + + self._checkpoint_transport.shutdown(wait=wait) self._executor.shutdown(wait=wait) def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]: @@ -824,3 +958,60 @@ def warn(self, msg: str) -> None: def exception(self, msg: str) -> None: self._logger.exception(f"{self.prefix()} {msg}") + + +def _failure_listener_process_main( + lighthouse_addr_str: Optional[str], + connect_timeout: timedelta, + stop_event: multiprocessing.Event, + error_pipe: Connection, + subscribe_timeout: timedelta = timedelta(milliseconds=100), +): + """ + Background process that monitors lighthouse for failures through gRPC stream (with an iterator interface) and reports them via error_pipe. + """ + if not lighthouse_addr_str: + return + + while not stop_event.is_set(): + try: + lighthouse_client = LighthouseClient( + lighthouse_addr_str, connect_timeout=connect_timeout + ) + stream = lighthouse_client.subscribe_failures(timeout=subscribe_timeout) + while not stop_event.is_set(): + try: + note = next( + stream + ) # This will block until a new item or timeout if stream supports it + if note: + if stop_event.is_set(): + break + error = Exception( + f"Peer failure detected in listener process: replica {note.replica_id} has failed" + ) + error_pipe.send(ExceptionWithTraceback(error)) + except StopIteration: + # Stream has ended, break out to outer loop to reconnect + if not stop_event.is_set(): + logging.warning( + "Failure Listener: Stream ended unexpectedly, attempting to reconnect..." + ) + break # Break the inner loop to reconnect + else: + break + except Exception as e_stream: + if not stop_event.is_set(): + continue # Break due to subscribe_timeout. Allows the process to check stop_event again. + else: + break + if stop_event.is_set(): + break + time.sleep(0.01) # Prevent CPU thrashing + except Exception as e_outer: + if not stop_event.is_set(): + logging.warning( + f"Failure Listener: Connection error: {e_outer}, retrying in 1 second..." + ) + time.sleep(1) + pass diff --git a/torchft/manager_test.py b/torchft/manager_test.py index bb058e4e..2fb0373b 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -5,18 +5,34 @@ # LICENSE file in the root directory of this source tree. import concurrent +import multiprocessing +import time +from dataclasses import dataclass from datetime import timedelta from typing import Optional from unittest import TestCase from unittest.mock import MagicMock, create_autospec, patch import torch +import torch.distributed as dist from torch.distributed import TCPStore -from torchft._torchft import QuorumResult +from torchft._torchft import ( + FailureStream, + LighthouseClient, + LighthouseServer, + QuorumResult, +) from torchft.checkpointing.transport import CheckpointTransport -from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode -from torchft.process_group import ProcessGroup, _DummyWork +from torchft.manager import ( + MANAGER_ADDR_KEY, + REPLICA_ID_KEY, + ExceptionWithTraceback, + Manager, + WorldSizeMode, + _failure_listener_process_main, +) +from torchft.process_group import ProcessGroup, ProcessGroupGloo, _DummyWork def mock_should_commit( @@ -43,6 +59,7 @@ def _create_manager( timeout: timedelta = timedelta(seconds=10), init_sync: bool = True, max_retries: Optional[int] = None, + proactive_recovery: bool = False, ) -> Manager: pg = create_autospec(ProcessGroup) pg.errored.return_value = None @@ -72,6 +89,7 @@ def _create_manager( timeout=timeout, init_sync=init_sync, max_retries=max_retries, + proactive_recovery=proactive_recovery, ) self.manager = manager return manager @@ -773,3 +791,127 @@ def test_max_retries(self, client_mock: MagicMock) -> None: # This should succeed and reset the counter self.assertTrue(manager.should_commit()) self.assertEqual(manager._commit_failures, 0) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_manager_error_handler(self, client_mock: MagicMock) -> None: + """Test that the Manager correctly processes exceptions sent from the failure_listener_process.""" + # Create a manager + manager = self._create_manager() + + # Create an exception simulating what would be sent from _failure_listener_process_main + error = Exception("Peer failure detected: replica failed_replica has failed") + exception = ExceptionWithTraceback(error) + + # Directly test the error handling mechanism + manager._error_handler(error) + + # Verify the error was properly processed + captured_error = manager.errored() + self.assertIsNotNone(captured_error) + self.assertEqual(str(captured_error.original_exception), str(error)) + + def test_direct_error_pipe(self) -> None: + """Test sending an exception to the Manager's _error_pipe.""" + # Create a manager with proactive_recovery=True to ensure it has an error pipe + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=100, + ) + + # Create a manager that tries to join + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=lambda x: None, + state_dict=lambda: None, + replica_id=f"lighthouse_test", + store_addr="localhost", + store_port=store.port, + rank=0, + world_size=1, + use_async_quorum=False, + lighthouse_addr=lighthouse.address(), + proactive_recovery=True, + ) + + # Make sure the error pipe is created + self.assertIsNotNone(manager._error_pipe, "Manager should have an error pipe") + time.sleep(1) + # Create a mock error message + mock_error_msg = "Test failure detected from direct pipe test" + test_exception = Exception(mock_error_msg) + + # Create an ExceptionWithTraceback and send it through the pipe + exc_with_tb = ExceptionWithTraceback(test_exception) + manager._error_remote.send(exc_with_tb) + + # Wait a short time for the error processor thread to process the message + time.sleep(1) + + # Verify that the error was properly processed by the Manager + error_obj = manager.errored() + self.assertIsNotNone( + error_obj, "Error should have been captured by the Manager" + ) + + # Clean up + manager.shutdown(wait=True) + + def test_manager_failure_e2e(self) -> None: + """Test that the Manager correctly handles errors from the failure_listener_process.""" + # Create a manager with proactive_recovery=True to ensure it has an error pipe + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=100, + ) + + # Create a manager that tries to join + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=lambda x: None, + state_dict=lambda: None, + replica_id=f"lighthouse_test", + store_addr="localhost", + store_port=store.port, + rank=0, + world_size=1, + use_async_quorum=False, + lighthouse_addr=lighthouse.address(), + proactive_recovery=True, + ) + + time.sleep(1.5) + + failed_replica_id = "failed_replica" + lighthouse.inject_failure(failed_replica_id) + + time.sleep(1.5) # Prevent flakyness + error_obj = manager.errored() + + # Verify that the manager received the error notification + self.assertIsNotNone(error_obj, "Manager should have captured the failure") + self.assertIn( + failed_replica_id, + str(error_obj.original_exception), + f"Error should mention the failed replica: {error_obj.original_exception}", + ) + + # Clean up resources + manager.shutdown(wait=True) diff --git a/train_ddp.py b/train_ddp.py index fd79b8ad..96c2c139 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -51,7 +51,7 @@ def main() -> None: # majority of groups will be available so few batches will be dropped. sampler = DistributedSampler( trainset, - replica_group=REPLICA_GROUP_ID, + replica_group_id=REPLICA_GROUP_ID, num_replica_groups=NUM_REPLICA_GROUPS, group_rank=0, # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. diff --git a/train_ddp_proactive.py b/train_ddp_proactive.py new file mode 100644 index 00000000..3d0002ce --- /dev/null +++ b/train_ddp_proactive.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +import time +from datetime import timedelta + +REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) +os.environ["CUDA_VISIBLE_DEVICES"] = str(REPLICA_GROUP_ID % 4) +os.environ["NCCL_HOSTID"] = str(REPLICA_GROUP_ID) + +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch import nn, optim +from torch.distributed.elastic.multiprocessing.errors import record +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchft import ( + DistributedDataParallel, + DistributedSampler, + Manager, + Optimizer, + ProcessGroupGloo, + ProcessGroupNCCL, +) +from torchft.checkpointing.pg_transport import PGTransport + +logging.basicConfig(level=logging.INFO) + + +@record +def main() -> None: + REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) + NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + trainset = torchvision.datasets.CIFAR10( + root="./cifar", train=True, download=True, transform=transform + ) + + # This shards the training set across all ranks and replica groups. We manage + # the dataloaders on a per replica group basis with the assumption that the + # majority of groups will be available so few batches will be dropped. + sampler = DistributedSampler( + trainset, + replica_group_id=REPLICA_GROUP_ID, + num_replica_groups=NUM_REPLICA_GROUPS, + group_rank=0, + # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. + num_replicas=1, + shuffle=True, + ) + + # This uses the torchdata StatefulDataLoader to be able to checkpoint and + # restore the per worker dataloader position. + trainloader = StatefulDataLoader( + trainset, batch_size=64, num_workers=2, sampler=sampler + ) + + def load_state_dict(state_dict): + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict(): + return { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + + device = "cuda" if torch.cuda.is_available() else "cpu" + pg = ( + ProcessGroupNCCL( + timeout=timedelta(seconds=30), + ) + if torch.cuda.is_available() + else ProcessGroupGloo(timeout=timedelta(seconds=5)) + ) + + transport = PGTransport( + pg, + timeout=timedelta(seconds=10), + device=("cuda" if torch.cuda.is_available() else "cpu"), + ) + + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=30), + checkpoint_transport=transport, + ) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.cnn = nn.Sequential( + nn.Conv2d(3, 6, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + ) + + final_dim = 10 + # We add a useless 1GB intermediate layer so we spend more time in dist + # communication so injected failures are more likely to cause issues + # if they exist. + target_size = 1_000_000_000 + self.useless = nn.Embedding(target_size // final_dim // 4, final_dim) + + self.classifier = nn.Sequential( + nn.Linear(16 * 5 * 5, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, final_dim), + ) + + def forward(self, x): + x = self.cnn(x) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = self.classifier(x) + x += self.useless.weight[0] + return x + + m = Net().to(device) + m = DistributedDataParallel(manager, m) + optimizer = Optimizer(manager, optim.AdamW(m.parameters())) + + print(m) + num_params = sum(p.numel() for p in m.parameters()) + print(f"Total number of parameters: {num_params}") + + sort_by_keyword = "self_" + device + "_time_total" + + def trace_handler(p): + output = p.key_averages().table( + sort_by=sort_by_keyword, + row_limit=100, + ) + print(output) + p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json") + + # You can use an epoch based training but with faults it's easier to use step + # based training. + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=5, warmup=1, active=10, repeat=2), + on_trace_ready=trace_handler, + record_shapes=True, + profile_memory=True, + ) + + prof.start() + while True: + for i, (inputs, labels) in enumerate(trainloader): + prof.step() + + time.sleep(0.5) # Else each iteration runs too quickly + + inputs = inputs.to(device) + labels = labels.to(device) + + # must be called at the beginning of each train loop + # Quorum computation is triggered here but only needed in the backwards pass. + optimizer.zero_grad() + + out = m(inputs) + criterion = nn.CrossEntropyLoss() + loss = criterion(out, labels) + + # Gradient allreduce overlaps with the backwards pass. + loss.backward() + if manager.current_step() == 3: + if REPLICA_GROUP_ID == 0: + manager.shutdown() + exit(0) + # If proactive recovery, then the surviving process will reconfigure + # If not proactive recovery, then the surviving process will wait until timeout + + test_tensor = torch.tensor([1.0]).to(device) + manager.allreduce(test_tensor) + + # must be called at the end of the train loop + # This may not actually step the optimizer if an error occured during grad allreduce. + optimizer.step() + + if manager.current_step() % 100 == 0: + print(f"[{manager.current_step()}] loss = {loss.item()}") + + # TODO (by the user): periodically checkpoint model, optim, manager and dataloader + + # You typically want to checkpoint dataloader frequently (every step?) to + # avoid repeated batches as it's replica group specific. + + # Model, optim and manager checkpoints can be done more infrequently as + # they're shared across all groups and will load from existing replicas as + # long as not every worker goes down. + + if manager.current_step() >= 10000: + # complete training + prof.stop() + exit() + + +if __name__ == "__main__": + main() From f5ee70455dde0353d0d64eaa60710bd3b8cc38f5 Mon Sep 17 00:00:00 2001 From: WarrenZhu050413 Date: Thu, 22 May 2025 08:52:50 +0800 Subject: [PATCH 2/2] Added example training scripts for localsgd, DiLoCo, Live Checkpoint Recovery, and proactive failure detection with DDP, along with CI (#198) --- .github/workflows/examples.yaml | 58 +++++ README.md | 45 +--- examples/README.md | 37 +++ examples/ddp_proactive/.torchxconfig | 7 + examples/ddp_proactive/README.md | 177 +++++++++++++ .../ddp_proactive/train_ddp_proactive.py | 30 ++- examples/diloco/.torchxconfig | 7 + examples/diloco/README.md | 67 +++++ examples/diloco/train_diloco.py | 240 ++++++++++++++++++ .../live_checkpoint_recovery/.torchxconfig | 7 + examples/live_checkpoint_recovery/README.md | 102 ++++++++ .../live_checkpoint_recovery/train_ddp_lcr.py | 221 ++++++++++++++++ examples/localsgd/.torchxconfig | 7 + examples/localsgd/README.md | 73 ++++++ examples/localsgd/train_localsgd.py | 227 +++++++++++++++++ examples/test_examples.py | 200 +++++++++++++++ examples/utils/utils.py | 143 +++++++++++ torchft/local_sgd.py | 2 +- torchft/manager.py | 1 + 19 files changed, 1600 insertions(+), 51 deletions(-) create mode 100644 .github/workflows/examples.yaml create mode 100644 examples/README.md create mode 100644 examples/ddp_proactive/.torchxconfig create mode 100644 examples/ddp_proactive/README.md rename train_ddp_proactive.py => examples/ddp_proactive/train_ddp_proactive.py (88%) create mode 100644 examples/diloco/.torchxconfig create mode 100644 examples/diloco/README.md create mode 100644 examples/diloco/train_diloco.py create mode 100644 examples/live_checkpoint_recovery/.torchxconfig create mode 100644 examples/live_checkpoint_recovery/README.md create mode 100644 examples/live_checkpoint_recovery/train_ddp_lcr.py create mode 100644 examples/localsgd/.torchxconfig create mode 100644 examples/localsgd/README.md create mode 100644 examples/localsgd/train_localsgd.py create mode 100644 examples/test_examples.py create mode 100644 examples/utils/utils.py diff --git a/.github/workflows/examples.yaml b/.github/workflows/examples.yaml new file mode 100644 index 00000000..ec2185c7 --- /dev/null +++ b/.github/workflows/examples.yaml @@ -0,0 +1,58 @@ +name: Examples + +on: + push: + branches: + - main + pull_request: + +jobs: + unittest: + strategy: + fail-fast: false + matrix: + include: + - runs-on: "linux.2xlarge" + gpu-arch-type: "cpu" + gpu-arch-version: "" + torch-version: "stable" + - runs-on: "linux.g5.12xlarge.nvidia.gpu" + gpu-arch-type: "cuda" + gpu-arch-version: "12.4" + torch-version: "stable" + - runs-on: "linux.g5.12xlarge.nvidia.gpu" + gpu-arch-type: "cuda" + gpu-arch-version: "12.4" + torch-version: "nightly" + + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + timeout: 120 + runner: ${{ matrix.runs-on }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + script: | + set -ex + + # install python and protobuf + conda create -n venv python=3.12 libprotobuf -y + conda activate venv + python -m pip install --upgrade pip + + # install recent version of Rust via rustup + curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=stable --profile=default -y + . "$HOME/.cargo/env" + + # Optionally install torch nightly, pulls latest CUDA from pip otherwise + if [ "${{ matrix.torch-version }}" = "nightly" ]; then + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + fi + if [ "${{ matrix.torch-version }}" = "test" ]; then + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 + fi + + # Install dependencies + pip install -e .[dev] -v + + # Run tests + pytest examples/test_examples.py diff --git a/README.md b/README.md index ff3cfeab..7370c6ea 100644 --- a/README.md +++ b/README.md @@ -79,15 +79,14 @@ We have a minimal DDP train loop that highlights all of the key components in to See [train_ddp.py](./train_ddp.py) for more info. +### Advanced Examples -### DiLoCo - -LocalSGD and DiLoCo are currently experimental. - -See -[the diloco_train_loop/local_sgd_train_loop tests](./torchft/local_sgd_integ_test.py) -for an example on how to integrate these algorithms into your training loop. +See the [examples/README.md](./examples/README.md) for advanced examples. Currently, the following examples are available: +- [DDP with proactive failure recovery](./examples/ddp_proactive/README.md): Demonstrates DDP with proactive failure recovery mode +- [DiLoCo](./examples/diloco/README.md): Demonstrates Distributed Local Convergence training +- [LocalSGD](./examples/localsgd/README.md): Demonstrates Local SGD with periodic synchronization +- [Live Checkpoint Recovery](./examples/live_checkpoint_recovery/README.md): Demonstrates live checkpoint recovery ## Design @@ -246,38 +245,6 @@ CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --mast By observing the outputs from both shells, you should observe process group reconfiguration and live checkpoint recovery. -### Proactive Failure Recovery Mode (Experimental) - -You can experiment with proactive failure recovery mode by: - -```sh -export TORCHFT_PROACTIVE_RECOVERY=1 -``` - -With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce. - -You can test this out by running `train_ddp_proactive.py` - -On shell 1 (one replica groups starts initial training): -```sh -export REPLICA_GROUP_ID=0 -export NUM_REPLICA_GROUPS=2 -export TORCHFT_PROACTIVE_RECOVERY=1 - -CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py -``` - -On shell 2 (a second replica group joins): -```sh -export REPLICA_GROUP_ID=1 -export NUM_REPLICA_GROUPS=2 -export TORCHFT_PROACTIVE_RECOVERY=1 - -CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py -``` - -You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting `export TORCHFT_PROACTIVE_RECOVERY=0`, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing. - ### Example Parameter Server torchft has a fault tolerant parameter server implementation built on it's diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..4bd7f071 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,37 @@ +# TorchFT Examples + +This directory contains advanced examples demonstrating various fault tolerance features and training approaches in TorchFT beyond the basic `train_ddp.py` example in the [README](../README.md). + +Each directory contains a README with more detailed instructions, as well as extensive documentation on the feature being showcased and how to interpret the outputs. + +## List of Examples + +- [DDP with proactive failure recovery](./ddp_proactive/README.md): Demonstrates DDP with proactive failure recovery mode +- [DiLoCo](./diloco/README.md): Demonstrates Distributed Local Convergence training +- [LocalSGD](./localsgd/README.md): Demonstrates Local SGD with periodic synchronization +- [Live Checkpoint Recovery](./live_checkpoint_recovery/README.md): Demonstrates live checkpoint recovery + +## Running the examples + +After starting the lighthouse server by running: + +```sh +RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 +``` + +You can `cd` into the example directory: + +```sh +cd examples/[example_directory] +``` + +and then launch the example with torchX with: + +```sh +export QUICK_RUN=1 +torchx run +``` + +the QUICK_RUN environment variable runs the examples for much less steps, and also uses a synthetic, rather than downloaded, dataset. It is useful for testing the examples quickly. + +See the `.torchxconfig` file in each example directory for configuration details, and [torchx.py](../torchft/torchx.py) and the [torchX documentation](https://pytorch.org/torchx/latest/) to understand how DDP is being ran. \ No newline at end of file diff --git a/examples/ddp_proactive/.torchxconfig b/examples/ddp_proactive/.torchxconfig new file mode 100644 index 00000000..0b198471 --- /dev/null +++ b/examples/ddp_proactive/.torchxconfig @@ -0,0 +1,7 @@ +[cli:run] +component=../../torchft/torchx.py:hsdp +scheduler=local_cwd + + +[component:../../torchft/torchx.py:hsdp] +script=train_ddp_proactive.py diff --git a/examples/ddp_proactive/README.md b/examples/ddp_proactive/README.md new file mode 100644 index 00000000..13e799a6 --- /dev/null +++ b/examples/ddp_proactive/README.md @@ -0,0 +1,177 @@ +# DDP Proactive Recovery Example + +This example demonstrates DDP with proactive failure recovery in torchft. + +Proactive recovery enables the training process to quickly detect and respond to worker failures without waiting for timeout periods, significantly reducing recovery time. + +Note that the setting `TORCHFT_PROACTIVE_RECOVERY=0` does not disable the Lighthouse heartbeat timeout detection logic, but stops the processes from spawning up a listening process. + +## Implementation Details + +The example is based on [train_ddp.py](../../train_ddp.py). We add the following logic before the construction of the `Manager` object: + +```python + if manager.current_step() == 3: + if REPLICA_GROUP_ID == 0: + manager.shutdown() + exit(0) + # If proactive recovery, then the surviving process will reconfigure + # If not proactive recovery, then the surviving process will wait until timeout + test_tensor = torch.tensor([1.0]).to(device) + manager.allreduce(test_tensor) +``` + +Here, without proactive error recovery, after Replica Group ID 0 shuts down, Replica Group ID 1 will wait until all reduce timeout (set to 120 seconds by default). + +However, with proactive error recovery enabled, the Lighthouse will detect that Replica Group ID 0 heartbeat times out and sends a message to Replica Group ID 1 to reconfigure its process group to exclude the failed replica. + +## How to Run + +You can experiment with proactive failure recovery mode by: + +```sh +export TORCHFT_PROACTIVE_RECOVERY=1 +``` + +On shell 1 (one replica group starts initial training): +```sh +export REPLICA_GROUP_ID=0 +export NUM_REPLICA_GROUPS=2 +export TORCHFT_PROACTIVE_RECOVERY=1 + +CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- examples/ddp_proactive/train_ddp_proactive.py +``` + +On shell 2 (a second replica group joins): +```sh +export REPLICA_GROUP_ID=1 +export NUM_REPLICA_GROUPS=2 +export TORCHFT_PROACTIVE_RECOVERY=1 + +CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- examples/ddp_proactive/train_ddp_proactive.py +``` + +And contrast this with if you did: + +```sh +export TORCHFT_PROACTIVE_RECOVERY=1 +``` + +## Example Outputs + +### With `TORCHFT_PROACTIVE_RECOVERY=1` + +#### Lighthouse + +You should see the following output from the Lighthouse: + +```txt +2025-05-22T11:47:49.435 [INFO] [torchft::lighthouse] - Replica train_ddp_0:782b3df7-ac82-4d1c-9c95-ff05b4c2ddb6 timed out (last heartbeat: Instant { tv_sec: 5334992, tv_nsec: 45173926 }), sending failure notification. +2025-05-22T11:47:49.435 [INFO] [torchft::lighthouse] - Removed replica train_ddp_0:782b3df7-ac82-4d1c-9c95-ff05b4c2ddb6 from heartbeats and participants due to timeout. +2025-05-22T11:47:49.435 [INFO] [torchft::lighthouse] - New failure detected, resetting all participants for quorum formation. +2025-05-22T11:47:49.435 [INFO] [torchft::lighthouse] - Healthy replicas received failure notification for train_ddp_0:782b3df7-ac82-4d1c-9c95-ff05b4c2ddb6 with error message: heartbeat timeout +``` + +Here, the Lighthouse detect heartbeat timeout and sends failure notifications to the healthy replicas. + +#### Replica Group ID 0 + +You should see the following output from Replica Group ID 0: + +```txt +INFO:torchft.manager:[train_ddp_0:782b3df7-ac82-4d1c-9c95-ff05b4c2ddb6/0 - step 3] Setting error processor thread stop event +INFO:torchft.manager:[train_ddp_0:782b3df7-ac82-4d1c-9c95-ff05b4c2ddb6/0 - step 3] Waiting for error processor thread to complete +INFO:torchft.manager:[train_ddp_0:782b3df7-ac82-4d1c-9c95-ff05b4c2ddb6/0 - step 3] Error processor thread shutdown completed. +INFO:torchft.manager:[train_ddp_0:782b3df7-ac82-4d1c-9c95-ff05b4c2ddb6/0 - step 3] Setting failure listener stop event for process +INFO:torchft.manager:[train_ddp_0:782b3df7-ac82-4d1c-9c95-ff05b4c2ddb6/0 - step 3] Waiting for failure listener process to complete +INFO:torchft.manager:[train_ddp_0:782b3df7-ac82-4d1c-9c95-ff05b4c2ddb6/0 - step 3] Failure listener process shutdown completed +``` + +This is the shutdown logic of the error processor thread and failure listener process. The failure listener process listens to the failure notifications from the Lighthouse, and transmits it to the error processor thread in the main training process. + +#### Replica Group ID 1 + +Replica Group ID 1 will recovery quickly after the shutdown of Replica Group ID 0. In the middle, there will be errors relating to the TCPStore due to Replica Group ID 1 aborting its process group in the middle of allreduce. The output is shown in full to show that these error traces are expected. + +```txt +[W522 11:47:45.949247853 TCPStore.cpp:125] [c10d] recvValue failed on SocketImpl(fd=77, addr=[::ffff:127.0.0.1]:39732, remote=[::ffff:127.0.0.1]:29600): failed to recv, got 0 bytes +Exception raised from recvBytes at /pytorch/torch/csrc/distributed/c10d/Utils.hpp:678 (most recent call first): +frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string, std::allocator >) + 0x98 (0x7f9a28f795e8 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libc10.so) +frame #1: + 0x5ba8afe (0x7f9a6d09cafe in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #2: + 0x5baae40 (0x7f9a6d09ee40 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #3: + 0x5bab74a (0x7f9a6d09f74a in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #4: c10d::TCPStore::check(std::vector, std::allocator >, std::allocator, std::allocator > > > const&) + 0x2a9 (0x7f9a6d0991a9 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #5: c10d::ProcessGroupNCCL::heartbeatMonitor() + 0x379 (0x7f9a2a2929a9 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so) +frame #6: + 0xdbbf4 (0x7f9a1a25bbf4 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/bin/../lib/libstdc++.so.6) +frame #7: + 0x8609 (0x7f9a84eb4609 in /lib/x86_64-linux-gnu/libpthread.so.0) +frame #8: clone + 0x43 (0x7f9a84c7f353 in /lib/x86_64-linux-gnu/libc.so.6) + +[W522 11:47:45.952720714 ProcessGroupNCCL.cpp:1659] [PG ID 0 PG GUID Rank 1] Failed to check the "should dump" flag on TCPStore, (maybe TCPStore server has shut down too early), with error: failed to recv, got 0 bytes +[W522 11:47:46.952883182 TCPStore.cpp:106] [c10d] sendBytes failed on SocketImpl(fd=77, addr=[::ffff:127.0.0.1]:39732, remote=[::ffff:127.0.0.1]:29600): Broken pipe +Exception raised from sendBytes at /pytorch/torch/csrc/distributed/c10d/Utils.hpp:653 (most recent call first): +frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string, std::allocator >) + 0x98 (0x7f9a28f795e8 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libc10.so) +frame #1: + 0x5ba8afe (0x7f9a6d09cafe in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #2: + 0x5baa358 (0x7f9a6d09e358 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #3: + 0x5babb3e (0x7f9a6d09fb3e in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #4: c10d::TCPStore::check(std::vector, std::allocator >, std::allocator, std::allocator > > > const&) + 0x298 (0x7f9a6d099198 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #5: c10d::ProcessGroupNCCL::heartbeatMonitor() + 0x379 (0x7f9a2a2929a9 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so) +frame #6: + 0xdbbf4 (0x7f9a1a25bbf4 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/bin/../lib/libstdc++.so.6) +frame #7: + 0x8609 (0x7f9a84eb4609 in /lib/x86_64-linux-gnu/libpthread.so.0) +frame #8: clone + 0x43 (0x7f9a84c7f353 in /lib/x86_64-linux-gnu/libc.so.6) + +[W522 11:47:46.955975324 ProcessGroupNCCL.cpp:1659] [PG ID 0 PG GUID Rank 1] Failed to check the "should dump" flag on TCPStore, (maybe TCPStore server has shut down too early), with error: Broken pipe +[W522 11:47:47.956137177 TCPStore.cpp:106] [c10d] sendBytes failed on SocketImpl(fd=77, addr=[::ffff:127.0.0.1]:39732, remote=[::ffff:127.0.0.1]:29600): Broken pipe +Exception raised from sendBytes at /pytorch/torch/csrc/distributed/c10d/Utils.hpp:653 (most recent call first): +frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string, std::allocator >) + 0x98 (0x7f9a28f795e8 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libc10.so) +frame #1: + 0x5ba8afe (0x7f9a6d09cafe in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #2: + 0x5baa358 (0x7f9a6d09e358 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #3: + 0x5babb3e (0x7f9a6d09fb3e in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #4: c10d::TCPStore::check(std::vector, std::allocator >, std::allocator, std::allocator > > > const&) + 0x298 (0x7f9a6d099198 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #5: c10d::ProcessGroupNCCL::heartbeatMonitor() + 0x379 (0x7f9a2a2929a9 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so) +frame #6: + 0xdbbf4 (0x7f9a1a25bbf4 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/bin/../lib/libstdc++.so.6) +frame #7: + 0x8609 (0x7f9a84eb4609 in /lib/x86_64-linux-gnu/libpthread.so.0) +frame #8: clone + 0x43 (0x7f9a84c7f353 in /lib/x86_64-linux-gnu/libc.so.6) + +[W522 11:47:47.959256423 ProcessGroupNCCL.cpp:1659] [PG ID 0 PG GUID Rank 1] Failed to check the "should dump" flag on TCPStore, (maybe TCPStore server has shut down too early), with error: Broken pipe +[W522 11:47:48.959394571 TCPStore.cpp:106] [c10d] sendBytes failed on SocketImpl(fd=77, addr=[::ffff:127.0.0.1]:39732, remote=[::ffff:127.0.0.1]:29600): Broken pipe +Exception raised from sendBytes at /pytorch/torch/csrc/distributed/c10d/Utils.hpp:653 (most recent call first): +frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string, std::allocator >) + 0x98 (0x7f9a28f795e8 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libc10.so) +frame #1: + 0x5ba8afe (0x7f9a6d09cafe in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #2: + 0x5baa358 (0x7f9a6d09e358 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #3: + 0x5babb3e (0x7f9a6d09fb3e in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #4: c10d::TCPStore::check(std::vector, std::allocator >, std::allocator, std::allocator > > > const&) + 0x298 (0x7f9a6d099198 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #5: c10d::ProcessGroupNCCL::heartbeatMonitor() + 0x379 (0x7f9a2a2929a9 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so) +frame #6: + 0xdbbf4 (0x7f9a1a25bbf4 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/bin/../lib/libstdc++.so.6) +frame #7: + 0x8609 (0x7f9a84eb4609 in /lib/x86_64-linux-gnu/libpthread.so.0) +frame #8: clone + 0x43 (0x7f9a84c7f353 in /lib/x86_64-linux-gnu/libc.so.6) + +[W522 11:47:48.962514781 ProcessGroupNCCL.cpp:1659] [PG ID 0 PG GUID Rank 1] Failed to check the "should dump" flag on TCPStore, (maybe TCPStore server has shut down too early), with error: Broken pipe +INFO:torchft.manager:[train_ddp_1:9b6cf09c-8747-43dd-bf2c-067cc4d77550/0 - step 3] Received error: Peer failure detected in listener process: replica train_ddp_0:782b3df7-ac82-4d1c-9c95-ff05b4c2ddb6 has failed +NoneType: None + +NoneType: None +``` + +### With `TORCHFT_PROACTIVE_RECOVERY=0` + +Execute the following command on Replica Group ID 1: +```sh +export TORCHFT_PROACTIVE_RECOVERY=0 +``` + +You should observe that Replica Group ID 1 stalls for 30 seconds before resuming training. + +```txt +[W522 11:47:47.959256423 ProcessGroupNCCL.cpp:1659] [PG ID 0 PG GUID Rank 1] Failed to check the "should dump" flag on TCPStore, (maybe TCPStore server has shut down too early), with error: Broken pipe +[W522 11:47:48.959394571 TCPStore.cpp:106] [c10d] sendBytes failed on SocketImpl(fd=77, addr=[::ffff:127.0.0.1]:39732, remote=[::ffff:127.0.0.1]:29600): Broken pipe +Exception raised from sendBytes at /pytorch/torch/csrc/distributed/c10d/Utils.hpp:653 (most recent call first): +frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string, std::allocator >) + 0x98 (0x7f9a28f795e8 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libc10.so) +frame #1: + 0x5ba8afe (0x7f9a6d09cafe in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #2: + 0x5baa358 (0x7f9a6d09e358 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #3: + 0x5babb3e (0x7f9a6d09fb3e in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #4: c10d::TCPStore::check(std::vector, std::allocator >, std::allocator, std::allocator > > > const&) + 0x298 (0x7f9a6d099198 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so) +frame #5: c10d::ProcessGroupNCCL::heartbeatMonitor() + 0x379 (0x7f9a2a2929a9 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so) +frame #6: + 0xdbbf4 (0x7f9a1a25bbf4 in /srv/apps/danny/miniconda3/envs/warren/torchtitan/bin/../lib/libstdc++.so.6) +frame #7: + 0x8609 (0x7f9a84eb4609 in /lib/x86_64-linux-gnu/libpthread.so.0) +frame #8: clone + 0x43 (0x7f9a84c7f353 in /lib/x86_64-linux-gnu/libc.so.6) +``` \ No newline at end of file diff --git a/train_ddp_proactive.py b/examples/ddp_proactive/train_ddp_proactive.py similarity index 88% rename from train_ddp_proactive.py rename to examples/ddp_proactive/train_ddp_proactive.py index 3d0002ce..d0aed08b 100644 --- a/train_ddp_proactive.py +++ b/examples/ddp_proactive/train_ddp_proactive.py @@ -10,10 +10,6 @@ import time from datetime import timedelta -REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) -os.environ["CUDA_VISIBLE_DEVICES"] = str(REPLICA_GROUP_ID % 4) -os.environ["NCCL_HOSTID"] = str(REPLICA_GROUP_ID) - import torch import torch.nn.functional as F import torchvision @@ -32,6 +28,9 @@ ) from torchft.checkpointing.pg_transport import PGTransport +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "utils")) +from utils import get_cifar10_dataset + logging.basicConfig(level=logging.INFO) @@ -39,12 +38,17 @@ def main() -> None: REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) + QUICK_RUN = bool(os.environ.get("QUICK_RUN", False)) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) - trainset = torchvision.datasets.CIFAR10( - root="./cifar", train=True, download=True, transform=transform + trainset = get_cifar10_dataset( + root="./cifar", + train=True, + download=True, + transform=transform, + quick_run=QUICK_RUN, ) # This shards the training set across all ranks and replica groups. We manage @@ -79,10 +83,10 @@ def state_dict(): device = "cuda" if torch.cuda.is_available() else "cpu" pg = ( ProcessGroupNCCL( - timeout=timedelta(seconds=30), + timeout=timedelta(seconds=120), ) if torch.cuda.is_available() - else ProcessGroupGloo(timeout=timedelta(seconds=5)) + else ProcessGroupGloo(timeout=timedelta(seconds=120)) ) transport = PGTransport( @@ -167,8 +171,6 @@ def trace_handler(p): for i, (inputs, labels) in enumerate(trainloader): prof.step() - time.sleep(0.5) # Else each iteration runs too quickly - inputs = inputs.to(device) labels = labels.to(device) @@ -188,8 +190,10 @@ def trace_handler(p): exit(0) # If proactive recovery, then the surviving process will reconfigure # If not proactive recovery, then the surviving process will wait until timeout + print("Starting Hanging Allreduce") test_tensor = torch.tensor([1.0]).to(device) + print("manager.current_step()", manager.current_step(), "world_size", manager._participating_replica_world_size, "rank", manager._participating_replica_rank) manager.allreduce(test_tensor) # must be called at the end of the train loop @@ -208,11 +212,15 @@ def trace_handler(p): # they're shared across all groups and will load from existing replicas as # long as not every worker goes down. - if manager.current_step() >= 10000: + max_steps = 10 if QUICK_RUN else 10000 + if manager.current_step() >= max_steps: # complete training prof.stop() exit() + sleep_time = 0.001 if QUICK_RUN else 0.5 + time.sleep(sleep_time) + if __name__ == "__main__": main() diff --git a/examples/diloco/.torchxconfig b/examples/diloco/.torchxconfig new file mode 100644 index 00000000..b3d3c285 --- /dev/null +++ b/examples/diloco/.torchxconfig @@ -0,0 +1,7 @@ +[cli:run] +component=../../torchft/torchx.py:hsdp +scheduler=local_cwd + + +[component:../../torchft/torchx.py:hsdp] +script=train_diloco.py diff --git a/examples/diloco/README.md b/examples/diloco/README.md new file mode 100644 index 00000000..d5ca9cd3 --- /dev/null +++ b/examples/diloco/README.md @@ -0,0 +1,67 @@ +# DiLoCo Example + +This example demonstrates DiLoCo training. + +From the doc strings of the [DiLoCo class](../../torchft/local_sgd.py#L157): + +```txt +DiLoCo is a subclass of LocalSGD that overrides the synchronization +mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights). + +This algorithm requires a backup copy of the +weights. By default these are stored in CPU memory. If any error occurs +during the DiLoCo step, the step will be discarded and the model +parameters will reset back to the last time DiLoCo synchronized. + +DiLoCo paper: https://arxiv.org/pdf/2311.08105 +``` + +## Implementation Details + +As seen in the training script, DiLoCo defines two optimizers, one for the inner loop and one for the outer loop. The paper found that using Adam for the inner loop and SGD for the outer loop worked best over a range of configurations. + +A backup device is specified. This backup device is used to store a copy of the model parameters at the beginning of the inner optimization loop, so that the outer step can be applied to the model parameters before undergoing inner optimization, keeping the model parameters in sync. + +## How to Run + +These assumes that you are in the root directory of the torchft repository. + +1. Start the Lighthouse server: +```bash +RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 +``` + +2. Run using torchx: +```bash +cd examples/diloco +torchx run +``` + +3. Or manually run multiple replica groups: + +Shell 1 (first replica group): +```bash +export REPLICA_GROUP_ID=0 +export NUM_REPLICA_GROUPS=2 +CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 examples/diloco/train_diloco.py +``` + +Shell 2 (second replica group): +```bash +export REPLICA_GROUP_ID=1 +export NUM_REPLICA_GROUPS=2 +CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 examples/diloco/train_diloco.py +``` + +## Example Outputs + +You should see snippets like the following from the output: + +```sh +DiLoCo: Number of inner optimizer steps completed: 500 +DiLoCo: Number of outer optimizer steps completed: [5] loss = 7.605193614959717 +``` + +These tell you that the inner optimizer has completed 500 steps and the outer optimizer has completed 5 steps. + + diff --git a/examples/diloco/train_diloco.py b/examples/diloco/train_diloco.py new file mode 100644 index 00000000..5a1a48ca --- /dev/null +++ b/examples/diloco/train_diloco.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +from datetime import timedelta + +import os +import sys +import time + +import torch +import torchvision +import torchvision.transforms as transforms +from torch import nn, optim +from torch.distributed.elastic.multiprocessing.errors import record +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchft import DistributedSampler, Manager, ProcessGroupGloo, ProcessGroupNCCL +from torchft.checkpointing.pg_transport import PGTransport +from torchft.local_sgd import DiLoCo + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "utils")) +from utils import get_cifar10_dataset + +logging.basicConfig(level=logging.INFO) + + +@record +def main() -> None: + REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) + NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) + QUICK_RUN = bool(os.environ.get("QUICK_RUN", False)) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + trainset = get_cifar10_dataset( + root="./cifar", + train=True, + download=True, + transform=transform, + quick_run=QUICK_RUN, + ) + + # This shards the training set across all ranks and replica groups. We manage + # the dataloaders on a per replica group basis with the assumption that the + # majority of groups will be available so few batches will be dropped. + sampler = DistributedSampler( + trainset, + replica_group_id=REPLICA_GROUP_ID, + num_replica_groups=NUM_REPLICA_GROUPS, + group_rank=0, + # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. + num_replicas=1, + shuffle=True, + ) + + # This uses the torchdata StatefulDataLoader to be able to checkpoint and + # restore the per worker dataloader position. + trainloader = StatefulDataLoader( + trainset, batch_size=64, num_workers=2, sampler=sampler + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + pg = ( + ProcessGroupNCCL( + timeout=timedelta(seconds=30), + ) + if torch.cuda.is_available() + else ProcessGroupGloo(timeout=timedelta(seconds=5)) + ) + + transport = PGTransport( + pg, + timeout=timedelta(seconds=10), + device=("cuda" if torch.cuda.is_available() else "cpu"), + ) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.cnn = nn.Sequential( + nn.Conv2d(3, 6, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + ) + + final_dim = 10 + # We add a useless 1GB intermediate layer so we spend more time in dist + # communication so injected failures are more likely to cause issues + # if they exist. + target_size = 1_000_000_000 + self.useless = nn.Embedding(target_size // final_dim // 4, final_dim) + + self.classifier = nn.Sequential( + nn.Linear(16 * 5 * 5, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, final_dim), + ) + + def forward(self, x): + x = self.cnn(x) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = self.classifier(x) + x += self.useless.weight[0] + return x + + m = Net().to(device) + inner_optimizer = optim.AdamW( + m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) + ) + outer_optimizer = optim.SGD(m.parameters(), lr=0.7, momentum=0.9, nesterov=True) + criterion = nn.CrossEntropyLoss() + + def load_state_dict(state_dict): + m.load_state_dict(state_dict["model"]) + m.to(device) + diloco.original_parameters = state_dict["original_params"] + for name in diloco.original_parameters.keys(): + diloco.original_parameters[name] = diloco.original_parameters[name].to( + device + ) + inner_optimizer.load_state_dict(state_dict["inner_optim"]) + outer_optimizer.load_state_dict(state_dict["outer_optim"]) + + def state_dict(): + return { + "model": m.state_dict(), + "original_params": diloco.original_parameters, + "inner_optim": inner_optimizer.state_dict(), + "outer_optim": outer_optimizer.state_dict(), + } + + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=30), + checkpoint_transport=transport, + use_async_quorum=False, + ) + + print(m) + num_params = sum(p.numel() for p in m.parameters()) + print(f"DiLoCo: Total number of parameters: {num_params}") + + sort_by_keyword = "self_" + device + "_time_total" + + def trace_handler(p): + output = p.key_averages().table( + sort_by=sort_by_keyword, + row_limit=100, + ) + print(output) + p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json") + + # You can use an epoch based training but with faults it's easier to use step + # based training. + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=5, warmup=1, active=10, repeat=2), + on_trace_ready=trace_handler, + record_shapes=True, + profile_memory=True, + ) + + prof.start() + + num_local_steps = 0 + sync_every = 5 if QUICK_RUN else 100 + with DiLoCo( + manager, + m, + inner_optimizer, + outer_optimizer, + backup_device=device, + sync_every=sync_every, + ) as diloco: + while True: + for i, (inputs, labels) in enumerate(trainloader): + prof.step() + + inputs = inputs.to(device) + labels = labels.to(device) + + # must be called at the beginning of each train loop + # Quorum computation is triggered here but only needed in the backwards pass. + inner_optimizer.zero_grad() + + out = m(inputs) + loss = criterion(out, labels) + + # Gradient allreduce overlaps with the backwards pass. + loss.backward() + + # must be called at the end of the train loop + # This may not actually step the optimizer if an error occured during grad allreduce. + inner_optimizer.step() + num_local_steps += 1 + + if num_local_steps % sync_every == 0: + print( + f"DiLoCo: Number of inner optimizer steps completed: {num_local_steps}" + ) + print( + f"DiLoCo: Number of outer optimizer steps completed: {manager.current_step()} loss = {loss.item()}" + ) + + # TODO (by the user): periodically checkpoint model, optim, manager and dataloader + + # You typically want to checkpoint dataloader frequently (every step?) to + # avoid repeated batches as it's replica group specific. + + # Model, optim and manager checkpoints can be done more infrequently as + # they're shared across all groups and will load from existing replicas as + # long as not every worker goes down. + + max_steps = 3 if QUICK_RUN else 10000 + if manager.current_step() >= max_steps: + # complete training + prof.stop() + exit() + + sleep_time = 0.001 if QUICK_RUN else 0.01 + time.sleep(sleep_time) + + +if __name__ == "__main__": + main() diff --git a/examples/live_checkpoint_recovery/.torchxconfig b/examples/live_checkpoint_recovery/.torchxconfig new file mode 100644 index 00000000..ecef62cd --- /dev/null +++ b/examples/live_checkpoint_recovery/.torchxconfig @@ -0,0 +1,7 @@ +[cli:run] +component=../../torchft/torchx.py:hsdp +scheduler=local_cwd + + +[component:../../torchft/torchx.py:hsdp] +script=train_ddp_lcr.py diff --git a/examples/live_checkpoint_recovery/README.md b/examples/live_checkpoint_recovery/README.md new file mode 100644 index 00000000..6eafd61f --- /dev/null +++ b/examples/live_checkpoint_recovery/README.md @@ -0,0 +1,102 @@ +# Live Checkpoint Recovery Example + +This example demonstrates live checkpoint recovery in torchft using [process group based transport](../../torchft/checkpointing/pg_transport.py). + +The description of Live Checkpoint Recovery from [Fault Tolerance Poster](../../media/fault_tolerance_poster.pdf) + +```txt +# Live Checkpoint Recovery +- We’re developing a novel way to live recover from failures by asynchronously saving checkpoints and serving them directly to newly joined and recovering workers. +- On worker start, the checkpoint is transferred via HTTP from an existing healthy worker. +- The weights are copied from the GPU in a non-blocking way during the forward pass using a separate CUDA stream. +- We use leader election to identify live workers and exchange step information to recover from failures. +``` + +## Implementation Details + +The example is based on [train_ddp.py](../../train_ddp.py). We add the following logic before the construction of the `Manager` object: + +```python +if REPLICA_GROUP_ID == 0: + time.sleep(10) +``` + +This simulates a worker that joins the training in the middle. Because the worker that joins has a step value that is less than the step value of the workers currently in training, Live Checkpoint Recovery will be triggered. + + +## How to Run + +You can experiment with live checkpoint recovery mode by launching the following commands in two shells at around the same time. + +On shell 1 (one replica group starts initial training): +```sh +export REPLICA_GROUP_ID=0 +export NUM_REPLICA_GROUPS=2 +export TORCHFT_PROACTIVE_RECOVERY=1 + +CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- examples/live_checkpoint_recovery/train_ddp_lcr.py +``` + +On shell 2 (a second replica group joins): +```sh +export REPLICA_GROUP_ID=1 +export NUM_REPLICA_GROUPS=2 +export TORCHFT_PROACTIVE_RECOVERY=1 + +CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- examples/live_checkpoint_recovery/train_ddp_lcr.py +``` + +## Example Outputs + +Below are snippets that you should see from the terminal outputs. Comments have been added to faciliate understanding. + +You should see snippets like the following from the output for Replica Group ID 0 that shows the steps of Live Checkpoint Recovery. + +```txt +2025-05-22T11:16:34.425 [INFO] [torchft::manager] - Creating LighthouseClient: establishing connection to http://localhost:29510 + +2025-05-22T11:16:34.817 [INFO] [torchft::manager] - [Replica train_ddp_0] Start quorum for group_rank 0 + +2025-05-22T11:16:34.817 [INFO] [torchft::manager] - [Replica train_ddp_0] All workers joined - starting quorum + +2025-05-22T11:16:34.905 [INFO] [torchft::manager] - [Replica train_ddp_0] got lighthouse quorum LighthouseQuorumResponse { quorum: Some(Quorum { quorum_id: 12, participants: [QuorumMember { replica_id: "train_ddp_0:9f5b5624-112d-4b51-a995-c8a08b640471", address: "http://sz-k8s-master:33997", store_address: "127.0.0.1:29600", step: 0, world_size: 1, shrink_only: false, commit_failures: 0, data: "" }, QuorumMember { replica_id: "train_ddp_1:99b3452a-ef27-4523-8b4f-cdab6cbbc004", address: "http://sz-k8s-master:42813", store_address: "127.0.0.1:29601", step: 17, world_size: 1, shrink_only: false, commit_failures: 0, data: "" }], created: Some(Timestamp { seconds: 1747883794, nanos: 905534965 }) }) } # train_ddp_0 has step=0, as it just joined, whilst train_ddp_1 has step=17 + +2025-05-22T11:16:34.905 [INFO] [torchft::manager] - [Replica train_ddp_0] Finished quorum for group_rank 0 + +2025-05-22T11:16:34.905 [INFO] [torchft::manager] - [Replica train_ddp_0] healing is required step=0, max_step=17, recover_src_replica_rank=1 # This discrepancy in the steps triggers live checkpoint recovery. train_ddp_0 initiates an MPI recv call to recover_src_replica_rank + +INFO:torchft.manager:[train_ddp_0:9f5b5624-112d-4b51-a995-c8a08b640471/0 - step 0] reconfiguring for quorum_id=12 store_prefixed_addr='127.0.0.1:29601/torchft/12/0' # Process group needs to be reconfigured after a new replica joins. The rendezvous store adds a new prefixed address to prevent data from previous rendezvous attempts affecting the current rendezvous + +INFO:torchft.manager:[train_ddp_0:9f5b5624-112d-4b51-a995-c8a08b640471/0 - step 0] healing required, fetching checkpoint metadata from recover_src_manager_address='http://sz-k8s-master:42813' max_step=17 + +2025-05-22T11:16:34.910 [INFO] [torchft::manager] - Creating ManagerClient: establishing connection to http://sz-k8s-master:42813 +INFO:torchft.manager:[train_ddp_0:9f5b5624-112d-4b51-a995-c8a08b640471/0 - step 0] fetching checkpoint from recover_src_replica_rank=1 with checkpoint_metadata='' # Checkpoint metadata is only needed if using a checkpoint server. Here, we directly use MPI send/recv + +[W522 11:16:35.452918836 reducer.cpp:1430] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) + +INFO:torchft.manager:[train_ddp_0:9f5b5624-112d-4b51-a995-c8a08b640471/0 - step 17] applying pending state dict # State dict contains the optimizer state and model states that the recover_src_replica_rank sends over. + +INFO:torchft.manager:[train_ddp_0:9f5b5624-112d-4b51-a995-c8a08b640471/0 - step 17] Loaded state dict. +``` + +And from the output for Replica Group ID 1 that shows the steps of Live Checkpoint Recovery. + +```txt +2025-05-22T11:16:34.905 [INFO] [torchft::manager] - [Replica train_ddp_1] Start quorum for group_rank 0 + +2025-05-22T11:16:34.905 [INFO] [torchft::manager] - [Replica train_ddp_1] All workers joined - starting quorum + +2025-05-22T11:16:34.905 [INFO] [torchft::manager] - [Replica train_ddp_1] got lighthouse quorum LighthouseQuorumResponse { quorum: Some(Quorum { quorum_id: 12, participants: [QuorumMember { replica_id: "train_ddp_0:9f5b5624-112d-4b51-a995-c8a08b640471", address: "http://sz-k8s-master:33997", store_address: "127.0.0.1:29600", step: 0, world_size: 1, shrink_only: false, commit_failures: 0, data: "" }, QuorumMember { replica_id: "train_ddp_1:99b3452a-ef27-4523-8b4f-cdab6cbbc004", address: "http://sz-k8s-master:42813", store_address: "127.0.0.1:29601", step: 17, world_size: 1, shrink_only: false, commit_failures: 0, data: "" }], created: Some(Timestamp { seconds: 1747883794, nanos: 905534965 }) }) } # train_ddp_0 has step=0, as it just joined, whilst train_ddp_1 has step=17, this should be the same as the output for REPLICA Group ID 0 + +2025-05-22T11:16:34.905 [INFO] [torchft::manager] - [Replica train_ddp_1] Finished quorum for group_rank 0 + +INFO:torchft.manager:[train_ddp_1:99b3452a-ef27-4523-8b4f-cdab6cbbc004/0 - step 17] reconfiguring for quorum_id=12 store_prefixed_addr='127.0.0.1:29601/torchft/12/0' # Process group needs to be reconfigured after a new replica joins. The rendezvous store adds a new prefixed address to prevent data from previous rendezvous attempts affecting the current rendezvous + +INFO:torchft.manager:[train_ddp_1:99b3452a-ef27-4523-8b4f-cdab6cbbc004/0 - step 17] peers need recovery from us [0] +INFO:torchft.checkpointing.pg_transport:preparing state_dict took 0.0023363223299384117s # [0] is the replica_group_rank of the peer that needs recovery. Does an MPI send to that replica_group_rank + +/srv/apps/torchft/torchft/checkpointing/pg_transport.py:208: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:1577.) + buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device) +INFO:torchft.checkpointing.pg_transport:send pickle took 0.08277195505797863s # Time taken to pickle the tensor for sending +INFO:torchft.checkpointing.pg_transport:send tensors took 1.706640336662531s # Time taken until the tensor is sent +``` \ No newline at end of file diff --git a/examples/live_checkpoint_recovery/train_ddp_lcr.py b/examples/live_checkpoint_recovery/train_ddp_lcr.py new file mode 100644 index 00000000..1a5fa8d9 --- /dev/null +++ b/examples/live_checkpoint_recovery/train_ddp_lcr.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +import time +from datetime import timedelta + +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch import nn, optim +from torch.distributed.elastic.multiprocessing.errors import record +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchft import ( + DistributedDataParallel, + DistributedSampler, + Manager, + Optimizer, + ProcessGroupGloo, + ProcessGroupNCCL, +) +from torchft.checkpointing.pg_transport import PGTransport + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "utils")) +from utils import get_cifar10_dataset + +logging.basicConfig(level=logging.INFO) + + +@record +def main() -> None: + REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) + NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) + QUICK_RUN = bool(os.environ.get("QUICK_RUN", False)) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + trainset = get_cifar10_dataset( + root="./cifar", + train=True, + download=True, + transform=transform, + quick_run=QUICK_RUN, + ) + + # This shards the training set across all ranks and replica groups. We manage + # the dataloaders on a per replica group basis with the assumption that the + # majority of groups will be available so few batches will be dropped. + sampler = DistributedSampler( + trainset, + replica_group_id=REPLICA_GROUP_ID, + num_replica_groups=NUM_REPLICA_GROUPS, + group_rank=0, + # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. + num_replicas=1, + shuffle=True, + ) + + # This uses the torchdata StatefulDataLoader to be able to checkpoint and + # restore the per worker dataloader position. + trainloader = StatefulDataLoader( + trainset, batch_size=64, num_workers=2, sampler=sampler + ) + + def load_state_dict(state_dict): + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict(): + return { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + + device = "cuda" if torch.cuda.is_available() else "cpu" + pg = ( + ProcessGroupNCCL( + timeout=timedelta(seconds=30), + ) + if torch.cuda.is_available() + else ProcessGroupGloo(timeout=timedelta(seconds=5)) + ) + + transport = PGTransport( + pg, + timeout=timedelta(seconds=10), + device=("cuda" if torch.cuda.is_available() else "cpu"), + ) + + if REPLICA_GROUP_ID == 0: + # Reduce initial sleep for quick runs + initial_sleep = 3 if QUICK_RUN else 10 + time.sleep(initial_sleep) + + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=30), + checkpoint_transport=transport, + ) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.cnn = nn.Sequential( + nn.Conv2d(3, 6, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + ) + + final_dim = 10 + # We add a useless 1GB intermediate layer so we spend more time in dist + # communication so injected failures are more likely to cause issues + # if they exist. + target_size = 1_000_000_000 + self.useless = nn.Embedding(target_size // final_dim // 4, final_dim) + + self.classifier = nn.Sequential( + nn.Linear(16 * 5 * 5, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, final_dim), + ) + + def forward(self, x): + x = self.cnn(x) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = self.classifier(x) + x += self.useless.weight[0] + return x + + m = Net().to(device) + m = DistributedDataParallel(manager, m) + optimizer = Optimizer(manager, optim.AdamW(m.parameters())) + + print(m) + num_params = sum(p.numel() for p in m.parameters()) + print(f"Total number of parameters: {num_params}") + + sort_by_keyword = "self_" + device + "_time_total" + + def trace_handler(p): + output = p.key_averages().table( + sort_by=sort_by_keyword, + row_limit=100, + ) + print(output) + p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json") + + # You can use an epoch based training but with faults it's easier to use step + # based training. + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=5, warmup=1, active=10, repeat=2), + on_trace_ready=trace_handler, + record_shapes=True, + profile_memory=True, + ) + + prof.start() + + while True: + for i, (inputs, labels) in enumerate(trainloader): + prof.step() + + inputs = inputs.to(device) + labels = labels.to(device) + + # must be called at the beginning of each train loop + # Quorum computation is triggered here but only needed in the backwards pass. + optimizer.zero_grad() + + out = m(inputs) + criterion = nn.CrossEntropyLoss() + loss = criterion(out, labels) + + # Gradient allreduce overlaps with the backwards pass. + loss.backward() + + # must be called at the end of the train loop + # This may not actually step the optimizer if an error occured during grad allreduce. + optimizer.step() + + if manager.current_step() % 100 == 0: + print(f"[{manager.current_step()}] loss = {loss.item()}") + + # TODO (by the user): periodically checkpoint model, optim, manager and dataloader + + # You typically want to checkpoint dataloader frequently (every step?) to + # avoid repeated batches as it's replica group specific. + + # Model, optim and manager checkpoints can be done more infrequently as + # they're shared across all groups and will load from existing replicas as + # long as not every worker goes down. + + max_steps = 10 if QUICK_RUN else 10000 + if manager.current_step() >= max_steps: + # complete training + prof.stop() + exit() + + sleep_time = 0.5 if QUICK_RUN else 0.5 + time.sleep(sleep_time) + + +if __name__ == "__main__": + main() diff --git a/examples/localsgd/.torchxconfig b/examples/localsgd/.torchxconfig new file mode 100644 index 00000000..2b8cbaa5 --- /dev/null +++ b/examples/localsgd/.torchxconfig @@ -0,0 +1,7 @@ +[cli:run] +component=../../torchft/torchx.py:hsdp +scheduler=local_cwd + + +[component:../../torchft/torchx.py:hsdp] +script=train_localsgd.py diff --git a/examples/localsgd/README.md b/examples/localsgd/README.md new file mode 100644 index 00000000..a0eb0686 --- /dev/null +++ b/examples/localsgd/README.md @@ -0,0 +1,73 @@ +# LocalSGD Example + +This example demonstrates localSGD training with torchft. + +From the docstrings of the LocalSGD class: + +```txt +LocalSGD is a context manager that +implements the algorithm described in https://arxiv.org/pdf/1805.09767 + +This will synchronize the model parameters periodically in a fault tolerant +way using a torchft Manager. The allreduce on the parameters will happen +every sync_every steps after the optimizer.step call. + +The torchft quorum is computed at the beginning of ``sync_every`` steps. If +any error occurs, or a worker fails between syncs, ``sync_every`` steps will be +discarded and a new quorum will be computed on the next step. + +If running in async mode, on a joining worker the first ``sync_every`` steps +will discarded as the model will be recovering during that period. When +using sync mode, the checkpoint will be restored prior to the first step. +``` + +## Implementation Details + +For localSGD training, there is no need to wrap the optimizer with the [OptimizerWrapper](../../torchft/optim.py#L24). This is because the LocalSGD context manager handles the calls to `manager.start_quorum()` and `manager.should_commit()` + +## How to Run + +These assumes that you are in the root directory of the torchft repository. + +1. Start the Lighthouse server: +```bash +RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 +``` + +2. Run using torchx: +```bash +cd examples/localsgd +torchx run +``` + +3. Or manually run multiple replica groups: + +Shell 1 (first replica group): +```bash +export REPLICA_GROUP_ID=0 +export NUM_REPLICA_GROUPS=2 +CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 examples/localsgd/train_localsgd.py +``` + +Shell 2 (second replica group): +```bash +export REPLICA_GROUP_ID=1 +export NUM_REPLICA_GROUPS=2 +CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 examples/localsgd/train_localsgd.py +``` + +## Interprating Outputs + +You should see snippets like the following from the output: + +```sh +LocalSGD: Number of local optimizer steps completed: 100 +``` + +And + +```sh +2025-05-22T10:52:30.575 [INFO] [torchft::manager] - [Replica train_ddp_0] got lighthouse quorum LighthouseQuorumResponse { quorum: Some(Quorum { quorum_id: 3, participants: [QuorumMember { replica_id: "train_ddp_0:4e6f882a-35c9-4f50-af99-ee1e0aae1310", address: "http://sz-k8s-master:45321", store_address: "127.0.0.1:29600", step: 8, world_size: 1, shrink_only: false, commit_failures: 0, data: "" }], created: Some(Timestamp { seconds: 1747882350, nanos: 574746342 }) }) } +``` + +The step above is the number of local optimizer step completed by localSGD, and the step below in QuorumMember is the number of global sync steps completed. \ No newline at end of file diff --git a/examples/localsgd/train_localsgd.py b/examples/localsgd/train_localsgd.py new file mode 100644 index 00000000..4392d5ce --- /dev/null +++ b/examples/localsgd/train_localsgd.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +from datetime import timedelta + +import time + +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch import nn, optim +from torch.distributed.elastic.multiprocessing.errors import record +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchft import ( + DistributedDataParallel, + DistributedSampler, + Manager, + Optimizer, + ProcessGroupGloo, + ProcessGroupNCCL, +) +from torchft.checkpointing.pg_transport import PGTransport +from torchft.local_sgd import LocalSGD + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "utils")) +from utils import get_cifar10_dataset + +logging.basicConfig(level=logging.INFO) + + +@record +def main() -> None: + REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) + NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) + QUICK_RUN = bool(os.environ.get("QUICK_RUN", False)) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + trainset = get_cifar10_dataset( + root="./cifar", + train=True, + download=True, + transform=transform, + quick_run=QUICK_RUN, + ) + + # This shards the training set across all ranks and replica groups. We manage + # the dataloaders on a per replica group basis with the assumption that the + # majority of groups will be available so few batches will be dropped. + sampler = DistributedSampler( + trainset, + replica_group_id=REPLICA_GROUP_ID, + num_replica_groups=NUM_REPLICA_GROUPS, + group_rank=0, + # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. + num_replicas=1, + shuffle=True, + ) + + # This uses the torchdata StatefulDataLoader to be able to checkpoint and + # restore the per worker dataloader position. + trainloader = StatefulDataLoader( + trainset, batch_size=64, num_workers=2, sampler=sampler + ) + + def load_state_dict(state_dict): + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict(): + return { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + + device = "cuda" if torch.cuda.is_available() else "cpu" + pg = ( + ProcessGroupNCCL( + timeout=timedelta(seconds=30), + ) + if torch.cuda.is_available() + else ProcessGroupGloo(timeout=timedelta(seconds=5)) + ) + print(f"LocalSGD: Process group: {pg}") + + transport = PGTransport( + pg, + timeout=timedelta(seconds=10), + device=("cuda" if torch.cuda.is_available() else "cpu"), + ) + + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=30), + checkpoint_transport=transport, + ) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.cnn = nn.Sequential( + nn.Conv2d(3, 6, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + ) + + final_dim = 10 + # We add a useless 1GB intermediate layer so we spend more time in dist + # communication so injected failures are more likely to cause issues + # if they exist. + target_size = 1_000_000_000 + self.useless = nn.Embedding(target_size // final_dim // 4, final_dim) + + self.classifier = nn.Sequential( + nn.Linear(16 * 5 * 5, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, final_dim), + ) + + def forward(self, x): + x = self.cnn(x) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = self.classifier(x) + x += self.useless.weight[0] + return x + + m = Net().to(device) + optimizer = optim.Adam(m.parameters()) + criterion = nn.CrossEntropyLoss() + + print(m) + num_params = sum(p.numel() for p in m.parameters()) + print(f"LocalSGD: Total number of parameters: {num_params}") + + sort_by_keyword = "self_" + device + "_time_total" + + def trace_handler(p): + output = p.key_averages().table( + sort_by=sort_by_keyword, + row_limit=100, + ) + print(output) + p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json") + + # You can use an epoch based training but with faults it's easier to use step + # based training. + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=5, warmup=1, active=10, repeat=2), + on_trace_ready=trace_handler, + record_shapes=True, + profile_memory=True, + ) + + prof.start() + + num_local_steps = 0 + sync_every = 5 if QUICK_RUN else 100 + with LocalSGD(manager, m, optimizer, sync_every=sync_every): + while True: + for i, (inputs, labels) in enumerate(trainloader): + prof.step() + + inputs = inputs.to(device) + labels = labels.to(device) + + # must be called at the beginning of each train loop + # Quorum computation is triggered here but only needed in the backwards pass. + optimizer.zero_grad() + + out = m(inputs) + loss = criterion(out, labels) + + # Gradient allreduce overlaps with the backwards pass. + loss.backward() + + # must be called at the end of the train loop + # This may not actually step the optimizer if an error occured during grad allreduce. + optimizer.step() + num_local_steps += 1 + + if manager.current_step() % 100 == 0: + print(f"LocalSGD: [{manager.current_step()}] loss = {loss.item()}") + + if num_local_steps % 100 == 0: + print( + f"LocalSGD: Number of local optimizer steps completed: {num_local_steps}" + ) + + # TODO (by the user): periodically checkpoint model, optim, manager and dataloader + + # You typically want to checkpoint dataloader frequently (every step?) to + # avoid repeated batches as it's replica group specific. + + # Model, optim and manager checkpoints can be done more infrequently as + # they're shared across all groups and will load from existing replicas as + # long as not every worker goes down. + + max_steps = 3 if QUICK_RUN else 10000 + if manager.current_step() >= max_steps: + # complete training + prof.stop() + exit() + + sleep_time = 0.001 if QUICK_RUN else 0.01 + time.sleep(sleep_time) + + +if __name__ == "__main__": + main() diff --git a/examples/test_examples.py b/examples/test_examples.py new file mode 100644 index 00000000..fa5db9d6 --- /dev/null +++ b/examples/test_examples.py @@ -0,0 +1,200 @@ +import os +import subprocess +import tempfile +import time +import socket +from pathlib import Path + +import pytest + + +def get_example_directories(): + """Get all example directories that have .torchxconfig files.""" + current_dir = Path.cwd() + if current_dir.name == "examples": + examples_root = current_dir + else: + examples_root = Path("examples") + + return [ + d for d in examples_root.iterdir() + if d.is_dir() and (d / ".torchxconfig").exists() + ] + +@pytest.fixture(scope="session") +def lighthouse_server(): + """Start a lighthouse server for the tests.""" + default_port = 29510 + lighthouse_url = f"http://localhost:{default_port}" + + # Kill any existing process using the lighthouse port + try: + result = subprocess.run(["lsof", "-ti", f":{default_port}"], capture_output=True, text=True) + if result.stdout.strip(): + pids = result.stdout.strip().split('\n') + for pid in pids: + subprocess.run(["kill", pid], capture_output=True) + time.sleep(1) # Give time for cleanup + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + lighthouse_proc = subprocess.Popen( + ["torchft_lighthouse", + "--min_replicas", "1", + "--quorum_tick_ms", "100", + "--join_timeout_ms", "10000", + "--bind", f"[::]:{default_port}" + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + time.sleep(1) + + yield lighthouse_url + + lighthouse_proc.terminate() + try: + lighthouse_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + lighthouse_proc.kill() + lighthouse_proc.wait() + +class TestTorchXExamples: + """Test that torchx run works for all examples.""" + + @pytest.mark.parametrize("example_dir", get_example_directories()) + def test_training_script_exists(self, example_dir): + """Test that the training script referenced in config exists.""" + config_path = example_dir / ".torchxconfig" + + import configparser + config = configparser.ConfigParser() + config.read(config_path) + + # Find component sections + for section_name in config.sections(): + if section_name.startswith("component:"): + section = config[section_name] + if "script" in section: + script_path = example_dir / section["script"] + assert script_path.exists(), f"Missing training script {script_path} referenced in {config_path}" + + @pytest.mark.parametrize("example_dir", get_example_directories()) + def test_torchx_config_valid(self, example_dir): + """Test that .torchxconfig files are valid.""" + config_path = example_dir / ".torchxconfig" + assert config_path.exists(), f"Missing .torchxconfig in {example_dir}" + + # Try to parse the config + import configparser + config = configparser.ConfigParser() + config.read(config_path) + + # Should have cli:run section + assert "cli:run" in config.sections(), f"Missing [cli:run] section in {config_path}" + + # Should have component reference + cli_section = config["cli:run"] + assert "component" in cli_section, f"Missing component in [cli:run] section of {config_path}" + + component_ref = cli_section["component"] + assert ":" in component_ref, f"Invalid component reference format in {config_path}: {component_ref}" + + @pytest.mark.parametrize("example_dir", get_example_directories()) + def test_torchx_run_quick(self, example_dir, lighthouse_server): + """Test that torchx run works with QUICK_RUN for each example.""" + + timeout_seconds = 120 + self._test_example(example_dir, lighthouse_server, timeout_seconds) + + def _test_example(self, example_dir, lighthouse_server, timeout_seconds=120): + """Test regular examples (non-ddp_proactive).""" + print(f"\n=== Testing {example_dir.name} ===") + + # Set environment for quick run with memory management + + env = os.environ.copy() + env["QUICK_RUN"] = "1" + env["TORCHFT_LIGHTHOUSE"] = lighthouse_server + # Enable PyTorch memory management features to prevent fragmentation + env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + + if example_dir.name == "ddp_proactive": + env["TORCHFT_PROACTIVE_RECOVERY"] = "1" + else: + env["TORCHFT_PROACTIVE_RECOVERY"] = "0" + + cmd = ["torchx", "run"] + + # Start the process with a new process group for better cleanup control + try: + result = subprocess.run( + cmd, + cwd=example_dir, + env=env, + timeout=timeout_seconds, # Increased timeout to allow for optimized sync frequencies + capture_output=False, # Let training logs print to console + text=True, + preexec_fn=os.setsid, # Create new process group for easier cleanup + ) + + # Training should complete successfully + # When capture_output=False, stdout/stderr are None, so we just check return code + assert result.returncode == 0, f"torchx run failed for {example_dir} with return code {result.returncode}" + + print("-" * 30) + print(f"✅ {example_dir.name} completed successfully! ") + print("-" * 30) + + except subprocess.TimeoutExpired: + # If timeout occurs, try to clean up any remaining processes + print(f"Test timed out for {example_dir.name}, cleaning up processes...") + self._cleanup_training_processes() + raise + except Exception as e: + # On any other error, also try cleanup + print(f"Test failed for {example_dir.name}: {e}, cleaning up processes...") + self._cleanup_training_processes() + raise + finally: + # Always attempt cleanup after each test to prevent accumulation + time.sleep(1) # Give processes time to naturally terminate + self._cleanup_training_processes() + + def _cleanup_training_processes(self): + """Clean up any remaining training processes to prevent GPU memory accumulation.""" + try: + # Find and kill any remaining training processes + result = subprocess.run( + ["ps", "aux"], capture_output=True, text=True, timeout=10 + ) + if result.returncode == 0: + lines = result.stdout.split('\n') + pids_to_kill = [] + + for line in lines: + # Look for training scripts and torchrun processes + if any(pattern in line for pattern in ['train_diloco.py', 'train_localsgd.py', 'train_ddp_proactive.py', 'torchrun']): + parts = line.split() + if len(parts) > 1: + try: + pid = int(parts[1]) + pids_to_kill.append(pid) + except (ValueError, IndexError): + continue + + # Kill the processes + for pid in pids_to_kill: + try: + subprocess.run(["kill", "-9", str(pid)], capture_output=True, timeout=5) + except (subprocess.TimeoutExpired, subprocess.CalledProcessError): + pass # Process might already be dead + + if pids_to_kill: + print(f"🧹 Cleaned up {len(pids_to_kill)} stale training processes") + + except (subprocess.TimeoutExpired, subprocess.CalledProcessError, FileNotFoundError): + # If cleanup fails, log but don't crash the test + print("Process cleanup encountered issues, but continuing...") + pass \ No newline at end of file diff --git a/examples/utils/utils.py b/examples/utils/utils.py new file mode 100644 index 00000000..76581013 --- /dev/null +++ b/examples/utils/utils.py @@ -0,0 +1,143 @@ +""" +Utility functions and classes for TorchFT examples. +""" + +import torch +import torch.utils.data as data +from PIL import Image +import numpy as np + + +class SyntheticCIFAR10(data.Dataset): + """ + Synthetic CIFAR10-like dataset for testing purposes. + + This dataset generates synthetic 32x32 RGB images with 10 classes, + mimicking the structure and interface of torchvision.datasets.CIFAR10 + without requiring network downloads. + + Args: + root (str): Not used, kept for interface compatibility with CIFAR10 + train (bool): Whether to generate training or test data + download (bool): Not used, kept for interface compatibility + transform: Optional transform to be applied on samples + target_transform: Optional transform to be applied on targets + size (int): Number of samples to generate (default: 1000) + """ + + def __init__(self, root=None, train=True, download=None, transform=None, + target_transform=None, size=1000): + self.train = train + self.transform = transform + self.target_transform = target_transform + self.size = size + + # Use fixed seed for deterministic generation + np.random.seed(42 if train else 123) + torch.manual_seed(42 if train else 123) + + # Generate synthetic data + self.data = self._generate_images() + self.targets = self._generate_labels() + + # Class names matching CIFAR10 + self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', + 'dog', 'frog', 'horse', 'ship', 'truck'] + + def _generate_images(self): + """Generate synthetic 32x32x3 images.""" + # Create synthetic images with some structure (not just random noise) + images = [] + for i in range(self.size): + # Create base pattern + img = np.zeros((32, 32, 3), dtype=np.uint8) + + # Add some structured patterns based on index + pattern_type = i % 4 + if pattern_type == 0: + # Gradient pattern + for row in range(32): + img[row, :, 0] = (row * 8) % 256 + img[:, row, 1] = (row * 8) % 256 + elif pattern_type == 1: + # Checkerboard pattern + for row in range(32): + for col in range(32): + if (row // 4 + col // 4) % 2: + img[row, col] = [255, 255, 255] + elif pattern_type == 2: + # Circular pattern + center = 16 + for row in range(32): + for col in range(32): + dist = ((row - center) ** 2 + (col - center) ** 2) ** 0.5 + intensity = int((np.sin(dist / 3) + 1) * 127) + img[row, col] = [intensity, intensity // 2, intensity] + else: + # Random noise pattern + img = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8) + + images.append(img) + + return np.array(images) + + def _generate_labels(self): + """Generate synthetic labels (0-9).""" + # Distribute labels evenly across classes + labels = [] + for i in range(self.size): + labels.append(i % 10) + return labels + + def __getitem__(self, index): + """Get a sample from the dataset.""" + img, target = self.data[index], self.targets[index] + + # Convert numpy array to PIL Image (matching CIFAR10 behavior) + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + """Return the size of the dataset.""" + return self.size + + +def get_cifar10_dataset(root="./cifar", train=True, download=True, transform=None, quick_run=False): + """ + Get CIFAR10 dataset - either real or synthetic based on quick_run flag. + + Args: + root (str): Root directory for real CIFAR10 data + train (bool): Whether to get training or test data + download (bool): Whether to download real CIFAR10 data + transform: Transform to apply to the data + quick_run (bool): If True, use synthetic data; if False, use real CIFAR10 + + Returns: + Dataset: Either SyntheticCIFAR10 or torchvision.datasets.CIFAR10 + """ + if quick_run: + print("Using synthetic CIFAR10 dataset") + return SyntheticCIFAR10( + root=root, + train=train, + download=download, + transform=transform, + size=1000 # Smaller dataset for quick testing + ) + else: + print("Using real CIFAR10 dataset") + import torchvision.datasets + return torchvision.datasets.CIFAR10( + root=root, + train=train, + download=download, + transform=transform + ) diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 5748def1..90b6db9b 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -222,7 +222,7 @@ def __init__( p = extract_local_tensor(p.data) backup_device = self._backup_device or torch.device("cpu") - t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=backup_device) + t: torch.Tensor = torch.empty(*tuple(p.shape), dtype=p.dtype, device=backup_device) if ( self._pin_memory and t.device == torch.device("cpu") diff --git a/torchft/manager.py b/torchft/manager.py index ae48cfd4..21ce7bd6 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -223,6 +223,7 @@ def __init__( ) if lighthouse_addr is not None and self._proactive_recovery: + print("In proactive recovery mode: ", self._proactive_recovery) ctx = multiprocessing.get_context("spawn") error_local, error_remote = ctx.Pipe() self._error_pipe = _MonitoredPipe(error_local)