From aa73516d373d1a60455a127a2a33c96d6c9be79a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Jul 2023 08:13:59 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/_relink_references.py | 6 ++- docs/_unsrt_et_al.py | 15 +++++-- docs/comparison/d2kkk.ipynb | 69 +++++++++++++++++++----------- docs/comparison/jpsi2phipipi.ipynb | 69 +++++++++++++++++++----------- docs/comparison/jpsi2pipipi.ipynb | 69 +++++++++++++++++++----------- docs/conf.py | 11 ++--- docs/jpsi2ksp.ipynb | 34 ++++++++++----- docs/lc2pkpi.ipynb | 14 +++--- src/ampform_dpd/__init__.py | 30 ++++++++----- src/ampform_dpd/_attrs.py | 9 ++-- src/ampform_dpd/angles.py | 25 ++++++----- src/ampform_dpd/decay.py | 41 +++++++++++------- src/ampform_dpd/dynamics.py | 8 +++- src/ampform_dpd/io.py | 40 +++++++++-------- src/ampform_dpd/spin.py | 20 ++++----- tests/test_angles.py | 8 ++-- tests/test_decay.py | 2 +- tests/test_io.py | 8 ++-- 18 files changed, 299 insertions(+), 179 deletions(-) diff --git a/docs/_relink_references.py b/docs/_relink_references.py index 73a591d3..a66c6515 100644 --- a/docs/_relink_references.py +++ b/docs/_relink_references.py @@ -8,11 +8,15 @@ """ from __future__ import annotations +from typing import TYPE_CHECKING + import sphinx.domains.python from docutils import nodes from sphinx.addnodes import pending_xref, pending_xref_condition from sphinx.domains.python import parse_reftarget -from sphinx.environment import BuildEnvironment + +if TYPE_CHECKING: + from sphinx.environment import BuildEnvironment __TARGET_SUBSTITUTIONS = { "Literal[- 1, 1]": "typing.Literal", diff --git a/docs/_unsrt_et_al.py b/docs/_unsrt_et_al.py index dfe91ca2..7d23a167 100644 --- a/docs/_unsrt_et_al.py +++ b/docs/_unsrt_et_al.py @@ -1,10 +1,17 @@ from pybtex.database import Entry from pybtex.richtext import Tag, Text from pybtex.style.formatting.unsrt import Style as UnsrtStyle -from pybtex.style.template import \ - _format_list # pyright: ignore[reportPrivateUsage] -from pybtex.style.template import (FieldIsMissing, Node, field, href, join, - node, sentence, words) +from pybtex.style.template import ( + FieldIsMissing, + Node, + _format_list, # pyright: ignore[reportPrivateUsage] + field, + href, + join, + node, + sentence, + words, +) class MyStyle(UnsrtStyle): # type: ignore[reportUntypedBaseClass] diff --git a/docs/comparison/d2kkk.ipynb b/docs/comparison/d2kkk.ipynb index 6aa5414f..0628f592 100644 --- a/docs/comparison/d2kkk.ipynb +++ b/docs/comparison/d2kkk.ipynb @@ -35,7 +35,7 @@ "import logging\n", "import os\n", "from textwrap import dedent\n", - "from typing import Iterable\n", + "from typing import TYPE_CHECKING, Iterable\n", "\n", "import ampform\n", "import graphviz\n", @@ -44,28 +44,49 @@ "import matplotlib.pyplot as plt\n", "import qrules\n", "import sympy as sp\n", - "from ampform.helicity import HelicityModel\n", "from ampform.kinematics import FourMomentumSymbol, InvariantMass\n", "from ampform.sympy import perform_cached_doit\n", "from IPython.display import SVG, Latex, Markdown, clear_output, display\n", - "from ipywidgets import (Accordion, Checkbox, GridBox, HBox, Layout,\n", - " SelectMultiple, Tab, ToggleButtons, VBox,\n", - " interactive_output)\n", - "from qrules.transition import ReactionInfo\n", + "from ipywidgets import (\n", + " Accordion,\n", + " Checkbox,\n", + " GridBox,\n", + " HBox,\n", + " Layout,\n", + " SelectMultiple,\n", + " Tab,\n", + " ToggleButtons,\n", + " VBox,\n", + " interactive_output,\n", + ")\n", "from tensorwaves.data.phasespace import TFPhaseSpaceGenerator\n", "from tensorwaves.data.rng import TFUniformRealNumberGenerator\n", "from tensorwaves.data.transform import SympyDataTransformer\n", - "from tensorwaves.interface import (DataSample, ParameterValue,\n", - " ParametrizedFunction)\n", - "\n", - "from ampform_dpd import (DalitzPlotDecompositionBuilder,\n", - " simplify_latex_rendering)\n", - "from ampform_dpd.decay import (IsobarNode, Particle, ThreeBodyDecay,\n", - " ThreeBodyDecayChain)\n", - "from ampform_dpd.io import (as_markdown_table, aslatex, get_readable_hash,\n", - " perform_cached_lambdify)\n", + "\n", + "from ampform_dpd import DalitzPlotDecompositionBuilder, simplify_latex_rendering\n", + "from ampform_dpd.decay import (\n", + " IsobarNode,\n", + " Particle,\n", + " ThreeBodyDecay,\n", + " ThreeBodyDecayChain,\n", + ")\n", + "from ampform_dpd.io import (\n", + " as_markdown_table,\n", + " aslatex,\n", + " get_readable_hash,\n", + " perform_cached_lambdify,\n", + ")\n", "from ampform_dpd.spin import filter_parity_violating_ls, generate_ls_couplings\n", "\n", + "if TYPE_CHECKING:\n", + " from ampform.helicity import HelicityModel\n", + " from qrules.transition import ReactionInfo\n", + " from tensorwaves.interface import (\n", + " DataSample,\n", + " ParameterValue,\n", + " ParametrizedFunction,\n", + " )\n", + "\n", "simplify_latex_rendering()\n", "logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n", @@ -628,17 +649,16 @@ "outputs": [], "source": [ "def create_sliders() -> dict[str, ToggleButtons]:\n", - " all_parameters = {k: v for k, v in ampform_model.parameter_defaults.items()}\n", - " all_parameters.update({k: v for k, v in dpd_model.parameter_defaults.items()})\n", + " all_parameters = dict(ampform_model.parameter_defaults.items())\n", + " all_parameters.update(dict(dpd_model.parameter_defaults.items()))\n", " sliders = {}\n", " for symbol, value in all_parameters.items():\n", " value = \"+1\"\n", " if (\n", " symbol.name.startswith(R\"\\mathcal{H}^\\mathrm{decay}\")\n", " and \"+\" in symbol.name\n", - " ):\n", - " if any(s in symbol.name for s in [\"{1}\", \"*\", \"rho\"]):\n", - " value = \"-1\"\n", + " ) and any(s in symbol.name for s in [\"{1}\", \"*\", \"rho\"]):\n", + " value = \"-1\"\n", " sliders[symbol.name] = ToggleButtons(\n", " description=Rf\"\\({sp.latex(symbol)}\\)\",\n", " options=[\"-1\", \"0\", \"+1\"],\n", @@ -651,7 +671,7 @@ "def to_unicode(particle: Particle) -> str:\n", " unicode = particle.name\n", " unicode = unicode.replace(\"pi\", \"π\")\n", - " unicode = unicode.replace(\"rho\", \"ρ\")\n", + " unicode = unicode.replace(\"rho\", \"p\")\n", " unicode = unicode.replace(\"Sigma\", \"Σ\")\n", " unicode = unicode.replace(\"~\", \"\")\n", " unicode = unicode.replace(\"Σ\", \"~Σ\")\n", @@ -659,8 +679,7 @@ " unicode = unicode.replace(\"-\", \"⁻\")\n", " unicode = unicode.replace(\"(0)\", \"₀\")\n", " unicode = unicode.replace(\"(1)\", \"₁\")\n", - " unicode = unicode.replace(\")0\", \")⁰\")\n", - " return unicode\n", + " return unicode.replace(\")0\", \")⁰\")\n", "\n", "\n", "sliders = create_sliders()\n", @@ -880,8 +899,8 @@ " )\n", "\n", " global lines\n", - " amp_kwargs = dict(color=\"r\", label=\"ampform\", linestyle=\"solid\")\n", - " dpd_kwargs = dict(color=\"blue\", label=\"dpd\", linestyle=\"dotted\")\n", + " amp_kwargs = {\"color\": \"r\", \"label\": \"ampform\", \"linestyle\": \"solid\"}\n", + " dpd_kwargs = {\"color\": \"blue\", \"label\": \"dpd\", \"linestyle\": \"dotted\"}\n", " if lines is None:\n", " sx = (s_edges[:-1] + s_edges[1:]) / 2\n", " tx = (t_edges[:-1] + t_edges[1:]) / 2\n", diff --git a/docs/comparison/jpsi2phipipi.ipynb b/docs/comparison/jpsi2phipipi.ipynb index 47b25a71..4013353f 100644 --- a/docs/comparison/jpsi2phipipi.ipynb +++ b/docs/comparison/jpsi2phipipi.ipynb @@ -35,7 +35,7 @@ "import logging\n", "import os\n", "from textwrap import dedent\n", - "from typing import Iterable\n", + "from typing import TYPE_CHECKING, Iterable\n", "\n", "import ampform\n", "import graphviz\n", @@ -44,28 +44,49 @@ "import matplotlib.pyplot as plt\n", "import qrules\n", "import sympy as sp\n", - "from ampform.helicity import HelicityModel\n", "from ampform.kinematics import FourMomentumSymbol, InvariantMass\n", "from ampform.sympy import perform_cached_doit\n", "from IPython.display import SVG, Latex, Markdown, clear_output, display\n", - "from ipywidgets import (Accordion, Checkbox, GridBox, HBox, Layout,\n", - " SelectMultiple, Tab, ToggleButtons, VBox,\n", - " interactive_output)\n", - "from qrules.transition import ReactionInfo\n", + "from ipywidgets import (\n", + " Accordion,\n", + " Checkbox,\n", + " GridBox,\n", + " HBox,\n", + " Layout,\n", + " SelectMultiple,\n", + " Tab,\n", + " ToggleButtons,\n", + " VBox,\n", + " interactive_output,\n", + ")\n", "from tensorwaves.data.phasespace import TFPhaseSpaceGenerator\n", "from tensorwaves.data.rng import TFUniformRealNumberGenerator\n", "from tensorwaves.data.transform import SympyDataTransformer\n", - "from tensorwaves.interface import (DataSample, ParameterValue,\n", - " ParametrizedFunction)\n", - "\n", - "from ampform_dpd import (DalitzPlotDecompositionBuilder,\n", - " simplify_latex_rendering)\n", - "from ampform_dpd.decay import (IsobarNode, Particle, ThreeBodyDecay,\n", - " ThreeBodyDecayChain)\n", - "from ampform_dpd.io import (as_markdown_table, aslatex, get_readable_hash,\n", - " perform_cached_lambdify)\n", + "\n", + "from ampform_dpd import DalitzPlotDecompositionBuilder, simplify_latex_rendering\n", + "from ampform_dpd.decay import (\n", + " IsobarNode,\n", + " Particle,\n", + " ThreeBodyDecay,\n", + " ThreeBodyDecayChain,\n", + ")\n", + "from ampform_dpd.io import (\n", + " as_markdown_table,\n", + " aslatex,\n", + " get_readable_hash,\n", + " perform_cached_lambdify,\n", + ")\n", "from ampform_dpd.spin import filter_parity_violating_ls, generate_ls_couplings\n", "\n", + "if TYPE_CHECKING:\n", + " from ampform.helicity import HelicityModel\n", + " from qrules.transition import ReactionInfo\n", + " from tensorwaves.interface import (\n", + " DataSample,\n", + " ParameterValue,\n", + " ParametrizedFunction,\n", + " )\n", + "\n", "simplify_latex_rendering()\n", "logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n", @@ -611,17 +632,16 @@ "outputs": [], "source": [ "def create_sliders() -> dict[str, ToggleButtons]:\n", - " all_parameters = {k: v for k, v in ampform_model.parameter_defaults.items()}\n", - " all_parameters.update({k: v for k, v in dpd_model.parameter_defaults.items()})\n", + " all_parameters = dict(ampform_model.parameter_defaults.items())\n", + " all_parameters.update(dict(dpd_model.parameter_defaults.items()))\n", " sliders = {}\n", " for symbol, value in all_parameters.items():\n", " value = \"+1\"\n", " if (\n", " symbol.name.startswith(R\"\\mathcal{H}^\\mathrm{decay}\")\n", " and \"+\" in symbol.name\n", - " ):\n", - " if any(s in symbol.name for s in [\"{1}\", \"*\", \"rho\"]):\n", - " value = \"-1\"\n", + " ) and any(s in symbol.name for s in [\"{1}\", \"*\", \"rho\"]):\n", + " value = \"-1\"\n", " sliders[symbol.name] = ToggleButtons(\n", " description=Rf\"\\({sp.latex(symbol)}\\)\",\n", " options=[\"-1\", \"0\", \"+1\"],\n", @@ -634,7 +654,7 @@ "def to_unicode(particle: Particle) -> str:\n", " unicode = particle.name\n", " unicode = unicode.replace(\"pi\", \"π\")\n", - " unicode = unicode.replace(\"rho\", \"ρ\")\n", + " unicode = unicode.replace(\"rho\", \"p\")\n", " unicode = unicode.replace(\"Sigma\", \"Σ\")\n", " unicode = unicode.replace(\"~\", \"\")\n", " unicode = unicode.replace(\"Σ\", \"~Σ\")\n", @@ -642,8 +662,7 @@ " unicode = unicode.replace(\"-\", \"⁻\")\n", " unicode = unicode.replace(\"(0)\", \"₀\")\n", " unicode = unicode.replace(\"(1)\", \"₁\")\n", - " unicode = unicode.replace(\")0\", \")⁰\")\n", - " return unicode\n", + " return unicode.replace(\")0\", \")⁰\")\n", "\n", "\n", "sliders = create_sliders()\n", @@ -864,8 +883,8 @@ " )\n", "\n", " global lines\n", - " amp_kwargs = dict(color=\"r\", label=\"ampform\", linestyle=\"solid\")\n", - " dpd_kwargs = dict(color=\"blue\", label=\"dpd\", linestyle=\"dotted\")\n", + " amp_kwargs = {\"color\": \"r\", \"label\": \"ampform\", \"linestyle\": \"solid\"}\n", + " dpd_kwargs = {\"color\": \"blue\", \"label\": \"dpd\", \"linestyle\": \"dotted\"}\n", " if lines is None:\n", " sx1 = (s1_edges[:-1] + s1_edges[1:]) / 2\n", " sx23 = (s23_edges[:-1] + s23_edges[1:]) / 2\n", diff --git a/docs/comparison/jpsi2pipipi.ipynb b/docs/comparison/jpsi2pipipi.ipynb index 828893f3..9a15747c 100644 --- a/docs/comparison/jpsi2pipipi.ipynb +++ b/docs/comparison/jpsi2pipipi.ipynb @@ -35,7 +35,7 @@ "import logging\n", "import os\n", "from textwrap import dedent\n", - "from typing import Iterable\n", + "from typing import TYPE_CHECKING, Iterable\n", "\n", "import ampform\n", "import graphviz\n", @@ -44,28 +44,49 @@ "import matplotlib.pyplot as plt\n", "import qrules\n", "import sympy as sp\n", - "from ampform.helicity import HelicityModel\n", "from ampform.kinematics import FourMomentumSymbol, InvariantMass\n", "from ampform.sympy import perform_cached_doit\n", "from IPython.display import SVG, Latex, Markdown, clear_output, display\n", - "from ipywidgets import (Accordion, Checkbox, GridBox, HBox, Layout,\n", - " SelectMultiple, Tab, ToggleButtons, VBox,\n", - " interactive_output)\n", - "from qrules.transition import ReactionInfo\n", + "from ipywidgets import (\n", + " Accordion,\n", + " Checkbox,\n", + " GridBox,\n", + " HBox,\n", + " Layout,\n", + " SelectMultiple,\n", + " Tab,\n", + " ToggleButtons,\n", + " VBox,\n", + " interactive_output,\n", + ")\n", "from tensorwaves.data.phasespace import TFPhaseSpaceGenerator\n", "from tensorwaves.data.rng import TFUniformRealNumberGenerator\n", "from tensorwaves.data.transform import SympyDataTransformer\n", - "from tensorwaves.interface import (DataSample, ParameterValue,\n", - " ParametrizedFunction)\n", - "\n", - "from ampform_dpd import (DalitzPlotDecompositionBuilder,\n", - " simplify_latex_rendering)\n", - "from ampform_dpd.decay import (IsobarNode, Particle, ThreeBodyDecay,\n", - " ThreeBodyDecayChain)\n", - "from ampform_dpd.io import (as_markdown_table, aslatex, get_readable_hash,\n", - " perform_cached_lambdify)\n", + "\n", + "from ampform_dpd import DalitzPlotDecompositionBuilder, simplify_latex_rendering\n", + "from ampform_dpd.decay import (\n", + " IsobarNode,\n", + " Particle,\n", + " ThreeBodyDecay,\n", + " ThreeBodyDecayChain,\n", + ")\n", + "from ampform_dpd.io import (\n", + " as_markdown_table,\n", + " aslatex,\n", + " get_readable_hash,\n", + " perform_cached_lambdify,\n", + ")\n", "from ampform_dpd.spin import filter_parity_violating_ls, generate_ls_couplings\n", "\n", + "if TYPE_CHECKING:\n", + " from ampform.helicity import HelicityModel\n", + " from qrules.transition import ReactionInfo\n", + " from tensorwaves.interface import (\n", + " DataSample,\n", + " ParameterValue,\n", + " ParametrizedFunction,\n", + " )\n", + "\n", "simplify_latex_rendering()\n", "logging.getLogger(\"jax\").setLevel(logging.ERROR) # mute JAX\n", "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # mute TF\n", @@ -611,17 +632,16 @@ "outputs": [], "source": [ "def create_sliders() -> dict[str, ToggleButtons]:\n", - " all_parameters = {k: v for k, v in ampform_model.parameter_defaults.items()}\n", - " all_parameters.update({k: v for k, v in dpd_model.parameter_defaults.items()})\n", + " all_parameters = dict(ampform_model.parameter_defaults.items())\n", + " all_parameters.update(dict(dpd_model.parameter_defaults.items()))\n", " sliders = {}\n", " for symbol, value in all_parameters.items():\n", " value = \"+1\"\n", " if (\n", " symbol.name.startswith(R\"\\mathcal{H}^\\mathrm{decay}\")\n", " and \"+\" in symbol.name\n", - " ):\n", - " if any(s in symbol.name for s in [\"{1}\", \"*\", \"rho\"]):\n", - " value = \"-1\"\n", + " ) and any(s in symbol.name for s in [\"{1}\", \"*\", \"rho\"]):\n", + " value = \"-1\"\n", " sliders[symbol.name] = ToggleButtons(\n", " description=Rf\"\\({sp.latex(symbol)}\\)\",\n", " options=[\"-1\", \"0\", \"+1\"],\n", @@ -634,7 +654,7 @@ "def to_unicode(particle: Particle) -> str:\n", " unicode = particle.name\n", " unicode = unicode.replace(\"pi\", \"π\")\n", - " unicode = unicode.replace(\"rho\", \"ρ\")\n", + " unicode = unicode.replace(\"rho\", \"p\")\n", " unicode = unicode.replace(\"Sigma\", \"Σ\")\n", " unicode = unicode.replace(\"~\", \"\")\n", " unicode = unicode.replace(\"Σ\", \"~Σ\")\n", @@ -642,8 +662,7 @@ " unicode = unicode.replace(\"-\", \"⁻\")\n", " unicode = unicode.replace(\"(0)\", \"₀\")\n", " unicode = unicode.replace(\"(1)\", \"₁\")\n", - " unicode = unicode.replace(\")0\", \")⁰\")\n", - " return unicode\n", + " return unicode.replace(\")0\", \")⁰\")\n", "\n", "\n", "sliders = create_sliders()\n", @@ -863,8 +882,8 @@ " )\n", "\n", " global lines\n", - " amp_kwargs = dict(color=\"r\", label=\"ampform\", linestyle=\"solid\")\n", - " dpd_kwargs = dict(color=\"blue\", label=\"dpd\", linestyle=\"dotted\")\n", + " amp_kwargs = {\"color\": \"r\", \"label\": \"ampform\", \"linestyle\": \"solid\"}\n", + " dpd_kwargs = {\"color\": \"blue\", \"label\": \"dpd\", \"linestyle\": \"dotted\"}\n", " if lines is None:\n", " sx = (s_edges[:-1] + s_edges[1:]) / 2\n", " tx = (t_edges[:-1] + t_edges[1:]) / 2\n", diff --git a/docs/conf.py b/docs/conf.py index 71f1c12e..19ba096f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,6 +16,8 @@ from importlib.metadata import version as get_package_version sys.path.insert(0, os.path.abspath(".")) +import contextlib + from _relink_references import relink_references from _unsrt_et_al import MyStyle @@ -36,13 +38,12 @@ def get_execution_mode() -> str: def get_logo_path() -> str | None: path = "_static/logo.svg" - try: + with contextlib.suppress(requests.exceptions.ConnectionError): _fetch_logo( url="https://raw.githubusercontent.com/ComPWA/ComPWA/04e5199/doc/images/logo.svg", output_path=path, ) - except requests.exceptions.ConnectionError: - pass + if os.path.exists(path): return path return None @@ -69,8 +70,8 @@ def generate_api() -> None: " ".join( [ "sphinx-apidoc", - f"../src/ampform_dpd/", - f"../src/ampform_dpd/version.py", + "../src/ampform_dpd/", + "../src/ampform_dpd/version.py", "-o api/", "--force", "--no-toc", diff --git a/docs/jpsi2ksp.ipynb b/docs/jpsi2ksp.ipynb index 36c0a3ce..2dfa7cba 100644 --- a/docs/jpsi2ksp.ipynb +++ b/docs/jpsi2ksp.ipynb @@ -29,7 +29,7 @@ "import itertools\n", "import logging\n", "import os\n", - "from typing import Iterable\n", + "from typing import TYPE_CHECKING, Iterable\n", "\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", @@ -37,21 +37,34 @@ "import sympy as sp\n", "from ampform.dynamics import EnergyDependentWidth, formulate_form_factor\n", "from ampform.kinematics.phasespace import compute_third_mandelstam\n", - "from ampform.sympy import (UnevaluatedExpression, create_expression,\n", - " implement_doit_method, make_commutative,\n", - " perform_cached_doit)\n", + "from ampform.sympy import (\n", + " UnevaluatedExpression,\n", + " create_expression,\n", + " implement_doit_method,\n", + " make_commutative,\n", + " perform_cached_doit,\n", + ")\n", "from IPython.display import Latex, Markdown\n", "from tensorwaves.data.transform import SympyDataTransformer\n", - "from tensorwaves.interface import DataSample, ParametrizedFunction\n", "from tqdm.auto import tqdm\n", "\n", - "from ampform_dpd import (DalitzPlotDecompositionBuilder, _get_particle,\n", - " simplify_latex_rendering)\n", - "from ampform_dpd.decay import (IsobarNode, Particle, ThreeBodyDecay,\n", - " ThreeBodyDecayChain)\n", + "from ampform_dpd import (\n", + " DalitzPlotDecompositionBuilder,\n", + " _get_particle,\n", + " simplify_latex_rendering,\n", + ")\n", + "from ampform_dpd.decay import (\n", + " IsobarNode,\n", + " Particle,\n", + " ThreeBodyDecay,\n", + " ThreeBodyDecayChain,\n", + ")\n", "from ampform_dpd.io import as_markdown_table, aslatex, perform_cached_lambdify\n", "from ampform_dpd.spin import filter_parity_violating_ls, generate_ls_couplings\n", "\n", + "if TYPE_CHECKING:\n", + " from tensorwaves.interface import DataSample, ParametrizedFunction\n", + "\n", "simplify_latex_rendering()\n", "logging.getLogger(\"absl\").setLevel(logging.ERROR) # mute JAX\n", "NO_TQDM = \"EXECUTE_NB\" in os.environ\n", @@ -459,8 +472,9 @@ " return s2\n", " if decay_masses == {m1, m2}:\n", " return s3\n", + " msg = f\"Cannot find Mandelstam variable for {''.join(decay_masses)}\"\n", " raise NotImplementedError(\n", - " f\"Cannot find Mandelstam variable for {''.join(decay_masses)}\"\n", + " msg\n", " )" ] }, diff --git a/docs/lc2pkpi.ipynb b/docs/lc2pkpi.ipynb index 07d0a60b..a79a3658 100644 --- a/docs/lc2pkpi.ipynb +++ b/docs/lc2pkpi.ipynb @@ -33,10 +33,13 @@ "import sympy as sp\n", "from IPython.display import Latex, Markdown\n", "\n", - "from ampform_dpd import (DalitzPlotDecompositionBuilder,\n", - " simplify_latex_rendering)\n", - "from ampform_dpd.decay import (IsobarNode, Particle, ThreeBodyDecay,\n", - " ThreeBodyDecayChain)\n", + "from ampform_dpd import DalitzPlotDecompositionBuilder, simplify_latex_rendering\n", + "from ampform_dpd.decay import (\n", + " IsobarNode,\n", + " Particle,\n", + " ThreeBodyDecay,\n", + " ThreeBodyDecayChain,\n", + ")\n", "from ampform_dpd.dynamics import BreitWignerMinL\n", "from ampform_dpd.io import as_markdown_table, aslatex\n", "from ampform_dpd.spin import filter_parity_violating_ls, generate_ls_couplings\n", @@ -304,8 +307,9 @@ " return s2\n", " if decay_masses == {m1, m2}:\n", " return s3\n", + " msg = f\"Cannot find Mandelstam variable for {''.join(decay_masses)}\"\n", " raise NotImplementedError(\n", - " f\"Cannot find Mandelstam variable for {''.join(decay_masses)}\"\n", + " msg\n", " )\n", "\n", "\n", diff --git a/src/ampform_dpd/__init__.py b/src/ampform_dpd/__init__.py index 46d78ae2..cbe5b367 100644 --- a/src/ampform_dpd/__init__.py +++ b/src/ampform_dpd/__init__.py @@ -10,13 +10,17 @@ from attrs import field, frozen from sympy.core.symbol import Str from sympy.physics.matrices import msigma -from sympy.physics.quantum.spin import CG +from sympy.physics.quantum.spin import CG, WignerD from sympy.physics.quantum.spin import Rotation as Wigner -from sympy.physics.quantum.spin import WignerD -from ampform_dpd.decay import (IsobarNode, LSCoupling, Particle, - ThreeBodyDecay, ThreeBodyDecayChain, - get_decay_product_ids) +from ampform_dpd.decay import ( + IsobarNode, + LSCoupling, + Particle, + ThreeBodyDecay, + ThreeBodyDecayChain, + get_decay_product_ids, +) from ampform_dpd.spin import create_spin_range from .angles import formulate_scattering_angle, formulate_zeta_angle @@ -67,8 +71,9 @@ def __init__( self.use_decay_helicity_couplings, ) = min_ls else: + msg = f"Cannot configure helicity couplings with a {type(min_ls).__name__}" raise NotImplementedError( - f"Cannot configure helicity couplings with a {type(min_ls).__name__}", + msg, min_ls, ) @@ -304,14 +309,16 @@ def __get_chain(self, identifier) -> ThreeBodyDecayChain: if isinstance(identifier, ThreeBodyDecayChain): chain = identifier if chain not in set(self.__decay.chains): + msg = f"Decay does not have chain with resonance {chain.resonance.name}" raise ValueError( - f"Decay does not have chain with resonance {chain.resonance.name}" + msg ) return chain if isinstance(identifier, str): return self.__decay.find_chain(identifier) + msg = f"Cannot get decay chain for identifier type {type(identifier)}" raise NotImplementedError( - f"Cannot get decay chain for identifier type {type(identifier)}" + msg ) @property @@ -348,8 +355,9 @@ def _formulate_clebsch_gordan_factors( helicities: dict[Particle, sp.Rational | sp.Symbol], ) -> sp.Expr: if isobar.interaction is None: + msg = "Cannot formulate amplitude model in LS-basis if LS-couplings are missing" raise ValueError( - "Cannot formulate amplitude model in LS-basis if LS-couplings are missing" + msg ) # https://github.com/ComPWA/ampform/blob/65b4efa/src/ampform/helicity/__init__.py#L785-L802 # and supplementary material p.1 (https://cds.cern.ch/record/2824328/files) @@ -391,9 +399,9 @@ def formulate_polarimetry( ) -> tuple[PoolSum, PoolSum, PoolSum]: half = sp.Rational(1, 2) if builder.decay.initial_state.spin != half: + msg = f"Can only formulate polarimetry for an initial state with spin 1/2, but got {builder.decay.initial_state.spin}" raise ValueError( - "Can only formulate polarimetry for an initial state with spin 1/2, but" - f" got {builder.decay.initial_state.spin}" + msg ) model = builder.formulate(reference_subsystem) λ0, λ0_prime = sp.symbols(R"lambda \lambda^{\prime}", rational=True) diff --git a/src/ampform_dpd/_attrs.py b/src/ampform_dpd/_attrs.py index 65aa6502..1a88e18d 100644 --- a/src/ampform_dpd/_attrs.py +++ b/src/ampform_dpd/_attrs.py @@ -4,16 +4,18 @@ from typing import TYPE_CHECKING, SupportsFloat import sympy as sp -from attrs import Attribute if TYPE_CHECKING: + from attrs import Attribute + from ampform_dpd.decay import LSCoupling def assert_spin_value(instance, attribute: Attribute, value: sp.Rational) -> None: if value.denominator not in {1, 2}: + msg = f"{attribute.name} value should be integer or half-integer, not {value}" raise ValueError( - f"{attribute.name} value should be integer or half-integer, not {value}" + msg ) @@ -27,7 +29,8 @@ def to_ls(obj: LSCoupling | tuple[int, SupportsFloat] | None) -> LSCoupling: if isinstance(obj, tuple): L, S = obj return LSCoupling(L, S) - raise TypeError(f"Cannot convert {type(obj).__name__} to {LSCoupling.__name__}") + msg = f"Cannot convert {type(obj).__name__} to {LSCoupling.__name__}" + raise TypeError(msg) def to_rational(obj: SupportsFloat) -> sp.Rational: diff --git a/src/ampform_dpd/angles.py b/src/ampform_dpd/angles.py index 1ef059a3..b1e268b9 100644 --- a/src/ampform_dpd/angles.py +++ b/src/ampform_dpd/angles.py @@ -15,17 +15,20 @@ def formulate_scattering_angle( rest frame of the isobar resonance :math:`(ij)`. """ if not {state_id, sibling_id} <= {1, 2, 3}: - raise ValueError(f"Child IDs need to be one of 1, 2, 3") + msg = "Child IDs need to be one of 1, 2, 3" + raise ValueError(msg) # pyright: ignore[reportUnnecessaryContains] if {state_id, sibling_id} in {(2, 1), (3, 2), (1, 3)}: + msg = f"Cannot compute scattering angle θ{state_id}{sibling_id}" raise NotImplementedError( - f"Cannot compute scattering angle θ{state_id}{sibling_id}" + msg ) if state_id == sibling_id: - raise ValueError(f"IDs of the decay products cannot be equal: {state_id}") + msg = f"IDs of the decay products cannot be equal: {state_id}" + raise ValueError(msg) symbol = sp.Symbol(Rf"theta_{state_id}{sibling_id}", real=True) spectator_id = next(iter({1, 2, 3} - {state_id, sibling_id})) - m0 = sp.Symbol(f"m0", nonnegative=True) + m0 = sp.Symbol("m0", nonnegative=True) mi = sp.Symbol(f"m{state_id}", nonnegative=True) mj = sp.Symbol(f"m{sibling_id}", nonnegative=True) mk = sp.Symbol(f"m{spectator_id}", nonnegative=True) @@ -50,15 +53,16 @@ def formulate_theta_hat_angle( r"""Formulate an expression for :math:`\hat\theta_{i(j)}`.""" allowed_ids = {1, 2, 3} if not {isobar_id, aligned_subsystem} <= allowed_ids: + msg = f"Child IDs need to be one of {', '.join(map(str, allowed_ids))}" raise ValueError( - f"Child IDs need to be one of {', '.join(map(str, allowed_ids))}" + msg ) symbol = sp.Symbol(Rf"\hat\theta_{isobar_id}({aligned_subsystem})", real=True) if isobar_id == aligned_subsystem: return symbol, sp.S.Zero if (isobar_id, aligned_subsystem) in {(3, 1), (1, 2), (2, 3)}: remaining_id = next(iter(allowed_ids - {isobar_id, aligned_subsystem})) - m0 = sp.Symbol(f"m0", nonnegative=True) + m0 = sp.Symbol("m0", nonnegative=True) mi = sp.Symbol(f"m{isobar_id}", nonnegative=True) mj = sp.Symbol(f"m{aligned_subsystem}", nonnegative=True) σi = sp.Symbol(f"sigma{isobar_id}", nonnegative=True) @@ -158,7 +162,8 @@ def formulate_zeta_angle( (2, 3, 1), (3, 1, 2), }: - create_symbols = lambda i: sp.symbols(f"m{i} sigma{i}", nonnegative=True) + def create_symbols(i): + return sp.symbols(f"m{i} sigma{i}", nonnegative=True) mi, σi = create_symbols(rotated_state) mj, σj = create_symbols(aligned_subsystem) mk, σk = create_symbols(reference_subsystem) @@ -174,11 +179,9 @@ def formulate_zeta_angle( (1, 3, 1), (2, 1, 2), (3, 2, 3), - # Eq (A8) (1, 1, 2), (2, 2, 3), (3, 3, 1), - # Eq (A11) (1, 3, 2), (2, 1, 3), (3, 2, 1), @@ -187,7 +190,7 @@ def formulate_zeta_angle( rotated_state, reference_subsystem, aligned_subsystem ) return zeta_symbol, -zeta + msg = f"No expression for ζ^{rotated_state}_{aligned_subsystem}({reference_subsystem})" raise NotImplementedError( - "No expression for" - f" ζ^{rotated_state}_{aligned_subsystem}({reference_subsystem})" + msg ) diff --git a/src/ampform_dpd/decay.py b/src/ampform_dpd/decay.py index a26eabd5..095c5e68 100644 --- a/src/ampform_dpd/decay.py +++ b/src/ampform_dpd/decay.py @@ -2,14 +2,16 @@ from __future__ import annotations import sys -from typing import Dict +from typing import TYPE_CHECKING, Dict -import sympy as sp from attrs import field, frozen from attrs.validators import instance_of from ampform_dpd._attrs import assert_spin_value, to_ls, to_rational +if TYPE_CHECKING: + import sympy as sp + if sys.version_info < (3, 8): from typing_extensions import Literal else: @@ -48,16 +50,17 @@ def __attrs_post_init__(self) -> None: expected_final_state = set(self.final_state.values()) for i, chain in enumerate(self.chains): if chain.parent != expected_initial_state: + msg = f"Chain {i} has initial state {chain.parent.name}, but should have {expected_initial_state.name}" raise ValueError( - f"Chain {i} has initial state {chain.parent.name}, but should have" - f" {expected_initial_state.name}" + msg ) final_state = {chain.spectator, *chain.decay_products} if final_state != expected_final_state: - to_str = lambda s: ", ".join(p.name for p in s) + def to_str(s): + return ", ".join(p.name for p in s) + msg = f"Chain {i} has final state {to_str(final_state)}, but should have {to_str(expected_final_state)}" raise ValueError( - f"Chain {i} has final state {to_str(final_state)}, but should have" - f" {to_str(expected_final_state)}" + msg ) @property @@ -72,7 +75,8 @@ def find_chain(self, resonance_name: str) -> ThreeBodyDecayChain: for chain in self.chains: if chain.resonance.name == resonance_name: return chain - raise KeyError(f"No decay chain found for resonance {resonance_name}") + msg = f"No decay chain found for resonance {resonance_name}" + raise KeyError(msg) def get_subsystem(self, subsystem_id: Literal[1, 2, 3]) -> ThreeBodyDecay: child1_id, child2_id = get_decay_product_ids(subsystem_id) @@ -93,7 +97,8 @@ def get_decay_product_ids(spectator_id: Literal[1, 2, 3]) -> tuple[int, int]: return 3, 1 if spectator_id == 3: return 1, 2 - raise ValueError(f"Spectator ID has to be one of 1, 2, 3, not {spectator_id}") + msg = f"Spectator ID has to be one of 1, 2, 3, not {spectator_id}" + raise ValueError(msg) OuterStates = Dict[Literal[0, 1, 2, 3], Particle] @@ -106,17 +111,23 @@ class ThreeBodyDecayChain: def __attrs_post_init__(self) -> None: if not isinstance(self.decay.child1, IsobarNode): - raise TypeError(f"Child 1 has of type {IsobarNode.__name__} (the decay)") + msg = f"Child 1 has of type {IsobarNode.__name__} (the decay)" + raise TypeError(msg) if not isinstance(self.decay.child1.child1, Particle): - raise TypeError(f"Child 1 of child 1 has of type {Particle.__name__}") + msg = f"Child 1 of child 1 has of type {Particle.__name__}" + raise TypeError(msg) if not isinstance(self.decay.child1.child2, Particle): - raise TypeError(f"Child 2 of child 1 has of type {Particle.__name__}") + msg = f"Child 2 of child 1 has of type {Particle.__name__}" + raise TypeError(msg) if not isinstance(self.decay.child2, Particle): - raise TypeError(f"Child 2 has of type {Particle.__name__} (spectator)") + msg = f"Child 2 has of type {Particle.__name__} (spectator)" + raise TypeError(msg) if self.incoming_ls is None: # pyright: ignore[reportUnnecessaryComparison] - raise ValueError(f"LS-coupling for production node required") + msg = "LS-coupling for production node required" + raise ValueError(msg) if self.outgoing_ls is None: # pyright: ignore[reportUnnecessaryComparison] - raise ValueError(f"LS-coupling for decay node required") + msg = "LS-coupling for decay node required" + raise ValueError(msg) @property def parent(self) -> Particle: diff --git a/src/ampform_dpd/dynamics.py b/src/ampform_dpd/dynamics.py index baa9a49f..a085f6c2 100644 --- a/src/ampform_dpd/dynamics.py +++ b/src/ampform_dpd/dynamics.py @@ -5,8 +5,12 @@ import sympy as sp from ampform.kinematics.phasespace import Kallen -from ampform.sympy import (UnevaluatedExpression, create_expression, - implement_doit_method, make_commutative) +from ampform.sympy import ( + UnevaluatedExpression, + create_expression, + implement_doit_method, + make_commutative, +) @make_commutative diff --git a/src/ampform_dpd/io.py b/src/ampform_dpd/io.py index 88920d6a..927e2fde 100644 --- a/src/ampform_dpd/io.py +++ b/src/ampform_dpd/io.py @@ -25,21 +25,22 @@ from functools import lru_cache from os.path import abspath, dirname, expanduser from textwrap import dedent -from typing import Iterable, Mapping, Sequence, overload +from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, overload from warnings import warn import cloudpickle import sympy as sp from ampform.io import aslatex -from tensorwaves.function import (ParametrizedBackendFunction, - PositionalArgumentFunction) -from tensorwaves.function.sympy import (create_function, - create_parametrized_function) -from tensorwaves.interface import (Function, ParameterValue, - ParametrizedFunction) +from tensorwaves.function.sympy import create_function, create_parametrized_function -from ampform_dpd.decay import (IsobarNode, Particle, ThreeBodyDecay, - ThreeBodyDecayChain) +from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecay, ThreeBodyDecayChain + +if TYPE_CHECKING: + from tensorwaves.function import ( + ParametrizedBackendFunction, + PositionalArgumentFunction, + ) + from tensorwaves.interface import Function, ParameterValue, ParametrizedFunction _LOGGER = logging.getLogger(__name__) @@ -66,7 +67,8 @@ def _(obj: sp.Basic, **kwargs) -> str: @aslatex.register(abc.Mapping) def _(obj: Mapping, **kwargs) -> str: if len(obj) == 0: - raise ValueError("Need at least one dictionary item") + msg = "Need at least one dictionary item" + raise ValueError(msg) latex = R"\begin{array}{rcl}" + "\n" for lhs, rhs in obj.items(): latex += Rf" {aslatex(lhs, **kwargs)} &=& {aslatex(rhs, **kwargs)} \\" + "\n" @@ -78,7 +80,8 @@ def _(obj: Mapping, **kwargs) -> str: def _(obj: Iterable, **kwargs) -> str: obj = list(obj) if len(obj) == 0: - raise ValueError("Need at least one item to render as LaTeX") + msg = "Need at least one item to render as LaTeX" + raise ValueError(msg) latex = R"\begin{array}{c}" + "\n" for item in obj: item_latex = aslatex(item, **kwargs) @@ -139,8 +142,9 @@ def as_markdown_table(obj: Sequence) -> str: return _as_decay_markdown_table(obj.chains) if item_type is ThreeBodyDecayChain: return _as_decay_markdown_table(obj) + msg = f"Cannot render a sequence with {item_type.__name__} items as a Markdown table" raise NotImplementedError( - f"Cannot render a sequence with {item_type.__name__} items as a Markdown table" + msg ) @@ -148,10 +152,12 @@ def _determine_item_type(obj) -> type: if not isinstance(obj, abc.Sequence): return type(obj) if len(obj) < 1: - raise ValueError(f"Need at least one entry to render a table") + msg = "Need at least one entry to render a table" + raise ValueError(msg) item_type = type(obj[0]) - if not all(map(lambda i: isinstance(i, item_type), obj)): - raise ValueError(f"Not all items are of type {item_type.__name__}") + if not all(isinstance(i, item_type) for i in obj): + msg = f"Not all items are of type {item_type.__name__}" + raise ValueError(msg) return item_type @@ -189,7 +195,7 @@ def _as_decay_markdown_table(decay_chains: Sequence[ThreeBodyDecayChain]) -> str for chain in decay_chains: child1, child2 = map(aslatex, chain.decay_products) row_items = [ - Rf"${chain.resonance.latex} \to" Rf" {child1} {child2}$", + Rf"${chain.resonance.latex} \to {child1} {child2}$", Rf"${aslatex(chain.resonance, only_jp=True)}$", f"{int(1e3 * chain.resonance.mass):,.0f}", f"{int(1e3 * chain.resonance.width):,.0f}", @@ -207,7 +213,7 @@ def _create_markdown_table_header(column_names: list[str]): def _create_markdown_table_row(items: Iterable): - items = map(lambda i: f"{i}", items) + items = (f"{i}" for i in items) return "| " + " | ".join(items) + " |\n" diff --git a/src/ampform_dpd/spin.py b/src/ampform_dpd/spin.py index 5797a0fa..6de5ed41 100644 --- a/src/ampform_dpd/spin.py +++ b/src/ampform_dpd/spin.py @@ -12,9 +12,8 @@ def generate_ls_couplings( child2_spin: SupportsFloat, max_L: int = 3, ) -> list[tuple[int, sp.Rational]]: - r""" - >>> generate_ls_couplings(1.5, 0.5, 0) - [(1, 1/2), (2, 1/2)] + r""">>> generate_ls_couplings(1.5, 0.5, 0) + [(1, 1/2), (2, 1/2)]. """ s1 = float(child1_spin) s2 = float(child2_spin) @@ -35,12 +34,11 @@ def filter_parity_violating_ls( child1_parity: SupportsInt, child2_parity: SupportsInt, ) -> list[tuple[int, sp.Rational]]: - r""" - >>> LS = generate_ls_couplings(0.5, 1.5, 0) # Λc → Λ(1520)π + r""">>> LS = generate_ls_couplings(0.5, 1.5, 0) # Λc → Λ(1520)π >>> LS [(1, 3/2), (2, 3/2)] >>> filter_parity_violating_ls(LS, +1, -1, -1) - [(2, 3/2)] + [(2, 3/2)]. """ η0, η1, η2 = ( int(parent_parity), @@ -51,9 +49,8 @@ def filter_parity_violating_ls( def create_spin_range(spin: SupportsFloat) -> list[sp.Rational]: - """ - >>> create_spin_range(1.5) - [-3/2, -1/2, 1/2, 3/2] + """>>> create_spin_range(1.5) + [-3/2, -1/2, 1/2, 3/2]. """ return create_rational_range(-spin, spin) @@ -61,9 +58,8 @@ def create_spin_range(spin: SupportsFloat) -> list[sp.Rational]: def create_rational_range( __from: SupportsFloat, __to: SupportsFloat ) -> list[sp.Rational]: - """ - >>> create_rational_range(-0.5, +1.5) - [-1/2, 1/2, 3/2] + """>>> create_rational_range(-0.5, +1.5) + [-1/2, 1/2, 3/2]. """ spin_range = arange(float(__from), +float(__to) + 0.5) return list(map(sp.Rational, spin_range)) diff --git a/tests/test_angles.py b/tests/test_angles.py index bbc6cc56..ca8eb2cc 100644 --- a/tests/test_angles.py +++ b/tests/test_angles.py @@ -3,9 +3,11 @@ import sympy as sp from ampform.kinematics.phasespace import Kallen, compute_third_mandelstam -from ampform_dpd.angles import (formulate_scattering_angle, - formulate_theta_hat_angle, - formulate_zeta_angle) +from ampform_dpd.angles import ( + formulate_scattering_angle, + formulate_theta_hat_angle, + formulate_zeta_angle, +) m0, m1, m2, m3 = sp.symbols("m:4", nonnegative=True) σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True) diff --git a/tests/test_decay.py b/tests/test_decay.py index 68377963..455bb0f0 100644 --- a/tests/test_decay.py +++ b/tests/test_decay.py @@ -1,7 +1,7 @@ from ampform_dpd.decay import IsobarNode, Particle # https://compwa-org--129.org.readthedocs.build/report/018.html#resonances-and-ls-scheme -dummy_args = dict(mass=0, width=0) +dummy_args = {"mass": 0, "width": 0} Λc = Particle("Λc", latex=R"\Lambda_c^+", spin=0.5, parity=+1, **dummy_args) p = Particle("p", latex="p", spin=0.5, parity=+1, **dummy_args) π = Particle("π+", latex=R"\pi^+", spin=0, parity=-1, **dummy_args) diff --git a/tests/test_io.py b/tests/test_io.py index 9e8f7b73..ba832de7 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -18,7 +18,7 @@ THIS_DIR = dirname(abspath(__file__)) # https://compwa-org--129.org.readthedocs.build/report/018.html#resonances-and-ls-scheme -dummy_args = dict(mass=0, width=0) +dummy_args = {"mass": 0, "width": 0} Λc = Particle("Λc", latex=R"\Lambda_c^+", spin=0.5, parity=+1, **dummy_args) p = Particle("p", latex="p", spin=0.5, parity=+1, **dummy_args) π = Particle("π+", latex=R"\pi^+", spin=0, parity=-1, **dummy_args) @@ -53,9 +53,9 @@ def test_aslatex_isobar_node(): @pytest.mark.parametrize( ("assumptions", "expected_hash"), [ - (dict(), "pythonhashseed-0+7459658071388516764"), - (dict(real=True), "pythonhashseed-0+3665410414623666716"), - (dict(rational=True), "pythonhashseed-0-7926839224244779605"), + ({}, "pythonhashseed-0+7459658071388516764"), + ({"real": True}, "pythonhashseed-0+3665410414623666716"), + ({"rational": True}, "pythonhashseed-0-7926839224244779605"), ], ) def test_get_readable_hash(assumptions, expected_hash, caplog: LogCaptureFixture):