diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8104668..5c78f7c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,5 +1,9 @@ --- name: main +# aca-model is a git submodule of the aca-dev workspace and has no pixi config +# of its own — the pixi environments live in the parent workspace, whose +# `tests-cpu` env has editable path-dependencies on private sibling repos that a +# standalone CI runner cannot clone. CI therefore installs with pip directly. concurrency: group: ${{ github.head_ref || github.run_id }} cancel-in-progress: true @@ -26,10 +30,10 @@ jobs: - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - - name: Install pylcm (feature branch — revert to @main once pylcm#348/#350 merge) + - name: Install pylcm (pinned to feat/distributed-V-arrays / PR #364 until it merges to main) run: >- pip install "pylcm @ - git+https://github.com/OpenSourceEconomics/pylcm.git@feat/runtime-grid-extra-params" + git+https://github.com/OpenSourceEconomics/pylcm.git@feat/distributed-V-arrays" - name: Install aca-model with test deps run: pip install -e . pytest pdbp - name: Run pytest diff --git a/.gitignore b/.gitignore index e2bab43..70259d4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,59 @@ -__pycache__/ -*.py[cod] +# Claude Code +.claude/ + +# Distribution / packaging +*.egg *.egg-info/ -dist/ +*.manifest +*.spec +.eggs/ +.installed.cfg build/ -bld/ +dist/ +MANIFEST +sdist/ +wheels/ + +# IDE +.idea/ +.vscode/ + +# Jupyter / Jupyter Book +.ipynb_checkpoints/ +_build + +# macOS +.DS_Store + +# pixi .pixi/ +node_modules/ + +# pytask .pytask/ +.pytask.sqlite3 +bld/ +out/ +pytask.lock +pytask.lock.journal + +# Python +__pycache__/ +*.py[cod] +*.so +*$py.class + +# Ruff +.ruff_cache/ + +# Testing +.cache/ .coverage +.coverage.* +.hypothesis/ +.pytest_cache/ +coverage.xml htmlcov/ + +# Version file (generated by hatch-vcs) +src/*/_version.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3188ab..50b0b1a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,9 @@ repos: - repo: meta hooks: - - id: check-hooks-apply + # check-hooks-apply is omitted: aca-model ships no notebooks yet, so the + # boilerplate nbstripout hook matches nothing and that meta check would + # fail. Re-add it once the repo gains a notebook. - id: check-useless-excludes - repo: https://github.com/tox-dev/pyproject-fmt rev: v2.21.1 @@ -37,6 +39,7 @@ repos: - id: name-tests-test args: - --pytest-test-first + exclude: ^tests/helpers/ - id: no-commit-to-branch args: - --branch @@ -46,6 +49,10 @@ repos: rev: v1.38.0 hooks: - id: yamllint + - repo: https://github.com/python-jsonschema/check-jsonschema + rev: 0.37.2 + hooks: + - id: check-github-workflows - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.15.12 hooks: @@ -61,6 +68,13 @@ repos: - jupyter - pyi - python + - repo: https://github.com/kynan/nbstripout + rev: 0.9.1 + hooks: + - id: nbstripout + args: + - --extra-keys + - metadata.kernelspec metadata.language_info.version metadata.vscode - repo: https://github.com/executablebooks/mdformat rev: 1.0.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 7bfbd0b..9df7876 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,9 @@ [build-system] build-backend = "hatchling.build" -requires = [ "hatchling" ] +requires = [ "hatch-vcs", "hatchling" ] [project] name = "aca-model" -version = "0.0.0" description = "Core lifecycle model for the ACA structural retirement project." readme = { file = "README.md", content-type = "text/markdown" } keywords = [ @@ -23,11 +22,12 @@ classifiers = [ "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.14", ] +dynamic = [ "version" ] dependencies = [ "attrs", + "beartype", "cloudpickle", "dags", - "estimagic", "jax>=0.9", "jaxtyping", "numpy>=2.2", @@ -43,13 +43,19 @@ email = "hmgaudecker@uni-bonn.de" [[project.maintainers]] name = "Hans-Martin von Gaudecker" email = "hmgaudecker@uni-bonn.de" +[project.urls] +Github = "https://github.com/OpenSourceEconomics/aca-model" +Repository = "https://github.com/OpenSourceEconomics/aca-model" +Tracker = "https://github.com/OpenSourceEconomics/aca-model/issues" [tool.hatch] +build.hooks.vcs.version-file = "src/aca_model/_version.py" build.targets.sdist.exclude = [ "tests" ] build.targets.sdist.only-packages = true build.targets.wheel.only-include = [ "src" ] build.targets.wheel.sources = [ "src" ] metadata.allow-direct-references = true +version.source = "vcs" [tool.ruff] fix = true @@ -84,9 +90,21 @@ extend-ignore = [ "RUF002", # Ambiguous Unicode in docstrings (Greek letters in math) "RUF003", # Ambiguous Unicode in comments (Greek letters in math) ] -per-file-ignores."src/aca_model/models/*" = [ "E501" ] -per-file-ignores."task_*.py" = [ "ANN", "ARG001" ] -per-file-ignores."tests/*" = [ "D", "E501", "INP001", "PD011", "PLR2004", "S101" ] +per-file-ignores."src/aca_model/models/*" = [ + "E501", # Line too long (generated model files) +] +per-file-ignores."task_*.py" = [ + "ANN", # Type annotations (use ty instead) + "ARG001", # Unused function argument (pytask signatures) +] +per-file-ignores."tests/*" = [ + "D", # Docstrings + "E501", # Line too long + "INP001", # Implicit namespace package + "PD011", # Use of .values (false positives on non-pandas objects) + "PLR2004", # Magic value used in comparison + "S101", # Use of assert +] pydocstyle.convention = "google" [tool.pyproject-fmt] diff --git a/src/aca_model/__init__.py b/src/aca_model/__init__.py index ac63b78..1fc43f7 100644 --- a/src/aca_model/__init__.py +++ b/src/aca_model/__init__.py @@ -1,3 +1,22 @@ import jax jax.config.update("jax_enable_x64", True) + +# Import lcm before installing the claw so its `_jaxtyping_patch` (picklable +# jaxtyping sentinel) and `MappingProxyType` pytree registration are in place. +import lcm # noqa: E402, F401 + +# Install beartype's AST-rewriting claw on the whole `aca_model` package before +# any submodule is imported. The claw transforms each module's AST at first +# import to insert runtime type checks against its annotations; aca_model's +# numerical DAG/transition/utility functions are otherwise unchecked, since +# pylcm's own claw is scoped to `lcm.*`. Violations surface as beartype's +# `BeartypeCallHintViolation` — aca_model is an application, not a library with +# a documented exception contract. +from beartype import BeartypeConf, BeartypeStrategy # noqa: E402 +from beartype.claw import beartype_package # noqa: E402 + +beartype_package( + "aca_model", + conf=BeartypeConf(strategy=BeartypeStrategy.On, is_pep484_tower=True), +) diff --git a/src/aca_model/_benchmark_data/benchmark_params.pkl b/src/aca_model/_benchmark_data/benchmark_params.pkl index f7c505e..24c5138 100644 Binary files a/src/aca_model/_benchmark_data/benchmark_params.pkl and b/src/aca_model/_benchmark_data/benchmark_params.pkl differ diff --git a/src/aca_model/_version.py b/src/aca_model/_version.py deleted file mode 100644 index 6c8e6b9..0000000 --- a/src/aca_model/_version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.0.0" diff --git a/src/aca_model/aca/model.py b/src/aca_model/aca/model.py index 6d39ac6..22437dd 100644 --- a/src/aca_model/aca/model.py +++ b/src/aca_model/aca/model.py @@ -25,6 +25,7 @@ def create_model( derived_categoricals: Mapping[str, DiscreteGrid], grid_config: GridConfig, pref_type_grid: DiscreteGrid, + subjects_batch_size: int = 0, ) -> Model: """Create an ACA policy variant model. @@ -45,6 +46,10 @@ def create_model( `pref_type`. grid_config: Continuous-grid point counts. pref_type_grid: Pref-type `DiscreteGrid`. + subjects_batch_size: Per-device chunk size for the simulate-side + per-subject dispatch. `0` (default) keeps a single vmap over + all subjects; `>0` chunks each device's local shard via + `jax.lax.map`. Tune via `grid_config.get_subjects_batch_size(log_level)`. Returns: pylcm Model. @@ -71,4 +76,5 @@ def create_model( fixed_params=fixed_params, derived_categoricals=derived_categoricals, n_subjects=n_subjects, + subjects_batch_size=subjects_batch_size, ) diff --git a/src/aca_model/agent/preferences.py b/src/aca_model/agent/preferences.py index 612896b..97fff2c 100644 --- a/src/aca_model/agent/preferences.py +++ b/src/aca_model/agent/preferences.py @@ -140,9 +140,10 @@ def u_alive( coefficient_rra: FloatND, utility_scale_factor: FloatND, ) -> FloatND: - """Within-period utility for every non-dead regime: CES over consumption and leisure. + """Within-period utility for every non-dead regime. - `leisure` is a DAG input — supplied per-regime by `leisure_canwork_retiree_or_nongroup`, + CES over consumption and leisure. `leisure` is a DAG input — supplied + per-regime by `leisure_canwork_retiree_or_nongroup`, `leisure_canwork_tied`, or `leisure_forcedout`. """ composite = consumption_equiv**consumption_weight * leisure ** ( diff --git a/src/aca_model/baseline/model.py b/src/aca_model/baseline/model.py index 98416ce..85fd8d9 100644 --- a/src/aca_model/baseline/model.py +++ b/src/aca_model/baseline/model.py @@ -28,6 +28,7 @@ def create_model( derived_categoricals: Mapping[str, DiscreteGrid], grid_config: GridConfig, pref_type_grid: DiscreteGrid, + subjects_batch_size: int = 0, ) -> Model: """Create the baseline structural retirement model. @@ -52,6 +53,10 @@ def create_model( pref_type_grid: Pref-type `DiscreteGrid`. Pass `DiscreteGrid(PrefType)` for the production 3-type layout, or a compact variant (e.g. `DiscreteGrid(BenchmarkPrefType)`). + subjects_batch_size: Per-device chunk size for the simulate-side + per-subject dispatch. `0` (default) keeps a single vmap over + all subjects; `>0` chunks each device's local shard via + `jax.lax.map`. Tune via `grid_config.get_subjects_batch_size(log_level)`. Returns: A pylcm Model with 19 regimes (18 non-terminal + dead) spanning @@ -78,4 +83,5 @@ def create_model( fixed_params=fixed_params, derived_categoricals=derived_categoricals, n_subjects=n_subjects, + subjects_batch_size=subjects_batch_size, ) diff --git a/src/aca_model/baseline/regimes/_common.py b/src/aca_model/baseline/regimes/_common.py index b4d1e26..b783ecb 100644 --- a/src/aca_model/baseline/regimes/_common.py +++ b/src/aca_model/baseline/regimes/_common.py @@ -10,20 +10,21 @@ from typing import Any, Literal, TypedDict import jax.numpy as jnp -import lcm.shocks.ar1 -import lcm.shocks.iid import numpy as np +from _lcm.grids.continuous import ContinuousGrid from lcm import ( DiscreteGrid, IrregSpacedGrid, LinSpacedGrid, MarkovTransition, + NormalIIDProcess, + PiecewiseGridSegment, + PiecewiseLinSpacedGrid, Regime, + RouwenhorstAR1Process, categorical, ) -from lcm.grids.continuous import ContinuousGrid -from lcm.grids.piecewise import Piece, PiecewiseLinSpacedGrid -from lcm.typing import BoolND, FloatND, RegimeName, ScalarInt, UserParams +from lcm.typing import BoolND, FloatND, IntND, RegimeName, ScalarInt, UserParams from aca_model.agent import ( assets_and_income, @@ -34,7 +35,7 @@ from aca_model.agent.health import Health, HealthWithDisability from aca_model.agent.labor_market import LaborSupply, LaggedLaborSupply, SpousalIncome from aca_model.baseline import health_insurance -from aca_model.baseline.health_insurance import BuyPrivate, HealthInsuranceState +from aca_model.baseline.health_insurance import BuyPrivate from aca_model.config import MODEL_CONFIG, GridConfig from aca_model.environment import social_security, taxes from aca_model.environment.social_security import ClaimedSS @@ -237,14 +238,14 @@ def build_grids( # grid to have unconditional variance 1, the Rouwenhorst innovation # std must be √(1 − ρ²). Passing the σ_y itself (≈0.577 for hcc, # 0.5627 for wage) would mis-scale the grid. - wage_res = lcm.shocks.ar1.Rouwenhorst( + wage_res = RouwenhorstAR1Process( n_points=grid_config.n_wage_res_gridpoints, rho=_WAGE_RHO, sigma=(1.0 - _WAGE_RHO**2) ** 0.5, mu=0.0, ) hcc_persistent = get_hcc_persistent_shock(grid_config=grid_config) - hcc_transitory = lcm.shocks.iid.Normal( + hcc_transitory = NormalIIDProcess( n_points=grid_config.n_hcc_transitory_gridpoints, gauss_hermite=True, mu=0.0, @@ -261,11 +262,11 @@ def build_grids( stop=500_000.0, n_points=grid_config.n_assets_gridpoints, batch_size=grid_config.n_assets_batch_size, + distributed=True, ), aime=_build_aime_grid(grid_config=grid_config, fixed_params=fixed_params), consumption_dollars=IrregSpacedGrid( n_points=grid_config.n_consumption_dollars_gridpoints, - extra_param_names=("max_consumption_dollars",), ), wage_res=wage_res, hcc_persistent=hcc_persistent, @@ -274,7 +275,7 @@ def build_grids( ) -def get_hcc_persistent_shock(*, grid_config: GridConfig) -> lcm.shocks.ar1.Rouwenhorst: +def get_hcc_persistent_shock(*, grid_config: GridConfig) -> RouwenhorstAR1Process: """Return the persistent-HCC AR(1) shock grid for a given `grid_config`. Exposed so callers that need the shock's gridpoints / transition @@ -282,7 +283,7 @@ def get_hcc_persistent_shock(*, grid_config: GridConfig) -> lcm.shocks.ar1.Rouwe can derive them from `grid_config` alone without instantiating a full `Model`. """ - return lcm.shocks.ar1.Rouwenhorst( + return RouwenhorstAR1Process( n_points=grid_config.n_hcc_persistent_gridpoints, rho=_HCC_RHO, sigma=(1.0 - _HCC_RHO**2) ** 0.5, @@ -306,20 +307,26 @@ def _build_aime_grid( this path; the total is fixed by the PIA structure (32 points). """ kinks = [float(k) for k in np.asarray(fixed_params["pia_aime_grid"])] - pieces = ( - Piece(interval=f"[{kinks[0]}, {kinks[1]})", n_points=_AIME_PIECE_N_POINTS[0]), - Piece(interval=f"[{kinks[1]}, {kinks[2]})", n_points=_AIME_PIECE_N_POINTS[1]), - Piece(interval=f"[{kinks[2]}, {kinks[3]}]", n_points=_AIME_PIECE_N_POINTS[2]), + segments = ( + PiecewiseGridSegment( + interval=f"[{kinks[0]}, {kinks[1]})", n_points=_AIME_PIECE_N_POINTS[0] + ), + PiecewiseGridSegment( + interval=f"[{kinks[1]}, {kinks[2]})", n_points=_AIME_PIECE_N_POINTS[1] + ), + PiecewiseGridSegment( + interval=f"[{kinks[2]}, {kinks[3]}]", n_points=_AIME_PIECE_N_POINTS[2] + ), ) return PiecewiseLinSpacedGrid( - pieces=pieces, batch_size=grid_config.n_aime_batch_size + segments=segments, batch_size=grid_config.n_aime_batch_size ) def _compute_max_annual_labor_income( *, wage_params: Mapping[str, Any], - wage_res_grid: lcm.shocks.ar1.Rouwenhorst, + wage_res_grid: RouwenhorstAR1Process, ) -> float: """Return the annual labor income at the top of the wage grid. @@ -418,7 +425,7 @@ def build_actions(spec: RegimeSpec, grids: Grids) -> dict: return actions -def build_regime_probs(target: FloatND, survival: FloatND) -> FloatND: +def build_regime_probs(target: IntND, survival: FloatND) -> FloatND: """Build regime transition probability vector.""" probs = jnp.zeros(19) probs = probs.at[RegimeId.dead].set(1.0 - survival) @@ -603,10 +610,10 @@ def make_targets(name: str) -> tuple[dict[str, int], dict[str, int]]: def select_target_for_age( - next_age: int | FloatND, + next_age: int | IntND | FloatND, mc_next: bool | BoolND, tgts: dict[str, int], -) -> FloatND: +) -> IntND: """Select target regime ID based on next-period age bracket.""" ss_choose = jnp.where( jnp.array(mc_next), diff --git a/src/aca_model/benchmark.py b/src/aca_model/benchmark.py index 5b519d5..8327cd8 100644 --- a/src/aca_model/benchmark.py +++ b/src/aca_model/benchmark.py @@ -12,7 +12,7 @@ Parameters (`fixed_params` + `params`) are a committed snapshot at `src/aca_model/_benchmark_data/benchmark_params.pkl`, generated by `scripts/regen_benchmark_params.py` against the current aca-data + -aca-estimation + aca-model code. Pref-type-indexed Series in `params` +aca-slurm + aca-model code. Pref-type-indexed Series in `params` are pre-truncated to two rows so the snapshot loads with no further reshaping; regenerate after any change that affects `fixed_params` shape (regime DAGs, aca-data outputs, key renames). @@ -100,19 +100,29 @@ def get_benchmark_params( ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: """Load the frozen `(fixed_params, wage_params, params)` snapshot. + `max_consumption_dollars` is popped out of `fixed_params` before + return — it's a grid-construction input read by + `inject_consumption_dollars_points`, not by any DAG function, so + leaving it in `fixed_params` would trip pylcm's unknown-keys check. + When `model` is provided, consumption_dollars gridpoints are injected into `params` for each regime that declares `consumption_dollars` as - an `IrregSpacedGrid` with runtime-supplied points. The lower bound is - read from `params["consumption_dollars_floor"]`. Pass `model=None` to - skip injection (e.g. when constructing the model with `fixed_params`). + an `IrregSpacedGrid` with runtime-supplied points. Pass `model=None` + to skip injection (e.g. when constructing the model with + `fixed_params`). """ with _PARAMS_FILE.open("rb") as fh: data = cloudpickle.load(fh) fixed_params = data["fixed_params"] wage_params = data["wage_params"] params = data["params"] + max_consumption_dollars = float(fixed_params.pop("max_consumption_dollars")) if model is not None: - params = inject_consumption_dollars_points(params=params, model=model) + params = inject_consumption_dollars_points( + params=params, + model=model, + max_consumption_dollars=max_consumption_dollars, + ) return fixed_params, wage_params, params @@ -133,7 +143,7 @@ def get_benchmark_initial_conditions( # Grid ranges come from any of the five regimes (shared structure). # Use to_jax() so the helper handles both LinSpacedGrid and # PiecewiseLinSpacedGrid (the latter has no `.start` / `.stop`). - ref_regime = model.regimes[_INITIAL_REGIMES[0]] + ref_regime = model.user_regimes[_INITIAL_REGIMES[0]] grids = ref_regime.states assets_pts = np.asarray(grids["assets"].to_jax()) aime_pts = np.asarray(grids["aime"].to_jax()) @@ -151,7 +161,7 @@ def get_benchmark_initial_conditions( ) return { - "regime": jnp.asarray(regime), + "regime_id": jnp.asarray(regime), "age": jnp.full(n_subjects, 51.0), "assets": jnp.asarray(rng.uniform(assets_lo, assets_hi, n_subjects)), "aime": jnp.asarray(rng.uniform(aime_lo, aime_hi, n_subjects)), diff --git a/src/aca_model/config.py b/src/aca_model/config.py index 101ef2d..e240841 100644 --- a/src/aca_model/config.py +++ b/src/aca_model/config.py @@ -1,7 +1,8 @@ """Configuration for the aca_model package.""" -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path +from types import MappingProxyType import plotly.io as pio from pytask import DataCatalog @@ -34,11 +35,36 @@ class GridConfig: n_hcc_persistent_gridpoints: int = 3 n_hcc_transitory_gridpoints: int = 5 # `batch_size` on the assets / AIME grids: chunked vmap stride for the - # outer state loop. Both partition the per-period Q intermediate so it - # fits in V100 16 GB once we splay across `pref_type`. Set to 0 in - # BENCHMARK_GRID_CONFIG to skip the Python-loop overhead. - n_assets_batch_size: int = 1 + # outer state loop. The assets axis is hardcoded `distributed=True` + # in regimes, so `n_assets_batch_size` must stay `0` — `>0 + distributed` + # is rejected by pylcm's grid-init guard. `n_aime_batch_size` is free + # to splay; `1` shrinks the per-period Q intermediate by 12x on hosts + # where the unsplayed kernel doesn't fit. + n_assets_batch_size: int = 0 n_aime_batch_size: int = 1 + # `batch_size` on the `pref_type` discrete grid: chunked vmap stride + # for the pref-type axis during solve. `1` (one pref-type per Python + # dispatch) shrinks the per-period Q intermediate by `n_pref_types` + # at the cost of an outer Python loop; `0` lets a single kernel span + # all pref-types. Defaults to `0` — the production overrides set it + # to `1` on hardware where the unsplayed kernel doesn't fit. + n_pref_type_batch_size: int = 0 + # Per-device chunk size for the simulate-side per-subject dispatch, + # keyed by `log_level`. Empty → 0 (no chunking) for every level. + # `log_level="off"` skips `validate_V` and its forced host-sync, which + # lets XLA pipeline across periods and reuse scratch — affordable + # chunk size grows. Use `get_subjects_batch_size(log_level)`. + subjects_batch_size_by_log_level: MappingProxyType[str, int] = field( + default_factory=lambda: MappingProxyType({}) + ) + + def get_subjects_batch_size(self, log_level: str) -> int: + """Return the per-device simulate chunk size for `log_level`. + + Returns 0 (no chunking) when this `GridConfig` defines no entry for + the given log level. + """ + return self.subjects_batch_size_by_log_level.get(log_level, 0) MODEL_CONFIG = ModelConfig() diff --git a/src/aca_model/consumption_dollars_grid.py b/src/aca_model/consumption_dollars_grid.py index d99f8a9..362498c 100644 --- a/src/aca_model/consumption_dollars_grid.py +++ b/src/aca_model/consumption_dollars_grid.py @@ -1,14 +1,13 @@ """Runtime-supplied gridpoints for the consumption_dollars action. -Consumption is declared as `IrregSpacedGrid(n_points=N, -extra_param_names=("max_consumption_dollars",))` in +Consumption is declared as `IrregSpacedGrid(n_points=N)` in `baseline.regimes._common.build_grids` so the bounds can track runtime parameters: the lower bound from the per-iteration `consumption_equiv_floor` parameter (and its couples-scaled twin), -the upper bound from `max_consumption_dollars` carried through -`fixed_params` (per pylcm#348). Callers must inject the actual -gridpoints into `params` via `inject_consumption_dollars_points` -before calling `model.solve()` / `model.simulate()`. +the upper bound from `max_consumption_dollars` supplied directly +by the caller. Callers must inject the actual gridpoints into +`params` via `inject_consumption_dollars_points` before calling +`model.solve()` / `model.simulate()`. The grid pins the two regime-relevant transfer-floor levels exactly on the action grid so the borrowing constraint's @@ -32,6 +31,7 @@ def inject_consumption_dollars_points( *, params: Mapping[str, Any], model: Model, + max_consumption_dollars: float, ) -> dict[str, Any]: """Inject consumption_dollars gridpoints into per-regime params. @@ -40,7 +40,7 @@ def inject_consumption_dollars_points( The lower two gridpoints are the single and married Dollar-valued transfer floors; the rest are geomspaced from the married floor up - to `model.fixed_params["max_consumption_dollars"]`. + to `max_consumption_dollars`. Args: params: Existing params mapping with `consumption_equiv_floor` @@ -48,8 +48,11 @@ def inject_consumption_dollars_points( new dict; the input is not mutated. model: Model whose regimes carry the runtime-points grid and whose `fixed_params` supplies `exponent` (married - equivalence-scale exponent) and `max_consumption_dollars` - (grid upper bound). + equivalence-scale exponent). + max_consumption_dollars: Grid upper bound. Sourced from the + caller (e.g. aca-data's `environment_constants.pkl`); not + routed through pylcm's params machinery because no DAG + function consumes it. Returns: New params dict with consumption_dollars points injected. @@ -61,11 +64,9 @@ def inject_consumption_dollars_points( """ consumption_equiv_floor = jnp.asarray(params["consumption_equiv_floor"]) exponent = jnp.asarray(model.fixed_params["exponent"]) - max_consumption_dollars = jnp.asarray( - model.fixed_params["max_consumption_dollars"] - ) + max_consumption_dollars_arr = jnp.asarray(max_consumption_dollars) out: dict[str, Any] = dict(params) - for regime_name, regime in model.regimes.items(): + for regime_name, regime in model.user_regimes.items(): if regime.terminal: continue grid = regime.actions.get("consumption_dollars") @@ -88,7 +89,7 @@ def inject_consumption_dollars_points( points = _compute_consumption_dollars_points( consumption_equiv_floor=consumption_equiv_floor, exponent=exponent, - max_consumption_dollars=max_consumption_dollars, + max_consumption_dollars=max_consumption_dollars_arr, n_points=grid.n_points, ) regime_entry = dict(out.get(regime_name, {})) diff --git a/src/aca_model/environment/social_security.py b/src/aca_model/environment/social_security.py index 8b655d1..837ff5d 100644 --- a/src/aca_model/environment/social_security.py +++ b/src/aca_model/environment/social_security.py @@ -11,10 +11,12 @@ from lcm import categorical from lcm.typing import ( Age, + BoolND, ContinuousState, DiscreteAction, DiscreteState, FloatND, + IntND, Period, ScalarFloat, ScalarInt, @@ -211,8 +213,8 @@ def _apply_benefit_rules( pia: FloatND, age: Age, period: Period, - ss: FloatND, - work: FloatND, + ss: IntND, + work: BoolND, labor_income: FloatND, early_ret_adjustment: FloatND, normal_retirement_age: ScalarInt, diff --git a/src/aca_model/environment/taxes.py b/src/aca_model/environment/taxes.py index 8ed25e8..6d9e1d3 100644 --- a/src/aca_model/environment/taxes.py +++ b/src/aca_model/environment/taxes.py @@ -8,7 +8,7 @@ import jax.numpy as jnp from lcm.params import MappingLeaf -from lcm.typing import DiscreteState, FloatND +from lcm.typing import DiscreteState, FloatND, IntND def gross_income( @@ -136,7 +136,7 @@ def marginal_rate( return sched["marginal_rates"][spousal_income, bracket_id] -def _find_bracket(income: FloatND, upper_bounds: FloatND) -> FloatND: +def _find_bracket(income: FloatND, upper_bounds: FloatND) -> IntND: """Find the tax bracket index for a given income level.""" return jnp.searchsorted(upper_bounds, income, side="right") diff --git a/tests/helpers/model.py b/tests/helpers/model.py index be778b4..55571e9 100644 --- a/tests/helpers/model.py +++ b/tests/helpers/model.py @@ -2,7 +2,7 @@ Used by tests that need a structurally faithful model without spelling out fixed_params, wage_params, and a pref-type grid at every call site. -Production callers (aca-estimation, scripts) assemble these explicitly. +Production callers (aca-slurm, scripts) assemble these explicitly. """ from lcm import DiscreteGrid, Model diff --git a/tests/test_aca_policies.py b/tests/test_aca_policies.py index 8ecc024..7ff22ae 100644 --- a/tests/test_aca_policies.py +++ b/tests/test_aca_policies.py @@ -63,7 +63,7 @@ def test_mandate_penalty_uninsured_above_exempt() -> None: income = jnp.array(40000.0) # 40000 * 0.025 = 1000, within [695, 2085] result = aca_hi.mandate_penalty( gross_income=income, - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), mandate_schedule=MANDATE_SCHEDULE, ) @@ -74,7 +74,7 @@ def test_mandate_penalty_insured_zero() -> None: """buy_private=yes produces no penalty.""" result = aca_hi.mandate_penalty( gross_income=jnp.array(40000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), mandate_schedule=MANDATE_SCHEDULE, ) @@ -85,7 +85,7 @@ def test_mandate_penalty_below_exempt_zero() -> None: """Income below exemption produces no penalty.""" result = aca_hi.mandate_penalty( gross_income=jnp.array(5000.0), # below 10350 - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), mandate_schedule=MANDATE_SCHEDULE, ) @@ -97,7 +97,7 @@ def test_mandate_penalty_clips_to_min() -> None: # 12000 * 0.025 = 300, below min of 695 result = aca_hi.mandate_penalty( gross_income=jnp.array(12000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), mandate_schedule=MANDATE_SCHEDULE, ) @@ -109,7 +109,7 @@ def test_mandate_penalty_clips_to_max() -> None: # 200000 * 0.025 = 5000, above max of 2085 result = aca_hi.mandate_penalty( gross_income=jnp.array(200000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), mandate_schedule=MANDATE_SCHEDULE, ) @@ -121,7 +121,7 @@ def test_hic_premium_subsidy_below_fpl_zero() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(10000.0), # below FPL_SINGLE - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -133,7 +133,7 @@ def test_hic_premium_subsidy_above_400_fpl_zero() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(50000.0), # above 4 * FPL_SINGLE = 47080 - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -148,7 +148,7 @@ def test_hic_premium_subsidy_at_200_fpl() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(premium), gross_income=jnp.array(income), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -160,7 +160,7 @@ def test_hic_premium_subsidy_uninsured_zero() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(2.0 * FPL_SINGLE), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -183,7 +183,7 @@ def test_cost_sharing_scale_brackets( """Verify each cost-sharing bracket produces the correct factor.""" result = aca_hi.cost_sharing( gross_income=jnp.array(income_fpl_frac * FPL_SINGLE), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), cost_sharing_schedule=COST_SHARING_SCHEDULE, ) @@ -194,7 +194,7 @@ def test_cost_sharing_scale_uninsured_one() -> None: """buy_private=no produces scale=1.0 (no reduction).""" result = aca_hi.cost_sharing( gross_income=jnp.array(1.2 * FPL_SINGLE), # would be 0.1721 if insured - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.no), cost_sharing_schedule=COST_SHARING_SCHEDULE, ) @@ -205,7 +205,7 @@ def test_medicaid_eligible_aca_below_threshold() -> None: """Income below 133% FPL produces eligible.""" result = aca_hi.is_medicaid_eligible( countable_income=jnp.array(10000.0), # below 15580 - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), medicaid_schedule=MEDICAID_SCHEDULE, ) assert result @@ -215,7 +215,7 @@ def test_medicaid_eligible_aca_above_threshold() -> None: """Income above 133% FPL produces not eligible.""" result = aca_hi.is_medicaid_eligible( countable_income=jnp.array(20000.0), # above 15580 - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), medicaid_schedule=MEDICAID_SCHEDULE, ) assert not result @@ -235,7 +235,7 @@ def test_premium_subsidy_exactly_at_100_fpl() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(1.0 * FPL_SINGLE), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -247,7 +247,7 @@ def test_premium_subsidy_exactly_at_400_fpl() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(4.0 * FPL_SINGLE), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) @@ -259,7 +259,7 @@ def test_premium_subsidy_just_below_400_fpl() -> None: result = aca_hi.premium_subsidy( hic_premium=jnp.array(5000.0), gross_income=jnp.array(4.0 * FPL_SINGLE - 1.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), buy_private=jnp.array(BuyPrivate.yes), premium_credit_schedule=PREMIUM_CREDIT_SCHEDULE, ) diff --git a/tests/test_beartype_claw.py b/tests/test_beartype_claw.py new file mode 100644 index 0000000..2fb7873 --- /dev/null +++ b/tests/test_beartype_claw.py @@ -0,0 +1,25 @@ +"""The beartype claw is live on the `aca_model` package. + +Registering `beartype_package("aca_model", ...)` in `aca_model/__init__.py` +instruments every `aca_model` module at import time, so a type violation in +any aca_model function — including the numerical DAG leaf functions fed into +pylcm — is caught at the call boundary rather than slipping through against +a dishonest annotation. + +The test calls a real model-builder with one argument of the wrong type; the +`BeartypeCallHintViolation` is what proves the claw is installed. +""" + +import pytest +from beartype.roar import BeartypeCallHintViolation +from helpers.model import make_baseline_model + + +def test_claw_checks_aca_model() -> None: + """An ill-typed argument to an `aca_model` function is rejected by beartype. + + `create_model` annotates `n_subjects` as `int`; passing a string is caught + by the claw before the value reaches pylcm's own `Model` perimeter. + """ + with pytest.raises(BeartypeCallHintViolation): + make_baseline_model(n_subjects="not an int") # ty: ignore[invalid-argument-type] diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index b1be815..5d7ddbd 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -29,7 +29,6 @@ def test_benchmark_model_simulates_end_to_end() -> None: initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) df = result.to_dataframe() @@ -69,13 +68,12 @@ def test_benchmark_simulate_obeys_borrowing_constraint() -> None: initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) df = result.to_dataframe(additional_targets=["cash_on_hand", "equivalence_scale"]) - alive = df.loc[df["regime"] != "dead"].copy() - consumption_dollars_floor = float(params["consumption_dollars_floor"]) - floor = consumption_dollars_floor * alive["equivalence_scale"].to_numpy() + alive = df.loc[df["regime_name"] != "dead"].copy() + consumption_equiv_floor = float(params["consumption_equiv_floor"]) + floor = consumption_equiv_floor * alive["equivalence_scale"].to_numpy() rhs = np.maximum(alive["cash_on_hand"].to_numpy(), floor) slack = rhs - alive["consumption_dollars"].to_numpy() assert (slack >= 0).all(), ( diff --git a/tests/test_budget_chain_integration.py b/tests/test_budget_chain_integration.py index f087d16..03e5fbe 100644 --- a/tests/test_budget_chain_integration.py +++ b/tests/test_budget_chain_integration.py @@ -53,12 +53,12 @@ def test_working_agent_cash_on_hand() -> None: result = combined( assets=jnp.array(50000.0), - rate_of_return=0.03, + rate_of_return=jnp.asarray(0.03), labor_income=jnp.array(40000.0), spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(0.0), pension_benefit=jnp.array(0.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, payroll_tax_schedule=PAYROLL_TAX_SCHEDULE, ss_tax_schedule=SS_TAX_SCHEDULE, @@ -87,12 +87,12 @@ def test_retired_agent_with_pension() -> None: result = combined( assets=jnp.array(200000.0), - rate_of_return=0.03, + rate_of_return=jnp.asarray(0.03), labor_income=jnp.array(0.0), spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(15000.0), pension_benefit=jnp.array(10000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, payroll_tax_schedule=PAYROLL_TAX_SCHEDULE, ss_tax_schedule=SS_TAX_SCHEDULE, diff --git a/tests/test_health_insurance.py b/tests/test_health_insurance.py index 06a23b7..56dd7af 100644 --- a/tests/test_health_insurance.py +++ b/tests/test_health_insurance.py @@ -17,7 +17,7 @@ def test_ssi_eligible_assets_too_high() -> None: result = health_insurance.is_ssi_eligible( assets=jnp.array(5000.0), countable_income=jnp.array(1000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, @@ -29,7 +29,7 @@ def test_ssi_eligible_income_too_high() -> None: result = health_insurance.is_ssi_eligible( assets=jnp.array(1000.0), countable_income=jnp.array(9000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, @@ -41,7 +41,7 @@ def test_ssi_eligible_no_medicare() -> None: result = health_insurance.is_ssi_eligible( assets=jnp.array(1000.0), countable_income=jnp.array(1000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), gets_medicare=jnp.asarray(False), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, @@ -53,7 +53,7 @@ def test_ssi_eligible_all_pass() -> None: result = health_insurance.is_ssi_eligible( assets=jnp.array(1000.0), countable_income=jnp.array(1000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, @@ -64,7 +64,7 @@ def test_ssi_eligible_all_pass() -> None: def test_ssi_benefit_eligible() -> None: result = health_insurance.ssi_benefit( countable_income=jnp.array(3000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), is_ssi_eligible=jnp.array(True), ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -74,7 +74,7 @@ def test_ssi_benefit_eligible() -> None: def test_ssi_benefit_not_eligible() -> None: result = health_insurance.ssi_benefit( countable_income=jnp.array(3000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), is_ssi_eligible=jnp.array(False), ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -149,8 +149,8 @@ def test_compute_table_uniform_transition(table_inputs: dict) -> None: _PREMIUM_KWARGS: dict = { "age": jnp.int32(60), - "good_health": jnp.array(True), - "is_married": jnp.array(False), + "good_health": jnp.int32(1), + "is_married": jnp.int32(0), "labor_supply": jnp.array(LaborSupply.h2000), "premium_intercept": jnp.asarray(1000.0), "premium_age": jnp.asarray(0.0), diff --git a/tests/test_initial_conditions_extreme_assets.py b/tests/test_initial_conditions_extreme_assets.py index 3b16522..17326fd 100644 --- a/tests/test_initial_conditions_extreme_assets.py +++ b/tests/test_initial_conditions_extreme_assets.py @@ -8,8 +8,8 @@ """ import jax.numpy as jnp +from _lcm.simulation.initial_conditions import validate_initial_conditions from lcm import DiscreteGrid -from lcm.simulation.initial_conditions import validate_initial_conditions from aca_model.agent.assets_and_income import borrowing_constraint from aca_model.agent.preferences import BenchmarkPrefType @@ -112,17 +112,17 @@ def test_extreme_negative_assets_subject_passes_validation() -> None: initial_conditions = { **initial_conditions, "assets": jnp.asarray([-1_000_000.0]), - "regime": jnp.asarray( + "regime_id": jnp.asarray( [model.regime_names_to_ids["retiree_nomc_inelig_canwork"]], dtype=jnp.int32, ), } - internal_params = model._process_params(params) # noqa: SLF001 + flat_params = model._process_params(params) # noqa: SLF001 validate_initial_conditions( initial_conditions=initial_conditions, - internal_regimes=model.internal_regimes, + regimes=model._regimes, # noqa: SLF001 regime_names_to_ids=model.regime_names_to_ids, - internal_params=internal_params, + flat_params=flat_params, ages=model.ages, ) diff --git a/tests/test_model_components.py b/tests/test_model_components.py index b3569c5..260c203 100644 --- a/tests/test_model_components.py +++ b/tests/test_model_components.py @@ -7,20 +7,20 @@ def test_equivalence_scale_single() -> None: - result = preferences.equivalence_scale(jnp.array(False), jnp.asarray(0.7)) + result = preferences.equivalence_scale(jnp.int32(0), jnp.asarray(0.7)) assert jnp.isclose(result, 1.0) def test_equivalence_scale_married() -> None: - result = preferences.equivalence_scale(jnp.array(True), jnp.asarray(0.7)) + result = preferences.equivalence_scale(jnp.int32(1), jnp.asarray(0.7)) assert jnp.isclose(result, 2.0**0.7) def test_leisure_not_working() -> None: result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(0.0), - good_health=jnp.array(1.0), - lagged_labor_supply=jnp.array(0), + good_health=jnp.int32(1), + lagged_labor_supply=jnp.int32(0), time_endowment=jnp.asarray(5000.0), leisure_cost_of_bad_health=jnp.asarray(500.0), fixed_cost_of_work=jnp.asarray(150.0), @@ -32,8 +32,8 @@ def test_leisure_not_working() -> None: def test_leisure_working_good_health() -> None: result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(2000.0), - good_health=jnp.array(1.0), - lagged_labor_supply=jnp.array(1), + good_health=jnp.int32(1), + lagged_labor_supply=jnp.int32(1), time_endowment=jnp.asarray(5000.0), leisure_cost_of_bad_health=jnp.asarray(500.0), fixed_cost_of_work=jnp.asarray(150.0), @@ -47,8 +47,8 @@ def test_leisure_working_good_health() -> None: def test_leisure_reentry_cost() -> None: result = preferences.leisure_canwork_retiree_or_nongroup( working_hours_value=jnp.array(2000.0), - good_health=jnp.array(1.0), - lagged_labor_supply=jnp.array(0), + good_health=jnp.int32(1), + lagged_labor_supply=jnp.int32(0), time_endowment=jnp.asarray(5000.0), leisure_cost_of_bad_health=jnp.asarray(500.0), fixed_cost_of_work=jnp.asarray(150.0), @@ -60,7 +60,7 @@ def test_leisure_reentry_cost() -> None: def test_leisure_bad_health() -> None: result = preferences.leisure_forcedout( - good_health=jnp.array(0.0), + good_health=jnp.int32(0), time_endowment=jnp.asarray(5000.0), leisure_cost_of_bad_health=jnp.asarray(500.0), ) diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 7ae6e36..d19758c 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -43,7 +43,7 @@ def build_regime(name: str): def test_model_creates_successfully() -> None: model = make_baseline_model(n_subjects=1) - assert len(model.regimes) == 19 + assert len(model.user_regimes) == 19 assert model.n_periods == 45 @@ -55,13 +55,13 @@ def test_model_age_range() -> None: def test_dead_regime_is_terminal() -> None: model = make_baseline_model(n_subjects=1) - assert model.regimes["dead"].terminal + assert model.user_regimes["dead"].terminal def test_non_terminal_regimes_not_terminal() -> None: model = make_baseline_model(n_subjects=1) for name in REGIME_SPECS: - assert not model.regimes[name].terminal + assert not model.user_regimes[name].terminal def test_regime_id_dead_is_last() -> None: @@ -192,7 +192,7 @@ def test_hcc_persistent_and_transitory_are_shock_grids() -> None: def test_aca_model_creates_successfully() -> None: model = make_aca_model(n_subjects=1, policy=PolicyVariant.ACA) - assert len(model.regimes) == 19 + assert len(model.user_regimes) == 19 assert model.n_periods == 45 @@ -233,7 +233,7 @@ def test_aca_other_regimes_have_no_aca_policy_keys() -> None: def test_all_policy_variants_create(policy: PolicyVariant) -> None: """All policy variants create valid models.""" model = make_aca_model(n_subjects=1, policy=policy) - assert len(model.regimes) == 19 + assert len(model.user_regimes) == 19 def test_aca_no_medicaid_expansion_keeps_baseline_medicaid() -> None: @@ -273,4 +273,4 @@ def test_aca_only_medicaid_expansion() -> None: def test_baseline_model_creates() -> None: """Baseline model creates successfully without PolicyVariant.""" model = make_baseline_model(n_subjects=1) - assert len(model.regimes) == 19 + assert len(model.user_regimes) == 19 diff --git a/tests/test_regime_transitions.py b/tests/test_regime_transitions.py index 87cdeea..9040fdc 100644 --- a/tests/test_regime_transitions.py +++ b/tests/test_regime_transitions.py @@ -45,8 +45,8 @@ def test_tied_stop_working_becomes_nongroup() -> None: transition = tied_canwork(gets_medicare=False, own=own, ng=ng) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.do_not_work), is_medicaid_eligible=jnp.array(False), survival_probs=SURVIVAL, @@ -62,8 +62,8 @@ def test_tied_keeps_working_stays_tied() -> None: transition = tied_canwork(gets_medicare=False, own=own, ng=ng) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(False), survival_probs=SURVIVAL, @@ -81,8 +81,8 @@ def test_retiree_medicaid_override_to_nongroup() -> None: transition = retiree_canwork(gets_medicare=False, own=own, ng=ng) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(True), survival_probs=SURVIVAL, @@ -97,8 +97,8 @@ def test_retiree_not_medicaid_stays_retiree() -> None: transition = retiree_canwork(gets_medicare=False, own=own, ng=ng) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(False), survival_probs=SURVIVAL, @@ -113,8 +113,8 @@ def test_retiree_forcedout_medicaid_override() -> None: transition = retiree_forcedout(gets_medicare=True, own=own, ng=ng) probs = transition( - age=80, - period=29, + age=jnp.int32(80), + period=jnp.int32(29), is_medicaid_eligible=jnp.array(True), survival_probs=SURVIVAL, ) @@ -150,9 +150,9 @@ def test_retiree_age_bracket_transitions( own, ng = make_targets("retiree_nomc_inelig_canwork") transition = retiree_canwork(gets_medicare=False, own=own, ng=ng) - period = int(age - MODEL_CONFIG.start_age) + period = jnp.int32(age - MODEL_CONFIG.start_age) probs = transition( - age=age, + age=jnp.asarray(age), period=period, labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(False), @@ -171,8 +171,8 @@ def test_nongroup_canwork_valid_probs() -> None: transition = nongroup_canwork(gets_medicare=False, own=own) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), survival_probs=SURVIVAL, ) @@ -186,8 +186,8 @@ def test_nongroup_forcedout_valid_probs() -> None: transition = nongroup_forcedout(gets_medicare=True, own=own) probs = transition( - age=80, - period=29, + age=jnp.int32(80), + period=jnp.int32(29), survival_probs=SURVIVAL, ) assert jnp.isclose(jnp.sum(probs), 1.0, atol=1e-6) @@ -203,8 +203,8 @@ def test_tied_medicaid_override_to_nongroup() -> None: transition = tied_canwork(gets_medicare=False, own=own, ng=ng) probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(True), survival_probs=SURVIVAL, @@ -218,9 +218,9 @@ def test_tied_at_medicare_age_with_medicaid() -> None: own, ng = make_targets("tied_nomc_choose_canwork") transition = tied_canwork(gets_medicare=False, own=own, ng=ng) - period = int(64 - MODEL_CONFIG.start_age) + period = jnp.int32(64 - MODEL_CONFIG.start_age) probs = transition( - age=64, + age=jnp.int32(64), period=period, labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(True), @@ -238,8 +238,8 @@ def test_survival_prob_determines_death_weight() -> None: survival = jnp.ones(N_PERIODS) * 0.85 probs = transition( - age=55, - period=4, + age=jnp.int32(55), + period=jnp.int32(4), labor_supply=jnp.array(LaborSupply.h2000), is_medicaid_eligible=jnp.array(False), survival_probs=survival, diff --git a/tests/test_social_security.py b/tests/test_social_security.py index c4c704f..e06912f 100644 --- a/tests/test_social_security.py +++ b/tests/test_social_security.py @@ -350,7 +350,7 @@ def test_benefit_inelig_pre65_disabled_below_sga() -> None: ) result = social_security.benefit_inelig_pre65( ssdi_pia=ssdi_val, - health=jnp.array(0), # disabled + health=jnp.int32(0), # disabled labor_income=jnp.array(0.0), ssdi_substantial_gainful_activity=SSDI_SGA, ) @@ -368,7 +368,7 @@ def test_benefit_inelig_pre65_disabled_above_sga() -> None: ) result = social_security.benefit_inelig_pre65( ssdi_pia=ssdi_val, - health=jnp.array(0), # disabled + health=jnp.int32(0), # disabled labor_income=jnp.array(20000.0), ssdi_substantial_gainful_activity=SSDI_SGA, ) @@ -379,7 +379,7 @@ def test_benefit_inelig_pre65_not_disabled() -> None: """Non-disabled agent: benefit = 0.""" result = social_security.benefit_inelig_pre65( ssdi_pia=jnp.array(1000.0), - health=jnp.array(2), # good health + health=jnp.int32(2), # good health labor_income=jnp.array(0.0), ssdi_substantial_gainful_activity=SSDI_SGA, ) diff --git a/tests/test_ss_benefit_integration.py b/tests/test_ss_benefit_integration.py index 81a1b61..bd0e9b2 100644 --- a/tests/test_ss_benefit_integration.py +++ b/tests/test_ss_benefit_integration.py @@ -56,7 +56,7 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), - health=jnp.array(2), + health=jnp.int32(2), labor_supply=jnp.array(LaborSupply.h2000), labor_income=jnp.array(30000.0), early_ret_adjustment=jnp.array([0.75]), @@ -74,7 +74,7 @@ def test_earnings_test_reduces_benefit_before_fra() -> None: period=jnp.int32(0), claim_ss=jnp.array(ClaimedSS.yes), claimed_ss=jnp.array(ClaimedSS.no), - health=jnp.array(2), + health=jnp.int32(2), labor_supply=jnp.array(LaborSupply.do_not_work), labor_income=jnp.array(0.0), early_ret_adjustment=jnp.array([0.75]), diff --git a/tests/test_ssi_medicaid_integration.py b/tests/test_ssi_medicaid_integration.py index 345d751..3bc6d92 100644 --- a/tests/test_ssi_medicaid_integration.py +++ b/tests/test_ssi_medicaid_integration.py @@ -35,11 +35,11 @@ def test_low_income_qualifies_for_ssi_and_medicaid() -> None: spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(500.0), pension_benefit=jnp.array(0.0), - ssi_ignored_overall=20.0, - ssi_ignored_earned=65.0, + ssi_ignored_overall=jnp.asarray(20.0), + ssi_ignored_earned=jnp.asarray(65.0), assets=jnp.array(1000.0), - spousal_income=jnp.array(0), - gets_medicare=True, + spousal_income=jnp.int32(0), + gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -66,11 +66,11 @@ def test_high_income_ineligible_for_ssi() -> None: spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(2000.0), pension_benefit=jnp.array(0.0), - ssi_ignored_overall=20.0, - ssi_ignored_earned=65.0, + ssi_ignored_overall=jnp.asarray(20.0), + ssi_ignored_earned=jnp.asarray(65.0), assets=jnp.array(1000.0), - spousal_income=jnp.array(0), - gets_medicare=True, + spousal_income=jnp.int32(0), + gets_medicare=jnp.asarray(True), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -95,11 +95,11 @@ def test_no_medicare_blocks_ssi_under_baseline() -> None: spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(0.0), pension_benefit=jnp.array(0.0), - ssi_ignored_overall=20.0, - ssi_ignored_earned=65.0, + ssi_ignored_overall=jnp.asarray(20.0), + ssi_ignored_earned=jnp.asarray(65.0), assets=jnp.array(100.0), - spousal_income=jnp.array(0), - gets_medicare=False, + spousal_income=jnp.int32(0), + gets_medicare=jnp.asarray(False), ssi_assets_test=SSI_ASSETS_TEST, ssi_maximum_benefit=SSI_MAX_BENEFIT, ) @@ -119,26 +119,26 @@ def test_medicaid_reduces_oop() -> None: oop_medicaid = combined( total_health_costs=jnp.array(10000.0), buy_private=jnp.array(BuyPrivate.yes), - deductible=500.0, - coinsurance_rate=0.2, - oop_max=5000.0, + deductible=jnp.asarray(500.0), + coinsurance_rate=jnp.asarray(0.2), + oop_max=jnp.asarray(5000.0), is_ssi_eligible=jnp.array(True), - deductible_medicaid=100.0, - coinsurance_rate_medicaid=0.05, - oop_max_medicaid=1000.0, + deductible_medicaid=jnp.asarray(100.0), + coinsurance_rate_medicaid=jnp.asarray(0.05), + oop_max_medicaid=jnp.asarray(1000.0), ) # Not Medicaid-eligible: primary OOP only oop_no_medicaid = combined( total_health_costs=jnp.array(10000.0), buy_private=jnp.array(BuyPrivate.yes), - deductible=500.0, - coinsurance_rate=0.2, - oop_max=5000.0, + deductible=jnp.asarray(500.0), + coinsurance_rate=jnp.asarray(0.2), + oop_max=jnp.asarray(5000.0), is_ssi_eligible=jnp.array(False), - deductible_medicaid=100.0, - coinsurance_rate_medicaid=0.05, - oop_max_medicaid=1000.0, + deductible_medicaid=jnp.asarray(100.0), + coinsurance_rate_medicaid=jnp.asarray(0.05), + oop_max_medicaid=jnp.asarray(1000.0), ) assert oop_medicaid < oop_no_medicaid diff --git a/tests/test_taxes.py b/tests/test_taxes.py index 958eb0b..862c024 100644 --- a/tests/test_taxes.py +++ b/tests/test_taxes.py @@ -224,7 +224,7 @@ def test_taxable_ss_benefit_below_threshold() -> None: spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), ss_benefit=jnp.array(5000.0), pension_benefit=jnp.array(0.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), ss_tax_schedule=SS_TAX_SCHEDULE, ) # Provisional income = 10000 + 0.5*5000 = 12500, below 25000 threshold @@ -235,7 +235,7 @@ def test_gross_income_basic() -> None: result = taxes.gross_income( capital_income=jnp.array(1000.0), labor_income=jnp.array(5000.0), - spousal_income=jnp.array(1), + spousal_income=jnp.int32(1), spousal_income_amounts=jnp.array([0.0, 2000.0, 20000.0]), taxable_ss_benefit=jnp.array(500.0), pension_benefit=jnp.array(300.0), @@ -247,7 +247,7 @@ def test_after_tax_income_zero() -> None: gi = taxes.gross_income( capital_income=jnp.array(0.0), labor_income=jnp.array(0.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), taxable_ss_benefit=jnp.array(0.0), pension_benefit=jnp.array(0.0), @@ -257,7 +257,7 @@ def test_after_tax_income_zero() -> None: ss_benefit=jnp.array(0.0), taxable_ss_benefit=jnp.array(0.0), labor_income=jnp.array(0.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, payroll_tax_schedule=PAYROLL_TAX_SCHEDULE, ) @@ -269,7 +269,7 @@ def test_after_tax_income_low_bracket() -> None: gi = taxes.gross_income( capital_income=jnp.array(0.0), labor_income=jnp.array(gross), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), spousal_income_amounts=jnp.array([0.0, 0.0, 20000.0]), taxable_ss_benefit=jnp.array(0.0), pension_benefit=jnp.array(0.0), @@ -279,7 +279,7 @@ def test_after_tax_income_low_bracket() -> None: ss_benefit=jnp.array(0.0), taxable_ss_benefit=jnp.array(0.0), labor_income=jnp.array(gross), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, payroll_tax_schedule=PAYROLL_TAX_SCHEDULE, ) @@ -293,7 +293,7 @@ def test_after_tax_income_low_bracket() -> None: def test_marginal_tax_rate_low_bracket() -> None: result = taxes.marginal_rate( gross_income=jnp.array(5000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, ) # 5000 is in bracket 1 (0-6200), rate = 0.0765 @@ -303,7 +303,7 @@ def test_marginal_tax_rate_low_bracket() -> None: def test_marginal_tax_rate_mid_bracket() -> None: result = taxes.marginal_rate( gross_income=jnp.array(10000.0), - spousal_income=jnp.array(0), + spousal_income=jnp.int32(0), income_tax_schedule=INCOME_TAX_SCHEDULE, ) # 10000 is in bracket 2 (6200-15275), rate = 0.199