diff --git a/src/morphoclass/console/performance_table.py b/src/morphoclass/console/performance_table.py index d2c54b8..c228fbf 100644 --- a/src/morphoclass/console/performance_table.py +++ b/src/morphoclass/console/performance_table.py @@ -42,6 +42,8 @@ def make_performance_table( output_dir The report output directory. """ + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # Resolve paths from parameters files = [pathlib.Path(path).resolve() for path in checkpoint_paths] output_dir = pathlib.Path(output_dir).resolve() @@ -61,7 +63,7 @@ def make_performance_table( logger.info( f"Processing checkpoint {i + 1} of {len(files)}: {str(metrics_file)!r}" ) - data = torch.load(metrics_file) + data = torch.load(metrics_file, map_location=device) data["metrics_file"] = metrics_file @@ -129,7 +131,7 @@ def make_report_row(data: dict) -> dict: """ return { "dataset": data["dataset_name"], - "feature_extractor": data["feature_extractor_name"], + "feature_extractor": data["features_dir"], "model_class": data["model_class"], "model_params": data["model_params"], "oversampled": bool(data["oversampling"]), diff --git a/src/morphoclass/training/training_log.py b/src/morphoclass/training/training_log.py index 04bc945..26e6d4a 100644 --- a/src/morphoclass/training/training_log.py +++ b/src/morphoclass/training/training_log.py @@ -181,5 +181,7 @@ def save(self, path: pathlib.Path) -> None: @classmethod def load(cls, path: pathlib.Path) -> TrainingLog: """Load the training log from disk.""" - data = torch.load(path) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + data = torch.load(path, map_location=device) return cls.from_dict(data)