Skip to content

Commit 038cd03

Browse files
committed
fix
1 parent 496d315 commit 038cd03

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/olmo_core/distributed/checkpoint/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ def load_model_and_optim_state(
204204
log.info(
205205
f"Mapping current key '{current_key}' to key '{original_key}' in checkpoint"
206206
)
207-
state_dict[original_key] = state_dict.pop(current_key)
207+
current_root, current_key = current_key.split(".", 1)
208+
original_root, original_key = original_key.split(".", 1)
209+
state_dict[original_root][original_key] = state_dict[current_root].pop(current_key)
208210

209211
dist_cp.load(
210212
state_dict,

0 commit comments

Comments
 (0)