From a0cf257a24a68f80772765db2f35896e36dc0cdd Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 25 Nov 2024 16:12:35 -0800 Subject: [PATCH] programmatically get GPU type --- src/olmo_core/internal/common.py | 32 +++++++++++++++++++++++++- src/olmo_core/internal/model_ladder.py | 4 ++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/internal/common.py b/src/olmo_core/internal/common.py index 9373dca2..9399e3cf 100644 --- a/src/olmo_core/internal/common.py +++ b/src/olmo_core/internal/common.py @@ -1,3 +1,4 @@ +import logging from typing import List, Optional from beaker import Beaker @@ -11,14 +12,25 @@ ) from olmo_core.utils import generate_uuid +log = logging.getLogger(__name__) +_BEAKER_CLIENT: Optional[Beaker] = None _BEAKER_USERNAME: Optional[str] = None +def get_beaker_client() -> Beaker: + global _BEAKER_CLIENT + + if _BEAKER_CLIENT is None: + _BEAKER_CLIENT = Beaker.from_env() + + return _BEAKER_CLIENT + + def get_beaker_username() -> str: global _BEAKER_USERNAME if _BEAKER_USERNAME is None: - _BEAKER_USERNAME = Beaker.from_env().account.whoami().name + _BEAKER_USERNAME = get_beaker_client().account.whoami().name return _BEAKER_USERNAME @@ -93,3 +105,21 @@ def build_launch_config( "printenv AWS_CREDENTIALS > ~/.aws/credentials", ], ) + + +CLUSTER_TO_GPU_TYPE = { + "ai2/jupiter-cirrascale-2": "NVIDIA H100 80GB HBM3", + "ai2/pluto-cirrascale": "NVIDIA H100", + "ai2/augusta-google-1": "NVIDIA H100", +} + + +def get_gpu_type(cluster: str) -> str: + if cluster in CLUSTER_TO_GPU_TYPE: + return CLUSTER_TO_GPU_TYPE[cluster] + else: + log.warning(f"Missing cluster '{cluster}' in CLUSTER_TO_GPU_TYPE mapping") + beaker = get_beaker_client() + nodes = beaker.cluster.nodes(cluster) + assert nodes and nodes[0].limits.gpu_type + return nodes[0].limits.gpu_type diff --git a/src/olmo_core/internal/model_ladder.py b/src/olmo_core/internal/model_ladder.py index 9db92e0f..1cddc6ca 100644 --- a/src/olmo_core/internal/model_ladder.py +++ b/src/olmo_core/internal/model_ladder.py @@ -18,7 +18,7 @@ from olmo_core.train.callbacks import CometCallback, ConfigSaverCallback, WandBCallback from olmo_core.utils import get_default_device, prepare_cli_environment, seed_all -from .common import build_launch_config, get_root_dir +from .common import build_launch_config, get_gpu_type, get_root_dir @dataclass @@ -102,7 +102,7 @@ def build_config( ).merge(overrides, strict=False) dp_world_size = launch.num_nodes * launch.num_gpus - gpu_type = "h100" # TODO: get actual device name + gpu_type = get_gpu_type(cluster) model = ladder.get_model_config(size=size) optim = ladder.get_optim_config(size=size)