Skip to content

Commit

Permalink
DX: automatically add jupyterlab requirements to notebook repos (#226)
Browse files Browse the repository at this point in the history
* ENH: add `source` argument to `load_pyproject()`
* ENH: add `target` argument for `write_pyproject()`
* ENH: remove duplicate final newlines from dumped TOML
* ENH: type Python version with `Literal`
* MAINT: extract `add_dependency()` function for `pyproject.toml`
* MAINT: test `get_package_name()` function
* MAINT: test `add_dependency()` function
  • Loading branch information
redeboer committed Nov 29, 2023
1 parent f6fb9b3 commit 45ebe3e
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 46 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"indentless",
"ipynb",
"jsonschema",
"jupyterlab",
"linkcheck",
"maxdepth",
"maxsplit",
Expand Down
3 changes: 3 additions & 0 deletions src/repoma/check_dev_files/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
github_labels,
github_workflows,
gitpod,
jupyter,
mypy,
nbstripout,
precommit,
Expand Down Expand Up @@ -58,6 +59,8 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
skip_tests=_to_list(args.ci_skipped_tests),
test_extras=_to_list(args.ci_test_extras),
)
if has_notebooks:
executor(jupyter.main)
executor(nbstripout.main)
executor(toml.main) # has to run before pre-commit
executor(prettier.main, args.no_prettierrc)
Expand Down
26 changes: 26 additions & 0 deletions src/repoma/check_dev_files/jupyter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Update the developer setup when using Jupyter notebooks."""

from repoma.utilities.executor import Executor
from repoma.utilities.project_info import get_supported_python_versions
from repoma.utilities.pyproject import add_dependency


def main() -> None:
_update_dev_requirements()


def _update_dev_requirements() -> None:
if "3.6" in get_supported_python_versions():
return
hierarchy = ["jupyter", "dev"]
dependencies = [
"jupyterlab",
"jupyterlab-code-formatter",
"jupyterlab-lsp",
"jupyterlab-myst",
"python-lsp-server[rope]",
]
executor = Executor()
for dependency in dependencies:
executor(add_dependency, dependency, optional_key=hierarchy)
executor.finalize()
31 changes: 2 additions & 29 deletions src/repoma/check_dev_files/ruff.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Check `Ruff <https://ruff.rs>`_ configuration."""

import os
from copy import deepcopy
from textwrap import dedent
from typing import List, Set

Expand All @@ -27,6 +26,7 @@
open_setup_cfg,
)
from repoma.utilities.pyproject import (
add_dependency,
complies_with_subset,
get_sub_table,
load_pyproject,
Expand Down Expand Up @@ -230,39 +230,12 @@ def _update_pyproject() -> None:
f" [{CONFIG_PATH.pyproject}]"
)
raise PrecommitError(msg)
project = get_sub_table(pyproject, "project", create=True)
old_dependencies = project.get("optional-dependencies")
new_dependencies = deepcopy(old_dependencies)
python_versions = project_info.supported_python_versions
if python_versions is not None and "3.6" in python_versions:
ruff = 'ruff; python_version >="3.7.0"'
else:
ruff = "ruff"
if new_dependencies is None:
new_dependencies = dict(
dev=[f"{package}[sty]"],
lint=[ruff],
sty=[f"{package}[lint]"],
)
else:
__add_package(new_dependencies, "dev", f"{package}[sty]")
__add_package(new_dependencies, "lint", ruff)
__add_package(new_dependencies, "sty", f"{package}[lint]")
if old_dependencies != new_dependencies:
project["optional-dependencies"] = new_dependencies
write_pyproject(pyproject)
msg = f"Updated [project.optional-dependencies] in {CONFIG_PATH.pyproject}"
raise PrecommitError(msg)


