From dfa2bb1bc8d8b187ea97f8ee3d2e5ce3759a4765 Mon Sep 17 00:00:00 2001 From: Alvaro Valdebenito Date: Mon, 20 Mar 2023 16:22:43 +0100 Subject: [PATCH] use click to define CLI sub cmd --- airbase/cli.py | 188 +++++++++++++++++++++++----------- tests/integration/test_cli.py | 9 +- tests/test_cli.py | 4 +- 3 files changed, 135 insertions(+), 66 deletions(-) diff --git a/airbase/cli.py b/airbase/cli.py index ca61a9a..f58b7ed 100644 --- a/airbase/cli.py +++ b/airbase/cli.py @@ -3,17 +3,14 @@ from datetime import date from enum import Enum from pathlib import Path -from typing import List +import click import typer from . import __version__ from .airbase import AirbaseClient -main = typer.Typer( - no_args_is_help=True, - add_completion=False, -) +app = typer.Typer(no_args_is_help=True, add_completion=False) client = AirbaseClient() @@ -47,7 +44,7 @@ def version_callback(value: bool): raise typer.Exit() -@main.callback() +@app.callback() def root_options( version: bool = typer.Option( False, @@ -60,18 +57,83 @@ def root_options( """Download Air Quality Data from the European Environment Agency (EEA)""" -@main.command() +countries = click.option( + "-c", + "--country", + "countries", + type=click.Choice(Country), # type:ignore [arg-type] + multiple=True, +) +country = click.argument( + "country", + type=click.Choice(Country), # type:ignore [arg-type] +) + +pollutants = click.option( + "-p", + "--pollutant", + "pollutants", + type=click.Choice(Pollutant), # type:ignore [arg-type] + multiple=True, +) +pollutant = click.argument( + "pollutant", + type=click.Choice(Pollutant), # type:ignore [arg-type] +) + + +path = click.option( + "--path", + default="data", + type=click.Path(exists=True, dir_okay=True, writable=True), +) +year = click.option("--year", default=date.today().year, type=int) +overwrite = click.option( + "-O", + "--overwrite", + is_flag=True, + help="Re-download existing files.", +) +quiet = click.option( + "-q", + "--quiet", + is_flag=True, + help="No progress-bar.", +) + + +def _download( + countries: list[Country], + pollutants: list[Pollutant], + path: Path, + year: int, + overwrite: bool, + quiet: bool, +): + request = client.request( + countries or None, # type:ignore[arg-type] + pollutants or None, # type:ignore[arg-type] + year_from=str(year), + year_to=str(year), + verbose=not quiet, + ) + request.download_to_directory(path, skip_existing=not overwrite) + + +@click.command() +@countries +@pollutants +@path +@year +@overwrite +@quiet def download( - countries: List[Country] = typer.Option([], "--country", "-c"), - pollutants: List[Pollutant] = typer.Option([], "--pollutant", "-p"), - path: Path = typer.Option( - "data", exists=True, dir_okay=True, writable=True - ), - year: int = typer.Option(date.today().year), - overwrite: bool = typer.Option( - False, "--overwrite", "-O", help="Re-download existing files." - ), - quiet: bool = typer.Option(False, "--quiet", "-q", help="No progress-bar."), + countries: list[Country], + pollutants: list[Pollutant], + path: Path, + year: int, + overwrite: bool, + quiet: bool, ): """Download all pollutants for all countries @@ -82,15 +144,7 @@ def download( - download only SO2, PM10 and PM2.5 observations airbase download -p SO2 -p PM10 -p PM2.5 """ - - request = client.request( - countries or None, # type:ignore[arg-type] - pollutants or None, # type:ignore[arg-type] - year_from=str(year), - year_to=str(year), - verbose=not quiet, - ) - request.download_to_directory(path, skip_existing=not overwrite) + _download(countries, pollutants, path, year, overwrite, quiet) def deprecation_message(old: str, new: str): # pragma: no cover @@ -101,55 +155,69 @@ def deprecation_message(old: str, new: str): # pragma: no cover ) -@main.command(name="all") +@click.command() +@countries +@pollutants +@path +@year +@overwrite +@quiet def download_all( - countries: List[Country] = typer.Option([], "--country", "-c"), - pollutants: List[Pollutant] = typer.Option([], "--pollutant", "-p"), - path: Path = typer.Option( - "data", exists=True, dir_okay=True, writable=True - ), - year: int = typer.Option(date.today().year), - overwrite: bool = typer.Option( - False, "--overwrite", "-O", help="Re-download existing files." - ), - quiet: bool = typer.Option(False, "--quiet", "-q", help="No progress-bar."), + countries: list[Country], + pollutants: list[Pollutant], + path: Path, + year: int, + overwrite: bool, + quiet: bool, ): # pragma: no cover """Download all pollutants for all countries (deprecated)""" deprecation_message("all", "download") - download(countries, pollutants, path, year, overwrite, quiet) + _download(countries, pollutants, path, year, overwrite, quiet) -@main.command(name="country") +@click.command() +@country +@pollutants +@path +@year +@overwrite +@quiet def download_country( country: Country, - pollutants: List[Pollutant] = typer.Option([], "--pollutant", "-p"), - path: Path = typer.Option( - "data", exists=True, dir_okay=True, writable=True - ), - year: int = typer.Option(date.today().year), - overwrite: bool = typer.Option( - False, "--overwrite", "-O", help="Re-download existing files." - ), - quiet: bool = typer.Option(False, "--quiet", "-q", help="No progress-bar."), + pollutants: list[Pollutant], + path: Path, + year: int, + overwrite: bool, + quiet: bool, ): # pragma: no cover """Download specific pollutants for one country (deprecated)""" deprecation_message("country", "download") - download([country], pollutants, path, year, overwrite, quiet) + _download([country], pollutants, path, year, overwrite, quiet) -@main.command(name="pollutant") +@click.command() +@pollutant +@countries +@path +@year +@overwrite +@quiet def download_pollutant( pollutant: Pollutant, - countries: List[Country] = typer.Option([], "--country", "-c"), - path: Path = typer.Option( - "data", exists=True, dir_okay=True, writable=True - ), - year: int = typer.Option(date.today().year), - overwrite: bool = typer.Option( - False, "--overwrite", "-O", help="Re-download existing files." - ), - quiet: bool = typer.Option(False, "--quiet", "-q", help="No progress-bar."), + countries: list[Country], + path: Path, + year: int, + overwrite: bool, + quiet: bool, ): # pragma: no cover """Download specific countries for one pollutant (deprecated)""" deprecation_message("pollutant", "download") - download(countries, [pollutant], path, year, overwrite, quiet) + _download(countries, [pollutant], path, year, overwrite, quiet) + + +# click object +main: click.Group = typer.main.get_command(app) # type:ignore [assignment] +main.add_command(download, "download") +main.add_command(download_all, "all") +main.add_command(download_country, "country") +main.add_command(download_pollutant, "pollutant") diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index f23dd60..b830d04 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -1,6 +1,6 @@ from pathlib import Path -from typer.testing import CliRunner +from click.testing import CliRunner from airbase.cli import main @@ -10,8 +10,9 @@ def test_download(tmp_path: Path): country, year, pollutant, id = "NO", 2021, "NO2", 8 options = f"download --quiet --country {country} --pollutant {pollutant} --year {year} --path {tmp_path}" - result = runner.invoke(main, options.split()) - assert result.exit_code == 0 + with runner.isolated_filesystem(temp_dir=tmp_path): + result = runner.invoke(main, options.split()) + assert result.exit_code == 0 + files = tmp_path.glob(f"{country}_{id}_*_{year}_timeseries.csv") - files = tmp_path.glob(f"{country}_{id}_*_{year}_timeseries.csv") assert list(files) diff --git a/tests/test_cli.py b/tests/test_cli.py index 9fccac8..dac9d6c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,8 +1,8 @@ from __future__ import annotations import pytest +from click.testing import CliRunner from typer import Typer -from typer.testing import CliRunner from airbase import __version__ from airbase.cli import Country, Pollutant, main @@ -25,7 +25,7 @@ def test_version(options: str): result = runner.invoke(main, options.split()) assert result.exit_code == 0 assert "airbase" in result.output - assert __version__ in result.output + assert str(__version__) in result.output @pytest.mark.xfail(