Skip to content

Commit

Permalink
Map checkpoint data to current device (cuda/cpu) when torch.load()
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescoCasalegno committed Apr 20, 2022
1 parent f4f0d49 commit 8fca841
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/morphoclass/console/performance_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def make_performance_table(
output_dir
The report output directory.
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Resolve paths from parameters
files = [pathlib.Path(path).resolve() for path in checkpoint_paths]
output_dir = pathlib.Path(output_dir).resolve()
Expand All @@ -61,7 +63,7 @@ def make_performance_table(
logger.info(
f"Processing checkpoint {i + 1} of {len(files)}: {str(metrics_file)!r}"
)
data = torch.load(metrics_file)
data = torch.load(metrics_file, map_location=device)

data["metrics_file"] = metrics_file

Expand Down Expand Up @@ -129,7 +131,7 @@ def make_report_row(data: dict) -> dict:
"""
return {
"dataset": data["dataset_name"],
"feature_extractor": data["feature_extractor_name"],
"feature_extractor": data["features_dir"],
"model_class": data["model_class"],
"model_params": data["model_params"],
"oversampled": bool(data["oversampling"]),
Expand Down
4 changes: 3 additions & 1 deletion src/morphoclass/training/training_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,5 +181,7 @@ def save(self, path: pathlib.Path) -> None:
@classmethod
def load(cls, path: pathlib.Path) -> TrainingLog:
"""Load the training log from disk."""
data = torch.load(path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data = torch.load(path, map_location=device)
return cls.from_dict(data)

0 comments on commit 8fca841

Please sign in to comment.