Skip to content

Commit a0cf257

Browse files
committed
programmatically get GPU type
1 parent 4ef79a7 commit a0cf257

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

src/olmo_core/internal/common.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import List, Optional
23

34
from beaker import Beaker
@@ -11,14 +12,25 @@
1112
)
1213
from olmo_core.utils import generate_uuid
1314

15+
log = logging.getLogger(__name__)
16+
_BEAKER_CLIENT: Optional[Beaker] = None
1417
_BEAKER_USERNAME: Optional[str] = None
1518

1619

20+
def get_beaker_client() -> Beaker:
21+
global _BEAKER_CLIENT
22+
23+
if _BEAKER_CLIENT is None:
24+
_BEAKER_CLIENT = Beaker.from_env()
25+
26+
return _BEAKER_CLIENT
27+
28+
1729
def get_beaker_username() -> str:
1830
global _BEAKER_USERNAME
1931

2032
if _BEAKER_USERNAME is None:
21-
_BEAKER_USERNAME = Beaker.from_env().account.whoami().name
33+
_BEAKER_USERNAME = get_beaker_client().account.whoami().name
2234

2335
return _BEAKER_USERNAME
2436

@@ -93,3 +105,21 @@ def build_launch_config(
93105
"printenv AWS_CREDENTIALS > ~/.aws/credentials",
94106
],
95107
)
108+
109+
110+
CLUSTER_TO_GPU_TYPE = {
111+
"ai2/jupiter-cirrascale-2": "NVIDIA H100 80GB HBM3",
112+
"ai2/pluto-cirrascale": "NVIDIA H100",
113+
"ai2/augusta-google-1": "NVIDIA H100",
114+
}
115+
116+
117+
def get_gpu_type(cluster: str) -> str:
118+
if cluster in CLUSTER_TO_GPU_TYPE:
119+
return CLUSTER_TO_GPU_TYPE[cluster]
120+
else:
121+
log.warning(f"Missing cluster '{cluster}' in CLUSTER_TO_GPU_TYPE mapping")
122+
beaker = get_beaker_client()
123+
nodes = beaker.cluster.nodes(cluster)
124+
assert nodes and nodes[0].limits.gpu_type
125+
return nodes[0].limits.gpu_type

src/olmo_core/internal/model_ladder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from olmo_core.train.callbacks import CometCallback, ConfigSaverCallback, WandBCallback
1919
from olmo_core.utils import get_default_device, prepare_cli_environment, seed_all
2020

21-
from .common import build_launch_config, get_root_dir
21+
from .common import build_launch_config, get_gpu_type, get_root_dir
2222

2323

2424
@dataclass
@@ -102,7 +102,7 @@ def build_config(
102102
).merge(overrides, strict=False)
103103

104104
dp_world_size = launch.num_nodes * launch.num_gpus
105-
gpu_type = "h100" # TODO: get actual device name
105+
gpu_type = get_gpu_type(cluster)
106106

107107
model = ladder.get_model_config(size=size)
108108
optim = ladder.get_optim_config(size=size)

0 commit comments

Comments
 (0)