|
| 1 | +import logging |
1 | 2 | from typing import List, Optional
|
2 | 3 |
|
3 | 4 | from beaker import Beaker
|
|
11 | 12 | )
|
12 | 13 | from olmo_core.utils import generate_uuid
|
13 | 14 |
|
| 15 | +log = logging.getLogger(__name__) |
| 16 | +_BEAKER_CLIENT: Optional[Beaker] = None |
14 | 17 | _BEAKER_USERNAME: Optional[str] = None
|
15 | 18 |
|
16 | 19 |
|
| 20 | +def get_beaker_client() -> Beaker: |
| 21 | + global _BEAKER_CLIENT |
| 22 | + |
| 23 | + if _BEAKER_CLIENT is None: |
| 24 | + _BEAKER_CLIENT = Beaker.from_env() |
| 25 | + |
| 26 | + return _BEAKER_CLIENT |
| 27 | + |
| 28 | + |
17 | 29 | def get_beaker_username() -> str:
|
18 | 30 | global _BEAKER_USERNAME
|
19 | 31 |
|
20 | 32 | if _BEAKER_USERNAME is None:
|
21 |
| - _BEAKER_USERNAME = Beaker.from_env().account.whoami().name |
| 33 | + _BEAKER_USERNAME = get_beaker_client().account.whoami().name |
22 | 34 |
|
23 | 35 | return _BEAKER_USERNAME
|
24 | 36 |
|
@@ -93,3 +105,21 @@ def build_launch_config(
|
93 | 105 | "printenv AWS_CREDENTIALS > ~/.aws/credentials",
|
94 | 106 | ],
|
95 | 107 | )
|
| 108 | + |
| 109 | + |
| 110 | +CLUSTER_TO_GPU_TYPE = { |
| 111 | + "ai2/jupiter-cirrascale-2": "NVIDIA H100 80GB HBM3", |
| 112 | + "ai2/pluto-cirrascale": "NVIDIA H100", |
| 113 | + "ai2/augusta-google-1": "NVIDIA H100", |
| 114 | +} |
| 115 | + |
| 116 | + |
| 117 | +def get_gpu_type(cluster: str) -> str: |
| 118 | + if cluster in CLUSTER_TO_GPU_TYPE: |
| 119 | + return CLUSTER_TO_GPU_TYPE[cluster] |
| 120 | + else: |
| 121 | + log.warning(f"Missing cluster '{cluster}' in CLUSTER_TO_GPU_TYPE mapping") |
| 122 | + beaker = get_beaker_client() |
| 123 | + nodes = beaker.cluster.nodes(cluster) |
| 124 | + assert nodes and nodes[0].limits.gpu_type |
| 125 | + return nodes[0].limits.gpu_type |
0 commit comments