diff --git a/setup.py b/setup.py index 114fbbd..6cc926a 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ "filelock", "flax>=0.4.1", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.13.2", + "huggingface-hub==0.24.7", "requests-mock==1.10.0", "importlib_metadata", "invisible-watermark>=0.2.0", diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 6a5eef2..438a074 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -397,8 +397,7 @@ def setup_initial_state( if not state: max_logging.log(f"Could not find the item in orbax, creating state...") - init_train_state_partial = functools.partial( - init_train_state, + state = init_train_state( model=model, tx=tx, weights_init_fn=weights_init_fn, @@ -407,8 +406,6 @@ def setup_initial_state( eval_only=False, ) - state = jax.jit(init_train_state_partial, in_shardings=None, out_shardings=state_mesh_shardings)() - state = unbox_logicallypartioned_trainstate(state) return state, state_mesh_shardings