From 5a3bf729ed06718b25459adad290bd7a00224c67 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 13 Apr 2023 18:31:38 +0200 Subject: [PATCH] FIX: permutate topology edges, not property mappings (#218) * FIX: permutate topologies instead of outer edges * MAINT: import entire `itertools` module * MAINT: improve order of function definitions * MAINT: improve spin-0 removal syntax * MAINT: simplify `_KinematicRepresentation` class * MAINT: switch to absolute imports * MAINT: test whether final state order matters --------- Co-authored-by: Remco de Boer --- src/qrules/__init__.py | 5 +- src/qrules/combinatorics.py | 333 +++++++++----------- src/qrules/transition.py | 52 ++- tests/unit/conftest.py | 28 ++ tests/unit/test_combinatorics.py | 204 ++++++------ tests/unit/test_final_state_permutations.py | 62 ++++ tests/unit/test_topology.py | 27 -- 7 files changed, 372 insertions(+), 339 deletions(-) create mode 100644 tests/unit/test_final_state_permutations.py diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index d3c971f9..a28b15d6 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -188,10 +188,7 @@ def check_edge_qn_conservation() -> Set[FrozenSet[str]]: node_id = next(iter(topology.nodes)) initial_facts = create_initial_facts( - topology=topology, - particle_db=particle_db, - initial_state=initial_state, - final_state=final_state, + topology, initial_state, final_state, particle_db ) check_pure_edge_rules() diff --git a/src/qrules/combinatorics.py b/src/qrules/combinatorics.py index 5b688e8b..261a6541 100644 --- a/src/qrules/combinatorics.py +++ b/src/qrules/combinatorics.py @@ -5,31 +5,27 @@ node properties. """ +import itertools import sys from collections import OrderedDict from copy import deepcopy -from itertools import permutations from typing import ( Any, Callable, Dict, - Generator, Iterable, List, Mapping, Optional, Sequence, Set, - Sized, Tuple, Union, ) -from qrules.particle import Particle, ParticleCollection - -from .particle import ParticleWithSpin -from .quantum_numbers import InteractionProperties, arange -from .topology import MutableTransition, Topology, get_originating_node_list +from qrules.particle import ParticleCollection, ParticleWithSpin +from qrules.quantum_numbers import InteractionProperties, arange +from qrules.topology import MutableTransition, Topology, get_originating_node_list if sys.version_info >= (3, 10): from typing import TypeAlias @@ -45,22 +41,22 @@ class _KinematicRepresentation: def __init__( self, - final_state: Optional[Union[Sequence[List[Any]], List[Any]]] = None, - initial_state: Optional[Union[Sequence[List[Any]], List[Any]]] = None, + final_state: Optional[Union[List[List[str]], List[str]]] = None, + initial_state: Optional[Union[List[List[str]], List[str]]] = None, ) -> None: - self.__initial_state: Optional[List[List[Any]]] = None - self.__final_state: Optional[List[List[Any]]] = None + self.__initial_state: Optional[List[List[str]]] = None + self.__final_state: Optional[List[List[str]]] = None if initial_state is not None: - self.__initial_state = self.__import(initial_state) + self.__initial_state = _sort_nested(ensure_nested_list(initial_state)) if final_state is not None: - self.__final_state = self.__import(final_state) + self.__final_state = _sort_nested(ensure_nested_list(final_state)) @property - def initial_state(self) -> Optional[List[List[Any]]]: + def initial_state(self) -> Optional[List[List[str]]]: return self.__initial_state @property - def final_state(self) -> Optional[List[List[Any]]]: + def final_state(self) -> Optional[List[List[str]]]: return self.__final_state def __eq__(self, other: object) -> bool: @@ -94,8 +90,8 @@ def __contains__(self, other: object) -> bool: """ def is_sublist( - sub_representation: Optional[List[List[Any]]], - main_representation: Optional[List[List[Any]]], + sub_representation: Optional[List[List[str]]], + main_representation: Optional[List[List[str]]], ) -> bool: if main_representation is None: if sub_representation is None: @@ -123,41 +119,23 @@ def is_sublist( f"Cannot compare {type(self).__name__} with {type(other).__name__}" ) - def __import( - self, nested_list: Union[Sequence[Sequence[Any]], Sequence[Any]] - ) -> List[List[Any]]: - return self.__sort(self.__prepare(nested_list)) - def __prepare( - self, nested_list: Union[Sequence[Sequence[Any]], Sequence[Any]] - ) -> List[List[Any]]: - if len(nested_list) == 0 or not isinstance(nested_list[0], list): - nested_list = [nested_list] - return [ - [self.__extract_particle_name(item) for item in sub_list] - for sub_list in nested_list - ] +def _sort_nested(nested_list: List[List[str]]) -> List[List[str]]: + return sorted(sorted(sub_list) for sub_list in nested_list) - @staticmethod - def __sort(nested_list: List[List[Any]]) -> List[List[Any]]: - return sorted(sorted(sub_list) for sub_list in nested_list) - @staticmethod - def __extract_particle_name(item: object) -> str: - if isinstance(item, str): - return item - if isinstance(item, (tuple, list)) and isinstance(item[0], str): - return item[0] - if isinstance(item, Particle): - return item.name - if isinstance(item, dict) and "Name" in item: - return str(item["Name"]) - raise ValueError(f"Cannot extract particle name from {type(item).__name__}") +def ensure_nested_list( + nested_list: Union[List[str], List[List[str]]] +) -> List[List[str]]: + if any(not isinstance(item, list) for item in nested_list): + nested_list = [nested_list] # type: ignore[assignment] + if any(not isinstance(i, str) for lst in nested_list for i in lst): + raise ValueError("Not all grouping items are particle names") + return nested_list # type: ignore[return-value] def _get_kinematic_representation( - topology: Topology, - initial_facts: Mapping[int, StateWithSpins], + topology: Topology, particle_names: Mapping[int, str] ) -> _KinematicRepresentation: r"""Group final or initial states by node, sorted by length of the group. @@ -194,22 +172,17 @@ def _get_kinematic_representation( and are therefore kinematically identical. The nested lists are sorted (by `list` length and element content) for comparisons. - - Note: more precisely, the states represented here by a `str` only also have a list - of allowed spin projections, for instance, :code:`("J/psi", [-1, +1])`. Note that a - `tuple` is also sortable. """ - def get_state_groupings( - edge_per_node_getter: Callable[[int], Iterable[int]] - ) -> List[Iterable[int]]: - return [edge_per_node_getter(i) for i in topology.nodes] + def get_state_groupings(get_edge: Callable[[int], Set[int]]) -> List[List[int]]: + return [sorted(get_edge(i)) for i in topology.nodes] def fill_groupings( - grouping_with_ids: Iterable[Iterable[int]], - ) -> List[List[StateWithSpins]]: + edge_id_groupings: Iterable[Iterable[int]], + ) -> List[List[str]]: return [ - [initial_facts[edge_id] for edge_id in group] for group in grouping_with_ids + [particle_names[edge_id] for edge_id in group] + for group in edge_id_groupings ] initial_state_edge_groups = fill_groupings( @@ -224,136 +197,54 @@ def fill_groupings( ) -def create_initial_facts( # pylint: disable=too-many-locals +def create_initial_facts( topology: Topology, - particle_db: ParticleCollection, initial_state: Sequence[StateDefinition], final_state: Sequence[StateDefinition], - final_state_groupings: Optional[ - Union[List[List[List[str]]], List[List[str]], List[str]] - ] = None, + particle_db: ParticleCollection, ) -> List[InitialFacts]: - def embed_in_list(some_list: List[Any]) -> List[List[Any]]: - if not isinstance(some_list[0], list): - return [some_list] - return some_list - - allowed_kinematic_groupings = None - if final_state_groupings is not None: - final_state_groupings = embed_in_list(final_state_groupings) - final_state_groupings = embed_in_list(final_state_groupings) - allowed_kinematic_groupings = [ - _KinematicRepresentation(final_state=grouping) - for grouping in final_state_groupings - ] - - kinematic_permutation_graphs = _generate_kinematic_permutations( - topology=topology, - particle_db=particle_db, - initial_state=initial_state, - final_state=final_state, - allowed_kinematic_groupings=allowed_kinematic_groupings, + states = __create_states_with_spin_projections( + list(topology.incoming_edge_ids) + list(topology.outgoing_edge_ids), + list(initial_state) + list(final_state), + particle_db, ) - edge_initial_facts: List[InitialFacts] = [] - for kinematic_permutation in kinematic_permutation_graphs: - spin_permutations = _generate_spin_permutations( - kinematic_permutation, particle_db - ) - edge_initial_facts.extend( - [MutableTransition(topology, states=x) for x in spin_permutations] # type: ignore[arg-type] - ) - return edge_initial_facts + spin_states = __generate_spin_combinations(states, particle_db) + return [MutableTransition(topology, state) for state in spin_states] # type: ignore[arg-type] -def _generate_kinematic_permutations( - topology: Topology, +def __create_states_with_spin_projections( + edge_ids: Sequence[int], + state_definitions: Sequence[StateDefinition], particle_db: ParticleCollection, - initial_state: Sequence[StateDefinition], - final_state: Sequence[StateDefinition], - allowed_kinematic_groupings: Optional[List[_KinematicRepresentation]] = None, -) -> List[Dict[int, StateWithSpins]]: - def assert_number_of_states(state_definitions: Sized, edge_ids: Sized) -> None: - if len(state_definitions) != len(edge_ids): - raise ValueError( - "Number of state definitions is not same as number of edge" - f" IDs:(len({state_definitions}) != len({edge_ids})" - ) - - assert_number_of_states(initial_state, topology.incoming_edge_ids) - assert_number_of_states(final_state, topology.outgoing_edge_ids) - - def is_allowed_grouping( - kinematic_representation: _KinematicRepresentation, - ) -> bool: - if allowed_kinematic_groupings is None: - return True - for allowed_kinematic_grouping in allowed_kinematic_groupings: - if allowed_kinematic_grouping in kinematic_representation: - return True - return False - - initial_state_with_projections = _safe_set_spin_projections( - initial_state, particle_db - ) - final_state_with_projections = _safe_set_spin_projections(final_state, particle_db) - - initial_facts_combinations: List[Dict[int, StateWithSpins]] = [] - kinematic_representations: List[_KinematicRepresentation] = [] - for permutation in _generate_outer_edge_permutations( - topology, - initial_state_with_projections, - final_state_with_projections, - ): - kinematic_representation = _get_kinematic_representation(topology, permutation) - if kinematic_representation in kinematic_representations: - continue - if not is_allowed_grouping(kinematic_representation): - continue - kinematic_representations.append(kinematic_representation) - initial_facts_combinations.append(permutation) - - return initial_facts_combinations +) -> Dict[int, StateWithSpins]: + if len(edge_ids) != len(state_definitions): + raise ValueError( + "Number of state definitions is not same as number of edge IDs" + ) + states = __safe_set_spin_projections(state_definitions, particle_db) + return dict(zip(edge_ids, states)) -def _safe_set_spin_projections( - list_of_states: Sequence[StateDefinition], +def __safe_set_spin_projections( + state_definitions: Sequence[StateDefinition], particle_db: ParticleCollection, ) -> Sequence[StateWithSpins]: - def safe_set_spin_projections( - state: StateDefinition, particle_db: ParticleCollection - ) -> StateWithSpins: + def fill_spin_projections(state: StateDefinition) -> StateWithSpins: if isinstance(state, str): particle_name = state - particle = particle_db[state] - spin_projections = list(arange(-particle.spin, particle.spin + 1, 1.0)) + particle = particle_db[particle_name] + spin_projections = set(arange(-particle.spin, particle.spin + 1, 1.0)) if particle.mass == 0.0: if 0.0 in spin_projections: - del spin_projections[spin_projections.index(0.0)] - state = (particle_name, spin_projections) + spin_projections.remove(0.0) + return particle_name, sorted(spin_projections) return state - return [safe_set_spin_projections(state, particle_db) for state in list_of_states] - - -def _generate_outer_edge_permutations( - topology: Topology, - initial_state: Sequence[StateWithSpins], - final_state: Sequence[StateWithSpins], -) -> Generator[Dict[int, StateWithSpins], None, None]: - initial_state_ids = list(topology.incoming_edge_ids) - final_state_ids = list(topology.outgoing_edge_ids) - for initial_state_permutation in permutations(initial_state): - for final_state_permutation in permutations(final_state): - yield dict( - zip( - initial_state_ids + final_state_ids, - initial_state_permutation + final_state_permutation, - ) - ) + return [fill_spin_projections(state) for state in state_definitions] -def _generate_spin_permutations( - initial_facts: Dict[int, StateWithSpins], +def __generate_spin_combinations( + states_with_spin_projections: Dict[int, StateWithSpins], particle_db: ParticleCollection, ) -> List[Dict[int, ParticleWithSpin]]: def populate_edge_with_spin_projections( @@ -371,7 +262,7 @@ def populate_edge_with_spin_projections( return new_permutations initial_facts_permutations: List[Dict[int, ParticleWithSpin]] = [{}] - for edge_id, state in initial_facts.items(): + for edge_id, state in states_with_spin_projections.items(): temp_permutations = initial_facts_permutations initial_facts_permutations = [] for temp_permutation in temp_permutations: @@ -382,16 +273,90 @@ def populate_edge_with_spin_projections( return initial_facts_permutations -def __get_initial_state_edge_ids( - graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", -) -> Iterable[int]: - return graph.topology.incoming_edge_ids +def permutate_topology_kinematically( + topology: Topology, + initial_state: List[StateDefinition], + final_state: List[StateDefinition], + final_state_groupings: Optional[ + Union[List[List[List[str]]], List[List[str]], List[str]] + ] = None, +) -> List[Topology]: + def strip_spin(state: StateDefinition) -> str: + if isinstance(state, tuple): + return state[0] + return state + edge_ids = sorted(topology.incoming_edge_ids) + sorted(topology.outgoing_edge_ids) + states = initial_state + final_state + return _generate_kinematic_permutations( + topology, + particle_names={i: strip_spin(s) for i, s in zip(edge_ids, states)}, + allowed_kinematic_groupings=__get_kinematic_groupings(final_state_groupings), + ) -def __get_final_state_edge_ids( - graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", -) -> Iterable[int]: - return graph.topology.outgoing_edge_ids + +def _generate_kinematic_permutations( + topology: Topology, + particle_names: Dict[int, str], + allowed_kinematic_groupings: Optional[List[_KinematicRepresentation]] = None, +) -> List[Topology]: + def is_allowed_grouping(kinematic_representation: _KinematicRepresentation) -> bool: + if allowed_kinematic_groupings is None: + return True + for allowed_kinematic_grouping in allowed_kinematic_groupings: + if allowed_kinematic_grouping in kinematic_representation: + return True + return False + + permuted_topologies: List[Topology] = [] + kinematic_representations: List[_KinematicRepresentation] = [] + for permutation in _permutate_outer_edges(topology): + kinematic_representation = _get_kinematic_representation( + permutation, particle_names + ) + if kinematic_representation in kinematic_representations: + continue + if not is_allowed_grouping(kinematic_representation): + continue + kinematic_representations.append(kinematic_representation) + permuted_topologies.append(permutation) + return permuted_topologies + + +def _permutate_outer_edges(topology: Topology) -> List[Topology]: + initial_state_ids = sorted(topology.incoming_edge_ids) + final_state_ids = sorted(topology.outgoing_edge_ids) + topologies = set() + for initial_state_permutation in itertools.permutations(initial_state_ids): + for final_state_permutation in itertools.permutations(final_state_ids): + permutation = zip( + initial_state_ids + final_state_ids, + initial_state_permutation + final_state_permutation, + ) + new_topology = topology.relabel_edges(dict(permutation)) + topologies.add(new_topology) + return sorted(topologies) + + +def __get_kinematic_groupings( + final_state_groupings: Union[ + List[List[List[str]]], + List[List[str]], + List[str], + None, + ] +) -> Optional[List[_KinematicRepresentation]]: + if final_state_groupings is None: + return None + + def embed_in_list(some_list: List[Any]) -> List[List[Any]]: + if not isinstance(some_list[0], list): + return [some_list] + return some_list + + final_state_groupings = embed_in_list(final_state_groupings) + final_state_groupings = embed_in_list(final_state_groupings) + return [_KinematicRepresentation(grouping) for grouping in final_state_groupings] def match_external_edges( @@ -442,6 +407,18 @@ def _match_external_edge_ids( # pylint: disable=too-many-locals graph.swap_edges(edge_id1, edge_id2) +def __get_initial_state_edge_ids( + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", +) -> Iterable[int]: + return graph.topology.incoming_edge_ids + + +def __get_final_state_edge_ids( + graph: "MutableTransition[ParticleWithSpin, InteractionProperties]", +) -> Iterable[int]: + return graph.topology.outgoing_edge_ids + + def perform_external_edge_identical_particle_combinatorics( graph: MutableTransition, ) -> List[MutableTransition]: @@ -484,7 +461,7 @@ def _external_edge_identical_particle_combinatorics( } # now for each identical particle group perform all permutations for edge_group in identical_particle_groups.values(): - combinations = permutations(edge_group) + combinations = itertools.permutations(edge_group) graph_combinations = set() ext_edge_combinations = [] ref_node_origin = get_originating_node_list(graph.topology, edge_group) diff --git a/src/qrules/transition.py b/src/qrules/transition.py index a7f8f470..033a874e 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -43,7 +43,9 @@ InitialFacts, StateDefinition, create_initial_facts, + ensure_nested_list, match_external_edges, + permutate_topology_kinematically, ) from .particle import ( Particle, @@ -270,8 +272,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-branches, too-many if particle_db is not None: self.__particles = particle_db self.reaction_mode = str(solving_mode) - self.initial_state = initial_state - self.final_state = final_state + self.initial_state = list(initial_state) + self.final_state = list(final_state) self.interaction_type_settings = interaction_type_settings self.interaction_determinators: List[InteractionDeterminator] = [ @@ -365,7 +367,7 @@ def add_final_state_grouping( if len(fs_group) > 0: if self.final_state_groupings is None: self.final_state_groupings = [] - nested_list = _safe_wrap_list(fs_group) + nested_list = ensure_nested_list(fs_group) self.final_state_groupings.append(nested_list) @overload @@ -408,28 +410,20 @@ def set_allowed_interaction_types( self.__allowed_interaction_types[node_id] = allowed_interaction_types def create_problem_sets(self) -> Dict[float, List[ProblemSet]]: - problem_sets = [] - for topology in self.topologies: + problem_sets = [ + ProblemSet(permutation, initial_facts, settings) + for topology in self.topologies + for permutation in permutate_topology_kinematically( + topology, + self.initial_state, + self.final_state, + self.final_state_groupings, + ) for initial_facts in create_initial_facts( - topology=topology, - particle_db=self.__particles, - initial_state=self.initial_state, - final_state=self.final_state, - final_state_groupings=self.final_state_groupings, - ): - problem_sets.extend( - [ - ProblemSet( - topology=topology, - initial_facts=initial_facts, - solving_settings=x, - ) - for x in self.__determine_graph_settings( - topology, initial_facts - ) - ] - ) - # create groups of settings ordered by "probability" + permutation, self.initial_state, self.final_state, self.__particles + ) + for settings in self.__determine_graph_settings(permutation, initial_facts) + ] return _group_by_strength(problem_sets) def __determine_graph_settings( @@ -698,16 +692,6 @@ def __convert_to_particle_definitions( ) -def _safe_wrap_list(nested_list: Union[List[str], List[List[str]]]) -> List[List[str]]: - if all(isinstance(i, list) for i in nested_list): - return nested_list # type: ignore[return-value] - if all(isinstance(i, str) for i in nested_list): - return [nested_list] # type: ignore[list-item] - raise TypeError( - f"Input final state grouping {nested_list} is not a list of lists of strings" - ) - - def _filter_by_name_pattern( particles: ParticleCollection, pattern: str, regex: bool ) -> ParticleCollection: diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2766f620..45ef8c92 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -6,6 +6,7 @@ import qrules from qrules import ReactionInfo +from qrules.topology import Edge, Topology logging.basicConfig(level=logging.ERROR) @@ -20,3 +21,30 @@ def reaction(request: SubRequest) -> ReactionInfo: allowed_interaction_types="strong", formalism=formalism, ) + + +@pytest.fixture(scope="session") +def two_to_three_decay() -> Topology: + r"""Create a dummy `Topology`. + + Has the following shape: + + .. code-block:: + + e-1 -- (N0) -- e3 -- (N1) -- e4 -- (N2) -- e2 + / \ \ + e-2 e0 e1 + """ + topology = Topology( + nodes={0, 1, 2}, + edges={ + -2: Edge(None, 0), + -1: Edge(None, 0), + 0: Edge(1, None), + 1: Edge(2, None), + 2: Edge(2, None), + 3: Edge(0, 1), + 4: Edge(1, 2), + }, + ) + return topology diff --git a/tests/unit/test_combinatorics.py b/tests/unit/test_combinatorics.py index 548dd876..098f19a7 100644 --- a/tests/unit/test_combinatorics.py +++ b/tests/unit/test_combinatorics.py @@ -1,18 +1,16 @@ # pylint: disable=redefined-outer-name -from math import factorial +from typing import List import pytest from qrules.combinatorics import ( _generate_kinematic_permutations, - _generate_outer_edge_permutations, - _generate_spin_permutations, _get_kinematic_representation, _KinematicRepresentation, - _safe_set_spin_projections, + _permutate_outer_edges, create_initial_facts, + permutate_topology_kinematically, ) -from qrules.particle import ParticleCollection from qrules.topology import Topology, create_isobar_topologies @@ -23,6 +21,67 @@ def three_body_decay() -> Topology: return topology +def test_create_initial_facts(three_body_decay, particle_database): + initial_facts = create_initial_facts( + three_body_decay, + initial_state=[("J/psi(1S)", [-1, +1])], + final_state=["gamma", "pi0", "pi0"], + particle_db=particle_database, + ) + assert len(initial_facts) == 4 + + for fact in initial_facts: + edge_ids = sorted(fact.states) + assert edge_ids == [-1, 0, 1, 2] + particle_names = [fact.states[i][0].name for i in edge_ids] + assert particle_names == ["J/psi(1S)", "gamma", "pi0", "pi0"] + _, initial_polarization = fact.states[-1] + assert initial_polarization in {-1, +1} + + +def test_generate_kinematic_permutations_groupings(three_body_decay: Topology): + topology = three_body_decay + particle_names = { + -1: "J/psi(1S)", + 0: "gamma", + 1: "pi0", + 2: "pi0", + } + allowed_kinematic_groupings = [_KinematicRepresentation(["pi0", "pi0"])] + permutations = _generate_kinematic_permutations( + topology, particle_names, allowed_kinematic_groupings + ) + assert len(permutations) == 1 + + permutations = _generate_kinematic_permutations(topology, particle_names) + assert len(permutations) == 2 + assert permutations[0].get_originating_final_state_edge_ids(1) == {1, 2} + assert permutations[1].get_originating_final_state_edge_ids(1) == {0, 2} + + +@pytest.mark.parametrize( + ("n_permutations", "decay_type"), + [ + (3, "two_to_three_decay"), + (3, "three_body_decay"), + ], +) +def test_permutate_outer_edges( + n_permutations: int, + decay_type: str, + three_body_decay: Topology, + two_to_three_decay: Topology, +): + if decay_type == "two_to_three_decay": + topology = two_to_three_decay + elif decay_type == "three_body_decay": + topology = three_body_decay + else: + raise NotImplementedError(decay_type) + permutations = _permutate_outer_edges(topology) + assert len(permutations) == n_permutations + + @pytest.mark.parametrize( "final_state_groupings", [ @@ -32,53 +91,55 @@ def three_body_decay() -> Topology: ["gamma", "pi0"], [["gamma", "pi0"]], [[["gamma", "pi0"]]], + None, ], ) -def test_initialize_graph( +def test_permutate_topology_kinematically( final_state_groupings, three_body_decay: Topology, - particle_database: ParticleCollection, ): - initial_facts = create_initial_facts( - three_body_decay, + permutations = permutate_topology_kinematically( + topology=three_body_decay, initial_state=[("J/psi(1S)", [-1, +1])], final_state=["gamma", "pi0", "pi0"], - particle_db=particle_database, final_state_groupings=final_state_groupings, ) - assert len(initial_facts) == 4 + if final_state_groupings is None: + assert len(permutations) == 2 + else: + assert len(permutations) == 1 @pytest.mark.parametrize( - ("initial_state", "final_state"), + ("n_permutations", "initial_state", "final_state"), [ - (["J/psi(1S)"], ["gamma", "pi0", "pi0"]), - (["J/psi(1S)"], ["K+", "K-", "pi+", "pi-"]), - (["e+", "e-"], ["gamma", "pi-", "pi+"]), - (["e+", "e-"], ["K+", "K-", "pi+", "pi-"]), + (2, ["J/psi(1S)"], ["gamma", "pi0", "pi0"]), + (3, ["J/psi(1S)"], ["gamma", "pi-", "pi+"]), + (2, ["e+", "e-"], ["gamma", "pi0", "pi0"]), + (3, ["e+", "e-"], ["gamma", "pi-", "pi+"]), ], ) -def test_generate_outer_edge_permutations( - initial_state, - final_state, +def test_generate_kinematic_permutations( + n_permutations: int, + initial_state: List[str], + final_state: List[str], three_body_decay: Topology, - particle_database: ParticleCollection, + two_to_three_decay: Topology, ): - initial_state_with_spins = _safe_set_spin_projections( - initial_state, particle_database - ) - final_state_with_spins = _safe_set_spin_projections(final_state, particle_database) - list_of_permutations = list( - _generate_outer_edge_permutations( - three_body_decay, - initial_state_with_spins, - final_state_with_spins, + if len(initial_state) == 1: + topology = three_body_decay + elif len(initial_state) == 2: + topology = two_to_three_decay + else: + raise NotImplementedError + particle_names = dict( + zip( + sorted(topology.incoming_edge_ids) + sorted(topology.outgoing_edge_ids), + list(initial_state) + list(final_state), ) ) - n_permutations_final_state = factorial(len(final_state)) - n_permutations_initial_state = factorial(len(initial_state)) - n_permutations = n_permutations_final_state * n_permutations_initial_state - assert len(list_of_permutations) == n_permutations + permutations = _generate_kinematic_permutations(topology, particle_names) + assert len(permutations) == n_permutations class TestKinematicRepresentation: @@ -94,13 +155,11 @@ def test_constructor(self): assert representation.final_state == [["gamma", "pi0"]] def test_from_topology(self, three_body_decay: Topology): - pi0 = ("pi0", [0]) - gamma = ("gamma", [-1, 1]) states = { - -1: ("J/psi", [-1, +1]), - 0: pi0, - 1: pi0, - 2: gamma, + -1: "J/psi", + 0: "pi0", + 1: "pi0", + 2: "gamma", } kinematic_representation1 = _get_kinematic_representation( three_body_decay, states @@ -116,22 +175,22 @@ def test_from_topology(self, three_body_decay: Topology): kinematic_representation2 = _get_kinematic_representation( topology=three_body_decay, - initial_facts={ - -1: ("J/psi", [-1, +1]), - 0: pi0, - 1: gamma, - 2: pi0, + particle_names={ + -1: "J/psi", + 0: "pi0", + 1: "gamma", + 2: "pi0", }, ) assert kinematic_representation1 == kinematic_representation2 kinematic_representation3 = _get_kinematic_representation( topology=three_body_decay, - initial_facts={ - -1: ("J/psi", [-1, +1]), - 0: pi0, - 1: gamma, - 2: gamma, + particle_names={ + -1: "J/psi", + 0: "pi0", + 1: "gamma", + 2: "gamma", }, ) assert kinematic_representation2 != kinematic_representation3 @@ -163,50 +222,3 @@ def test_in_operator(self): match=r"Comparison representation needs to be a list of lists", ): assert ["should be nested list"] in kinematic_representation - - -def test_generate_permutations( - three_body_decay: Topology, particle_database: ParticleCollection -): - permutations = _generate_kinematic_permutations( - three_body_decay, - initial_state=[("J/psi(1S)", [-1, +1])], - final_state=["gamma", "pi0", "pi0"], - particle_db=particle_database, - allowed_kinematic_groupings=[_KinematicRepresentation(["pi0", "pi0"])], - ) - assert len(permutations) == 1 - - permutations = _generate_kinematic_permutations( - three_body_decay, - initial_state=[("J/psi(1S)", [-1, +1])], - final_state=["gamma", "pi0", "pi0"], - particle_db=particle_database, - ) - assert len(permutations) == 2 - graph0_final_state_node1 = [ - permutations[0][edge_id] - for edge_id in three_body_decay.get_originating_final_state_edge_ids(1) - ] - graph1_final_state_node1 = [ - permutations[1][edge_id] - for edge_id in three_body_decay.get_originating_final_state_edge_ids(1) - ] - assert graph0_final_state_node1 == [ - ("pi0", [0]), - ("pi0", [0]), - ] - assert graph1_final_state_node1 == [ - ("gamma", [-1, 1]), - ("pi0", [0]), - ] - - permutation0 = permutations[0] - spin_permutations = _generate_spin_permutations(permutation0, particle_database) - assert len(spin_permutations) == 4 - assert spin_permutations[0][-1][1] == -1 - assert spin_permutations[0][0][1] == -1 - assert spin_permutations[1][-1][1] == -1 - assert spin_permutations[1][0][1] == +1 - assert spin_permutations[2][-1][1] == +1 - assert spin_permutations[3][-1][1] == +1 diff --git a/tests/unit/test_final_state_permutations.py b/tests/unit/test_final_state_permutations.py new file mode 100644 index 00000000..8dedb332 --- /dev/null +++ b/tests/unit/test_final_state_permutations.py @@ -0,0 +1,62 @@ +# cspell:ignore pbar +import itertools + +import pytest + +import qrules +from qrules.settings import InteractionType +from qrules.transition import StateTransitionManager + + +@pytest.mark.parametrize( + "final_state_description", + sorted({" ".join(p) for p in itertools.permutations(["p~", "Sigma+", "K0"])}), +) +def test_create_problem_sets(final_state_description: str): + input_final_state = final_state_description.split(" ") + stm = StateTransitionManager( + initial_state=["J/psi(1S)"], + final_state=input_final_state, + allowed_intermediate_particles=["N(1440)"], + ) + stm.set_allowed_interaction_types([InteractionType.STRONG]) + problem_sets = stm.create_problem_sets() + for problem_set in problem_sets.values(): + for problem in problem_set: + problem_final_state = [ + problem.initial_facts.states[i][0].name for i in range(3) + ] + assert problem_final_state == input_final_state + + +@pytest.mark.parametrize( + "final_state_description", + sorted({" ".join(p) for p in itertools.permutations(["gamma", "pi0", "pi0"], 3)}), +) +def test_generate_transitions(final_state_description: str): + final_state = final_state_description.split(" ") + reaction = qrules.generate_transitions( + initial_state=("J/psi(1S)", [-1, +1]), + final_state=final_state, + allowed_intermediate_particles=["omega(782)"], + allowed_interaction_types=["strong", "EM"], + ) + ordered_final_state = [ + reaction.final_state[i].name for i in sorted(reaction.final_state) + ] + assert final_state == ordered_final_state + + assert len(reaction.transitions) == 8 + for transition in reaction.transitions: + ordered_final_state = [ + transition.final_states[i].particle.name + for i in sorted(transition.final_states) + ] + assert final_state == ordered_final_state + + topology = transition.topology + decay_products = { + transition.states[i].particle.name + for i in topology.get_edge_ids_outgoing_from_node(1) + } + assert decay_products == {"gamma", "pi0"} diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index 2fa242b5..3e637d69 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -19,33 +19,6 @@ ) -@pytest.fixture(scope="session") -def two_to_three_decay() -> Topology: - r"""Create a dummy `Topology`. - - Has the following shape: - - .. code-block:: - - e-1 -- (N0) -- e3 -- (N1) -- e4 -- (N2) -- e2 - / \ \ - e-2 e0 e1 - """ - topology = Topology( - nodes={0, 1, 2}, - edges={ - -2: Edge(None, 0), - -1: Edge(None, 0), - 0: Edge(1, None), - 1: Edge(2, None), - 2: Edge(2, None), - 3: Edge(0, 1), - 4: Edge(1, 2), - }, - ) - return topology - - class TestEdge: def test_get_connected_nodes(self): edge = Edge(1, 2)