From 15d4fa0613923fdf8b1240b8766ead000f1bb8e8 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Sat, 8 Mar 2025 17:03:57 -0800 Subject: [PATCH 1/6] add option to skip initial sync in Manager Summary: We currently always heal on step 0 to avoid synchronization issues. We want an option to support skipping this sync for users who set the PyTorch seed so all ranks are initialized with the same values. This diff added a init_sync boolean flag that can be passed from the manager client in python to the manager service in rust. If the manager service skips the sync depending on whether the init_sync is true. Test Plan: Reviewers: Subscribers: Tasks: Tags: --- proto/torchft.proto | 1 + src/manager.rs | 5 ++++- torchft/_torchft.pyi | 1 + torchft/manager.py | 5 ++++- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index 7e248e5..7ca48cf 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -74,6 +74,7 @@ message ManagerQuorumRequest { int64 step = 2; string checkpoint_metadata = 3; bool shrink_only = 4; + optional bool init_sync = 5; } message ManagerQuorumResponse { diff --git a/src/manager.rs b/src/manager.rs index bd14783..a91bfbe 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -430,7 +430,10 @@ fn compute_quorum_results( .iter() .enumerate() .filter_map(|(i, p)| { - if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { + if init_sync + || p.step != max_step + || max_step == 0 && primary.replica_id != p.replica_id + { Some(i) } else { None diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index b4afde6..b8aecd0 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -10,6 +10,7 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, + init_sync: Optional[bool] = False, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( diff --git a/torchft/manager.py b/torchft/manager.py index 0da48d0..3120818 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -372,6 +372,7 @@ def start_quorum( allow_heal: bool = True, shrink_only: bool = False, timeout: Optional[timedelta] = None, + init_sync: bool = False, ) -> None: """ .. note:: @@ -407,6 +408,7 @@ def start_quorum( allow_heal=allow_heal, shrink_only=shrink_only, quorum_timeout=timeout or self._quorum_timeout, + init_sync=True, ) if not self._use_async_quorum: self.wait_quorum() @@ -431,7 +433,7 @@ def wait_quorum(self) -> None: self._quorum_future.result() def _async_quorum( - self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta + self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta, init_sync: bool ) -> None: quorum = self._client._quorum( rank=self._rank, @@ -439,6 +441,7 @@ def _async_quorum( checkpoint_metadata=self._checkpoint_transport.metadata(), shrink_only=shrink_only, timeout=quorum_timeout, + init_sync=init_sync, ) quorum_id = quorum.quorum_id From f803f40f43aba6934b1af29a1004699ade69b615 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Sat, 8 Mar 2025 22:14:10 -0700 Subject: [PATCH 2/6] Fixed some rust files and tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- src/lib.rs | 1 + src/manager.rs | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2d4de57..9725fc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -154,6 +154,7 @@ impl ManagerClient { step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, + init_sync: Some(false), }); // This timeout is processed on the server side so we also enable diff --git a/src/manager.rs b/src/manager.rs index a91bfbe..3004433 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -285,7 +285,12 @@ impl ManagerService for Arc { info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank); - let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?; + let reply = compute_quorum_results( + &self.replica_id, + rank, + &quorum, + req.init_sync.unwrap_or_default(), + )?; Ok(Response::new(reply)) } @@ -381,6 +386,7 @@ fn compute_quorum_results( replica_id: &str, rank: i64, quorum: &Quorum, + init_sync: bool, ) -> Result { let mut participants = quorum.participants.clone(); participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); @@ -607,6 +613,7 @@ mod tests { step: 123, checkpoint_metadata: "addr".to_string(), shrink_only: false, + init_sync: Some(false), }); request.set_timeout(Duration::from_secs(10)); let resp = client.quorum(request).await?.into_inner(); @@ -666,6 +673,7 @@ mod tests { step: 0, checkpoint_metadata: "addr".to_string(), shrink_only: false, + init_sync: Some(false), }); request.set_timeout(Duration::from_secs(10)); @@ -771,13 +779,13 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum)?; + let results = compute_quorum_results("replica_0", 0, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![1]); - let results = compute_quorum_results("replica_1", 0, &quorum)?; + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; assert!(results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, Some(0)); @@ -785,7 +793,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above and the primary - let results = compute_quorum_results("replica_1", 1, &quorum)?; + let results = compute_quorum_results("replica_1", 1, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); @@ -845,21 +853,21 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum)?; + let results = compute_quorum_results("replica_0", 0, &quorum, false)?; assert!(results.heal); assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, Some(1)); assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 0, &quorum)?; + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; assert!(!results.heal); assert_eq!(results.recover_src_manager_address, "".to_string()); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![0, 4]); - let results = compute_quorum_results("replica_3", 0, &quorum)?; + let results = compute_quorum_results("replica_3", 0, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 3); assert_eq!(results.recover_src_rank, None); @@ -867,7 +875,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above - let results = compute_quorum_results("replica_1", 1, &quorum)?; + let results = compute_quorum_results("replica_1", 1, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); From 472b8fc048e57883f57a1549fdca44fa7ca7c7f7 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Sun, 9 Mar 2025 18:04:31 -0700 Subject: [PATCH 3/6] Pass init_sync parameter in manager.py Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchft/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchft/manager.py b/torchft/manager.py index 3120818..9839383 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -408,7 +408,7 @@ def start_quorum( allow_heal=allow_heal, shrink_only=shrink_only, quorum_timeout=timeout or self._quorum_timeout, - init_sync=True, + init_sync=init_sync, ) if not self._use_async_quorum: self.wait_quorum() From e3611a1cb2264591090e8016ba14263ddb3a9c94 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Mon, 10 Mar 2025 07:42:45 -0700 Subject: [PATCH 4/6] Fix init_sync logic Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- proto/torchft.proto | 2 +- src/lib.rs | 2 +- src/manager.rs | 61 +++++++++++++++++++++----------------------- torchft/_torchft.pyi | 2 +- torchft/manager.py | 2 +- 5 files changed, 33 insertions(+), 36 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index 7ca48cf..5213c17 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -74,7 +74,7 @@ message ManagerQuorumRequest { int64 step = 2; string checkpoint_metadata = 3; bool shrink_only = 4; - optional bool init_sync = 5; + bool init_sync = 5; } message ManagerQuorumResponse { diff --git a/src/lib.rs b/src/lib.rs index 9725fc4..cf98c99 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -154,7 +154,7 @@ impl ManagerClient { step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, - init_sync: Some(false), + init_sync: true, }); // This timeout is processed on the server side so we also enable diff --git a/src/manager.rs b/src/manager.rs index 3004433..97623f7 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -285,12 +285,7 @@ impl ManagerService for Arc { info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank); - let reply = compute_quorum_results( - &self.replica_id, - rank, - &quorum, - req.init_sync.unwrap_or_default(), - )?; + let reply = compute_quorum_results(&self.replica_id, rank, &quorum, req.init_sync)?; Ok(Response::new(reply)) } @@ -429,23 +424,25 @@ fn compute_quorum_results( // Compute recovery assignments - // Nodes are recovering if: - // 1. not at the max step - // 2. max_step == 0 and not the primary replica - let all_recover_dst_ranks: Vec = participants - .iter() - .enumerate() - .filter_map(|(i, p)| { - if init_sync - || p.step != max_step - || max_step == 0 && primary.replica_id != p.replica_id - { - Some(i) - } else { - None - } - }) - .collect(); + let all_recover_dst_ranks = if init_sync { + // Nodes are recovering if + // 1. not at the max step + // 2. max_step == 0 and not the primary replica + participants + .iter() + .enumerate() + .filter_map(|(i, p)| { + if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { + Some(i) + } else { + None + } + }) + .collect() + } else { + Vec::::new() + }; + let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::>(); let up_to_date_ranks: Vec = participants .iter() @@ -613,7 +610,7 @@ mod tests { step: 123, checkpoint_metadata: "addr".to_string(), shrink_only: false, - init_sync: Some(false), + init_sync: true, }); request.set_timeout(Duration::from_secs(10)); let resp = client.quorum(request).await?.into_inner(); @@ -673,7 +670,7 @@ mod tests { step: 0, checkpoint_metadata: "addr".to_string(), shrink_only: false, - init_sync: Some(false), + init_sync: true, }); request.set_timeout(Duration::from_secs(10)); @@ -779,13 +776,13 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![1]); - let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, Some(0)); @@ -793,7 +790,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above and the primary - let results = compute_quorum_results("replica_1", 1, &quorum, false)?; + let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); @@ -853,21 +850,21 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, Some(1)); assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.recover_src_manager_address, "".to_string()); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![0, 4]); - let results = compute_quorum_results("replica_3", 0, &quorum, false)?; + let results = compute_quorum_results("replica_3", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 3); assert_eq!(results.recover_src_rank, None); @@ -875,7 +872,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above - let results = compute_quorum_results("replica_1", 1, &quorum, false)?; + let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index b8aecd0..9f182a3 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -10,7 +10,7 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, - init_sync: Optional[bool] = False, + init_sync: bool = True, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( diff --git a/torchft/manager.py b/torchft/manager.py index 9839383..50587ee 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -372,7 +372,7 @@ def start_quorum( allow_heal: bool = True, shrink_only: bool = False, timeout: Optional[timedelta] = None, - init_sync: bool = False, + init_sync: bool = True, ) -> None: """ .. note:: From 51c43dbaa5a9a90a7b56b55162ae0f48a6414e0b Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Tue, 11 Mar 2025 09:23:21 -0700 Subject: [PATCH 5/6] Add tests for manager.rs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- src/manager.rs | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/manager.rs b/src/manager.rs index 97623f7..d94c9e4 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -880,4 +880,86 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_compute_quorum_results_skip_init_sync() -> Result<()> { + let quorum = Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "replica_0".to_string(), + address: "addr_0".to_string(), + store_address: "store_addr_0".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_1".to_string(), + address: "addr_1".to_string(), + store_address: "store_addr_1".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_2".to_string(), + address: "addr_2".to_string(), + store_address: "store_addr_2".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_3".to_string(), + address: "addr_3".to_string(), + store_address: "store_addr_3".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_4".to_string(), + address: "addr_4".to_string(), + store_address: "store_addr_4".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + ], + created: None, + }; + + // rank 0 + + let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 0); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_3", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 3); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_1", 1, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + Ok(()) + } } From 3dbcb78f527270e1e6f755c9fb2e99e8fe787a81 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Wed, 12 Mar 2025 11:48:37 -0700 Subject: [PATCH 6/6] Added skip init_sync tests to python client Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchft/manager_test.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 05793e1..f5fd865 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -579,3 +579,35 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: client_mock().should_commit.call_args.kwargs["timeout"], timedelta(seconds=23), ) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_quorum_skip_init(self, client_mock: MagicMock) -> None: + manager = self._create_manager(use_async_quorum=False) + + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock()._quorum.return_value = quorum + + manager.start_quorum() + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], True + ) + + manager.start_quorum(init_sync=True) + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], True + ) + + manager.start_quorum(init_sync=False) + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], False + )