diff --git a/docs/conf.py b/docs/conf.py index fd54522f..009e4d43 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -36,10 +36,14 @@ add_module_names = False api_github_repo = f"{ORGANIZATION}/{REPO_NAME}" api_target_substitutions: dict[str, str | tuple[str, str]] = { + "ampform_dpd.decay.StateIDTemplate": ("obj", "ampform_dpd.decay.StateID"), + "DecayNode": ("obj", "ampform_dpd.decay.DecayNode"), + "FinalState": ("obj", "ampform_dpd.decay.FinalState"), + "FinalStateID": ("obj", "ampform_dpd.decay.FinalStateID"), "FrozenTransition": "qrules.topology.FrozenTransition", + "InitialStateID": ("obj", "ampform_dpd.decay.InitialStateID"), "Literal[-1, 1]": "typing.Literal", "Literal[(-1, 1)]": "typing.Literal", - "OuterStates": ("obj", "ampform_dpd.decay.OuterStates"), "ParameterValue": ("obj", "tensorwaves.interface.ParameterValue"), "ParametrizedBackendFunction": "tensorwaves.function.ParametrizedBackendFunction", "PoolSum": "ampform.sympy.PoolSum", @@ -51,6 +55,8 @@ "sp.Indexed": "sympy.tensor.indexed.Indexed", "sp.Rational": "sympy.core.numbers.Rational", "sp.Symbol": "sympy.core.symbol.Symbol", + "StateID": ("obj", "ampform_dpd.decay.StateID"), + "StateIDTemplate": ("obj", "ampform_dpd.decay.StateID"), } api_target_types: dict[str, str] = {} author = "Common Partial Wave Analysis" diff --git a/docs/jpsi2ksp.ipynb b/docs/jpsi2ksp.ipynb index 067cd192..a968a98a 100644 --- a/docs/jpsi2ksp.ipynb +++ b/docs/jpsi2ksp.ipynb @@ -44,10 +44,11 @@ "from tensorwaves.data.transform import SympyDataTransformer\n", "from tqdm.auto import tqdm\n", "\n", - "from ampform_dpd import DalitzPlotDecompositionBuilder, get_particle\n", + "from ampform_dpd import DalitzPlotDecompositionBuilder\n", "from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n", - "from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecayChain\n", + "from ampform_dpd.decay import State\n", "from ampform_dpd.dynamics import FormFactor, RelativisticBreitWigner\n", + "from ampform_dpd.dynamics.builder import formulate_breit_wigner_with_form_factor\n", "from ampform_dpd.io import (\n", " as_markdown_table,\n", " aslatex,\n", @@ -69,7 +70,8 @@ { "cell_type": "markdown", "metadata": { - "jp-MarkdownHeadingCollapsed": true + "jp-MarkdownHeadingCollapsed": true, + "tags": [] }, "source": [ "## Decay definition" @@ -158,173 +160,6 @@ "Latex(aslatex(DECAY, with_jp=True))" ] }, - { - "cell_type": "markdown", - "metadata": { - "jp-MarkdownHeadingCollapsed": true - }, - "source": [ - "## Lineshapes for dynamics" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - ":::{note}\n", - "As opposed to [AmpForm](https://ampform.rtfd.io), AmpForm-DPD defines dynamics over the **entire decay chain**, not a single isobar node. The dynamics classes and the corresponding builders would have to be extended to implement other dynamics lineshapes.\n", - ":::" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true - }, - "tags": [ - "hide-input", - "full-width" - ] - }, - "outputs": [], - "source": [ - "s, m0, w0, m1, m2, L, R, z = sp.symbols(\"s m0 Gamma0 m1 m2 L R z\")\n", - "exprs = [\n", - " RelativisticBreitWigner(s, m0, w0, m1, m2, L, R),\n", - " EnergyDependentWidth(s, m0, w0, m1, m2, L, R),\n", - " FormFactor(s, m1, m2, L, R),\n", - " BlattWeisskopfSquared(z, L),\n", - "]\n", - "Latex(aslatex({e: e.doit(deep=False) for e in exprs}))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true - }, - "mystnb": { - "code_prompt_show": "Define dynamics builder functions" - }, - "tags": [ - "hide-input", - "scroll-input" - ] - }, - "outputs": [], - "source": [ - "def formulate_breit_wigner_with_ff(\n", - " decay_chain: ThreeBodyDecayChain,\n", - ") -> tuple[sp.Expr, dict[sp.Symbol, float]]:\n", - " production_node = decay_chain.decay\n", - " assert isinstance(production_node.child1, IsobarNode), \"Not a 3-body isobar decay\"\n", - " decay_node = production_node.child1\n", - " s = _get_mandelstam_s(decay_chain)\n", - " parameter_defaults = {}\n", - " production_ff, new_pars = _create_form_factor(s, production_node)\n", - " parameter_defaults.update(new_pars)\n", - " decay_ff, new_pars = _create_form_factor(s, decay_node)\n", - " parameter_defaults.update(new_pars)\n", - " breit_wigner, new_pars = _create_breit_wigner(s, decay_node)\n", - " parameter_defaults.update(new_pars)\n", - " return (\n", - " production_ff * decay_ff * breit_wigner,\n", - " parameter_defaults,\n", - " )\n", - "\n", - "\n", - "def _create_form_factor(\n", - " s: sp.Symbol, isobar: IsobarNode\n", - ") -> tuple[sp.Expr, dict[sp.Symbol, float]]:\n", - " assert isobar.interaction is not None, \"Need LS-couplings\"\n", - " if isobar.parent.name == \"J/psi(1S)\":\n", - " inv_mass = sp.Symbol(\"m0\", nonnegative=True)\n", - " else:\n", - " inv_mass = _get_mandelstam_s(isobar)\n", - " outgoing_state_mass1 = _create_mass_symbol(isobar.child1)\n", - " outgoing_state_mass2 = _create_mass_symbol(isobar.child2)\n", - " meson_radius = _create_meson_radius_symbol(isobar.parent)\n", - " form_factor = FormFactor(\n", - " s=inv_mass**2,\n", - " m1=outgoing_state_mass1,\n", - " m2=outgoing_state_mass2,\n", - " angular_momentum=isobar.interaction.L,\n", - " meson_radius=meson_radius,\n", - " )\n", - " parameter_defaults = {\n", - " meson_radius: 1,\n", - " outgoing_state_mass1: get_particle(isobar.child1).mass,\n", - " outgoing_state_mass2: get_particle(isobar.child2).mass,\n", - " }\n", - " if not inv_mass.name.startswith(\"s\"):\n", - " parameter_defaults[inv_mass] = get_particle(isobar).mass\n", - " return form_factor, parameter_defaults\n", - "\n", - "\n", - "def _create_breit_wigner(\n", - " s: sp.Symbol, isobar: IsobarNode\n", - ") -> tuple[sp.Expr, dict[sp.Symbol, float]]:\n", - " assert isobar.interaction is not None, \"Need LS-couplings\"\n", - " outgoing_state_mass1 = _create_mass_symbol(isobar.child1)\n", - " outgoing_state_mass2 = _create_mass_symbol(isobar.child2)\n", - " angular_momentum = isobar.interaction.L\n", - " res_mass = _create_mass_symbol(isobar.parent)\n", - " res_width = sp.Symbol(Rf\"\\Gamma_{{{isobar.parent.latex}}}\", nonnegative=True)\n", - " meson_radius = _create_meson_radius_symbol(isobar.parent)\n", - "\n", - " breit_wigner_expr = RelativisticBreitWigner(\n", - " s=s,\n", - " mass0=res_mass,\n", - " gamma0=res_width,\n", - " m1=outgoing_state_mass1,\n", - " m2=outgoing_state_mass2,\n", - " angular_momentum=angular_momentum,\n", - " meson_radius=meson_radius,\n", - " )\n", - " parameter_defaults = {\n", - " res_mass: isobar.parent.mass,\n", - " res_width: isobar.parent.width,\n", - " meson_radius: 1,\n", - " }\n", - " return breit_wigner_expr, parameter_defaults\n", - "\n", - "\n", - "def _create_meson_radius_symbol(isobar: IsobarNode) -> sp.Symbol:\n", - " if get_particle(isobar).name == \"J/psi(1S)\":\n", - " return sp.Symbol(R\"R_{J/\\psi}\")\n", - " return sp.Symbol(R\"R_\\mathrm{res}\")\n", - "\n", - "\n", - "def _create_mass_symbol(particle: IsobarNode | Particle) -> sp.Symbol:\n", - " particle = get_particle(particle)\n", - " return sp.Symbol(f\"m_{{{particle.latex}}}\", nonnegative=True)\n", - "\n", - "\n", - "def _get_mandelstam_s(decay: ThreeBodyDecayChain | IsobarNode) -> sp.Symbol:\n", - " s1, s2, s3 = sp.symbols(\"sigma1:4\", nonnegative=True)\n", - " decay_products = {p.name for p in _get_decay_products(decay)}\n", - " if decay_products == {\"Sigma+\", \"p~\"}:\n", - " return s1\n", - " if decay_products == {\"K0\", \"p~\"}:\n", - " return s2\n", - " if decay_products == {\"K0\", \"Sigma+\"}:\n", - " return s3\n", - " msg = f\"Cannot find Mandelstam variable for {', '.join(decay_products)}\"\n", - " raise NotImplementedError(msg)\n", - "\n", - "\n", - "def _get_decay_products(\n", - " decay: ThreeBodyDecayChain | IsobarNode,\n", - ") -> tuple[Particle, Particle]:\n", - " if isinstance(decay, ThreeBodyDecayChain):\n", - " return decay.decay_products\n", - " return decay.children" - ] - }, { "cell_type": "markdown", "metadata": { @@ -354,19 +189,12 @@ "model_builder = DalitzPlotDecompositionBuilder(DECAY, min_ls=False)\n", "for chain in model_builder.decay.chains:\n", " model_builder.dynamics_choices.register_builder(\n", - " chain, formulate_breit_wigner_with_ff\n", + " chain, formulate_breit_wigner_with_form_factor\n", " )\n", "model = model_builder.formulate(reference_subsystem=1)\n", "model.intensity" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "where the angles can be computed from initial and final state masses $m_0$, $m_1$, $m_2$, and $m_3$:" - ] - }, { "cell_type": "code", "execution_count": null, @@ -405,67 +233,37 @@ "Latex(aslatex({k: v for k, v in model.amplitudes.items() if v}))" ] }, - { - "cell_type": "markdown", - "metadata": { - "tags": [] - }, - "source": [ - "## Preparing for input data" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The {meth}`~sympy.core.basic.Basic.doit` operation can be cached to disk with {func}`~ampform.sympy.perform_cached_doit`. We do this twice, once for the unfolding of the {attr}`~.AmplitudeModel.intensity` expression and second for the substitution and unfolding of the {attr}`~.AmplitudeModel.amplitudes`. Note that we could also have unfolded the intensity and substituted the amplitudes with {attr}`~.AmplitudeModel.full_expression`, but then the unfolded {attr}`~.AmplitudeModel.intensity` expression is not cached." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "unfolded_intensity_expr = perform_cached_doit(model.intensity)\n", - "full_intensity_expr = perform_cached_doit(\n", - " unfolded_intensity_expr.xreplace(model.amplitudes)\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "tags": [] - }, - "source": [ - "With this, the remaining {class}`~sympy.core.symbol.Symbol`s in the full expression are kinematic variables.[^1]\n", - "\n", - "[^1]: Yes, there are still $\\mathcal{H}^\\mathrm{production}$ and $\\mathcal{H}^\\mathrm{decay}$, but these are the {attr}`~sympy.tensor.indexed.Indexed.base`s of the {class}`~sympy.tensor.indexed.Indexed` coupling symbols. They should **NOT** be substituted." - ] - }, { "cell_type": "code", "execution_count": null, "metadata": { + "jupyter": { + "source_hidden": true + }, "tags": [ - "hide-input" + "hide-input", + "full-width" ] }, "outputs": [], "source": [ - "sp.Array(\n", - " sorted(full_intensity_expr.free_symbols - set(model.parameter_defaults), key=str)\n", - ")" + "s, m0, w0, m1, m2, L, R, z = sp.symbols(\"s m0 Gamma0 m1 m2 L R z\")\n", + "exprs = [\n", + " RelativisticBreitWigner(s, m0, w0, m1, m2, L, R),\n", + " EnergyDependentWidth(s, m0, w0, m1, m2, L, R),\n", + " FormFactor(s, m1, m2, L, R),\n", + " BlattWeisskopfSquared(z, L),\n", + "]\n", + "Latex(aslatex({e: e.doit(deep=False) for e in exprs}))" ] }, { "cell_type": "markdown", "metadata": { - "tags": [] + "jp-MarkdownHeadingCollapsed": true }, "source": [ - "The $\\theta$ and $\\zeta$ angles are defined by the {attr}`~.AmplitudeModel.variables` attribute (they are shown under {ref}`jpsi2ksp:Model formulation`). Those definitions allow us to create a converter that computes kinematic variables from masses and Mandelstam variables:" + "## Preparing for input data" ] }, { @@ -473,170 +271,159 @@ "execution_count": null, "metadata": { "tags": [ - "hide-output" + "hide-input" ] }, "outputs": [], "source": [ - "masses_to_angles = SympyDataTransformer.from_sympy(model.variables, backend=\"jax\")\n", - "masses_to_angles.functions" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Dalitz plot" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "tags": [] - }, - "source": [ - "The data input for this data transformer can be several things. One can compute them from a (generated) data sample of four-momenta. Or one can compute them for a Dalitz plane. We do the latter in this section.\n", - "\n", - "First, the data transformer defined above expects values for the masses. We have already defined these values above, but we need to convert them from {mod}`sympy` objects to numerical data:" + "i, j = (3, 2)\n", + "k, *_ = {1, 2, 3} - {i, j}\n", + "σk, σk_expr = list(model.invariants.items())[k - 1]\n", + "Latex(aslatex({σk: σk_expr}))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "tags": [] + "jupyter": { + "source_hidden": true + }, + "mystnb": { + "code_prompt_show": "Define meshgrid for Dalitz plot" + }, + "tags": [ + "hide-input", + "remove-output" + ] }, "outputs": [], "source": [ - "dalitz_data = {str(s): float(v) for s, v in model.masses.items()}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, we define a grid of data points over Mandelstam (Dalitz) variables $\\sigma_2=m_{13}, \\sigma_3=m_{12}$:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "resolution = 500\n", + "resolution = 1_000\n", + "m = sorted(model.masses, key=str)\n", + "x_min = float(((m[j] + m[k]) ** 2).xreplace(model.masses))\n", + "x_max = float(((m[0] - m[i]) ** 2).xreplace(model.masses))\n", + "y_min = float(((m[i] + m[k]) ** 2).xreplace(model.masses))\n", + "y_max = float(((m[0] - m[j]) ** 2).xreplace(model.masses))\n", + "x_diff = x_max - x_min\n", + "y_diff = y_max - y_min\n", + "x_min -= 0.05 * x_diff\n", + "x_max += 0.05 * x_diff\n", + "y_min -= 0.05 * y_diff\n", + "y_max += 0.05 * y_diff\n", "X, Y = jnp.meshgrid(\n", - " jnp.linspace(1.66**2, 2.18**2, num=resolution),\n", - " jnp.linspace(1.4**2, 1.93**2, num=resolution),\n", - ")\n", - "dalitz_data[\"sigma3\"] = X\n", - "dalitz_data[\"sigma2\"] = Y" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "tags": [] - }, - "source": [ - "The remaining Mandelstam variable can be expressed in terms of the others as follows:" + " jnp.linspace(x_min, x_max, num=resolution),\n", + " jnp.linspace(y_min, y_max, num=resolution),\n", + ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { + "jupyter": { + "source_hidden": true + }, + "mystnb": { + "code_prompt_show": "Create data converter for Dalitz coordinates" + }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ - "(s1, s1_expr), *_ = model.invariants.items()\n", - "Latex(aslatex({s1: s1_expr}))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "tags": [] - }, - "source": [ - "That completes the data sample over which we want to evaluate the intensity model defined above:" + "definitions = dict(model.variables)\n", + "definitions[σk] = σk_expr\n", + "definitions = {\n", + " symbol: expr.xreplace(definitions).xreplace(model.masses)\n", + " for symbol, expr in definitions.items()\n", + "}\n", + "data_transformer = SympyDataTransformer.from_sympy(definitions, backend=\"jax\")\n", + "dalitz_data = {\n", + " f\"sigma{i}\": X,\n", + " f\"sigma{j}\": Y,\n", + "}\n", + "dalitz_data.update(data_transformer(dalitz_data))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { + "jupyter": { + "source_hidden": true + }, "tags": [ - "hide-output" + "remove-input" ] }, "outputs": [], "source": [ - "sigma1_func = perform_cached_lambdify(s1_expr, backend=\"jax\")\n", - "dalitz_data[\"sigma1\"] = sigma1_func(dalitz_data)\n", - "dalitz_data" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "tags": [] - }, - "source": [ - "We can now extend the sample with angle definitions so that we have a data sample over which the intensity can be evaluated." + "for key, array in dalitz_data.items():\n", + " assert not jnp.all(jnp.isnan(array)), f\"All values for {key} are NaN\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "tags": [] + "jupyter": { + "source_hidden": true + }, + "mystnb": { + "code_prompt_show": "Prepare parametrized numerical function" + }, + "tags": [ + "hide-input" + ] }, "outputs": [], "source": [ - "angle_data = masses_to_angles(dalitz_data)\n", - "dalitz_data.update(angle_data)" + "unfolded_intensity_expr = perform_cached_doit(model.intensity)\n", + "full_intensity_expr = perform_cached_doit(\n", + " unfolded_intensity_expr.xreplace(model.amplitudes)\n", + ")\n", + "free_parameters = {\n", + " k: v\n", + " for k, v in model.parameter_defaults.items()\n", + " if isinstance(k, sp.Indexed)\n", + " if \"production\" in str(k) or \"decay\" in str(k)\n", + "}\n", + "fixed_parameters = {\n", + " k: v for k, v in model.parameter_defaults.items() if k not in free_parameters\n", + "}\n", + "intensity_func = perform_cached_lambdify(\n", + " full_intensity_expr.xreplace(fixed_parameters),\n", + " parameters=free_parameters,\n", + " backend=\"jax\",\n", + ")\n", + "intensities = intensity_func(dalitz_data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { + "jupyter": { + "source_hidden": true + }, "tags": [ "remove-input" ] }, "outputs": [], "source": [ - "for k, v in dalitz_data.items():\n", - " assert not jnp.all(jnp.isnan(v)), f\"All values for {k} are NaN\"" + "assert not jnp.all(jnp.isnan(intensities)), \"All intensities are NaN\"" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": { "tags": [] }, - "outputs": [], "source": [ - "free_parameters = {\n", - " k: v\n", - " for k, v in model.parameter_defaults.items()\n", - " if isinstance(k, sp.Indexed)\n", - " if \"production\" in str(k) or \"decay\" in str(k)\n", - "}\n", - "fixed_parameters = {\n", - " k: v for k, v in model.parameter_defaults.items() if k not in free_parameters\n", - "}\n", - "intensity_func = perform_cached_lambdify(\n", - " full_intensity_expr.xreplace(fixed_parameters),\n", - " parameters=free_parameters,\n", - " backend=\"jax\",\n", - ")" + "## Dalitz plot" ] }, { @@ -665,19 +452,25 @@ }, "outputs": [], "source": [ + "def get_decay_products(subsystem_id: int) -> tuple[State, State]:\n", + " return tuple(s for s in DECAY.final_state.values() if s.index != subsystem_id)\n", + "\n", + "\n", "plt.rc(\"font\", size=18)\n", - "intensities = intensity_func(dalitz_data)\n", "I_tot = jnp.nansum(intensities)\n", "normalized_intensities = intensities / I_tot\n", - "assert not jnp.all(jnp.isnan(normalized_intensities)), \"All intensities are NaN\"\n", "\n", "fig, ax = plt.subplots(figsize=(14, 10))\n", "mesh = ax.pcolormesh(X, Y, normalized_intensities)\n", "ax.set_aspect(\"equal\")\n", "c_bar = plt.colorbar(mesh, ax=ax, pad=0.01)\n", "c_bar.ax.set_ylabel(\"Normalized intensity (a.u.)\")\n", - "ax.set_xlabel(R\"$\\sigma_3 = M^2\\left(K^0\\Sigma^+\\right)$\")\n", - "ax.set_ylabel(R\"$\\sigma_2 = M^2\\left(K^0\\bar{p}\\right)$\")\n", + "sigma_labels = {\n", + " i: Rf\"$\\sigma_{i} = M^2\\left({' '.join(p.latex for p in get_decay_products(i))}\\right)$\"\n", + " for i in (1, 2, 3)\n", + "}\n", + "ax.set_xlabel(sigma_labels[i])\n", + "ax.set_ylabel(sigma_labels[j])\n", "plt.show()" ] }, @@ -738,33 +531,30 @@ " ax.set_ylim(0, y_max)\n", " ax.autoscale(enable=False, axis=\"x\")\n", "ax1.set_ylabel(\"Normalized intensity (a.u.)\")\n", - "ax1.set_xlabel(R\"$M\\left(K^0\\Sigma^+\\right)$\")\n", - "ax2.set_xlabel(R\"$M\\left(K^0\\bar{p}\\right)$\")\n", - "i1, i2 = 0, 0\n", + "ax1.set_xlabel(sigma_labels[i])\n", + "ax2.set_xlabel(sigma_labels[j])\n", + "resonance_counter1, resonance_counter2 = 0, 0\n", "for chain in tqdm(model.decay.chains, disable=NO_TQDM):\n", " resonance = chain.resonance\n", - " decay_product = {p.name for p in chain.decay_products}\n", - " if decay_product == {\"K0\", \"Sigma+\"}:\n", + " if set(chain.decay_products) == set(get_decay_products(subsystem_id=i)):\n", " ax = ax1\n", - " i1 += 1\n", - " i = i1\n", + " resonance_counter1 += 1\n", + " color = f\"C{resonance_counter1}\"\n", " projection_axis = 0\n", " x_data = x\n", - " elif decay_product == {\"K0\", \"p~\"}:\n", + " elif set(chain.decay_products) == set(get_decay_products(subsystem_id=j)):\n", " ax = ax2\n", - " i2 += 1\n", - " i = i2\n", + " resonance_counter2 += 1\n", + " color = f\"C{resonance_counter2}\"\n", " projection_axis = 1\n", " x_data = y\n", " else:\n", - " continue\n", + " raise NotImplementedError\n", " sub_intensities = compute_sub_intensity(\n", " intensity_func, dalitz_data, resonance.latex\n", " )\n", - " ax.plot(\n", - " x_data, jnp.nansum(sub_intensities / I_tot, axis=projection_axis), c=f\"C{i}\"\n", - " )\n", - " ax.axvline(resonance.mass, label=f\"${resonance.latex}$\", c=f\"C{i}\", ls=\"dashed\")\n", + " ax.plot(x_data, jnp.nansum(sub_intensities / I_tot, axis=projection_axis), c=color)\n", + " ax.axvline(resonance.mass, label=f\"${resonance.latex}$\", c=color, ls=\"dashed\")\n", "for ax in axes:\n", " ax.legend(fontsize=12)\n", "plt.show()" diff --git a/docs/lc2pkpi.ipynb b/docs/lc2pkpi.ipynb index e9c7845e..569bcde0 100644 --- a/docs/lc2pkpi.ipynb +++ b/docs/lc2pkpi.ipynb @@ -33,14 +33,15 @@ "import sympy as sp\n", "from IPython.display import Latex, Markdown\n", "\n", - "from ampform_dpd import DalitzPlotDecompositionBuilder, get_particle\n", + "from ampform_dpd import DalitzPlotDecompositionBuilder\n", "from ampform_dpd.adapter.qrules import (\n", " load_particles,\n", " normalize_state_ids,\n", " to_three_body_decay,\n", ")\n", - "from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecayChain\n", + "from ampform_dpd.decay import ThreeBodyDecayChain\n", "from ampform_dpd.dynamics import BreitWignerMinL\n", + "from ampform_dpd.dynamics.builder import create_mass_symbol, get_mandelstam_s\n", "from ampform_dpd.io import as_markdown_table, aslatex, simplify_latex_rendering\n", "\n", "simplify_latex_rendering()\n", @@ -185,6 +186,9 @@ "jupyter": { "source_hidden": true }, + "mystnb": { + "code_prompt_show": "Dynamics builder function" + }, "tags": [ "hide-input" ] @@ -194,8 +198,8 @@ "def formulate_breit_wigner(\n", " decay_chain: ThreeBodyDecayChain,\n", ") -> tuple[BreitWignerMinL, dict[sp.Symbol, float]]:\n", - " s = _get_mandelstam_s(decay_chain)\n", - " child1_mass, child2_mass = map(_create_mass_symbol, decay_chain.decay_products)\n", + " s = get_mandelstam_s(decay_chain.decay_node)\n", + " child1_mass, child2_mass = map(create_mass_symbol, decay_chain.decay_products)\n", " l_dec = sp.Rational(decay_chain.outgoing_ls.L)\n", " l_prod = sp.Rational(decay_chain.incoming_ls.L)\n", " parent_mass = sp.Symbol(f\"m_{{{decay_chain.parent.latex}}}\", nonnegative=True)\n", @@ -230,36 +234,14 @@ " R_dec,\n", " R_prod,\n", " )\n", - " return dynamics, parameter_defaults\n", - "\n", - "\n", - "def _create_mass_symbol(particle: IsobarNode | Particle) -> sp.Symbol:\n", - " particle = get_particle(particle)\n", - " return sp.Symbol(f\"m_{{{particle.latex}}}\", nonnegative=True)\n", - "\n", - "\n", - "def _get_mandelstam_s(decay: ThreeBodyDecayChain | IsobarNode) -> sp.Symbol:\n", - " s1, s2, s3 = sp.symbols(\"sigma1:4\", nonnegative=True)\n", - " decay_products = {p.name for p in _get_decay_products(decay)}\n", - " if decay_products == {\"p\", \"pi+\"}:\n", - " return s1\n", - " if decay_products == {\"pi+\", \"K-\"}:\n", - " return s2\n", - " if decay_products == {\"K-\", \"p\"}:\n", - " return s3\n", - " msg = f\"Cannot find Mandelstam variable for {', '.join(decay_products)}\"\n", - " raise NotImplementedError(msg)\n", - "\n", - "\n", - "def _get_decay_products(decay: ThreeBodyDecayChain | IsobarNode) -> list[Particle]:\n", - " if isinstance(decay, ThreeBodyDecayChain):\n", - " return decay.decay_products\n", - " return decay.children" + " return dynamics, parameter_defaults" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "tags": [] + }, "source": [ "## Model formulation" ] diff --git a/src/ampform_dpd/__init__.py b/src/ampform_dpd/__init__.py index 4c5af6cb..4c6b5cfb 100644 --- a/src/ampform_dpd/__init__.py +++ b/src/ampform_dpd/__init__.py @@ -16,13 +16,14 @@ from ampform_dpd.angles import formulate_scattering_angle, formulate_zeta_angle from ampform_dpd.decay import ( + FinalStateID, IsobarNode, LSCoupling, Particle, ThreeBodyDecay, ThreeBodyDecayChain, get_decay_product_ids, - get_particle, + to_particle, ) from ampform_dpd.io import ( simplify_latex_rendering, # noqa: F401 # pyright:ignore[reportUnusedImport] @@ -77,7 +78,7 @@ def __init__( def formulate( self, - reference_subsystem: Literal[1, 2, 3] = 1, + reference_subsystem: FinalStateID = 1, cleanup_summations: bool = False, ) -> AmplitudeModel: helicity_symbols: tuple[sp.Symbol, sp.Symbol, sp.Symbol, sp.Symbol] = ( @@ -130,7 +131,7 @@ def formulate_subsystem_amplitude( # noqa: PLR0914 λ1: sp.Rational, λ2: sp.Rational, λ3: sp.Rational, - subsystem_id: Literal[1, 2, 3], + subsystem_id: FinalStateID, ) -> AmplitudeModel: k = subsystem_id i, j = get_decay_product_ids(subsystem_id) @@ -231,7 +232,7 @@ def formulate_aligned_amplitude( λ1: sp.Rational | sp.Symbol, λ2: sp.Rational | sp.Symbol, λ3: sp.Rational | sp.Symbol, - reference_subsystem: Literal[1, 2, 3] = 1, + reference_subsystem: FinalStateID = 1, ) -> tuple[PoolSum, dict[sp.Symbol, sp.Expr]]: wigner_generator = _AlignmentWignerGenerator(reference_subsystem) _λ0, _λ1, _λ2, _λ3 = sp.symbols(R"\lambda_(0:4)^{\prime}", rational=True) @@ -296,8 +297,8 @@ def _formulate_clebsch_gordan_factors( raise ValueError(msg) # https://github.com/ComPWA/ampform/blob/65b4efa/src/ampform/helicity/__init__.py#L785-L802 # and supplementary material p.1 (https://cds.cern.ch/record/2824328/files) - child1 = get_particle(isobar.child1) - child2 = get_particle(isobar.child2) + child1 = to_particle(isobar.child1) + child2 = to_particle(isobar.child2) child1_helicity = helicities[child1] child2_helicity = helicities[child2] cg_ss = CG( @@ -321,12 +322,12 @@ def _formulate_clebsch_gordan_factors( @lru_cache(maxsize=None) -def _generate_amplitude_index_bases() -> dict[Literal[1, 2, 3], sp.IndexedBase]: +def _generate_amplitude_index_bases() -> dict[FinalStateID, sp.IndexedBase]: return dict(enumerate(sp.symbols(R"A^(1:4)", cls=sp.IndexedBase), 1)) # type:ignore[arg-type] class _AlignmentWignerGenerator: - def __init__(self, reference_subsystem: Literal[1, 2, 3] = 1) -> None: + def __init__(self, reference_subsystem: FinalStateID = 1) -> None: self.angle_definitions: dict[sp.Symbol, sp.acos] = {} self.reference_subsystem = reference_subsystem @@ -409,8 +410,8 @@ def formulate_invariants(decay: ThreeBodyDecay) -> dict[sp.Symbol, sp.Expr]: def formulate_third_mandelstam( decay: ThreeBodyDecay, - x_mandelstam: Literal[1, 2, 3] = 1, - y_mandelstam: Literal[1, 2, 3] = 2, + x_mandelstam: FinalStateID = 1, + y_mandelstam: FinalStateID = 2, ) -> sp.Add: m0, m1, m2, m3 = create_mass_symbol_mapping(decay) sigma_x = sp.Symbol(f"sigma{x_mandelstam}", nonnegative=True) diff --git a/src/ampform_dpd/adapter/qrules.py b/src/ampform_dpd/adapter/qrules.py index 5b749e35..8b588dcf 100644 --- a/src/ampform_dpd/adapter/qrules.py +++ b/src/ampform_dpd/adapter/qrules.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools import logging from collections import abc, defaultdict from functools import singledispatch @@ -10,12 +11,15 @@ import qrules from qrules.quantum_numbers import InteractionProperties from qrules.topology import EdgeType, FrozenTransition, NodeType -from qrules.transition import ReactionInfo, State, StateTransition, Topology +from qrules.transition import ReactionInfo, StateTransition, Topology from ampform_dpd.decay import ( + FinalStateID, IsobarNode, LSCoupling, Particle, + State, + StateIDTemplate, ThreeBodyDecay, ThreeBodyDecayChain, ) @@ -36,23 +40,26 @@ def to_three_body_decay( ) != {1, 2, 3}: transitions = normalize_state_ids(transitions) _LOGGER.warning("Relabeled initial state to 0 and final states to 1, 2, 3") - transitions = convert_edges_and_nodes(transitions) + transitions = convert_transitions(transitions) if min_ls: transitions = filter_min_ls(transitions) some_transition = transitions[0] - initial_state, *_ = some_transition.initial_states.values() - final_states = { - i: some_transition.final_states[idx] - for i, idx in enumerate(sorted(some_transition.final_states), 1) - } + (initial_state_id, initial_state), *_ = some_transition.initial_states.items() + outer_states = ( + _to_state(initial_state, index=initial_state_id), # type:ignore[type-var] + *[ + _to_state(particle, index=idx) # type:ignore[type-var] + for idx, particle in some_transition.final_states.items() + ], + ) return ThreeBodyDecay( - states={0: initial_state, **final_states}, # type:ignore[dict-item] - chains=tuple(sorted(to_decay_chain(t) for t in transitions)), + states={state.index: state for state in outer_states}, # type:ignore[misc] + chains=tuple(sorted(_to_decay_chain(t) for t in transitions)), ) -def to_decay_chain( - transition: FrozenTransition[Particle, LSCoupling | None], +def _to_decay_chain( + transition: FrozenTransition[Particle | State, LSCoupling | None], ) -> ThreeBodyDecayChain: if len(transition.initial_states) != 1: msg = f"Can only handle one initial state, but got {len(transition.initial_states)}" @@ -64,48 +71,56 @@ def to_decay_chain( msg = f"There are {len(transition.interactions)} interaction nodes, so this can't be a three-body decay" raise ValueError(msg) topology = transition.topology + parent, *_ = transition.initial_states.values() spectator_id, resonance_id = sorted(topology.get_edge_ids_outgoing_from_node(0)) - resonance_id, *_ = sorted(topology.get_edge_ids_ingoing_to_node(1)) child1_id, child2_id = sorted(topology.get_edge_ids_outgoing_from_node(1)) - parent, *_ = transition.initial_states.values() + resonance_id, *_ = sorted(topology.get_edge_ids_ingoing_to_node(1)) production_node, decay_node = transition.interactions.values() - isobar = IsobarNode( - parent=parent, - child1=IsobarNode( - parent=transition.states[resonance_id], - child1=transition.states[child1_id], - child2=transition.states[child2_id], - interaction=decay_node, - ), - child2=transition.states[spectator_id], - interaction=production_node, + return ThreeBodyDecayChain( + decay=IsobarNode( + parent=parent, + child1=IsobarNode( + parent=transition.states[resonance_id], + child1=transition.states[child1_id], # type:ignore[arg-type] + child2=transition.states[child2_id], # type:ignore[arg-type] + interaction=decay_node, + ), + child2=transition.states[spectator_id], # type:ignore[arg-type] + interaction=production_node, + ) ) - return ThreeBodyDecayChain(decay=isobar) -def convert_edges_and_nodes( +def convert_transitions( transitions: Iterable[FrozenTransition], -) -> tuple[FrozenTransition[Particle, LSCoupling | None], ...]: - unique_transitions = { - transition.convert( - state_converter=_convert_edge, - interaction_converter=_convert_node, - ) - for transition in transitions - } +) -> tuple[FrozenTransition[Particle | State, LSCoupling | None], ...]: + unique_transitions = {_convert_transition(t) for t in transitions} return tuple(sorted(unique_transitions)) -def _convert_edge(state: Any) -> Particle: - if isinstance(state, Particle): - return state - if not isinstance(state, State): - msg = f"Cannot convert state of type {type(state)}" - raise NotImplementedError(msg) - particle = state.particle - if particle.parity is None: - msg = f"Cannot convert particle {particle.name} with undefined parity" - raise NotImplementedError(msg) +def _convert_transition( + transition: FrozenTransition, +) -> FrozenTransition[Particle | State, LSCoupling | None]: + return FrozenTransition( + transition.topology, + states={ + index: _to_particle(state) + if index in transition.intermediate_states + else _to_state(state, index=index) # type:ignore[type-var] + for index, state in transition.states.items() + }, + interactions={ + i: _to_ls_coupling(interaction) + for i, interaction in transition.interactions.items() + }, + ) + + +def _to_particle( + particle: qrules.particle.Particle | qrules.transition.State, +) -> Particle: + if isinstance(particle, qrules.transition.State): + particle = particle.particle return Particle( name=particle.name, latex=particle.name if particle.latex is None else particle.latex, @@ -116,7 +131,29 @@ def _convert_edge(state: Any) -> Particle: ) -def _convert_node(node: Any) -> LSCoupling | None: +def _to_state(obj: Any, index: StateIDTemplate | None = None): + if isinstance(obj, qrules.transition.State): + obj = obj.particle + if isinstance(obj, State): + index = obj.index + if index is None: + msg = f"Cannot create a {State} from a {type(obj)} without an index" + raise ValueError(msg) + if not isinstance(obj, Particle) and not isinstance(obj, qrules.particle.Particle): + msg = f"Cannot convert object of type {type(obj)} to a {State}" + raise NotImplementedError(msg) + return State( + name=obj.name, + latex=obj.name if obj.latex is None else obj.latex, # pyright:ignore[reportUnnecessaryComparison] + spin=obj.spin, + parity=int(obj.parity), # type:ignore[arg-type] + mass=obj.mass, + width=obj.width, + index=index, + ) + + +def _to_ls_coupling(node: Any) -> LSCoupling | None: if node is None: return None if isinstance(node, LSCoupling): @@ -137,8 +174,11 @@ def filter_min_ls( ) -> tuple[FrozenTransition[EdgeType, NodeType], ...]: grouped_transitions = defaultdict(list) for transition in transitions: - resonances = tuple(transition.intermediate_states.values()) - grouped_transitions[resonances].append(transition) + key = tuple( + (state, _get_decay_product_ids(transition.topology, resonance_id)) + for resonance_id, state in transition.intermediate_states.items() + ) + grouped_transitions[key].append(transition) min_transitions = [] for group in grouped_transitions.values(): transition, *_ = group @@ -146,7 +186,9 @@ def filter_min_ls( topology=transition.topology, states=transition.states, interactions={ - i: min(t.interactions[i] for t in group) # type:ignore[type-var] + i: None + if any(t.interactions[i] is None for t in group) + else min(t.interactions[i] for t in group) # type:ignore[type-var] for i in transition.interactions }, ) @@ -154,6 +196,14 @@ def filter_min_ls( return tuple(min_transitions) +def _get_decay_product_ids(topology: Topology, resonance_id: int) -> tuple[int, ...]: + node_id = topology.edges[resonance_id].ending_node_id + if node_id is None: + msg = f"Resonance graph edge {resonance_id} has no ending node" + raise ValueError(msg) + return tuple(sorted(topology.get_originating_final_state_edge_ids(node_id))) + + def load_particles() -> qrules.particle.ParticleCollection: src_dir = Path(__file__).parent.parent particle_database = qrules.load_default_particles() @@ -209,3 +259,82 @@ def _(obj: abc.Iterable[T]) -> list[T]: T = TypeVar("T", ReactionInfo, StateTransition, Topology) """Type variable for the input and output of :func:`normalize_state_ids`.""" + + +@overload +def permute_equal_final_states(obj: ReactionInfo) -> ReactionInfo: ... +@overload +def permute_equal_final_states( + obj: Iterable[FrozenTransition[EdgeType, NodeType]], +) -> list[FrozenTransition[EdgeType, NodeType]]: ... +@overload +def permute_equal_final_states( + obj: FrozenTransition[EdgeType, NodeType], +) -> list[FrozenTransition[EdgeType, NodeType]]: ... +def permute_equal_final_states(obj: T) -> T: # type:ignore[misc] # pyright:ignore[reportInconsistentOverload] + return _impl_permute_equal_final_states(obj) + + +@singledispatch +def _impl_permute_equal_final_states(obj): + msg = f"Cannot permute equal final states of a {type(obj)}" + raise NotImplementedError(msg) + + +@_impl_permute_equal_final_states.register(ReactionInfo) +def _(obj: ReactionInfo) -> ReactionInfo: + return ReactionInfo( + transitions=permute_equal_final_states(obj.transitions), + formalism=obj.formalism, + ) + + +@_impl_permute_equal_final_states.register(abc.Iterable) +def _( + obj: Iterable[FrozenTransition[EdgeType, NodeType]], +) -> list[FrozenTransition[EdgeType, NodeType]]: + permuted_transitions = [] + for transition in obj: + permuted_transitions.extend(permute_equal_final_states(transition)) + return permuted_transitions + + +@_impl_permute_equal_final_states.register(FrozenTransition) +def _( + obj: FrozenTransition[EdgeType, NodeType], +) -> list[FrozenTransition[EdgeType, NodeType]]: + transition = obj + equal_state_ids = _get_equal_final_state_ids(transition) + if not equal_state_ids: + return [transition] + unique_permutations = {transition} | { + attrs.evolve(transition, topology=transition.topology.swap_edges(i, j)) + for i, j in itertools.combinations(equal_state_ids, 2) + } + return sorted(unique_permutations) + + +def _get_equal_final_state_ids( + transition: FrozenTransition, +) -> ( + tuple[()] + | tuple[FinalStateID, FinalStateID] + | tuple[FinalStateID, FinalStateID, FinalStateID] +): + particle_to_id = defaultdict(list) + for idx, state in transition.final_states.items(): + key = _uniqueness_repr(state) + particle_to_id[key].append(idx) + all_equal_state_ids = [set(ids) for ids in particle_to_id.values() if len(ids) > 1] + if not all_equal_state_ids: + return tuple() # type:ignore[return-value] + return tuple(sorted(all_equal_state_ids[0])) # type:ignore[return-value] + + +def _uniqueness_repr(obj: Any) -> str: + if isinstance(obj, qrules.transition.State): + return _uniqueness_repr(obj.particle) + if isinstance(obj, (Particle, State, qrules.particle.Particle)): + return obj.name + msg = f"Cannot create a uniqueness key for {type(obj)}" + raise NotImplementedError(msg) diff --git a/src/ampform_dpd/decay.py b/src/ampform_dpd/decay.py index 1cbbc2fb..0141f9ef 100644 --- a/src/ampform_dpd/decay.py +++ b/src/ampform_dpd/decay.py @@ -3,7 +3,8 @@ from __future__ import annotations from functools import lru_cache -from typing import TYPE_CHECKING, Dict, Literal, TypeVar +from textwrap import dedent +from typing import TYPE_CHECKING, Generic, Literal, TypeVar, overload from attrs import field, frozen from attrs.validators import instance_of @@ -14,6 +15,16 @@ import sympy as sp +InitialStateID = Literal[0] +"""ID for the initial state particle in a three-body decay.""" +FinalStateID = Literal[1, 2, 3] +"""ID for a particle in the final state of a three-body decay.""" +StateID = Literal[0, 1, 2, 3] +"""ID for any of the initial state or final state particles in a three-body decay.""" +StateIDTemplate = TypeVar("StateIDTemplate", InitialStateID, FinalStateID, StateID) +"""Generic template for the ID of a particle in a three-body decay.""" + + @frozen(order=True) class Particle: name: str @@ -25,25 +36,41 @@ class Particle: @frozen(order=True) -class IsobarNode: - parent: Particle - child1: Particle | IsobarNode - child2: Particle | IsobarNode +class State(Particle, Generic[StateIDTemplate]): + """Initial or final state `.Particle` in a `ThreeBodyDecay`, carrying an index.""" + + index: StateIDTemplate + + +InitialState = State[InitialStateID] +"""The initial state particle.""" +FinalState = State[FinalStateID] +"""One of the final state particles.""" +ParentType = TypeVar("ParentType", Particle, InitialState) +"""Type of the parent of an `IsobarNode`.""" + + +@frozen(order=True) +class IsobarNode(Generic[ParentType]): + parent: ParentType + child1: IsobarNode[Particle] | FinalState + child2: IsobarNode[Particle] | FinalState interaction: LSCoupling | None = field(default=None, converter=to_ls) @property - def children( - self, - ) -> tuple[ - Particle | IsobarNode, - Particle | IsobarNode, - ]: + def children(self) -> tuple[DecayNode | FinalState, DecayNode | FinalState]: return self.child1, self.child2 +ProductionNode = IsobarNode[InitialState] +"""The first `IsobarNode` in a `ThreeBodyDecayChain`.""" +DecayNode = IsobarNode[Particle] +"""The second `IsobarNode` in a `ThreeBodyDecayChain`.""" + + @frozen class ThreeBodyDecay: - states: OuterStates + states: dict[StateID, State[StateID]] chains: tuple[ThreeBodyDecayChain, ...] = field(converter=to_chains) def __attrs_post_init__(self) -> None: @@ -51,30 +78,36 @@ def __attrs_post_init__(self) -> None: expected_final_state = set(self.final_state.values()) for i, chain in enumerate(self.chains): if chain.parent != expected_initial_state: - msg = ( - f"Chain {i} has initial state {chain.parent.name}, but should have" - f" {expected_initial_state.name}" - ) + msg = dedent(f""" + Chain {i} has initial state + {chain.parent.index}: {chain.parent.name} + but should have + {expected_initial_state.index}: {expected_initial_state.name} + """).strip() raise ValueError(msg) final_state = {chain.spectator, *chain.decay_products} if final_state != expected_final_state: - def to_str(s): - return ", ".join(p.name for p in s) - - msg = ( - f"Chain {i} has final state {to_str(final_state)}, but should have" - f" {to_str(expected_final_state)}" - ) + def to_str(s: set[FinalState]) -> str: + return ", ".join( + f"{p.index}: {p.name}" for p in sorted(s, key=lambda x: x.index) + ) + + msg = dedent(f""" + Chain {i} has final state + {to_str(final_state)} + but should have + {to_str(expected_final_state)} + """).strip() raise ValueError(msg) @property - def initial_state(self) -> Particle: - return self.states[0] + def initial_state(self) -> InitialState: + return self.states[0] # type:ignore[return-value] @property - def final_state(self) -> dict[Literal[1, 2, 3], Particle]: - return {k: v for k, v in self.states.items() if k != 0} + def final_state(self) -> dict[FinalStateID, FinalState]: + return {s.index: s for s in self.states.values() if s.index != 0} # type:ignore[misc] def find_chain(self, resonance_name: str) -> ThreeBodyDecayChain: for chain in self.chains: @@ -83,7 +116,7 @@ def find_chain(self, resonance_name: str) -> ThreeBodyDecayChain: msg = f"No decay chain found for resonance {resonance_name}" raise KeyError(msg) - def get_subsystem(self, subsystem_id: Literal[1, 2, 3]) -> ThreeBodyDecay: + def get_subsystem(self, subsystem_id: FinalStateID) -> ThreeBodyDecay: child1_id, child2_id = get_decay_product_ids(subsystem_id) child1 = self.final_state[child1_id] child2 = self.final_state[child2_id] @@ -96,8 +129,8 @@ def get_subsystem(self, subsystem_id: Literal[1, 2, 3]) -> ThreeBodyDecay: def get_decay_product_ids( - spectator_id: Literal[1, 2, 3], -) -> tuple[Literal[1, 2, 3], Literal[1, 2, 3]]: + spectator_id: FinalStateID, +) -> tuple[FinalStateID, FinalStateID]: if spectator_id == 1: return 2, 3 if spectator_id == 2: # noqa: PLR2004 @@ -108,37 +141,58 @@ def get_decay_product_ids( raise ValueError(msg) -OuterStates = Dict[Literal[0, 1, 2, 3], Particle] -"""Mapping of the initial and final state IDs to their `.Particle` definition.""" - - @frozen(order=True) class ThreeBodyDecayChain: - decay: IsobarNode = field(validator=instance_of(IsobarNode)) + decay: ProductionNode = field(validator=instance_of(IsobarNode)) + + def __attrs_post_init__(self) -> None: + outer_states: list[State[StateID]] = [self.initial_state, *self.final_state] # type:ignore[list-item] + for state in outer_states: + if not isinstance(state, State): + msg = f"Not all particles in the initial or final state are not type {State.__name__}" + raise TypeError(msg) + if len({state.index for state in outer_states}) != 4: # noqa: PLR2004 + msg = "The initial and/or final state contains particles with the same ID:" + for state in outer_states: + msg += f"\n {state.index}: {state.name}" + raise ValueError(msg) @property - def parent(self) -> Particle: - return self.decay.parent + def initial_state(self) -> InitialState: + return self.parent + + @property + @lru_cache(maxsize=None) # noqa: B019 + def final_state(self) -> tuple[FinalState, FinalState, FinalState]: + final_state = (*self.decay_products, self.spectator) + return tuple(sorted(final_state, key=lambda x: x.index)) # type:ignore[return-value] + + @property + def parent(self) -> InitialState: + return self.decay.parent # type:ignore[return-value] @property def resonance(self) -> Particle: - decay_node: IsobarNode = self._get_child_of_type(IsobarNode) - return get_particle(decay_node) + return to_particle(self.decay_node) + + @property + def production_node(self) -> ProductionNode: + return self.decay @property - def decay_node(self) -> IsobarNode: + def decay_node(self) -> DecayNode: return self._get_child_of_type(IsobarNode) @property - def decay_products(self) -> tuple[Particle, Particle]: - return ( - get_particle(self.decay_node.child1), - get_particle(self.decay_node.child2), + def decay_products(self) -> tuple[FinalState, FinalState]: + return ( # type:ignore[return-value] + to_particle(self.decay_node.child1), + to_particle(self.decay_node.child2), ) @property - def spectator(self) -> Particle: - return self._get_child_of_type(Particle) + def spectator(self) -> FinalState: + return self._get_child_of_type(State) @lru_cache(maxsize=None) # noqa: B019 def _get_child_of_type(self, typ: type[T]) -> T: @@ -154,11 +208,10 @@ def incoming_ls(self) -> LSCoupling | None: @property def outgoing_ls(self) -> LSCoupling | None: - decay_node: IsobarNode = self._get_child_of_type(IsobarNode) - return decay_node.interaction + return self.decay_node.interaction -T = TypeVar("T", Particle, IsobarNode) +T = TypeVar("T", IsobarNode, Particle, InitialState, FinalState) @frozen(order=True) @@ -167,7 +220,13 @@ class LSCoupling: S: sp.Rational = field(converter=to_rational, validator=assert_spin_value) -def get_particle(isobar: IsobarNode | Particle) -> Particle: +@overload +def to_particle(isobar: IsobarNode[ParentType]) -> ParentType: ... +@overload +def to_particle(isobar: State[StateIDTemplate]) -> State[StateIDTemplate]: ... +@overload +def to_particle(isobar: Particle) -> Particle: ... +def to_particle(isobar): if isinstance(isobar, IsobarNode): return isobar.parent return isobar diff --git a/src/ampform_dpd/dynamics/builder.py b/src/ampform_dpd/dynamics/builder.py new file mode 100644 index 00000000..15c96c76 --- /dev/null +++ b/src/ampform_dpd/dynamics/builder.py @@ -0,0 +1,116 @@ +"""Dynamics builder functions for :meth:`.register_builder`. + +.. note:: As opposed to `AmpForm `_, AmpForm-DPD defines + dynamics over the **entire decay chain**, not a single isobar node. The dynamics + classes and the corresponding builders would have to be extended to implement other + dynamics lineshapes. +""" + +from __future__ import annotations + +import sympy as sp + +from ampform_dpd import to_particle +from ampform_dpd.decay import ( + DecayNode, + IsobarNode, + Particle, + State, + ThreeBodyDecayChain, +) +from ampform_dpd.dynamics import FormFactor, RelativisticBreitWigner + + +def formulate_breit_wigner_with_form_factor( + decay: ThreeBodyDecayChain, +) -> tuple[sp.Expr, dict[sp.Symbol, float]]: + decay_node = decay.decay_node + s = get_mandelstam_s(decay_node) + parameter_defaults = {} + production_ff, new_pars = _create_form_factor(s, decay.production_node) + parameter_defaults.update(new_pars) + decay_ff, new_pars = _create_form_factor(s, decay_node) + parameter_defaults.update(new_pars) + breit_wigner, new_pars = _create_breit_wigner(s, decay_node) + parameter_defaults.update(new_pars) + return ( + production_ff * decay_ff * breit_wigner, + parameter_defaults, + ) + + +def _create_form_factor( + s: sp.Symbol, isobar: IsobarNode +) -> tuple[sp.Expr, dict[sp.Symbol, float]]: + if isinstance(isobar.parent, State): + inv_mass = sp.Symbol("m0", nonnegative=True) + else: + inv_mass = get_mandelstam_s(isobar) + outgoing_state_mass1 = create_mass_symbol(isobar.child1) + outgoing_state_mass2 = create_mass_symbol(isobar.child2) + meson_radius = _create_meson_radius_symbol(isobar) + form_factor = FormFactor( + s=inv_mass**2, + m1=outgoing_state_mass1, + m2=outgoing_state_mass2, + angular_momentum=_get_angular_momentum(isobar), + meson_radius=meson_radius, + ) + parameter_defaults = { + meson_radius: 1, + outgoing_state_mass1: to_particle(isobar.child1).mass, + outgoing_state_mass2: to_particle(isobar.child2).mass, + } + if not inv_mass.name.startswith("s"): + parameter_defaults[inv_mass] = to_particle(isobar).mass + return form_factor, parameter_defaults + + +def _create_breit_wigner( + s: sp.Symbol, isobar: DecayNode +) -> tuple[sp.Expr, dict[sp.Symbol, float]]: + outgoing_state_mass1 = create_mass_symbol(isobar.child1) + outgoing_state_mass2 = create_mass_symbol(isobar.child2) + angular_momentum = _get_angular_momentum(isobar) + res_mass = create_mass_symbol(isobar.parent) + res_width = sp.Symbol(Rf"\Gamma_{{{isobar.parent.latex}}}", nonnegative=True) + meson_radius = _create_meson_radius_symbol(isobar) + + breit_wigner_expr = RelativisticBreitWigner( + s=s, + mass0=res_mass, + gamma0=res_width, + m1=outgoing_state_mass1, + m2=outgoing_state_mass2, + angular_momentum=angular_momentum, + meson_radius=meson_radius, + ) + parameter_defaults = { + res_mass: isobar.parent.mass, + res_width: isobar.parent.width, + meson_radius: 1, + } + return breit_wigner_expr, parameter_defaults + + +def _get_angular_momentum(isobar: IsobarNode) -> int: + if isobar.interaction is None: + msg = "Need LS couplings to formulate a form factor" + raise ValueError(msg) + return isobar.interaction.L + + +def _create_meson_radius_symbol(isobar: IsobarNode) -> sp.Symbol: + if isinstance(isobar.parent, State): + return sp.Symbol(Rf"R_{{{isobar.parent.latex}}}") + return sp.Symbol(R"R_\mathrm{res}") + + +def create_mass_symbol(particle: IsobarNode | Particle) -> sp.Symbol: + particle = to_particle(particle) + return sp.Symbol(f"m_{{{particle.latex}}}", nonnegative=True) + + +def get_mandelstam_s(decay: DecayNode) -> sp.Symbol: + subsystem_id, *_ = {1, 2, 3} - {s.index for s in decay.children} # type:ignore[union-attr] + return sp.Symbol(f"sigma{subsystem_id}", nonnegative=True) diff --git a/src/ampform_dpd/io.py b/src/ampform_dpd/io.py index 5fce8fc9..65149f3c 100644 --- a/src/ampform_dpd/io.py +++ b/src/ampform_dpd/io.py @@ -34,7 +34,13 @@ from tensorwaves.function.sympy import create_function, create_parametrized_function from ampform_dpd._cache import get_readable_hash, get_system_cache_directory -from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecay, ThreeBodyDecayChain +from ampform_dpd.decay import ( + IsobarNode, + Particle, + State, + ThreeBodyDecay, + ThreeBodyDecayChain, +) if TYPE_CHECKING: from tensorwaves.function import ( @@ -139,7 +145,7 @@ def as_markdown_table(obj: Sequence) -> str: if isinstance(obj, ThreeBodyDecay): return _as_decay_markdown_table(obj.chains) item_type = _determine_item_type(obj) - if item_type is Particle: + if item_type in {Particle, State}: return _as_resonance_markdown_table(obj) if item_type is ThreeBodyDecayChain: return _as_decay_markdown_table(obj) @@ -162,7 +168,7 @@ def _determine_item_type(obj) -> type: return item_type -def _as_resonance_markdown_table(items: Sequence[Particle]) -> str: +def _as_resonance_markdown_table(items: Sequence[Particle | State]) -> str: column_names = [ "name", "LaTeX", @@ -170,6 +176,9 @@ def _as_resonance_markdown_table(items: Sequence[Particle]) -> str: "mass (MeV)", "width (MeV)", ] + render_index = any(isinstance(i, State) for i in items) + if render_index: + column_names.insert(0, "index") src = _create_markdown_table_header(column_names) for particle in items: row_items = [ @@ -179,6 +188,8 @@ def _as_resonance_markdown_table(items: Sequence[Particle]) -> str: f"{int(1e3 * particle.mass):,.0f}", f"{int(1e3 * particle.width):,.0f}", ] + if render_index and isinstance(particle, State): + row_items.insert(0, particle.index) src += _create_markdown_table_row(row_items) return src diff --git a/tests/adapter/test_qrules.py b/tests/adapter/test_qrules.py index 9d691cff..7b95e9b4 100644 --- a/tests/adapter/test_qrules.py +++ b/tests/adapter/test_qrules.py @@ -1,61 +1,136 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable +import attrs import pytest import qrules from ampform_dpd.adapter.qrules import ( + _convert_transition, + _get_equal_final_state_ids, + convert_transitions, filter_min_ls, normalize_state_ids, + permute_equal_final_states, to_three_body_decay, ) from ampform_dpd.decay import LSCoupling, Particle if TYPE_CHECKING: + from _pytest.fixtures import SubRequest + from qrules.topology import FrozenTransition from qrules.transition import ReactionInfo, StateTransition @pytest.fixture(scope="session") -def reaction() -> ReactionInfo: +def a2pipipi_reaction() -> ReactionInfo: + return qrules.generate_transitions( + initial_state="a(1)(1260)0", + final_state=["pi0", "pi0", "pi0"], + allowed_intermediate_particles=["a(0)(980)0"], + formalism="helicity", + ) + + +@pytest.fixture(scope="session", params=["canonical-helicity", "helicity"]) +def jpsi2pksigma_reaction(request: SubRequest) -> ReactionInfo: # cspell:ignore pksigma return qrules.generate_transitions( initial_state=[("J/psi(1S)", [+1])], final_state=["K0", ("Sigma+", [+0.5]), ("p~", [+0.5])], allowed_interaction_types="strong", allowed_intermediate_particles=["Sigma(1660)"], - formalism="canonical-helicity", + formalism=request.param, + ) + + +@pytest.fixture(scope="session") +def xib2pkk_reaction() -> ReactionInfo: + reaction = qrules.generate_transitions( + initial_state="Xi(b)-", + final_state=["p", "K-", "K-"], + allowed_intermediate_particles=["Lambda(1520)"], + formalism="helicity", + ) + swapped_transitions = tuple( + attrs.evolve(t, topology=t.topology.swap_edges(1, 2)) + for t in reaction.transitions + ) + return qrules.transition.ReactionInfo( + transitions=reaction.transitions + swapped_transitions, + formalism=reaction.formalism, ) -def test_filter_min_ls(reaction: ReactionInfo): +def test_convert_transitions(xib2pkk_reaction: ReactionInfo): + reaction = normalize_state_ids(xib2pkk_reaction) + assert reaction.get_intermediate_particles().names == ["Lambda(1520)"] + assert len(reaction.transitions) == 16 + transitions = convert_transitions(reaction.transitions) + assert len(transitions) == 2 + decay = to_three_body_decay(transitions, min_ls=True) + assert len(decay.chains) == 2 + + +def test_filter_min_ls(jpsi2pksigma_reaction: ReactionInfo): + reaction = jpsi2pksigma_reaction transitions = tuple( t for t in reaction.transitions if t.states[3].spin_projection == +0.5 ) ls_couplings = [_get_couplings(t) for t in transitions] - assert ls_couplings == [ - ( - {"L": 0, "S": 1.0}, - {"L": 1, "S": 0.5}, - ), - ( - {"L": 2, "S": 1.0}, - {"L": 1, "S": 0.5}, - ), - ] + if reaction.formalism == "canonical-helicity": + assert len(ls_couplings) == 2 + assert ls_couplings == [ + ( + {"L": 0, "S": 1.0}, + {"L": 1, "S": 0.5}, + ), + ( + {"L": 2, "S": 1.0}, + {"L": 1, "S": 0.5}, + ), + ] + else: + assert len(ls_couplings) == 1 + for ls_coupling in ls_couplings: + for ls in ls_coupling: + assert ls == {"L": None, "S": None} min_ls_transitions = filter_min_ls(transitions) ls_couplings = [_get_couplings(t) for t in min_ls_transitions] - assert ls_couplings == [ - ( - {"L": 0, "S": 1.0}, - {"L": 1, "S": 0.5}, - ), + assert len(ls_couplings) == 1 + if reaction.formalism == "canonical-helicity": + assert ls_couplings == [ + ( + {"L": 0, "S": 1.0}, + {"L": 1, "S": 0.5}, + ), + ] + + +@pytest.mark.parametrize("converter", [lambda x: x, _convert_transition]) +def test_get_equal_final_state_ids( + a2pipipi_reaction: ReactionInfo, + jpsi2pksigma_reaction: ReactionInfo, + xib2pkk_reaction: ReactionInfo, + converter: Callable[[FrozenTransition], FrozenTransition], +): + test_cases = [ + (a2pipipi_reaction, (1, 2, 3)), + (jpsi2pksigma_reaction, tuple()), + (xib2pkk_reaction, (2, 3)), ] + for reaction012, expected in test_cases: + reaction = normalize_state_ids(reaction012) + transition = converter(reaction.transitions[0]) + equal_ids = _get_equal_final_state_ids(transition) + assert equal_ids == expected -def test_normalize_state_ids_reaction(reaction: ReactionInfo): - reaction012 = reaction +def test_normalize_state_ids_reaction(jpsi2pksigma_reaction: ReactionInfo): + reaction012 = jpsi2pksigma_reaction reaction123 = normalize_state_ids(reaction012) assert set(reaction123.initial_state) == {0} assert set(reaction123.final_state) == {1, 2, 3} @@ -75,8 +150,30 @@ def test_normalize_state_ids_reaction(reaction: ReactionInfo): assert transition012.states[i] == transition123.states[i + 1] +def test_permute_equal_final_states( + a2pipipi_reaction: ReactionInfo, + jpsi2pksigma_reaction: ReactionInfo, + xib2pkk_reaction: ReactionInfo, +): + test_cases = [ + (1, jpsi2pksigma_reaction), + (2, xib2pkk_reaction), + (3, a2pipipi_reaction), + ] + for n_permutations, reaction012 in test_cases: + reaction = normalize_state_ids(reaction012) + transition = reaction.transitions[0] + permutations = permute_equal_final_states(transition) + assert len(permutations) == n_permutations + + permuted_reaction = permute_equal_final_states(reaction) + n_transitions = len(permuted_reaction.transitions) + assert n_transitions == n_permutations * len(reaction.transitions) + + @pytest.mark.parametrize("min_ls", [False, True]) -def test_to_three_body_decay(reaction: ReactionInfo, min_ls: bool): +def test_to_three_body_decay(jpsi2pksigma_reaction: ReactionInfo, min_ls: bool): + reaction = normalize_state_ids(jpsi2pksigma_reaction) decay = to_three_body_decay(reaction.transitions, min_ls) assert decay.initial_state.name == "J/psi(1S)" assert {i: p.name for i, p in decay.final_state.items()} == { @@ -84,14 +181,19 @@ def test_to_three_body_decay(reaction: ReactionInfo, min_ls: bool): 2: "Sigma+", 3: "p~", } - if min_ls: + if reaction.formalism == "canonical-helicity": + if min_ls: + assert len(decay.chains) == 1 + assert decay.chains[0].incoming_ls == LSCoupling(L=0, S=1) + assert decay.chains[0].outgoing_ls == LSCoupling(L=1, S=0.5) + else: + assert len(decay.chains) == 2 + assert decay.chains[1].incoming_ls == LSCoupling(L=2, S=1) + assert decay.chains[1].outgoing_ls == LSCoupling(L=1, S=0.5) + elif reaction.formalism == "helicity": assert len(decay.chains) == 1 - assert decay.chains[0].incoming_ls == LSCoupling(L=0, S=1) - assert decay.chains[0].outgoing_ls == LSCoupling(L=1, S=0.5) - else: - assert len(decay.chains) == 2 - assert decay.chains[1].incoming_ls == LSCoupling(L=2, S=1) - assert decay.chains[1].outgoing_ls == LSCoupling(L=1, S=0.5) + assert decay.chains[0].incoming_ls is None + assert decay.chains[0].outgoing_ls is None for chain in decay.chains: assert isinstance(chain.resonance, Particle) assert chain.resonance.name == "Sigma(1660)~-"