diff --git a/src/garage/experiment/snapshotter.py b/src/garage/experiment/snapshotter.py index 549569fbf7..4ed71da57a 100644 --- a/src/garage/experiment/snapshotter.py +++ b/src/garage/experiment/snapshotter.py @@ -159,7 +159,7 @@ def load(self, load_dir, itr='last'): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), '*.pkl file in', load_dir) - files.sort() + files.sort(key=_extract_snapshot_itr) load_from_file = files[0] if itr == 'first' else files[-1] load_from_file = os.path.join(load_dir, load_from_file) @@ -170,5 +170,20 @@ def load(self, load_dir, itr='last'): return cloudpickle.load(file) +def _extract_snapshot_itr(filename: str) -> int: + """Extracts the integer itr from a filename. + + Args: + filename(str): The snapshot filename. + + Returns: + int: The snapshot as an integer. + + """ + base = os.path.splitext(filename)[0] + digits = base.split('itr_')[1] + return int(digits) + + class NotAFileError(Exception): """Raise when the snapshot is not a file.""" diff --git a/tests/garage/experiment/test_snapshotter.py b/tests/garage/experiment/test_snapshotter.py index 358d5887d1..b32b6870ca 100644 --- a/tests/garage/experiment/test_snapshotter.py +++ b/tests/garage/experiment/test_snapshotter.py @@ -22,6 +22,7 @@ class TestSnapshotter: def setup_method(self): + # pylint: disable=consider-using-with self.temp_dir = tempfile.TemporaryDirectory() def teardown_method(self): @@ -78,3 +79,10 @@ def test_conflicting_params(self): Snapshotter(snapshot_dir=self.temp_dir.name, snapshot_mode='gap_overwrite', snapshot_gap=1) + + def test_sorts_correctly(self): + snapshotter = Snapshotter(self.temp_dir.name, 'all', 2) + snapshotter.save_itr_params(80, {'test_itr': 80}) + snapshotter.save_itr_params(120, {'test_itr': 120}) + last = snapshotter.load(self.temp_dir.name) + assert last['test_itr'] == 120