From b5d31950e5f99ae2e5be8d2468e24c54700ac507 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 25 Apr 2024 18:26:02 +0200 Subject: [PATCH] DX: lint type hints with MyPy (#109) --- .gitpod.yml | 1 + .pre-commit-config.yaml | 11 ++++++ .vscode/extensions.json | 2 +- .vscode/settings.json | 7 ++++ docs/conf.py | 1 - pyproject.toml | 16 +++++++++ src/ampform_dpd/__init__.py | 58 ++++++++++++++++--------------- src/ampform_dpd/_attrs.py | 10 ++++-- src/ampform_dpd/adapter/qrules.py | 14 ++++---- src/ampform_dpd/decay.py | 51 +++++++++++++++++++++------ src/ampform_dpd/dynamics.py | 4 +-- src/ampform_dpd/io.py | 19 +++++----- src/ampform_dpd/spin.py | 2 +- tests/adapter/test_qrules.py | 2 +- tests/test_angles.py | 10 +++--- tests/test_io.py | 2 +- 16 files changed, 141 insertions(+), 69 deletions(-) diff --git a/.gitpod.yml b/.gitpod.yml index 866fb318..2e3ec8ef 100644 --- a/.gitpod.yml +++ b/.gitpod.yml @@ -22,6 +22,7 @@ vscode: - github.vscode-github-actions - github.vscode-pull-request-github - ms-python.python + - ms-python.mypy-type-checker - ms-python.vscode-pylance - ms-toolsai.vscode-jupyter-cell-tags - ms-vscode.live-server diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2b4b76e2..0361bfbb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,7 @@ ci: autoupdate_commit_msg: "MAINT: update pip constraints and pre-commit" autoupdate_schedule: quarterly # already done by requirements-cron.yml skip: + - mypy - prettier - pyright - taplo @@ -108,6 +109,16 @@ repos: .*\.py )$ + - repo: local + hooks: + - id: mypy + name: mypy + entry: mypy + language: system + require_serial: true + types: + - python + - repo: https://github.com/ComPWA/mirrors-pyright rev: v1.1.359 hooks: diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 44c0afda..d528453c 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -8,6 +8,7 @@ "github.vscode-github-actions", "github.vscode-pull-request-github", "ms-python.python", + "ms-python.mypy-type-checker", "ms-python.vscode-pylance", "ms-toolsai.vscode-jupyter-cell-tags", "ms-vscode.live-server", @@ -27,7 +28,6 @@ "ms-python.black-formatter", "ms-python.flake8", "ms-python.isort", - "ms-python.mypy-type-checker", "ms-python.pylint", "travisillig.vscode-json-stable-stringify", "tyriar.sort-lines" diff --git a/.vscode/settings.json b/.vscode/settings.json index 0408778f..04642116 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -38,6 +38,8 @@ "gitlens.telemetry.enabled": false, "livePreview.defaultPreviewPath": "docs/_build/html", "multiDiffEditor.experimental.enabled": true, + "mypy-type-checker.args": ["--config-file=${workspaceFolder}/pyproject.toml"], + "mypy-type-checker.importStrategy": "fromEnvironment", "notebook.codeActionsOnSave": { "notebook.source.organizeImports": "explicit" }, @@ -54,6 +56,11 @@ "ruff.enable": true, "ruff.importStrategy": "fromEnvironment", "ruff.organizeImports": true, + "search.exclude": { + "**/tests/**/__init__.py": true, + ".constraints/*.txt": true, + "typings/**": true + }, "telemetry.telemetryLevel": "off", "yaml.schemas": { "https://raw.githubusercontent.com/ComPWA/qrules/0.10.x/src/qrules/particle-validation.json": [ diff --git a/docs/conf.py b/docs/conf.py index 9e88c084..fd54522f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -67,7 +67,6 @@ "show-inheritance": True, } autodoc_member_order = "bysource" -autodoc_type_aliases = {} autodoc_typehints_format = "short" autosectionlabel_prefix_document = True autosectionlabel_maxdepth = 2 diff --git a/pyproject.toml b/pyproject.toml index a6174d0f..d8ed943a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ numba = [ ] sty = [ "ampform-dpd[types]", + "mypy", "pre-commit >=1.4.0", "ruff", ] @@ -134,6 +135,21 @@ where = ["src"] [tool.setuptools_scm] write_to = "src/ampform_dpd/version.py" +[tool.mypy] +exclude = "_build" +show_error_codes = true +warn_unused_configs = true + +[[tool.mypy.overrides]] +ignore_missing_imports = true +module = ["cloudpickle.*"] + +[[tool.mypy.overrides]] +check_untyped_defs = true +disallow_incomplete_defs = false +disallow_untyped_defs = false +module = ["tests.*"] + [tool.pyright] reportArgumentType = false reportAssignmentType = false diff --git a/src/ampform_dpd/__init__.py b/src/ampform_dpd/__init__.py index f316525a..4c5af6cb 100644 --- a/src/ampform_dpd/__init__.py +++ b/src/ampform_dpd/__init__.py @@ -22,6 +22,7 @@ ThreeBodyDecay, ThreeBodyDecayChain, get_decay_product_ids, + get_particle, ) from ampform_dpd.io import ( simplify_latex_rendering, # noqa: F401 # pyright:ignore[reportUnusedImport] @@ -35,9 +36,9 @@ class AmplitudeModel: intensity: sp.Expr = sp.S.One amplitudes: dict[sp.Indexed, sp.Expr] = field(factory=dict) variables: dict[sp.Symbol, sp.Expr] = field(factory=dict) - parameter_defaults: dict[sp.Symbol, float] = field(factory=dict) - masses: dict[sp.Symbol, float] = field(factory=dict) - invariants: dict[sp.Symbol, float] = field(factory=dict) + parameter_defaults: dict[sp.Symbol, float | complex] = field(factory=dict) + masses: dict[sp.Symbol, float | complex] = field(factory=dict) + invariants: dict[sp.Symbol, sp.Expr] = field(factory=dict) @property def full_expression(self) -> sp.Expr: @@ -79,17 +80,20 @@ def formulate( reference_subsystem: Literal[1, 2, 3] = 1, cleanup_summations: bool = False, ) -> AmplitudeModel: - helicity_symbols = sp.symbols("lambda:4", rational=True) + helicity_symbols: tuple[sp.Symbol, sp.Symbol, sp.Symbol, sp.Symbol] = ( + sp.symbols("lambda:4", rational=True) + ) allowed_helicities = { - symbol: create_spin_range(self.decay.states[i].spin) + symbol: create_spin_range(self.decay.states[i].spin) # type:ignore[index] for i, symbol in enumerate(helicity_symbols) } amplitude_definitions = {} angle_definitions = {} parameter_defaults = {} - for args in product(*allowed_helicities.values()): - for sub_system in [1, 2, 3]: - chain_model = self.formulate_subsystem_amplitude(*args, sub_system) + args: tuple[sp.Rational, sp.Rational, sp.Rational, sp.Rational] + for args in product(*allowed_helicities.values()): # type:ignore[assignment] + for sub_system in (1, 2, 3): + chain_model = self.formulate_subsystem_amplitude(*args, sub_system) # type:ignore[arg-type] amplitude_definitions.update(chain_model.amplitudes) angle_definitions.update(chain_model.variables) parameter_defaults.update(chain_model.parameter_defaults) @@ -100,13 +104,13 @@ def formulate( masses = create_mass_symbol_mapping(self.decay) parameter_defaults.update(masses) if cleanup_summations: - aligned_amp = aligned_amp.cleanup() + aligned_amp = aligned_amp.cleanup() # type:ignore[assignment] intensity = PoolSum( sp.Abs(aligned_amp) ** 2, *allowed_helicities.items(), ) if cleanup_summations: - intensity = intensity.cleanup() + intensity = intensity.cleanup() # type:ignore[assignment] return AmplitudeModel( decay=self.decay, intensity=PoolSum( @@ -165,8 +169,8 @@ def formulate_subsystem_amplitude( # noqa: PLR0914 interaction=chain.outgoing_ls, typ="decay", ) - parameter_defaults[h_prod] = 1 + 0j - parameter_defaults[h_dec] = 1 + parameter_defaults[h_prod] = 1 + 0j # type:ignore[index] + parameter_defaults[h_dec] = 1 # type:ignore[index] sub_amp_expr = ( sp.KroneckerDelta(λ[0], λR - λ[k]) * (-1) ** (spin[k] - λ[k]) @@ -189,9 +193,8 @@ def formulate_subsystem_amplitude( # noqa: PLR0914 * (-1) ** (spin[j] - λ[j]) ) if not self.use_decay_helicity_couplings: - resonance_isobar = chain.decay.child1 sub_amp_expr *= _formulate_clebsch_gordan_factors( - resonance_isobar, + chain.decay_node, helicities={ self.decay.final_state[i]: λ[i], self.decay.final_state[j]: λ[j], @@ -255,20 +258,23 @@ def formulate_aligned_amplitude( (_λ2, create_spin_range(j2)), (_λ3, create_spin_range(j3)), ) - return amp_expr, wigner_generator.angle_definitions + return amp_expr, wigner_generator.angle_definitions # type:ignore[return-value] def _create_coupling_symbol( helicity_coupling: bool, resonance: Str, helicities: tuple[sp.Basic, sp.Basic], - interaction: LSCoupling, + interaction: LSCoupling | None, typ: Literal["production", "decay"], ) -> sp.Indexed: H = _get_coupling_base(helicity_coupling, typ) if helicity_coupling: λi, λj = helicities return H[resonance, λi, λj] + if interaction is None: + msg = "Cannot formulate LS-coupling without LS combinations" + raise ValueError(msg) return H[resonance, interaction.L, interaction.S] @@ -314,15 +320,9 @@ def _formulate_clebsch_gordan_factors( return sqrt_factor * cg_ll * cg_ss -def get_particle(isobar: IsobarNode | Particle) -> Particle: - if isinstance(isobar, IsobarNode): - return isobar.parent - return isobar - - @lru_cache(maxsize=None) def _generate_amplitude_index_bases() -> dict[Literal[1, 2, 3], sp.IndexedBase]: - return dict(enumerate(sp.symbols(R"A^(1:4)", cls=sp.IndexedBase), 1)) + return dict(enumerate(sp.symbols(R"A^(1:4)", cls=sp.IndexedBase), 1)) # type:ignore[arg-type] class _AlignmentWignerGenerator: @@ -333,8 +333,8 @@ def __init__(self, reference_subsystem: Literal[1, 2, 3] = 1) -> None: def __call__( self, j: sp.Rational, - m: sp.Rational, - m_prime: sp.Rational, + m: sp.Rational | sp.Symbol, + m_prime: sp.Rational | sp.Symbol, rotated_state: int, aligned_subsystem: int, ) -> sp.Rational | WignerD: @@ -380,16 +380,18 @@ def decay(self) -> ThreeBodyDecay: class DynamicsBuilder(Protocol): def __call__( self, decay_chain: ThreeBodyDecayChain - ) -> tuple[sp.Expr, dict[sp.Symbol, float]]: ... + ) -> tuple[sp.Expr, dict[sp.Symbol, float | complex]]: ... def formulate_non_resonant( decay_chain: ThreeBodyDecayChain, -) -> tuple[sp.Expr, dict[sp.Symbol, float]]: +) -> tuple[sp.Expr, dict[sp.Symbol, float | complex]]: return sp.Rational(1), {} -def create_mass_symbol_mapping(decay: ThreeBodyDecay) -> dict[sp.Symbol, float]: +def create_mass_symbol_mapping( + decay: ThreeBodyDecay, +) -> dict[sp.Symbol, float | complex]: return { sp.Symbol(f"m{i}", nonnegative=True): decay.states[i].mass for i in sorted(decay.states) # ensure that dict keys are sorted by state ID diff --git a/src/ampform_dpd/_attrs.py b/src/ampform_dpd/_attrs.py index 19d5eb60..b85aaad3 100644 --- a/src/ampform_dpd/_attrs.py +++ b/src/ampform_dpd/_attrs.py @@ -2,14 +2,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, SupportsFloat +from typing import TYPE_CHECKING, Iterable, SupportsFloat import sympy as sp if TYPE_CHECKING: from attrs import Attribute - from ampform_dpd.decay import LSCoupling + from ampform_dpd.decay import LSCoupling, ThreeBodyDecayChain def assert_spin_value(instance, attribute: Attribute, value: sp.Rational) -> None: @@ -18,7 +18,7 @@ def assert_spin_value(instance, attribute: Attribute, value: sp.Rational) -> Non raise ValueError(msg) -def to_ls(obj: LSCoupling | tuple[int, SupportsFloat] | None) -> LSCoupling: +def to_ls(obj: LSCoupling | tuple[int, SupportsFloat] | None) -> LSCoupling | None: from ampform_dpd.decay import LSCoupling # noqa: PLC0415 if obj is None: @@ -32,5 +32,9 @@ def to_ls(obj: LSCoupling | tuple[int, SupportsFloat] | None) -> LSCoupling: raise TypeError(msg) +def to_chains(obj: Iterable[ThreeBodyDecayChain]) -> tuple[ThreeBodyDecayChain, ...]: + return tuple(obj) + + def to_rational(obj: SupportsFloat) -> sp.Rational: return sp.Rational(obj) diff --git a/src/ampform_dpd/adapter/qrules.py b/src/ampform_dpd/adapter/qrules.py index d2a83e4f..9605a5f3 100644 --- a/src/ampform_dpd/adapter/qrules.py +++ b/src/ampform_dpd/adapter/qrules.py @@ -35,7 +35,7 @@ def to_three_body_decay( for i, idx in enumerate(sorted(some_transition.final_states), 1) } return ThreeBodyDecay( - states={0: initial_state, **final_states}, + states={0: initial_state, **final_states}, # type:ignore[dict-item] chains=tuple(sorted(to_decay_chain(t) for t in transitions)), ) @@ -97,9 +97,9 @@ def _convert_edge(state: Any) -> Particle: raise NotImplementedError(msg) return Particle( name=particle.name, - latex=particle.latex, + latex=particle.name if particle.latex is None else particle.latex, spin=particle.spin, - parity=particle.parity, + parity=int(particle.parity), # type:ignore[arg-type] mass=particle.mass, width=particle.width, ) @@ -131,11 +131,11 @@ def filter_min_ls( min_transitions = [] for group in grouped_transitions.values(): transition, *_ = group - min_transition = FrozenTransition( + min_transition: FrozenTransition[EdgeType, NodeType] = FrozenTransition( topology=transition.topology, states=transition.states, interactions={ - i: min(t.interactions[i] for t in group) + i: min(t.interactions[i] for t in group) # type:ignore[type-var] for i in transition.interactions }, ) @@ -146,6 +146,6 @@ def filter_min_ls( def load_particles() -> qrules.particle.ParticleCollection: src_dir = Path(__file__).parent.parent particle_database = qrules.load_default_particles() - additional_definitions = qrules.io.load(src_dir / "particle-definitions.yml") - particle_database.update(additional_definitions) + additional_definitions = qrules.io.load(src_dir / "particle-definitions.yml") # type:ignore[arg-type] + particle_database.update(additional_definitions) # type:ignore[arg-type] return particle_database diff --git a/src/ampform_dpd/decay.py b/src/ampform_dpd/decay.py index 2d538f1c..a4380603 100644 --- a/src/ampform_dpd/decay.py +++ b/src/ampform_dpd/decay.py @@ -2,12 +2,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Literal +from functools import lru_cache +from typing import TYPE_CHECKING, Dict, Literal, TypeVar from attrs import field, frozen from attrs.validators import instance_of -from ampform_dpd._attrs import assert_spin_value, to_ls, to_rational +from ampform_dpd._attrs import assert_spin_value, to_chains, to_ls, to_rational if TYPE_CHECKING: import sympy as sp @@ -31,14 +32,19 @@ class IsobarNode: interaction: LSCoupling | None = field(default=None, converter=to_ls) @property - def children(self) -> tuple[Particle, Particle]: + def children( + self, + ) -> tuple[ + Particle | IsobarNode, + Particle | IsobarNode, + ]: return self.child1, self.child2 @frozen class ThreeBodyDecay: states: OuterStates - chains: tuple[ThreeBodyDecayChain, ...] + chains: tuple[ThreeBodyDecayChain, ...] = field(converter=to_chains) def __attrs_post_init__(self) -> None: expected_initial_state = self.initial_state @@ -89,7 +95,9 @@ def get_subsystem(self, subsystem_id: Literal[1, 2, 3]) -> ThreeBodyDecay: return ThreeBodyDecay(self.states, filtered_chains) -def get_decay_product_ids(spectator_id: Literal[1, 2, 3]) -> tuple[int, int]: +def get_decay_product_ids( + spectator_id: Literal[1, 2, 3], +) -> tuple[Literal[1, 2, 3], Literal[1, 2, 3]]: if spectator_id == 1: return 2, 3 if spectator_id == 2: # noqa: PLR2004 @@ -128,18 +136,31 @@ def parent(self) -> Particle: @property def resonance(self) -> Particle: - return self.decay.child1.parent + decay_node: IsobarNode = self._get_child_of_type(IsobarNode) + return get_particle(decay_node) + + @property + def decay_node(self) -> IsobarNode: + return self._get_child_of_type(IsobarNode) @property def decay_products(self) -> tuple[Particle, Particle]: return ( - self.decay.child1.child1, - self.decay.child1.child2, + get_particle(self.decay_node.child1), + get_particle(self.decay_node.child2), ) @property def spectator(self) -> Particle: - return self.decay.child2 + return self._get_child_of_type(Particle) + + @lru_cache(maxsize=None) # noqa: B019 + def _get_child_of_type(self, typ: type[T]) -> T: + for child in self.decay.children: + if isinstance(child, typ): + return child + msg = f"The production node does not have any children that are of type {typ.__name__}" + raise ValueError(msg) @property def incoming_ls(self) -> LSCoupling | None: @@ -147,10 +168,20 @@ def incoming_ls(self) -> LSCoupling | None: @property def outgoing_ls(self) -> LSCoupling | None: - return self.decay.child1.interaction + decay_node: IsobarNode = self._get_child_of_type(IsobarNode) + return decay_node.interaction + + +T = TypeVar("T", Particle, IsobarNode) @frozen(order=True) class LSCoupling: L: int S: sp.Rational = field(converter=to_rational, validator=assert_spin_value) + + +def get_particle(isobar: IsobarNode | Particle) -> Particle: + if isinstance(isobar, IsobarNode): + return isobar.parent + return isobar diff --git a/src/ampform_dpd/dynamics.py b/src/ampform_dpd/dynamics.py index 00f7e7d8..13a5a118 100644 --- a/src/ampform_dpd/dynamics.py +++ b/src/ampform_dpd/dynamics.py @@ -144,8 +144,8 @@ def evaluate(self) -> sp.Piecewise: z, L = self.args cases = { 0: 1, - 1: 1 / (1 + z**2), - 2: 1 / (9 + 3 * z**2 + z**4), + 1: 1 / (1 + z**2), # type:ignore[operator] + 2: 1 / (9 + 3 * z**2 + z**4), # type:ignore[operator] } return sp.Piecewise(*[ (sp.sqrt(expr), sp.Eq(L, l_val)) for l_val, expr in cases.items() diff --git a/src/ampform_dpd/io.py b/src/ampform_dpd/io.py index 6b4d1766..5fce8fc9 100644 --- a/src/ampform_dpd/io.py +++ b/src/ampform_dpd/io.py @@ -23,7 +23,7 @@ from collections import abc from importlib.metadata import version from pathlib import Path -from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, overload +from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, overload import cloudpickle import sympy as sp @@ -136,11 +136,11 @@ def _render_jp(particle: Particle) -> str: def as_markdown_table(obj: Sequence) -> str: """Render objects a `str` suitable for generating a table.""" + if isinstance(obj, ThreeBodyDecay): + return _as_decay_markdown_table(obj.chains) item_type = _determine_item_type(obj) if item_type is Particle: return _as_resonance_markdown_table(obj) - if item_type is ThreeBodyDecay: - return _as_decay_markdown_table(obj.chains) if item_type is ThreeBodyDecayChain: return _as_decay_markdown_table(obj) msg = ( @@ -175,7 +175,7 @@ def _as_resonance_markdown_table(items: Sequence[Particle]) -> str: row_items = [ f"`{particle.name}`", f"${particle.latex}$", - Rf"${aslatex(particle, only_jp=True)}$", + Rf"${aslatex(particle, only_jp=True)}$", # type:ignore[call-arg] f"{int(1e3 * particle.mass):,.0f}", f"{int(1e3 * particle.width):,.0f}", ] @@ -195,9 +195,9 @@ def _as_decay_markdown_table(decay_chains: Sequence[ThreeBodyDecayChain]) -> str src = _create_markdown_table_header(column_names) for chain in decay_chains: child1, child2 = map(aslatex, chain.decay_products) - row_items = [ + row_items: list = [ Rf"${chain.resonance.latex} \to {child1} {child2}$", - Rf"${aslatex(chain.resonance, only_jp=True)}$", + Rf"${aslatex(chain.resonance, only_jp=True)}$", # type:ignore[call-arg] f"{int(1e3 * chain.resonance.mass):,.0f}", f"{int(1e3 * chain.resonance.width):,.0f}", ] @@ -236,7 +236,7 @@ def perform_cached_lambdify( ) -> ParametrizedBackendFunction: ... -def perform_cached_lambdify( # pyright: ignore[reportInconsistentOverload] +def perform_cached_lambdify( # type:ignore[misc] # pyright:ignore[reportInconsistentOverload] expr: sp.Expr, parameters: Mapping[sp.Symbol, ParameterValue] | None = None, backend: str = "jax", @@ -276,7 +276,7 @@ def perform_cached_lambdify( # pyright: ignore[reportInconsistentOverload] cache_directory = Path(cache_directory) cache_directory.mkdir(exist_ok=True, parents=True) if parameters is None: - hash_obj = expr + hash_obj: Any = expr else: hash_obj = ( expr, @@ -288,6 +288,7 @@ def perform_cached_lambdify( # pyright: ignore[reportInconsistentOverload] with open(filename, "rb") as f: return pickle.load(f) _LOGGER.warning(f"Cached function file {filename} not found, lambdifying...") + func: ParametrizedFunction | Function if parameters is None: func = create_function(expr, backend) else: @@ -305,4 +306,4 @@ def _print_Indexed_latex(self, printer, *args): # noqa: N802 indices = ", ".join(map(printer._print, self.indices)) return f"{base}_{{{indices}}}" - sp.Indexed._latex = _print_Indexed_latex + sp.Indexed._latex = _print_Indexed_latex # type:ignore[attr-defined] diff --git a/src/ampform_dpd/spin.py b/src/ampform_dpd/spin.py index f583f0ef..df8f8ad0 100644 --- a/src/ampform_dpd/spin.py +++ b/src/ampform_dpd/spin.py @@ -60,7 +60,7 @@ def create_spin_range(spin: SupportsFloat) -> list[sp.Rational]: >>> create_spin_range(1.5) [-3/2, -1/2, 1/2, 3/2] """ - return create_rational_range(-spin, spin) + return create_rational_range(-float(spin), +float(spin)) def create_rational_range( diff --git a/tests/adapter/test_qrules.py b/tests/adapter/test_qrules.py index b97b7c3e..f2d9b51e 100644 --- a/tests/adapter/test_qrules.py +++ b/tests/adapter/test_qrules.py @@ -73,7 +73,7 @@ def test_to_three_body_decay(reaction: ReactionInfo, min_ls: bool): def _get_couplings(transition: StateTransition) -> tuple[dict, dict]: - return tuple( + return tuple( # type:ignore[return-value] {"L": node.l_magnitude, "S": node.s_magnitude} for node in transition.interactions.values() ) diff --git a/tests/test_angles.py b/tests/test_angles.py index cbdbb708..7d4ffffe 100644 --- a/tests/test_angles.py +++ b/tests/test_angles.py @@ -49,7 +49,7 @@ def test_formulate_zeta_angle_equation_A6(): # noqa: N802 @pytest.mark.parametrize( - ("ζ1", "ζ2", "ζ3"), + ("ζ1_expr", "ζ2_expr", "ζ3_expr"), [ ( formulate_zeta_angle(1, 2, 3)[1], @@ -68,14 +68,14 @@ def test_formulate_zeta_angle_equation_A6(): # noqa: N802 ), ], ) -def test_formulate_zeta_angle_sum_rule(ζ1: sp.Expr, ζ2: sp.Expr, ζ3: sp.Expr): +def test_formulate_zeta_angle_sum(ζ1_expr: sp.Expr, ζ2_expr: sp.Expr, ζ3_expr: sp.Expr): """Test Eq. (A9), https://journals.aps.org/prd/pdf/10.1103/PhysRevD.101.034033#page=11. """ σ3_expr = compute_third_mandelstam(σ1, σ2, m0, m1, m2, m3) masses = {m0: 2.3, m1: 0.94, m2: 0.14, m3: 0.49, σ1: 1.2, σ2: 3.0, σ3: σ3_expr} - ζ1 = float(ζ1.doit().xreplace(masses).xreplace(masses)) - ζ2 = float(ζ2.doit().xreplace(masses).xreplace(masses)) - ζ3 = float(ζ3.doit().xreplace(masses).xreplace(masses)) + ζ1 = float(ζ1_expr.doit().xreplace(masses).xreplace(masses)) + ζ2 = float(ζ2_expr.doit().xreplace(masses).xreplace(masses)) + ζ3 = float(ζ3_expr.doit().xreplace(masses).xreplace(masses)) np.testing.assert_almost_equal(ζ1, ζ2 + ζ3, decimal=14) diff --git a/tests/test_io.py b/tests/test_io.py index c3700498..b5a883ba 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -75,7 +75,7 @@ def test_get_readable_hash( caplog.clear() elif python_hash_seed == "0": if sys.version_info < (3, 11): - expected_hash = expected_hash[0] + expected_hash = expected_hash[0] # type:ignore[assignment] else: expected_hash = expected_hash[1] expected = f"pythonhashseed-0{expected_hash:+d}"