Skip to content

Commit

Permalink
fix checkpoint test
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed May 7, 2024
1 parent dc16384 commit 3aa30bf
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,28 @@ def save_and_load_checkpoint_with_different_sharding_spec(dir):
[
# save_tensor: |x x x|x x x
# load_tensor: |x x|x x x x
(((0, 3), (3, 6)), ((0, 2), (2, 6))),
((((0, 3),), ((3, 6),)), (((0, 2),), ((2, 6),))),
# save_tensor: |x x x|x x x
# load_tensor: |x x x x|x x
(((0, 3), (3, 6)), ((0, 4), (4, 6))),
((((0, 3),), ((3, 6),)), (((0, 4),), ((4, 6),))),
# save_tensor: |x x x x x x|
# load_tensor: |x x x x|x x
(((0, 6), (6, 6)), ((0, 4), (4, 6))),
((((0, 6),), ((6, 6),)), (((0, 4),), ((4, 6),))),
]
):
checkpointer = Checkpointer()

state_dict_to_save = {
"x": ShardedFlatParameter.shard(
torch.rand(2, 3, device=get_default_device()),
ShardingSpec(unsharded_shape=(2, 3), unsharded_flattened_offsets=(offsets_to_save,)),
ShardingSpec(unsharded_shape=(2, 3), unsharded_flattened_offsets=offsets_to_save),
),
}

state_dict_to_load = {
"x": ShardedFlatParameter.shard(
torch.rand(2, 3, device=get_default_device()),
ShardingSpec(unsharded_shape=(2, 3), unsharded_flattened_offsets=(offsets_to_load,)),
ShardingSpec(unsharded_shape=(2, 3), unsharded_flattened_offsets=offsets_to_load),
),
}

Expand Down

0 comments on commit 3aa30bf

Please sign in to comment.