How can be the rng_state stored using the checkpoint_callbak. #6445
Unanswered
dhorka
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 1 comment 4 replies
-
Lightnings checkpoints also save the users hyperparameters (https://pytorch-lightning.readthedocs.io/en/latest/common/weights_loading.html). Would that be an option for you to save the states as part of the hyperparameters? class MyModule(LightningModule):
def __init__(self, torch_state, numpy_state):
self.save_hyperparameters()
numpy_state = numpy.random.get_state()
torch_state = torch.get_rng_state()
model = MyModule(torch_state, numpy_state) |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I would like to know how can I store the rng_state from the different random generators used in a model, such as NumPy and PyTorch random generators. Each framework has its own method to recover the states, for instance, PyTorch has torch.get_rng_state that returns the current state of the rng. I would like to be able to save these states, in order to resume the experiment at the same exact point where it was at the first run. As far as I saw, at this moment this behaviour is not implemented in the checkpoint_callback.
Thanks,
Beta Was this translation helpful? Give feedback.
All reactions