diff --git a/tests/test_utils/python_scripts/generate_local_jobs.py b/tests/test_utils/python_scripts/generate_local_jobs.py index 175492175d..a47311b52b 100644 --- a/tests/test_utils/python_scripts/generate_local_jobs.py +++ b/tests/test_utils/python_scripts/generate_local_jobs.py @@ -29,6 +29,12 @@ def load_script(config_path: str) -> str: @click.option( "--test-case", required=False, type=str, help="Returns a single test-case with matching name." ) +@click.option( + "--environment", + required=True, + type=str, + help="Pass 'lts' for PyTorch 24.01 and 'dev' for a more recent version.", +) @click.option( "--output-path", required=True, @@ -36,9 +42,20 @@ def load_script(config_path: str) -> str: help="Directory where the functional test will write its artifacts to (Tensorboard logs)", default="/opt/megatron-lm", ) -def main(model: Optional[str], scope: Optional[str], test_case: Optional[str], output_path: str): +def main( + model: Optional[str], + scope: Optional[str], + test_case: Optional[str], + environment: str, + output_path: str, +): workloads = common.load_workloads( - container_image='none', scope=scope, model=model, test_case=test_case, container_tag='none' + container_image='none', + scope=scope, + model=model, + test_case=test_case, + environment=environment, + container_tag='none', ) for workload in workloads: