Skip to content

Commit

Permalink
MAINT: switch to unevaluated decorator in implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Jan 11, 2024
1 parent 436a7c7 commit c4e1db3
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 129 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"pycache",
"pygments",
"redeboer",
"repr",
"savefig",
"sdist",
"seealso",
Expand Down
55 changes: 26 additions & 29 deletions docs/jpsi2ksp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,15 @@
"import itertools\n",
"import logging\n",
"import os\n",
"from typing import TYPE_CHECKING, Iterable\n",
"from typing import TYPE_CHECKING, Any, Iterable\n",
"\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import qrules\n",
"import sympy as sp\n",
"from ampform.dynamics import EnergyDependentWidth, formulate_form_factor\n",
"from ampform.kinematics.phasespace import compute_third_mandelstam\n",
"from ampform.sympy import (\n",
" UnevaluatedExpression,\n",
" create_expression,\n",
" implement_doit_method,\n",
" make_commutative,\n",
" perform_cached_doit,\n",
")\n",
"from ampform.sympy import perform_cached_doit, unevaluated\n",
"from IPython.display import Latex, Markdown\n",
"from tensorwaves.data.transform import SympyDataTransformer\n",
"from tqdm.auto import tqdm\n",
Expand Down Expand Up @@ -279,13 +273,18 @@
},
"outputs": [],
"source": [
"@make_commutative\n",
"@implement_doit_method\n",
"class RelativisticBreitWigner(UnevaluatedExpression):\n",
" def __new__(cls, s, mass0, gamma0, m1, m2, angular_momentum, meson_radius):\n",
" return create_expression(\n",
" cls, s, mass0, gamma0, m1, m2, angular_momentum, meson_radius\n",
" )\n",
"@unevaluated\n",
"class RelativisticBreitWigner(sp.Expr):\n",
" s: Any\n",
" mass0: Any\n",
" gamma0: Any\n",
" m1: Any\n",
" m2: Any\n",
" angular_momentum: Any\n",
" meson_radius: Any\n",
" _latex_repr_ = (\n",
" R\"\\mathcal{{R}}_{{{angular_momentum}}}\\left({s}, {mass0}, {gamma0}\\right)\"\n",
" )\n",
"\n",
" def evaluate(self):\n",
" s, m0, w0, m1, m2, angular_momentum, meson_radius = self.args\n",
Expand All @@ -301,10 +300,6 @@
" )\n",
" return (m0 * w0) / (m0**2 - s - width * m0 * sp.I)\n",
"\n",
" def _latex(self, printer, *args) -> str:\n",
" s, m0, w0, _, _, L, *_ = map(printer._print, self.args)\n",
" return Rf\"\\mathcal{{R}}_{{{L}}}\\left({s}, {m0}, {w0}\\right)\"\n",
"\n",
"\n",
"bw = RelativisticBreitWigner(*sp.symbols(\"s m0 Gamma0 m1 m2 L R\"))\n",
"Latex(aslatex({bw: bw.doit(deep=False)}))"
Expand All @@ -330,11 +325,17 @@
},
"outputs": [],
"source": [
"@make_commutative\n",
"@implement_doit_method\n",
"class FormFactor(UnevaluatedExpression):\n",
" def __new__(cls, s, m1, m2, angular_momentum, meson_radius):\n",
" return create_expression(cls, s, m1, m2, angular_momentum, meson_radius)\n",
"@unevaluated\n",
"class FormFactor(sp.Expr):\n",
" s: Any\n",
" m1: Any\n",
" m2: Any\n",
" angular_momentum: Any\n",
" meson_radius: Any\n",
"\n",
" _latex_repr_ = (\n",
" R\"\\mathcal{{F}}_{{{angular_momentum}}}\\left({s}, {m1}, {m1}\\right)\"\n",
" )\n",
"\n",
" def evaluate(self):\n",
" s, m1, m2, angular_momentum, meson_radius = self.args\n",
Expand All @@ -346,10 +347,6 @@
" meson_radius=meson_radius,\n",
" )\n",
"\n",
" def _latex(self, printer, *args) -> str:\n",
" s, m1, m2, L, *_ = map(printer._print, self.args)\n",
" return Rf\"\\mathcal{{F}}_{{{L}}}\\left({s}, {m1}, {m1}\\right)\"\n",
"\n",
"\n",
"ff = FormFactor(*sp.symbols(\"s m1 m2 L R\"))\n",
"Latex(aslatex({ff: ff.doit(deep=False)}))"
Expand Down Expand Up @@ -973,7 +970,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.8.18"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
"Typing :: Typed",
]
dependencies = [
"ampform ~=0.14.4", # Kibble and Kallen functions, perform_cached_doit
"ampform >=0.14.8", # Kibble and Kallen functions, perform_cached_doit, @unevaluated
"attrs >=20.1.0", # on_setattr and https://www.attrs.org/en/stable/api.html#next-gen
"cloudpickle",
"sympy >=1.10", # module sympy.printing.numpy and array expressions with shape kwarg
Expand Down
161 changes: 62 additions & 99 deletions src/ampform_dpd/dynamics.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,52 @@
"""Functions for dynamics lineshapes and kinematics."""

# pyright: reportPrivateUsage=false
from __future__ import annotations

from typing import Any

import sympy as sp
from ampform.kinematics.phasespace import Kallen
from ampform.sympy import (
UnevaluatedExpression,
create_expression,
implement_doit_method,
make_commutative,
)
from ampform.sympy import unevaluated


@make_commutative
@implement_doit_method
class P(UnevaluatedExpression):
def __new__(cls, s, mi, mj, **hints):
return create_expression(cls, s, mi, mj, **hints)
@unevaluated
class P(sp.Expr):
s: Any
mi: Any
mj: Any
_latex_repr_ = R"p_{{{mi},{mj}}}\left({s}\right)"

def evaluate(self):
s, mi, mj = self.args
return sp.sqrt(Kallen(s, mi**2, mj**2)) / (2 * sp.sqrt(s))

def _latex(self, printer, *args):
s, mi, mj = map(printer._print, self.args)
return Rf"p_{{{mi},{mj}}}\left({s}\right)"


@make_commutative
@implement_doit_method
class Q(UnevaluatedExpression):
def __new__(cls, s, m0, mk, **hints):
return create_expression(cls, s, m0, mk, **hints)
@unevaluated
class Q(sp.Expr):
s: Any
m0: Any
mk: Any
_latex_repr_ = R"q_{{{m0},{mk}}}\left({s}\right)"

def evaluate(self):
s, m0, mk = self.args
return sp.sqrt(Kallen(s, m0**2, mk**2)) / (2 * m0) # <-- not s!

def _latex(self, printer, *args):
s, m0, mk = map(printer._print, self.args)
return Rf"q_{{{m0},{mk}}}\left({s}\right)"


@make_commutative
@implement_doit_method
class BreitWignerMinL(UnevaluatedExpression):
def __new__(
cls,
s,
decaying_mass,
spectator_mass,
resonance_mass,
resonance_width,
child2_mass,
child1_mass,
l_dec,
l_prod,
R_dec,
R_prod,
):
return create_expression(
cls,
s,
decaying_mass,
spectator_mass,
resonance_mass,
resonance_width,
child2_mass,
child1_mass,
l_dec,
l_prod,
R_dec,
R_prod,
)

@unevaluated
class BreitWignerMinL(sp.Expr):
s: Any
decaying_mass: Any
spectator_mass: Any
resonance_mass: Any
resonance_width: Any
child2_mass: Any
child1_mass: Any
l_dec: Any
l_prod: Any
R_dec: Any
R_prod: Any
_latex_repr_ = R"\mathcal{{R}}^\mathrm{{BW}}_{{{l_dec},{l_prod}}}\left({s}\right)"

def evaluate(self):
s, m_top, m_spec, m0, Γ0, m1, m2, l_dec, l_prod, R_dec, R_prod = self.args
Expand All @@ -91,18 +64,16 @@ def evaluate(self):
evaluate=False,
)

