Skip to content

Commit

Permalink
FEAT: implement AmplitudeModel.masses and .invariants (#96)
Browse files Browse the repository at this point in the history
* BREAK: move `simplify_latex_rendering()` to `io` module
* BREAK: remove `formulate_polarimetry()` function
* DOC: import docstring from ComPWA/polarimetry
* MAINT: move implementation functions closer to usage
  • Loading branch information
redeboer authored Mar 17, 2024
1 parent 0eafbba commit 31d573e
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 123 deletions.
210 changes: 93 additions & 117 deletions src/ampform_dpd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# cspell:ignore msigma
"""Module for formulating the amplitude model for a three-body decay using DPD."""

from __future__ import annotations

from functools import lru_cache
from itertools import product
from typing import Literal, Protocol

import sympy as sp
from ampform.kinematics.phasespace import compute_third_mandelstam
from ampform.sympy import PoolSum
from attrs import field, frozen
from sympy.core.symbol import Str
from sympy.physics.matrices import msigma
from sympy.physics.quantum.spin import CG, WignerD
from sympy.physics.quantum.spin import Rotation as Wigner

Expand All @@ -22,6 +23,9 @@
ThreeBodyDecayChain,
get_decay_product_ids,
)
from ampform_dpd.io import (
simplify_latex_rendering, # noqa: F401 # pyright:ignore[reportUnusedImport]
)
from ampform_dpd.spin import create_spin_range


Expand All @@ -32,6 +36,8 @@ class AmplitudeModel:
amplitudes: dict[sp.Indexed, sp.Expr] = field(factory=dict)
variables: dict[sp.Symbol, sp.Expr] = field(factory=dict)
parameter_defaults: dict[sp.Symbol, float] = field(factory=dict)
masses: dict[sp.Symbol, float] = field(factory=dict)
invariants: dict[sp.Symbol, float] = field(factory=dict)

@property
def full_expression(self) -> sp.Expr:
Expand Down Expand Up @@ -91,13 +97,7 @@ def formulate(
*helicity_symbols, reference_subsystem
)
angle_definitions.update(zeta_defs)
m0, m1, m2, m3 = sp.symbols("m:4", nonnegative=True)
masses = {
m0: self.decay.states[0].mass,
m1: self.decay.states[1].mass,
m2: self.decay.states[2].mass,
m3: self.decay.states[3].mass,
}
masses = create_mass_symbol_mapping(self.decay)
parameter_defaults.update(masses)
if cleanup_summations:
aligned_amp = aligned_amp.cleanup()
Expand All @@ -116,6 +116,8 @@ def formulate(
amplitudes=amplitude_definitions,
variables=angle_definitions,
parameter_defaults=parameter_defaults,
masses=masses,
invariants=formulate_invariants(self.decay),
)

