Skip to content

IrregSpacedGrid: extra_param_names for user-side runtime scalars#351

Closed
hmgaudecker wants to merge 11 commits into
mainfrom
feat/runtime-grid-extra-params
Closed

IrregSpacedGrid: extra_param_names for user-side runtime scalars#351
hmgaudecker wants to merge 11 commits into
mainfrom
feat/runtime-grid-extra-params

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

@hmgaudecker hmgaudecker commented May 11, 2026

Summary

  • Adds IrregSpacedGrid(extra_param_names=("foo", "bar")) so scalar params consumed by user-side code that builds the runtime points (e.g., an upper bound that changes per optimizer iteration) can flow through fixed_params / per-call params without tripping broadcast_to_template's unknown-keys check.
  • Each name surfaces as a ScalarFloat slot in the grid's pseudo-function template entry alongside points. Pylcm itself never reads these values — they're carried through so user-side injection code can pick them up via model.fixed_params[name] (or the resolved internal params) before computing the points.
  • Rejected on fixed-points grids — a baked-in grid has no user-side runtime computation to feed.

Closes #348. Stacks on #350 (uses the feat/categorical-scalarint branch as base).

Test plan

  • pixi run -e tests-cpu tests — 961 passed (4 new tests in test_runtime_params.py).
  • pixi run -e type-checking ty — clean.
  • prek run --all-files — clean.
  • Downstream cascade verified in aca-model (216 + workspace 407 green; aca-model PR cascading off PR [pre-commit.ci] pre-commit autoupdate #10 follows).

🤖 Generated with Claude Code

hmgaudecker and others added 4 commits May 11, 2026 14:32
… values (#349)

The decorator used to read `__annotations__` only for field names and
never validate them — every consumer wrote `field: int` even though
codes flowed through JAX as `jnp.int32` everywhere downstream. The
mismatch was a consistent lie ty couldn't catch, and downstream
`jnp.int32(...)` wraps papered over the dtype gap at runtime.

This change closes the gap at all three layers in lockstep:

* **Annotation gate.** `@categorical` requires every field to be
  annotated `ScalarInt` (from `lcm.typing`); other annotations raise
  the new `CategoricalDefinitionError` at decoration time, naming the
  offending fields and pointing at the import.
* **Runtime values.** Class- and instance-level attribute access
  return 0-d `jnp.int32` scalars. The decorator assigns
  `field(default=i, init=False)` (Python int placeholders for
  `dataclass(frozen=True)`'s mutable-default check) then overrides
  the class attributes with `jnp.int32(i)` via `type.__setattr__`
  post-decoration; `init=False` keeps instance `__dict__` empty so
  attribute lookup falls through to the class scalar.
* **Validator.** `validate_category_class` now checks
  `isinstance(value, jax.Array) and value.shape == () and
  jnp.issubdtype(value.dtype, jnp.integer)`; the consecutiveness
  check coerces via `int(v)`.

Hashability fix: JAX 0-d arrays aren't hashable, so the two pylcm
dict-inversion sites (`simulation/simulate.py`, `initial_conditions.py`)
and `DiscreteGrid.codes` now coerce to Python `int` at the boundary
where hashable keys are needed. New helper
`lcm.invert_regime_ids` lives in `lcm.utils.containers` and centralises
the inversion pattern; downstream consumers can call it instead of
hand-rolling `{int(v): k for k, v in ...}`. `RegimeNamesToIds`
tightens to `MappingProxyType[RegimeName, ScalarInt]`.

Sweep: every `@categorical`-decorated class in `src/lcm_examples`,
`tests/`, `tests/test_models/`, and the user-guide notebooks/docs
moves from `field: int` to `field: ScalarInt`. The `Effort` example
in `mahler_yum_2024/_model.py` switches to the `make_dataclass +
type.__setattr__` shim so its 40 `ScalarInt` defaults bypass
`dataclass`'s mutable-default check. `tests/conftest.py` and
`tests/test_grids.py` gain a `_make_dc` helper for hand-crafted
test dataclasses with `ScalarInt` values.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PR #350 tightens `@categorical` to require `ScalarInt` field
annotations. The benchmarks-cuda12 env still pulls aca-model from the
pre-#350 sha (`9ac2043`), which annotates fields as `int` and fails
the new decoration-time gate during `setup_cache`.

Point the pin at the aca-model #350 cascade branch
(`feature/categorical-scalarint`, head `b807b28`) until both PRs land.
After aca-model PR#10 and pylcm#350 merge, revert to a `main`-tracking
sha.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
aca-model's `get_benchmark_params` returns three values (likely a
moments tuple alongside the two pre-existing entries); the benchmark
file was stuck on the older 2-tuple shape. aca-model's own
`test_benchmark.py` runs against the up-to-date in-tree helper, so the
drift only surfaces here once pylcm's benchmark CI pulls a fresh
aca-model.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`broadcast_to_template` rejects any `fixed_params` key that no DAG
function lists in its signature. That assumption breaks for
runtime-supplied grids whose points are computed by user-side code
from per-iteration scalars (e.g., a grid upper bound) — pylcm itself
never reads those scalars, but they still need to flow through the
params machinery instead of being funneled through
`simulate(params=...)` where they compete with estimation parameters.

This adds an `extra_param_names: tuple[str, ...]` keyword to
`IrregSpacedGrid`. When `pass_points_at_runtime=True`, each name in
the tuple surfaces as a `ScalarFloat` slot in the action/state's
pseudo-function template entry alongside `points`. The user's
injection code reads the resolved values from `model.fixed_params`
(or per-iteration params) before computing the points; pylcm just
threads them through `broadcast_to_template` unchanged.

Rejected on fixed-points grids — a baked-in grid has no user-side
runtime computation to feed.

Closes #348.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@read-the-docs-community
Copy link
Copy Markdown

read-the-docs-community Bot commented May 11, 2026

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 11, 2026

Benchmark comparison (main → HEAD)

Comparing 3eec679e (main) → a02158f8 (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 28.052 s 27.750 s 0.99
peak GPU mem 579 MB 2.56 GB 4.43
compilation time 295.82 s 292.99 s 0.99
peak CPU mem 7.44 GB 7.63 GB 1.03
Mahler-Yum execution time 4.803 s 4.713 s 0.98
peak GPU mem 529 MB 529 MB 1.00
compilation time 14.32 s 14.51 s 1.01
peak CPU mem 1.68 GB 1.68 GB 1.00
Precautionary Savings - Solve execution time 53.9 ms 51.5 ms 0.95
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.70 s 2.78 s 1.03
peak CPU mem 1.13 GB 1.14 GB 1.00
Precautionary Savings - Simulate execution time 120.6 ms 115.5 ms 0.96
peak GPU mem 344 MB 344 MB 1.00
compilation time 5.00 s 4.83 s 0.97
peak CPU mem 1.31 GB 1.31 GB 1.00
Precautionary Savings - Solve & Simulate execution time 145.1 ms 149.1 ms 1.03
peak GPU mem 578 MB 578 MB 1.00
compilation time 6.99 s 7.19 s 1.03
peak CPU mem 1.28 GB 1.29 GB 1.00
Precautionary Savings - Solve & Simulate (irreg) execution time 292.7 ms 283.1 ms 0.97
peak GPU mem 2.19 GB 2.19 GB 1.00
compilation time 7.35 s 7.23 s 0.98
peak CPU mem 1.34 GB 1.34 GB 1.01

hmgaudecker and others added 6 commits May 11, 2026 19:52
Class-attr access already documents the `Array(0, dtype=int32)` shape;
the instance fall-through line repeats the same value without adding
information about the decorator contract.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_to_names

Mirrors the forward direction `regime_names_to_ids` so the inverse is
read symmetrically. No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Matches the inverse direction's symmetry with `regime_names_to_ids`
and the recent rename in `simulation.initial_conditions`. No behavior
change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- bench: simplify get_benchmark_params unpacking via index access.
- categorical: tighten validate_category_class docstring to refer
  to `ScalarInt`s directly without restating the type.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a runtime-points subsection under the IrregSpacedGrid docs in
`docs/user_guide/grids.md`: explains what `extra_param_names`
declares, why it's needed (carries scalars through the params
template without `Unknown keys` errors), and shows a minimal
inject-points helper plus the call site against `solve()`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Base automatically changed from feat/categorical-scalarint to main May 12, 2026 05:32
Cascades #350's squash merge (RegimeIdsToNames alias + renames).
Conflicts in typing.py/logging.py/initial_conditions.py resolved
in favor of main — feat/runtime-grid-extra-params had the older
un-squashed history of the same content, so main's squash is the
canonical form.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker
Copy link
Copy Markdown
Member Author

Closing without merging.

The motivating use case (aca-model's consumption_dollars grid pinned to max_consumption_dollars) doesn't actually need the scalar to live in pylcm's params machinery — no DAG function consumes it; only the user-side injector that constructs the gridpoints reads it.

The user-side fix is to pass that scalar as an explicit Python argument to the injector, sourced from wherever it naturally lives (an aca-data pickle, a config module, whatever). pylcm's fixed_params then doesn't have to open a per-key hole in broadcast_to_template's unknown-keys check.

Runtime points themselves (IrregSpacedGrid(n_points=N) with params[regime][slot]["points"] injection) already work as of #338 — that's still the load-bearing mechanism. Just don't carry the construction inputs through fixed_params.

Closing #348 along with this PR.

@hmgaudecker hmgaudecker deleted the feat/runtime-grid-extra-params branch May 12, 2026 18:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fixed_params should accept keys required for runtime-supplied grids

1 participant