Skip to content

Commit

Permalink
now preserves stm-API, explicit coercion in test_settings-arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
grayson-helmholz committed Nov 11, 2024
1 parent 274284b commit 881689f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
17 changes: 14 additions & 3 deletions src/qrules/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ def calculate_strength(node_interaction_settings: dict[int, NodeSettings]) -> fl
return strength_sorted_problem_sets


def _fractionalize_statedefinitions(definition: StateDefinition) -> StateDefinition:
if type(definition) is str:
return definition
if type(definition) is tuple:
name = definition[0]
state = definition[1]
return name, list(map(Fraction, state))
msg = f"value has to be of type {StateDefinition}, got {type(definition)}"
raise ValueError(msg)


class StateTransitionManager:
"""Main handler for decay topologies.
Expand Down Expand Up @@ -263,8 +274,8 @@ def __init__( # noqa: C901, PLR0912, PLR0917
if particle_db is not None:
self.__particles = particle_db
self.reaction_mode = str(solving_mode)
self.initial_state = list(initial_state)
self.final_state = list(final_state)
self.initial_state = list(map(_fractionalize_statedefinitions, initial_state))
self.final_state = list(map(_fractionalize_statedefinitions, final_state))
self.interaction_type_settings = interaction_type_settings

self.interaction_determinators: list[InteractionDeterminator] = [
Expand Down Expand Up @@ -732,7 +743,7 @@ def _strip_spin(state_definition: Sequence[StateDefinition]) -> list[str]:
@frozen(order=True)
class State:
particle: Particle = field(validator=instance_of(Particle))
spin_projection: float = field(converter=_to_fraction)
spin_projection: Fraction = field(converter=_to_fraction)


StateTransition = FrozenTransition[State, InteractionProperties]
Expand Down
29 changes: 18 additions & 11 deletions tests/unit/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

from fractions import Fraction
from typing import TYPE_CHECKING

import pytest

from qrules.particle import ParticleCollection
from qrules.quantum_numbers import EdgeQuantumNumbers as EdgeQN
from qrules.settings import (
InteractionType,
Expand All @@ -9,7 +13,10 @@
_int_domain,
create_interaction_settings,
)
from qrules.transition import SpinFormalism

if TYPE_CHECKING:
from qrules.particle import ParticleCollection
from qrules.transition import SpinFormalism


class TestInteractionType:
Expand Down Expand Up @@ -82,23 +89,23 @@ def test_create_interaction_settings(
"parity": [-1, +1],
"c_parity": [-1, +1, None],
"g_parity": [-1, +1, None],
"spin_magnitude": _halves_domain(0, 4),
"spin_projection": _halves_domain(-4, +4),
"spin_magnitude": _halves_domain(*tuple(map(Fraction, (0, 4)))),
"spin_projection": _halves_domain(*tuple(map(Fraction, (-4, +4)))),
"charge": _int_domain(-2, 2),
"isospin_magnitude": _halves_domain(0, 1.5),
"isospin_projection": _halves_domain(-1.5, +1.5),
"isospin_magnitude": _halves_domain(*tuple(map(Fraction, (0, 1.5)))),
"isospin_projection": _halves_domain(*tuple(map(Fraction, (-1.5, +1.5)))),
"strangeness": _int_domain(-3, +3),
"charmness": _int_domain(-1, 1),
"bottomness": _int_domain(-1, 1),
}

expected = {
"l_magnitude": _int_domain(0, 2),
"s_magnitude": _halves_domain(0, 2),
"s_magnitude": _halves_domain(*tuple(map(Fraction, (0, 2)))),
}
if "canonical" in formalism:
expected["l_projection"] = [-2, -1, 0, 1, 2]
expected["s_projection"] = _halves_domain(-2, 2)
expected["s_projection"] = _halves_domain(*tuple(map(Fraction, (-2, 2))))
if formalism == "canonical-helicity":
expected["l_projection"] = [0]
if "helicity" in formalism and interaction_type != InteractionType.WEAK:
Expand All @@ -124,9 +131,9 @@ def test_create_interaction_settings(
(-1, +1, [-1, -0.5, 0, 0.5, +1]),
],
)
def test_halves_range(start: float, stop: float, expected: list):
def test_halves_range(start: float, stop: float, expected: list | None):
if expected is None:
with pytest.raises(ValueError, match=r"needs to be multiple of 0.5"):
_halves_domain(start, stop)
_halves_domain(Fraction(start), Fraction(stop))
else:
assert _halves_domain(start, stop) == expected
assert _halves_domain(Fraction(start), Fraction(stop)) == expected

0 comments on commit 881689f

Please sign in to comment.