From 0c0913a89ccd581faf0658e0c5e725f34dab9a5f Mon Sep 17 00:00:00 2001 From: Nicolas Tessore Date: Fri, 30 Aug 2024 12:02:55 +0100 Subject: [PATCH] API: progress protocol (#167) Implement a protocol for progress reporting. Provides an implementation using the existing `rich`-based progress bar, as well as a new `ipywidgets`-based progress bar for Jupyter notebooks. Closes: #140 --- .commitlint.rules.js | 1 + examples/example.ipynb | 301 ++++++++++++----------------------------- heracles/__init__.py | 8 ++ heracles/cli.py | 2 +- heracles/core.py | 37 ++++- heracles/fields.py | 40 +++--- heracles/mapping.py | 137 ++++++++----------- heracles/notebook.py | 76 +++++++++++ heracles/progress.py | 125 +++++++++-------- heracles/rich.py | 117 ++++++++++++++++ heracles/twopoint.py | 75 ++++------ tests/test_core.py | 33 +++++ tests/test_fields.py | 2 +- tests/test_mapping.py | 2 +- 14 files changed, 529 insertions(+), 427 deletions(-) create mode 100644 heracles/notebook.py create mode 100644 heracles/rich.py diff --git a/.commitlint.rules.js b/.commitlint.rules.js index 5d28c09..34dc91a 100644 --- a/.commitlint.rules.js +++ b/.commitlint.rules.js @@ -15,6 +15,7 @@ module.exports = { "mapper", "mapping", "progress", + "rich", "twopoint", ], ], diff --git a/examples/example.ipynb b/examples/example.ipynb index 106eebc..d3f2829 100644 --- a/examples/example.ipynb +++ b/examples/example.ipynb @@ -63,23 +63,6 @@ "## Setup" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Uncomment this to enable info-level logging. This produces quite a lot of output below, but will show exactly what is going on while you are waiting for results." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "#import logging\n", - "#logging.basicConfig(level=logging.INFO)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -89,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -106,17 +89,26 @@ "source": [ "Now the *Heracles* imports:\n", "* The top-level `heracles` module contains all general user-facing functionality.\n", - "* The `heracles.healpy` modules contains mappers based on the `healpy` package." + "* The `heracles.healpy` module contains mappers based on the `healpy` package.\n", + "* The `heracles.notebook` module contains a progress bar based on the `ipywidgets` package." ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import heracles\n", - "import heracles.healpy" + "import heracles.healpy\n", + "from heracles.notebook import Progress" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If there is an import error on the last line of the previous block, it means you need to install the `ipywidgets` package." ] }, { @@ -135,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -163,7 +155,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -194,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -247,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -267,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -291,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -329,7 +321,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -338,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -356,7 +348,7 @@ " 9: wlfs2_dr1n_vis24.5_nomag.fits['TOM_BIN_ID == 9']}" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -376,7 +368,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -385,7 +377,7 @@ "True" ] }, - "execution_count": 12, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -417,7 +409,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -451,7 +443,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -476,7 +468,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -518,66 +510,35 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "112aa2960648428db6d1d6d31d35227c", + "model_id": "dba080a903db49e89ed57d817d024ec9", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Output()" + "VBox()" ] }, "metadata": {}, "output_type": "display_data" }, { - "data": { - "text/html": [ - "
/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:295: UserWarning: positions and visibility have \n",
-       "different size\n",
-       "  warnings.warn(\"positions and visibility have different size\")\n",
-       "
\n" - ], - "text/plain": [ - "/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:295: UserWarning: positions and visibility have \n", - "different size\n", - " warnings.warn(\"positions and visibility have different size\")\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ntessore/code/heracles-ec/heracles/heracles/fields.py:299: UserWarning: positions and visibility have different size\n", + " warnings.warn(\"positions and visibility have different size\")\n" + ] } ], "source": [ - "data = heracles.map_catalogs(fields, catalogs, parallel=True, progress=True)" + "with Progress(\"mapping\") as progress:\n", + " data = heracles.map_catalogs(fields, catalogs, parallel=True, progress=progress)" ] }, { @@ -589,7 +550,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -606,7 +567,7 @@ " '...']" ] }, - "execution_count": 17, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -631,7 +592,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -679,41 +640,18 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5e458063c9564c9a93e83b3c99ad9b5d", + "model_id": "2cb5a4f82ef446429a8d4cde95f84249", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" + "VBox()" ] }, "metadata": {}, @@ -721,7 +659,8 @@ } ], "source": [ - "alms = heracles.transform(fields, data, progress=True)" + "with Progress(\"transform\") as progress:\n", + " alms = heracles.transform(fields, data, progress=progress)" ] }, { @@ -733,7 +672,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -751,7 +690,7 @@ " '...']" ] }, - "execution_count": 20, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -776,7 +715,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -792,7 +731,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -813,7 +752,7 @@ " ('G_B', 'G_B', 9, 9)]" ] }, - "execution_count": 22, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -845,7 +784,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -861,41 +800,18 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d71fd15a84cd421691a88d4dd2856176", + "model_id": "bf5ea2c4a6a94c96a09adaf86ee12c37", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" + "VBox()" ] }, "metadata": {}, @@ -916,7 +832,8 @@ " ),\n", "}\n", "\n", - "data_mm = heracles.map_catalogs(fields_mm, catalogs, parallel=True, progress=True)" + "with Progress(\"mapping\") as progress:\n", + " data_mm = heracles.map_catalogs(fields_mm, catalogs, parallel=True, progress=progress)" ] }, { @@ -928,7 +845,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -951,7 +868,7 @@ " ('V', 7)]" ] }, - "execution_count": 25, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -969,41 +886,18 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "cb32835d778f40e0900f0df84d4a805c", + "model_id": "961caa0fc0f24b1badd1e0ceb355b758", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" + "VBox()" ] }, "metadata": {}, @@ -1011,7 +905,8 @@ } ], "source": [ - "alms_mm = heracles.transform(fields_mm, data_mm, progress=True)" + "with Progress(\"transform\") as progress:\n", + " alms_mm = heracles.transform(fields_mm, data_mm, progress=progress)" ] }, { @@ -1023,7 +918,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 26, "metadata": { "scrolled": true }, @@ -1041,18 +936,18 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "703478a454124cba8c0ec2ad3859f20e", + "model_id": "ed16323ea3ff4038817c7494a2799db0", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Output()" + "VBox()" ] }, "metadata": {}, @@ -1064,33 +959,11 @@ "text": [ "OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n" ] - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n",
-       "
\n" - ], - "text/plain": [ - "\n" - ] - }, - "metadata": {}, - "output_type": "display_data" } ], "source": [ - "mms = heracles.mixing_matrices(fields, cls_mm, l1max=lmax, l2max=lmax, progress=True)" + "with Progress(\"mixmats\") as progress:\n", + " mms = heracles.mixing_matrices(fields, cls_mm, l1max=lmax, l2max=lmax, progress=progress)" ] }, { @@ -1110,7 +983,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -1131,7 +1004,7 @@ " ('G_E', 'G_B', 9, 9)]" ] }, - "execution_count": 29, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1156,7 +1029,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -1193,7 +1066,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -1230,7 +1103,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -1256,7 +1129,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -1415,7 +1288,7 @@ " " ] }, - "execution_count": 33, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -1438,7 +1311,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -1454,7 +1327,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -1471,7 +1344,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -1494,7 +1367,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -1511,7 +1384,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -1529,7 +1402,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -1569,7 +1442,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -1578,7 +1451,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 40, "metadata": {}, "outputs": [ { @@ -1622,7 +1495,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 41, "metadata": {}, "outputs": [ { @@ -1668,7 +1541,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 42, "metadata": {}, "outputs": [ { diff --git a/heracles/__init__.py b/heracles/__init__.py index 241f4bf..db894a1 100644 --- a/heracles/__init__.py +++ b/heracles/__init__.py @@ -59,6 +59,9 @@ # mapping "map_catalogs", "transform", + # progress + "NoProgress", + "Progress", # twopoint "angular_power_spectra", "debias_cls", @@ -121,6 +124,11 @@ transform, ) +from .progress import ( + NoProgress, + Progress, +) + from .twopoint import ( angular_power_spectra, debias_cls, diff --git a/heracles/cli.py b/heracles/cli.py index 31c18a0..85c3828 100644 --- a/heracles/cli.py +++ b/heracles/cli.py @@ -413,7 +413,7 @@ def map_all_selections( ) -> Iterator: """Iteratively map the catalogues defined in config.""" - from .maps import map_catalogs + from .mapping import map_catalogs # load catalogues to process catalogs = catalogs_from_config(config) diff --git a/heracles/core.py b/heracles/core.py index 218cc65..3eec1ba 100644 --- a/heracles/core.py +++ b/heracles/core.py @@ -16,7 +16,9 @@ # # You should have received a copy of the GNU Lesser General Public # License along with Heracles. If not, see . -"""module for common core functionality""" +""" +Module for common core functionality. +""" from __future__ import annotations @@ -118,3 +120,36 @@ def update_metadata(array, *sources, **metadata): raise ValueError(msg) # set the new dtype in array array.dtype = dt + + +class ExceptionExplainer: + """ + Context manager that adds a note to exceptions. + """ + + def __init__( + self, + exc_type: type[BaseException] | tuple[type[BaseException], ...], + note: str, + ) -> None: + self.exc_type = exc_type + self.note = note + + def __enter__(self) -> None: + pass + + def __exit__(self, exc_type, exc_value, traceback) -> None: + if exc_type and issubclass(exc_type, self.exc_type): + try: + exc_value.add_note(self.note) + except AttributeError: + pass + + +external_dependency_explainer = ExceptionExplainer( + ModuleNotFoundError, + "You are trying to import a Heracles module that relies on a missing " + "external dependency. These dependencies are not part of the core " + "Heracles functionality, and are therefore not installed automatically. " + "Please install the missing packages, and this error will disappear.", +) diff --git a/heracles/fields.py b/heracles/fields.py index 5582eb0..ac504f5 100644 --- a/heracles/fields.py +++ b/heracles/fields.py @@ -38,8 +38,8 @@ from numpy.typing import ArrayLike from .catalog import Catalog, CatalogPage - from .maps import Mapper - from .progress import ProgressTask + from .mapper import Mapper + from .progress import Progress # type alias for column specification @@ -164,28 +164,32 @@ async def __call__( self, catalog: Catalog, *, - progress: ProgressTask | None = None, + progress: Progress | None = None, ) -> ArrayLike: """Implementation for mapping a catalogue.""" ... -async def _pages( +async def aiter_pages( catalog: Catalog, - progress: ProgressTask | None, + progress: Progress | None, ) -> AsyncIterable[CatalogPage]: """ Asynchronous generator for the pages of a catalogue. Also manages progress updates. """ page_size = catalog.page_size - if progress: - progress.update(completed=0, total=catalog.size) + current, total = 0, catalog.size + for page in catalog: + if progress is not None: + progress.update(current, total) + await coroutines.sleep() yield page - if progress: - progress.update(advance=page_size) + + current += page_size + # suspend again to give all concurrent loops a chance to finish await coroutines.sleep() @@ -232,7 +236,7 @@ async def __call__( self, catalog: Catalog, *, - progress: ProgressTask | None = None, + progress: Progress | None = None, ) -> ArrayLike: """Map the given catalogue.""" @@ -254,7 +258,7 @@ async def __call__( ngal = 0 # map catalogue data asynchronously - async for page in _pages(catalog, progress): + async for page in aiter_pages(catalog, progress): if page.size: lon, lat = page.get(*col) w = np.ones(page.size) @@ -316,7 +320,7 @@ async def __call__( self, catalog: Catalog, *, - progress: ProgressTask | None = None, + progress: Progress | None = None, ) -> ArrayLike: """Map real values from catalogue to HEALPix map.""" @@ -334,7 +338,7 @@ async def __call__( wmean, var = 0.0, 0.0 # go through pages in catalogue and map values - async for page in _pages(catalog, progress): + async for page in aiter_pages(catalog, progress): if wcol is not None: page.delete(page[wcol] == 0) @@ -387,7 +391,7 @@ async def __call__( self, catalog: Catalog, *, - progress: ProgressTask | None = None, + progress: Progress | None = None, ) -> ArrayLike: """Map complex values from catalogue to HEALPix map.""" @@ -405,7 +409,7 @@ async def __call__( wmean, var = 0.0, 0.0 # go through pages in catalogue and get the shear values, - async for page in _pages(catalog, progress): + async for page in aiter_pages(catalog, progress): if wcol is not None: page.delete(page[wcol] == 0) @@ -450,7 +454,7 @@ async def __call__( self, catalog: Catalog, *, - progress: ProgressTask | None = None, + progress: Progress | None = None, ) -> ArrayLike: """Create a visibility map from the given catalogue.""" @@ -488,7 +492,7 @@ async def __call__( self, catalog: Catalog, *, - progress: ProgressTask | None = None, + progress: Progress | None = None, ) -> ArrayLike: """Map catalogue weights.""" @@ -506,7 +510,7 @@ async def __call__( wmean, w2mean = 0.0, 0.0 # map catalogue - async for page in _pages(catalog, progress): + async for page in aiter_pages(catalog, progress): if wcol is not None: page.delete(page[wcol] == 0) diff --git a/heracles/mapping.py b/heracles/mapping.py index 8332792..c94fb06 100644 --- a/heracles/mapping.py +++ b/heracles/mapping.py @@ -22,12 +22,12 @@ from __future__ import annotations -from contextlib import nullcontext -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable import coroutines from heracles.core import TocDict, toc_match +from heracles.progress import Progress, NoProgress if TYPE_CHECKING: from collections.abc import Mapping, MutableMapping, Sequence @@ -36,31 +36,24 @@ from heracles.catalog import Catalog from heracles.fields import Field - from heracles.progress import Progress, ProgressTask -async def _map_progress( +async def _map_field( key: tuple[Any, ...], field: Field, catalog: Catalog, - progress: Progress | None, + progress: Progress, + task_done: Callable[[], None], ) -> NDArray: """ - Coroutine that keeps track of progress. + Coroutine to map an individual field. """ - task: ProgressTask | None - if progress is not None: - name = "[" + ", ".join(map(str, key)) + "]" - task = progress.task(name, subtask=True, total=None) - else: - task = None + label = "(" + ", ".join(map(str, key)) + ")" + with progress.task(label) as task: + result = await field(catalog, progress=task) - result = await field(catalog, progress=task) - - if progress is not None: - task.remove() - progress.advance(progress.task_ids[0]) + task_done() return result @@ -73,7 +66,7 @@ def map_catalogs( out: MutableMapping[tuple[Any, Any], NDArray] | None = None, include: Sequence[tuple[Any, Any]] | None = None, exclude: Sequence[tuple[Any, Any]] | None = None, - progress: bool = False, + progress: Progress | None = None, ) -> MutableMapping[tuple[Any, Any], NDArray]: """Map a set of catalogues to fields.""" @@ -81,6 +74,10 @@ def map_catalogs( if out is None: out = TocDict() + # create dummy progress object if none was given + if progress is None: + progress = NoProgress() + # collect groups of items to go through # items are tuples of (key, field, catalog) groups = [ @@ -92,45 +89,39 @@ def map_catalogs( if parallel: groups = [sum(groups, [])] - # display a progress bar if asked to - progressbar: Progress | nullcontext - if progress: - from heracles.progress import Progress + # progress tracking + current, total = 0, sum(map(len, groups)) + progress.update(0, total) - # create the progress bar - # add the main task -- this must be the first task - progressbar = Progress() - progressbar.add_task("mapping", total=sum(map(len, groups))) - else: - progressbar = nullcontext() + def _task_done(): + """callback for async execution""" + nonlocal current + current += 1 + progress.update(current, total) # process all groups of fields and catalogues - with progressbar as prog: - for items in groups: - # fields return coroutines, which are ran concurrently - keys, coros = [], [] - for key, field, catalog in items: - if toc_match(key, include, exclude): - keys.append(key) - coros.append(_map_progress(key, field, catalog, prog)) - - # run all coroutines concurrently - try: - results = coroutines.run(coroutines.gather(*coros)) - finally: - # force-close coroutines to prevent "never awaited" warnings - for coro in coros: - coro.close() - - # store results - for key, value in zip(keys, results): - out[key] = value - - # free up memory for next group - del results - - if prog is not None: - prog.refresh() + for items in groups: + # fields return coroutines, which are ran concurrently + keys, coros = [], [] + for key, field, catalog in items: + if toc_match(key, include, exclude): + keys.append(key) + coros.append(_map_field(key, field, catalog, progress, _task_done)) + + # run all coroutines concurrently + try: + results = coroutines.run(coroutines.gather(*coros)) + finally: + # force-close coroutines to prevent "never awaited" warnings + for coro in coros: + coro.close() + + # store results + for key, value in zip(keys, results): + out[key] = value + + # free up memory for next group + del results # return the toc dict return out @@ -141,8 +132,7 @@ def transform( data: Mapping[tuple[Any, Any], NDArray], *, out: MutableMapping[tuple[Any, Any], NDArray] | None = None, - progress: bool = False, - **kwargs, + progress: Progress | None = None, ) -> MutableMapping[tuple[Any, Any], NDArray]: """transform data to alms""" @@ -150,41 +140,26 @@ def transform( if out is None: out = TocDict() - # display a progress bar if asked to - progressbar: Progress | nullcontext - if progress: - from heracles.progress import Progress + # create dummy progress object if none was given + if progress is None: + progress = NoProgress() - progressbar = Progress() - task = progressbar.task("transform", total=len(data)) - else: - progressbar = nullcontext() + # progress reporting + current, total = 0, len(data) # convert data to alms, taking care of complex and spin-weighted fields - with progressbar as prog: - for (k, i), m in data.items(): - if progress: - subtask = prog.task( - f"[{k}, {i}]", - subtask=True, - start=False, - total=None, - ) + for (k, i), m in data.items(): + current += 1 + progress.update(current, total) + with progress.task(f"({k}, {i})"): try: field = fields[k] except KeyError: msg = f"unknown field name: {k}" - raise ValueError(msg) + raise ValueError(msg) from None out[k, i] = field.mapper_or_error.transform(m) - if progress: - subtask.remove() - task.update(advance=1) - - if prog is not None: - prog.refresh() - # return the toc dict of alms return out diff --git a/heracles/notebook.py b/heracles/notebook.py new file mode 100644 index 0000000..9b554ce --- /dev/null +++ b/heracles/notebook.py @@ -0,0 +1,76 @@ +# Heracles: Euclid code for harmonic-space statistics on the sphere +# +# Copyright (C) 2023-2024 Euclid Science Ground Segment +# +# This file is part of Heracles. +# +# Heracles is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Heracles is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with Heracles. If not, see . +""" +Module for Jupyter (IPython) notebook integration. +""" + +from __future__ import annotations + +from .core import external_dependency_explainer + +with external_dependency_explainer: + import ipywidgets as widgets + from IPython.display import display + + +class Progress: + """ + Progress bar using ipywidgets. + """ + + def __init__(self, label: str, *, box: widgets.Box | None = None) -> None: + if box is None: + self.box = widgets.VBox() + else: + self.box = box + self.widget = widgets.IntProgress( + value=0, + min=0, + max=1, + description=label, + orientation="horizontal", + ) + + def __enter__(self) -> "Progress": + if not self.box.children: + display(self.box) + self.box.children += (self.widget,) + return self + + def __exit__(self, *exc) -> None: + self.widget.close() + try: + index = self.box.children.index(self.widget) + except ValueError: + pass + else: + self.box.children = ( + self.box.children[:index] + self.box.children[index + 1 :] + ) + if not self.box.children: + self.box.close() + + def update(self, current: int | None = None, total: int | None = None) -> None: + if current is not None: + self.widget.value = current + if total is not None: + self.widget.max = total + + def task(self, label: str) -> "Progress": + return self.__class__(label, box=self.box) diff --git a/heracles/progress.py b/heracles/progress.py index e066750..b013f9e 100644 --- a/heracles/progress.py +++ b/heracles/progress.py @@ -16,81 +16,78 @@ # # You should have received a copy of the GNU Lesser General Public # License along with Heracles. If not, see . -"""module for progress reporting with rich""" - -try: - import rich.box - import rich.panel - import rich.progress -except ModuleNotFoundError as exc: - try: - exc.add_note("You do not have the 'rich' package installed.") - exc.add_note("Disabling progress reports should fix this error.") - except AttributeError: - pass - raise +""" +Module for the progress reporting protocol. +""" +from __future__ import annotations -class ProgressTask: - """ - A wrapper for tasks that forwards calls with their task ID. - """ +from typing import Protocol - def __init__( - self, - progress: rich.progress.Progress, - task_id: rich.progress.TaskID, - ) -> None: - self.progress = progress - self.task_id = task_id - def start(self) -> None: - self.progress.start_task(self.task_id) +class Progress(Protocol): + """ + Protocol for progress reporting. Implementations of this protocol are + meant to be used as a context manager:: - def stop(self) -> None: - self.progress.stop_task(self.task_id) + with MyProgress("working") as progress: + for i in range(100): + # main progress update + progress.update(i + 1, 100) - def remove(self) -> None: - self.progress.remove_task(self.task_id) + # report progress of an individual task + with progress.task(f"subtask {i + 1}") as task: + for j in range(10): + task.update(j + 1, 10) - def update(self, *args, **kwargs): - self.progress.update(self.task_id, *args, **kwargs) + """ - def track(self, *args, **kwargs): - return self.progress.track(*args, task_id=self.task_id, **kwargs) + def update(self, current: int | None = None, total: int | None = None) -> None: + """ + Update progress. + """ + def task(self, label: str) -> "Progress": + """ + Create a task with the given label. + """ -class Progress(rich.progress.Progress): - """ - A progress bar that + def __enter__(self) -> "Progress": + """ + Start progress. + """ - a) returns ProgressTask instances on task creation, - b) allows creation of "subtasks" by passing subtask=True. + def __exit__(self, *exc) -> None: + """ + Stop progress. + """ - Subtasks are shown below the main tasks and separated by a divider. + +class NoProgress: + """ + Dummy progress reporter. """ - def task(self, *args, **kwargs) -> ProgressTask: - task_id = self.add_task(*args, **kwargs) - return ProgressTask(self, task_id) - - @classmethod - def get_default_columns(cls): - return ( - rich.progress.TextColumn("[progress.description]{task.description}"), - rich.progress.BarColumn(bar_width=20), - rich.progress.TaskProgressColumn(), - rich.progress.TimeElapsedColumn(), - ) - - def make_tasks_table(self, tasks): - def _is_subtask(task): - return bool(task.fields.get("subtask")) - - subtask_count = sum(map(_is_subtask, tasks)) - sorted_tasks = sorted(tasks, key=_is_subtask) - table = super().make_tasks_table(sorted_tasks) - table.box = rich.box.HORIZONTALS - if len(table.rows) > subtask_count: - table.rows[-subtask_count - 1].end_section = True - return table + def update(self, current: int | None = None, total: int | None = None) -> None: + """ + Dummy progress update (does nothing). + """ + pass + + def task(self, label: str) -> "NoProgress": + """ + Create a dummy task (does nothing). + """ + return NoProgress() + + def __enter__(self) -> "NoProgress": + """ + Start dummy progress (does nothing). + """ + return self + + def __exit__(self, *exc) -> None: + """ + Stop dummy progress (does nothing). + """ + pass diff --git a/heracles/rich.py b/heracles/rich.py new file mode 100644 index 0000000..1a91541 --- /dev/null +++ b/heracles/rich.py @@ -0,0 +1,117 @@ +# Heracles: Euclid code for harmonic-space statistics on the sphere +# +# Copyright (C) 2023-2024 Euclid Science Ground Segment +# +# This file is part of Heracles. +# +# Heracles is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Heracles is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with Heracles. If not, see . +""" +Module for the integration with the rich package. +Contains a progress bar implementation. +""" + +from __future__ import annotations + +from .core import external_dependency_explainer + +with external_dependency_explainer: + import rich.box + import rich.panel + import rich.progress + + +class _RichProgressBar(rich.progress.Progress): + """ + Rich progress bar subclass with customisations. + """ + + @classmethod + def get_default_columns(cls): + """ + Default columns for progress reporting. + """ + return ( + rich.progress.TextColumn("[progress.description]{task.description}"), + rich.progress.BarColumn(bar_width=20), + rich.progress.TaskProgressColumn(), + rich.progress.TimeElapsedColumn(), + ) + + def make_tasks_table(self, tasks): + """ + Create a table of tasks sorted by their depths. + """ + sorted_tasks = sorted(tasks, key=lambda task: int(task.fields.get("depth", -1))) + table = super().make_tasks_table(sorted_tasks) + table.box = rich.box.HORIZONTALS + depth = -1 + for i, task in enumerate(sorted_tasks): + if (d := task.fields.get("depth", -1)) != depth: + depth = d + if i > 0: + table.rows[i - 1].end_section = True + return table + + +class Progress: + """ + Progress bar using rich. + """ + + def __init__( + self, + label: str, + *, + progress: rich.progress.Progress | None = None, + depth: int = 0, + ) -> None: + if progress is None: + self.progress = _RichProgressBar() + else: + self.progress = progress + self.label = label + self.depth = depth + self.task_id: rich.progress.TaskID | None = None + + def __enter__(self) -> "Progress": + if not self.progress.tasks: + self.progress.start() + if self.task_id is None: + self.task_id = self.progress.add_task( + self.label, + start=True, + total=None, + depth=self.depth, + ) + else: + self.progress.start_task(self.task_id) + self.progress.refresh() + return self + + def __exit__(self, *exc) -> None: + if self.task_id is not None: + self.progress.stop_task(self.task_id) + self.progress.remove_task(self.task_id) + self.task_id = None + if not self.progress.tasks: + self.progress.stop() + self.progress.refresh() + + def update(self, current: int | None = None, total: int | None = None) -> None: + if self.task_id is not None: + self.progress.update(self.task_id, total=total, completed=current) + self.progress.refresh() + + def task(self, label: str) -> "Progress": + return self.__class__(label, progress=self.progress, depth=self.depth + 1) diff --git a/heracles/twopoint.py b/heracles/twopoint.py index f25a10b..71cf6b4 100644 --- a/heracles/twopoint.py +++ b/heracles/twopoint.py @@ -22,7 +22,6 @@ import logging import time -from contextlib import nullcontext from datetime import timedelta from itertools import combinations_with_replacement, product from typing import TYPE_CHECKING, Any @@ -30,6 +29,7 @@ import numpy as np from .core import TocDict, toc_match, update_metadata +from .progress import NoProgress, Progress if TYPE_CHECKING: from collections.abc import Iterator, Mapping, MutableMapping @@ -37,7 +37,6 @@ from numpy.typing import ArrayLike, NDArray from .fields import Field - from .progress import Progress # type alias for the keys of two-point data TwoPointKey = tuple[Any, Any, Any, Any] @@ -335,7 +334,7 @@ def mixing_matrices( bins: ArrayLike | None = None, weights: str | ArrayLike | None = None, out: MutableMapping[TwoPointKey, ArrayLike] | None = None, - progress: bool = False, + progress: Progress | None = None, ) -> MutableMapping[TwoPointKey, ArrayLike]: """compute mixing matrices for fields from a set of cls""" @@ -345,6 +344,10 @@ def mixing_matrices( if out is None: out = TocDict() + # create dummy progress object if none was given + if progress is None: + progress = NoProgress() + # inverse mapping of masks to fields masks: dict[str, dict[Any, Field]] = {} for key, field in fields.items(): @@ -356,47 +359,33 @@ def mixing_matrices( # keep track of combinations that have been done already done = set() - # display a progress bar if asked to - progressbar: Progress | nullcontext[None] - if progress: - from heracles.progress import Progress - - progressbar = Progress() - progressbar.task("mixing matrices", total=None) - else: - progressbar = nullcontext() - # go through the toc dict of cls and compute mixing matrices # which mixing matrix is computed depends on the `masks` mapping - with progressbar as prog: - for (k1, k2, i1, i2), cl in cls.items(): - # if the masks are not named then skip this cl - try: - fields1 = masks[k1] - fields2 = masks[k2] - except KeyError: - continue + current, total = 0, len(cls) + for (k1, k2, i1, i2), cl in cls.items(): + current += 1 + progress.update(current, total) + + # if the masks are not named then skip this cl + try: + fields1 = masks[k1] + fields2 = masks[k2] + except KeyError: + continue - # deal with structured cl arrays - if cl.dtype.names is not None: - cl = cl["CL"] - - # compute mixing matrices for all fields of this mask combination - for f1, f2 in product(fields1, fields2): - # check if this combination has been done already - if (f1, f2, i1, i2) in done or (f2, f1, i2, i1) in done: - continue - # otherwise, mark it as done - done.add((f1, f2, i1, i2)) - - if prog is not None: - subtask = prog.task( - f"[{f1}, {f2}, {i1}, {i2}]", - subtask=True, - start=False, - total=None, - ) + # deal with structured cl arrays + if cl.dtype.names is not None: + cl = cl["CL"] + + # compute mixing matrices for all fields of this mask combination + for f1, f2 in product(fields1, fields2): + # check if this combination has been done already + if (f1, f2, i1, i2) in done or (f2, f1, i2, i1) in done: + continue + # otherwise, mark it as done + done.add((f1, f2, i1, i2)) + with progress.task(f"({f1}, {f2}, {i1}, {i2})"): # get spins of fields spin1, spin2 = fields1[f1].spin, fields2[f2].spin @@ -433,12 +422,6 @@ def mixing_matrices( out[f"{f1}_E", f"{f2}_B", i1, i2] = mm_eb del mm_ee, mm_bb, mm_eb - if prog is not None: - subtask.remove() - - if prog is not None: - prog.refresh() - # return the toc dict of mixing matrices return out diff --git a/tests/test_core.py b/tests/test_core.py index 6de46c1..be7727b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -120,3 +120,36 @@ def test_update_metadata(): assert a.dtype.fields == a_fields_original assert a.dtype.metadata == {"x": 1, "y": 2} + + +def test_exception_explainer(): + from heracles.core import ExceptionExplainer + + class TestException(BaseException): + def __init__(self): + super().__init__() + self.notes = [] + + def add_note(self, note): + self.notes.append(note) + + with ExceptionExplainer(TestException, "explanation"): + pass + + with pytest.raises(TestException) as excinfo: + with ExceptionExplainer(TestException, "explanation"): + raise TestException() + + assert excinfo.value.notes == ["explanation"] + + with pytest.raises(TestException) as excinfo: + with ExceptionExplainer((TestException, ValueError), "explanation"): + raise TestException() + + assert excinfo.value.notes == ["explanation"] + + with pytest.raises(TestException) as excinfo: + with ExceptionExplainer(ValueError, "explanation"): + raise TestException() + + assert excinfo.value.notes == [] diff --git a/tests/test_fields.py b/tests/test_fields.py index e66e7ee..d8fff0a 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -67,7 +67,7 @@ def catalog(page): from unittest.mock import Mock catalog = Mock() - catalog.size = page.size + catalog.size = catalog.page_size = page.size catalog.visibility = None catalog.fsky = None catalog.metadata = {"catalog": catalog.label} diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 0eb23ab..8bb3ab8 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -27,7 +27,7 @@ def test_map_catalogs(parallel): for k in fields: for i in catalogs: - fields[k].assert_any_call(catalogs[i], progress=None) + fields[k].assert_any_call(catalogs[i], progress=unittest.mock.ANY) assert maps[k, i] is fields[k].return_value