We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 532b12d commit 9da93a6Copy full SHA for 9da93a6
gemma/gm/ckpts/_checkpoint.py
@@ -182,6 +182,7 @@ def load_params(
182
sharding: kd.sharding.ShardingTree | None = None,
183
quantize: bool = False,
184
use_ocdbt: bool = True,
185
+ block_until_ready: bool = False,
186
) -> Params:
187
"""Restore the params from a checkpoint.
188
@@ -251,6 +252,8 @@ def load_params(
251
252
output_with_skip = metadata.make_tree_for_params(params)
253
restore_fn = functools.partial(ckpt.restore, path)
254
output = _partial_restore(restore_fn, output_with_skip)
255
+ if block_until_ready:
256
+ output.block_until_ready()
257
258
# TODO(epot): Better API. Currently this do not quantize the weights, but
259
# just refactor the params to the QAT structure.
0 commit comments