Skip to content

Commit

Permalink
Merge branch 'mblaz/fix-deprecation-notice' into 'main'
Browse files Browse the repository at this point in the history
Fix `zarr` deprecation notice

See merge request ADLR/megatron-lm!2042
  • Loading branch information
ericharper committed Sep 6, 2024
2 parents cc16182 + 9a0e78d commit 7a113e7
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 60 deletions.
8 changes: 3 additions & 5 deletions megatron/core/dist_checkpointing/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.

""" Various loading and saving strategies """
from megatron.core.dist_checkpointing.strategies.common import register_default_common_strategies

# We mock imports to populate the `default_strategies` objects.
# Since they are defined in base but populated in common, we have to mock
# import both modules.
from megatron.core.dist_checkpointing.strategies.base import _import_trigger
from megatron.core.dist_checkpointing.strategies.common import _import_trigger
# We load "common" strategies by default to be always available
register_default_common_strategies()
45 changes: 39 additions & 6 deletions megatron/core/dist_checkpointing/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, DefaultDict
from typing import Any, DefaultDict, Union

from ..mapping import CheckpointingException, ShardedStateDict, StateDict
from .async_utils import AsyncCallsQueue, AsyncRequest


class StrategyAction(Enum):
"""Specifies save vs load and sharded vs common action."""

LOAD_COMMON = 'load_common'
LOAD_SHARDED = 'load_sharded'
SAVE_COMMON = 'save_common'
SAVE_SHARDED = 'save_sharded'


_import_trigger = None
default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict)

async_calls = AsyncCallsQueue()
Expand All @@ -30,11 +31,17 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int):
try:
if backend == 'zarr':
error_hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages'
from .tensorstore import _import_trigger
from .zarr import _import_trigger
from .tensorstore import register_default_tensorstore_strategies

register_default_tensorstore_strategies()
from .zarr import register_default_zarr_strategies

register_default_zarr_strategies()
elif backend == 'torch_dist':
error_hint = ' Please use PyTorch version >=2.1'
from .torch import _import_trigger
from .torch import register_default_torch_strategies

register_default_torch_strategies()
except ImportError as e:
raise CheckpointingException(
f'Cannot import a default strategy for: {(action.value, backend, version)}. '
Expand All @@ -48,16 +55,35 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int):
) from e


def register_default_strategy(
action: StrategyAction,
backend: str,
version: int,
strategy: Union['SaveStrategyBase', 'LoadStrategyBase'],
):
"""Adds a given strategy to the registry of default strategies.
Args:
action (StrategyAction): specifies save/load and sharded/common
backend (str): backend that the strategy becomes a default for
version (int): version that the strategy becomes a default for
strategy (SaveStrategyBase, LoadStrategyBase): strategy to register
"""
default_strategies[action.value][(backend, version)] = strategy


class LoadStrategyBase(ABC):
"""Base class for a load strategy. Requires implementing checks for compatibility with a
given checkpoint version."""

@abstractmethod
def check_backend_compatibility(self, loaded_version):
def check_backend_compatibility(self, loaded_backend):
"""Verifies if this strategy is compatible with `loaded_backend`."""
raise NotImplementedError

@abstractmethod
def check_version_compatibility(self, loaded_version):
"""Verifies if this strategy is compatible with `loaded_version`."""
raise NotImplementedError

@property
Expand Down Expand Up @@ -88,15 +114,18 @@ class LoadCommonStrategy(LoadStrategyBase):

@abstractmethod
def load_common(self, checkpoint_dir: Path):
"""Load common part of the checkpoint."""
raise NotImplementedError

@abstractmethod
def load_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Load sharded objects from the checkpoint."""
raise NotImplementedError

def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
"""Load just the metadata from the checkpoint."""
if not self.can_handle_sharded_objects:
return {}
raise NotImplementedError
Expand All @@ -107,6 +136,7 @@ class LoadShardedStrategy(LoadStrategyBase):

@abstractmethod
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Load the sharded part of the checkpoint."""
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -145,11 +175,13 @@ class SaveCommonStrategy(SaveStrategyBase):

@abstractmethod
def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path):
"""Save common part of the state dict."""
raise NotImplementedError

def save_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Save sharded objects from the state dict."""
raise NotImplementedError


Expand All @@ -158,6 +190,7 @@ class SaveShardedStrategy(SaveStrategyBase):

@abstractmethod
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Save the sharded part of the state dict."""
raise NotImplementedError


Expand Down
27 changes: 17 additions & 10 deletions megatron/core/dist_checkpointing/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,52 @@
from megatron.core.dist_checkpointing.strategies.base import (
SaveCommonStrategy,
StrategyAction,
default_strategies,
register_default_strategy,
)

from ..dict_utils import dict_list_map_inplace, nested_values
from ..mapping import CheckpointingException, ShardedObject, is_main_replica
from ..strategies.base import LoadCommonStrategy

_import_trigger = None

COMMON_STATE_FNAME = 'common.pt'

logger = logging.getLogger(__name__)


