Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: lint type hints with MyPy #109

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitpod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ vscode:
- github.vscode-github-actions
- github.vscode-pull-request-github
- ms-python.python
- ms-python.mypy-type-checker
- ms-python.vscode-pylance
- ms-toolsai.vscode-jupyter-cell-tags
- ms-vscode.live-server
Expand Down
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ ci:
autoupdate_commit_msg: "MAINT: update pip constraints and pre-commit"
autoupdate_schedule: quarterly # already done by requirements-cron.yml
skip:
- mypy
- prettier
- pyright
- taplo
Expand Down Expand Up @@ -108,6 +109,16 @@ repos:
.*\.py
)$

- repo: local
hooks:
- id: mypy
name: mypy
entry: mypy
language: system
require_serial: true
types:
- python

- repo: https://github.com/ComPWA/mirrors-pyright
rev: v1.1.359
hooks:
Expand Down
2 changes: 1 addition & 1 deletion .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"github.vscode-github-actions",
"github.vscode-pull-request-github",
"ms-python.python",
"ms-python.mypy-type-checker",
"ms-python.vscode-pylance",
"ms-toolsai.vscode-jupyter-cell-tags",
"ms-vscode.live-server",
Expand All @@ -27,7 +28,6 @@
"ms-python.black-formatter",
"ms-python.flake8",
"ms-python.isort",
"ms-python.mypy-type-checker",
"ms-python.pylint",
"travisillig.vscode-json-stable-stringify",
"tyriar.sort-lines"
Expand Down
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"gitlens.telemetry.enabled": false,
"livePreview.defaultPreviewPath": "docs/_build/html",
"multiDiffEditor.experimental.enabled": true,
"mypy-type-checker.args": ["--config-file=${workspaceFolder}/pyproject.toml"],
"mypy-type-checker.importStrategy": "fromEnvironment",
"notebook.codeActionsOnSave": {
"notebook.source.organizeImports": "explicit"
},
Expand All @@ -54,6 +56,11 @@
"ruff.enable": true,
"ruff.importStrategy": "fromEnvironment",
"ruff.organizeImports": true,
"search.exclude": {
"**/tests/**/__init__.py": true,
".constraints/*.txt": true,
"typings/**": true
},
"telemetry.telemetryLevel": "off",
"yaml.schemas": {
"https://raw.githubusercontent.com/ComPWA/qrules/0.10.x/src/qrules/particle-validation.json": [
Expand Down
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
"show-inheritance": True,
}
autodoc_member_order = "bysource"
autodoc_type_aliases = {}
autodoc_typehints_format = "short"
autosectionlabel_prefix_document = True
autosectionlabel_maxdepth = 2
Expand Down
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ numba = [
]
sty = [
"ampform-dpd[types]",
"mypy",
"pre-commit >=1.4.0",
"ruff",
]
Expand Down Expand Up @@ -134,6 +135,21 @@ where = ["src"]
[tool.setuptools_scm]
write_to = "src/ampform_dpd/version.py"

[tool.mypy]
exclude = "_build"
show_error_codes = true
warn_unused_configs = true

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = ["cloudpickle.*"]

[[tool.mypy.overrides]]
check_untyped_defs = true
disallow_incomplete_defs = false
disallow_untyped_defs = false
module = ["tests.*"]

[tool.pyright]
reportArgumentType = false
reportAssignmentType = false
Expand Down
58 changes: 30 additions & 28 deletions src/ampform_dpd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ThreeBodyDecay,
ThreeBodyDecayChain,
get_decay_product_ids,
get_particle,
)
from ampform_dpd.io import (
simplify_latex_rendering, # noqa: F401 # pyright:ignore[reportUnusedImport]
Expand All @@ -35,9 +36,9 @@ class AmplitudeModel:
intensity: sp.Expr = sp.S.One
amplitudes: dict[sp.Indexed, sp.Expr] = field(factory=dict)
variables: dict[sp.Symbol, sp.Expr] = field(factory=dict)
parameter_defaults: dict[sp.Symbol, float] = field(factory=dict)
masses: dict[sp.Symbol, float] = field(factory=dict)
invariants: dict[sp.Symbol, float] = field(factory=dict)
parameter_defaults: dict[sp.Symbol, float | complex] = field(factory=dict)
masses: dict[sp.Symbol, float | complex] = field(factory=dict)
invariants: dict[sp.Symbol, sp.Expr] = field(factory=dict)

@property
def full_expression(self) -> sp.Expr:
Expand Down Expand Up @@ -79,17 +80,20 @@ def formulate(
reference_subsystem: Literal[1, 2, 3] = 1,
cleanup_summations: bool = False,
) -> AmplitudeModel:
helicity_symbols = sp.symbols("lambda:4", rational=True)
helicity_symbols: tuple[sp.Symbol, sp.Symbol, sp.Symbol, sp.Symbol] = (
sp.symbols("lambda:4", rational=True)
)
allowed_helicities = {
symbol: create_spin_range(self.decay.states[i].spin)
symbol: create_spin_range(self.decay.states[i].spin) # type:ignore[index]
for i, symbol in enumerate(helicity_symbols)
}
amplitude_definitions = {}
angle_definitions = {}
parameter_defaults = {}
for args in product(*allowed_helicities.values()):
for sub_system in [1, 2, 3]:
chain_model = self.formulate_subsystem_amplitude(*args, sub_system)
args: tuple[sp.Rational, sp.Rational, sp.Rational, sp.Rational]
for args in product(*allowed_helicities.values()): # type:ignore[assignment]
for sub_system in (1, 2, 3):
chain_model = self.formulate_subsystem_amplitude(*args, sub_system) # type:ignore[arg-type]
amplitude_definitions.update(chain_model.amplitudes)
angle_definitions.update(chain_model.variables)
parameter_defaults.update(chain_model.parameter_defaults)
Expand All @@ -100,13 +104,13 @@ def formulate(
masses = create_mass_symbol_mapping(self.decay)
parameter_defaults.update(masses)
if cleanup_summations:
aligned_amp = aligned_amp.cleanup()
aligned_amp = aligned_amp.cleanup() # type:ignore[assignment]
intensity = PoolSum(
sp.Abs(aligned_amp) ** 2,
*allowed_helicities.items(),
)
if cleanup_summations:
intensity = intensity.cleanup()
intensity = intensity.cleanup() # type:ignore[assignment]
return AmplitudeModel(
decay=self.decay,
intensity=PoolSum(
Expand Down Expand Up @@ -165,8 +169,8 @@ def formulate_subsystem_amplitude( # noqa: PLR0914
interaction=chain.outgoing_ls,
typ="decay",
)
parameter_defaults[h_prod] = 1 + 0j
parameter_defaults[h_dec] = 1
parameter_defaults[h_prod] = 1 + 0j # type:ignore[index]
parameter_defaults[h_dec] = 1 # type:ignore[index]
sub_amp_expr = (
sp.KroneckerDelta(λ[0], λR - λ[k])
* (-1) ** (spin[k] - λ[k])
Expand All @@ -189,9 +193,8 @@ def formulate_subsystem_amplitude( # noqa: PLR0914
* (-1) ** (spin[j] - λ[j])
)
if not self.use_decay_helicity_couplings:
resonance_isobar = chain.decay.child1
sub_amp_expr *= _formulate_clebsch_gordan_factors(
resonance_isobar,
chain.decay_node,
helicities={
self.decay.final_state[i]: λ[i],
self.decay.final_state[j]: λ[j],
Expand Down Expand Up @@ -255,20 +258,23 @@ def formulate_aligned_amplitude(
(_λ2, create_spin_range(j2)),
(_λ3, create_spin_range(j3)),
)
return amp_expr, wigner_generator.angle_definitions
return amp_expr, wigner_generator.angle_definitions # type:ignore[return-value]


def _create_coupling_symbol(
helicity_coupling: bool,
resonance: Str,
helicities: tuple[sp.Basic, sp.Basic],
interaction: LSCoupling,
interaction: LSCoupling | None,
typ: Literal["production", "decay"],
) -> sp.Indexed:
H = _get_coupling_base(helicity_coupling, typ)
if helicity_coupling:
λi, λj = helicities
return H[resonance, λi, λj]
if interaction is None:
msg = "Cannot formulate LS-coupling without LS combinations"
raise ValueError(msg)
return H[resonance, interaction.L, interaction.S]


Expand Down Expand Up @@ -314,15 +320,9 @@ def _formulate_clebsch_gordan_factors(
return sqrt_factor * cg_ll * cg_ss


def get_particle(isobar: IsobarNode | Particle) -> Particle:
if isinstance(isobar, IsobarNode):
return isobar.parent
return isobar


@lru_cache(maxsize=None)
def _generate_amplitude_index_bases() -> dict[Literal[1, 2, 3], sp.IndexedBase]:
return dict(enumerate(sp.symbols(R"A^(1:4)", cls=sp.IndexedBase), 1))
return dict(enumerate(sp.symbols(R"A^(1:4)", cls=sp.IndexedBase), 1)) # type:ignore[arg-type]


class _AlignmentWignerGenerator:
Expand All @@ -333,8 +333,8 @@ def __init__(self, reference_subsystem: Literal[1, 2, 3] = 1) -> None:
def __call__(
self,
j: sp.Rational,
m: sp.Rational,
m_prime: sp.Rational,
m: sp.Rational | sp.Symbol,
m_prime: sp.Rational | sp.Symbol,
rotated_state: int,
aligned_subsystem: int,
) -> sp.Rational | WignerD:
Expand Down Expand Up @@ -380,16 +380,18 @@ def decay(self) -> ThreeBodyDecay:
class DynamicsBuilder(Protocol):
def __call__(
self, decay_chain: ThreeBodyDecayChain
) -> tuple[sp.Expr, dict[sp.Symbol, float]]: ...
) -> tuple[sp.Expr, dict[sp.Symbol, float | complex]]: ...


def formulate_non_resonant(
decay_chain: ThreeBodyDecayChain,
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
) -> tuple[sp.Expr, dict[sp.Symbol, float | complex]]:
return sp.Rational(1), {}


def create_mass_symbol_mapping(decay: ThreeBodyDecay) -> dict[sp.Symbol, float]:
def create_mass_symbol_mapping(
decay: ThreeBodyDecay,
) -> dict[sp.Symbol, float | complex]:
return {
sp.Symbol(f"m{i}", nonnegative=True): decay.states[i].mass
for i in sorted(decay.states) # ensure that dict keys are sorted by state ID
Expand Down
10 changes: 7 additions & 3 deletions src/ampform_dpd/_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from __future__ import annotations

from typing import TYPE_CHECKING, SupportsFloat
from typing import TYPE_CHECKING, Iterable, SupportsFloat

import sympy as sp

if TYPE_CHECKING:
from attrs import Attribute

from ampform_dpd.decay import LSCoupling
from ampform_dpd.decay import LSCoupling, ThreeBodyDecayChain


def assert_spin_value(instance, attribute: Attribute, value: sp.Rational) -> None:
Expand All @@ -18,7 +18,7 @@ def assert_spin_value(instance, attribute: Attribute, value: sp.Rational) -> Non
raise ValueError(msg)


def to_ls(obj: LSCoupling | tuple[int, SupportsFloat] | None) -> LSCoupling:
def to_ls(obj: LSCoupling | tuple[int, SupportsFloat] | None) -> LSCoupling | None:
from ampform_dpd.decay import LSCoupling # noqa: PLC0415

if obj is None:
Expand All @@ -32,5 +32,9 @@ def to_ls(obj: LSCoupling | tuple[int, SupportsFloat] | None) -> LSCoupling:
raise TypeError(msg)


def to_chains(obj: Iterable[ThreeBodyDecayChain]) -> tuple[ThreeBodyDecayChain, ...]:
return tuple(obj)


def to_rational(obj: SupportsFloat) -> sp.Rational:
return sp.Rational(obj)
14 changes: 7 additions & 7 deletions src/ampform_dpd/adapter/qrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def to_three_body_decay(
for i, idx in enumerate(sorted(some_transition.final_states), 1)
}
return ThreeBodyDecay(
states={0: initial_state, **final_states},
states={0: initial_state, **final_states}, # type:ignore[dict-item]
chains=tuple(sorted(to_decay_chain(t) for t in transitions)),
)

Expand Down Expand Up @@ -97,9 +97,9 @@ def _convert_edge(state: Any) -> Particle:
raise NotImplementedError(msg)
return Particle(
name=particle.name,
latex=particle.latex,
latex=particle.name if particle.latex is None else particle.latex,
spin=particle.spin,
parity=particle.parity,
parity=int(particle.parity), # type:ignore[arg-type]
mass=particle.mass,
width=particle.width,
)
Expand Down Expand Up @@ -131,11 +131,11 @@ def filter_min_ls(
min_transitions = []
for group in grouped_transitions.values():
transition, *_ = group
min_transition = FrozenTransition(
min_transition: FrozenTransition[EdgeType, NodeType] = FrozenTransition(
topology=transition.topology,
states=transition.states,
interactions={
i: min(t.interactions[i] for t in group)
i: min(t.interactions[i] for t in group) # type:ignore[type-var]
for i in transition.interactions
},
)
Expand All @@ -146,6 +146,6 @@ def filter_min_ls(
def load_particles() -> qrules.particle.ParticleCollection:
src_dir = Path(__file__).parent.parent
particle_database = qrules.load_default_particles()
additional_definitions = qrules.io.load(src_dir / "particle-definitions.yml")
particle_database.update(additional_definitions)
additional_definitions = qrules.io.load(src_dir / "particle-definitions.yml") # type:ignore[arg-type]
particle_database.update(additional_definitions) # type:ignore[arg-type]
return particle_database
Loading
Loading