diff --git a/.gitlab/stages/00.pre.yml b/.gitlab/stages/00.pre.yml index ac1bcca3fe..02b441e97b 100644 --- a/.gitlab/stages/00.pre.yml +++ b/.gitlab/stages/00.pre.yml @@ -38,6 +38,13 @@ label_merge_request: source labels curl --header "PRIVATE-TOKEN: ${PROJECT_ACCESS_TOKEN_MCORE}" --url "https://${GITLAB_ENDPOINT}/api/v4/projects/${CI_PROJECT_ID}/merge_requests/${CI_MERGE_REQUEST_IID}" --data-urlencode "add_labels=$LABELS" -X PUT +clean_docker_node: + stage: .pre + image: docker:26.1.4-dind + tags: [mcore-docker-node] + script: + - docker system prune -a --filter "until=48h" -f + check_milestone: rules: - if: $CI_PIPELINE_SOURCE == "merge_request_event" diff --git a/.gitlab/stages/01.tests.yml b/.gitlab/stages/01.tests.yml index 18b4175d93..230f5ed5b9 100644 --- a/.gitlab/stages/01.tests.yml +++ b/.gitlab/stages/01.tests.yml @@ -104,17 +104,15 @@ unit_tests: - coverage docs_build_test: - image: ${GITLAB_ENDPOINT}:5005/adlr/megatron-lm/python-format:0.0.1 + image: ${CI_MCORE_IMAGE}:${CI_PIPELINE_ID} tags: [mcore-docker-node-small] + needs: [build_image] script: - cd .. - rm -rf documentation && git clone https://gitlab-ci-token:${CI_JOB_TOKEN}@${GITLAB_ENDPOINT}/nemo-megatron-core-tme/documentation.git - mv megatron-lm/ documentation/ - cd documentation/ - ./repo docs - allow_failure: true - except: - - main formatting: extends: [.tests_common] diff --git a/docs/source/distrib_optimizer.md b/docs/source/api-guide/dist_optimizer.md similarity index 95% rename from docs/source/distrib_optimizer.md rename to docs/source/api-guide/dist_optimizer.md index def23b20eb..0f52ad7175 100644 --- a/docs/source/distrib_optimizer.md +++ b/docs/source/api-guide/dist_optimizer.md @@ -28,11 +28,11 @@ The figures below illustrate the grad buffer's sharding scheme, and the key step ## Data flow -![Data flow](images/distrib_optimizer/data_flow.png) +![Data flow](../images/distrib_optimizer/data_flow.png) ## Sharding scheme -![Sharding scheme](images/distrib_optimizer/sharding_scheme.png) +![Sharding scheme](../images/distrib_optimizer/sharding_scheme.png) ## Key steps diff --git a/docs/source/api-guide/fusions.rst b/docs/source/api-guide/fusions.rst index 694ed129f4..22782ca84e 100644 --- a/docs/source/api-guide/fusions.rst +++ b/docs/source/api-guide/fusions.rst @@ -58,7 +58,7 @@ fusions.fused\_cross\_entropy\_loss module This module uses PyTorch JIT to fuse the cross entropy loss calculation and batches communication calls. -.. automodule:: core.fusions.fused_softmax +.. automodule:: core.fusions.fused_cross_entropy :members: :undoc-members: :show-inheritance: diff --git a/docs/source/api-guide/index.rst b/docs/source/api-guide/index.rst index d0206eb281..c2265356d4 100644 --- a/docs/source/api-guide/index.rst +++ b/docs/source/api-guide/index.rst @@ -12,6 +12,7 @@ API Guide transformer moe dist_checkpointing + dist_optimizer distributed datasets num_microbatches_calculator diff --git a/docs/source/api-guide/num_microbatches_calculator.rst b/docs/source/api-guide/num_microbatches_calculator.rst index 1c478a7a80..4790b31749 100644 --- a/docs/source/api-guide/num_microbatches_calculator.rst +++ b/docs/source/api-guide/num_microbatches_calculator.rst @@ -1,5 +1,5 @@ Microbatches Calculator -============== +======================= This api is used to calculate the number of microbatches required to fit a given model on a given batch size. diff --git a/megatron/core/dist_checkpointing/strategies/__init__.py b/megatron/core/dist_checkpointing/strategies/__init__.py index db8093f803..6d04265b16 100644 --- a/megatron/core/dist_checkpointing/strategies/__init__.py +++ b/megatron/core/dist_checkpointing/strategies/__init__.py @@ -2,4 +2,8 @@ """ Various loading and saving strategies """ -from .common import _import_trigger +# We mock imports to populate the `default_strategies` objects. +# Since they are defined in base but populated in common, we have to mock +# import both modules. +from megatron.core.dist_checkpointing.strategies.base import _import_trigger +from megatron.core.dist_checkpointing.strategies.common import _import_trigger diff --git a/megatron/core/dist_checkpointing/strategies/base.py b/megatron/core/dist_checkpointing/strategies/base.py index daa37fe43c..cc1c83b92f 100644 --- a/megatron/core/dist_checkpointing/strategies/base.py +++ b/megatron/core/dist_checkpointing/strategies/base.py @@ -6,6 +6,7 @@ from collections import defaultdict from enum import Enum from pathlib import Path +from typing import Any, DefaultDict from ..mapping import CheckpointingException, ShardedStateDict, StateDict from .async_utils import AsyncCallsQueue, AsyncRequest @@ -18,7 +19,8 @@ class StrategyAction(Enum): SAVE_SHARDED = 'save_sharded' -default_strategies = defaultdict(dict) +_import_trigger = None +default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict) async_calls = AsyncCallsQueue() @@ -35,7 +37,8 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int): from .torch import _import_trigger except ImportError as e: raise CheckpointingException( - f'Cannot import a default strategy for: {(action.value, backend, version)}. Error: {e}. Hint: {error_hint}' + f'Cannot import a default strategy for: {(action.value, backend, version)}. ' + f'Error: {e}. Hint: {error_hint}' ) from e try: return default_strategies[action.value][(backend, version)] @@ -46,7 +49,8 @@ def get_default_strategy(action: StrategyAction, backend: str, version: int): class LoadStrategyBase(ABC): - """Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version.""" + """Base class for a load strategy. Requires implementing checks for compatibility with a + given checkpoint version.""" @abstractmethod def check_backend_compatibility(self, loaded_version): @@ -63,7 +67,8 @@ def can_handle_sharded_objects(self): class SaveStrategyBase(ABC): - """Base class for a save strategy. Requires defining a backend type and version of the saved format.""" + """Base class for a save strategy. Requires defining a backend type and + version of the saved format.""" def __init__(self, backend: str, version: int): self.backend = backend diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py index cfa55ab480..46f10733f5 100644 --- a/megatron/core/dist_checkpointing/strategies/common.py +++ b/megatron/core/dist_checkpointing/strategies/common.py @@ -4,7 +4,6 @@ import logging import os -from itertools import product from pathlib import Path import torch @@ -68,10 +67,12 @@ def load_common(self, checkpoint_dir: Path): def load_sharded_objects( self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path ): - """Replaces all ShardedObject from a given state dict with values loaded from the checkpoint. + """Replaces all ShardedObject from a given state dict with values loaded from the + checkpoint. Args: - sharded_objects_state_dict (ShardedStateDict): sharded state dict defining what objects should be loaded. + sharded_objects_state_dict (ShardedStateDict): + sharded state dict defining what objects should be loaded. checkpoint_dir (Path): checkpoint directory Returns: @@ -99,7 +100,8 @@ def load_sharded_object(sh_obj: ShardedObject): else: ckpt_files = [f.name for f in checkpoint_dir.iterdir()] logger.debug( - f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint directory content: {ckpt_files}' + f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint' + f' directory content: {ckpt_files}' ) raise CheckpointingException(err_msg) from e return loaded_obj @@ -119,7 +121,8 @@ def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict: full_key = f'{subdir.name}/{shard_file.stem}' sh_objs.append(ShardedObject.empty_from_unique_key(full_key)) - # This is a backward-compatibility fix, where the last global shape is missing in the name + # This is a backward-compatibility fix, where the last global shape is missing in the + # name if sh_objs[0].global_shape[-1] < 0: max_last_offset = max(map(lambda sh_obj: sh_obj.global_offset[-1], sh_objs)) for sh_obj in sh_objs: diff --git a/megatron/core/fusions/fused_bias_gelu.py b/megatron/core/fusions/fused_bias_gelu.py index 2b5467467c..13c5bdf705 100644 --- a/megatron/core/fusions/fused_bias_gelu.py +++ b/megatron/core/fusions/fused_bias_gelu.py @@ -4,7 +4,7 @@ from megatron.core.jit import jit_fuser -###### BIAS GELU FUSION/ NO AUTOGRAD ################ +# BIAS GELU FUSION/ NO AUTOGRAD ################ # 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2) -> 0.70710678 # sqrt(2/pi) -> 0.79788456 @@ -46,5 +46,10 @@ def backward(ctx, grad_output): tmp = bias_gelu_back(grad_output, bias, input) return tmp, tmp + # This is required to make Sphinx happy :-( + @classmethod + def apply(cls, *args, **kwargs): + super().apply(*args, **kwargs) + bias_gelu_impl = GeLUFunction.apply diff --git a/megatron/core/num_microbatches_calculator.py b/megatron/core/num_microbatches_calculator.py index e5ed7fc6f0..16bd95a7b4 100644 --- a/megatron/core/num_microbatches_calculator.py +++ b/megatron/core/num_microbatches_calculator.py @@ -41,9 +41,12 @@ def update_num_microbatches( """Update number of microbatches. Args: - consumed_samples (int): Number of samples consumed. - consistency_check (bool, optional): Option to check current schedule's consistency. Defaults to True. - verbose (bool, optional): Option to control logging. Defaults to False. + consumed_samples (int): + Number of samples consumed. + consistency_check (bool, optional): + Option to check current schedule's consistency. Defaults to True. + verbose (bool, optional): + Option to control logging. Defaults to False. """ _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check, verbose) @@ -59,12 +62,20 @@ def init_num_microbatches_calculator( """Initialize number of microbatches calculator. Supporting backward compatibility. Args: - rank (int): Rank of the GPU, only rank 0 will log the information. - rampup_batch_size (Optional[List[int]]): Rampup batch size, should be in format of [start_global_batch_size, batch_size_increment, ramup_samples]. - global_batch_size (int): Global batch size for the model. - micro_batch_size (int): Micro batch size at initialization. - data_parallel_size (int): Data parallel size. - decrease_batch_size_if_needed (bool, optional): If true, scale down batch size to ensure divisibility by DP size * microbatch size. Defaults to False. + rank (int): + Rank of the GPU, only rank 0 will log the information. + rampup_batch_size (Optional[List[int]]): + Rampup batch size, should be in format of [start_global_batch_size, + batch_size_increment, ramup_samples]. + global_batch_size (int): + Global batch size for the model. + micro_batch_size (int): + Micro batch size at initialization. + data_parallel_size (int): + Data parallel size. + decrease_batch_size_if_needed (bool, optional): + If true, scale down batch size to ensure divisibility by DP size * microbatch size. + Defaults to False. """ _configure_global_num_microbatches_calculator( rank, @@ -94,12 +105,20 @@ def reconfigure_num_microbatches_calculator( """Reconfigure number of microbatches calculator. Supporting backward compatibility. Args: - rank (int): Rank of the GPU, only rank 0 will log the information. - rampup_batch_size (Optional[List[int]]): Rampup batch size, should be in format of [start_global_batch_size, batch_size_increment, ramup_samples]. - global_batch_size (int): Global batch size for the model. - micro_batch_size (int): Micro batch size at initialization. - data_parallel_size (int): Data parallel size. - decrease_batch_size_if_needed (bool, optional): If true, scale down batch size to ensure divisibility by DP size * microbatch size. Defaults to False. + rank (int): + Rank of the GPU, only rank 0 will log the information. + rampup_batch_size (Optional[List[int]]): + Rampup batch size, should be in format of + [start_global_batch_size, batch_size_increment, ramup_samples]. + global_batch_size (int): + Global batch size for the model. + micro_batch_size (int): + Micro batch size at initialization. + data_parallel_size (int): + Data parallel size. + decrease_batch_size_if_needed (bool, optional): + If true, scale down batch size to ensure divisibility by DP size * microbatch size. + Defaults to False. """ _configure_global_num_microbatches_calculator( rank, @@ -121,16 +140,26 @@ def _configure_global_num_microbatches_calculator( decrease_batch_size_if_needed: bool = False, init: bool = False, ) -> None: - """Configure number of microbatches calculator. Can be used for initialization and reconfiguration. + """Configure number of microbatches calculator. Can be used for initialization and + reconfiguration. Args: - rank (int): Rank of the GPU, only rank 0 will log the information. - rampup_batch_size (Optional[List[int]]): Rampup batch size, should be in format of [start_global_batch_size, batch_size_increment, ramup_samples]. - global_batch_size (int): Global batch size for the model. - micro_batch_size (int): Micro batch size at initialization. - data_parallel_size (int): Data parallel size. - decrease_batch_size_if_needed (bool, optional): If true, scale down batch size to ensure divisibility by DP size * microbatch size. Defaults to False. - init (bool, optional): If true, initialize the calculator. Defaults to False. + rank (int): + Rank of the GPU, only rank 0 will log the information. + rampup_batch_size (Optional[List[int]]): + Rampup batch size, should be in format of + [start_global_batch_size, batch_size_increment, ramup_samples]. + global_batch_size (int): + Global batch size for the model. + micro_batch_size (int): + Micro batch size at initialization. + data_parallel_size (int): + Data parallel size. + decrease_batch_size_if_needed (bool, optional): + If true, scale down batch size to ensure divisibility by DP size * microbatch size. + Defaults to False. + init (bool, optional): + If true, initialize the calculator. Defaults to False. """ global _GLOBAL_NUM_MICROBATCHES_CALCULATOR @@ -160,12 +189,20 @@ def _build_num_microbatches_calculator( """Build number of microbatches calculator. Internal helper method. Args: - rank (int): Rank of the GPU, only rank 0 will log the information. - rampup_batch_size (Optional[List[int]]): Rampup batch size, should be in format of [start_global_batch_size, batch_size_increment, ramup_samples]. - global_batch_size (int): Global batch size for the model. - micro_batch_size (int): Micro batch size at initialization. - data_parallel_size (int): Data parallel size. - decrease_batch_size_if_needed (bool): If true, scale down batch size to ensure divisibility by DP size * microbatch size. + rank (int): + Rank of the GPU, only rank 0 will log the information. + rampup_batch_size (Optional[List[int]]): + Rampup batch size, should be in format of + [start_global_batch_size, batch_size_increment, ramup_samples]. + global_batch_size (int): + Global batch size for the model. + micro_batch_size (int): + Micro batch size at initialization. + data_parallel_size (int): + Data parallel size. + decrease_batch_size_if_needed (bool): + If true, scale down batch size to ensure divisibility by DP size * microbatch size. + """ # Constant batch size. @@ -193,7 +230,9 @@ def _build_num_microbatches_calculator( ramup_samples = int(rampup_batch_size[2]) if rank == 0: logger.info( - f'will use batch size rampup starting from global batch size {start_global_batch_size} to global batch size {global_batch_size} with batch size increments {batch_size_increment} over {ramup_samples} samples.' + f'will use batch size rampup starting from global batch size ' + f'{start_global_batch_size} to global batch size {global_batch_size} with batch' + f'size increments {batch_size_increment} over {ramup_samples} samples.' ) num_microbatches_calculator = RampupBatchsizeNumMicroBatchesCalculator( global_batch_size, @@ -236,7 +275,8 @@ def get_micro_batch_size(self) -> int: return self.micro_batch_size def get_current_running_global_batch_size(self) -> int: - """Get current running global batch size. If decrease_batch_size_if_needed is False, this just equals global batch size.""" + """Get current running global batch size. If decrease_batch_size_if_needed is False, + this just equals global batch size.""" return self.current_running_global_batch_size @abstractmethod @@ -249,11 +289,17 @@ class ConstantNumMicroBatchesCalculator(NumMicroBatchesCalculator): """Calculator of number of microbatches with constant global batch size. Args: - global_batch_size (int): Global batch size. - micro_batch_size (int): Micro batch size. - data_parallel_size (int): Data parallel size. - decrease_batch_size_if_needed (bool): If true, decrease batch size to ensure divisibility by DP size * microbatch size (if needed). - rank (int): Rank (to determine whether logging should be performed). + global_batch_size (int): + Global batch size. + micro_batch_size (int): + Micro batch size. + data_parallel_size (int): + Data parallel size. + decrease_batch_size_if_needed (bool): + If true, decrease batch size to ensure divisibility by DP size * microbatch size + (if needed). + rank (int): + Rank (to determine whether logging should be performed). """ def __init__( @@ -301,21 +347,28 @@ def update(self, consumed_samples, consistency_check, verbose=False) -> None: class RampupBatchsizeNumMicroBatchesCalculator(NumMicroBatchesCalculator): """Calculator of number of microbatches with batch size rampup. - Over - steps = (global-batch-size - start-batch-size) / batch_size_increment - increment batch size from start-batch-size to global-batch-size using - rampup-samples / steps + Over `steps = (global-batch-size - start-batch-size) / batch_size_increment` increment batch + size from start-batch-size to global-batch-size using rampup-samples / steps samples. Args: - global_batch_size (int): Global batch size post rampup. - micro_batch_size (int): Micro batch size. - data_parallel_size (int): Data parallel size. - decrease_batch_size_if_needed (bool): If true, decrease batch size to ensure divisibility by DP size * microbatch size (if needed). - rank (int): Rank (to determine whether logging should be performed). - start_global_batch_size (int): Global batch size to start with. - batch_size_increment (int): Global batch size increments. - ramup_samples (int): Number of samples to use ramp up global + global_batch_size (int): + Global batch size post rampup. + micro_batch_size (int): + Micro batch size. + data_parallel_size (int): + Data parallel size. + decrease_batch_size_if_needed (bool): + If true, decrease batch size to ensure divisibility by DP size * microbatch size + (if needed). + rank (int): + Rank (to determine whether logging should be performed). + start_global_batch_size (int): + Global batch size to start with. + batch_size_increment (int): + Global batch size increments. + ramup_samples (int): + Number of samples to use ramp up global batch size from `start_global_batch_size` to `global_batch_size`. """ @@ -357,15 +410,14 @@ def __init__( self.current_global_batch_size = None diff_batch_size = self.global_batch_size - self.start_global_batch_size - assert ( - diff_batch_size >= 0 - ), 'expected global batch size to be greater than or equal to start batch size, got {} and {}.'.format( - self.global_batch_size, self.start_global_batch_size + assert diff_batch_size >= 0, ( + 'expected global batch size to be greater than or equal to start batch size, ' + f'got {self.global_batch_size} and {self.start_global_batch_size}' ) assert diff_batch_size % batch_size_increment == 0, ( 'expected ' - 'global batch size interval ({}) to be divisible by global batch ' - 'size increment ({})'.format(diff_batch_size, batch_size_increment) + f'global batch size interval ({diff_batch_size}) to be divisible by global batch ' + f'size increment ({batch_size_increment})' ) num_increments = diff_batch_size // self.batch_size_increment @@ -399,7 +451,8 @@ def update(self, consumed_samples: int, consistency_check: bool, verbose: bool = global_batch_size_changed = True if self.rank == 0 and global_batch_size_changed and verbose: logger.info( - f'ramping up batch size from {old_current_global_batch_size} to {self.current_global_batch_size}' + f'ramping up batch size from {old_current_global_batch_size} to ' + f'{self.current_global_batch_size}' ) # Check consistency of the current global batch size. @@ -423,7 +476,8 @@ def update(self, consumed_samples: int, consistency_check: bool, verbose: bool = ) if self.rank == 0 and global_batch_size_changed and verbose: logger.info( - f'decreasing batch size from {self.current_global_batch_size} to {self.current_running_global_batch_size}' + f'decreasing batch size from {self.current_global_batch_size} to ' + f'{self.current_running_global_batch_size}' ) assert ( self.current_running_global_batch_size % self.micro_batch_times_data_parallel_size diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index b7669ccb45..d7da83cc71 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -1,7 +1,7 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import contextlib -from typing import Callable, Iterator, List, Optional, Union +from typing import Iterator, List, Union import torch from torch.autograd.variable import Variable @@ -96,7 +96,8 @@ def forward_step(data_iterator, model): collect_non_loss_data (optional, bool, default=False): TODO first_val_step (bool, optional): Is the first step of the validation phase. Used by - Transformer Engine modules to only update their fp8 weights only on the first validation step. + Transformer Engine modules to only update their fp8 weights only on the first validation + step. """ pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() @@ -187,9 +188,11 @@ def forward_step( Otherwise, the passed-in input_tensor is used. Args: - forward_step_func (callable): The forward step function for the model that takes the + forward_step_func (callable): + The forward step function for the model that takes the data iterator as the first argument, and model as the second. This user's forward step is expected to output a tuple of two elements: + 1. The output object from the forward step. This output object needs to be a tensor or some kind of collection of tensors. The only hard requirement for this object is that it needs to be acceptible as input into the second @@ -198,7 +201,8 @@ def forward_step( could be a reduction over the loss from the model, it could be a function that grabs the output from the model and reformats, it could be a function that just passes through the model output. This function must have one of the following - patterns, and depending on the pattern different things happen internally. + patterns, and depending on the pattern different things happen internally: + a. A tuple of reduced loss and some other data. Note that in this case the first argument is divided by the number of global microbatches, assuming it is a loss, so that the loss is stable as a function of @@ -212,23 +216,33 @@ def forward_step( to specify `collect_non_loss_data=True` and you may also want to specify `forward_only=True` in the call to the parent forward_backward function. - data_iterator (iterator): The data iterator. - model (nn.Module): The model to perform the forward step on. - num_microbatches (int): The number of microbatches. - input_tensor (Tensor or list[Tensor]): The input tensor(s) for the forward step. - forward_data_store (list): The list to store the forward data. If you go down path 2.a or + data_iterator (iterator): + The data iterator. + model (nn.Module): + The model to perform the forward step on. + num_microbatches (int): + The number of microbatches. + input_tensor (Tensor or list[Tensor]): + The input tensor(s) for the forward step. + forward_data_store (list): + The list to store the forward data. If you go down path 2.a or 2.b for the return of your forward reduction function then this will store only the final dimension of the output, for example the metadata output by the loss function. If you go down the path of 2.c then this will store the entire output of the forward reduction function applied to the model output. - config (object): The configuration object. - collect_non_loss_data (bool, optional): Whether to collect non-loss data. Defaults to False. + config (object): + The configuration object. + collect_non_loss_data (bool, optional): + Whether to collect non-loss data. Defaults to False. This is the path to use if you want to collect arbitrary output from the model forward, such as with inference use cases. Defaults to False. - checkpoint_activations_microbatch (int, optional): The microbatch to checkpoint activations. + checkpoint_activations_microbatch (int, optional): + The microbatch to checkpoint activations. Defaults to None. - is_first_microbatch (bool, optional): Whether it is the first microbatch. Defaults to False. - current_microbatch (int, optional): The current microbatch. Defaults to None. + is_first_microbatch (bool, optional): + Whether it is the first microbatch. Defaults to False. + current_microbatch (int, optional): + The current microbatch. Defaults to None. Returns: Tensor or list[Tensor]: The output object(s) from the forward step. @@ -285,7 +299,8 @@ def forward_step( config.timers('forward-compute').stop() # Set the loss scale for the auxiliary loss of the MoE layer. - # Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly. + # Since we use a trick to do backward on the auxiliary loss, we need to set the scale + # explicitly. if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None: # Calculate the loss scale based on the grad_scale_func if available, else default to 1. loss_scale = ( @@ -685,7 +700,6 @@ def get_microbatch_id_in_model_chunk(iteration_id, forward): def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: """Check if an iteration is the first for a model chunk.""" microbatch_group_size = pipeline_parallel_size * num_model_chunks - num_microbatch_groups = total_num_microbatches // microbatch_group_size microbatch_group_id = microbatch_id // microbatch_group_size microbatch_id_in_group = microbatch_id % microbatch_group_size if microbatch_group_id == 0: @@ -814,7 +828,6 @@ def backward_step_helper(microbatch_id): for req in fwd_wait_handles: req.wait() - cur_model_chunk_id = get_model_chunk_id(k, forward=True) # Decide to checkpoint all layers' activations of the current micro-batch if max_outstanding_backprops is not None: checkpoint_activations_microbatch = ( @@ -918,7 +931,6 @@ def backward_step_helper(microbatch_id): else: checkpoint_activations_microbatch = None - cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True) current_microbatch = get_microbatch_id_in_model_chunk(forward_k, forward=True) if config.overlap_p2p_comm: if fwd_wait_handles is not None: @@ -1145,8 +1157,10 @@ def get_tensor_shapes( config, encoder_decoder_xattn: bool, ): - # Determine right tensor sizes (based on position of rank with respect to split rank) and model size. - # Send two tensors if model decoder requires the encoder's output (via cross-attention) and rank is in decoder stage. + # Determine right tensor sizes (based on position of rank with + # respect to split rank) and model size. + # Send two tensors if model decoder requires the encoder's output + # (via cross-attention) and rank is in decoder stage. # first tensor is decoder. # second tensor is encoder. # If model has an encoder & decoder and rank is at the boundary: @@ -1260,9 +1274,7 @@ def forward_backward_pipelining_without_interleaving( first_val_step: bool = None, ): """Run non-interleaved 1F1B schedule, with communication between pipeline - stages. - - Returns dictionary with losses if the last stage, empty dict otherwise.""" + stages. Returns dictionary with losses if the last stage, empty dict otherwise.""" if isinstance(model, list): assert ( diff --git a/megatron/core/transformer/moe/README.md b/megatron/core/transformer/moe/README.md index 43643f57d6..9a43c82dae 100644 --- a/megatron/core/transformer/moe/README.md +++ b/megatron/core/transformer/moe/README.md @@ -87,7 +87,8 @@ To enable the token drop mechanism, such as GShard and SwitchTransformer, includ ``` The following figure illustrates differenting dropping strategies in MCore: -![Token Droppling Strategies](../../../../docs/source/images/moe/token_drop.png) + + 1. The default dropless strategy will not drop or pad any token. 2. By setting `--moe-expert-capacity-factor`, the tokens exceed the capacity of expert will be dropped based on their selected probabilities. @@ -97,7 +98,7 @@ The following figure illustrates differenting dropping strategies in MCore: ### Fine-tuning Mixtral Models Megatron-Core has full support for Mixtral MoE models, and we provide the checkpoint converter for Mixtral models from huggingface format to MCore format. -See more details in the [mixtral example](../../../../examples/mixtral/README.md). + ### Distributed Checkpointing MCore v0.7 introduced fully parallel and asynchronous saving capabilities to distributed checkpointing,