Skip to content

Commit

Permalink
Merge branch 'mblaz/fix-flat-validation' into 'main'
Browse files Browse the repository at this point in the history
Improved flattened tensors validation

See merge request ADLR/megatron-lm!2409
  • Loading branch information
ericharper committed Dec 18, 2024
2 parents 319c8aa + 474f9c5 commit 1b7553e
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 14 deletions.
3 changes: 2 additions & 1 deletion megatron/core/dist_checkpointing/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class with `from_rank_offsets` or `from_rank_offsets_flat` constructors.
self.init_data(device='meta')
if self.data.shape != real_data.shape:
raise CheckpointingException(
f'Data shape doesnt match expected {self.data.shape} for {self}'
f'Data shape {real_data.shape} doesnt match'
f' expected {self.data.shape} for {self}'
)
finally:
self.data = real_data
Expand Down
25 changes: 12 additions & 13 deletions megatron/core/dist_checkpointing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,15 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
lambda x: x[1],
_validate_sharding_for_key_flattened,
)
else:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
# For each shard with at least 1 flattened tensor in it, the above
# `_validate_sharding_for_key_flattened` ensure a correct consistent pattern
# The only thing that can go wrong at this point is that some shard don't have
# *any* representatives which will be checked later by comparing `shard_access_cnt == 1`
shard_access_cnt = torch.minimum(shard_access_cnt, torch.tensor([1]))
if not torch.all(shard_access_cnt == 1):
raise CheckpointingException(
f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}'
)


def _compute_shards_access(rank_sharding):
Expand All @@ -489,16 +494,10 @@ def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))

starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if (
starts[0] != 0
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
expected_size = np.product(local_shape)
if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]):
raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}'
)


Expand Down
68 changes: 68 additions & 0 deletions tests/unit_tests/dist_checkpointing/test_flattened_resharding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import io
from contextlib import nullcontext

import numpy as np
import pytest
Expand All @@ -18,6 +19,10 @@
restore_nd_flattened_tensors_formulation,
)
from megatron.core.dist_checkpointing.strategies.torch import get_reformulation_metadata
from megatron.core.dist_checkpointing.validation import (
determine_global_metadata,
validate_sharding_integrity,
)
from tests.unit_tests.dist_checkpointing import TempNamedDir
from tests.unit_tests.test_utilities import Utils

Expand Down Expand Up @@ -198,3 +203,66 @@ def _build_state_dict(self, random=False):
),
}
return state_dict

def test_flattened_tensors_are_properly_validated(self, tmp_path_dist_ckpt):
Utils.initialize_model_parallel()
# Global tensor of shape (6, 6) is built from:
# ranks 0, 1, 2 tensors of length 1, 2, 3
# and then ranks 3, ..., 7 tensors of length 6
local_flat_ten = torch.ones(Utils.rank + 1 if Utils.rank <= 2 else 6) * Utils.rank

global_flattened_len = 6 + (Utils.world_size - 3) * 6
if Utils.world_size == 8:
assert global_flattened_len == 1 + 2 + 3 + 5 * 6
local_ten_shape = (1, 6)
else:
local_ten_shape = (global_flattened_len,)

if Utils.rank == 0:
local_dp_slice_start = 0
elif Utils.rank == 1:
local_dp_slice_start = 1
elif Utils.rank == 2:
local_dp_slice_start = 3
else:
local_dp_slice_start = 0
local_dp_slice = slice(local_dp_slice_start, local_dp_slice_start + len(local_flat_ten))

state_dict = {
'sd_key_flat': ShardedTensor.from_rank_offsets_flat(
'flat',
local_flat_ten,
local_ten_shape,
*((0, max(0, Utils.rank - 2), 6),) if Utils.world_size == 8 else (),
flattened_range=local_dp_slice,
replica_id=0
)
}
validate_sharding_integrity(determine_global_metadata(state_dict)[1])
if Utils.rank == 1:
old_state_dict = state_dict
state_dict = {}

with (
pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext()
) as exc_info:
validate_sharding_integrity(determine_global_metadata(state_dict)[1])
if Utils.rank == 0:
assert 'Flattened ranges dont cover the whole shard ShardedTensor' in str(
exc_info.value
)

if Utils.rank == 1:
state_dict = old_state_dict

if Utils.rank == 4:
state_dict = {}

with (
pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext()
) as exc_info:
validate_sharding_integrity(determine_global_metadata(state_dict)[1])
if Utils.rank == 0:
assert 'Invalid access pattern' in str(exc_info.value)

Utils.destroy_model_parallel()

0 comments on commit 1b7553e

Please sign in to comment.