From a0b1e55682a346c9522144557baf49278e6359e2 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 25 Oct 2024 18:05:28 +0000 Subject: [PATCH 1/2] Update the huggingface_hub version in setup script so that pip install can work. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 114fbbda..6cc926ac 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ "filelock", "flax>=0.4.1", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.13.2", + "huggingface-hub==0.24.7", "requests-mock==1.10.0", "importlib_metadata", "invisible-watermark>=0.2.0", From 07563ba568714b671c711f6f03196e05fe9cd081 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Fri, 25 Oct 2024 18:12:16 +0000 Subject: [PATCH 2/2] Fix the GPU multi host segmentation fault on XPK. --- src/maxdiffusion/max_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 6a5eef2c..438a0740 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -397,8 +397,7 @@ def setup_initial_state( if not state: max_logging.log(f"Could not find the item in orbax, creating state...") - init_train_state_partial = functools.partial( - init_train_state, + state = init_train_state( model=model, tx=tx, weights_init_fn=weights_init_fn, @@ -407,8 +406,6 @@ def setup_initial_state( eval_only=False, ) - state = jax.jit(init_train_state_partial, in_shardings=None, out_shardings=state_mesh_shardings)() - state = unbox_logicallypartioned_trainstate(state) return state, state_mesh_shardings