Skip to content

Commit 5dca802

Browse files
committed
feat: add test_checkpointer unit test suite
Signed-off-by: Charlie Doern <[email protected]>
1 parent 0f28c81 commit 5dca802

File tree

1 file changed

+193
-0
lines changed

1 file changed

+193
-0
lines changed

tests/unit/test_checkpointer.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Standard
2+
from pathlib import Path
3+
from unittest.mock import MagicMock, patch
4+
5+
# Third Party
6+
import pytest
7+
import torch
8+
import torch.distributed as dist
9+
10+
# First Party
11+
from instructlab.training.accelerator import Accelerator
12+
from instructlab.training.checkpointer import Checkpointer
13+
from instructlab.training.config import DistributedBackend
14+
15+
16+
@pytest.fixture(autouse=True)
17+
def mock_distributed():
18+
"""Mock PyTorch distributed functionality for all tests."""
19+
with (
20+
patch("torch.distributed.is_initialized", return_value=True),
21+
patch("torch.distributed.barrier") as mock_barrier,
22+
patch("torch.distributed.get_rank", return_value=0),
23+
):
24+
yield mock_barrier
25+
26+
27+
@pytest.fixture
28+
def mock_model():
29+
model = MagicMock()
30+
model.lora_config = None
31+
model.model_type = "llama"
32+
model.module = MagicMock()
33+
model.module.config = MagicMock()
34+
model.tokenizer = MagicMock()
35+
return model
36+
37+
38+
@pytest.fixture
39+
def mock_optimizer():
40+
return MagicMock()
41+
42+
43+
@pytest.fixture
44+
def mock_accelerator():
45+
accelerator = MagicMock(spec=Accelerator)
46+
accelerator.is_main_process = True
47+
accelerator.distributed_type = DistributedBackend.FSDP
48+
accelerator.distributed_framework = "fsdp"
49+
accelerator.get_state_dict = MagicMock()
50+
# Add missing methods that are used in the checkpointer
51+
accelerator.save_state = MagicMock()
52+
accelerator.save_model = MagicMock()
53+
return accelerator
54+
55+
56+
def test_checkpointer_initialization(mock_model, mock_optimizer, mock_accelerator):
57+
checkpointer = Checkpointer(
58+
model=mock_model,
59+
optimizer=mock_optimizer,
60+
accelerator=mock_accelerator,
61+
strategy="all",
62+
)
63+
64+
assert checkpointer.model == mock_model
65+
assert checkpointer.optimizer == mock_optimizer
66+
assert checkpointer.accelerator == mock_accelerator
67+
assert checkpointer.strategy == "all"
68+
69+
70+
def test_checkpointer_no_checkpoint(mock_model, mock_optimizer, mock_accelerator):
71+
checkpointer = Checkpointer(
72+
model=mock_model,
73+
optimizer=mock_optimizer,
74+
accelerator=mock_accelerator,
75+
strategy="none",
76+
)
77+
78+
# Test that no checkpointing occurs
79+
checkpointer.checkpoint(output_dir="test_dir", epoch=1, samples_seen=100)
80+
mock_accelerator.save_state.assert_not_called()
81+
82+
83+
def test_checkpointer_full_state(mock_model, mock_optimizer, mock_accelerator):
84+
checkpointer = Checkpointer(
85+
model=mock_model,
86+
optimizer=mock_optimizer,
87+
accelerator=mock_accelerator,
88+
strategy="full_state",
89+
)
90+
91+
output_dir = Path("test_dir")
92+
checkpointer.checkpoint(output_dir=output_dir, epoch=1, samples_seen=100)
93+
94+
# Verify accelerator save_state was called
95+
mock_accelerator.save_state.assert_called_once()
96+
# Verify metadata was saved
97+
assert (output_dir / "full_state" / "epoch_1" / "training_metadata.json").exists()
98+
99+
100+
def test_checkpointer_hf_format(mock_model, mock_optimizer, mock_accelerator):
101+
checkpointer = Checkpointer(
102+
model=mock_model,
103+
optimizer=mock_optimizer,
104+
accelerator=mock_accelerator,
105+
strategy="hf_format",
106+
)
107+
108+
output_dir = Path("test_dir")
109+
checkpointer.checkpoint(output_dir=output_dir, epoch=1, samples_seen=100)
110+
111+
# Verify model config and tokenizer were saved
112+
mock_model.module.config.to_json_file.assert_called_once()
113+
mock_model.tokenizer.save_pretrained.assert_called_once()
114+
# Verify accelerator save_model was called
115+
mock_accelerator.save_model.assert_called_once()
116+
117+
118+
def test_checkpointer_all_strategies(mock_model, mock_optimizer, mock_accelerator):
119+
checkpointer = Checkpointer(
120+
model=mock_model,
121+
optimizer=mock_optimizer,
122+
accelerator=mock_accelerator,
123+
strategy="all",
124+
)
125+
126+
output_dir = Path("test_dir")
127+
checkpointer.checkpoint(output_dir=output_dir, epoch=1, samples_seen=100)
128+
129+
# Verify both full state and HF format were saved
130+
mock_accelerator.save_state.assert_called_once()
131+
mock_model.module.config.to_json_file.assert_called_once()
132+
mock_model.tokenizer.save_pretrained.assert_called_once()
133+
mock_accelerator.save_model.assert_called_once()
134+
135+
136+
def test_checkpointer_lora_not_supported(mock_model, mock_optimizer, mock_accelerator):
137+
mock_model.lora_config = MagicMock() # Set lora_config to non-None
138+
139+
checkpointer = Checkpointer(
140+
model=mock_model,
141+
optimizer=mock_optimizer,
142+
accelerator=mock_accelerator,
143+
strategy="full_state",
144+
)
145+
146+
with pytest.raises(NotImplementedError):
147+
checkpointer.checkpoint(output_dir="test_dir", epoch=1, samples_seen=100)
148+
149+
150+
def test_checkpointer_load_latest_full_state(
151+
mock_model, mock_optimizer, mock_accelerator
152+
):
153+
checkpointer = Checkpointer(
154+
model=mock_model,
155+
optimizer=mock_optimizer,
156+
accelerator=mock_accelerator,
157+
strategy="all",
158+
)
159+
160+
# Mock the output directory structure
161+
output_dir = Path("test_dir")
162+
checkpoint_dir = output_dir / "full_state" / "epoch_1"
163+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
164+
165+
# Mock the accelerator's load_state method
166+
mock_accelerator.load_state = MagicMock()
167+
168+
checkpointer.load_latest_full_state(output_dir)
169+
170+
# Verify accelerator load_state was called
171+
mock_accelerator.load_state.assert_called_once()
172+
173+
174+
def test_checkpointer_save_last_epoch(mock_model, mock_optimizer, mock_accelerator):
175+
checkpointer = Checkpointer(
176+
model=mock_model,
177+
optimizer=mock_optimizer,
178+
accelerator=mock_accelerator,
179+
strategy="hf_format",
180+
)
181+
182+
output_dir = Path("test_dir")
183+
checkpointer.checkpoint(
184+
output_dir=output_dir,
185+
epoch=1,
186+
samples_seen=100,
187+
last_epoch=True,
188+
)
189+
190+
# Verify model was saved in last_epoch directory
191+
mock_model.module.config.to_json_file.assert_called_once()
192+
mock_model.tokenizer.save_pretrained.assert_called_once()
193+
mock_accelerator.save_model.assert_called_once()

0 commit comments

Comments
 (0)