Skip to content

Commit 326c8e5

Browse files
schwarzmxfacebook-github-bot
authored andcommitted
Improve error message in _read_snapshot_metadata
Summary: This is to mitigate confusion about what went wrong about the snapshot. Reviewed By: JKSenthil Differential Revision: D54705863
1 parent 514e43d commit 326c8e5

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

tests/test_snapshot.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import copy
1111
from pathlib import Path
1212
from typing import Any, Dict, List
13+
from unittest.mock import MagicMock
1314

1415
import pytest
1516

@@ -226,3 +227,21 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
226227
snapshot = Snapshot.take(app_state={"state": src}, path=str(tmp_path))
227228
snapshot.restore(app_state={"state": dst})
228229
assert check_state_dict_eq(src.state_dict(), dst.state_dict())
230+
231+
232+
@pytest.mark.usefixtures("toggle_batching")
233+
def test_snapshot_metadata_error(tmp_path: Path) -> None:
234+
mock_storage_plugin = MagicMock()
235+
mock_event_loop = MagicMock()
236+
mock_storage_plugin.sync_read.side_effect = Exception(
237+
"Mock error reading from storage"
238+
)
239+
with pytest.raises(
240+
expected_exception=RuntimeError,
241+
match=(
242+
"Failed to read .snapshot_metadata. "
243+
"Ensure path to snapshot is correct, "
244+
"otherwise snapshot is likely incomplete or corrupted."
245+
),
246+
):
247+
Snapshot._read_snapshot_metadata(mock_storage_plugin, mock_event_loop)

torchsnapshot/snapshot.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,14 @@ def _read_snapshot_metadata(
840840
storage: StoragePlugin, event_loop: asyncio.AbstractEventLoop
841841
) -> SnapshotMetadata:
842842
read_io = ReadIO(path=SNAPSHOT_METADATA_FNAME)
843-
storage.sync_read(read_io=read_io, event_loop=event_loop)
843+
try:
844+
storage.sync_read(read_io=read_io, event_loop=event_loop)
845+
except Exception as e:
846+
raise RuntimeError(
847+
f"Failed to read {SNAPSHOT_METADATA_FNAME}. "
848+
"Ensure path to snapshot is correct, "
849+
"otherwise snapshot is likely incomplete or corrupted."
850+
) from e
844851
yaml_str = read_io.buf.getvalue().decode("utf-8")
845852
return SnapshotMetadata.from_yaml(yaml_str)
846853

0 commit comments

Comments
 (0)