Skip to content

Commit

Permalink
programmatically get GPU type
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 26, 2024
1 parent 4ef79a7 commit a0cf257
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
32 changes: 31 additions & 1 deletion src/olmo_core/internal/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import List, Optional

from beaker import Beaker
Expand All @@ -11,14 +12,25 @@
)
from olmo_core.utils import generate_uuid

log = logging.getLogger(__name__)
_BEAKER_CLIENT: Optional[Beaker] = None
_BEAKER_USERNAME: Optional[str] = None


def get_beaker_client() -> Beaker:
global _BEAKER_CLIENT

if _BEAKER_CLIENT is None:
_BEAKER_CLIENT = Beaker.from_env()

return _BEAKER_CLIENT


def get_beaker_username() -> str:
global _BEAKER_USERNAME

if _BEAKER_USERNAME is None:
_BEAKER_USERNAME = Beaker.from_env().account.whoami().name
_BEAKER_USERNAME = get_beaker_client().account.whoami().name

return _BEAKER_USERNAME

Expand Down Expand Up @@ -93,3 +105,21 @@ def build_launch_config(
"printenv AWS_CREDENTIALS > ~/.aws/credentials",
],
)


CLUSTER_TO_GPU_TYPE = {
"ai2/jupiter-cirrascale-2": "NVIDIA H100 80GB HBM3",
"ai2/pluto-cirrascale": "NVIDIA H100",
"ai2/augusta-google-1": "NVIDIA H100",
}


def get_gpu_type(cluster: str) -> str:
if cluster in CLUSTER_TO_GPU_TYPE:
return CLUSTER_TO_GPU_TYPE[cluster]
else:
log.warning(f"Missing cluster '{cluster}' in CLUSTER_TO_GPU_TYPE mapping")
beaker = get_beaker_client()
nodes = beaker.cluster.nodes(cluster)
assert nodes and nodes[0].limits.gpu_type
return nodes[0].limits.gpu_type
4 changes: 2 additions & 2 deletions src/olmo_core/internal/model_ladder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from olmo_core.train.callbacks import CometCallback, ConfigSaverCallback, WandBCallback
from olmo_core.utils import get_default_device, prepare_cli_environment, seed_all

from .common import build_launch_config, get_root_dir
from .common import build_launch_config, get_gpu_type, get_root_dir


@dataclass
Expand Down Expand Up @@ -102,7 +102,7 @@ def build_config(
).merge(overrides, strict=False)

dp_world_size = launch.num_nodes * launch.num_gpus
gpu_type = "h100" # TODO: get actual device name
gpu_type = get_gpu_type(cluster)

model = ladder.get_model_config(size=size)
optim = ladder.get_optim_config(size=size)
Expand Down

0 comments on commit a0cf257

Please sign in to comment.