From 3f93a8965c6eb68ea51db8f9d139d084ecb9a473 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Wed, 27 Nov 2024 10:09:21 -0800 Subject: [PATCH] Fix local restore by re-mapping device ids directly instead of inferring them from how process indexes changed across restarts with some false assumptions. PiperOrigin-RevId: 700737164 --- MaxText/max_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 550d22bf..c74297ef 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -325,6 +325,7 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): jax.distributed.initialize() ocp.multihost.initialize_runtime_to_distributed_ids() + ocp.multihost.initialize_distributed_to_device_ids() def _retrieve_jax_init_info(raw_keys):