Skip to content

Commit

Permalink
FEAT: define normalize_state_ids() for QRules (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer authored Apr 26, 2024
1 parent a7c58b8 commit 80e3e6d
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 29 deletions.
5 changes: 2 additions & 3 deletions docs/comparison/d2kkk.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"import matplotlib.pyplot as plt\n",
"import qrules\n",
"import sympy as sp\n",
"from ampform.helicity.align.dpd import relabel_edge_ids\n",
"from ampform.kinematics.lorentz import FourMomentumSymbol, InvariantMass\n",
"from ampform.sympy import perform_cached_doit\n",
"from IPython.display import Latex, Markdown, clear_output, display\n",
Expand All @@ -67,7 +66,7 @@
"from tensorwaves.data.transform import SympyDataTransformer\n",
"\n",
"from ampform_dpd import DalitzPlotDecompositionBuilder\n",
"from ampform_dpd.adapter.qrules import to_three_body_decay\n",
"from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n",
"from ampform_dpd.decay import Particle\n",
"from ampform_dpd.io import (\n",
" as_markdown_table,\n",
Expand Down Expand Up @@ -122,7 +121,7 @@
" mass_conservation_factor=0.2,\n",
" formalism=\"helicity\",\n",
")\n",
"REACTION123 = relabel_edge_ids(REACTION)\n",
"REACTION123 = normalize_state_ids(REACTION)\n",
"dot = qrules.io.asdot(REACTION123, collapse_graphs=True)\n",
"graphviz.Source(dot)"
]
Expand Down
5 changes: 2 additions & 3 deletions docs/comparison/jpsi2phipipi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"import matplotlib.pyplot as plt\n",
"import qrules\n",
"import sympy as sp\n",
"from ampform.helicity.align.dpd import relabel_edge_ids\n",
"from ampform.kinematics.lorentz import FourMomentumSymbol, InvariantMass\n",
"from ampform.sympy import perform_cached_doit\n",
"from IPython.display import Latex, Markdown, clear_output, display\n",
Expand All @@ -67,7 +66,7 @@
"from tensorwaves.data.transform import SympyDataTransformer\n",
"\n",
"from ampform_dpd import DalitzPlotDecompositionBuilder\n",
"from ampform_dpd.adapter.qrules import to_three_body_decay\n",
"from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n",
"from ampform_dpd.decay import Particle\n",
"from ampform_dpd.io import (\n",
" as_markdown_table,\n",
Expand Down Expand Up @@ -122,7 +121,7 @@
" mass_conservation_factor=0,\n",
" formalism=\"helicity\",\n",
")\n",
"REACTION123 = relabel_edge_ids(REACTION)\n",
"REACTION123 = normalize_state_ids(REACTION)\n",
"dot = qrules.io.asdot(REACTION123, collapse_graphs=True)\n",
"graphviz.Source(dot)"
]
Expand Down
5 changes: 2 additions & 3 deletions docs/comparison/jpsi2pipipi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"import matplotlib.pyplot as plt\n",
"import qrules\n",
"import sympy as sp\n",
"from ampform.helicity.align.dpd import relabel_edge_ids\n",
"from ampform.kinematics.lorentz import FourMomentumSymbol, InvariantMass\n",
"from ampform.sympy import perform_cached_doit\n",
"from IPython.display import Latex, Markdown, clear_output, display\n",
Expand All @@ -67,7 +66,7 @@
"from tensorwaves.data.transform import SympyDataTransformer\n",
"\n",
"from ampform_dpd import DalitzPlotDecompositionBuilder\n",
"from ampform_dpd.adapter.qrules import to_three_body_decay\n",
"from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n",
"from ampform_dpd.decay import Particle\n",
"from ampform_dpd.io import (\n",
" as_markdown_table,\n",
Expand Down Expand Up @@ -122,7 +121,7 @@
" mass_conservation_factor=0,\n",
" formalism=\"helicity\",\n",
")\n",
"REACTION123 = relabel_edge_ids(REACTION)\n",
"REACTION123 = normalize_state_ids(REACTION)\n",
"dot = qrules.io.asdot(REACTION123, collapse_graphs=True)\n",
"graphviz.Source(dot)"
]
Expand Down
14 changes: 2 additions & 12 deletions docs/jpsi2ksp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@
"import qrules\n",
"import sympy as sp\n",
"from ampform.dynamics import EnergyDependentWidth, formulate_form_factor\n",
"from ampform.helicity.align.dpd import relabel_edge_ids\n",
"from ampform.kinematics.phasespace import compute_third_mandelstam\n",
"from ampform.sympy import perform_cached_doit, unevaluated\n",
"from IPython.display import Latex, Markdown\n",
"from tensorwaves.data.transform import SympyDataTransformer\n",
"from tqdm.auto import tqdm\n",
"\n",
"from ampform_dpd import DalitzPlotDecompositionBuilder, get_particle\n",
"from ampform_dpd.adapter.qrules import to_three_body_decay\n",
"from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n",
"from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecayChain\n",
"from ampform_dpd.io import (\n",
" as_markdown_table,\n",
Expand Down Expand Up @@ -101,7 +100,7 @@
" formalism=\"canonical-helicity\",\n",
" mass_conservation_factor=0.05,\n",
")\n",
"REACTION = relabel_edge_ids(REACTION)\n",
"REACTION = normalize_state_ids(REACTION)\n",
"dot = qrules.io.asdot(REACTION, collapse_graphs=True)\n",
"graphviz.Source(dot)"
]
Expand Down Expand Up @@ -300,26 +299,17 @@
" s = _get_mandelstam_s(decay_chain)\n",
" parameter_defaults = {}\n",
" production_ff, new_pars = _create_form_factor(s, production_node)\n",
" _is_messed_up(new_pars)\n",
" parameter_defaults.update(new_pars)\n",
" decay_ff, new_pars = _create_form_factor(s, decay_node)\n",
" _is_messed_up(new_pars)\n",
" parameter_defaults.update(new_pars)\n",
" breit_wigner, new_pars = _create_breit_wigner(s, decay_node)\n",
" _is_messed_up(new_pars)\n",
" parameter_defaults.update(new_pars)\n",
" return (\n",
" production_ff * decay_ff * breit_wigner,\n",
" parameter_defaults,\n",
" )\n",
"\n",
"\n",
"def _is_messed_up(parameter_defaults):\n",
" for key in parameter_defaults:\n",
" if key.name == \"m0\" and not key.is_nonnegative:\n",
" raise ValueError\n",
"\n",
"\n",
"def _create_form_factor(\n",
" s: sp.Symbol, isobar: IsobarNode\n",
") -> tuple[sp.Expr, dict[sp.Symbol, float]]:\n",
Expand Down
9 changes: 6 additions & 3 deletions docs/lc2pkpi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
"import graphviz\n",
"import qrules\n",
"import sympy as sp\n",
"from ampform.helicity.align.dpd import relabel_edge_ids\n",
"from IPython.display import Latex, Markdown\n",
"\n",
"from ampform_dpd import DalitzPlotDecompositionBuilder, get_particle\n",
"from ampform_dpd.adapter.qrules import load_particles, to_three_body_decay\n",
"from ampform_dpd.adapter.qrules import (\n",
" load_particles,\n",
" normalize_state_ids,\n",
" to_three_body_decay,\n",
")\n",
"from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecayChain\n",
"from ampform_dpd.dynamics import BreitWignerMinL\n",
"from ampform_dpd.io import as_markdown_table, aslatex, simplify_latex_rendering\n",
Expand Down Expand Up @@ -90,7 +93,7 @@
"STM.set_allowed_interaction_types([qrules.InteractionType.STRONG], node_id=1)\n",
"problem_sets = STM.create_problem_sets()\n",
"REACTION = STM.find_solutions(problem_sets)\n",
"REACTION = relabel_edge_ids(REACTION)\n",
"REACTION = normalize_state_ids(REACTION)\n",
"dot = qrules.io.asdot(REACTION, collapse_graphs=True)\n",
"graphviz.Source(dot)"
]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Typing :: Typed",
]
dependencies = [
"ampform >=0.15.0", # relabel_edge_ids
"ampform >=0.14.8", # Kibble and Kallen functions, perform_cached_doit, @unevaluated
"attrs >=20.1.0", # on_setattr and https://www.attrs.org/en/stable/api.html#next-gen
"cloudpickle",
"qrules >=0.10.0",
Expand Down
57 changes: 54 additions & 3 deletions src/ampform_dpd/adapter/qrules.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from collections import defaultdict
from collections import abc, defaultdict
from functools import singledispatch
from pathlib import Path
from typing import Any, Iterable
from typing import Any, Iterable, TypeVar, overload

import attrs
import qrules
from qrules.quantum_numbers import InteractionProperties
from qrules.topology import EdgeType, FrozenTransition, NodeType
from qrules.transition import State
from qrules.transition import ReactionInfo, State, StateTransition, Topology

from ampform_dpd.decay import (
IsobarNode,
Expand Down Expand Up @@ -149,3 +151,52 @@ def load_particles() -> qrules.particle.ParticleCollection:
additional_definitions = qrules.io.load(src_dir / "particle-definitions.yml") # type:ignore[arg-type]
particle_database.update(additional_definitions) # type:ignore[arg-type]
return particle_database


@overload
def normalize_state_ids(obj: T) -> T: ...
@overload
def normalize_state_ids(obj: Iterable[T]) -> list[T]: ...
def normalize_state_ids(obj): # pyright:ignore[reportInconsistentOverload]
"""Relabel the state IDs so that they lie in the range :math:`[0, N)`."""
return _impl_normalize_state_ids(obj)


@singledispatch
def _impl_normalize_state_ids(obj):
"""Relabel the state IDs so that they lie in the range :math:`[0, N)`."""
msg = f"Cannot relabel edge IDs of a {type(obj).__name__}"
raise NotImplementedError(msg)


@_impl_normalize_state_ids.register(ReactionInfo) # type:ignore[attr-defined]
def _(obj: ReactionInfo) -> ReactionInfo:
return ReactionInfo(
# no attrs.evolve() in order to call __attrs_post_init__()
transitions=[_impl_normalize_state_ids(g) for g in obj.transitions],
formalism=obj.formalism,
)


@_impl_normalize_state_ids.register(FrozenTransition) # type:ignore[attr-defined]
def _(obj: StateTransition) -> StateTransition:
return attrs.evolve(
obj,
topology=_impl_normalize_state_ids(obj.topology),
states={new: obj.states[old] for new, old in enumerate(sorted(obj.states))},
)


@_impl_normalize_state_ids.register(Topology) # type:ignore[attr-defined]
def _(obj: Topology) -> Topology:
mapping = {old: new for new, old in enumerate(sorted(obj.edges))}
return obj.relabel_edges(mapping)


@_impl_normalize_state_ids.register(abc.Iterable) # type:ignore[attr-defined]
def _(obj: abc.Iterable[T]) -> list[T]:
return [_impl_normalize_state_ids(x) for x in obj]


T = TypeVar("T", ReactionInfo, StateTransition, Topology)
"""Type variable for the input and output of :func:`normalize_state_ids`."""
27 changes: 26 additions & 1 deletion tests/adapter/test_qrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import pytest
import qrules

from ampform_dpd.adapter.qrules import filter_min_ls, to_three_body_decay
from ampform_dpd.adapter.qrules import (
filter_min_ls,
normalize_state_ids,
to_three_body_decay,
)
from ampform_dpd.decay import LSCoupling, Particle

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,6 +54,27 @@ def test_filter_min_ls(reaction: ReactionInfo):
]


def test_normalize_state_ids_reaction(reaction: ReactionInfo):
reaction012 = reaction
reaction123 = normalize_state_ids(reaction012)
assert set(reaction123.initial_state) == {0}
assert set(reaction123.final_state) == {1, 2, 3}

transitions123 = normalize_state_ids(reaction012.transitions)
for transition012, transition123 in zip(reaction012.transitions, transitions123):
assert set(transition123.initial_states) == {0}
assert set(transition123.final_states) == {1, 2, 3}
assert set(transition123.intermediate_states) == {4}

topology123 = normalize_state_ids(transition123.topology)
assert topology123.incoming_edge_ids == {0}
assert topology123.outgoing_edge_ids == {1, 2, 3}
assert topology123.intermediate_edge_ids == {4}

for i in transition012.states:
assert transition012.states[i] == transition123.states[i + 1]


@pytest.mark.parametrize("min_ls", [False, True])
def test_to_three_body_decay(reaction: ReactionInfo, min_ls: bool):
decay = to_three_body_decay(reaction.transitions, min_ls)
Expand Down

0 comments on commit 80e3e6d

Please sign in to comment.