diff --git a/jetstream_pt/cli.py b/jetstream_pt/cli.py index f5e7af2..5eca56b 100644 --- a/jetstream_pt/cli.py +++ b/jetstream_pt/cli.py @@ -40,7 +40,6 @@ flags.DEFINE_bool("enable_model_warmup", False, "enable model warmup") - def shard_weights(env, weights, weight_shardings): """Shard weights according to weight_shardings""" sharded = {} diff --git a/jetstream_pt/fetch_models.py b/jetstream_pt/fetch_models.py index 8c3b0a9..3bee8b1 100644 --- a/jetstream_pt/fetch_models.py +++ b/jetstream_pt/fetch_models.py @@ -159,6 +159,10 @@ def _load_weights(directory): for key in f.keys(): state_dict[key] = f.get_tensor(key).to(torch.bfloat16) # Load the state_dict into the model + if not state_dict: + raise AssertionError( + f"Tried to load weights from {directory}, but couldn't find any." + ) return state_dict @@ -177,7 +181,8 @@ def instantiate_model_from_repo_id( """Create model instance by hf model id.+""" model_dir = _hf_dir(repo_id) if not FLAGS.internal_use_random_weights and ( - not os.path.exists(model_dir) or not os.listdir(model_dir) + not os.path.exists(model_dir) + or not glob.glob(os.path.join(model_dir, "*.safetensors")) ): # no weights has been downloaded _hf_download(repo_id, model_dir, FLAGS.hf_token)