Skip to content

Commit

Permalink
ADLR/megatron-lm!1962 - docs: Fixes to allow building docs again
Browse files Browse the repository at this point in the history
  • Loading branch information
ko3n1g committed Aug 28, 2024
1 parent 34e607e commit 46736de
Show file tree
Hide file tree
Showing 13 changed files with 190 additions and 100 deletions.
7 changes: 7 additions & 0 deletions .gitlab/stages/00.pre.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ label_merge_request:
source labels
curl --header "PRIVATE-TOKEN: ${PROJECT_ACCESS_TOKEN_MCORE}" --url "https://${GITLAB_ENDPOINT}/api/v4/projects/${CI_PROJECT_ID}/merge_requests/${CI_MERGE_REQUEST_IID}" --data-urlencode "add_labels=$LABELS" -X PUT
clean_docker_node:
stage: .pre
image: docker:26.1.4-dind
tags: [mcore-docker-node]
script:
- docker system prune -a --filter "until=48h" -f

check_milestone:
rules:
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
Expand Down
6 changes: 2 additions & 4 deletions .gitlab/stages/01.tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,15 @@ unit_tests:
- coverage

docs_build_test:
image: ${GITLAB_ENDPOINT}:5005/adlr/megatron-lm/python-format:0.0.1
image: ${CI_MCORE_IMAGE}:${CI_PIPELINE_ID}
tags: [mcore-docker-node-small]
needs: [build_image]
script:
- cd ..
- rm -rf documentation && git clone https://gitlab-ci-token:${CI_JOB_TOKEN}@${GITLAB_ENDPOINT}/nemo-megatron-core-tme/documentation.git
- mv megatron-lm/ documentation/
- cd documentation/
- ./repo docs
allow_failure: true
except:
- main

formatting:
extends: [.tests_common]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ The figures below illustrate the grad buffer's sharding scheme, and the key step

## Data flow

![Data flow](images/distrib_optimizer/data_flow.png)
![Data flow](../images/distrib_optimizer/data_flow.png)

## Sharding scheme

![Sharding scheme](images/distrib_optimizer/sharding_scheme.png)
![Sharding scheme](../images/distrib_optimizer/sharding_scheme.png)

## Key steps

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api-guide/fusions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ fusions.fused\_cross\_entropy\_loss module

This module uses PyTorch JIT to fuse the cross entropy loss calculation and batches communication calls.

.. automodule:: core.fusions.fused_softmax
.. automodule:: core.fusions.fused_cross_entropy
:members:
:undoc-members:
:show-inheritance:
Expand Down
1 change: 1 addition & 0 deletions docs/source/api-guide/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ API Guide
transformer
moe
dist_checkpointing
dist_optimizer
distributed
datasets
num_microbatches_calculator
2 changes: 1 addition & 1 deletion docs/source/api-guide/num_microbatches_calculator.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Microbatches Calculator
==============
=======================
This api is used to calculate the number of microbatches required to fit a given model on a given batch size.


Expand Down
6 changes: 5 additions & 1 deletion megatron/core/dist_checkpointing/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@

""" Various loading and saving strategies """

from .common import _import_trigger
# We mock imports to populate the `default_strategies` objects.
# Since they are defined in base but populated in common, we have to mock
# import both modules.
from megatron.core.dist_checkpointing.strategies.base import _import_trigger
from megatron.core.dist_checkpointing.strategies.common import _import_trigger
13 changes: 9 additions & 4 deletions megatron/core/dist_checkpointing/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, DefaultDict

from ..mapping import CheckpointingException, ShardedStateDict, StateDict
from .async_utils import AsyncCallsQueue, AsyncRequest
Expand All @@ -18,7 +19,8 @@ class StrategyAction(Enum):
SAVE_SHARDED = 'save_sharded'


default_strategies = defaultdict(dict)
_import_trigger = None
default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict)

async_calls = AsyncCallsQueue()

Expand All @@ -35,7 +37,8 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int):
from .torch import _import_trigger
except ImportError as e:
raise CheckpointingException(
f'Cannot import a default strategy for: {(action.value, backend, version)}. Error: {e}. Hint: {error_hint}'
f'Cannot import a default strategy for: {(action.value, backend, version)}. '
f'Error: {e}. Hint: {error_hint}'
) from e
try:
return default_strategies[action.value][(backend, version)]
Expand All @@ -46,7 +49,8 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int):


class LoadStrategyBase(ABC):
"""Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version."""
"""Base class for a load strategy. Requires implementing checks for compatibility with a
given checkpoint version."""

@abstractmethod
def check_backend_compatibility(self, loaded_version):
Expand All @@ -63,7 +67,8 @@ def can_handle_sharded_objects(self):


class SaveStrategyBase(ABC):
"""Base class for a save strategy. Requires defining a backend type and version of the saved format."""
"""Base class for a save strategy. Requires defining a backend type and
version of the saved format."""

def __init__(self, backend: str, version: int):
self.backend = backend
Expand Down
13 changes: 8 additions & 5 deletions megatron/core/dist_checkpointing/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import os
from itertools import product
from pathlib import Path

import torch
Expand Down Expand Up @@ -68,10 +67,12 @@ def load_common(self, checkpoint_dir: Path):
def load_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Replaces all ShardedObject from a given state dict with values loaded from the checkpoint.
"""Replaces all ShardedObject from a given state dict with values loaded from the
checkpoint.
Args:
sharded_objects_state_dict (ShardedStateDict): sharded state dict defining what objects should be loaded.
sharded_objects_state_dict (ShardedStateDict):
sharded state dict defining what objects should be loaded.
checkpoint_dir (Path): checkpoint directory
Returns:
Expand Down Expand Up @@ -99,7 +100,8 @@ def load_sharded_object(sh_obj: ShardedObject):
else:
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(
f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint directory content: {ckpt_files}'
f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint'
f' directory content: {ckpt_files}'
)
raise CheckpointingException(err_msg) from e
return loaded_obj
Expand All @@ -119,7 +121,8 @@ def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
full_key = f'{subdir.name}/{shard_file.stem}'
sh_objs.append(ShardedObject.empty_from_unique_key(full_key))

# This is a backward-compatibility fix, where the last global shape is missing in the name
# This is a backward-compatibility fix, where the last global shape is missing in the
# name
if sh_objs[0].global_shape[-1] < 0:
max_last_offset = max(map(lambda sh_obj: sh_obj.global_offset[-1], sh_objs))
for sh_obj in sh_objs:
Expand Down
7 changes: 6 additions & 1 deletion megatron/core/fusions/fused_bias_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from megatron.core.jit import jit_fuser

###### BIAS GELU FUSION/ NO AUTOGRAD ################
# BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
Expand Down Expand Up @@ -46,5 +46,10 @@ def backward(ctx, grad_output):
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp

# This is required to make Sphinx happy :-(
@classmethod
def apply(cls, *args, **kwargs):
super().apply(*args, **kwargs)


bias_gelu_impl = GeLUFunction.apply
Loading

0 comments on commit 46736de

Please sign in to comment.