IrregSpacedGrid: extra_param_names for user-side runtime scalars#351
IrregSpacedGrid: extra_param_names for user-side runtime scalars#351hmgaudecker wants to merge 11 commits into
Conversation
… values (#349) The decorator used to read `__annotations__` only for field names and never validate them — every consumer wrote `field: int` even though codes flowed through JAX as `jnp.int32` everywhere downstream. The mismatch was a consistent lie ty couldn't catch, and downstream `jnp.int32(...)` wraps papered over the dtype gap at runtime. This change closes the gap at all three layers in lockstep: * **Annotation gate.** `@categorical` requires every field to be annotated `ScalarInt` (from `lcm.typing`); other annotations raise the new `CategoricalDefinitionError` at decoration time, naming the offending fields and pointing at the import. * **Runtime values.** Class- and instance-level attribute access return 0-d `jnp.int32` scalars. The decorator assigns `field(default=i, init=False)` (Python int placeholders for `dataclass(frozen=True)`'s mutable-default check) then overrides the class attributes with `jnp.int32(i)` via `type.__setattr__` post-decoration; `init=False` keeps instance `__dict__` empty so attribute lookup falls through to the class scalar. * **Validator.** `validate_category_class` now checks `isinstance(value, jax.Array) and value.shape == () and jnp.issubdtype(value.dtype, jnp.integer)`; the consecutiveness check coerces via `int(v)`. Hashability fix: JAX 0-d arrays aren't hashable, so the two pylcm dict-inversion sites (`simulation/simulate.py`, `initial_conditions.py`) and `DiscreteGrid.codes` now coerce to Python `int` at the boundary where hashable keys are needed. New helper `lcm.invert_regime_ids` lives in `lcm.utils.containers` and centralises the inversion pattern; downstream consumers can call it instead of hand-rolling `{int(v): k for k, v in ...}`. `RegimeNamesToIds` tightens to `MappingProxyType[RegimeName, ScalarInt]`. Sweep: every `@categorical`-decorated class in `src/lcm_examples`, `tests/`, `tests/test_models/`, and the user-guide notebooks/docs moves from `field: int` to `field: ScalarInt`. The `Effort` example in `mahler_yum_2024/_model.py` switches to the `make_dataclass + type.__setattr__` shim so its 40 `ScalarInt` defaults bypass `dataclass`'s mutable-default check. `tests/conftest.py` and `tests/test_grids.py` gain a `_make_dc` helper for hand-crafted test dataclasses with `ScalarInt` values. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR #350 tightens `@categorical` to require `ScalarInt` field annotations. The benchmarks-cuda12 env still pulls aca-model from the pre-#350 sha (`9ac2043`), which annotates fields as `int` and fails the new decoration-time gate during `setup_cache`. Point the pin at the aca-model #350 cascade branch (`feature/categorical-scalarint`, head `b807b28`) until both PRs land. After aca-model PR#10 and pylcm#350 merge, revert to a `main`-tracking sha. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model's `get_benchmark_params` returns three values (likely a moments tuple alongside the two pre-existing entries); the benchmark file was stuck on the older 2-tuple shape. aca-model's own `test_benchmark.py` runs against the up-to-date in-tree helper, so the drift only surfaces here once pylcm's benchmark CI pulls a fresh aca-model. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`broadcast_to_template` rejects any `fixed_params` key that no DAG function lists in its signature. That assumption breaks for runtime-supplied grids whose points are computed by user-side code from per-iteration scalars (e.g., a grid upper bound) — pylcm itself never reads those scalars, but they still need to flow through the params machinery instead of being funneled through `simulate(params=...)` where they compete with estimation parameters. This adds an `extra_param_names: tuple[str, ...]` keyword to `IrregSpacedGrid`. When `pass_points_at_runtime=True`, each name in the tuple surfaces as a `ScalarFloat` slot in the action/state's pseudo-function template entry alongside `points`. The user's injection code reads the resolved values from `model.fixed_params` (or per-iteration params) before computing the points; pylcm just threads them through `broadcast_to_template` unchanged. Rejected on fixed-points grids — a baked-in grid has no user-side runtime computation to feed. Closes #348. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Benchmark comparison (main → HEAD)Comparing
|
Class-attr access already documents the `Array(0, dtype=int32)` shape; the instance fall-through line repeats the same value without adding information about the decorator contract. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_to_names Mirrors the forward direction `regime_names_to_ids` so the inverse is read symmetrically. No behavior change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Matches the inverse direction's symmetry with `regime_names_to_ids` and the recent rename in `simulation.initial_conditions`. No behavior change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- bench: simplify get_benchmark_params unpacking via index access. - categorical: tighten validate_category_class docstring to refer to `ScalarInt`s directly without restating the type. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a runtime-points subsection under the IrregSpacedGrid docs in `docs/user_guide/grids.md`: explains what `extra_param_names` declares, why it's needed (carries scalars through the params template without `Unknown keys` errors), and shows a minimal inject-points helper plus the call site against `solve()`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Cascades #350's squash merge (RegimeIdsToNames alias + renames). Conflicts in typing.py/logging.py/initial_conditions.py resolved in favor of main — feat/runtime-grid-extra-params had the older un-squashed history of the same content, so main's squash is the canonical form. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Closing without merging. The motivating use case (aca-model's The user-side fix is to pass that scalar as an explicit Python argument to the injector, sourced from wherever it naturally lives (an aca-data pickle, a config module, whatever). pylcm's Runtime points themselves ( Closing #348 along with this PR. |
Summary
IrregSpacedGrid(extra_param_names=("foo", "bar"))so scalar params consumed by user-side code that builds the runtime points (e.g., an upper bound that changes per optimizer iteration) can flow throughfixed_params/ per-callparamswithout trippingbroadcast_to_template's unknown-keys check.ScalarFloatslot in the grid's pseudo-function template entry alongsidepoints. Pylcm itself never reads these values — they're carried through so user-side injection code can pick them up viamodel.fixed_params[name](or the resolved internal params) before computing the points.Closes #348. Stacks on #350 (uses the
feat/categorical-scalarintbranch as base).Test plan
pixi run -e tests-cpu tests— 961 passed (4 new tests intest_runtime_params.py).pixi run -e type-checking ty— clean.prek run --all-files— clean.🤖 Generated with Claude Code