Skip to content

Commit

Permalink
Merge pull request #1172 from AI-Hypercomputer:xpk_runner
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716454014
  • Loading branch information
maxtext authors committed Jan 17, 2025
2 parents 3ad02ba + 13dbefa commit 1ea7602
Show file tree
Hide file tree
Showing 4 changed files with 603 additions and 376 deletions.
132 changes: 48 additions & 84 deletions benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,21 @@
limitations under the License.
"""

""" Script to run a benchmark/benchmarks on exsiting xpk or QR nodes (to be implmented)
""" Script to run a benchmark/benchmarks on existing xpk or QR nodes (to be implemented)
***** IMPORTANT *****
This script will run specific tuned workload on specified hardwear and software enviroments
This script will run specific tuned workload on specified hardware and software environments
Example usages:
python3 benchmark_runner.py --project=<my-project> --zone=<zone> \
--cluster_name=<xpk_cluster_name> --base_output_directory=<output_gcloud_bucket> --device_type=v6e-256 --num_slices=1 --model_name="llama2_70b_4096" --libtpu_version=20241009 --base_docker_image=maxtext_base_image
"""
import argparse
import importlib

import maxtext_trillium_model_configs
from maxtext_xpk_runner import BenchmarkRunner
from maxtext_xpk_runner import HWConfig
from maxtext_xpk_runner import SWconfig
from maxtext_trillium_model_configs import trillium_model_dict
from maxtext_xpk_runner import PathwaysConfig
from maxtext_xpk_runner import WorkloadConfig
from maxtext_xpk_runner import xpk_benchmark_runner
from maxtext_xpk_runner import XpkConfig

from maxtext_xpk_runner import XpkClusterConfig
from maxtext_xpk_runner import LibTpuType

def add_shared_arguments(custom_parser: argparse.ArgumentParser):
"""Add shared arguments to the parser.
Expand Down Expand Up @@ -65,7 +62,7 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
'--base_output_directory',
type=str,
default=None, required=True,
help='gcloud bucket to store arfifacts.',
help='gcloud bucket to store artifacts.',
)
custom_parser.add_argument(
'--device_type',
Expand All @@ -82,45 +79,10 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
custom_parser.add_argument(
'--model_name',
type=str,
choices=[
'gpt_3_175b',
'llama2_7b_4096',
'llama2_70b_4096',
'llama2_70b_4096_real_data',
'llama2_70b_4096_pw_long_run',
'llama2_70b_4096_real_data_pw_long_run',
'llama2_70b_4096_pw_rd_tfds',
'llama2_70b_4096_synthetic_pw_lr',
'llama2_70b_4096_synthetic',
'llama3_70b_8192',
'llama3_1_405b_8192_fsdp_dcn',
'mixtral_8x7b_dropped',
'mixtral_8x7b_dropped_int8',
'mixtral_8x7b_dropless',
'gemma2_9b_8192',
'gemma2_27b_8192',
'llama3_1_70b_129024',
'llama3_1_8b_8192',
'llama3_1_70b_8192',
],
default='llama2_70b_4096',
choices=list(trillium_model_dict.keys()),
default=list(trillium_model_dict.keys())[0],
help=(
'model to be benchmarked, supported models are gpt_3_175b '
'llama2_7b_4096 '
'llama2_70b_4096 '
'llama2_70b_4096_real_data '
'llama2_70b_4096_pw_long_run '
'llama2_70b_4096_real_data_pw_long_run '
'llama2_70b_4096_pw_rd_tfds '
'llama2_70b_4096_synthetic_pw_lr '
'llama2_70b_4096_synthetic '
'llama3_1_405b_8192_fsdp_dcn '
'mixtral_8x7b_dropped '
'mixtral_8x7b_dropped_int8 '
'mixtral_8x7b_dropless '
'gemma2_9b_8192 '
'gemma2_27b_8192 '
'command.'
f'model to be benchmarked, supported models are the command choices.'
),
)
custom_parser.add_argument(
Expand Down Expand Up @@ -173,6 +135,12 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
default='medium',
help='Priority the XPK workload should run with.',
)
custom_parser.add_argument(
'--num_steps',
type=int,
default=20,
help='Number of steps to run the workload for.',
)
custom_parser.add_argument(
'--max_restarts',
type=int,
Expand All @@ -188,42 +156,38 @@ def main() -> None:
add_shared_arguments(parser)
options = parser.parse_args()

cluster_config = XpkConfig(
cluster_name=options.cluster_name,
project=options.project,
zone=options.zone,
num_slices=options.num_slices,
device_type=options.device_type,
base_output_directory=options.base_output_directory,
priority=options.priority,
max_restarts=options.max_restarts,
)

v6e_env_configs = SWconfig(
base_docker_image=options.base_docker_image,
libtpu_version=options.libtpu_version,
pathways_config=PathwaysConfig(
use_pathways=options.use_pathways,
server_image=options.pathways_server_image,
proxy_image=options.pathways_proxy_image,
runner_image=options.pathways_runner_image,
),
)

v6e_256_configs = HWConfig(
num_slices=options.num_slices, device_type=options.device_type
)

model_sets = importlib.import_module('maxtext_trillium_model_configs')
benchmark_model = getattr(model_sets, options.model_name)

model_runner = BenchmarkRunner(
model_name=benchmark_model,
software_config=v6e_env_configs,
hardware_config=v6e_256_configs,
)

xpk_benchmark_runner(cluster_config, [model_runner], options.xpk_path)
cluster_config = XpkClusterConfig(
cluster_name=options.cluster_name,
project=options.project,
zone=options.zone,
device_type=options.device_type
)

pw_config = None
if options.use_pathways:
pw_config = PathwaysConfig(
server_image=options.pathways_server_image,
proxy_image=options.pathways_proxy_image,
runner_image=options.pathways_runner_image,
)

assert trillium_model_dict.get(options.model_name) is not None, f'Invalid model name: {options.model_name}'
workload_config = WorkloadConfig(
model=trillium_model_dict.get(options.model_name),
num_slices=options.num_slices,
num_steps=options.num_steps,
device_type=options.device_type,
base_output_directory=options.base_output_directory,
priority=options.priority,
max_restarts=options.max_restarts,
libtpu_type=LibTpuType.NIGHTLY,
libtpu_nightly_version=options.libtpu_version,
base_docker_image=options.base_docker_image,
xpk_path=options.xpk_path,
pathways_config=pw_config
)

xpk_benchmark_runner(cluster_config, [workload_config])


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 1ea7602

Please sign in to comment.