From 9a0e78df73a45ff48e74faeebd08f4f4b282d949 Mon Sep 17 00:00:00 2001 From: Mikolaj Blaz Date: Fri, 6 Sep 2024 13:59:19 -0700 Subject: [PATCH] ADLR/megatron-lm!2042 - Fix `zarr` deprecation notice --- .../dist_checkpointing/strategies/__init__.py | 8 ++-- .../dist_checkpointing/strategies/base.py | 45 +++++++++++++++--- .../dist_checkpointing/strategies/common.py | 27 +++++++---- .../strategies/tensorstore.py | 26 +++++------ .../dist_checkpointing/strategies/torch.py | 26 +++++++---- .../dist_checkpointing/strategies/zarr.py | 46 ++++++++++++------- 6 files changed, 118 insertions(+), 60 deletions(-) diff --git a/megatron/core/dist_checkpointing/strategies/__init__.py b/megatron/core/dist_checkpointing/strategies/__init__.py index 6d04265b16..a786b8e84a 100644 --- a/megatron/core/dist_checkpointing/strategies/__init__.py +++ b/megatron/core/dist_checkpointing/strategies/__init__.py @@ -1,9 +1,7 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. """ Various loading and saving strategies """ +from megatron.core.dist_checkpointing.strategies.common import register_default_common_strategies -# We mock imports to populate the `default_strategies` objects. -# Since they are defined in base but populated in common, we have to mock -# import both modules. -from megatron.core.dist_checkpointing.strategies.base import _import_trigger -from megatron.core.dist_checkpointing.strategies.common import _import_trigger +# We load "common" strategies by default to be always available +register_default_common_strategies() diff --git a/megatron/core/dist_checkpointing/strategies/base.py b/megatron/core/dist_checkpointing/strategies/base.py index cc1c83b92f..35fca1f350 100644 --- a/megatron/core/dist_checkpointing/strategies/base.py +++ b/megatron/core/dist_checkpointing/strategies/base.py @@ -6,20 +6,21 @@ from collections import defaultdict from enum import Enum from pathlib import Path -from typing import Any, DefaultDict +from typing import Any, DefaultDict, Union from ..mapping import CheckpointingException, ShardedStateDict, StateDict from .async_utils import AsyncCallsQueue, AsyncRequest class StrategyAction(Enum): + """Specifies save vs load and sharded vs common action.""" + LOAD_COMMON = 'load_common' LOAD_SHARDED = 'load_sharded' SAVE_COMMON = 'save_common' SAVE_SHARDED = 'save_sharded' -_import_trigger = None default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict) async_calls = AsyncCallsQueue() @@ -30,11 +31,17 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int): try: if backend == 'zarr': error_hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages' - from .tensorstore import _import_trigger - from .zarr import _import_trigger + from .tensorstore import register_default_tensorstore_strategies + + register_default_tensorstore_strategies() + from .zarr import register_default_zarr_strategies + + register_default_zarr_strategies() elif backend == 'torch_dist': error_hint = ' Please use PyTorch version >=2.1' - from .torch import _import_trigger + from .torch import register_default_torch_strategies + + register_default_torch_strategies() except ImportError as e: raise CheckpointingException( f'Cannot import a default strategy for: {(action.value, backend, version)}. ' @@ -48,16 +55,35 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int): ) from e +def register_default_strategy( + action: StrategyAction, + backend: str, + version: int, + strategy: Union['SaveStrategyBase', 'LoadStrategyBase'], +): + """Adds a given strategy to the registry of default strategies. + + Args: + action (StrategyAction): specifies save/load and sharded/common + backend (str): backend that the strategy becomes a default for + version (int): version that the strategy becomes a default for + strategy (SaveStrategyBase, LoadStrategyBase): strategy to register + """ + default_strategies[action.value][(backend, version)] = strategy + + class LoadStrategyBase(ABC): """Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.""" @abstractmethod - def check_backend_compatibility(self, loaded_version): + def check_backend_compatibility(self, loaded_backend): + """Verifies if this strategy is compatible with `loaded_backend`.""" raise NotImplementedError @abstractmethod def check_version_compatibility(self, loaded_version): + """Verifies if this strategy is compatible with `loaded_version`.""" raise NotImplementedError @property @@ -88,15 +114,18 @@ class LoadCommonStrategy(LoadStrategyBase): @abstractmethod def load_common(self, checkpoint_dir: Path): + """Load common part of the checkpoint.""" raise NotImplementedError @abstractmethod def load_sharded_objects( self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path ): + """Load sharded objects from the checkpoint.""" raise NotImplementedError def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: + """Load just the metadata from the checkpoint.""" if not self.can_handle_sharded_objects: return {} raise NotImplementedError @@ -107,6 +136,7 @@ class LoadShardedStrategy(LoadStrategyBase): @abstractmethod def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Load the sharded part of the checkpoint.""" raise NotImplementedError @abstractmethod @@ -145,11 +175,13 @@ class SaveCommonStrategy(SaveStrategyBase): @abstractmethod def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path): + """Save common part of the state dict.""" raise NotImplementedError def save_sharded_objects( self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path ): + """Save sharded objects from the state dict.""" raise NotImplementedError @@ -158,6 +190,7 @@ class SaveShardedStrategy(SaveStrategyBase): @abstractmethod def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Save the sharded part of the state dict.""" raise NotImplementedError diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py index 46f10733f5..f2c87b4d60 100644 --- a/megatron/core/dist_checkpointing/strategies/common.py +++ b/megatron/core/dist_checkpointing/strategies/common.py @@ -12,29 +12,38 @@ from megatron.core.dist_checkpointing.strategies.base import ( SaveCommonStrategy, StrategyAction, - default_strategies, + register_default_strategy, ) from ..dict_utils import dict_list_map_inplace, nested_values from ..mapping import CheckpointingException, ShardedObject, is_main_replica from ..strategies.base import LoadCommonStrategy -_import_trigger = None - COMMON_STATE_FNAME = 'common.pt' logger = logging.getLogger(__name__) +def register_default_common_strategies(): + """Register default common strategies.""" + register_default_strategy(StrategyAction.LOAD_COMMON, 'torch', 1, TorchCommonLoadStrategy()) + register_default_strategy( + StrategyAction.SAVE_COMMON, 'torch', 1, TorchCommonSaveStrategy('torch', 1) + ) + + class TorchCommonSaveStrategy(SaveCommonStrategy): + """Common save strategy leveraging native torch save/load.""" + def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path): + """Save common part of the state dict.""" if torch.distributed.get_rank() == 0: torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME) def save_sharded_objects( self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path ): - + """Save sharded objects from the state dict.""" for sh_obj in nested_values(sharded_objects_state_dict): if is_main_replica(sh_obj.replica_id): save_path = checkpoint_dir / f'{sh_obj.unique_key}.pt' @@ -42,10 +51,13 @@ def save_sharded_objects( torch.save(sh_obj.data, save_path) def can_handle_sharded_objects(self): + """This strategy can handle ShardedObjects.""" return True class TorchCommonLoadStrategy(LoadCommonStrategy): + """Common load strategy leveraging native torch save/load.""" + def load_common(self, checkpoint_dir: Path): """Load common (non-sharded) objects state dict from the checkpoint. @@ -135,6 +147,7 @@ def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: @property def can_handle_sharded_objects(self): + """This strategy can handle ShardedObjects.""" return True def check_backend_compatibility(self, loaded_version): @@ -142,9 +155,3 @@ def check_backend_compatibility(self, loaded_version): def check_version_compatibility(self, loaded_version): pass - - -default_strategies[StrategyAction.LOAD_COMMON.value][('torch', 1)] = TorchCommonLoadStrategy() -default_strategies[StrategyAction.SAVE_COMMON.value][('torch', 1)] = TorchCommonSaveStrategy( - 'torch', 1 -) diff --git a/megatron/core/dist_checkpointing/strategies/tensorstore.py b/megatron/core/dist_checkpointing/strategies/tensorstore.py index 9b4eeb3185..0b20bf5e77 100644 --- a/megatron/core/dist_checkpointing/strategies/tensorstore.py +++ b/megatron/core/dist_checkpointing/strategies/tensorstore.py @@ -13,19 +13,22 @@ from ..core import CheckpointingException from ..dict_utils import dict_list_map_inplace from ..mapping import ShardedStateDict, ShardedTensor -from .base import LoadShardedStrategy, StrategyAction, default_strategies -from .zarr import ( - load_zarr_based_sharded_metadata, - numpy_to_torch_dtype_dict, - postprocess_numpy_array, -) - -_import_trigger = None +from .base import LoadShardedStrategy, StrategyAction, register_default_strategy +from .zarr import load_zarr_based_sharded_metadata, postprocess_numpy_array logger = getLogger(__name__) +def register_default_tensorstore_strategies(): + """Register default strategies leveraging tensorstore.""" + register_default_strategy( + StrategyAction.LOAD_SHARDED, 'zarr', 1, TensorStoreLoadShardedStrategy() + ) + + class TensorStoreLoadShardedStrategy(LoadShardedStrategy): + """Load strategy for Zarr backend using `tensorstore` for loading.""" + def __init__(self, load_directly_on_device: bool = False): super().__init__() self.load_directly_on_device = load_directly_on_device @@ -58,6 +61,8 @@ def check_version_compatibility(self, loaded_version): def merge_global_slice_with_shape(global_slice, actual_shape, key): + """Intersects the global slice with the actual shape (prevent overflow).""" + def _merge_slice(dim_slice, dim_size): if isinstance(dim_slice, slice): assert ( @@ -121,8 +126,3 @@ def open_ts_array(arr_path: Path): except Exception as e: raise CheckpointingException(f'Array {arr_path} could not be loaded. Error: {e}') from e return arr - - -default_strategies[StrategyAction.LOAD_SHARDED.value][ - ('zarr', 1) -] = TensorStoreLoadShardedStrategy() diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index be959bff7a..96f2c316c5 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -44,7 +44,12 @@ is_main_replica, ) from .async_utils import AsyncRequest -from .base import AsyncSaveShardedStrategy, LoadShardedStrategy, StrategyAction, default_strategies +from .base import ( + AsyncSaveShardedStrategy, + LoadShardedStrategy, + StrategyAction, + register_default_strategy, +) from .filesystem_async import FileSystemWriterAsync from .resharding import ( TensorReformulationMetadata, @@ -64,7 +69,16 @@ except ImportError: HAVE_TE = False -_import_trigger = None + +def register_default_torch_strategies(): + """Register default strategies related to PyT Distributed backend.""" + register_default_strategy( + StrategyAction.LOAD_SHARDED, 'torch_dist', 1, TorchDistLoadShardedStrategy() + ) + register_default_strategy( + StrategyAction.SAVE_SHARDED, 'torch_dist', 1, TorchDistSaveShardedStrategy('torch_dist', 1) + ) + logger = getLogger(__name__) @@ -818,11 +832,3 @@ def check_backend_compatibility(self, loaded_version): def check_version_compatibility(self, loaded_version): pass # TODO - - -default_strategies[StrategyAction.LOAD_SHARDED.value][ - ('torch_dist', 1) -] = TorchDistLoadShardedStrategy() -default_strategies[StrategyAction.SAVE_SHARDED.value][('torch_dist', 1)] = ( - TorchDistSaveShardedStrategy('torch_dist', 1) -) diff --git a/megatron/core/dist_checkpointing/strategies/zarr.py b/megatron/core/dist_checkpointing/strategies/zarr.py index ee63c91c53..6a52b12257 100644 --- a/megatron/core/dist_checkpointing/strategies/zarr.py +++ b/megatron/core/dist_checkpointing/strategies/zarr.py @@ -3,7 +3,6 @@ """ Strategies using Zarr as an underlying format. """ import logging import os -import threading from functools import partial from logging import getLogger from pathlib import Path @@ -16,7 +15,12 @@ from ..core import CheckpointingException from ..dict_utils import dict_list_map_inplace, nested_values from ..mapping import ShardedStateDict, ShardedTensor, is_main_replica -from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies +from .base import ( + LoadShardedStrategy, + SaveShardedStrategy, + StrategyAction, + register_default_strategy, +) logger = logging.getLogger(__name__) @@ -38,7 +42,8 @@ try: - import tensorstore + # Register a bfloat16 type with this import + import tensorstore # pylint: disable=unused-import HAS_BFLOAT16 = True numpy_to_torch_dtype_dict[np.dtype('bfloat16')] = torch.bfloat16 @@ -46,12 +51,19 @@ except ImportError: HAS_BFLOAT16 = False -_import_trigger = None - logger = getLogger(__name__) +def register_default_zarr_strategies(): + """Register default strategies related to Zarr backend.""" + register_default_strategy( + StrategyAction.SAVE_SHARDED, 'zarr', 1, ZarrSaveShardedStrategy('zarr', 1) + ) + + class ZarrSaveShardedStrategy(SaveShardedStrategy): + """Save strategy for Zarr backend.""" + def __init__(self, backend: str, version: int): super().__init__(backend, version) logger.warning( @@ -74,11 +86,13 @@ def _create_or_open_zarr_arrays( For a sharded tensors that: a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array - b) is main replica but not the first chunk, opens the arrays created in (a) (possibly by other process) + b) is main replica but not the first chunk, + opens the arrays created in (a) (possibly by other process) c) otherwise, sets the corresponding array to None since it won't be used Args: - sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank that will be saved to checkpoint + sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank + that will be saved to checkpoint checkpoint_dir (Path): checkpoint in which the arrays will be created """ arrays = [] @@ -159,6 +173,8 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path): class ZarrLoadShardedStrategy(LoadShardedStrategy): + """Load strategy for the Zarr backend.""" + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): dict_list_map_inplace( partial(_load_from_array, checkpoint_dir=checkpoint_dir), sharded_state_dict @@ -210,6 +226,7 @@ def _open_zarr_array_verbose(path: Path, mode: str, **open_kwargs): def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True): + """Turn numpy array to torch tensor.""" x = loaded_array if HAS_BFLOAT16 and x.dtype == np.dtype('bfloat16'): x = x.astype(np.dtype('float32')) @@ -237,10 +254,12 @@ def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range= def flatten_range(sharded_tensor, x): + """Apply flattened range to a tensor.""" return x.flatten()[sharded_tensor.flattened_range] def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor): + """Pad tensor to the expected shape.""" pad_args = [] assert len(x.shape) == len(expected_sharded_ten.local_shape) # Reversed iteration order because F.pad expects so @@ -252,9 +271,10 @@ def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor): if x_sh == exp_sh: pad_args.extend((0, 0)) elif x_sh > exp_sh: - assert ( - False - ), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}' + assert False, ( + f'Expected shape ({exp_sh}) smaller than actual ({x_sh})' + f' for {repr(expected_sharded_ten)}' + ) else: pad_args.extend((0, exp_sh - x_sh)) # TODO: behavior control with envvar is for testing purposes only, remove it @@ -299,9 +319,3 @@ def load_zarr_based_sharded_metadata( tuple(1 for _ in arr_shape), ) return sharded_state_dict - - -# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy() -default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy( - 'zarr', 1 -)