diff --git a/megatron/core/dist_checkpointing/mapping.py b/megatron/core/dist_checkpointing/mapping.py index 2ddfcf3b31..d376c6374b 100644 --- a/megatron/core/dist_checkpointing/mapping.py +++ b/megatron/core/dist_checkpointing/mapping.py @@ -119,7 +119,8 @@ class with `from_rank_offsets` or `from_rank_offsets_flat` constructors. self.init_data(device='meta') if self.data.shape != real_data.shape: raise CheckpointingException( - f'Data shape doesnt match expected {self.data.shape} for {self}' + f'Data shape {real_data.shape} doesnt match' + f' expected {self.data.shape} for {self}' ) finally: self.data = real_data diff --git a/megatron/core/dist_checkpointing/validation.py b/megatron/core/dist_checkpointing/validation.py index 48e023dc39..5142ec6261 100644 --- a/megatron/core/dist_checkpointing/validation.py +++ b/megatron/core/dist_checkpointing/validation.py @@ -461,10 +461,15 @@ def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]): lambda x: x[1], _validate_sharding_for_key_flattened, ) - else: - if not torch.all(shard_access_cnt == 1): - logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}') - raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}') + # For each shard with at least 1 flattened tensor in it, the above + # `_validate_sharding_for_key_flattened` ensure a correct consistent pattern + # The only thing that can go wrong at this point is that some shard don't have + # *any* representatives which will be checked later by comparing `shard_access_cnt == 1` + shard_access_cnt = torch.minimum(shard_access_cnt, torch.tensor([1])) + if not torch.all(shard_access_cnt == 1): + raise CheckpointingException( + f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}' + ) def _compute_shards_access(rank_sharding): @@ -489,16 +494,10 @@ def _validate_sharding_for_key_flattened(tensors_by_shard): all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop)) starts, stops = map(np.asarray, zip(*sorted(all_slices))) - if ( - starts[0] != 0 - or stops[-1] != np.product(local_shape) - or not np.all(starts[1:] == stops[:-1]) - ): - logger.error( - f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}' - ) + expected_size = np.product(local_shape) + if starts[0] != 0 or stops[-1] != expected_size or not np.all(starts[1:] == stops[:-1]): raise CheckpointingException( - f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}' + f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]} of size {expected_size}. Ranges: {(starts, stops)}' ) diff --git a/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py b/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py index fa00a20cad..1485eebe10 100644 --- a/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py +++ b/tests/unit_tests/dist_checkpointing/test_flattened_resharding.py @@ -1,6 +1,7 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import io +from contextlib import nullcontext import numpy as np import pytest @@ -18,6 +19,10 @@ restore_nd_flattened_tensors_formulation, ) from megatron.core.dist_checkpointing.strategies.torch import get_reformulation_metadata +from megatron.core.dist_checkpointing.validation import ( + determine_global_metadata, + validate_sharding_integrity, +) from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils @@ -198,3 +203,66 @@ def _build_state_dict(self, random=False): ), } return state_dict + + def test_flattened_tensors_are_properly_validated(self, tmp_path_dist_ckpt): + Utils.initialize_model_parallel() + # Global tensor of shape (6, 6) is built from: + # ranks 0, 1, 2 tensors of length 1, 2, 3 + # and then ranks 3, ..., 7 tensors of length 6 + local_flat_ten = torch.ones(Utils.rank + 1 if Utils.rank <= 2 else 6) * Utils.rank + + global_flattened_len = 6 + (Utils.world_size - 3) * 6 + if Utils.world_size == 8: + assert global_flattened_len == 1 + 2 + 3 + 5 * 6 + local_ten_shape = (1, 6) + else: + local_ten_shape = (global_flattened_len,) + + if Utils.rank == 0: + local_dp_slice_start = 0 + elif Utils.rank == 1: + local_dp_slice_start = 1 + elif Utils.rank == 2: + local_dp_slice_start = 3 + else: + local_dp_slice_start = 0 + local_dp_slice = slice(local_dp_slice_start, local_dp_slice_start + len(local_flat_ten)) + + state_dict = { + 'sd_key_flat': ShardedTensor.from_rank_offsets_flat( + 'flat', + local_flat_ten, + local_ten_shape, + *((0, max(0, Utils.rank - 2), 6),) if Utils.world_size == 8 else (), + flattened_range=local_dp_slice, + replica_id=0 + ) + } + validate_sharding_integrity(determine_global_metadata(state_dict)[1]) + if Utils.rank == 1: + old_state_dict = state_dict + state_dict = {} + + with ( + pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext() + ) as exc_info: + validate_sharding_integrity(determine_global_metadata(state_dict)[1]) + if Utils.rank == 0: + assert 'Flattened ranges dont cover the whole shard ShardedTensor' in str( + exc_info.value + ) + + if Utils.rank == 1: + state_dict = old_state_dict + + if Utils.rank == 4: + state_dict = {} + + with ( + pytest.raises(CheckpointingException) if Utils.rank == 0 else nullcontext() + ) as exc_info: + validate_sharding_integrity(determine_global_metadata(state_dict)[1]) + if Utils.rank == 0: + assert 'Invalid access pattern' in str(exc_info.value) + + Utils.destroy_model_parallel()