def _latex(self, printer, *args) -> str:
s = printer._print(self.args[0])
l_dec = printer._print(self.args[7])
l_prod = printer._print(self.args[8])
return Rf"\mathcal{{R}}^\mathrm{{BW}}_{{{l_dec},{l_prod}}}\left({s}\right)"


@make_commutative
@implement_doit_method
class BuggBreitWigner(UnevaluatedExpression):
def __new__(cls, s, m0, Γ0, m1, m2, γ):
return create_expression(cls, s, m0, Γ0, m1, m2, γ)
@unevaluated
class BuggBreitWigner(sp.Expr):
s: Any
m0: Any
Γ0: Any
m1: Any
m2: Any
γ: Any
_latex_repr_ = R"\mathcal{{R}}^\mathrm{{Bugg}}\left({s}\right)"

def evaluate(self):
s, m0, Γ0, m1, m2, γ = self.args
Expand All @@ -114,17 +85,16 @@ def evaluate(self):
)
return 1 / (m0**2 - s - sp.I * g_squared)

def _latex(self, printer, *args) -> str:
s = printer._print(self.args[0], *args)
return Rf"\mathcal{{R}}^\mathrm{{Bugg}}\left({s}\right)"


@make_commutative
@implement_doit_method
class FlattéSWave(UnevaluatedExpression):
@unevaluated
class FlattéSWave(sp.Expr):
# https://github.com/ComPWA/polarimetry/blob/34f5330/julia/notebooks/model0.jl#L151-L161
def __new__(cls, s, m0, widths, masses1, masses2):
return create_expression(cls, s, m0, widths, masses1, masses2)
s: Any
m0: Any
widths: tuple[Any, Any]
masses1: tuple[Any, Any]
masses2: tuple[Any, Any]
_latex_repr_ = R"\mathcal{{R}}^\mathrm{{Flatté}}\left({s}\right)"