def __add_package(optional_dependencies: Table, key: str, package: str) -> None:
section = optional_dependencies.get(key)
if section is None:
optional_dependencies[key] = [package]
elif package not in section:
optional_dependencies[key] = to_toml_array(
sorted({package, *section}, key=lambda s: ('"' in s, s)) # Taplo sorting
)
add_dependency(ruff, optional_key=["lint", "sty", "dev"])


def _remove_nbqa() -> None:
Expand Down
17 changes: 13 additions & 4 deletions src/repoma/utilities/project_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Helper functions for reading from and writing to :file:`setup.cfg`."""

import os
import sys
from configparser import ConfigParser
from textwrap import dedent
from typing import Dict, List, Optional
Expand All @@ -14,11 +15,19 @@
from . import CONFIG_PATH
from .cfg import open_config

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


PythonVersion = Literal["3.6", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]


@frozen
class ProjectInfo:
name: Optional[str] = None
supported_python_versions: Optional[List[str]] = None
supported_python_versions: Optional[List[PythonVersion]] = None
urls: Dict[str, str] = field(factory=dict)

def is_empty(self) -> bool:
Expand Down Expand Up @@ -80,13 +89,13 @@ def get_project_info(pyproject: Optional[TOMLDocument] = None) -> ProjectInfo:
raise PrecommitError(msg)


def _extract_python_versions(classifiers: List[str]) -> Optional[List[str]]:
def _extract_python_versions(classifiers: List[str]) -> Optional[List[PythonVersion]]:
identifier = "Programming Language :: Python :: 3."
version_classifiers = [s for s in classifiers if s.startswith(identifier)]
if not version_classifiers:
return None
prefix = identifier[:-2]
return [s.replace(prefix, "") for s in version_classifiers]
return [s.replace(prefix, "") for s in version_classifiers] # type: ignore[misc]


def get_pypi_name(pyproject: Optional[TOMLDocument] = None) -> str:
Expand All @@ -107,7 +116,7 @@ def get_pypi_name(pyproject: Optional[TOMLDocument] = None) -> str:

def get_supported_python_versions(
pyproject: Optional[TOMLDocument] = None,
) -> List[str]:
) -> List[PythonVersion]:
"""Extract supported Python versions from package classifiers.
>>> get_supported_python_versions()
Expand Down
133 changes: 120 additions & 13 deletions src/repoma/utilities/pyproject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Tools for loading, inspecting, and updating :code:`pyproject.toml`."""

import os
from typing import Any, Iterable, Optional
import io
from collections import abc
from pathlib import Path
from typing import IO, Any, Iterable, List, Optional, Sequence, Set, Union

import tomlkit
from tomlkit.container import Container
Expand All @@ -10,20 +12,115 @@

from repoma.errors import PrecommitError
from repoma.utilities import CONFIG_PATH
from repoma.utilities.executor import Executor
from repoma.utilities.precommit import find_repo, load_round_trip_precommit_config


def add_dependency( # noqa: C901, PLR0912
package: str,
optional_key: Optional[Union[str, Sequence[str]]] = None,
source: Union[IO, Path, TOMLDocument, str] = CONFIG_PATH.pyproject,
target: Optional[Union[IO, Path, str]] = None,
) -> None:
if isinstance(source, TOMLDocument):
pyproject = source
else:
pyproject = load_pyproject(source)
if target is None:
if isinstance(source, TOMLDocument):
msg = "If the source is a TOML document, you have to specify a target"
raise TypeError(msg)
target = source
if optional_key is None:
project = get_sub_table(pyproject, "project", create=True)
existing_dependencies: Set[str] = set(project.get("dependencies", []))
if package in existing_dependencies:
return
existing_dependencies.add(package)
project["dependencies"] = to_toml_array(_sort_taplo(existing_dependencies))
elif isinstance(optional_key, str):
optional_dependencies = get_sub_table(
pyproject, "project.optional-dependencies", create=True
)
existing_dependencies = set(optional_dependencies.get(optional_key, []))
if package in existing_dependencies:
return
existing_dependencies.add(package)
existing_dependencies = set(existing_dependencies)
optional_dependencies[optional_key] = to_toml_array(
_sort_taplo(existing_dependencies)
)
elif isinstance(optional_key, abc.Sequence):
if len(optional_key) < 2: # noqa: PLR2004
msg = "Need at least two keys to define nested optional dependencies"
raise ValueError(msg)
this_package = get_package_name_safe(pyproject)
executor = Executor()
for key, previous in zip(optional_key, [None, *optional_key]):
if previous is None:
executor(add_dependency, package, key, source, target)
else:
executor(
add_dependency, f"{this_package}[{previous}]", key, source, target
)
if executor.finalize() == 0:
return
else:
msg = f"Unsupported type for optional_key: {type(optional_key)}"
raise NotImplementedError(msg)
write_pyproject(pyproject, target)
msg = f"Listed {package} as a dependency under {CONFIG_PATH.pyproject}"
raise PrecommitError(msg)


def _sort_taplo(items: Iterable[str]) -> List[str]:
return sorted(items, key=lambda s: ('"' in s, s))


def complies_with_subset(settings: dict, minimal_settings: dict) -> bool:
return all(settings.get(key) == value for key, value in minimal_settings.items())


def load_pyproject(content: Optional[str] = None) -> TOMLDocument:
if not os.path.exists(CONFIG_PATH.pyproject):
return TOMLDocument()
if content is None:
with open(CONFIG_PATH.pyproject) as stream:
return tomlkit.loads(stream.read())
return tomlkit.loads(content)
def load_pyproject(
source: Union[IO, Path, str] = CONFIG_PATH.pyproject
) -> TOMLDocument:
if isinstance(source, io.IOBase):
source.seek(0)
return tomlkit.load(source)
if isinstance(source, Path):
with open(source) as stream:
return load_pyproject(stream)
if isinstance(source, str):
return tomlkit.loads(source)
msg = f"Source of type {type(source).__name__} is not supported"
raise TypeError(msg)


def get_package_name(
source: Union[IO, Path, TOMLDocument, str] = CONFIG_PATH.pyproject
) -> Optional[str]:
if isinstance(source, TOMLDocument):
pyproject = source
else:
pyproject = load_pyproject(source)
project = get_sub_table(pyproject, "project", create=True)
package_name = project.get("name")
if package_name is None:
return None
return package_name


def get_package_name_safe(
source: Union[IO, Path, TOMLDocument, str] = CONFIG_PATH.pyproject
) -> str:
package_name = get_package_name(source)
if package_name is None:
msg = (
"Please provide a name for the package under the [project] table in"
f" {CONFIG_PATH.pyproject}"
)
raise PrecommitError(msg)
return package_name


def get_sub_table(config: Container, dotted_header: str, create: bool = False) -> Table:
Expand All @@ -40,10 +137,20 @@ def get_sub_table(config: Container, dotted_header: str, create: bool = False) -
return current_table


def write_pyproject(config: TOMLDocument) -> None:
src = tomlkit.dumps(config, sort_keys=True)
with open(CONFIG_PATH.pyproject, "w") as stream:
stream.write(src)
def write_pyproject(
config: TOMLDocument, target: Union[IO, Path, str] = CONFIG_PATH.pyproject
) -> None:
if isinstance(target, io.IOBase):
target.seek(0)
tomlkit.dump(config, target, sort_keys=True)
elif isinstance(target, (Path, str)):
src = tomlkit.dumps(config, sort_keys=True)
src = f"{src.strip()}\n"
with open(target, "w") as stream:
stream.write(src)
else:
msg = f"Target of type {type(target).__name__} is not supported"
raise TypeError(msg)


def to_toml_array(items: Iterable[Any], enforce_multiline: bool = False) -> Array:
Expand Down
Loading

0 comments on commit 45ebe3e

Please sign in to comment.