diff --git a/docs/docs/installation/index.md b/docs/docs/installation/index.md index 524b5325f..128701378 100644 --- a/docs/docs/installation/index.md +++ b/docs/docs/installation/index.md @@ -94,6 +94,73 @@ Configuration is updated at ~/.dstack/config.yml This configuration is stored in `~/.dstack/config.yml`. +### (Optional) CLI Autocompletion + +`dstack` supports shell autocompletion for `bash` and `zsh`. + +=== "bash" + + First, validate if completion scripts load correctly in your current shell session: + +
+ + ```shell + $ eval "$(dstack completion bash)" + ``` + +
+ + If completions work as expected and you would like them to persist across shell sessions, add the completion script to your shell profile using these commands: + +
+ + ```shell + $ mkdir -p ~/.dstack + $ dstack completion bash > ~/.dstack/completion.sh + $ echo 'source ~/.dstack/completion.sh' >> ~/.bashrc + ``` + +
+ +=== "zsh" + + First, validate if completion scripts load correctly in your current shell session: + +
+ + ```shell + $ eval "$(dstack completion zsh)" + ``` + +
+ + If completions work as expected and you would like them to persist across shell sessions, you can install them via Oh My Zsh using these commands: + +
+ + ```shell + $ mkdir -p ~/.oh-my-zsh/completions + $ dstack completion zsh > ~/.oh-my-zsh/completions/_dstack + ``` + +
+ + And if you don't use Oh My Zsh: + +
+ + ```shell + $ mkdir -p ~/.dstack + $ dstack completion zsh > ~/.dstack/completion.sh + $ echo 'source ~/.dstack/completion.sh' >> ~/.zshrc + ``` + +
+ + > If you get an error similar to `2: command not found: compdef`, then add the following line to the beginning of your `~/.zshrc` file: + > `autoload -Uz compinit && compinit`. + + !!! info "What's next?" 1. Check the [server/config.yml reference](../reference/server/config.yml.md) on how to configure backends 2. Check [SSH fleets](../concepts/fleets.md#ssh) to learn about running on your on-prem servers diff --git a/setup.py b/setup.py index c726b00be..e04d018ae 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ def get_long_description(): "filelock", "psutil", "gpuhunt>=0.0.19,<0.1.0", + "argcomplete>=3.5.0", ] GATEWAY_AND_SERVER_COMMON_DEPS = [ diff --git a/src/dstack/_internal/cli/commands/__init__.py b/src/dstack/_internal/cli/commands/__init__.py index 768c7f8e1..7a6f6eb61 100644 --- a/src/dstack/_internal/cli/commands/__init__.py +++ b/src/dstack/_internal/cli/commands/__init__.py @@ -5,6 +5,7 @@ from rich_argparse import RichHelpFormatter +from dstack._internal.cli.services.completion import ProjectNameCompleter from dstack._internal.cli.utils.common import configure_logging from dstack.api import Client @@ -61,7 +62,7 @@ def _register(self): help="The name of the project. Defaults to [code]$DSTACK_PROJECT[/]", metavar="NAME", default=os.getenv("DSTACK_PROJECT"), - ) + ).completer = ProjectNameCompleter() def _command(self, args: argparse.Namespace): configure_logging() diff --git a/src/dstack/_internal/cli/commands/apply.py b/src/dstack/_internal/cli/commands/apply.py index 7283969b0..d3420441e 100644 --- a/src/dstack/_internal/cli/commands/apply.py +++ b/src/dstack/_internal/cli/commands/apply.py @@ -1,6 +1,8 @@ import argparse from pathlib import Path +from argcomplete import FilesCompleter + from dstack._internal.cli.commands import APIBaseCommand from dstack._internal.cli.services.configurators import ( get_apply_configurator_class, @@ -42,7 +44,7 @@ def _register(self): metavar="FILE", help="The path to the configuration file. Defaults to [code]$PWD/.dstack.yml[/]", dest="configuration_file", - ) + ).completer = FilesCompleter(allowednames=["*.yml", "*.yaml"]) self._parser.add_argument( "-y", "--yes", @@ -57,7 +59,7 @@ def _register(self): self._parser.add_argument( "-d", "--detach", - help="Exit immediately after sumbitting configuration", + help="Exit immediately after submitting configuration", action="store_true", ) repo_group = self._parser.add_argument_group("Repo Options") diff --git a/src/dstack/_internal/cli/commands/attach.py b/src/dstack/_internal/cli/commands/attach.py index f92c3de41..bca40ea37 100644 --- a/src/dstack/_internal/cli/commands/attach.py +++ b/src/dstack/_internal/cli/commands/attach.py @@ -6,6 +6,7 @@ from dstack._internal.cli.commands import APIBaseCommand from dstack._internal.cli.services.args import port_mapping +from dstack._internal.cli.services.completion import RunNameCompleter from dstack._internal.cli.utils.common import console from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT from dstack._internal.core.errors import CLIError @@ -57,7 +58,7 @@ def _register(self): type=int, default=0, ) - self._parser.add_argument("run_name") + self._parser.add_argument("run_name").completer = RunNameCompleter() def _command(self, args: argparse.Namespace): super()._command(args) diff --git a/src/dstack/_internal/cli/commands/completion.py b/src/dstack/_internal/cli/commands/completion.py new file mode 100644 index 000000000..961d2cdd9 --- /dev/null +++ b/src/dstack/_internal/cli/commands/completion.py @@ -0,0 +1,20 @@ +import argcomplete + +from dstack._internal.cli.commands import BaseCommand + + +class CompletionCommand(BaseCommand): + NAME = "completion" + DESCRIPTION = "Generate shell completion scripts" + + def _register(self): + super()._register() + self._parser.add_argument( + "shell", + help="The shell to generate the completion script for", + choices=["bash", "zsh"], + ) + + def _command(self, args): + super()._command(args) + print(argcomplete.shellcode(["dstack"], shell=args.shell)) diff --git a/src/dstack/_internal/cli/commands/delete.py b/src/dstack/_internal/cli/commands/delete.py index 673205670..b651cd7a4 100644 --- a/src/dstack/_internal/cli/commands/delete.py +++ b/src/dstack/_internal/cli/commands/delete.py @@ -1,6 +1,8 @@ import argparse from pathlib import Path +from argcomplete import FilesCompleter + from dstack._internal.cli.commands import APIBaseCommand from dstack._internal.cli.services.configurators import ( get_apply_configurator_class, @@ -22,7 +24,7 @@ def _register(self): metavar="FILE", help="The path to the configuration file. Defaults to [code]$PWD/.dstack.yml[/]", dest="configuration_file", - ) + ).completer = FilesCompleter(allowednames=["*.yml", "*.yaml"]) self._parser.add_argument( "-y", "--yes", diff --git a/src/dstack/_internal/cli/commands/fleet.py b/src/dstack/_internal/cli/commands/fleet.py index a3b1d8a3b..244563e12 100644 --- a/src/dstack/_internal/cli/commands/fleet.py +++ b/src/dstack/_internal/cli/commands/fleet.py @@ -4,6 +4,7 @@ from rich.live import Live from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.completion import FleetNameCompleter from dstack._internal.cli.utils.common import ( LIVE_TABLE_PROVISION_INTERVAL_SECS, LIVE_TABLE_REFRESH_RATE_PER_SEC, @@ -47,7 +48,7 @@ def _register(self): delete_parser.add_argument( "name", help="The name of the fleet", - ) + ).completer = FleetNameCompleter() delete_parser.add_argument( "-i", "--instance", diff --git a/src/dstack/_internal/cli/commands/gateway.py b/src/dstack/_internal/cli/commands/gateway.py index 195e08308..aa9d28958 100644 --- a/src/dstack/_internal/cli/commands/gateway.py +++ b/src/dstack/_internal/cli/commands/gateway.py @@ -4,6 +4,7 @@ from rich.live import Live from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.completion import GatewayNameCompleter from dstack._internal.cli.utils.common import ( LIVE_TABLE_PROVISION_INTERVAL_SECS, LIVE_TABLE_REFRESH_RATE_PER_SEC, @@ -59,7 +60,9 @@ def _register(self): "delete", help="Delete a gateway", formatter_class=self._parser.formatter_class ) delete_parser.set_defaults(subfunc=self._delete) - delete_parser.add_argument("name", help="The name of the gateway") + delete_parser.add_argument( + "name", help="The name of the gateway" + ).completer = GatewayNameCompleter() delete_parser.add_argument( "-y", "--yes", action="store_true", help="Don't ask for confirmation" ) @@ -68,7 +71,9 @@ def _register(self): "update", help="Update a gateway", formatter_class=self._parser.formatter_class ) update_parser.set_defaults(subfunc=self._update) - update_parser.add_argument("name", help="The name of the gateway") + update_parser.add_argument( + "name", help="The name of the gateway" + ).completer = GatewayNameCompleter() update_parser.add_argument( "--set-default", action="store_true", help="Set it the default gateway for the project" ) diff --git a/src/dstack/_internal/cli/commands/logs.py b/src/dstack/_internal/cli/commands/logs.py index 369e5c507..ffcf55101 100644 --- a/src/dstack/_internal/cli/commands/logs.py +++ b/src/dstack/_internal/cli/commands/logs.py @@ -3,6 +3,7 @@ from pathlib import Path from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.completion import RunNameCompleter from dstack._internal.core.errors import CLIError from dstack._internal.utils.logging import get_logger @@ -33,7 +34,7 @@ def _register(self): ) self._parser.add_argument( "--replica", - help="The relica number. Defaults to 0.", + help="The replica number. Defaults to 0.", type=int, default=0, ) @@ -43,7 +44,7 @@ def _register(self): type=int, default=0, ) - self._parser.add_argument("run_name") + self._parser.add_argument("run_name").completer = RunNameCompleter(all=True) def _command(self, args: argparse.Namespace): super()._command(args) diff --git a/src/dstack/_internal/cli/commands/stats.py b/src/dstack/_internal/cli/commands/stats.py index 20c3bdeff..12f52346b 100644 --- a/src/dstack/_internal/cli/commands/stats.py +++ b/src/dstack/_internal/cli/commands/stats.py @@ -7,6 +7,7 @@ from rich.table import Table from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.completion import RunNameCompleter from dstack._internal.cli.utils.common import ( LIVE_TABLE_PROVISION_INTERVAL_SECS, LIVE_TABLE_REFRESH_RATE_PER_SEC, @@ -25,7 +26,7 @@ class StatsCommand(APIBaseCommand): def _register(self): super()._register() - self._parser.add_argument("run_name") + self._parser.add_argument("run_name").completer = RunNameCompleter() self._parser.add_argument( "-w", "--watch", diff --git a/src/dstack/_internal/cli/commands/stop.py b/src/dstack/_internal/cli/commands/stop.py index ad08228b3..dced8d77e 100644 --- a/src/dstack/_internal/cli/commands/stop.py +++ b/src/dstack/_internal/cli/commands/stop.py @@ -1,6 +1,7 @@ import argparse from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.completion import RunNameCompleter from dstack._internal.cli.utils.common import confirm_ask from dstack._internal.core.errors import CLIError @@ -13,7 +14,7 @@ def _register(self): super()._register() self._parser.add_argument("-x", "--abort", action="store_true") self._parser.add_argument("-y", "--yes", action="store_true") - self._parser.add_argument("run_name") + self._parser.add_argument("run_name").completer = RunNameCompleter() def _command(self, args: argparse.Namespace): super()._command(args) diff --git a/src/dstack/_internal/cli/commands/volume.py b/src/dstack/_internal/cli/commands/volume.py index a4c10b342..5f7bc833e 100644 --- a/src/dstack/_internal/cli/commands/volume.py +++ b/src/dstack/_internal/cli/commands/volume.py @@ -4,6 +4,7 @@ from rich.live import Live from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.completion import VolumeNameCompleter from dstack._internal.cli.utils.common import ( LIVE_TABLE_PROVISION_INTERVAL_SECS, LIVE_TABLE_REFRESH_RATE_PER_SEC, @@ -47,7 +48,7 @@ def _register(self): delete_parser.add_argument( "name", help="The name of the volume", - ) + ).completer = VolumeNameCompleter() delete_parser.add_argument( "-y", "--yes", help="Don't ask for confirmation", action="store_true" ) diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index 431e07862..800595b73 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -1,10 +1,12 @@ import argparse +import argcomplete from rich.markup import escape from rich_argparse import RichHelpFormatter from dstack._internal.cli.commands.apply import ApplyCommand from dstack._internal.cli.commands.attach import AttachCommand +from dstack._internal.cli.commands.completion import CompletionCommand from dstack._internal.cli.commands.config import ConfigCommand from dstack._internal.cli.commands.delete import DeleteCommand from dstack._internal.cli.commands.fleet import FleetCommand @@ -72,9 +74,13 @@ def main(): StatsCommand.register(subparsers) StopCommand.register(subparsers) VolumeCommand.register(subparsers) + CompletionCommand.register(subparsers) + + argcomplete.autocomplete(parser, always_complete_options=False) args, unknown_args = parser.parse_known_args() args.unknown = unknown_args + try: check_for_updates() get_ssh_client_info() diff --git a/src/dstack/_internal/cli/services/completion.py b/src/dstack/_internal/cli/services/completion.py new file mode 100644 index 000000000..ed8ce26ad --- /dev/null +++ b/src/dstack/_internal/cli/services/completion.py @@ -0,0 +1,86 @@ +import argparse +import os +from abc import ABC, abstractmethod +from typing import Iterable, List, Optional + +import argcomplete +from argcomplete.completers import BaseCompleter + +from dstack._internal.core.errors import ConfigurationError +from dstack._internal.core.services.configs import ConfigManager +from dstack.api import Client + + +class BaseAPINameCompleter(BaseCompleter, ABC): + """ + Base class for name completers that fetch resource names via the API. + """ + + def __init__(self): + super().__init__() + + def get_api(self, parsed_args: argparse.Namespace) -> Optional[Client]: + argcomplete.debug(f"{self.__class__.__name__}: Retrieving API client") + project = getattr(parsed_args, "project", os.getenv("DSTACK_PROJECT")) + try: + return Client.from_config(project_name=project) + except ConfigurationError as e: + argcomplete.debug(f"{self.__class__.__name__}: Error initializing API client: {e}") + return None + + def __call__(self, prefix: str, parsed_args: argparse.Namespace, **kwargs) -> List[str]: + api = self.get_api(parsed_args) + if api is None: + return [] + + argcomplete.debug(f"{self.__class__.__name__}: Fetching completions") + try: + resource_names = self.fetch_resource_names(api) + return [name for name in resource_names if name.startswith(prefix)] + except Exception as e: + argcomplete.debug( + f"{self.__class__.__name__}: Error fetching resource completions: {e}" + ) + return [] + + @abstractmethod + def fetch_resource_names(self, api: Client) -> Iterable[str]: + """ + Returns an iterable of resource names. + """ + pass + + +class RunNameCompleter(BaseAPINameCompleter): + def __init__(self, all: bool = False): + super().__init__() + self.all = all + + def fetch_resource_names(self, api: Client) -> Iterable[str]: + return [r.name for r in api.runs.list(self.all)] + + +class FleetNameCompleter(BaseAPINameCompleter): + def fetch_resource_names(self, api: Client) -> Iterable[str]: + return [r.name for r in api.client.fleets.list(api.project)] + + +class VolumeNameCompleter(BaseAPINameCompleter): + def fetch_resource_names(self, api: Client) -> Iterable[str]: + return [r.name for r in api.client.volumes.list(api.project)] + + +class GatewayNameCompleter(BaseAPINameCompleter): + def fetch_resource_names(self, api: Client) -> Iterable[str]: + return [r.name for r in api.client.gateways.list(api.project)] + + +class ProjectNameCompleter(BaseCompleter): + """ + Completer for local project names. + """ + + def __call__(self, prefix: str, parsed_args: argparse.Namespace, **kwargs) -> List[str]: + argcomplete.debug(f"{self.__class__.__name__}: Listing projects from ConfigManager") + projects = ConfigManager().list_projects() + return [p for p in projects if p.startswith(prefix)] diff --git a/src/dstack/_internal/core/services/configs/__init__.py b/src/dstack/_internal/core/services/configs/__init__.py index b0fa73444..f11983e38 100644 --- a/src/dstack/_internal/core/services/configs/__init__.py +++ b/src/dstack/_internal/core/services/configs/__init__.py @@ -65,6 +65,9 @@ def configure_project(self, name: str, url: str, token: str, default: bool): if len(self.config.projects) == 1: self.config.projects[0].default = True + def list_projects(self): + return [project.name for project in self.config.projects] + def delete_project(self, name: str): self.config.projects = [p for p in self.config.projects if p.name != name]