diff --git a/docs/comparison/d2kkk.ipynb b/docs/comparison/d2kkk.ipynb index c321153c..d960ba24 100644 --- a/docs/comparison/d2kkk.ipynb +++ b/docs/comparison/d2kkk.ipynb @@ -46,7 +46,6 @@ "import matplotlib.pyplot as plt\n", "import qrules\n", "import sympy as sp\n", - "from ampform.helicity.align.dpd import relabel_edge_ids\n", "from ampform.kinematics.lorentz import FourMomentumSymbol, InvariantMass\n", "from ampform.sympy import perform_cached_doit\n", "from IPython.display import Latex, Markdown, clear_output, display\n", @@ -67,7 +66,7 @@ "from tensorwaves.data.transform import SympyDataTransformer\n", "\n", "from ampform_dpd import DalitzPlotDecompositionBuilder\n", - "from ampform_dpd.adapter.qrules import to_three_body_decay\n", + "from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n", "from ampform_dpd.decay import Particle\n", "from ampform_dpd.io import (\n", " as_markdown_table,\n", @@ -122,7 +121,7 @@ " mass_conservation_factor=0.2,\n", " formalism=\"helicity\",\n", ")\n", - "REACTION123 = relabel_edge_ids(REACTION)\n", + "REACTION123 = normalize_state_ids(REACTION)\n", "dot = qrules.io.asdot(REACTION123, collapse_graphs=True)\n", "graphviz.Source(dot)" ] diff --git a/docs/comparison/jpsi2phipipi.ipynb b/docs/comparison/jpsi2phipipi.ipynb index 819171c6..19928907 100644 --- a/docs/comparison/jpsi2phipipi.ipynb +++ b/docs/comparison/jpsi2phipipi.ipynb @@ -46,7 +46,6 @@ "import matplotlib.pyplot as plt\n", "import qrules\n", "import sympy as sp\n", - "from ampform.helicity.align.dpd import relabel_edge_ids\n", "from ampform.kinematics.lorentz import FourMomentumSymbol, InvariantMass\n", "from ampform.sympy import perform_cached_doit\n", "from IPython.display import Latex, Markdown, clear_output, display\n", @@ -67,7 +66,7 @@ "from tensorwaves.data.transform import SympyDataTransformer\n", "\n", "from ampform_dpd import DalitzPlotDecompositionBuilder\n", - "from ampform_dpd.adapter.qrules import to_three_body_decay\n", + "from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n", "from ampform_dpd.decay import Particle\n", "from ampform_dpd.io import (\n", " as_markdown_table,\n", @@ -122,7 +121,7 @@ " mass_conservation_factor=0,\n", " formalism=\"helicity\",\n", ")\n", - "REACTION123 = relabel_edge_ids(REACTION)\n", + "REACTION123 = normalize_state_ids(REACTION)\n", "dot = qrules.io.asdot(REACTION123, collapse_graphs=True)\n", "graphviz.Source(dot)" ] diff --git a/docs/comparison/jpsi2pipipi.ipynb b/docs/comparison/jpsi2pipipi.ipynb index 62ffd5ac..4f5b77d1 100644 --- a/docs/comparison/jpsi2pipipi.ipynb +++ b/docs/comparison/jpsi2pipipi.ipynb @@ -46,7 +46,6 @@ "import matplotlib.pyplot as plt\n", "import qrules\n", "import sympy as sp\n", - "from ampform.helicity.align.dpd import relabel_edge_ids\n", "from ampform.kinematics.lorentz import FourMomentumSymbol, InvariantMass\n", "from ampform.sympy import perform_cached_doit\n", "from IPython.display import Latex, Markdown, clear_output, display\n", @@ -67,7 +66,7 @@ "from tensorwaves.data.transform import SympyDataTransformer\n", "\n", "from ampform_dpd import DalitzPlotDecompositionBuilder\n", - "from ampform_dpd.adapter.qrules import to_three_body_decay\n", + "from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n", "from ampform_dpd.decay import Particle\n", "from ampform_dpd.io import (\n", " as_markdown_table,\n", @@ -122,7 +121,7 @@ " mass_conservation_factor=0,\n", " formalism=\"helicity\",\n", ")\n", - "REACTION123 = relabel_edge_ids(REACTION)\n", + "REACTION123 = normalize_state_ids(REACTION)\n", "dot = qrules.io.asdot(REACTION123, collapse_graphs=True)\n", "graphviz.Source(dot)" ] diff --git a/docs/jpsi2ksp.ipynb b/docs/jpsi2ksp.ipynb index 9068f7ed..7f06e26a 100644 --- a/docs/jpsi2ksp.ipynb +++ b/docs/jpsi2ksp.ipynb @@ -39,7 +39,6 @@ "import qrules\n", "import sympy as sp\n", "from ampform.dynamics import EnergyDependentWidth, formulate_form_factor\n", - "from ampform.helicity.align.dpd import relabel_edge_ids\n", "from ampform.kinematics.phasespace import compute_third_mandelstam\n", "from ampform.sympy import perform_cached_doit, unevaluated\n", "from IPython.display import Latex, Markdown\n", @@ -47,7 +46,7 @@ "from tqdm.auto import tqdm\n", "\n", "from ampform_dpd import DalitzPlotDecompositionBuilder, get_particle\n", - "from ampform_dpd.adapter.qrules import to_three_body_decay\n", + "from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n", "from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecayChain\n", "from ampform_dpd.io import (\n", " as_markdown_table,\n", @@ -101,7 +100,7 @@ " formalism=\"canonical-helicity\",\n", " mass_conservation_factor=0.05,\n", ")\n", - "REACTION = relabel_edge_ids(REACTION)\n", + "REACTION = normalize_state_ids(REACTION)\n", "dot = qrules.io.asdot(REACTION, collapse_graphs=True)\n", "graphviz.Source(dot)" ] diff --git a/docs/lc2pkpi.ipynb b/docs/lc2pkpi.ipynb index e2c1721f..e9c7845e 100644 --- a/docs/lc2pkpi.ipynb +++ b/docs/lc2pkpi.ipynb @@ -31,11 +31,14 @@ "import graphviz\n", "import qrules\n", "import sympy as sp\n", - "from ampform.helicity.align.dpd import relabel_edge_ids\n", "from IPython.display import Latex, Markdown\n", "\n", "from ampform_dpd import DalitzPlotDecompositionBuilder, get_particle\n", - "from ampform_dpd.adapter.qrules import load_particles, to_three_body_decay\n", + "from ampform_dpd.adapter.qrules import (\n", + " load_particles,\n", + " normalize_state_ids,\n", + " to_three_body_decay,\n", + ")\n", "from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecayChain\n", "from ampform_dpd.dynamics import BreitWignerMinL\n", "from ampform_dpd.io import as_markdown_table, aslatex, simplify_latex_rendering\n", @@ -90,7 +93,7 @@ "STM.set_allowed_interaction_types([qrules.InteractionType.STRONG], node_id=1)\n", "problem_sets = STM.create_problem_sets()\n", "REACTION = STM.find_solutions(problem_sets)\n", - "REACTION = relabel_edge_ids(REACTION)\n", + "REACTION = normalize_state_ids(REACTION)\n", "dot = qrules.io.asdot(REACTION, collapse_graphs=True)\n", "graphviz.Source(dot)" ] diff --git a/pyproject.toml b/pyproject.toml index d8ed943a..5cb6e6ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ - "ampform >=0.15.0", # relabel_edge_ids + "ampform >=0.14.8", # Kibble and Kallen functions, perform_cached_doit, @unevaluated "attrs >=20.1.0", # on_setattr and https://www.attrs.org/en/stable/api.html#next-gen "cloudpickle", "qrules >=0.10.0", diff --git a/src/ampform_dpd/adapter/qrules.py b/src/ampform_dpd/adapter/qrules.py index 9605a5f3..a303e3d5 100644 --- a/src/ampform_dpd/adapter/qrules.py +++ b/src/ampform_dpd/adapter/qrules.py @@ -1,13 +1,15 @@ from __future__ import annotations -from collections import defaultdict +from collections import abc, defaultdict +from functools import singledispatch from pathlib import Path -from typing import Any, Iterable +from typing import Any, Iterable, TypeVar, overload +import attrs import qrules from qrules.quantum_numbers import InteractionProperties from qrules.topology import EdgeType, FrozenTransition, NodeType -from qrules.transition import State +from qrules.transition import ReactionInfo, State, StateTransition, Topology from ampform_dpd.decay import ( IsobarNode, @@ -149,3 +151,52 @@ def load_particles() -> qrules.particle.ParticleCollection: additional_definitions = qrules.io.load(src_dir / "particle-definitions.yml") # type:ignore[arg-type] particle_database.update(additional_definitions) # type:ignore[arg-type] return particle_database + + +@overload +def normalize_state_ids(obj: T) -> T: ... +@overload +def normalize_state_ids(obj: Iterable[T]) -> list[T]: ... +def normalize_state_ids(obj): # pyright:ignore[reportInconsistentOverload] + """Relabel the state IDs so that they lie in the range :math:`[0, N)`.""" + return _impl_normalize_state_ids(obj) + + +@singledispatch +def _impl_normalize_state_ids(obj): + """Relabel the state IDs so that they lie in the range :math:`[0, N)`.""" + msg = f"Cannot relabel edge IDs of a {type(obj).__name__}" + raise NotImplementedError(msg) + + +@_impl_normalize_state_ids.register(ReactionInfo) # type:ignore[attr-defined] +def _(obj: ReactionInfo) -> ReactionInfo: + return ReactionInfo( + # no attrs.evolve() in order to call __attrs_post_init__() + transitions=[_impl_normalize_state_ids(g) for g in obj.transitions], + formalism=obj.formalism, + ) + + +@_impl_normalize_state_ids.register(FrozenTransition) # type:ignore[attr-defined] +def _(obj: StateTransition) -> StateTransition: + return attrs.evolve( + obj, + topology=_impl_normalize_state_ids(obj.topology), + states={new: obj.states[old] for new, old in enumerate(sorted(obj.states))}, + ) + + +@_impl_normalize_state_ids.register(Topology) # type:ignore[attr-defined] +def _(obj: Topology) -> Topology: + mapping = {old: new for new, old in enumerate(sorted(obj.edges))} + return obj.relabel_edges(mapping) + + +@_impl_normalize_state_ids.register(abc.Iterable) # type:ignore[attr-defined] +def _(obj: abc.Iterable[T]) -> list[T]: + return [_impl_normalize_state_ids(x) for x in obj] + + +T = TypeVar("T", ReactionInfo, StateTransition, Topology) +"""Type variable for the input and output of :func:`normalize_state_ids`.""" diff --git a/tests/adapter/test_qrules.py b/tests/adapter/test_qrules.py index f2d9b51e..9d691cff 100644 --- a/tests/adapter/test_qrules.py +++ b/tests/adapter/test_qrules.py @@ -5,7 +5,11 @@ import pytest import qrules -from ampform_dpd.adapter.qrules import filter_min_ls, to_three_body_decay +from ampform_dpd.adapter.qrules import ( + filter_min_ls, + normalize_state_ids, + to_three_body_decay, +) from ampform_dpd.decay import LSCoupling, Particle if TYPE_CHECKING: @@ -50,6 +54,27 @@ def test_filter_min_ls(reaction: ReactionInfo): ] +def test_normalize_state_ids_reaction(reaction: ReactionInfo): + reaction012 = reaction + reaction123 = normalize_state_ids(reaction012) + assert set(reaction123.initial_state) == {0} + assert set(reaction123.final_state) == {1, 2, 3} + + transitions123 = normalize_state_ids(reaction012.transitions) + for transition012, transition123 in zip(reaction012.transitions, transitions123): + assert set(transition123.initial_states) == {0} + assert set(transition123.final_states) == {1, 2, 3} + assert set(transition123.intermediate_states) == {4} + + topology123 = normalize_state_ids(transition123.topology) + assert topology123.incoming_edge_ids == {0} + assert topology123.outgoing_edge_ids == {1, 2, 3} + assert topology123.intermediate_edge_ids == {4} + + for i in transition012.states: + assert transition012.states[i] == transition123.states[i + 1] + + @pytest.mark.parametrize("min_ls", [False, True]) def test_to_three_body_decay(reaction: ReactionInfo, min_ls: bool): decay = to_three_body_decay(reaction.transitions, min_ls)