Skip to content

Commit

Permalink
use click to define CLI sub cmd
Browse files Browse the repository at this point in the history
  • Loading branch information
avaldebe committed Mar 20, 2023
1 parent 6fa8f23 commit 975044a
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 64 deletions.
192 changes: 132 additions & 60 deletions airbase/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
from datetime import date
from enum import Enum
from pathlib import Path
from typing import List

import click
import typer

from . import __version__
from .airbase import AirbaseClient

main = typer.Typer(
no_args_is_help=True,
add_completion=False,
)
app = typer.Typer(no_args_is_help=True, add_completion=False)
client = AirbaseClient()


Expand Down Expand Up @@ -47,7 +44,7 @@ def version_callback(value: bool):
raise typer.Exit()


@main.callback()
@app.callback()
def root_options(
version: bool = typer.Option(
False,
Expand All @@ -60,18 +57,87 @@ def root_options(
"""Download Air Quality Data from the European Environment Agency (EEA)"""


@main.command()
countries = click.option(
"--country",
"-c",
"countries",
type=click.Choice(Country), # type:ignore [arg-type]
multiple=True,
default=[],
)
country = click.argument(
"country",
type=click.Choice(Country), # type:ignore [arg-type]
)

pollutants = click.option(
"--pollutant",
"-p",
"pollutants",
type=click.Choice(Pollutant), # type:ignore [arg-type]
multiple=True,
default=[],
)
pollutant = click.argument(
"pollutant",
type=click.Choice(Pollutant), # type:ignore [arg-type]
)


path = click.option(
"--path",
default="data",
type=click.Path(exists=True, dir_okay=True, writable=True),
)
year = click.option("--year", default=date.today().year, type=int)
overwrite = click.option(
"--overwrite",
"-O",
help="Re-download existing files.",
default=False,
type=bool,
)
quiet = click.option(
"--quiet",
"-q",
help="No progress-bar.",
default=False,
type=bool,
)


def _download(
countries: list[Country],
pollutants: list[Pollutant],
path: Path,
year: int,
overwrite: bool,
quiet: bool,
):
request = client.request(
countries or None, # type:ignore[arg-type]
pollutants or None, # type:ignore[arg-type]
year_from=str(year),
year_to=str(year),
verbose=not quiet,
)
request.download_to_directory(path, skip_existing=not overwrite)


@click.command()
@countries
@pollutants
@path
@year
@overwrite
@quiet
def download(
countries: List[Country] = typer.Option([], "--country", "-c"),
pollutants: List[Pollutant] = typer.Option([], "--pollutant", "-p"),
path: Path = typer.Option(
"data", exists=True, dir_okay=True, writable=True
),
year: int = typer.Option(date.today().year),
overwrite: bool = typer.Option(
False, "--overwrite", "-O", help="Re-download existing files."
),
quiet: bool = typer.Option(False, "--quiet", "-q", help="No progress-bar."),
countries: list[Country],
pollutants: list[Pollutant],
path: Path,
year: int,
overwrite: bool,
quiet: bool,
):
"""Download all pollutants for all countries
Expand All @@ -82,15 +148,7 @@ def download(
- download only SO2, PM10 and PM2.5 observations
airbase download -p SO2 -p PM10 -p PM2.5
"""

request = client.request(
countries or None, # type:ignore[arg-type]
pollutants or None, # type:ignore[arg-type]
year_from=str(year),
year_to=str(year),
verbose=not quiet,
)
request.download_to_directory(path, skip_existing=not overwrite)
_download(countries, pollutants, path, year, overwrite, quiet)


def deprecation_message(old: str, new: str): # pragma: no cover
Expand All @@ -101,55 +159,69 @@ def deprecation_message(old: str, new: str): # pragma: no cover
)


@main.command(name="all")
@click.command()
@countries
@pollutants
@path
@year
@overwrite
@quiet
def download_all(
countries: List[Country] = typer.Option([], "--country", "-c"),
pollutants: List[Pollutant] = typer.Option([], "--pollutant", "-p"),
path: Path = typer.Option(
"data", exists=True, dir_okay=True, writable=True
),
year: int = typer.Option(date.today().year),
overwrite: bool = typer.Option(
False, "--overwrite", "-O", help="Re-download existing files."
),
quiet: bool = typer.Option(False, "--quiet", "-q", help="No progress-bar."),
countries: list[Country],
pollutants: list[Pollutant],
path: Path,
year: int,
overwrite: bool,
quiet: bool,
): # pragma: no cover
"""Download all pollutants for all countries (deprecated)"""
deprecation_message("all", "download")
download(countries, pollutants, path, year, overwrite, quiet)
_download(countries, pollutants, path, year, overwrite, quiet)


@main.command(name="country")
@click.command()
@country
@pollutants
@path
@year
@overwrite
@quiet
def download_country(
country: Country,
pollutants: List[Pollutant] = typer.Option([], "--pollutant", "-p"),
path: Path = typer.Option(
"data", exists=True, dir_okay=True, writable=True
),
year: int = typer.Option(date.today().year),
overwrite: bool = typer.Option(
False, "--overwrite", "-O", help="Re-download existing files."
),
quiet: bool = typer.Option(False, "--quiet", "-q", help="No progress-bar."),
pollutants: list[Pollutant],
path: Path,
year: int,
overwrite: bool,
quiet: bool,
): # pragma: no cover
"""Download specific pollutants for one country (deprecated)"""
deprecation_message("country", "download")
download([country], pollutants, path, year, overwrite, quiet)
_download([country], pollutants, path, year, overwrite, quiet)


@main.command(name="pollutant")
@click.command()
@pollutant
@countries
@path
@year
@overwrite
@quiet
def download_pollutant(
pollutant: Pollutant,
countries: List[Country] = typer.Option([], "--country", "-c"),
path: Path = typer.Option(
"data", exists=True, dir_okay=True, writable=True
),
year: int = typer.Option(date.today().year),
overwrite: bool = typer.Option(
False, "--overwrite", "-O", help="Re-download existing files."
),
quiet: bool = typer.Option(False, "--quiet", "-q", help="No progress-bar."),
countries: list[Country],
path: Path,
year: int,
overwrite: bool,
quiet: bool,
): # pragma: no cover
"""Download specific countries for one pollutant (deprecated)"""
deprecation_message("pollutant", "download")
download(countries, [pollutant], path, year, overwrite, quiet)
_download(countries, [pollutant], path, year, overwrite, quiet)


# click object
main: click.Group = typer.main.get_command(app) # type:ignore [assignment]
main.add_command(download, "download")
main.add_command(download_all, "all")
main.add_command(download_country, "country")
main.add_command(download_pollutant, "pollutant")
4 changes: 2 additions & 2 deletions tests/integration/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from typer.testing import CliRunner

from airbase.cli import main
from airbase.cli import app

runner = CliRunner()


def test_download(tmp_path: Path):
country, year, pollutant, id = "NO", 2021, "NO2", 8
options = f"download --quiet --country {country} --pollutant {pollutant} --year {year} --path {tmp_path}"
result = runner.invoke(main, options.split())
result = runner.invoke(app, options.split())
assert result.exit_code == 0

files = tmp_path.glob(f"{country}_{id}_*_{year}_timeseries.csv")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typer.testing import CliRunner

from airbase import __version__
from airbase.cli import Country, Pollutant, main
from airbase.cli import Country, Pollutant, app

runner = CliRunner()

Expand All @@ -22,7 +22,7 @@ def test_pollutant(pollutant: Pollutant):

@pytest.mark.parametrize("options", ("--version", "-V"))
def test_version(options: str):
result = runner.invoke(main, options.split())
result = runner.invoke(app, options.split())
assert result.exit_code == 0
assert "airbase" in result.output
assert __version__ in result.output
Expand Down

0 comments on commit 975044a

Please sign in to comment.