Skip to content

Commit ddfe758

Browse files
committed
Merge
Remote Python Support for Benchmark Runner Add Remote Python Recipe Clean up Add Delete Capabilities Minor edits Fix Fix problems Changes Changes Fixes Undo Weird Merge Save work Break out args helper Fix file imports Move around stuff cleanup Fix Clean up comment whitespace Fix issues Fix lint
1 parent 1c3d418 commit ddfe758

File tree

6 files changed

+293
-0
lines changed

6 files changed

+293
-0
lines changed

benchmarks/benchmark_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def add_xpk_runner_arguments(custom_parser: argparse.ArgumentParser):
126126
default='us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/maxtext_jax_stable:latest',
127127
help='version of pathways runner image to be benchmarked command.',
128128
)
129+
custom_parser.add_argument(
130+
'--remote_python_sidecar_image',
131+
type=str,
132+
help='version of remote python sidecar image to be benchmarked command.',
133+
)
129134
custom_parser.add_argument(
130135
'--use_pathways',
131136
type=bool,
@@ -246,6 +251,7 @@ def main() -> None:
246251
server_image=options.pathways_server_image,
247252
proxy_image=options.pathways_proxy_image,
248253
runner_image=options.pathways_runner_image,
254+
remote_python_sidecar_image=options.remote_python_sidecar_image,
249255
)
250256

251257
workload_config = WorkloadConfig(

benchmarks/maxtext_trillium_model_configs.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,38 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
6060
)
6161

6262

