diff --git a/pyproject.toml b/pyproject.toml index 27e4b8b..4310217 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,10 +46,10 @@ docker_persistent = "nemo_run.run.torchx_backend.schedulers.docker:create_schedu [project.optional-dependencies] skypilot = [ - "skypilot[kubernetes]>=v0.6.1", + "skypilot-nightly[kubernetes]", ] skypilot-all = [ - "skypilot[all]>=v0.6.1", + "skypilot-nightly[all]", ] [build-system] diff --git a/requirements-dev.lock b/requirements-dev.lock index a7a2817..07ea6b0 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -11,461 +11,167 @@ -e file:. absl-py==2.1.0 - # via fiddle antlr4-python3-runtime==4.9.3 - # via omegaconf anyio==4.4.0 - # via httpx - # via jupyter-server -appnope==0.1.4 - # via ipykernel argon2-cffi==23.1.0 - # via jupyter-server argon2-cffi-bindings==21.2.0 - # via argon2-cffi arrow==1.3.0 - # via isoduration asttokens==2.4.1 - # via stack-data async-lru==2.0.4 - # via jupyterlab attrs==24.2.0 - # via jsonschema - # via referencing babel==2.16.0 - # via jupyterlab-server bcrypt==4.2.0 - # via paramiko beautifulsoup4==4.12.3 - # via nbconvert bleach==6.1.0 - # via nbconvert cachetools==5.5.0 - # via google-auth - # via skypilot catalogue==2.0.10 - # via nemo-run certifi==2024.7.4 - # via httpcore - # via httpx - # via kubernetes - # via requests cffi==1.17.0 - # via argon2-cffi-bindings - # via cryptography - # via pynacl cfgv==3.4.0 - # via pre-commit charset-normalizer==3.3.2 - # via requests click==8.1.7 - # via skypilot - # via typer colorama==0.4.6 - # via skypilot comm==0.2.2 - # via ipykernel - # via ipywidgets coverage==7.6.1 cryptography==42.0.8 - # via nemo-run - # via paramiko - # via skypilot debugpy==1.8.5 - # via ipykernel decorator==5.1.1 - # via fabric - # via ipython defusedxml==0.7.1 - # via nbconvert deprecated==1.2.14 - # via fabric distlib==0.3.8 - # via virtualenv docker==7.1.0 - # via torchx docstring-parser==0.16 - # via torchx executing==2.0.1 - # via stack-data fabric==3.2.2 - # via nemo-run fastjsonschema==2.20.0 - # via nbformat fiddle==0.3.0 - # via nemo-run filelock==3.15.4 - # via skypilot - # via torchx - # via virtualenv fqdn==1.5.1 - # via jsonschema fsspec==2024.6.1 - # via torchx google-auth==2.34.0 - # via kubernetes graphviz==0.20.3 - # via fiddle h11==0.14.0 - # via httpcore httpcore==1.0.5 - # via httpx httpx==0.27.2 - # via jupyterlab identify==2.6.0 - # via pre-commit idna==3.7 - # via anyio - # via httpx - # via jsonschema - # via requests importlib-metadata==8.3.0 - # via torchx iniconfig==2.0.0 - # via pytest inquirerpy==0.3.4 - # via nemo-run invoke==2.2.0 - # via fabric ipykernel==6.29.5 - # via jupyter - # via jupyter-console - # via jupyterlab ipython==8.26.0 - # via ipykernel - # via ipywidgets - # via jupyter-console ipywidgets==8.1.3 - # via jupyter isoduration==20.11.0 - # via jsonschema jedi==0.19.1 - # via ipython jinja2==3.1.4 - # via jupyter-server - # via jupyterlab - # via jupyterlab-server - # via nbconvert - # via nemo-run - # via skypilot json5==0.9.25 - # via jupyterlab-server jsonpointer==3.0.0 - # via jsonschema jsonschema==4.23.0 - # via jupyter-events - # via jupyterlab-server - # via nbformat - # via skypilot jsonschema-specifications==2023.12.1 - # via jsonschema jupyter==1.1.1 jupyter-client==8.6.2 - # via ipykernel - # via jupyter-console - # via jupyter-server - # via nbclient jupyter-console==6.6.3 - # via jupyter jupyter-core==5.7.2 - # via ipykernel - # via jupyter-client - # via jupyter-console - # via jupyter-server - # via jupyterlab - # via nbclient - # via nbconvert - # via nbformat jupyter-events==0.10.0 - # via jupyter-server jupyter-lsp==2.2.5 - # via jupyterlab jupyter-server==2.14.2 - # via jupyter-lsp - # via jupyterlab - # via jupyterlab-server - # via notebook - # via notebook-shim jupyter-server-terminals==0.5.3 - # via jupyter-server jupyterlab==4.2.5 - # via jupyter - # via notebook jupyterlab-pygments==0.3.0 - # via nbconvert jupyterlab-server==2.27.3 - # via jupyterlab - # via notebook jupyterlab-widgets==3.0.11 - # via ipywidgets kubernetes==30.1.0 - # via skypilot libcst==1.4.0 - # via fiddle markdown-it-py==3.0.0 - # via rich markupsafe==2.1.5 - # via jinja2 - # via nbconvert matplotlib-inline==0.1.7 - # via ipykernel - # via ipython mdurl==0.1.2 - # via markdown-it-py mistune==3.0.2 - # via nbconvert mypy-extensions==1.0.0 - # via typing-inspect nbclient==0.10.0 - # via nbconvert nbconvert==7.16.4 - # via jupyter - # via jupyter-server nbformat==5.10.4 - # via jupyter-server - # via nbclient - # via nbconvert nest-asyncio==1.6.0 - # via ipykernel networkx==3.3 - # via nemo-run - # via skypilot nodeenv==1.9.1 - # via pre-commit notebook==7.2.2 - # via jupyter notebook-shim==0.2.4 - # via jupyterlab - # via notebook numpy==2.1.0 - # via pandas oauthlib==3.2.2 - # via kubernetes - # via requests-oauthlib omegaconf==2.3.0 - # via nemo-run overrides==7.7.0 - # via jupyter-server packaging==24.1 - # via ipykernel - # via jupyter-server - # via jupyterlab - # via jupyterlab-server - # via nbconvert - # via pytest - # via pytest-sugar - # via skypilot pandas==2.2.2 - # via skypilot pandocfilters==1.5.1 - # via nbconvert paramiko==3.4.1 - # via fabric parso==0.8.4 - # via jedi pendulum==3.0.0 - # via skypilot pexpect==4.9.0 - # via ipython pfzy==0.3.4 - # via inquirerpy platformdirs==4.2.2 - # via jupyter-core - # via virtualenv pluggy==1.5.0 - # via pytest pre-commit==3.8.0 prettytable==3.11.0 - # via skypilot prometheus-client==0.20.0 - # via jupyter-server prompt-toolkit==3.0.47 - # via inquirerpy - # via ipython - # via jupyter-console psutil==6.0.0 - # via ipykernel - # via skypilot ptyprocess==0.7.0 - # via pexpect - # via terminado pulp==2.9.0 - # via skypilot pure-eval==0.2.3 - # via stack-data pyasn1==0.6.0 - # via pyasn1-modules - # via rsa pyasn1-modules==0.4.0 - # via google-auth pycparser==2.22 - # via cffi pygments==2.18.0 - # via ipython - # via jupyter-console - # via nbconvert - # via rich pynacl==1.5.0 - # via paramiko pyre-extensions==0.0.30 - # via torchx pytest==8.3.2 - # via pytest-mock - # via pytest-sugar pytest-mock==3.14.0 pytest-sugar==1.0.0 python-dateutil==2.9.0.post0 - # via arrow - # via jupyter-client - # via kubernetes - # via pandas - # via pendulum - # via time-machine python-dotenv==1.0.1 - # via skypilot python-json-logger==2.0.7 - # via jupyter-events pytz==2024.1 - # via pandas pyyaml==6.0.2 - # via jupyter-events - # via kubernetes - # via libcst - # via omegaconf - # via pre-commit - # via skypilot - # via torchx pyzmq==26.1.1 - # via ipykernel - # via jupyter-client - # via jupyter-console - # via jupyter-server referencing==0.35.1 - # via jsonschema - # via jsonschema-specifications - # via jupyter-events requests==2.32.3 - # via docker - # via jupyterlab-server - # via kubernetes - # via requests-oauthlib - # via skypilot requests-oauthlib==2.0.0 - # via kubernetes rfc3339-validator==0.1.4 - # via jsonschema - # via jupyter-events rfc3986-validator==0.1.1 - # via jsonschema - # via jupyter-events rich==13.7.1 - # via nemo-run - # via skypilot - # via typer rpds-py==0.20.0 - # via jsonschema - # via referencing rsa==4.9 - # via google-auth ruff==0.6.1 send2trash==1.8.3 - # via jupyter-server setuptools==74.1.2 - # via jupyterlab shellingham==1.5.4 - # via typer six==1.16.0 - # via asttokens - # via bleach - # via kubernetes - # via python-dateutil - # via rfc3339-validator -skypilot==0.6.1 - # via nemo-run +skypilot-nightly==1.0.0.dev20241205 sniffio==1.3.1 - # via anyio - # via httpx soupsieve==2.6 - # via beautifulsoup4 stack-data==0.6.3 - # via ipython tabulate==0.9.0 - # via skypilot - # via torchx termcolor==2.4.0 - # via pytest-sugar terminado==0.18.1 - # via jupyter-server - # via jupyter-server-terminals time-machine==2.15.0 - # via pendulum tinycss2==1.3.0 - # via nbconvert torchx==0.7.0 - # via nemo-run tornado==6.4.1 - # via ipykernel - # via jupyter-client - # via jupyter-server - # via jupyterlab - # via notebook - # via terminado traitlets==5.14.3 - # via comm - # via ipykernel - # via ipython - # via ipywidgets - # via jupyter-client - # via jupyter-console - # via jupyter-core - # via jupyter-events - # via jupyter-server - # via jupyterlab - # via matplotlib-inline - # via nbclient - # via nbconvert - # via nbformat typer==0.12.4 - # via nemo-run types-python-dateutil==2.9.0.20240906 - # via arrow typing-extensions==4.12.2 - # via fiddle - # via ipython - # via pyre-extensions - # via skypilot - # via typer - # via typing-inspect typing-inspect==0.9.0 - # via pyre-extensions tzdata==2024.1 - # via pandas - # via pendulum uri-template==1.3.0 - # via jsonschema urllib3==1.26.19 - # via docker - # via kubernetes - # via requests - # via torchx virtualenv==20.26.3 - # via pre-commit wcwidth==0.2.13 - # via prettytable - # via prompt-toolkit webcolors==24.8.0 - # via jsonschema webencodings==0.5.1 - # via bleach - # via tinycss2 websocket-client==1.8.0 - # via jupyter-server - # via kubernetes wheel==0.44.0 - # via skypilot widgetsnbextension==4.0.11 - # via ipywidgets wrapt==1.16.0 - # via deprecated zipp==3.20.0 - # via importlib-metadata diff --git a/requirements.lock b/requirements.lock index 50511b1..d072ce9 100644 --- a/requirements.lock +++ b/requirements.lock @@ -11,199 +11,82 @@ -e file:. absl-py==2.1.0 - # via fiddle antlr4-python3-runtime==4.9.3 - # via omegaconf attrs==24.2.0 - # via jsonschema - # via referencing bcrypt==4.2.0 - # via paramiko cachetools==5.5.0 - # via google-auth - # via skypilot catalogue==2.0.10 - # via nemo-run certifi==2024.7.4 - # via kubernetes - # via requests cffi==1.17.0 - # via cryptography - # via pynacl charset-normalizer==3.3.2 - # via requests click==8.1.7 - # via skypilot - # via typer colorama==0.4.6 - # via skypilot cryptography==42.0.8 - # via nemo-run - # via paramiko - # via skypilot decorator==5.1.1 - # via fabric deprecated==1.2.14 - # via fabric docker==7.1.0 - # via torchx docstring-parser==0.16 - # via torchx fabric==3.2.2 - # via nemo-run fiddle==0.3.0 - # via nemo-run filelock==3.15.4 - # via skypilot - # via torchx fsspec==2024.6.1 - # via torchx google-auth==2.34.0 - # via kubernetes graphviz==0.20.3 - # via fiddle idna==3.7 - # via requests importlib-metadata==8.3.0 - # via torchx inquirerpy==0.3.4 - # via nemo-run invoke==2.2.0 - # via fabric jinja2==3.1.4 - # via nemo-run - # via skypilot jsonschema==4.23.0 - # via skypilot jsonschema-specifications==2023.12.1 - # via jsonschema kubernetes==30.1.0 - # via skypilot libcst==1.4.0 - # via fiddle markdown-it-py==3.0.0 - # via rich markupsafe==2.1.5 - # via jinja2 mdurl==0.1.2 - # via markdown-it-py mypy-extensions==1.0.0 - # via typing-inspect networkx==3.3 - # via nemo-run - # via skypilot numpy==2.1.0 - # via pandas oauthlib==3.2.2 - # via kubernetes - # via requests-oauthlib omegaconf==2.3.0 - # via nemo-run packaging==24.1 - # via skypilot pandas==2.2.2 - # via skypilot paramiko==3.4.1 - # via fabric pendulum==3.0.0 - # via skypilot pfzy==0.3.4 - # via inquirerpy prettytable==3.11.0 - # via skypilot prompt-toolkit==3.0.47 - # via inquirerpy psutil==6.0.0 - # via skypilot pulp==2.9.0 - # via skypilot pyasn1==0.6.0 - # via pyasn1-modules - # via rsa pyasn1-modules==0.4.0 - # via google-auth pycparser==2.22 - # via cffi pygments==2.18.0 - # via rich pynacl==1.5.0 - # via paramiko pyre-extensions==0.0.30 - # via torchx python-dateutil==2.9.0.post0 - # via kubernetes - # via pandas - # via pendulum - # via time-machine python-dotenv==1.0.1 - # via skypilot pytz==2024.1 - # via pandas pyyaml==6.0.2 - # via kubernetes - # via libcst - # via omegaconf - # via skypilot - # via torchx referencing==0.35.1 - # via jsonschema - # via jsonschema-specifications requests==2.32.3 - # via docker - # via kubernetes - # via requests-oauthlib - # via skypilot requests-oauthlib==2.0.0 - # via kubernetes rich==13.7.1 - # via nemo-run - # via skypilot - # via typer rpds-py==0.20.0 - # via jsonschema - # via referencing rsa==4.9 - # via google-auth shellingham==1.5.4 - # via typer six==1.16.0 - # via kubernetes - # via python-dateutil -skypilot==0.6.1 - # via nemo-run +skypilot-nightly==1.0.0.dev20241205 tabulate==0.9.0 - # via skypilot - # via torchx time-machine==2.15.0 - # via pendulum torchx==0.7.0 - # via nemo-run typer==0.12.4 - # via nemo-run typing-extensions==4.12.2 - # via fiddle - # via pyre-extensions - # via skypilot - # via typer - # via typing-inspect typing-inspect==0.9.0 - # via pyre-extensions tzdata==2024.1 - # via pandas - # via pendulum urllib3==1.26.19 - # via docker - # via kubernetes - # via requests - # via torchx wcwidth==0.2.13 - # via prettytable - # via prompt-toolkit websocket-client==1.8.0 - # via kubernetes wheel==0.44.0 - # via skypilot wrapt==1.16.0 - # via deprecated zipp==3.20.0 - # via importlib-metadata diff --git a/src/nemo_run/core/execution/skypilot.py b/src/nemo_run/core/execution/skypilot.py index 7b4a1d4..dacf627 100644 --- a/src/nemo_run/core/execution/skypilot.py +++ b/src/nemo_run/core/execution/skypilot.py @@ -18,7 +18,7 @@ import subprocess from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, Type, Union +from typing import Any, Optional, Type, Union from invoke.context import Context @@ -104,6 +104,7 @@ class SkypilotExecutor(Executor): autodown: bool = False idle_minutes_to_autostop: Optional[int] = None torchrun_nproc_per_node: Optional[int] = None + cluster_config_overrides: Optional[dict[str, Any]] = None packager: Packager = field(default_factory=lambda: GitArchivePackager()) # type: ignore # noqa: F821 def __post_init__(self): @@ -153,12 +154,12 @@ def parse_attr(attr: str): if len(any_of) < i + 1: any_of.append({}) - if val.lower() == "none": - any_of[i][attr] = val + if isinstance(val, str) and val.lower() == "none": + any_of[i][attr] = None else: any_of[i][attr] = val else: - if value.lower() == "none": + if isinstance(value, str) and value.lower() == "none": resources_cfg[attr] = None else: resources_cfg[attr] = value @@ -182,6 +183,9 @@ def parse_attr(attr: str): parse_attr(attr) resources_cfg["any_of"] = any_of + if self.cluster_config_overrides: + resources_cfg["_cluster_config_overrides"] = self.cluster_config_overrides + resources = Resources.from_yaml_config(resources_cfg) return resources # type: ignore @@ -411,8 +415,8 @@ def launch( backend=backend, idle_minutes_to_autostop=self.idle_minutes_to_autostop, down=self.autodown, + fast=True, # retry_until_up=retry_until_up, - no_setup=True if (self.cluster_name and not self.setup) else False, # clone_disk_from=clone_disk_from, )