diff --git a/.cspell.json b/.cspell.json index 72da7eb2..dfb23319 100644 --- a/.cspell.json +++ b/.cspell.json @@ -71,6 +71,7 @@ "indentless", "ipynb", "jsonschema", + "jupyterlab", "linkcheck", "maxdepth", "maxsplit", diff --git a/src/repoma/check_dev_files/__init__.py b/src/repoma/check_dev_files/__init__.py index 1cfdb2aa..25f3748b 100644 --- a/src/repoma/check_dev_files/__init__.py +++ b/src/repoma/check_dev_files/__init__.py @@ -16,6 +16,7 @@ github_labels, github_workflows, gitpod, + jupyter, mypy, nbstripout, precommit, @@ -58,6 +59,8 @@ def main(argv: Optional[Sequence[str]] = None) -> int: skip_tests=_to_list(args.ci_skipped_tests), test_extras=_to_list(args.ci_test_extras), ) + if has_notebooks: + executor(jupyter.main) executor(nbstripout.main) executor(toml.main) # has to run before pre-commit executor(prettier.main, args.no_prettierrc) diff --git a/src/repoma/check_dev_files/jupyter.py b/src/repoma/check_dev_files/jupyter.py new file mode 100644 index 00000000..9fa4dd93 --- /dev/null +++ b/src/repoma/check_dev_files/jupyter.py @@ -0,0 +1,26 @@ +"""Update the developer setup when using Jupyter notebooks.""" + +from repoma.utilities.executor import Executor +from repoma.utilities.project_info import get_supported_python_versions +from repoma.utilities.pyproject import add_dependency + + +def main() -> None: + _update_dev_requirements() + + +def _update_dev_requirements() -> None: + if "3.6" in get_supported_python_versions(): + return + hierarchy = ["jupyter", "dev"] + dependencies = [ + "jupyterlab", + "jupyterlab-code-formatter", + "jupyterlab-lsp", + "jupyterlab-myst", + "python-lsp-server[rope]", + ] + executor = Executor() + for dependency in dependencies: + executor(add_dependency, dependency, optional_key=hierarchy) + executor.finalize() diff --git a/src/repoma/check_dev_files/ruff.py b/src/repoma/check_dev_files/ruff.py index c966ceaa..65cee492 100644 --- a/src/repoma/check_dev_files/ruff.py +++ b/src/repoma/check_dev_files/ruff.py @@ -1,7 +1,6 @@ """Check `Ruff `_ configuration.""" import os -from copy import deepcopy from textwrap import dedent from typing import List, Set @@ -27,6 +26,7 @@ open_setup_cfg, ) from repoma.utilities.pyproject import ( + add_dependency, complies_with_subset, get_sub_table, load_pyproject, @@ -230,39 +230,12 @@ def _update_pyproject() -> None: f" [{CONFIG_PATH.pyproject}]" ) raise PrecommitError(msg) - project = get_sub_table(pyproject, "project", create=True) - old_dependencies = project.get("optional-dependencies") - new_dependencies = deepcopy(old_dependencies) python_versions = project_info.supported_python_versions if python_versions is not None and "3.6" in python_versions: ruff = 'ruff; python_version >="3.7.0"' else: ruff = "ruff" - if new_dependencies is None: - new_dependencies = dict( - dev=[f"{package}[sty]"], - lint=[ruff], - sty=[f"{package}[lint]"], - ) - else: - __add_package(new_dependencies, "dev", f"{package}[sty]") - __add_package(new_dependencies, "lint", ruff) - __add_package(new_dependencies, "sty", f"{package}[lint]") - if old_dependencies != new_dependencies: - project["optional-dependencies"] = new_dependencies - write_pyproject(pyproject) - msg = f"Updated [project.optional-dependencies] in {CONFIG_PATH.pyproject}" - raise PrecommitError(msg) - - -def __add_package(optional_dependencies: Table, key: str, package: str) -> None: - section = optional_dependencies.get(key) - if section is None: - optional_dependencies[key] = [package] - elif package not in section: - optional_dependencies[key] = to_toml_array( - sorted({package, *section}, key=lambda s: ('"' in s, s)) # Taplo sorting - ) + add_dependency(ruff, optional_key=["lint", "sty", "dev"]) def _remove_nbqa() -> None: diff --git a/src/repoma/utilities/project_info.py b/src/repoma/utilities/project_info.py index 66ecf712..4729a3ab 100644 --- a/src/repoma/utilities/project_info.py +++ b/src/repoma/utilities/project_info.py @@ -1,6 +1,7 @@ """Helper functions for reading from and writing to :file:`setup.cfg`.""" import os +import sys from configparser import ConfigParser from textwrap import dedent from typing import Dict, List, Optional @@ -14,11 +15,19 @@ from . import CONFIG_PATH from .cfg import open_config +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + + +PythonVersion = Literal["3.6", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + @frozen class ProjectInfo: name: Optional[str] = None - supported_python_versions: Optional[List[str]] = None + supported_python_versions: Optional[List[PythonVersion]] = None urls: Dict[str, str] = field(factory=dict) def is_empty(self) -> bool: @@ -80,13 +89,13 @@ def get_project_info(pyproject: Optional[TOMLDocument] = None) -> ProjectInfo: raise PrecommitError(msg) -def _extract_python_versions(classifiers: List[str]) -> Optional[List[str]]: +def _extract_python_versions(classifiers: List[str]) -> Optional[List[PythonVersion]]: identifier = "Programming Language :: Python :: 3." version_classifiers = [s for s in classifiers if s.startswith(identifier)] if not version_classifiers: return None prefix = identifier[:-2] - return [s.replace(prefix, "") for s in version_classifiers] + return [s.replace(prefix, "") for s in version_classifiers] # type: ignore[misc] def get_pypi_name(pyproject: Optional[TOMLDocument] = None) -> str: @@ -107,7 +116,7 @@ def get_pypi_name(pyproject: Optional[TOMLDocument] = None) -> str: def get_supported_python_versions( pyproject: Optional[TOMLDocument] = None, -) -> List[str]: +) -> List[PythonVersion]: """Extract supported Python versions from package classifiers. >>> get_supported_python_versions() diff --git a/src/repoma/utilities/pyproject.py b/src/repoma/utilities/pyproject.py index 488272b1..1a3431ea 100644 --- a/src/repoma/utilities/pyproject.py +++ b/src/repoma/utilities/pyproject.py @@ -1,7 +1,9 @@ """Tools for loading, inspecting, and updating :code:`pyproject.toml`.""" -import os -from typing import Any, Iterable, Optional +import io +from collections import abc +from pathlib import Path +from typing import IO, Any, Iterable, List, Optional, Sequence, Set, Union import tomlkit from tomlkit.container import Container @@ -10,20 +12,115 @@ from repoma.errors import PrecommitError from repoma.utilities import CONFIG_PATH +from repoma.utilities.executor import Executor from repoma.utilities.precommit import find_repo, load_round_trip_precommit_config +def add_dependency( # noqa: C901, PLR0912 + package: str, + optional_key: Optional[Union[str, Sequence[str]]] = None, + source: Union[IO, Path, TOMLDocument, str] = CONFIG_PATH.pyproject, + target: Optional[Union[IO, Path, str]] = None, +) -> None: + if isinstance(source, TOMLDocument): + pyproject = source + else: + pyproject = load_pyproject(source) + if target is None: + if isinstance(source, TOMLDocument): + msg = "If the source is a TOML document, you have to specify a target" + raise TypeError(msg) + target = source + if optional_key is None: + project = get_sub_table(pyproject, "project", create=True) + existing_dependencies: Set[str] = set(project.get("dependencies", [])) + if package in existing_dependencies: + return + existing_dependencies.add(package) + project["dependencies"] = to_toml_array(_sort_taplo(existing_dependencies)) + elif isinstance(optional_key, str): + optional_dependencies = get_sub_table( + pyproject, "project.optional-dependencies", create=True + ) + existing_dependencies = set(optional_dependencies.get(optional_key, [])) + if package in existing_dependencies: + return + existing_dependencies.add(package) + existing_dependencies = set(existing_dependencies) + optional_dependencies[optional_key] = to_toml_array( + _sort_taplo(existing_dependencies) + ) + elif isinstance(optional_key, abc.Sequence): + if len(optional_key) < 2: # noqa: PLR2004 + msg = "Need at least two keys to define nested optional dependencies" + raise ValueError(msg) + this_package = get_package_name_safe(pyproject) + executor = Executor() + for key, previous in zip(optional_key, [None, *optional_key]): + if previous is None: + executor(add_dependency, package, key, source, target) + else: + executor( + add_dependency, f"{this_package}[{previous}]", key, source, target + ) + if executor.finalize() == 0: + return + else: + msg = f"Unsupported type for optional_key: {type(optional_key)}" + raise NotImplementedError(msg) + write_pyproject(pyproject, target) + msg = f"Listed {package} as a dependency under {CONFIG_PATH.pyproject}" + raise PrecommitError(msg) + + +def _sort_taplo(items: Iterable[str]) -> List[str]: + return sorted(items, key=lambda s: ('"' in s, s)) + + def complies_with_subset(settings: dict, minimal_settings: dict) -> bool: return all(settings.get(key) == value for key, value in minimal_settings.items()) -def load_pyproject(content: Optional[str] = None) -> TOMLDocument: - if not os.path.exists(CONFIG_PATH.pyproject): - return TOMLDocument() - if content is None: - with open(CONFIG_PATH.pyproject) as stream: - return tomlkit.loads(stream.read()) - return tomlkit.loads(content) +def load_pyproject( + source: Union[IO, Path, str] = CONFIG_PATH.pyproject +) -> TOMLDocument: + if isinstance(source, io.IOBase): + source.seek(0) + return tomlkit.load(source) + if isinstance(source, Path): + with open(source) as stream: + return load_pyproject(stream) + if isinstance(source, str): + return tomlkit.loads(source) + msg = f"Source of type {type(source).__name__} is not supported" + raise TypeError(msg) + + +def get_package_name( + source: Union[IO, Path, TOMLDocument, str] = CONFIG_PATH.pyproject +) -> Optional[str]: + if isinstance(source, TOMLDocument): + pyproject = source + else: + pyproject = load_pyproject(source) + project = get_sub_table(pyproject, "project", create=True) + package_name = project.get("name") + if package_name is None: + return None + return package_name + + +def get_package_name_safe( + source: Union[IO, Path, TOMLDocument, str] = CONFIG_PATH.pyproject +) -> str: + package_name = get_package_name(source) + if package_name is None: + msg = ( + "Please provide a name for the package under the [project] table in" + f" {CONFIG_PATH.pyproject}" + ) + raise PrecommitError(msg) + return package_name def get_sub_table(config: Container, dotted_header: str, create: bool = False) -> Table: @@ -40,10 +137,20 @@ def get_sub_table(config: Container, dotted_header: str, create: bool = False) - return current_table -def write_pyproject(config: TOMLDocument) -> None: - src = tomlkit.dumps(config, sort_keys=True) - with open(CONFIG_PATH.pyproject, "w") as stream: - stream.write(src) +def write_pyproject( + config: TOMLDocument, target: Union[IO, Path, str] = CONFIG_PATH.pyproject +) -> None: + if isinstance(target, io.IOBase): + target.seek(0) + tomlkit.dump(config, target, sort_keys=True) + elif isinstance(target, (Path, str)): + src = tomlkit.dumps(config, sort_keys=True) + src = f"{src.strip()}\n" + with open(target, "w") as stream: + stream.write(src) + else: + msg = f"Target of type {type(target).__name__} is not supported" + raise TypeError(msg) def to_toml_array(items: Iterable[Any], enforce_multiline: bool = False) -> Array: diff --git a/tests/utilities/test_pyproject.py b/tests/utilities/test_pyproject.py new file mode 100644 index 00000000..48656af0 --- /dev/null +++ b/tests/utilities/test_pyproject.py @@ -0,0 +1,172 @@ +import io +from pathlib import Path +from textwrap import dedent, indent +from typing import Optional + +import pytest +from tomlkit.items import Table + +from repoma.errors import PrecommitError +from repoma.utilities.pyproject import ( + add_dependency, + get_package_name_safe, + get_sub_table, + load_pyproject, + to_toml_array, + write_pyproject, +) + +REPOMA_DIR = Path(__file__).absolute().parent.parent.parent + + +def test_add_dependency(): + stream = io.StringIO(dedent(""" + [project] + name = "my-package" + """)) + stream.seek(0) + dependency = "attrs" + with pytest.raises( + PrecommitError, + match=f"Listed {dependency} as a dependency under pyproject.toml", + ): + add_dependency(dependency, source=stream) + result = stream.getvalue() + print(result) # noqa: T201 # run with pytest -s + assert result == dedent(""" + [project] + name = "my-package" + dependencies = ["attrs"] + """) + + +def test_add_dependency_nested(): + stream = io.StringIO(dedent(""" + [project] + name = "my-package" + """)) + stream.seek(0) + with pytest.raises(PrecommitError): + add_dependency("ruff", optional_key=["lint", "sty", "dev"], source=stream) + result = stream.getvalue() + print(result) # noqa: T201 # run with pytest -s + assert result == dedent(""" + [project] + name = "my-package" + + [project.optional-dependencies] + lint = ["ruff"] + sty = ["my-package[lint]"] + dev = ["my-package[sty]"] + """) + + +def test_add_dependency_optional(): + stream = io.StringIO(dedent(""" + [project] + name = "my-package" + """)) + stream.seek(0) + with pytest.raises(PrecommitError): + add_dependency("ruff", optional_key="lint", source=stream) + result = stream.getvalue() + print(result) # noqa: T201 # run with pytest -s + assert result == dedent(""" + [project] + name = "my-package" + + [project.optional-dependencies] + lint = ["ruff"] + """) + + +def test_edit_toml(): + src = dedent(""" + [owner] + name = "John Smith" + age = 30 + + [owner.address] + city = "Wonderland" + street = "123 Main St" + """) + config = load_pyproject(src) + + address = get_sub_table(config, "owner.address") + address["city"] = "New York" + work = get_sub_table(config, "owner.work", create=True) + work["type"] = "scientist" + tools = get_sub_table(config, "tool", create=True) + tools["black"] = to_toml_array(["--line-length=79"], enforce_multiline=True) + + stream = io.StringIO() + write_pyproject(config, target=stream) + result = stream.getvalue() + print(indent(result, prefix=4 * " ")) # noqa: T201 # run with pytest -s + assert result == dedent(""" + [owner] + name = "John Smith" + age = 30 + + [owner.address] + city = "New York" + street = "123 Main St" + + [owner.work] + type = "scientist" + + [tool] + black = [ + "--line-length=79", + ] + """) + + +def test_get_package_name_safe(): + correct_input = io.StringIO(dedent(""" + [project] + name = "my-package" + """)) + assert get_package_name_safe(correct_input) == "my-package" + + with pytest.raises(PrecommitError, match=r"^Please provide a name for the package"): + _ = get_package_name_safe(io.StringIO("[project]")) + with pytest.raises(PrecommitError, match=r"^Please provide a name for the package"): + _ = get_package_name_safe(io.StringIO()) + + +@pytest.mark.parametrize("path", [None, REPOMA_DIR / "pyproject.toml"]) +def test_load_pyproject(path: Optional[Path]): + if path is None: + pyproject = load_pyproject() + else: + pyproject = load_pyproject(path) + assert "build-system" in pyproject + assert "tool" in pyproject + + +def test_load_pyproject_str(): + src = dedent(""" + [build-system] + build-backend = "setuptools.build_meta" + requires = [ + "setuptools>=61.2", + "setuptools_scm", + ] + + [project] + dependencies = [ + "attrs", + "sympy >=1.10", + ] + name = "my-package" + requires-python = ">=3.7" + """) + pyproject = load_pyproject(src) + assert isinstance(pyproject["build-system"], Table) + assert pyproject["project"]["dependencies"] == ["attrs", "sympy >=1.10"] # type: ignore[index] + + +def test_load_pyproject_type_error(): + with pytest.raises(TypeError, match="Source of type int is not supported"): + _ = load_pyproject(1) # type: ignore[arg-type]