def evaluate(self):
s, m0, (Γ1, Γ2), (ma1, mb1), (ma2, mb2) = self.args
Expand All @@ -137,16 +107,17 @@ def evaluate(self):
Γ = Γ1 + Γ2
return 1 / (m0**2 - s - sp.I * m0 * Γ)

def _latex(self, printer, *args) -> str:
s = printer._print(self.args[0])
return Rf"\mathcal{{R}}^\mathrm{{Flatté}}\left({s}\right)"


@make_commutative
@implement_doit_method
class EnergyDependentWidth(UnevaluatedExpression):
def __new__(cls, s, m0, Γ0, m1, m2, L, R):
return create_expression(cls, s, m0, Γ0, m1, m2, L, R)
@unevaluated
class EnergyDependentWidth(sp.Expr):
s: Any
m0: Any
Γ0: Any
m1: Any
m2: Any
L: Any
R: Any
_latex_repr_ = R"\Gamma\left({s}\right)"

def evaluate(self):
s, m0, Γ0, m1, m2, L, R = self.args
Expand All @@ -162,16 +133,12 @@ def evaluate(self):
evaluate=False,
)

def _latex(self, printer, *args) -> str:
s = printer._print(self.args[0])
return Rf"\Gamma\left({s}\right)"


@make_commutative
@implement_doit_method
class BlattWeisskopf(UnevaluatedExpression):
def __new__(cls, z, L, **hints):
return create_expression(cls, z, L, **hints)
@unevaluated
class BlattWeisskopf(sp.Expr):
z: Any
L: Any
_latex_repr_ = R"F_{{{L}}}\left({z}\right)"

def evaluate(self) -> sp.Piecewise:
z, L = self.args
Expand All @@ -183,7 +150,3 @@ def evaluate(self) -> sp.Piecewise:
return sp.Piecewise(

Check warning on line 150 in src/ampform_dpd/dynamics.py

View check run for this annotation

Codecov / codecov/patch

src/ampform_dpd/dynamics.py#L150

Added line #L150 was not covered by tests
*[(sp.sqrt(expr), sp.Eq(L, l_val)) for l_val, expr in cases.items()]
)

def _latex(self, printer, *args):
z, L = map(printer._print, self.args)
return Rf"F_{{{L}}}\left({z}\right)"

0 comments on commit c4e1db3

Please sign in to comment.