Skip to content

Commit

Permalink
Fix local restore by re-mapping device ids directly instead of inferr…
Browse files Browse the repository at this point in the history
…ing them from how process indexes changed across restarts with some false assumptions.

PiperOrigin-RevId: 700737164
  • Loading branch information
cpgaffney1 authored and maxtext authors committed Nov 27, 2024
1 parent 392a23e commit 3f93a89
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys):
jax.distributed.initialize()

ocp.multihost.initialize_runtime_to_distributed_ids()
ocp.multihost.initialize_distributed_to_device_ids()


def _retrieve_jax_init_info(raw_keys):
Expand Down

0 comments on commit 3f93a89

Please sign in to comment.