63+
default_basic_1_pw = _add_to_model_dictionary(
64+
trillium_model_dict,
65+
MaxTextModel(
66+
model_name="default-basic-1-pw",
67+
model_type="default",
68+
tuning_params={
69+
"per_device_batch_size": 1,
70+
"remat_policy": "full",
71+
"global_parameter_scale": 1,
72+
"attention": "flash",
73+
"dataset_path": "gs://max-datasets-rogue",
74+
"dataset_type": "synthetic",
75+
"reuse_example_batch": 1,
76+
"enable_checkpointing": False,
77+
# "profiler": "xplane",
78+
79+
# Additional tuning params for pathways long running test.
80+
"enable_checkpointing": True,
81+
"async_checkpointing": True,
82+
"checkpoint_period": 100,
83+
"checkpoint_storage_use_ocdbt": False,
84+
"checkpoint_storage_use_zarr3": False,
85+
"metrics_file": "metrics.txt",
86+
"goodput_upload_interval_seconds": 30,
87+
# "enable_pathways_goodput": True,
88+
"enable_checkpoint_cloud_logger": True,
89+
"enable_single_controller": True,
90+
},
91+
xla_flags="",
92+
)
93+
)
94+
6395
default_32 = _add_to_model_dictionary(
6496
trillium_model_dict,
6597
MaxTextModel(
@@ -274,6 +306,48 @@ def _add_to_model_dictionary(model_dictionary: dict[str, MaxTextModel], maxtext_
274306
)
275307
)
276308

309+
llama2_7b_4096_pw = _add_to_model_dictionary(
310+
trillium_model_dict,
311+
MaxTextModel(
312+
model_name="llama2-7b-4096-pw",
313+
model_type="llama2-7b",
314+
tuning_params={
315+
"per_device_batch_size": 4,
316+
"ici_fsdp_parallelism": -1,
317+
"remat_policy": "full",
318+
"max_target_length": 4096,
319+
"attention": "flash",
320+
"gcs_metrics": True,
321+
"use_iota_embed": True,
322+
"dataset_path": "gs://max-datasets-rogue",
323+
"dataset_type": "synthetic",
324+
"reuse_example_batch": 1,
325+
"enable_checkpointing": False,
326+
"profiler": "xplane",
327+
"sa_block_q": 1024,
328+
"sa_block_q_dkv": 2048,
329+
"sa_block_q_dq": 2048,
330+
"steps": 1000000,
331+
332+
# Additional tuning params for pathways long running test.
333+
"enable_checkpointing": True,
334+
"async_checkpointing": True,
335+
"checkpoint_period": 100,
336+
"checkpoint_storage_use_ocdbt": False,
337+
"checkpoint_storage_use_zarr3": False,
338+
"metrics_file": "metrics.txt",
339+
"goodput_upload_interval_seconds": 30,
340+
# "enable_pathways_goodput": True,
341+
"enable_checkpoint_cloud_logger": True,
342+
"enable_single_controller": True,
343+
},
344+
xla_flags=(
345+
xla_flags_library.DENSE_VMEM_LIMIT_FLAG
346+
+ xla_flags_library.CF_FOR_ALL_GATHER
347+
),
348+
)
349+
)
350+
277351

278352
llama2_70b_4096 = _add_to_model_dictionary(
279353
trillium_model_dict,

benchmarks/maxtext_xpk_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class PathwaysConfig:
5353
server_image: str
5454
proxy_image: str
5555
runner_image: str
56+
remote_python_sidecar_image: str
5657

5758

5859
# TODO(@vbarr): Split out parameters related to XPK workload and a General workload
@@ -375,6 +376,8 @@ def generate_xpk_workload_cmd(
375376
'--use-pathways'
376377
f' --server-image={pw_config.server_image}'
377378
f' --proxy-server-image={pw_config.proxy_image}'
379+
f' --remote-python-sidecar-image={pw_config.remote_python_sidecar_image}'
380+
if pw_config.remote_python_sidecar_image is not None else ''
378381
' --termination-grace-period-seconds=300'
379382
f' --pathways-gcs-location={wl_config.base_output_directory}'
380383
f' --restart-on-user-code-failure'

benchmarks/recipes/__init__.py

Whitespace-only changes.

benchmarks/recipes/args_helper.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""Copyright 2025 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
import argparse
17+
import os
18+
import sys
19+
20+
# Needed to import files from the parent directory
21+
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
22+
sys.path.append(parent_dir)
23+
24+
import maxtext_xpk_runner as mxr
25+
26+
# Constants for defining supported actions
27+
DELETE = "delete"
28+
29+
30+
def _handle_delete(
31+
cluster_config: mxr.XpkClusterConfig, user: str, **kwargs
32+
) -> int:
33+
"""Handles the deletion of workloads.
34+
35+
Args:
36+
cluster_config: mxr.XpkClusterConfig object
37+
user: User string
38+
**kwargs: Optional keyword arguments, such as xpk_path
39+
"""
40+
xpk_path = kwargs.get("xpk_path", "xpk") # Default to "xpk" if not provided
41+
first_five_chars = user[:5]
42+
delete_command = (
43+
f"python3 {xpk_path}/xpk.py workload delete "
44+
f"--project={cluster_config.project} --cluster={cluster_config.cluster_name}"
45+
f" --filter-by-job={first_five_chars} --zone={cluster_config.zone}"
46+
)
47+
print(
48+
f"Deleting workloads starting with: {first_five_chars} using command:"
49+
f" {delete_command}"
50+
)
51+
os.system(delete_command)
52+
53+
54+
def handle_cmd_args(
55+
cluster_config: mxr.XpkClusterConfig, *actions: str, **kwargs
56+
) -> bool:
57+
"""Parses command-line arguments and executes the specified actions.
58+
59+
Args:
60+
cluster_config: Contains Cluster configuration information that's helpful
61+
for running the actions.
62+
*actions: Variable number of string arguments representing the actions to
63+
be performed.
64+
**kwargs: Optional keyword arguments to be passed to action handlers.
65+
66+
Raises:
67+
ValueError: If an unsupported action is provided or if unknown arguments are
68+
passed.
69+
"""
70+
71+
parser = argparse.ArgumentParser()
72+
73+
if DELETE in actions:
74+
parser.add_argument(
75+
"--delete",
76+
action="store_true",
77+
help="Delete workloads starting with the user's first five characters.",
78+
)
79+
80+
known_args, unknown_args = parser.parse_known_args()
81+
82+
if unknown_args:
83+
raise ValueError(f"Unrecognized arguments: {unknown_args}")
84+
85+
# Get user
86+
user = os.environ["USER"]
87+
88+
# Handle actions
89+
should_continue = True
90+
if DELETE in actions and known_args.delete:
91+
_handle_delete(cluster_config, user, **kwargs)
92+
should_continue = False
93+
94+
return should_continue
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Copyright 2025 Google LLC
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
import os
17+
import sys
18+
import args_helper as helper
19+
20+
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
21+
sys.path.append(parent_dir)
22+
23+
import maxtext_trillium_model_configs as model_configs
24+
import maxtext_xpk_runner as mxr
25+
26+
27+
def main() -> int:
28+
# V6e cluster config
29+
cluster_config = mxr.XpkClusterConfig(
30+
cluster_name="v6e-256-cluster",
31+
project="tpu-project",
32+
zone="us-east5-b",
33+
device_type="v6e-256",
34+
)
35+
36+
xpk_path = "xpk"
37+
38+
# Handle command line arguments using args_helper
39+
should_continue = helper.handle_cmd_args(
40+
cluster_config, helper.DELETE, xpk_path=xpk_path
41+
)
42+
43+
if not should_continue:
44+
return 0
45+
46+
# Configure test images
47+
user = os.environ["USER"]
48+
region = "-".join(cluster_config.zone.split("-")[:-1])
49+
proxy_image = (
50+
f"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/{user}/"
51+
"proxy_server:latest"
52+
)
53+
server_image = (
54+
f"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/{user}/"
55+
"server:latest"
56+
)
57+
remote_python_image = f"gcr.io/{cluster_config.project}/{user}/remote_python_sidecar_latest:latest"
58+
runner = f"gcr.io/{cluster_config.project}/{user}_latest:latest"
59+
base_output_directory = f"gs://{user}-{region}/{user}"
60+
61+
list_of_models = [
62+
model_configs.default_basic_1_pw,
63+
]
64+
pathways_config = mxr.PathwaysConfig(
65+
server_image=server_image,
66+
proxy_image=proxy_image,
67+
runner_image=runner,
68+
remote_python_sidecar_image=remote_python_image,
69+
)
70+
num_slices_list = [1]
71+
72+
xpk_workload_cmds = []
73+
xpk_workload_names = []
74+
75+
for model in list_of_models:
76+
# Run workloads on the below clusters
77+
for cluster_config in [
78+
cluster_config,
79+
]:
80+
# Run workloads in the following slice configurations
81+
for num_slices in num_slices_list:
82+
wl_config = mxr.WorkloadConfig(
83+
model=model,
84+
num_slices=num_slices,
85+
device_type=cluster_config.device_type,
86+
base_output_directory=base_output_directory,
87+
max_restarts=0,
88+
libtpu_type=None,
89+
libtpu_nightly_version="",
90+
base_docker_image="",
91+
pathways_config=pathways_config,
92+
xpk_path=xpk_path,
93+
num_steps=1000000,
94+
)
95+
command, name = mxr.generate_xpk_workload_cmd(
96+
cluster_config=cluster_config, wl_config=wl_config
97+
)
98+
99+
print(f"Name of the workload is: {name} \n")
100+
xpk_workload_names.append(name)
101+
102+
print(f"XPK command to be used is: {command} \n")
103+
xpk_workload_cmds.append(command)
104+
105+
for xpk_workload_name, xpk_workload_cmd in zip(
106+
xpk_workload_names, xpk_workload_cmds
107+
):
108+
return_code = mxr.run_command_with_updates(
109+
xpk_workload_cmd, xpk_workload_name
110+
)
111+
if return_code != 0:
112+
print(f"Unable to run xpk workload: {xpk_workload_name}")
113+
114+
115+
if __name__ == "__main__":
116+
main()

0 commit comments

Comments
 (0)