diff --git a/torchft/manager.py b/torchft/manager.py index 668189c..967365e 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -151,6 +151,7 @@ def __init__( self._quorum_timeout = quorum_timeout self._connect_timeout = connect_timeout self._world_size_mode = world_size_mode + self._first_quorum = True store_addr = store_addr or os.environ["MASTER_ADDR"] store_port = store_port or int(os.environ["MASTER_PORT"]) @@ -404,8 +405,11 @@ def start_quorum( shrink_only=shrink_only, quorum_timeout=timeout or self._quorum_timeout, ) - if not self._use_async_quorum: + # If this is the first quorum sync, we need to wait for the result. + # Otherwise, we may mistakenly perform the forward with incorrect weights. + if not self._use_async_quorum or self._first_quorum: self.wait_quorum() + self._first_quorum = False if self._healing: # eagerly apply pending state_dict so we can run the forwards pass