diff --git a/src/olmo_core/internal/model_ladder.py b/src/olmo_core/internal/model_ladder.py index 5557e737..d88a76fb 100644 --- a/src/olmo_core/internal/model_ladder.py +++ b/src/olmo_core/internal/model_ladder.py @@ -1,3 +1,4 @@ +import logging import sys from dataclasses import dataclass from typing import Callable, List, cast @@ -21,6 +22,8 @@ from .common import build_launch_config, get_gpu_type, get_root_dir +log = logging.getLogger(__name__) + @dataclass class LadderRunConfig(Config): @@ -116,6 +119,17 @@ def build_config( data_loader = ladder.get_data_loader_config(size=size) trainer = ladder.get_trainer_config(size=size, gpu_type=gpu_type) + # Make sure rank micro-batch size makes sense. + rank_mbz_instances = trainer.rank_microbatch_size // ladder.sequence_length + global_bz_instances = data_loader.global_batch_size // ladder.sequence_length + if rank_mbz_instances * dp_world_size > global_bz_instances: + new_rank_mbz_instances = global_bz_instances // dp_world_size + new_rank_mbz = new_rank_mbz_instances * ladder.sequence_length + log.warning( + f"Adjusting rank micro-batch size from {trainer.rank_microbatch_size:,d} tokens ({rank_mbz_instances:,d} instances) " + f"down to {new_rank_mbz:,d} tokens ({new_rank_mbz_instances:,d} instances)" + ) + return LadderRunConfig( launch=launch, ladder=ladder,