Skip to content

Commit 9da93a6

Browse files
author
The gemma Authors
committed
Add an option to wait until params are loaded
PiperOrigin-RevId: 794898550
1 parent 532b12d commit 9da93a6

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

gemma/gm/ckpts/_checkpoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def load_params(
182182
sharding: kd.sharding.ShardingTree | None = None,
183183
quantize: bool = False,
184184
use_ocdbt: bool = True,
185+
block_until_ready: bool = False,
185186
) -> Params:
186187
"""Restore the params from a checkpoint.
187188
@@ -251,6 +252,8 @@ def load_params(
251252
output_with_skip = metadata.make_tree_for_params(params)
252253
restore_fn = functools.partial(ckpt.restore, path)
253254
output = _partial_restore(restore_fn, output_with_skip)
255+
if block_until_ready:
256+
output.block_until_ready()
254257

255258
# TODO(epot): Better API. Currently this do not quantize the weights, but
256259
# just refactor the params to the QAT structure.

0 commit comments

Comments
 (0)