diff --git a/docs/cli.rst b/docs/cli.rst index a85af4309..0e48ce521 100644 --- a/docs/cli.rst +++ b/docs/cli.rst @@ -190,9 +190,11 @@ For instance :: ValueError("invalid literal for int() with base 10: 'foo'",) The toolkit includes two additional "types" (or rather, *validators*): ``Range`` and ``Set``. -``Range`` takes a minimal value and a maximal value and expects an integer in that range -(inclusive). ``Set`` takes a set of allowed values, and expects the argument to match one of -these values. Here's an example :: +``Range`` takes a minimal value and a maximal value and expects an integer in +that range (inclusive). ``Set`` takes a set of allowed values, and expects the +argument to match one of these values. You can set ``case_sensitive=False``, or +add ``all_markers={"*", "all"}`` if you want to have a "trigger all markers" +marker. Here's an example :: class MyApp(cli.Application): _port = 8080 diff --git a/plumbum/cli/switches.py b/plumbum/cli/switches.py index 830e0008c..692b599fb 100644 --- a/plumbum/cli/switches.py +++ b/plumbum/cli/switches.py @@ -1,6 +1,8 @@ +import collections.abc import contextlib import inspect from abc import ABC, abstractmethod +from typing import Callable, Generator, List, Union from plumbum import local from plumbum.cli.i18n import get_translation_for @@ -456,41 +458,62 @@ class MyApp(Application): comparison or not. The default is ``False`` :param csv: splits the input as a comma-separated-value before validating and returning a list. Accepts ``True``, ``False``, or a string for the separator + :param all_markers: When a user inputs any value from this set, all values are iterated + over. Something like {"*", "all"} would be a potential setting for + this option. """ - def __init__(self, *values, **kwargs): - self.case_sensitive = kwargs.pop("case_sensitive", False) - self.csv = kwargs.pop("csv", False) - if self.csv is True: - self.csv = "," - if kwargs: - raise TypeError( - _("got unexpected keyword argument(s): {0}").format(kwargs.keys()) - ) + def __init__( + self, + *values: Union[str, Callable[[str], str]], + case_sensitive: bool = False, + csv: Union[bool, str] = False, + all_markers: "collections.abc.Set[str]" = frozenset(), + ) -> None: + self.case_sensitive = case_sensitive + if isinstance(csv, bool): + self.csv = "," if csv else "" + else: + self.csv = csv self.values = values + self.all_markers = all_markers def __repr__(self): items = ", ".join(v if isinstance(v, str) else v.__name__ for v in self.values) return f"{{{items}}}" - def __call__(self, value, check_csv=True): + def _call_iter( + self, value: str, check_csv: bool = True + ) -> Generator[str, None, None]: if self.csv and check_csv: - return [self(v.strip(), check_csv=False) for v in value.split(",")] + for v in value.split(self.csv): + yield from self._call_iter(v.strip(), check_csv=False) + if not self.case_sensitive: value = value.lower() + for opt in self.values: if isinstance(opt, str): if not self.case_sensitive: opt = opt.lower() - if opt == value: - return opt # always return original value + if opt == value or value in self.all_markers: + yield opt # always return original value continue with contextlib.suppress(ValueError): - return opt(value) - raise ValueError(f"Invalid value: {value} (Expected one of {self.values})") + yield opt(value) + + def __call__(self, value: str, check_csv: bool = True) -> Union[str, List[str]]: + items = list(self._call_iter(value, check_csv)) + if not items: + msg = f"Invalid value: {value} (Expected one of {self.values})" + raise ValueError(msg) + if self.csv and check_csv or len(items) > 1: + return items + return items[0] def choices(self, partial=""): choices = {opt if isinstance(opt, str) else f"({opt})" for opt in self.values} + choices |= self.all_markers if partial: choices = {opt for opt in choices if opt.lower().startswith(partial)} return choices diff --git a/pyproject.toml b/pyproject.toml index 27ddb9908..bf9042f4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,4 +105,5 @@ messages_control.disable = [ "too-many-statements", "unidiomatic-typecheck", # TODO: might be able to remove "unnecessary-lambda-assignment", # TODO: 4 instances + "unused-import", # identical to flake8 but has typing false positives ] diff --git a/tests/test_cli.py b/tests/test_cli.py index b7a5e5aa2..c2512a419 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -27,7 +27,9 @@ def bacon(self, param): text wrapping in help messages as well""", ) - csv = cli.SwitchAttr(["--csv"], cli.Set("MIN", "MAX", int, csv=True)) + csv = cli.SwitchAttr( + ["--csv"], cli.Set("MIN", "MAX", int, csv=True, all_markers={"all"}) + ) num = cli.SwitchAttr(["--num"], cli.Set("MIN", "MAX", int)) def main(self, *args): @@ -36,6 +38,8 @@ def main(self, *args): self.eggs = old self.tailargs = args + print(self.csv) + class PositionalApp(cli.Application): def main(self, one): @@ -163,7 +167,7 @@ def test_meta_switches(self): _, rc = SimpleApp.run(["foo", "--version"], exit=False) assert rc == 0 - def test_okay(self): + def test_okay(self, capsys): _, rc = SimpleApp.run(["foo", "--bacon=81"], exit=False) assert rc == 0 @@ -195,6 +199,14 @@ def test_okay(self): _, rc = SimpleApp.run(["foo", "--bacon=81", "--num=100"], exit=False) assert rc == 0 + capsys.readouterr() + _, rc = SimpleApp.run(["foo", "--bacon=81", "--csv=all,100"], exit=False) + assert rc == 0 + output = capsys.readouterr() + assert "min" in output.out + assert "max" in output.out + assert "100" in output.out + _, rc = SimpleApp.run(["foo", "--bacon=81", "--num=MAX"], exit=False) assert rc == 0