def formulate_subsystem_amplitude( # noqa: PLR0914
Expand Down Expand Up @@ -256,6 +258,68 @@ def formulate_aligned_amplitude(
return amp_expr, wigner_generator.angle_definitions


def _create_coupling_symbol(
helicity_coupling: bool,
resonance: Str,
helicities: tuple[sp.Basic, sp.Basic],
interaction: LSCoupling,
typ: Literal["production", "decay"],
) -> sp.Indexed:
H = _get_coupling_base(helicity_coupling, typ)
if helicity_coupling:
λi, λj = helicities
return H[resonance, λi, λj]
return H[resonance, interaction.L, interaction.S]


@lru_cache(maxsize=None)
def _get_coupling_base(
helicity_coupling: bool, typ: Literal["production", "decay"]
) -> sp.IndexedBase:
if helicity_coupling:
return sp.IndexedBase(Rf"\mathcal{{H}}^\mathrm{{{typ}}}")
return sp.IndexedBase(Rf"\mathcal{{H}}^\mathrm{{LS,{typ}}}")


def _formulate_clebsch_gordan_factors(
isobar: IsobarNode,
helicities: dict[Particle, sp.Rational | sp.Symbol],
) -> sp.Expr:
if isobar.interaction is None:
msg = "Cannot formulate amplitude model in LS-basis if LS-couplings are missing"
raise ValueError(msg)
# https://github.com/ComPWA/ampform/blob/65b4efa/src/ampform/helicity/__init__.py#L785-L802
# and supplementary material p.1 (https://cds.cern.ch/record/2824328/files)
child1 = _get_particle(isobar.child1)
child2 = _get_particle(isobar.child2)
child1_helicity = helicities[child1]
child2_helicity = helicities[child2]
cg_ss = CG(
j1=child1.spin,
m1=child1_helicity,
j2=child2.spin,
m2=-child2_helicity,
j3=isobar.interaction.S,
m3=child1_helicity - child2_helicity,
)
cg_ll = CG(
j1=isobar.interaction.L,
m1=0,
j2=isobar.interaction.S,
m2=child1_helicity - child2_helicity,
j3=isobar.parent.spin,
m3=child1_helicity - child2_helicity,
)
sqrt_factor = sp.sqrt((2 * isobar.interaction.L + 1) / (2 * isobar.parent.spin + 1))
return sqrt_factor * cg_ll * cg_ss


def _get_particle(isobar: IsobarNode | Particle) -> Particle:
if isinstance(isobar, IsobarNode):
return isobar.parent
return isobar


@lru_cache(maxsize=None)
def _generate_amplitude_index_bases() -> dict[Literal[1, 2, 3], sp.IndexedBase]:
return dict(enumerate(sp.symbols(R"A^(1:4)", cls=sp.IndexedBase), 1))
Expand Down Expand Up @@ -325,116 +389,28 @@ def formulate_non_resonant(
return sp.Rational(1), {}


def simplify_latex_rendering() -> None:
"""Improve LaTeX rendering of an `~sympy.tensor.indexed.Indexed` object."""

def _print_Indexed_latex(self, printer, *args): # noqa: N802
base = printer._print(self.base)
indices = ", ".join(map(printer._print, self.indices))
return f"{base}_{{{indices}}}"

sp.Indexed._latex = _print_Indexed_latex


def _formulate_clebsch_gordan_factors(
isobar: IsobarNode,
helicities: dict[Particle, sp.Rational | sp.Symbol],
) -> sp.Expr:
if isobar.interaction is None:
msg = "Cannot formulate amplitude model in LS-basis if LS-couplings are missing"
raise ValueError(msg)
# https://github.com/ComPWA/ampform/blob/65b4efa/src/ampform/helicity/__init__.py#L785-L802
# and supplementary material p.1 (https://cds.cern.ch/record/2824328/files)
child1 = _get_particle(isobar.child1)
child2 = _get_particle(isobar.child2)
child1_helicity = helicities[child1]
child2_helicity = helicities[child2]
cg_ss = CG(
j1=child1.spin,
m1=child1_helicity,
j2=child2.spin,
m2=-child2_helicity,
j3=isobar.interaction.S,
m3=child1_helicity - child2_helicity,
)
cg_ll = CG(
j1=isobar.interaction.L,
m1=0,
j2=isobar.interaction.S,
m2=child1_helicity - child2_helicity,
j3=isobar.parent.spin,
m3=child1_helicity - child2_helicity,
)
sqrt_factor = sp.sqrt(
(2 * isobar.interaction.L + 1) / (2 * isobar.parent.spin + 1),
evaluate=False,
)
return sqrt_factor * cg_ll * cg_ss


def _get_particle(isobar: IsobarNode | Particle) -> Particle:
if isinstance(isobar, IsobarNode):
return isobar.parent
return isobar


def formulate_polarimetry(
builder: DalitzPlotDecompositionBuilder, reference_subsystem: Literal[1, 2, 3] = 1
) -> tuple[PoolSum, PoolSum, PoolSum]:
half = sp.Rational(1, 2)
if builder.decay.initial_state.spin != half:
msg = (
"Can only formulate polarimetry for an initial state with spin 1/2, but"
f" got {builder.decay.initial_state.spin}"
)
raise ValueError(msg)
model = builder.formulate(reference_subsystem)
λ0, λ0_prime = sp.symbols(R"lambda \lambda^{\prime}", rational=True)
λ = {
sp.Symbol(f"lambda{i}", rational=True): create_spin_range(state.spin)
for i, state in builder.decay.final_state.items()
def create_mass_symbol_mapping(decay: ThreeBodyDecay) -> dict[sp.Symbol, float]:
return {
sp.Symbol(f"m{i}"): decay.states[i].mass
for i in sorted(decay.states) # ensure that dict keys are sorted by state ID
}
ref = reference_subsystem
return tuple(
PoolSum(
builder.formulate_aligned_amplitude(λ0, *λ, ref)[0].conjugate()
* pauli_matrix[_to_index(λ0), _to_index(λ0_prime)]
* builder.formulate_aligned_amplitude(λ0_prime, *λ, ref)[0],
(λ0, [-half, +half]),
(λ0_prime, [-half, +half]),
*λ.items(),
).cleanup()
/ model.intensity
for pauli_matrix in map(msigma, [1, 2, 3])
)


def _to_index(helicity):
"""Symbolic conversion of half-value helicities to Pauli matrix indices."""
return sp.Piecewise(
(1, sp.LessThan(helicity, 0)),
(0, True),
)


def _create_coupling_symbol(
helicity_coupling: bool,
resonance: Str,
helicities: tuple[sp.Basic, sp.Basic],
interaction: LSCoupling,
typ: Literal["production", "decay"],
) -> sp.Indexed:
H = _get_coupling_base(helicity_coupling, typ)
if helicity_coupling:
λi, λj = helicities
return H[resonance, λi, λj]
return H[resonance, interaction.L, interaction.S]
def formulate_invariants(decay: ThreeBodyDecay) -> dict[sp.Symbol, sp.Expr]:
s1, s2, s3 = sp.symbols("sigma1:4", nonnegative=True)
return {
s1: formulate_third_mandelstam(decay, 2, 3),
s2: formulate_third_mandelstam(decay, 3, 1),
s3: formulate_third_mandelstam(decay, 1, 2),
}


@lru_cache(maxsize=None)
def _get_coupling_base(
helicity_coupling: bool, typ: Literal["production", "decay"]
) -> sp.IndexedBase:
if helicity_coupling:
return sp.IndexedBase(Rf"\mathcal{{H}}^\mathrm{{{typ}}}")
return sp.IndexedBase(Rf"\mathcal{{H}}^\mathrm{{LS,{typ}}}")
def formulate_third_mandelstam(
decay: ThreeBodyDecay,
x_mandelstam: Literal[1, 2, 3] = 1,
y_mandelstam: Literal[1, 2, 3] = 2,
) -> sp.Add:
m0, m1, m2, m3 = create_mass_symbol_mapping(decay)
sigma_x = sp.Symbol(f"sigma{x_mandelstam}", nonnegative=True)
sigma_y = sp.Symbol(f"sigma{y_mandelstam}", nonnegative=True)
return compute_third_mandelstam(sigma_x, sigma_y, m0, m1, m2, m3)
9 changes: 7 additions & 2 deletions src/ampform_dpd/angles.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Formulate expressions for scattering and alignment angles."""

from __future__ import annotations

import sympy as sp
Expand All @@ -17,8 +19,11 @@ def formulate_scattering_angle(
if not {state_id, sibling_id} <= {1, 2, 3}:
msg = "Child IDs need to be one of 1, 2, 3"
raise ValueError(msg)
# pyright: ignore[reportUnnecessaryContains]
if {state_id, sibling_id} in {(2, 1), (3, 2), (1, 3)}:
if {state_id, sibling_id} in { # pyright: ignore[reportUnnecessaryContains]
(2, 1),
(3, 2),
(1, 3),
}:
msg = f"Cannot compute scattering angle θ{state_id}{sibling_id}"
raise NotImplementedError(msg)
if state_id == sibling_id:
Expand Down
11 changes: 11 additions & 0 deletions src/ampform_dpd/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,14 @@ def _warn_about_unsafe_hash():
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)


def simplify_latex_rendering() -> None:
"""Improve LaTeX rendering of an `~sympy.tensor.indexed.Indexed` object."""

def _print_Indexed_latex(self, printer, *args): # noqa: N802
base = printer._print(self.base)
indices = ", ".join(map(printer._print, self.indices))
return f"{base}_{{{indices}}}"

sp.Indexed._latex = _print_Indexed_latex
14 changes: 10 additions & 4 deletions src/ampform_dpd/spin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Functions for generating spin projections and LS couplings."""

from __future__ import annotations

from decimal import Decimal
Expand All @@ -12,7 +14,8 @@ def generate_ls_couplings(
child2_spin: SupportsFloat,
max_L: int = 3, # noqa: N803
) -> list[tuple[int, sp.Rational]]:
r"""
"""Generate a list of allowed LS couplings.
>>> generate_ls_couplings(1.5, 0.5, 0)
[(1, 1/2), (2, 1/2)]
"""
Expand All @@ -35,7 +38,8 @@ def filter_parity_violating_ls(
child1_parity: SupportsInt,
child2_parity: SupportsInt,
) -> list[tuple[int, sp.Rational]]:
r"""
"""Filter parity-violating LS combinations from a list of LS couplings.
>>> LS = generate_ls_couplings(0.5, 1.5, 0) # Λc → Λ(1520)π
>>> LS
[(1, 3/2), (2, 3/2)]
Expand All @@ -51,7 +55,8 @@ def filter_parity_violating_ls(


def create_spin_range(spin: SupportsFloat) -> list[sp.Rational]:
"""
"""Create a range of allowed spin projections.
>>> create_spin_range(1.5)
[-3/2, -1/2, 1/2, 3/2]
"""
Expand All @@ -61,7 +66,8 @@ def create_spin_range(spin: SupportsFloat) -> list[sp.Rational]:
def create_rational_range(
__from: SupportsFloat, __to: SupportsFloat
) -> list[sp.Rational]:
"""
"""Create a range of rational numbers, especially useful for spin projections.
>>> create_rational_range(-0.5, +1.5)
[-1/2, 1/2, 3/2]
"""
Expand Down

0 comments on commit 31d573e

Please sign in to comment.