def register_default_common_strategies():
"""Register default common strategies."""
register_default_strategy(StrategyAction.LOAD_COMMON, 'torch', 1, TorchCommonLoadStrategy())
register_default_strategy(
StrategyAction.SAVE_COMMON, 'torch', 1, TorchCommonSaveStrategy('torch', 1)
)


class TorchCommonSaveStrategy(SaveCommonStrategy):
"""Common save strategy leveraging native torch save/load."""

def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path):
"""Save common part of the state dict."""
if torch.distributed.get_rank() == 0:
torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME)

def save_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):

"""Save sharded objects from the state dict."""
for sh_obj in nested_values(sharded_objects_state_dict):
if is_main_replica(sh_obj.replica_id):
save_path = checkpoint_dir / f'{sh_obj.unique_key}.pt'
os.makedirs(save_path.parent, exist_ok=True)
torch.save(sh_obj.data, save_path)

def can_handle_sharded_objects(self):
"""This strategy can handle ShardedObjects."""
return True


class TorchCommonLoadStrategy(LoadCommonStrategy):
"""Common load strategy leveraging native torch save/load."""

def load_common(self, checkpoint_dir: Path):
"""Load common (non-sharded) objects state dict from the checkpoint.
Expand Down Expand Up @@ -135,16 +147,11 @@ def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:

@property
def can_handle_sharded_objects(self):
"""This strategy can handle ShardedObjects."""
return True

def check_backend_compatibility(self, loaded_version):
pass

def check_version_compatibility(self, loaded_version):
pass


default_strategies[StrategyAction.LOAD_COMMON.value][('torch', 1)] = TorchCommonLoadStrategy()
default_strategies[StrategyAction.SAVE_COMMON.value][('torch', 1)] = TorchCommonSaveStrategy(
'torch', 1
)
26 changes: 13 additions & 13 deletions megatron/core/dist_checkpointing/strategies/tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace
from ..mapping import ShardedStateDict, ShardedTensor
from .base import LoadShardedStrategy, StrategyAction, default_strategies
from .zarr import (
load_zarr_based_sharded_metadata,
numpy_to_torch_dtype_dict,
postprocess_numpy_array,
)

_import_trigger = None
from .base import LoadShardedStrategy, StrategyAction, register_default_strategy
from .zarr import load_zarr_based_sharded_metadata, postprocess_numpy_array

logger = getLogger(__name__)


def register_default_tensorstore_strategies():
"""Register default strategies leveraging tensorstore."""
register_default_strategy(
StrategyAction.LOAD_SHARDED, 'zarr', 1, TensorStoreLoadShardedStrategy()
)


class TensorStoreLoadShardedStrategy(LoadShardedStrategy):
"""Load strategy for Zarr backend using `tensorstore` for loading."""

def __init__(self, load_directly_on_device: bool = False):
super().__init__()
self.load_directly_on_device = load_directly_on_device
Expand Down Expand Up @@ -58,6 +61,8 @@ def check_version_compatibility(self, loaded_version):


def merge_global_slice_with_shape(global_slice, actual_shape, key):
"""Intersects the global slice with the actual shape (prevent overflow)."""

def _merge_slice(dim_slice, dim_size):
if isinstance(dim_slice, slice):
assert (
Expand Down Expand Up @@ -121,8 +126,3 @@ def open_ts_array(arr_path: Path):
except Exception as e:
raise CheckpointingException(f'Array {arr_path} could not be loaded. Error: {e}') from e
return arr


default_strategies[StrategyAction.LOAD_SHARDED.value][
('zarr', 1)
] = TensorStoreLoadShardedStrategy()
26 changes: 16 additions & 10 deletions megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@
is_main_replica,
)
from .async_utils import AsyncRequest
from .base import AsyncSaveShardedStrategy, LoadShardedStrategy, StrategyAction, default_strategies
from .base import (
AsyncSaveShardedStrategy,
LoadShardedStrategy,
StrategyAction,
register_default_strategy,
)
from .filesystem_async import FileSystemWriterAsync
from .resharding import (
TensorReformulationMetadata,
Expand All @@ -64,7 +69,16 @@
except ImportError:
HAVE_TE = False

_import_trigger = None

def register_default_torch_strategies():
"""Register default strategies related to PyT Distributed backend."""
register_default_strategy(
StrategyAction.LOAD_SHARDED, 'torch_dist', 1, TorchDistLoadShardedStrategy()
)
register_default_strategy(
StrategyAction.SAVE_SHARDED, 'torch_dist', 1, TorchDistSaveShardedStrategy('torch_dist', 1)
)


logger = getLogger(__name__)

Expand Down Expand Up @@ -818,11 +832,3 @@ def check_backend_compatibility(self, loaded_version):

def check_version_compatibility(self, loaded_version):
pass # TODO


default_strategies[StrategyAction.LOAD_SHARDED.value][
('torch_dist', 1)
] = TorchDistLoadShardedStrategy()
default_strategies[StrategyAction.SAVE_SHARDED.value][('torch_dist', 1)] = (
TorchDistSaveShardedStrategy('torch_dist', 1)
)
Loading

0 comments on commit 7a113e7

Please sign in to comment.