diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index 763d08aede..3de68b5091 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -16,7 +16,7 @@ from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import get_batch_on_this_cp_rank, log_single_rank +from megatron.core.utils import log_single_rank try: import transformer_engine # pylint: disable=unused-import @@ -637,6 +637,8 @@ def _process_embedding_token_parallel( if self.context_parallel_lm > 1: # Distribute sequence across CP ranks + from megatron.training.utils import get_batch_on_this_cp_rank + batch = get_batch_on_this_cp_rank( { "combined_embeddings": combined_embeddings, diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 2e0480a146..7a97603eba 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -262,48 +262,56 @@ def _get_megatron_optimizer_based_on_param_groups( Returns: Instance of MegatronOptimizer. """ - if config.optimizer == 'adam': - kwargs = { - "params": param_groups, - "lr": config.lr, - "weight_decay": config.weight_decay, - "betas": (config.adam_beta1, config.adam_beta2), - "eps": config.adam_eps, - } - if config.use_precision_aware_optimizer: - kwargs.update( - { - "master_weights": True, - "use_decoupled_grad": True, - "master_weight_dtype": config.main_params_dtype, - "exp_avg_dtype": config.exp_avg_dtype, - "exp_avg_sq_dtype": config.exp_avg_sq_dtype, - } - ) + # when freezing sub-models we may have no trainable parameters on a rank and + # hence an empty param_groups. However, we still need to create an optimizer + # for the purposes of grad stats reductions + if param_groups: + if config.optimizer == 'adam': + kwargs = { + "params": param_groups, + "lr": config.lr, + "weight_decay": config.weight_decay, + "betas": (config.adam_beta1, config.adam_beta2), + "eps": config.adam_eps, + } + + if config.use_precision_aware_optimizer: + kwargs.update( + { + "master_weights": True, + "use_decoupled_grad": True, + "master_weight_dtype": config.main_params_dtype, + "exp_avg_dtype": config.exp_avg_dtype, + "exp_avg_sq_dtype": config.exp_avg_sq_dtype, + } + ) - optimizer = Adam(**kwargs) - - def init_state_fn(opt, config=None): - for group in opt.param_groups: - for p in group['params']: - if len(opt.state[p]) == 0: - if config is None or not config.use_precision_aware_optimizer: - opt.state[p]['exp_avg'] = torch.zeros_like(p.data) - opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data) - else: - opt.initialize_state(p) - - elif config.optimizer == 'sgd': - optimizer = SGD( - param_groups, - lr=config.lr, - weight_decay=config.weight_decay, - momentum=config.sgd_momentum, - ) - init_state_fn = None + optimizer = Adam(**kwargs) + + def init_state_fn(opt, config=None): + for group in opt.param_groups: + for p in group['params']: + if len(opt.state[p]) == 0: + if config is None or not config.use_precision_aware_optimizer: + opt.state[p]['exp_avg'] = torch.zeros_like(p.data) + opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data) + else: + opt.initialize_state(p) + + elif config.optimizer == 'sgd': + optimizer = SGD( + param_groups, + lr=config.lr, + weight_decay=config.weight_decay, + momentum=config.sgd_momentum, + ) + init_state_fn = None + else: + raise Exception('{} optimizer is not supported.'.format(config.optimizer)) else: - raise Exception('{} optimizer is not supported.'.format(config.optimizer)) + optimizer = None + init_state_fn = None # Mixed precision optimizer. # - Note: both the Float16Optimizer and the DistributedOptimizer inherit @@ -423,6 +431,7 @@ def get_megatron_optimizer( model_chunk.overlap_param_gather_with_optimizer_step = ( overlap_param_gather_with_optimizer_step ) + optimizers.append( _get_megatron_optimizer_based_on_param_groups( config, diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index aab7bde9ed..6b3c53efca 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -483,10 +483,16 @@ def __init__( for model_chunk in self.model_chunks: assert self.ddp_config == model_chunk.ddp_config - assert isinstance( - optimizer, Adam + assert ( + isinstance(optimizer, Adam) or optimizer is None ), "Only Adam currently supported, due to checkpointing requirements." + # when freezing sub-models we have no real optimizer + # but still need a stub DistributedOptimizer class + if optimizer is None: + self.is_stub_optimizer = True + return + # Model grad buffer ranges. assert per_model_buffers is not None, "per_model_buffers must be provided" self.buffers = list(itertools.chain(*per_model_buffers.values())) @@ -551,6 +557,8 @@ def __init__( self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges] self.optimizer.load_state_dict(self.optimizer.state_dict()) + self.is_stub_optimizer = False + def _get_model_param_range_map(self, param: torch.nn.Parameter): """ Given a model param, get the index sub-range of the param that this @@ -1635,6 +1643,8 @@ def load_parameter_state(self, filename: str, *, update_legacy_format=False): Args: filename (str): path to load parameter state from. """ + if self.is_stub_optimizer: + return state_dict = None if torch.distributed.get_rank(self.data_parallel_group) == 0: state_dict = torch.load(filename) @@ -1653,6 +1663,8 @@ def zero_grad(self, set_to_none: bool = True): Args: set_to_none (bool): if true, set grads to None. """ + if self.is_stub_optimizer: + return total_groups = [ self.model_float16_groups, self.model_fp32_groups, @@ -1710,6 +1722,8 @@ def _copy_model_grads_to_main_grads(self): buffer, this method is responsible for copying the updated grads from the grad buffer to the main shard's grad field. """ + if self.is_stub_optimizer: + return # Utility method for copying group grads. def copy_group_grads(model_groups, shard_main_groups): @@ -1748,6 +1762,8 @@ def _copy_main_params_to_model_params(self): buffer, this method is responsible for copying the updated params from the main shards into the correct position in the grad buffer. """ + if self.is_stub_optimizer: + return # Utility method for copying group params. def copy_group_params(shard_main_groups, model_groups): @@ -1831,6 +1847,8 @@ def _update_fp8_scale_inv_and_amax(self): If detect FP8 parameters, update their `_scale_inv` and do reduce-max for their `amax_history`. """ + if self.is_stub_optimizer: + return amaxes = [] scales = [] scale_invs = [] diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index 785eda795f..e830bea88d 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -4,6 +4,7 @@ import copy import math +import warnings from abc import ABC, abstractmethod from itertools import chain from logging import getLogger @@ -109,7 +110,11 @@ def __init__( ): """Input optimizer is the base optimizer (e.g., Adam).""" self.optimizer = optimizer - assert self.optimizer, 'no optimizer is provided.' + if self.optimizer is None: + warnings.warn( + f"WARNING: there is no optimizer on RANK {torch.distributed.get_rank()}. " + "This may be expected if you have frozen sub-models." + ) self.config = config self.init_state_fn = init_state_fn @@ -118,9 +123,10 @@ def get_parameters(self) -> List[torch.nn.Parameter]: Get list of parameters wrapped in optimizer. """ params = [] - for param_group in self.optimizer.param_groups: - for param in param_group['params']: - params.append(param) + if hasattr(self.optimizer, 'param_groups'): + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + params.append(param) return params def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]: @@ -189,13 +195,18 @@ def get_grad_norm(self): def clip_grad_norm(self, clip_grad: float) -> float: """Compute and return grad norm, also clip grads.""" params = self.get_parameters() - grads_for_norm = self.get_main_grads_for_grad_norm() + if params: + grads_for_norm = self.get_main_grads_for_grad_norm() + else: + grads_for_norm = [] grad_norm = get_grad_norm_fp32( grads_for_norm, grad_stats_parallel_group=self.get_grad_stats_parallel_group() ) - clip_grad_by_total_norm_fp32( - params, clip_grad, grad_norm, self.config.use_precision_aware_optimizer - ) + + if params: + clip_grad_by_total_norm_fp32( + params, clip_grad, grad_norm, self.config.use_precision_aware_optimizer + ) return grad_norm def count_zeros(self) -> float: @@ -257,7 +268,10 @@ def _set_state(self, value): # "optimizer_instance.param_groups" # (for example, to adjust the learning rate) def _get_param_groups(self): - return self.optimizer.param_groups + if self.is_stub_optimizer: + return [] + else: + return self.optimizer.param_groups def _set_param_groups(self, value): self.optimizer.param_groups = value @@ -365,15 +379,17 @@ def reload_model_params(self): def _unscale_main_grads_and_check_for_nan(self): # Collect main grads. - main_grads = self._collect_main_grad_data_for_unscaling() + if not self.is_stub_optimizer: + main_grads = self._collect_main_grad_data_for_unscaling() # Reset found inf. self.found_inf.fill_(0.0) - # Unscale and set found inf/nan - torch._amp_foreach_non_finite_check_and_unscale_( - main_grads, self.found_inf, self.grad_scaler.inv_scale - ) + if not self.is_stub_optimizer: + # Unscale and set found inf/nan + torch._amp_foreach_non_finite_check_and_unscale_( + main_grads, self.found_inf, self.grad_scaler.inv_scale + ) # Update across all model parallel instances. torch.distributed.all_reduce( @@ -397,7 +413,8 @@ def prepare_grads(self) -> bool: timers('optimizer-copy-to-main-grad', log_level=1).start( barrier=self.config.barrier_with_L1_time ) - self._copy_model_grads_to_main_grads() + if not self.is_stub_optimizer: + self._copy_model_grads_to_main_grads() if timers is not None: timers('optimizer-copy-to-main-grad').stop() @@ -431,7 +448,8 @@ def step_with_ready_grads(self) -> bool: timers('optimizer-inner-step', log_level=1).start( barrier=self.config.barrier_with_L1_time ) - self.optimizer.step() + if not self.is_stub_optimizer: + self.optimizer.step() if timers is not None: timers('optimizer-inner-step').stop() @@ -440,7 +458,8 @@ def step_with_ready_grads(self) -> bool: timers('optimizer-copy-main-to-model-params', log_level=1).start( barrier=self.config.barrier_with_L1_time ) - self._copy_main_params_to_model_params() + if not self.is_stub_optimizer: + self._copy_main_params_to_model_params() if timers is not None: timers('optimizer-copy-main-to-model-params').stop() @@ -459,7 +478,7 @@ def step(self): timers('optimizer-clip-main-grad', log_level=1).start( barrier=self.config.barrier_with_L1_time ) - grad_norm = None + grad_norm = 0.0 if self.config.clip_grad > 0.0: grad_norm = self.clip_grad_norm(self.config.clip_grad) if timers is not None: @@ -470,7 +489,7 @@ def step(self): timers('optimizer-count-zeros', log_level=1).start( barrier=self.config.barrier_with_L1_time ) - num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else None + num_zeros_in_grad = self.count_zeros() if self.config.log_num_zeros_in_grad else 0 if timers is not None: timers('optimizer-count-zeros').stop() @@ -506,56 +525,60 @@ def __init__( # Handle main parameters. - # Three groups of parameters: - # float16_groups: original float16 parameters - # fp32_from_float16_groups: fp32 copy of float16 parameters - # fp32_from_fp32_groups: original fp32 parameters - self.float16_groups = [] - self.fp32_from_float16_groups = [] - self.fp32_from_fp32_groups = [] - - # For all the groups in the original optimizer: - for param_group in self.optimizer.param_groups: - float16_params_this_group = [] - fp32_params_this_group = [] - fp32_from_float16_params_this_group = [] - # For all the parameters in this group: - for i, param in enumerate(param_group['params']): - if param.requires_grad: - - # float16 params: - if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: - float16_params_this_group.append(param) - # Create a copy - main_param = param.detach().clone().float() - # Copy tensor model parallel attributes. - tensor_parallel.copy_tensor_model_parallel_attributes(main_param, param) - if hasattr(param, 'shared'): - main_param.shared = param.shared - # Replace the optimizer params with the new fp32 copy. - param_group['params'][i] = main_param - - fp32_from_float16_params_this_group.append(main_param) - # Reset existing state dict key to the new main param. - if param in self.optimizer.state: - self.optimizer.state[main_param] = self.optimizer.state.pop(param) - # fp32 params. - elif param.type() == 'torch.cuda.FloatTensor': - fp32_params_this_group.append(param) - param_group['params'][i] = param - - else: - raise TypeError( - 'Wrapped parameters must be one of ' - 'torch.cuda.FloatTensor, ' - 'torch.cuda.HalfTensor, or ' - 'torch.cuda.BFloat16Tensor. ' - 'Received {}'.format(param.type()) - ) - - self.float16_groups.append(float16_params_this_group) - self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group) - self.fp32_from_fp32_groups.append(fp32_params_this_group) + if optimizer: + # Three groups of parameters: + # float16_groups: original float16 parameters + # fp32_from_float16_groups: fp32 copy of float16 parameters + # fp32_from_fp32_groups: original fp32 parameters + self.float16_groups = [] + self.fp32_from_float16_groups = [] + self.fp32_from_fp32_groups = [] + + # For all the groups in the original optimizer: + for param_group in self.optimizer.param_groups: + float16_params_this_group = [] + fp32_params_this_group = [] + fp32_from_float16_params_this_group = [] + # For all the parameters in this group: + for i, param in enumerate(param_group['params']): + if param.requires_grad: + + # float16 params: + if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: + float16_params_this_group.append(param) + # Create a copy + main_param = param.detach().clone().float() + # Copy tensor model parallel attributes. + tensor_parallel.copy_tensor_model_parallel_attributes(main_param, param) + if hasattr(param, 'shared'): + main_param.shared = param.shared + # Replace the optimizer params with the new fp32 copy. + param_group['params'][i] = main_param + + fp32_from_float16_params_this_group.append(main_param) + # Reset existing state dict key to the new main param. + if param in self.optimizer.state: + self.optimizer.state[main_param] = self.optimizer.state.pop(param) + # fp32 params. + elif param.type() == 'torch.cuda.FloatTensor': + fp32_params_this_group.append(param) + param_group['params'][i] = param + + else: + raise TypeError( + 'Wrapped parameters must be one of ' + 'torch.cuda.FloatTensor, ' + 'torch.cuda.HalfTensor, or ' + 'torch.cuda.BFloat16Tensor. ' + 'Received {}'.format(param.type()) + ) + + self.float16_groups.append(float16_params_this_group) + self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + self.is_stub_optimizer = False + else: + self.is_stub_optimizer = True def zero_grad(self, set_to_none=True): """We only need to zero the model related parameters, i.e., @@ -563,6 +586,8 @@ def zero_grad(self, set_to_none=True): fp32_from_float16_groups as a memory optimization to reduce fragmentation; in the case of set_to_none==True, the space used by this field can be safely deallocated at this point.""" + if self.is_stub_optimizer: + return for group in self.float16_groups: _zero_grad_group_helper(group, set_to_none) for group in self.fp32_from_float16_groups: @@ -571,6 +596,8 @@ def zero_grad(self, set_to_none=True): _zero_grad_group_helper(group, set_to_none) def _collect_main_grad_data_for_unscaling(self): + if self.is_stub_optimizer: + return main_grads = [] @@ -739,9 +766,12 @@ def __init__( super(FP32Optimizer, self).__init__(optimizer, config, init_state_fn) self._scale = torch.tensor([1.0], dtype=torch.float, device='cuda') + self.is_stub_optimizer = True if optimizer is None else False def zero_grad(self, set_to_none=True): """Copied from torch.optim.optimizer""" + if self.is_stub_optimizer: + return for group in self.optimizer.param_groups: _zero_grad_group_helper(group['params'], set_to_none) @@ -752,6 +782,8 @@ def get_loss_scale(self): @torch.no_grad() def prepare_grads(self) -> bool: """Pre-processing gradients before the optimizer step, returns whether inf/nan is found.""" + if self.is_stub_optimizer: + return False timers = self.config.timers # Copy main_grads to grads. @@ -771,6 +803,8 @@ def prepare_grads(self) -> bool: @torch.no_grad() def step_with_ready_grads(self) -> bool: """Step the optimizer with ready gradients, return successful.""" + if self.is_stub_optimizer: + return True timers = self.config.timers # Update parameters. @@ -904,13 +938,19 @@ class ChainedOptimizer(MegatronOptimizer): def __init__(self, chained_optimizers: List[MegatronOptimizer]): self.model_chunks = [] - self.config = getattr(chained_optimizers[0], 'config', None) - for optimizer in chained_optimizers: - if hasattr(optimizer, 'model_chunks'): - for model_chunk in optimizer.model_chunks: - if model_chunk not in self.model_chunks: - self.model_chunks.append(model_chunk) - assert self.config == getattr(optimizer, 'config', None) + # chained_optimizers would be empty in the case that a rank + # has no trainable parameters + if chained_optimizers: + self.config = getattr(chained_optimizers[0], 'config', None) + for optimizer in chained_optimizers: + if hasattr(optimizer, 'model_chunks'): + for model_chunk in optimizer.model_chunks: + if model_chunk not in self.model_chunks: + self.model_chunks.append(model_chunk) + assert self.config == getattr(optimizer, 'config', None) + self.is_stub_optimizer = False + else: + self.is_stub_optimizer = True self.chained_optimizers = chained_optimizers @property @@ -934,7 +974,10 @@ def zero_grad(self, set_to_none=True): optimizer.zero_grad(set_to_none) def get_loss_scale(self): - return self.chained_optimizers[0].get_loss_scale() + if self.chained_optimizers: + return self.chained_optimizers[0].get_loss_scale() + else: + return torch.tensor([1.0], dtype=torch.float32, device=torch.cuda.current_device()) def reload_model_params(self): for optimizer in self.chained_optimizers: @@ -991,6 +1034,8 @@ def step_with_ready_grads(self) -> bool: @torch.no_grad() def step(self): """ChainedOptimizer will step all optimizers one by one.""" + if self.is_stub_optimizer: + return True, 0.0, 0 found_inf_flag = self.prepare_grads() if found_inf_flag: return False, None, None diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index bfbe3d4283..ff430957d1 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -196,12 +196,10 @@ def validate_args(args, defaults={}): # Set args.use_dist_ckpt from args.ckpt_format. update_use_dist_ckpt(args) - if args.encoder_pipeline_model_parallel_size == 0 and args.num_experts == 0: assert args.encoder_tensor_model_parallel_size == args.tensor_model_parallel_size, "If non-MOE encoder shares first decoder pipeline rank it must have the same TP as the decoder." if args.encoder_tensor_model_parallel_size > 0: - assert args.encoder_pipeline_model_parallel_size > 0, "encoder_pipeline_model_parallel_size must be defined." assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0 assert args.encoder_tensor_model_parallel_size <= args.tensor_model_parallel_size, "We do not support encoders with more TP than the decoder." @@ -675,7 +673,7 @@ def validate_args(args, defaults={}): args.num_experts = None if args.num_experts is not None: assert args.spec is None, "Model Spec must be None when using MoEs" - + if args.moe_ffn_hidden_size is None: args.moe_ffn_hidden_size = args.ffn_hidden_size diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index e24bf7d2f4..b51a6c7c78 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -385,7 +385,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati optim_checkpoint_name = \ get_distributed_optimizer_checkpoint_name(checkpoint_name) ensure_directory_exists(optim_checkpoint_name) - optimizer.save_parameter_state(optim_checkpoint_name) + if not optimizer.is_stub_optimizer: + optimizer.save_parameter_state(optim_checkpoint_name) async_save_request = None if args.async_save: @@ -620,7 +621,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler, model[i].state_dict_for_save_checkpoint()) # Optimizer stuff. if not args.no_save_optim: - if optimizer is not None: + if optimizer is not None and not optimizer.is_stub_optimizer: state_dict['optimizer'] = (optimizer.sharded_state_dict(state_dict, **(optim_sd_kwargs or {})) if use_dist_ckpt else optimizer.state_dict()) @@ -1161,7 +1162,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri args, model, gen_sd_optim, gen_sd_opt_param_scheduler, gen_sd_rng_state, use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, rerun_state=gen_sd_rerun_state ) - + # When "--fp8-param-gather" is disabled, this function doesn't modify anything. fix_fp8_params_lose_precision_when_loading_dist_ckpt(load_kwargs['sharded_state_dict']) @@ -1244,7 +1245,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri if not release and not args.finetune and not args.no_load_optim: try: # Load state dict. - if not skip_load_to_model_and_opt and optimizer is not None: + if not skip_load_to_model_and_opt and optimizer is not None and not optimizer.is_stub_optimizer: optimizer.load_state_dict(state_dict['optimizer']) # Load distributed optimizer's custom parameter state. diff --git a/megatron/training/training.py b/megatron/training/training.py index f640eec37c..4d00bd1c8a 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -69,14 +69,16 @@ from .async_utils import maybe_finalize_async_save from .utils import ( + append_to_progress_log, calc_params_l2_norm, check_adlr_autoresume_termination, + logical_and_across_model_parallel_group, + reduce_max_stat_across_model_parallel_group, is_last_rank, print_rank_0, print_rank_last, report_memory, unwrap_model, - append_to_progress_log, update_use_dist_ckpt, ) from .global_vars import ( @@ -86,7 +88,8 @@ get_timers, get_tensorboard_writer, get_wandb_writer, - get_one_logger) + get_one_logger, +) from . import one_logger_utils from . import ft_integration @@ -212,7 +215,7 @@ def _get_field(string, type): def preprocess_common_state_dict(common_state_dict): import copy - # Convert args key of type namespace to dictionary + # Convert args key of type namespace to dictionary preprocessed_common_state_dict = copy.deepcopy(common_state_dict) preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args']) # Remove rank and local rank from state dict if it exists, since they are expected to be different @@ -746,7 +749,7 @@ def setup_model_and_optimizer(model_provider_func, def train_step(forward_step_func, data_iterator, - model, optimizer, opt_param_scheduler, config): + model, optimizer, opt_param_scheduler, config): """Single training step.""" args = get_args() timers = get_timers() @@ -783,10 +786,20 @@ def train_step(forward_step_func, data_iterator, unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) # Update parameters. + timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) update_successful, grad_norm, num_zeros_in_grad = optimizer.step() timers('optimizer').stop() + # when freezing sub-models we may have a mixture of successful and unsucessful ranks, + # so we must gather across mp ranks + update_successful = logical_and_across_model_parallel_group(update_successful) + # grad_norm and num_zeros_in_grad will be None on ranks without trainable params, + # so we must gather across mp ranks + grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) + if args.log_num_zeros_in_grad: + num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad) + # Vision momentum. if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino": unwrapped_model = unwrap_model(model[0]) @@ -825,7 +838,6 @@ def train_step(forward_step_func, data_iterator, numerator += val denominator += 1 loss_reduced[key] = numerator / denominator - return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad @@ -906,6 +918,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r total_iterations = total_loss_dict[advanced_iters_key] + \ total_loss_dict[skipped_iters_key] + # learning rate will be None on ranks without trainable params, so we must gather across mp ranks + learning_rate = reduce_max_stat_across_model_parallel_group(learning_rate) # Tensorboard values. # Timer requires all the ranks to call. if args.log_timers_to_tensorboard and \ @@ -923,12 +937,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r wandb_writer.log({'samples vs steps': args.consumed_train_samples}, iteration) writer.add_scalar('learning-rate', learning_rate, iteration) - if args.decoupled_lr is not None: - writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration) writer.add_scalar('learning-rate vs samples', learning_rate, - args.consumed_train_samples) + args.consumed_train_samples) if wandb_writer: wandb_writer.log({'learning-rate': learning_rate}, iteration) + if args.decoupled_lr is not None: + writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration) if args.skipped_train_samples > 0: writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration) if wandb_writer: @@ -1028,7 +1042,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r writer.add_scalar('throughput', throughput, iteration) if wandb_writer: wandb_writer.log({'throughput': throughput}, iteration) - assert learning_rate is not None # Decoupled_learning_rate should be not None only on first and last pipeline stage. log_string += f' learning rate: {learning_rate:.6E} |' if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or @@ -1061,7 +1074,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 print_rank_last(log_string) - if report_memory_flag and learning_rate > 0.: + if report_memory_flag: # Report memory after optimizer state has been initialized. if torch.distributed.get_rank() == 0: num_microbatches = get_num_microbatches() @@ -1511,8 +1524,12 @@ def get_e2e_base_metrics(): num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch # Logging. - loss_scale = optimizer.get_loss_scale().item() + if not optimizer.is_stub_optimizer: + loss_scale = optimizer.get_loss_scale().item() + else: + loss_scale = 1.0 params_norm = None + if args.log_params_norm: params_norm = calc_params_l2_norm(model) learning_rate = None @@ -1715,7 +1732,9 @@ def evaluate(forward_step_func, timers('evaluate').stop() timers.log(['evaluate']) - + + rerun_state_machine.set_mode(rerun_mode) + rerun_state_machine.set_mode(rerun_mode) return total_loss_dict, collected_non_loss_data, False diff --git a/megatron/training/utils.py b/megatron/training/utils.py index b91c8e90cf..2f517d2be3 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -94,13 +94,16 @@ def calc_params_l2_norm(model): # Calculate dense param norm dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda') - norm, _ = multi_tensor_applier( - multi_tensor_l2norm, - dummy_overflow_buf, - [params_data], - False # no per-parameter norm - ) - norm_2 = norm * norm + if len(params_data) > 0: + norm, _ = multi_tensor_applier( + multi_tensor_l2norm, + dummy_overflow_buf, + [params_data], + False # no per-parameter norm + ) + norm_2 = norm * norm + else: + norm_2 = torch.tensor([0.0], dtype=torch.float32, device='cuda') if data_parallel_group is not None: torch.distributed.all_reduce(norm_2, @@ -144,6 +147,41 @@ def average_losses_across_data_parallel_group(losses): return averaged_losses +def reduce_max_stat_across_model_parallel_group(stat: float) -> float: + """ + Ranks without an optimizer will have no grad_norm or num_zeros_in_grad stats. + We need to ensure the logging and writer rank has those values. + This function reduces a stat tensor across the model parallel group. + + We use an all_reduce max since the values have already been summed across optimizer ranks where possible + """ + if stat is None: + stat = -1.0 + stat = torch.tensor([stat], dtype=torch.float32, device=torch.cuda.current_device()) + torch.distributed.all_reduce( + stat, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group() + ) + if stat.item() == -1.0: + return None + else: + return stat.item() + + +def logical_and_across_model_parallel_group(input: bool) -> bool: + """ + This function gathers a bool value across the model parallel group + """ + if input is True: + input = 1 + else: + input = 0 + input = torch.tensor([input], dtype=torch.int, device=torch.cuda.current_device()) + torch.distributed.all_reduce( + input, op=torch.distributed.ReduceOp.MIN, group=mpu.get_model_parallel_group() + ) + return bool(input.item()) + + def report_memory(name): """Simple GPU memory report.""" mega_bytes = 1024.0 * 1024.0 @@ -402,11 +440,11 @@ def _broadcast(item): _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(position_ids) - + elif mpu.is_pipeline_first_stage(): labels=None loss_mask=None - + _broadcast(tokens) _broadcast(attention_mask) _broadcast(position_ids) @@ -414,11 +452,11 @@ def _broadcast(item): elif mpu.is_pipeline_last_stage(): tokens=None position_ids=None - + _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) - + batch = { 'tokens': tokens, 'labels': labels, diff --git a/pretrain_vlm.py b/pretrain_vlm.py index 605634060f..1870a77d61 100644 --- a/pretrain_vlm.py +++ b/pretrain_vlm.py @@ -83,12 +83,6 @@ def model_provider( assert args.ckpt_format == 'torch', "Only ckpt-format torch is supported for VLM training currently." - if args.pipeline_model_parallel_size > 1: - assert not args.freeze_LM, "Freezing a pipeline parallel language model is not currently supported" - - if args.encoder_pipeline_model_parallel_size == 1: - assert not args.freeze_ViT, "Freezing a vision encoder on its own pipeline rank is not currently supported" - num_image_embeddings = get_num_image_embeddings( args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token, class_token_len=1, pixel_shuffle=False, use_tile_tags=False @@ -129,7 +123,7 @@ def model_provider( language_transformer_layer_spec = decoder_model_with_local_default_spec( args.num_experts, args.moe_grouped_gemm ) - + # Prepare mask type for any required padding to support CP/SP sequence sharding. if mp_padding_needed > 0: if language_transformer_layer_spec.submodules.self_attention.params.get('attn_mask_type', '') == AttnMaskType.causal: @@ -351,10 +345,10 @@ def _get_packed_seq_params(tokens, img_seq_len, mp_padding_needed): labels = data_i["labels"].long() loss_mask = data_f["loss_mask"].float() images = data_f["image"].float() - + if cp_size > 1 or args.sequence_parallel: vision_model_type = "clip" - # Calculate the number of image embedding tokens will be added to text tokens + # Calculate the number of image embedding tokens will be added to text tokens num_image_embeddings_per_tile = get_num_image_embeddings( args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token, 1 ) @@ -367,7 +361,7 @@ def _get_packed_seq_params(tokens, img_seq_len, mp_padding_needed): num_images_per_sample = torch.sum(image_token_mask, dim=-1) img_seq_len = (num_image_embeddings_per_tile * num_images_per_sample - num_images_per_sample).max() packed_seq_params = _get_packed_seq_params(tokens, img_seq_len, mp_padding_needed_for_text) - + # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank({"tokens": tokens, "position_ids": position_ids}) attention_mask = None # Use the attention mask type defined in layer spec. Typically no mask for the vision model and causal mask for the vision model. diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/golden_values_dev.json b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/golden_values_dev.json new file mode 100644 index 0000000000..a2ef225d83 --- /dev/null +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/golden_values_dev.json @@ -0,0 +1 @@ +{"forward-backward-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [7.96777, 0.62507, 0.62176, 0.62042, 0.62061, 0.62067, 0.62001, 0.61924, 0.61823, 0.6178]}, "forward-compute-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3.04896, 0.30356, 0.30062, 0.29886, 0.29955, 0.29936, 0.29825, 0.29839, 0.2968, 0.29625]}, "backward-compute-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.99454, 0.28657, 0.28691, 0.28667, 0.28654, 0.28672, 0.28654, 0.2861, 0.28657, 0.28683]}, "batch-generator-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.3938, 0.01749, 0.01695, 0.01841, 0.01751, 0.01736, 0.01792, 0.01739, 0.01667, 0.01628]}, "forward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3.32161, 0.03012, 0.02986, 0.02994, 0.02968, 0.02964, 0.03016, 0.02977, 0.02991, 0.02985]}, "forward-send-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.53192, 0.00018, 0.00018, 0.00018, 0.00019, 0.0002, 0.00019, 0.00019, 0.00019, 0.00018]}, "backward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.07283, 0.07198, 0.07135, 0.07044, 0.07023, 0.07085, 0.07065, 0.07057, 0.0704, 0.07021]}, "backward-send-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00023, 0.00029, 0.0002, 0.00027, 0.00027, 0.00032, 0.00032, 0.00028, 0.00027, 0.00021]}, "forward-send-backward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3.46399, 0.30175, 0.30094, 0.29597, 0.29703, 0.29641, 0.2959, 0.29432, 0.29344, 0.29317]}, "backward-send-forward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.82172, 0.00243, 0.00247, 0.00234, 0.00236, 0.00228, 0.0023, 0.00235, 0.00232, 0.00233]}, "layernorm-grads-all-reduce-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [4e-05, 3e-05, 3e-05, 3e-05, 3e-05, 2e-05, 3e-05, 3e-05, 3e-05, 3e-05]}, "embedding-grads-all-reduce-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [7e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05]}, "all-grads-sync-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.16382, 0.00025, 0.00025, 0.00025, 0.00024, 0.00024, 0.00024, 0.00024, 0.00023, 0.00026]}, "optimizer-copy-to-main-grad-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [7e-05, 5e-05, 5e-05, 5e-05, 5e-05, 5e-05, 5e-05, 5e-05, 5e-05, 5e-05]}, "optimizer-clip-main-grad-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.83319, 0.00053, 0.00052, 0.00044, 0.00052, 0.00043, 0.00043, 0.00043, 0.00043, 0.00043]}, "optimizer-count-zeros-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00895, 0.00069, 0.00069, 0.00068, 0.00069, 0.00069, 0.00068, 0.00068, 0.00068, 0.00069]}, "optimizer-inner-step-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00119, 0.00025, 0.00024, 0.00023, 0.00023, 0.00025, 0.00024, 0.00024, 0.00024, 0.00025]}, "optimizer-copy-main-to-model-params-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00014, 9e-05, 9e-05, 8e-05, 8e-05, 9e-05, 9e-05, 8e-05, 9e-05, 9e-05]}, "optimizer-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.84455, 0.00225, 0.00226, 0.00214, 0.00221, 0.00216, 0.00214, 0.00213, 0.00214, 0.00214]}, "learning-rate": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, "learning-rate vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, "batch-size": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0]}, "batch-size vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0]}, "lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19947, 9.20335, 9.20248, 9.19723, 9.19172, 9.18973, 9.18517, 9.17532, 9.17374, 9.1609]}, "lm loss vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19947, 9.20335, 9.20248, 9.19723, 9.19172, 9.18973, 9.18517, 9.17532, 9.17374, 9.1609]}, "loss-scale": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, "loss-scale vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, "grad-norm": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.93277, 1.00171, 1.00056, 0.944, 1.16867, 0.98576, 0.91686, 0.9042, 0.83078, 0.88219]}, "grad-norm vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.93277, 1.00171, 1.00056, 0.944, 1.16867, 0.98576, 0.91686, 0.9042, 0.83078, 0.88219]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [106.0, 114.0, 108.0, 110.0, 81.0, 105.0, 85.0, 109.0, 146.0, 122.0]}, "num-zeros vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [106.0, 114.0, 108.0, 110.0, 81.0, 105.0, 85.0, 109.0, 146.0, 122.0]}, "params-norm": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.6785, 35.67851, 35.6785, 35.67848, 35.67848]}, "params-norm vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.6785, 35.67851, 35.6785, 35.67848, 35.67848]}, "iteration-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [8.83079, 0.64044, 0.63692, 0.63516, 0.63554, 0.63541, 0.63471, 0.63399, 0.63285, 0.63245]}, "lm loss validation": {"start_step": 0, "end_step": 2, "step_interval": 5, "values": [9.1542]}, "lm loss validation vs samples": {"start_step": 0, "end_step": 1, "step_interval": 5, "values": [9.1542]}, "lm loss validation ppl": {"start_step": 0, "end_step": 1, "step_interval": 5, "values": [9454.09668]}, "lm loss validation ppl vs samples": {"start_step": 0, "end_step": 1, "step_interval": 5, "values": [9454.09668]}} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/golden_values_lts.json b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/golden_values_lts.json new file mode 100644 index 0000000000..3c933e0123 --- /dev/null +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/golden_values_lts.json @@ -0,0 +1 @@ +{"forward-backward-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [12.7291, 0.62672, 0.60589, 0.60528, 0.60867, 0.60545, 0.60403, 0.61268, 0.61851, 0.60357]}, "forward-compute-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [5.56178, 0.30066, 0.28459, 0.28176, 0.28541, 0.27947, 0.28138, 0.28895, 0.29453, 0.28039]}, "backward-compute-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.12115, 0.28858, 0.28597, 0.28809, 0.28772, 0.28811, 0.28721, 0.28849, 0.28849, 0.28829]}, "batch-generator-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [2.85702, 0.03903, 0.0338, 0.03035, 0.03224, 0.03016, 0.02978, 0.03435, 0.03368, 0.02954]}, "forward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [5.26228, 0.03127, 0.02963, 0.02987, 0.02952, 0.03226, 0.02962, 0.02934, 0.02956, 0.02928]}, "forward-send-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [2.5072, 0.00017, 0.00015, 0.00018, 0.00016, 0.00015, 0.00015, 0.00015, 0.00017, 0.00015]}, "backward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.07163, 0.07147, 0.0696, 0.06982, 0.07399, 0.0702, 0.06973, 0.07326, 0.07023, 0.06973]}, "backward-send-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00026, 0.00021, 0.00019, 0.00019, 0.00019, 0.00018, 0.00019, 0.0002, 0.0002, 0.00019]}, "forward-send-backward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [6.16563, 0.28249, 0.27763, 0.28103, 0.27952, 0.28051, 0.2813, 0.28172, 0.29124, 0.28177]}, "backward-send-forward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.92523, 0.00228, 0.00214, 0.00215, 0.00226, 0.00213, 0.00217, 0.00235, 0.00224, 0.00219]}, "layernorm-grads-all-reduce-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [4e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05]}, "embedding-grads-all-reduce-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [6e-05, 3e-05, 3e-05, 3e-05, 3e-05, 4e-05, 3e-05, 3e-05, 3e-05, 4e-05]}, "all-grads-sync-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.19033, 0.00022, 0.00021, 0.00022, 0.00022, 0.00023, 0.00022, 0.00022, 0.00022, 0.00022]}, "optimizer-copy-to-main-grad-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [6e-05, 4e-05, 4e-05, 4e-05, 5e-05, 4e-05, 4e-05, 4e-05, 4e-05, 5e-05]}, "optimizer-clip-main-grad-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [2.24661, 0.00048, 0.00047, 0.00038, 0.00047, 0.00039, 0.00039, 0.00039, 0.00039, 0.0004]}, "optimizer-count-zeros-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00926, 0.00069, 0.00062, 0.00063, 0.00063, 0.00063, 0.00062, 0.00063, 0.00062, 0.00062]}, "optimizer-inner-step-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00112, 0.0002, 0.0002, 0.00021, 0.00021, 0.00021, 0.00021, 0.00021, 0.00022, 0.00021]}, "optimizer-copy-main-to-model-params-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00014, 8e-05, 8e-05, 8e-05, 8e-05, 8e-05, 8e-05, 8e-05, 8e-05, 8e-05]}, "optimizer-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [2.25814, 0.0021, 0.00203, 0.00193, 0.00201, 0.00193, 0.00195, 0.00196, 0.00197, 0.00195]}, "learning-rate": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, "learning-rate vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, "batch-size": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0]}, "batch-size vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0]}, "lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19948, 9.20339, 9.20246, 9.19721, 9.1917, 9.18976, 9.18512, 9.17531, 9.17379, 9.16091]}, "lm loss vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19948, 9.20339, 9.20246, 9.19721, 9.1917, 9.18976, 9.18512, 9.17531, 9.17379, 9.16091]}, "loss-scale": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, "loss-scale vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, "grad-norm": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.93282, 1.00192, 1.00046, 0.94405, 1.16906, 0.98576, 0.91648, 0.90421, 0.83062, 0.8822]}, "grad-norm vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.93282, 1.00192, 1.00046, 0.94405, 1.16906, 0.98576, 0.91648, 0.90421, 0.83062, 0.8822]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [103.0, 122.0, 112.0, 97.0, 93.0, 105.0, 109.0, 107.0, 125.0, 130.0]}, "num-zeros vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [103.0, 122.0, 112.0, 97.0, 93.0, 105.0, 109.0, 107.0, 125.0, 130.0]}, "params-norm": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.6785, 35.67849, 35.67848]}, "params-norm vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.6785, 35.67849, 35.67848]}, "iteration-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [15.00501, 0.64144, 0.62022, 0.6193, 0.62312, 0.61981, 0.61869, 0.62693, 0.63288, 0.61782]}, "lm loss validation": {"start_step": 0, "end_step": 2, "step_interval": 5, "values": [9.15419]}, "lm loss validation vs samples": {"start_step": 0, "end_step": 2, "step_interval": 5, "values": [9.15419]}, "lm loss validation ppl": {"start_step": 0, "end_step": 2, "step_interval": 5, "values": [9453.99707]}, "lm loss validation ppl vs samples": {"start_step": 0, "end_step": 2, "step_interval": 5, "values": [9453.99707]}} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/model_config.yaml new file mode 100644 index 0000000000..e2ef184e5e --- /dev/null +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/model_config.yaml @@ -0,0 +1,57 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Tree + CUBLAS_WORKSPACE_CONFIG: :4096:8 + GPUS_PER_NODE: 8 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 624 + --attention-dropout: 0.0 + --hidden-dropout: 0.0 + --num-attention-heads: 12 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 50 + --timing-log-level: 2 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_PATH} + --load: ${CHECKPOINT_PATH} + --split: 949,50,1 + --tokenizer-type: NullTokenizer + --vocab-size: 8192 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 10000 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 4 + --pipeline-model-parallel-size: 1 + --encoder-pipeline-model-parallel-size: 1 + --encoder-tensor-model-parallel-size: 4 + --deterministic-mode: true + --attention-softmax-in-fp32: true + --ckpt-format: torch + --no-gradient-accumulation-fusion: true + --bf16: true + --img-h: 336 + --img-w: 336 + --patch-dim: 14 + --mock-data: true + --freeze-ViT: true + --freeze-LM: true +TEST_TYPE: regular diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/golden_values_dev.json b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/golden_values_dev.json new file mode 100644 index 0000000000..c4c1cffa46 --- /dev/null +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/golden_values_dev.json @@ -0,0 +1 @@ +{"forward-backward-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.80164, 0.62602, 0.62115, 0.61347, 0.61356, 0.6148, 0.61452, 0.61389, 0.61239, 0.61187]}, "forward-compute-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [5.14549, 0.30295, 0.29758, 0.29055, 0.29096, 0.29124, 0.29129, 0.2913, 0.29037, 0.28939]}, "backward-compute-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.12619, 0.28782, 0.28877, 0.28732, 0.28777, 0.28808, 0.28786, 0.28769, 0.28753, 0.28791]}, "batch-generator-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.29859, 0.02375, 0.02123, 0.01897, 0.01822, 0.01828, 0.01866, 0.01876, 0.01889, 0.01783]}, "forward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3.69025, 0.02974, 0.02963, 0.03036, 0.03015, 0.03018, 0.03047, 0.03047, 0.03, 0.03017]}, "forward-send-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.06877, 0.00017, 0.00016, 0.00015, 0.00015, 0.00015, 0.00018, 0.00015, 0.00016, 0.00014]}, "backward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.07001, 0.07185, 0.07034, 0.07062, 0.07068, 0.07076, 0.07093, 0.07034, 0.07033, 0.07056]}, "backward-send-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00032, 0.00023, 0.00027, 0.00028, 0.00026, 0.0003, 0.00028, 0.00029, 0.00028, 0.00029]}, "forward-send-backward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [5.90985, 0.29772, 0.29629, 0.28867, 0.29204, 0.29221, 0.29134, 0.28969, 0.29014, 0.29351]}, "backward-send-forward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.86713, 0.00263, 0.0025, 0.00238, 0.00246, 0.00238, 0.00237, 0.00259, 0.00243, 0.00254]}, "layernorm-grads-all-reduce-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3e-05, 2e-05, 2e-05, 2e-05, 2e-05, 2e-05, 3e-05, 3e-05, 2e-05, 2e-05]}, "embedding-grads-all-reduce-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [5e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05]}, "all-grads-sync-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.20519, 0.00031, 0.00025, 0.00025, 0.00026, 0.00025, 0.00025, 0.00025, 0.00025, 0.00025]}, "params-all-gather-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00016, 0.00013, 0.00012, 0.00011, 0.00011, 0.00011, 0.00011, 0.00011, 0.00011, 0.00011]}, "optimizer-copy-to-main-grad-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00015, 0.00013, 0.00011, 0.00011, 0.00011, 0.00011, 0.0001, 0.0001, 0.0001, 0.0001]}, "optimizer-clip-main-grad-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.01362, 0.00058, 0.00048, 0.00041, 0.00047, 0.0004, 0.0004, 0.00039, 0.0004, 0.0004]}, "optimizer-count-zeros-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00823, 0.00068, 0.00072, 0.00073, 0.00068, 0.00069, 0.00069, 0.0007, 0.00069, 0.00066]}, "optimizer-inner-step-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00098, 0.00026, 0.00023, 0.00023, 0.00025, 0.00023, 0.00023, 0.00024, 0.00024, 0.00023]}, "optimizer-copy-main-to-model-params-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00019, 0.00018, 0.00015, 0.00016, 0.00015, 0.00016, 0.00016, 0.00015, 0.00015, 0.00015]}, "optimizer-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.02427, 0.00277, 0.00256, 0.00257, 0.00249, 0.00243, 0.00242, 0.00241, 0.00241, 0.00237]}, "learning-rate": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, "learning-rate vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, "batch-size": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0]}, "batch-size vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0]}, "lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19947, 9.20335, 9.20248, 9.19723, 9.19172, 9.18973, 9.18517, 9.17532, 9.17374, 9.1609]}, "lm loss vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19947, 9.20335, 9.20248, 9.19723, 9.19172, 9.18973, 9.18517, 9.17532, 9.17374, 9.1609]}, "loss-scale": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, "loss-scale vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, "grad-norm": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.93277, 1.00171, 1.00056, 0.944, 1.16867, 0.98576, 0.91686, 0.9042, 0.83078, 0.88219]}, "grad-norm vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.93277, 1.00171, 1.00056, 0.944, 1.16867, 0.98576, 0.91686, 0.9042, 0.83078, 0.88219]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [106.0, 114.0, 108.0, 110.0, 81.0, 105.0, 85.0, 109.0, 146.0, 122.0]}, "num-zeros vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [106.0, 114.0, 108.0, 110.0, 81.0, 105.0, 85.0, 109.0, 146.0, 122.0]}, "params-norm": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.6785, 35.67851, 35.6785, 35.67848, 35.67848]}, "params-norm vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.6785, 35.67851, 35.6785, 35.67848, 35.67848]}, "iteration-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [11.71205, 0.64203, 0.63681, 0.62887, 0.62867, 0.62983, 0.6294, 0.62857, 0.62698, 0.62637]}, "lm loss validation": {"start_step": 0, "end_step": 2, "step_interval": 5, "values": [9.1542]}, "lm loss validation vs samples": {"start_step": 0, "end_step": 1, "step_interval": 5, "values": [9.1542]}, "lm loss validation ppl": {"start_step": 0, "end_step": 1, "step_interval": 5, "values": [9454.09668]}, "lm loss validation ppl vs samples": {"start_step": 0, "end_step": 1, "step_interval": 5, "values": [9454.09668]}} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/golden_values_lts.json b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/golden_values_lts.json new file mode 100644 index 0000000000..bfdacf168e --- /dev/null +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/golden_values_lts.json @@ -0,0 +1 @@ +{"forward-backward-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.12533, 0.61523, 0.612, 0.61274, 0.60959, 0.61563, 0.61043, 0.62211, 0.61259, 0.61475]}, "forward-compute-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3.2886, 0.29298, 0.28952, 0.29035, 0.28755, 0.29301, 0.28608, 0.30023, 0.28978, 0.29236]}, "backward-compute-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.10925, 0.28738, 0.28707, 0.28715, 0.28829, 0.28813, 0.29022, 0.28846, 0.29053, 0.29005]}, "batch-generator-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.76471, 0.01852, 0.01694, 0.02369, 0.02029, 0.01651, 0.01633, 0.02469, 0.01956, 0.01684]}, "forward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3.67666, 0.02972, 0.02965, 0.02942, 0.02811, 0.0288, 0.0288, 0.02849, 0.02832, 0.02838]}, "forward-send-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.9526, 0.00016, 0.00016, 0.00016, 0.00016, 0.00018, 0.00017, 0.00017, 0.00014, 0.00015]}, "backward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.07105, 0.07081, 0.07084, 0.07037, 0.06972, 0.07299, 0.06941, 0.06963, 0.07091, 0.07042]}, "backward-send-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00019, 0.0002, 0.00021, 0.00019, 0.0002, 0.00019, 0.00019, 0.00018, 0.00018, 0.00018]}, "forward-send-backward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [4.17022, 0.29888, 0.30073, 0.30472, 0.30255, 0.30377, 0.30116, 0.3082, 0.3045, 0.30713]}, "backward-send-forward-recv-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.89549, 0.00229, 0.00225, 0.00218, 0.00224, 0.00218, 0.00214, 0.00228, 0.00208, 0.00209]}, "layernorm-grads-all-reduce-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3e-05, 3e-05, 4e-05, 2e-05, 3e-05, 2e-05, 2e-05, 3e-05, 2e-05, 2e-05]}, "embedding-grads-all-reduce-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [5e-05, 3e-05, 5e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05, 3e-05]}, "all-grads-sync-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.19492, 0.00027, 0.00039, 0.00025, 0.00027, 0.00025, 0.00024, 0.00025, 0.00022, 0.00022]}, "params-all-gather-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00015, 0.0001, 0.00011, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 9e-05, 9e-05]}, "optimizer-copy-to-main-grad-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00013, 0.00011, 0.00011, 0.0001, 0.0001, 0.0001, 0.0001, 0.00011, 9e-05, 9e-05]}, "optimizer-clip-main-grad-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.02498, 0.00052, 0.00052, 0.00039, 0.00051, 0.00039, 0.00041, 0.00041, 0.00037, 0.00036]}, "optimizer-count-zeros-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00735, 0.00064, 0.00064, 0.00064, 0.00063, 0.00065, 0.00068, 0.00065, 0.00065, 0.00065]}, "optimizer-inner-step-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00093, 0.00021, 0.00021, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.00018, 0.00018]}, "optimizer-copy-main-to-model-params-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.00018, 0.00015, 0.00015, 0.00015, 0.00014, 0.00014, 0.00014, 0.00014, 0.00014, 0.00014]}, "optimizer-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.03475, 0.00249, 0.00249, 0.0023, 0.00258, 0.0023, 0.00234, 0.00235, 0.00223, 0.00223]}, "learning-rate": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, "learning-rate vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}, "batch-size": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0]}, "batch-size vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0, 32.0]}, "lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19948, 9.20339, 9.20246, 9.19721, 9.1917, 9.18976, 9.18515, 9.17526, 9.1738, 9.16094]}, "lm loss vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19948, 9.20339, 9.20246, 9.19721, 9.1917, 9.18976, 9.18515, 9.17526, 9.1738, 9.16094]}, "loss-scale": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, "loss-scale vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}, "grad-norm": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.93282, 1.00192, 1.00046, 0.94405, 1.16906, 0.98576, 0.91623, 0.90401, 0.83116, 0.88246]}, "grad-norm vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [0.93282, 1.00192, 1.00046, 0.94405, 1.16906, 0.98576, 0.91623, 0.90401, 0.83116, 0.88246]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [103.0, 122.0, 112.0, 97.0, 93.0, 105.0, 105.0, 101.0, 126.0, 120.0]}, "num-zeros vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [103.0, 122.0, 112.0, 97.0, 93.0, 105.0, 105.0, 101.0, 126.0, 120.0]}, "params-norm": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.6785, 35.67849, 35.67848]}, "params-norm vs samples": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.67851, 35.6785, 35.67849, 35.67848]}, "iteration-time": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [11.25871, 0.63103, 0.62702, 0.628, 0.62436, 0.6304, 0.62504, 0.63626, 0.62666, 0.62873]}, "lm loss validation": {"start_step": 0, "end_step": 2, "step_interval": 5, "values": [9.1542]}, "lm loss validation vs samples": {"start_step": 0, "end_step": 2, "step_interval": 5, "values": [9.1542]}, "lm loss validation ppl": {"start_step": 0, "end_step": 2, "step_interval": 5, "values": [9454.09668]}, "lm loss validation ppl vs samples": {"start_step": 0, "end_step": 2, "step_interval": 5, "values": [9454.09668]}} \ No newline at end of file diff --git a/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/model_config.yaml new file mode 100644 index 0000000000..9a40c4406e --- /dev/null +++ b/tests/functional_tests/test_cases/multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/model_config.yaml @@ -0,0 +1,58 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Tree + CUBLAS_WORKSPACE_CONFIG: :4096:8 + GPUS_PER_NODE: 8 +MODEL_ARGS: + --num-layers: 12 + --hidden-size: 624 + --attention-dropout: 0.0 + --hidden-dropout: 0.0 + --num-attention-heads: 12 + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 50 + --timing-log-level: 2 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_PATH} + --load: ${CHECKPOINT_PATH} + --split: 949,50,1 + --tokenizer-type: NullTokenizer + --vocab-size: 8192 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 10000 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 4 + --pipeline-model-parallel-size: 1 + --encoder-pipeline-model-parallel-size: 1 + --encoder-tensor-model-parallel-size: 4 + --deterministic-mode: true + --attention-softmax-in-fp32: true + --ckpt-format: torch + --no-gradient-accumulation-fusion: true + --bf16: true + --img-h: 336 + --img-w: 336 + --patch-dim: 14 + --mock-data: true + --freeze-ViT: true + --freeze-LM: true + --use-distributed-optimizer: true +TEST_TYPE: regular diff --git a/tests/test_utils/recipes/multimodal-llava.yaml b/tests/test_utils/recipes/multimodal-llava.yaml index 3989ebeefa..0d43c64bad 100644 --- a/tests/test_utils/recipes/multimodal-llava.yaml +++ b/tests/test_utils/recipes/multimodal-llava.yaml @@ -40,6 +40,8 @@ products: test_case: - multimodal_llava_mr_mcore_te_tp1_pp1_dgx_a100_1N8G - multimodal_llava_mr_mcore_te_tp2_pp3_dgx_a100_1N8G + - multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G + - multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G - environment: [lts, dev] scope: [mr] n_repeat: [5]