diff --git a/python/keepsake/pl_callback.py b/python/keepsake/pl_callback.py index 7bbfedae..56089de3 100644 --- a/python/keepsake/pl_callback.py +++ b/python/keepsake/pl_callback.py @@ -1,5 +1,6 @@ from copy import deepcopy from typing import Optional, Dict, Tuple, Any +from pathlib import Path import keepsake from pytorch_lightning.callbacks.base import Callback @@ -56,7 +57,7 @@ def __init__( """ super().__init__() - self.filepath = filepath + self.filepath = Path(filepath).resolve() self.params = params self.primary_metric = primary_metric self.period = period @@ -64,7 +65,10 @@ def __init__( self.last_global_step_saved = -1 def on_pretrain_routine_start(self, trainer, pl_module): - self.experiment = keepsake.init(path=".", params=self.params) + self.experiment = keepsake.init( + path=str(self.filepath.parent), + params=self.params, + ) def on_epoch_end(self, trainer, pl_module): self._save_model(trainer, pl_module) @@ -89,7 +93,7 @@ def _save_model(self, trainer, pl_module): return if self.filepath != None: - trainer.save_checkpoint(self.filepath, self.save_weights_only) + trainer.save_checkpoint(self.filepath.name, self.save_weights_only) self.last_global_step_saved = global_step @@ -99,7 +103,7 @@ def _save_model(self, trainer, pl_module): metrics.update({"global_step": trainer.global_step}) self.experiment.checkpoint( - path=self.filepath, + path=self.filepath.name, step=trainer.current_epoch, metrics=metrics, primary_metric=self.primary_metric,