Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround for Typer optional default values with Python calls #10788

Merged
merged 17 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
dd8c03c
Workaround for Typer optional default values with Python calls: added…
rmitsch May 12, 2022
fe0730f
@rmitsch Workaround for Typer optional default values with Python cal…
rmitsch May 18, 2022
457a5c9
@rmitsch Workaround for Typer optional default values with Python cal…
rmitsch May 18, 2022
ef66c3d
Workaround for Typer optional default values with Python calls: fixed…
rmitsch May 18, 2022
9021176
Workaround for Typer optional default values with Python calls: remov…
rmitsch May 23, 2022
e19a14b
Workaround for Typer optional default values with Python calls: remov…
rmitsch May 23, 2022
e2555e5
Workaround for Typer optional default values with Python calls: fixed…
rmitsch May 23, 2022
5c7eed2
Workaround for Typer optional default values with Pythhon calls: remo…
rmitsch May 24, 2022
a569674
Merge branch 'master' into fix/typer-option-default-values
rmitsch May 24, 2022
2abe11e
Workaround for Typer optional default values with Python calls: renam…
rmitsch May 25, 2022
596fcf0
Merge branch 'fix/typer-option-default-values' of github.com:rmitsch/…
rmitsch May 25, 2022
527bee1
Workaround for Typer optional default values with Python calls: remov…
rmitsch May 25, 2022
da5ae2e
remove introduced newlines
svlandeg Jun 9, 2022
19a4d7a
Remove test_init_config_from_python_without_optional_args().
rmitsch Jun 9, 2022
16f8944
remove leftover import
svlandeg Jun 10, 2022
58047f8
reformat import
svlandeg Jun 10, 2022
64db12e
remove duplicate
svlandeg Jun 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions spacy/cli/init_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, List, Tuple
from enum import Enum
from pathlib import Path

svlandeg marked this conversation as resolved.
Show resolved Hide resolved
from wasabi import Printer, diff_strings
from thinc.api import Config
import srsly
Expand All @@ -12,7 +13,7 @@
from ..schemas import RecommendationSchema
svlandeg marked this conversation as resolved.
Show resolved Hide resolved
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
from ._util import string_to_list, import_code
svlandeg marked this conversation as resolved.
Show resolved Hide resolved

from ..util import SimpleFrozenList
svlandeg marked this conversation as resolved.
Show resolved Hide resolved

ROOT = Path(__file__).parent / "templates"
TEMPLATE_PATH = ROOT / "quickstart_training.jinja"
Expand All @@ -24,16 +25,31 @@ class Optimizations(str, Enum):
accuracy = "accuracy"


class InitDefaultValues:
"""
Default values for initialization. Dedicated class to allow synchronized default values for init_config_cli() and
init_config(), i.e. initialization calls via CLI respectively Python.
"""

output_file: Path = ... # type: ignore
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
lang = "en"
pipeline = SimpleFrozenList(["tagger", "parser", "ner"])
optimize = Optimizations.efficiency
gpu = False
pretraining = False
force_overwrite = False


@init_cli.command("config")
def init_config_cli(
# fmt: off
output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True),
lang: str = Opt("en", "--lang", "-l", help="Two-letter code of the language to use"),
pipeline: str = Opt("tagger,parser,ner", "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include (without 'tok2vec' or 'transformer')"),
optimize: Optimizations = Opt(Optimizations.efficiency.value, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
gpu: bool = Opt(False, "--gpu", "-G", help="Whether the model can run on GPU. This will impact the choice of architecture, pretrained weights and related hyperparameters."),
pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
force_overwrite: bool = Opt(False, "--force", "-F", help="Force overwriting the output file"),
lang: str = Opt(InitDefaultValues.lang, "--lang", "-l", help="Two-letter code of the language to use"),
pipeline: str = Opt(",".join(InitDefaultValues.pipeline), "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include (without 'tok2vec' or 'transformer')"),
optimize: Optimizations = Opt(InitDefaultValues.optimize, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."),
gpu: bool = Opt(InitDefaultValues.gpu, "--gpu", "-G", help="Whether the model can run on GPU. This will impact the choice of architecture, pretrained weights and related hyperparameters."),
pretraining: bool = Opt(InitDefaultValues.pretraining, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
force_overwrite: bool = Opt(InitDefaultValues.force_overwrite, "--force", "-F", help="Force overwriting the output file"),
# fmt: on
):
"""
Expand All @@ -44,6 +60,7 @@ def init_config_cli(

DOCS: https://spacy.io/api/cli#init-config
"""

rmitsch marked this conversation as resolved.
Show resolved Hide resolved
pipeline = string_to_list(pipeline)
is_stdout = str(output_file) == "-"
if not is_stdout and output_file.exists() and not force_overwrite:
Expand All @@ -52,6 +69,7 @@ def init_config_cli(
"The provided output file already exists. To force overwriting the config file, set the --force or -F flag.",
exits=1,
)

svlandeg marked this conversation as resolved.
Show resolved Hide resolved
config = init_config(
lang=lang,
pipeline=pipeline,
Expand Down Expand Up @@ -133,11 +151,11 @@ def fill_config(

def init_config(
*,
lang: str,
pipeline: List[str],
optimize: str,
gpu: bool,
pretraining: bool = False,
lang: str = InitDefaultValues.lang,
pipeline: List[str] = InitDefaultValues.pipeline,
optimize: str = InitDefaultValues.optimize,
gpu: bool = InitDefaultValues.gpu,
pretraining: bool = InitDefaultValues.pretraining,
silent: bool = True,
) -> Config:
msg = Printer(no_print=silent)
Expand Down
24 changes: 24 additions & 0 deletions spacy/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from spacy.cli.debug_data import _get_labels_from_spancat
from spacy.cli.download import get_compatibility, get_version
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
from spacy.cli.init_config import init_config_cli, Optimizations
svlandeg marked this conversation as resolved.
Show resolved Hide resolved
from spacy.cli.package import get_third_party_dependencies
from spacy.cli.package import _is_permitted_package_name
from spacy.cli.validate import get_model_pkgs
Expand Down Expand Up @@ -740,3 +741,26 @@ def test_debug_data_compile_gold():
eg = Example(pred, ref)
data = _compile_gold([eg], ["ner"], nlp, True)
assert data["boundary_cross_ents"] == 1


@pytest.mark.issue(10727)
def test_init_config_from_python_without_optional_args():
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
"""
Tests calling init_config_cli() from Python with optional arguments not set. This should detect whether Typer
automatically sets default values for arguments when decorated functions are called from Python instead
from the CLI.
"""

with pytest.raises(AttributeError):
with make_tempdir() as temp_dir:
init_config_cli(output_file=temp_dir / "config.cfg")

with make_tempdir() as temp_dir:
init_config_cli(
output_file=temp_dir / "config.cfg",
lang="en",
pipeline="ner",
optimize=Optimizations.efficiency,
gpu=False,
pretraining=False,
)