diff --git a/docs/docs/installation/index.md b/docs/docs/installation/index.md
index 524b5325f..128701378 100644
--- a/docs/docs/installation/index.md
+++ b/docs/docs/installation/index.md
@@ -94,6 +94,73 @@ Configuration is updated at ~/.dstack/config.yml
This configuration is stored in `~/.dstack/config.yml`.
+### (Optional) CLI Autocompletion
+
+`dstack` supports shell autocompletion for `bash` and `zsh`.
+
+=== "bash"
+
+ First, validate if completion scripts load correctly in your current shell session:
+
+
+
+ ```shell
+ $ eval "$(dstack completion bash)"
+ ```
+
+
+
+ If completions work as expected and you would like them to persist across shell sessions, add the completion script to your shell profile using these commands:
+
+
+
+ ```shell
+ $ mkdir -p ~/.dstack
+ $ dstack completion bash > ~/.dstack/completion.sh
+ $ echo 'source ~/.dstack/completion.sh' >> ~/.bashrc
+ ```
+
+
+
+=== "zsh"
+
+ First, validate if completion scripts load correctly in your current shell session:
+
+
+
+ ```shell
+ $ eval "$(dstack completion zsh)"
+ ```
+
+
+
+ If completions work as expected and you would like them to persist across shell sessions, you can install them via Oh My Zsh using these commands:
+
+
+
+ ```shell
+ $ mkdir -p ~/.oh-my-zsh/completions
+ $ dstack completion zsh > ~/.oh-my-zsh/completions/_dstack
+ ```
+
+
+
+ And if you don't use Oh My Zsh:
+
+
+
+ ```shell
+ $ mkdir -p ~/.dstack
+ $ dstack completion zsh > ~/.dstack/completion.sh
+ $ echo 'source ~/.dstack/completion.sh' >> ~/.zshrc
+ ```
+
+
+
+ > If you get an error similar to `2: command not found: compdef`, then add the following line to the beginning of your `~/.zshrc` file:
+ > `autoload -Uz compinit && compinit`.
+
+
!!! info "What's next?"
1. Check the [server/config.yml reference](../reference/server/config.yml.md) on how to configure backends
2. Check [SSH fleets](../concepts/fleets.md#ssh) to learn about running on your on-prem servers
diff --git a/setup.py b/setup.py
index c726b00be..e04d018ae 100644
--- a/setup.py
+++ b/setup.py
@@ -49,6 +49,7 @@ def get_long_description():
"filelock",
"psutil",
"gpuhunt>=0.0.19,<0.1.0",
+ "argcomplete>=3.5.0",
]
GATEWAY_AND_SERVER_COMMON_DEPS = [
diff --git a/src/dstack/_internal/cli/commands/__init__.py b/src/dstack/_internal/cli/commands/__init__.py
index 768c7f8e1..7a6f6eb61 100644
--- a/src/dstack/_internal/cli/commands/__init__.py
+++ b/src/dstack/_internal/cli/commands/__init__.py
@@ -5,6 +5,7 @@
from rich_argparse import RichHelpFormatter
+from dstack._internal.cli.services.completion import ProjectNameCompleter
from dstack._internal.cli.utils.common import configure_logging
from dstack.api import Client
@@ -61,7 +62,7 @@ def _register(self):
help="The name of the project. Defaults to [code]$DSTACK_PROJECT[/]",
metavar="NAME",
default=os.getenv("DSTACK_PROJECT"),
- )
+ ).completer = ProjectNameCompleter()
def _command(self, args: argparse.Namespace):
configure_logging()
diff --git a/src/dstack/_internal/cli/commands/apply.py b/src/dstack/_internal/cli/commands/apply.py
index 7283969b0..d3420441e 100644
--- a/src/dstack/_internal/cli/commands/apply.py
+++ b/src/dstack/_internal/cli/commands/apply.py
@@ -1,6 +1,8 @@
import argparse
from pathlib import Path
+from argcomplete import FilesCompleter
+
from dstack._internal.cli.commands import APIBaseCommand
from dstack._internal.cli.services.configurators import (
get_apply_configurator_class,
@@ -42,7 +44,7 @@ def _register(self):
metavar="FILE",
help="The path to the configuration file. Defaults to [code]$PWD/.dstack.yml[/]",
dest="configuration_file",
- )
+ ).completer = FilesCompleter(allowednames=["*.yml", "*.yaml"])
self._parser.add_argument(
"-y",
"--yes",
@@ -57,7 +59,7 @@ def _register(self):
self._parser.add_argument(
"-d",
"--detach",
- help="Exit immediately after sumbitting configuration",
+ help="Exit immediately after submitting configuration",
action="store_true",
)
repo_group = self._parser.add_argument_group("Repo Options")
diff --git a/src/dstack/_internal/cli/commands/attach.py b/src/dstack/_internal/cli/commands/attach.py
index f92c3de41..bca40ea37 100644
--- a/src/dstack/_internal/cli/commands/attach.py
+++ b/src/dstack/_internal/cli/commands/attach.py
@@ -6,6 +6,7 @@
from dstack._internal.cli.commands import APIBaseCommand
from dstack._internal.cli.services.args import port_mapping
+from dstack._internal.cli.services.completion import RunNameCompleter
from dstack._internal.cli.utils.common import console
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT
from dstack._internal.core.errors import CLIError
@@ -57,7 +58,7 @@ def _register(self):
type=int,
default=0,
)
- self._parser.add_argument("run_name")
+ self._parser.add_argument("run_name").completer = RunNameCompleter()
def _command(self, args: argparse.Namespace):
super()._command(args)
diff --git a/src/dstack/_internal/cli/commands/completion.py b/src/dstack/_internal/cli/commands/completion.py
new file mode 100644
index 000000000..961d2cdd9
--- /dev/null
+++ b/src/dstack/_internal/cli/commands/completion.py
@@ -0,0 +1,20 @@
+import argcomplete
+
+from dstack._internal.cli.commands import BaseCommand
+
+
+class CompletionCommand(BaseCommand):
+ NAME = "completion"
+ DESCRIPTION = "Generate shell completion scripts"
+
+ def _register(self):
+ super()._register()
+ self._parser.add_argument(
+ "shell",
+ help="The shell to generate the completion script for",
+ choices=["bash", "zsh"],
+ )
+
+ def _command(self, args):
+ super()._command(args)
+ print(argcomplete.shellcode(["dstack"], shell=args.shell))
diff --git a/src/dstack/_internal/cli/commands/delete.py b/src/dstack/_internal/cli/commands/delete.py
index 673205670..b651cd7a4 100644
--- a/src/dstack/_internal/cli/commands/delete.py
+++ b/src/dstack/_internal/cli/commands/delete.py
@@ -1,6 +1,8 @@
import argparse
from pathlib import Path
+from argcomplete import FilesCompleter
+
from dstack._internal.cli.commands import APIBaseCommand
from dstack._internal.cli.services.configurators import (
get_apply_configurator_class,
@@ -22,7 +24,7 @@ def _register(self):
metavar="FILE",
help="The path to the configuration file. Defaults to [code]$PWD/.dstack.yml[/]",
dest="configuration_file",
- )
+ ).completer = FilesCompleter(allowednames=["*.yml", "*.yaml"])
self._parser.add_argument(
"-y",
"--yes",
diff --git a/src/dstack/_internal/cli/commands/fleet.py b/src/dstack/_internal/cli/commands/fleet.py
index a3b1d8a3b..244563e12 100644
--- a/src/dstack/_internal/cli/commands/fleet.py
+++ b/src/dstack/_internal/cli/commands/fleet.py
@@ -4,6 +4,7 @@
from rich.live import Live
from dstack._internal.cli.commands import APIBaseCommand
+from dstack._internal.cli.services.completion import FleetNameCompleter
from dstack._internal.cli.utils.common import (
LIVE_TABLE_PROVISION_INTERVAL_SECS,
LIVE_TABLE_REFRESH_RATE_PER_SEC,
@@ -47,7 +48,7 @@ def _register(self):
delete_parser.add_argument(
"name",
help="The name of the fleet",
- )
+ ).completer = FleetNameCompleter()
delete_parser.add_argument(
"-i",
"--instance",
diff --git a/src/dstack/_internal/cli/commands/gateway.py b/src/dstack/_internal/cli/commands/gateway.py
index 195e08308..aa9d28958 100644
--- a/src/dstack/_internal/cli/commands/gateway.py
+++ b/src/dstack/_internal/cli/commands/gateway.py
@@ -4,6 +4,7 @@
from rich.live import Live
from dstack._internal.cli.commands import APIBaseCommand
+from dstack._internal.cli.services.completion import GatewayNameCompleter
from dstack._internal.cli.utils.common import (
LIVE_TABLE_PROVISION_INTERVAL_SECS,
LIVE_TABLE_REFRESH_RATE_PER_SEC,
@@ -59,7 +60,9 @@ def _register(self):
"delete", help="Delete a gateway", formatter_class=self._parser.formatter_class
)
delete_parser.set_defaults(subfunc=self._delete)
- delete_parser.add_argument("name", help="The name of the gateway")
+ delete_parser.add_argument(
+ "name", help="The name of the gateway"
+ ).completer = GatewayNameCompleter()
delete_parser.add_argument(
"-y", "--yes", action="store_true", help="Don't ask for confirmation"
)
@@ -68,7 +71,9 @@ def _register(self):
"update", help="Update a gateway", formatter_class=self._parser.formatter_class
)
update_parser.set_defaults(subfunc=self._update)
- update_parser.add_argument("name", help="The name of the gateway")
+ update_parser.add_argument(
+ "name", help="The name of the gateway"
+ ).completer = GatewayNameCompleter()
update_parser.add_argument(
"--set-default", action="store_true", help="Set it the default gateway for the project"
)
diff --git a/src/dstack/_internal/cli/commands/logs.py b/src/dstack/_internal/cli/commands/logs.py
index 369e5c507..ffcf55101 100644
--- a/src/dstack/_internal/cli/commands/logs.py
+++ b/src/dstack/_internal/cli/commands/logs.py
@@ -3,6 +3,7 @@
from pathlib import Path
from dstack._internal.cli.commands import APIBaseCommand
+from dstack._internal.cli.services.completion import RunNameCompleter
from dstack._internal.core.errors import CLIError
from dstack._internal.utils.logging import get_logger
@@ -33,7 +34,7 @@ def _register(self):
)
self._parser.add_argument(
"--replica",
- help="The relica number. Defaults to 0.",
+ help="The replica number. Defaults to 0.",
type=int,
default=0,
)
@@ -43,7 +44,7 @@ def _register(self):
type=int,
default=0,
)
- self._parser.add_argument("run_name")
+ self._parser.add_argument("run_name").completer = RunNameCompleter(all=True)
def _command(self, args: argparse.Namespace):
super()._command(args)
diff --git a/src/dstack/_internal/cli/commands/stats.py b/src/dstack/_internal/cli/commands/stats.py
index 20c3bdeff..12f52346b 100644
--- a/src/dstack/_internal/cli/commands/stats.py
+++ b/src/dstack/_internal/cli/commands/stats.py
@@ -7,6 +7,7 @@
from rich.table import Table
from dstack._internal.cli.commands import APIBaseCommand
+from dstack._internal.cli.services.completion import RunNameCompleter
from dstack._internal.cli.utils.common import (
LIVE_TABLE_PROVISION_INTERVAL_SECS,
LIVE_TABLE_REFRESH_RATE_PER_SEC,
@@ -25,7 +26,7 @@ class StatsCommand(APIBaseCommand):
def _register(self):
super()._register()
- self._parser.add_argument("run_name")
+ self._parser.add_argument("run_name").completer = RunNameCompleter()
self._parser.add_argument(
"-w",
"--watch",
diff --git a/src/dstack/_internal/cli/commands/stop.py b/src/dstack/_internal/cli/commands/stop.py
index ad08228b3..dced8d77e 100644
--- a/src/dstack/_internal/cli/commands/stop.py
+++ b/src/dstack/_internal/cli/commands/stop.py
@@ -1,6 +1,7 @@
import argparse
from dstack._internal.cli.commands import APIBaseCommand
+from dstack._internal.cli.services.completion import RunNameCompleter
from dstack._internal.cli.utils.common import confirm_ask
from dstack._internal.core.errors import CLIError
@@ -13,7 +14,7 @@ def _register(self):
super()._register()
self._parser.add_argument("-x", "--abort", action="store_true")
self._parser.add_argument("-y", "--yes", action="store_true")
- self._parser.add_argument("run_name")
+ self._parser.add_argument("run_name").completer = RunNameCompleter()
def _command(self, args: argparse.Namespace):
super()._command(args)
diff --git a/src/dstack/_internal/cli/commands/volume.py b/src/dstack/_internal/cli/commands/volume.py
index a4c10b342..5f7bc833e 100644
--- a/src/dstack/_internal/cli/commands/volume.py
+++ b/src/dstack/_internal/cli/commands/volume.py
@@ -4,6 +4,7 @@
from rich.live import Live
from dstack._internal.cli.commands import APIBaseCommand
+from dstack._internal.cli.services.completion import VolumeNameCompleter
from dstack._internal.cli.utils.common import (
LIVE_TABLE_PROVISION_INTERVAL_SECS,
LIVE_TABLE_REFRESH_RATE_PER_SEC,
@@ -47,7 +48,7 @@ def _register(self):
delete_parser.add_argument(
"name",
help="The name of the volume",
- )
+ ).completer = VolumeNameCompleter()
delete_parser.add_argument(
"-y", "--yes", help="Don't ask for confirmation", action="store_true"
)
diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py
index 431e07862..800595b73 100644
--- a/src/dstack/_internal/cli/main.py
+++ b/src/dstack/_internal/cli/main.py
@@ -1,10 +1,12 @@
import argparse
+import argcomplete
from rich.markup import escape
from rich_argparse import RichHelpFormatter
from dstack._internal.cli.commands.apply import ApplyCommand
from dstack._internal.cli.commands.attach import AttachCommand
+from dstack._internal.cli.commands.completion import CompletionCommand
from dstack._internal.cli.commands.config import ConfigCommand
from dstack._internal.cli.commands.delete import DeleteCommand
from dstack._internal.cli.commands.fleet import FleetCommand
@@ -72,9 +74,13 @@ def main():
StatsCommand.register(subparsers)
StopCommand.register(subparsers)
VolumeCommand.register(subparsers)
+ CompletionCommand.register(subparsers)
+
+ argcomplete.autocomplete(parser, always_complete_options=False)
args, unknown_args = parser.parse_known_args()
args.unknown = unknown_args
+
try:
check_for_updates()
get_ssh_client_info()
diff --git a/src/dstack/_internal/cli/services/completion.py b/src/dstack/_internal/cli/services/completion.py
new file mode 100644
index 000000000..ed8ce26ad
--- /dev/null
+++ b/src/dstack/_internal/cli/services/completion.py
@@ -0,0 +1,86 @@
+import argparse
+import os
+from abc import ABC, abstractmethod
+from typing import Iterable, List, Optional
+
+import argcomplete
+from argcomplete.completers import BaseCompleter
+
+from dstack._internal.core.errors import ConfigurationError
+from dstack._internal.core.services.configs import ConfigManager
+from dstack.api import Client
+
+
+class BaseAPINameCompleter(BaseCompleter, ABC):
+ """
+ Base class for name completers that fetch resource names via the API.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def get_api(self, parsed_args: argparse.Namespace) -> Optional[Client]:
+ argcomplete.debug(f"{self.__class__.__name__}: Retrieving API client")
+ project = getattr(parsed_args, "project", os.getenv("DSTACK_PROJECT"))
+ try:
+ return Client.from_config(project_name=project)
+ except ConfigurationError as e:
+ argcomplete.debug(f"{self.__class__.__name__}: Error initializing API client: {e}")
+ return None
+
+ def __call__(self, prefix: str, parsed_args: argparse.Namespace, **kwargs) -> List[str]:
+ api = self.get_api(parsed_args)
+ if api is None:
+ return []
+
+ argcomplete.debug(f"{self.__class__.__name__}: Fetching completions")
+ try:
+ resource_names = self.fetch_resource_names(api)
+ return [name for name in resource_names if name.startswith(prefix)]
+ except Exception as e:
+ argcomplete.debug(
+ f"{self.__class__.__name__}: Error fetching resource completions: {e}"
+ )
+ return []
+
+ @abstractmethod
+ def fetch_resource_names(self, api: Client) -> Iterable[str]:
+ """
+ Returns an iterable of resource names.
+ """
+ pass
+
+
+class RunNameCompleter(BaseAPINameCompleter):
+ def __init__(self, all: bool = False):
+ super().__init__()
+ self.all = all
+
+ def fetch_resource_names(self, api: Client) -> Iterable[str]:
+ return [r.name for r in api.runs.list(self.all)]
+
+
+class FleetNameCompleter(BaseAPINameCompleter):
+ def fetch_resource_names(self, api: Client) -> Iterable[str]:
+ return [r.name for r in api.client.fleets.list(api.project)]
+
+
+class VolumeNameCompleter(BaseAPINameCompleter):
+ def fetch_resource_names(self, api: Client) -> Iterable[str]:
+ return [r.name for r in api.client.volumes.list(api.project)]
+
+
+class GatewayNameCompleter(BaseAPINameCompleter):
+ def fetch_resource_names(self, api: Client) -> Iterable[str]:
+ return [r.name for r in api.client.gateways.list(api.project)]
+
+
+class ProjectNameCompleter(BaseCompleter):
+ """
+ Completer for local project names.
+ """
+
+ def __call__(self, prefix: str, parsed_args: argparse.Namespace, **kwargs) -> List[str]:
+ argcomplete.debug(f"{self.__class__.__name__}: Listing projects from ConfigManager")
+ projects = ConfigManager().list_projects()
+ return [p for p in projects if p.startswith(prefix)]
diff --git a/src/dstack/_internal/core/services/configs/__init__.py b/src/dstack/_internal/core/services/configs/__init__.py
index b0fa73444..f11983e38 100644
--- a/src/dstack/_internal/core/services/configs/__init__.py
+++ b/src/dstack/_internal/core/services/configs/__init__.py
@@ -65,6 +65,9 @@ def configure_project(self, name: str, url: str, token: str, default: bool):
if len(self.config.projects) == 1:
self.config.projects[0].default = True
+ def list_projects(self):
+ return [project.name for project in self.config.projects]
+
def delete_project(self, name: str):
self.config.projects = [p for p in self.config.projects if p.name != name]