diff --git a/.github/workflows/benchmark-pr.yml b/.github/workflows/benchmark-pr.yml index e63344ee6..9181ce046 100644 --- a/.github/workflows/benchmark-pr.yml +++ b/.github/workflows/benchmark-pr.yml @@ -7,6 +7,10 @@ on: pull_request: branches: - '**' + paths-ignore: + - docs/** + - '**.md' + - '**.ipynb' workflow_dispatch: null jobs: run-benchmarks: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d26a89c51..197f9bb1d 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -11,6 +11,11 @@ on: pull_request: branches: - '**' + paths-ignore: + - docs/** + - benchmarks/** + - '**.md' + - '**.ipynb' jobs: run-tests: name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }} diff --git a/AGENTS.md b/AGENTS.md index 6b72b9b74..d991435df 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -221,11 +221,15 @@ Model( ### Core Methods -- `model.solve(params=params)` - Solve the model and return value function arrays per - period and regime -- `model.simulate(params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=period_to_regime_to_V_arr)` +- `model.solve(params=params, log_level="debug")` - Solve the model and return value + function arrays per period and regime +- `model.simulate(params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=period_to_regime_to_V_arr, log_level="debug")` \- Simulate forward given solution. `period_to_regime_to_V_arr` is optional; when `None`, the model is solved automatically before simulating. +- `log_level` is **required** on both `solve()` and `simulate()` + (`off < warning < progress < debug`). It governs all runtime validation: `"off"` skips + it, `"warning"` / `"progress"` warn and continue, `"debug"` raises. Start projects at + `"debug"`. ### Derived Categoricals @@ -245,6 +249,7 @@ result = model.simulate( params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, + log_level="debug", ) # Convert to DataFrame (deferred computation) @@ -562,6 +567,16 @@ Code structure should be self-evident from function names and ordering. - Always use **plotly** for visualizations, never matplotlib. Use `plotly.graph_objects` and `plotly.subplots.make_subplots`. +### Notebooks + +Explanation notebooks live in `docs/explanations/*.ipynb`. After editing one, verify: + +- Each cell's `source` is a JSON array of lines (one array element per line), never a + single multi-line string — a one-string `source` produces an unreadable diff. +- Outputs and execution counts are stripped (`pixi run nbstripout `). +- Markdown and code use literal UTF-8 characters (`—`, `→`, `μ`), never `\u`-style + escape sequences. + ### Key Dependencies - **jax**: Numerical computation diff --git a/benchmarks/bench_aca_baseline.py b/benchmarks/bench_aca_baseline.py index 91ea8ecdc..a985a8a94 100644 --- a/benchmarks/bench_aca_baseline.py +++ b/benchmarks/bench_aca_baseline.py @@ -8,6 +8,15 @@ over 19 regimes, DAG resolution, pref_type batching) while shrinking per-call numerical work so the benchmark fits in an asv invocation. +Two simulate variants run as separate benchmark classes: + +- `AcaBaseline` — `log_level="off"`, `log_path=None`: runtime validation + and diagnostic logging disabled. +- `AcaBaselineDebugLog` — `log_level="debug"` with snapshots written to a + temporary directory: the slow path that runs every validation check + and persists diagnostic snapshots. The gap to `AcaBaseline` is the + validation + logging overhead. + Requires the `aca_model` package to be importable. Use the `benchmarks-cuda12` pixi environment, which pulls aca-model from its public git URL. Inside the aca-dev monorepo the editable path install @@ -28,15 +37,20 @@ compilation is still per-method — the JIT cache is process-local — but the persistent XLA disk cache keeps second and third compiles fast. -- `AcaBaselineGpuPeakMem` runs in a separate subprocess via `_gpu_mem` - that does not go through ASV's `setup_cache` pipeline. It calls - `setup_for_gpu_measurement()` (rebuild fresh, no warm-up) then - `time_execution()` to measure cold peak memory. Both methods - accept `cache=None` so the same callable serves ASV (cache passed - in) and the subprocess (cache omitted). +- `AcaBaselineDebugLog` subclasses `AcaBaseline`, overriding only the + `log_level` and the per-run temporary `log_path`; it reuses the same + `setup_cache` / metric methods. +- `AcaBaselineGpuPeakMem` and `AcaBaselineDebugLogGpuPeakMem` run in a + separate subprocess via `_gpu_mem` that does not go through ASV's + `setup_cache` pipeline. They call `setup_for_gpu_measurement()` + (rebuild fresh, no warm-up) then `time_execution()` to measure cold + peak memory. Both methods accept `cache=None` so the same callable + serves ASV (cache passed in) and the subprocess (cache omitted). """ import gc +import shutil +import tempfile import time import cloudpickle @@ -79,6 +93,8 @@ def _build() -> tuple[object, object, object]: class AcaBaseline: + """aca-baseline simulate with runtime validation and logging off.""" + timeout = 3600 # Pin every ASV sample knob to 1 so setup runs once per subprocess # and one warm call is timed. `timeout=3600` gives headroom for the @@ -88,6 +104,10 @@ class AcaBaseline: number = 1 warmup_time = 0 + # Simulate logging configuration; `AcaBaselineDebugLog` overrides both. + log_level = "off" + log_path: str | None = None + def setup_cache(self) -> bytes: # Build once per ASV benchmark class run and hand the result to # every method via ASV's setup_cache mechanism. ASV pickles the @@ -103,13 +123,7 @@ def setup(self, cache: bytes) -> None: ) # Warm-trigger compilation so time_execution runs on a hot kernel. start = time.perf_counter() - self.model.simulate( - params=self.model_params, - initial_conditions=self.initial_conditions, - period_to_regime_to_V_arr=None, - log_level="off", - check_initial_conditions=False, - ) + self._simulate() self._compile_time = time.perf_counter() - start def setup_for_gpu_measurement(self) -> None: @@ -118,23 +132,20 @@ def setup_for_gpu_measurement(self) -> None: # (build + compile + run, no warm-up). self.model, self.model_params, self.initial_conditions = _build() - def time_execution(self, cache: bytes | None = None) -> None: + def _simulate(self) -> None: self.model.simulate( params=self.model_params, initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=None, - log_level="off", - check_initial_conditions=False, + log_level=self.log_level, + log_path=self.log_path, ) + def time_execution(self, cache: bytes | None = None) -> None: + self._simulate() + def peakmem_execution(self, cache: bytes | None = None) -> None: - self.model.simulate( - params=self.model_params, - initial_conditions=self.initial_conditions, - period_to_regime_to_V_arr=None, - log_level="off", - check_initial_conditions=False, - ) + self._simulate() def teardown(self, cache: bytes | None = None) -> None: import jax @@ -148,7 +159,42 @@ def track_compilation_time(self, cache: bytes | None = None) -> float: track_compilation_time.unit = "seconds" +class AcaBaselineDebugLog(AcaBaseline): + """aca-baseline simulate at `log_level="debug"` with snapshot logging. + + Runs every runtime-validation check and persists diagnostic + snapshots to a temporary directory. Measured against `AcaBaseline` + (`log_level="off"`), the difference is the validation + logging + overhead. + """ + + log_level = "debug" + + def setup(self, cache: bytes) -> None: + self.log_path = tempfile.mkdtemp(prefix="aca-bench-debug-log-") + super().setup(cache) + + def setup_for_gpu_measurement(self) -> None: + # Mirror `setup`'s log_path setup so the cold-measurement + # subprocess exercises snapshot writing too. The tmpdir leaks + # when the subprocess exits — acceptable since /tmp is OS-cleaned. + self.log_path = tempfile.mkdtemp(prefix="aca-bench-debug-log-") + super().setup_for_gpu_measurement() + + def teardown(self, cache: bytes | None = None) -> None: + super().teardown(cache) + if self.log_path is not None: + shutil.rmtree(self.log_path, ignore_errors=True) + self.log_path = None + + class AcaBaselineGpuPeakMem(_gpu_mem.GpuPeakMem): bench_module = "benchmarks.bench_aca_baseline" bench_class = "AcaBaseline" timeout = 3600 + + +class AcaBaselineDebugLogGpuPeakMem(_gpu_mem.GpuPeakMem): + bench_module = "benchmarks.bench_aca_baseline" + bench_class = "AcaBaselineDebugLog" + timeout = 3600 diff --git a/benchmarks/bench_mahler_yum.py b/benchmarks/bench_mahler_yum.py index 046757a82..ead4a5d86 100644 --- a/benchmarks/bench_mahler_yum.py +++ b/benchmarks/bench_mahler_yum.py @@ -44,7 +44,6 @@ def setup(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) self._compile_time = time.perf_counter() - start @@ -57,7 +56,6 @@ def time_execution(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) def peakmem_execution(self): @@ -66,7 +64,6 @@ def peakmem_execution(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) def teardown(self): diff --git a/benchmarks/bench_precautionary_savings.py b/benchmarks/bench_precautionary_savings.py index 822d3747d..056077394 100644 --- a/benchmarks/bench_precautionary_savings.py +++ b/benchmarks/bench_precautionary_savings.py @@ -100,7 +100,6 @@ def setup(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=self.period_to_regime_to_V_arr, log_level="off", - check_initial_conditions=False, ) self._compile_time = time.perf_counter() - start @@ -113,7 +112,6 @@ def time_execution(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=self.period_to_regime_to_V_arr, log_level="off", - check_initial_conditions=False, ) def peakmem_execution(self): @@ -122,7 +120,6 @@ def peakmem_execution(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=self.period_to_regime_to_V_arr, log_level="off", - check_initial_conditions=False, ) def teardown(self): @@ -157,7 +154,6 @@ def setup(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) self._compile_time = time.perf_counter() - start @@ -170,7 +166,6 @@ def time_execution(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) def peakmem_execution(self): @@ -179,7 +174,6 @@ def peakmem_execution(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) def teardown(self): @@ -215,7 +209,6 @@ def setup(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) self._compile_time = time.perf_counter() - start @@ -228,7 +221,6 @@ def time_execution(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) def peakmem_execution(self): @@ -237,7 +229,6 @@ def peakmem_execution(self): initial_conditions=self.initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) def teardown(self): diff --git a/benchmarks/pr_comment.py b/benchmarks/pr_comment.py index 35b0df766..55ee93e9e 100644 --- a/benchmarks/pr_comment.py +++ b/benchmarks/pr_comment.py @@ -35,6 +35,8 @@ _CLASS_DISPLAY = { "AcaBaseline": "aca-baseline", "AcaBaselineGpuPeakMem": "aca-baseline", + "AcaBaselineDebugLog": "aca-baseline-debug", + "AcaBaselineDebugLogGpuPeakMem": "aca-baseline-debug", "MahlerYum": "Mahler-Yum", "MahlerYumGpuPeakMem": "Mahler-Yum", "PrecautionarySavingsSolve": "Precautionary Savings - Solve", diff --git a/docs/examples/mahler_yum_2024.md b/docs/examples/mahler_yum_2024.md index 4ce0df0a8..7413fcc26 100644 --- a/docs/examples/mahler_yum_2024.md +++ b/docs/examples/mahler_yum_2024.md @@ -46,6 +46,7 @@ result = MAHLER_YUM_MODEL.simulate( ), }, period_to_regime_to_V_arr=None, + log_level="debug", seed=8295, ) ``` diff --git a/docs/examples/mortality.md b/docs/examples/mortality.md index 687fb528b..223ca212f 100644 --- a/docs/examples/mortality.md +++ b/docs/examples/mortality.md @@ -29,6 +29,7 @@ result = model.simulate( "regime_id": jnp.full(100, model.regime_names_to_ids["working_life"]), }, period_to_regime_to_V_arr=None, + log_level="debug", seed=1234, ) diff --git a/docs/examples/precautionary_savings.md b/docs/examples/precautionary_savings.md index 6e92d327a..5f932fe42 100644 --- a/docs/examples/precautionary_savings.md +++ b/docs/examples/precautionary_savings.md @@ -30,6 +30,7 @@ result = model.simulate( "regime_id": jnp.full(100, model.regime_names_to_ids["alive"]), }, period_to_regime_to_V_arr=None, + log_level="debug", ) df = result.to_dataframe(additional_targets="all") diff --git a/docs/examples/precautionary_savings_health.md b/docs/examples/precautionary_savings_health.md index 38a244125..bc8c7053b 100644 --- a/docs/examples/precautionary_savings_health.md +++ b/docs/examples/precautionary_savings_health.md @@ -35,6 +35,7 @@ result = model.simulate( "regime_id": jnp.full(1_000, model.regime_names_to_ids["working_life"]), }, period_to_regime_to_V_arr=None, + log_level="debug", ) df = result.to_dataframe(additional_targets="all") diff --git a/docs/examples/tiny.md b/docs/examples/tiny.md index aa943d395..27e3dca91 100644 --- a/docs/examples/tiny.md +++ b/docs/examples/tiny.md @@ -35,6 +35,7 @@ result = model.simulate( params=params, initial_conditions=initial_df, period_to_regime_to_V_arr=None, + log_level="debug", ) df = result.to_dataframe(additional_targets="all") diff --git a/docs/explanations/beta_delta.ipynb b/docs/explanations/beta_delta.ipynb index e9a11d0c6..8638c5a1c 100644 --- a/docs/explanations/beta_delta.ipynb +++ b/docs/explanations/beta_delta.ipynb @@ -319,6 +319,7 @@ " params={\"working\": {\"H\": {\"discount_factor\": DELTA}}},\n", " initial_conditions=initial_conditions,\n", " period_to_regime_to_V_arr=None,\n", + " log_level=\"debug\",\n", ")\n", "\n", "df_exp = result_exp.to_dataframe().query('regime == \"working\"')\n", @@ -363,6 +364,7 @@ " },\n", " initial_conditions=initial_conditions,\n", " period_to_regime_to_V_arr=None,\n", + " log_level=\"debug\",\n", ")\n", "\n", "df_naive = result_naive.to_dataframe().query('regime == \"working\"')\n", diff --git a/docs/explanations/stochastic_transitions.ipynb b/docs/explanations/stochastic_transitions.ipynb index ebfb018b6..feb800bf0 100644 --- a/docs/explanations/stochastic_transitions.ipynb +++ b/docs/explanations/stochastic_transitions.ipynb @@ -21,12 +21,103 @@ { "cell_type": "markdown", "metadata": {}, - "source": "## The Computation Hierarchy\n\nThe action-value function $Q$ decomposes into nested expectations:\n\n$$\nQ(s, a) = U(s, a) + \\beta \\sum_{r'} P(r' \\mid s, a)\n \\sum_{s'} P(s' \\mid s, a, r') \\, V(s', r')\n$$\n\nThe **outer sum** runs over target regimes $r'$, weighted by regime transition\nprobabilities. The **inner sum** runs over next-period states $s'$, weighted by\nstate transition probabilities.\n\nThis hierarchy is reflected directly in the solve loop\n(`src/lcm/regime_building/Q_and_F.py`):\n\n```python\nfor target_regime_name in active_regimes_next_period:\n next_states = state_transitions[target_regime_name](...)\n weights = next_stochastic_states_weights[target_regime_name](...)\n next_V_expected = jnp.average(next_V_arr, weights=joint_weights)\n continuation_value += regime_probs[target_regime_name] * next_V_expected\n```\n\n- The **outer sum** is a Python `for` loop over regime names.\n- The **inner sum** is a vectorised `jnp.average` over state grid points.\n\nThis structural difference — loop-by-name vs array operation — is the fundamental\nreason why regime and state transition probabilities use different data structures." + "source": [ + "## The Computation Hierarchy\n", + "\n", + "The action-value function $Q$ decomposes into nested expectations:\n", + "\n", + "$$\n", + "Q(s, a) = U(s, a) + \\beta \\sum_{r'} P(r' \\mid s, a)\n", + " \\sum_{s'} P(s' \\mid s, a, r') \\, V(s', r')\n", + "$$\n", + "\n", + "The **outer sum** runs over target regimes $r'$, weighted by regime transition\n", + "probabilities. The **inner sum** runs over next-period states $s'$, weighted by\n", + "state transition probabilities.\n", + "\n", + "This hierarchy is reflected directly in the solve loop\n", + "(`src/lcm/regime_building/Q_and_F.py`):\n", + "\n", + "```python\n", + "for target_regime_name in active_regimes_next_period:\n", + " next_states = state_transitions[target_regime_name](...)\n", + " weights = next_stochastic_states_weights[target_regime_name](...)\n", + " next_V_expected = jnp.average(next_V_arr, weights=joint_weights)\n", + " continuation_value += regime_probs[target_regime_name] * next_V_expected\n", + "```\n", + "\n", + "- The **outer sum** is a Python `for` loop over regime names.\n", + "- The **inner sum** is a vectorised `jnp.average` over state grid points.\n", + "\n", + "This structural difference — loop-by-name vs array operation — is the fundamental\n", + "reason why regime and state transition probabilities use different data structures." + ] }, { "cell_type": "markdown", "metadata": {}, - "source": "## Regime Transition Probabilities\n\n### Runtime format\n\n`MappingProxyType[str, Array]` — an immutable dict mapping regime names to\nper-subject probability arrays.\n\n```python\n# Example: 3 subjects, 2 regimes\n{\n \"working\": Array([0.8, 0.65, 0.9]),\n \"retired\": Array([0.2, 0.35, 0.1]),\n}\n```\n\n### Why a dict?\n\nThe solve loop iterates over target regimes by name, and every co-indexed data\nstructure — `state_transitions`, `next_V`, `next_stochastic_states_weights` — is\nalso keyed by regime name. The dict format fits naturally:\n\n```python\ncontinuation_value += regime_probs[target_regime_name] * next_V_expected\n```\n\nSwitching to a plain array indexed by integer regime ID would require:\n\n```python\ncontinuation_value += regime_probs_arr[regime_names_to_ids[target_regime_name]] * next_V_expected\n```\n\nThis gains nothing — the loop must still iterate by name because the other dicts\nforce it — and adds noise.\n\n### The array-to-dict-to-array round-trip\n\nThe user's transition function returns a plain array indexed by integer regime ID.\nThe processing pipeline immediately wraps it into a dict\n(`_wrap_regime_transition` in `src/lcm/regime_building/processing.py`):\n\n```python\ndef wrapped(*args, **kwargs):\n result = func(*args, **kwargs) # Array[n_regimes]\n return MappingProxyType(\n {name: result[idx] for idx, name in enumerate(regime_names)}\n )\n```\n\nDuring simulation, `draw_key_from_dict` converts it back to a matrix for\nsampling (`src/lcm/simulation/transitions.py`):\n\n```python\nregime_transition_probs = jnp.array(list(d.values())).T # dict → matrix\nregime_ids = jnp.array([regime_names_to_ids[name] for name in d])\n# ... jax.random.choice(key, regime_ids, p=p)\n```\n\nThis round-trip is the cost of keeping the solve path readable. The alternative —\nusing plain arrays everywhere — would require restructuring the solve loop and\nall co-indexed data structures, a large change with no functional benefit." + "source": [ + "## Regime Transition Probabilities\n", + "\n", + "### Runtime format\n", + "\n", + "`MappingProxyType[str, Array]` — an immutable dict mapping regime names to\n", + "per-subject probability arrays.\n", + "\n", + "```python\n", + "# Example: 3 subjects, 2 regimes\n", + "{\n", + " \"working\": Array([0.8, 0.65, 0.9]),\n", + " \"retired\": Array([0.2, 0.35, 0.1]),\n", + "}\n", + "```\n", + "\n", + "### Why a dict?\n", + "\n", + "The solve loop iterates over target regimes by name, and every co-indexed data\n", + "structure — `state_transitions`, `next_V`, `next_stochastic_states_weights` — is\n", + "also keyed by regime name. The dict format fits naturally:\n", + "\n", + "```python\n", + "continuation_value += regime_probs[target_regime_name] * next_V_expected\n", + "```\n", + "\n", + "Switching to a plain array indexed by integer regime ID would require:\n", + "\n", + "```python\n", + "continuation_value += regime_probs_arr[regime_names_to_ids[target_regime_name]] * next_V_expected\n", + "```\n", + "\n", + "This gains nothing — the loop must still iterate by name because the other dicts\n", + "force it — and adds noise.\n", + "\n", + "### The array-to-dict-to-array round-trip\n", + "\n", + "The user's transition function returns a plain array indexed by integer regime ID.\n", + "The processing pipeline immediately wraps it into a dict\n", + "(`_wrap_regime_transition` in `src/lcm/regime_building/processing.py`):\n", + "\n", + "```python\n", + "def wrapped(*args, **kwargs):\n", + " result = func(*args, **kwargs) # Array[n_regimes]\n", + " return MappingProxyType(\n", + " {name: result[idx] for idx, name in enumerate(regime_names)}\n", + " )\n", + "```\n", + "\n", + "During simulation, `draw_key_from_dict` converts it back to a matrix for\n", + "sampling (`src/lcm/simulation/transitions.py`):\n", + "\n", + "```python\n", + "regime_transition_probs = jnp.array(list(d.values())).T # dict → matrix\n", + "regime_ids = jnp.array([regime_names_to_ids[name] for name in d])\n", + "# ... jax.random.choice(key, regime_ids, p=p)\n", + "```\n", + "\n", + "This round-trip is the cost of keeping the solve path readable. The alternative —\n", + "using plain arrays everywhere — would require restructuring the solve loop and\n", + "all co-indexed data structures, a large change with no functional benefit." + ] }, { "cell_type": "markdown", @@ -186,8 +277,6 @@ "2. **User-facing pandas integration** — `solve()` and `simulate()` accept `pd.Series`\n", " values in params directly. Series with labeled MultiIndex levels are converted to\n", " correctly shaped JAX arrays automatically (`array_from_series` under the hood).\n", - " `validate_transition_probs` checks the resulting probability arrays for both regime\n", - " and state transitions.\n", "\n", "Full harmonisation of the runtime formats is blocked by the solve loop structure:\n", "regime probabilities are consumed in a Python loop over named regimes (because\n", diff --git a/docs/user_guide/debugging.md b/docs/user_guide/debugging.md index 5dcc0b4fc..555a808f9 100644 --- a/docs/user_guide/debugging.md +++ b/docs/user_guide/debugging.md @@ -27,38 +27,36 @@ Re-enable JIT once the issue is resolved. ## Log levels -The `log_level` parameter controls both console output and disk persistence: - -| Level | Output | Persistence | -| ---------------------- | ----------------------------------------------------------------- | ------------------------ | -| `"off"` | Nothing (good for HPC batch jobs) | No | -| `"warning"` | NaN/Inf warnings in value functions | No | -| `"progress"` (default) | Progress and timing per period, total elapsed time | No | -| `"debug"` | All above + V_arr statistics per regime, regime transition counts | Yes, requires `log_path` | +`log_level` controls console verbosity *and* the runtime-validation policy — how +`solve()` / `simulate()` react to an invalid transition-probability ensemble or a NaN +value function. See [Solving and Simulating](solving_and_simulating.md) for the full +`log_level` × `log_path` behaviour table. ```python -# Silent — no console output at all +# Silent — no console output, no validation period_to_regime_to_V_arr = model.solve(params=params, log_level="off") -# Warnings only — alerts on NaN/Inf but no progress output +# Warnings only — invalid input is logged, the run continues period_to_regime_to_V_arr = model.solve(params=params, log_level="warning") -# Progress (default) — timing per period -period_to_regime_to_V_arr = model.solve(params=params) # log_level="progress" +# Debug — validation raises, full diagnostics +period_to_regime_to_V_arr = model.solve(params=params, log_level="debug") -# Debug — full diagnostics + snapshot persistence +# Debug + snapshot persistence period_to_regime_to_V_arr = model.solve( params=params, log_level="debug", log_path="./debug/" ) ``` -Using `log_level="debug"` without providing `log_path` raises a `ValueError`. +`log_path` is optional at every level — including `"debug"`. ## Debug snapshots -When `log_level="debug"` and `log_path` is provided, pylcm saves a **snapshot -directory** containing all inputs and outputs. This lets you reconstruct a failed run on -a different machine. +When `log_path` is provided, pylcm saves a **snapshot directory** containing all inputs +and outputs, so you can reconstruct a failed run on a different machine. In `"debug"` +mode a snapshot is written on every solve and on a raised failure; in `"warning"` / +`"progress"` mode one is written whenever a warned failure leaves NaN in the value +function. ### What's saved @@ -119,7 +117,9 @@ snapshot.params # the user parameters snapshot.period_to_regime_to_V_arr # value function arrays (loaded from HDF5) # Re-run the solve to reproduce the result -period_to_regime_to_V_arr = snapshot.model.solve(params=snapshot.params) +period_to_regime_to_V_arr = snapshot.model.solve( + params=snapshot.params, log_level="debug" +) ``` For large snapshots, skip fields you don't need: @@ -212,7 +212,7 @@ model = Model( ) # Call solve with the bad parameters --- the traceback will be readable -period_to_regime_to_V_arr = model.solve(params=bad_params) +period_to_regime_to_V_arr = model.solve(params=bad_params, log_level="debug") ``` The traceback now points to the exact line in your user-defined functions where the NaN @@ -229,7 +229,7 @@ import jax.numpy as jnp import plotly.graph_objects as go from plotly.subplots import make_subplots -period_to_regime_to_V_arr = model.solve(params=params) +period_to_regime_to_V_arr = model.solve(params=params, log_level="debug") # Check for issues for period, regimes in period_to_regime_to_V_arr.items(): @@ -255,10 +255,9 @@ fig.show() ## Failure snapshots -When `log_path` is set and `solve()` raises `InvalidValueFunctionError`, a snapshot is -saved automatically --- even without `log_level="debug"`. This lets you inspect the -partial solution (value functions for periods that completed before the error) on -another machine. +When `log_path` is set and `solve()` raises `InvalidValueFunctionError` (in `"debug"` +mode), a snapshot is saved automatically. This lets you inspect the partial solution +(value functions for periods that completed before the error) on another machine. ```python # log_path is enough to get a failure snapshot @@ -266,6 +265,7 @@ result = model.simulate( params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, + log_level="debug", log_path="./debug/", ) ``` diff --git a/docs/user_guide/pandas_interop.md b/docs/user_guide/pandas_interop.md index 221604cde..e366b570c 100644 --- a/docs/user_guide/pandas_interop.md +++ b/docs/user_guide/pandas_interop.md @@ -28,6 +28,7 @@ result = model.simulate( params=params, initial_conditions=df, period_to_regime_to_V_arr=None, + log_level="debug", ) ``` @@ -63,6 +64,7 @@ result = model.simulate( params=params, initial_conditions=df, period_to_regime_to_V_arr=None, + log_level="debug", ) ``` @@ -174,31 +176,10 @@ def is_good_health(health: DiscreteState) -> IntND: ## Validating State Transition Probabilities -Check that a state transition probability array has the correct shape, values in -$[0, 1]$, and rows that sum to 1: - -```python -from lcm import validate_transition_probs - -validate_transition_probs( - probs=health_probs, - model=model, - regime_name="working", - state_name="health", -) -``` - -Raises `ValueError` if: - -- The array shape doesn't match the expected dimensions (indexing parameters + outcome - axis) -- Any value is outside $[0, 1]$ -- Any row (slice along the last axis) doesn't sum to 1 - -Call this after building the array to catch mistakes early. - -Regime transition probabilities are validated automatically during `model.solve()` and -`model.simulate()`, so this helper covers only state transitions. +State (and regime) transition probabilities are validated automatically during +`model.solve()` and `model.simulate()`: their shape, $[0, 1]$ range, and rows summing to +1 are checked on a sweep over the regime's grids. Whether a failure raises or only warns +is controlled by `log_level` — see [Solving and Simulating](solving_and_simulating.md). ## Under the Hood diff --git a/docs/user_guide/parameters.md b/docs/user_guide/parameters.md index 3e22472aa..b092e0b61 100644 --- a/docs/user_guide/parameters.md +++ b/docs/user_guide/parameters.md @@ -164,7 +164,10 @@ model = Model( # The solve/simulate calls only need the estimated parameters params = {"risk_aversion": 1.5, "disutility_of_work": 1.0} result = model.simulate( - params=params, initial_conditions=..., period_to_regime_to_V_arr=None + params=params, + initial_conditions=..., + period_to_regime_to_V_arr=None, + log_level="debug", ) ``` diff --git a/docs/user_guide/solving_and_simulating.md b/docs/user_guide/solving_and_simulating.md index c217635f6..b91d8fa4b 100644 --- a/docs/user_guide/solving_and_simulating.md +++ b/docs/user_guide/solving_and_simulating.md @@ -10,30 +10,54 @@ induction and simulates forward. ## Solving ```python -period_to_regime_to_V_arr = model.solve(params=params) +period_to_regime_to_V_arr = model.solve(params=params, log_level="debug") ``` Performs backward induction using dynamic programming. Returns an immutable mapping of `period -> regime_name -> value_function_array`. -### Log levels +### Log levels and runtime validation -Control console output and snapshot persistence with `log_level`: +`log_level` is a required argument: it controls both console verbosity *and* the +runtime-validation policy — how `solve()` / `simulate()` react to an invalid +transition-probability ensemble or a NaN value function. Start every project at +`"debug"` (validation runs and raises); ease to `"warning"` / `"off"` once the model is +trusted. ```python -# Default: progress + timing -period_to_regime_to_V_arr = model.solve(params=params) +# Debug — validation runs and raises on the first failure +period_to_regime_to_V_arr = model.solve(params=params, log_level="debug") -# Silent +# Silent — no logging, no validation period_to_regime_to_V_arr = model.solve(params=params, log_level="off") -# Full diagnostics + disk snapshots +# Validation runs but only warns; the run continues +period_to_regime_to_V_arr = model.solve(params=params, log_level="warning") + +# Diagnostics + disk snapshots period_to_regime_to_V_arr = model.solve( params=params, log_level="debug", log_path="./debug/" ) ``` -See [Debugging](debugging.md) for details on log levels and debug snapshots. +The full behaviour of every `log_level` × `log_path` combination: + +| `log_level` | `log_path` | Runtime validation | Console output | Snapshots to disk | +| --------------------- | ---------- | ------------------------- | ------------------------------- | --------------------------------------------------------- | +| `"off"` | (ignored) | not run | silent | none | +| `"warning"` | `None` | runs → failures **warn** | warnings | none | +| `"warning"` | set | runs → failures **warn** | warnings | one per warned failure, capped at `log_keep_n_latest` | +| `"progress"` | `None` | runs → failures **warn** | warnings + timing | none | +| `"progress"` | set | runs → failures **warn** | warnings + timing | one per warned failure, capped at `log_keep_n_latest` | +| `"debug"` *(default)* | `None` | runs → failures **raise** | warnings + timing + V_arr stats | none | +| `"debug"` *(default)* | set | runs → failures **raise** | warnings + timing + V_arr stats | one per solve and on raise, capped at `log_keep_n_latest` | + +`log_path` is optional at every level — snapshots are written only when it is set. In +`"warning"` / `"progress"` mode, an invalid model produces warnings and a numerically +meaningless result rather than an exception; use this to keep an estimation loop +running, but read the warnings. + +See [Debugging](debugging.md) for details on snapshots. ## Simulating @@ -42,6 +66,7 @@ result = model.simulate( params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=period_to_regime_to_V_arr, + log_level="debug", ) ``` @@ -59,6 +84,7 @@ result = model.simulate( params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, + log_level="debug", ) ``` @@ -85,6 +111,7 @@ result = model.simulate( params=params, initial_conditions=df, period_to_regime_to_V_arr=None, + log_level="debug", ) ``` @@ -118,13 +145,13 @@ initial_conditions = { - All arrays must have the same length (= number of agents). - Shock states are drawn automatically. -### Optional arguments +### Further arguments -- `check_initial_conditions=True`: Validates that initial states are on-grid and regimes - are valid. Set to `False` to skip validation. +- `log_level`: Required. Console verbosity and runtime-validation policy (same options + and table as `solve()`); start at `"debug"`. Initial-condition validation (states + on-grid, regimes valid) follows this policy too — `"off"` skips it. - `seed=None`: Random seed for stochastic simulations (int). -- `log_level="progress"`: Controls logging verbosity (same options as `solve()`). -- `log_path=None`: Directory for debug snapshots (when `log_level="debug"`). +- `log_path=None`: Directory for diagnostic snapshots; optional at every level. - `log_keep_n_latest=3`: Maximum snapshot directories to retain. ### Heterogeneous initial ages @@ -250,6 +277,7 @@ result = model.simulate( params=params, initial_conditions=initial_df, period_to_regime_to_V_arr=None, + log_level="debug", ) # 5. Analyze diff --git a/docs/user_guide/tiny_example.ipynb b/docs/user_guide/tiny_example.ipynb index 7e2bafb52..bb15006db 100644 --- a/docs/user_guide/tiny_example.ipynb +++ b/docs/user_guide/tiny_example.ipynb @@ -354,6 +354,7 @@ " params=params,\n", " initial_conditions=initial_df,\n", " period_to_regime_to_V_arr=None,\n", + " log_level=\"debug\",\n", ")" ] }, diff --git a/src/lcm/__init__.py b/src/lcm/__init__.py index 5eaac6882..4d6e8e167 100644 --- a/src/lcm/__init__.py +++ b/src/lcm/__init__.py @@ -75,7 +75,6 @@ from lcm.user_regime import ( # noqa: E402 MarkovTransition, Regime, - validate_transition_probs, ) from lcm.utils.containers import invert_regime_ids # noqa: E402 from lcm.variables import ( # noqa: E402 @@ -117,5 +116,4 @@ "load_solution", "save_solution", "shocks", - "validate_transition_probs", ] diff --git a/src/lcm/_transition_checks.py b/src/lcm/_transition_checks.py index 7d65385f3..9b6b5b927 100644 --- a/src/lcm/_transition_checks.py +++ b/src/lcm/_transition_checks.py @@ -1,11 +1,23 @@ """Pre-flight numerical checks on user-supplied transition functions. Called from `Model.solve()` and `Model.simulate()` before backward induction -runs. For each (regime, active period) pair, evaluates the regime transition -function on the Cartesian product of the grid variables it accepts and -verifies finiteness, [0, 1] range, sum-to-1, no probability mass to inactive -regimes, and no positive probability to a target with incomplete stochastic -transitions. +runs. Two families: + +- **Regime transition probability check** keyed on + `validate_regime_transitions_all_periods`. Iterates active non-terminal + regimes across periods, evaluates the regime transition function on the + Cartesian product of its accepted grid variables, and verifies finiteness, + [0, 1] range, sum-to-1, no probability mass to inactive regimes, and no + positive probability to a target with incomplete stochastic transitions. +- **State transition probability check** keyed on + `validate_state_transitions_all_periods`. Sweeps every `MarkovTransition` + state transition (incl. per-target dict entries), evaluates the user + function on the Cartesian product of the function's accepted grid + variables, and verifies outcome-axis size, [0, 1] range, and sum-to-1. + +Both checks read their policy off the `logger`: `log_level="off"` skips the +check, `"warning"` / `"progress"` log each failure and let the run continue, +`"debug"` raises on the first failure. These are runtime checks: they need a fully-built `Regime` plus user `flat_params` and evaluate the transition functions numerically. The @@ -15,6 +27,7 @@ """ import inspect +import logging from types import MappingProxyType import jax @@ -22,8 +35,11 @@ import pandas as pd from lcm.ages import AgeGrid -from lcm.exceptions import InvalidRegimeTransitionProbabilitiesError -from lcm.interfaces import Regime +from lcm.exceptions import ( + InvalidRegimeTransitionProbabilitiesError, + InvalidStateTransitionProbabilitiesError, +) +from lcm.interfaces import Regime, StateActionSpace, _StochasticStateTransition from lcm.typing import ( FlatParams, FlatRegimeParams, @@ -34,6 +50,40 @@ ScalarInt, StateOrActionName, ) +from lcm.utils.logging import raise_or_warn, validation_enabled + + +def _params_callable_for_state_transition( + *, + regime: Regime, + flat_params_for_regime: FlatRegimeParams, + transition: _StochasticStateTransition, +) -> FlatRegimeParams: + """Return un-qualified params for calling a state-transition function. + + Both `regime.resolved_fixed_params` and `flat_params_for_regime` key + their entries by qualified names like `next___` (or + `next_____` for per-target dicts). The + `MarkovTransition`'s user function is called with the raw parameter + names from its signature, so the validator must strip that qualifier + before lookup. Without the strip, every transition-function + parameter that isn't a grid axis falls through to the "not + numerically validated" skip branch and the per-transition numerical + check never runs. + """ + prefix = f"next_{transition.state_name}" + if transition.target_regime_name is not None: + prefix = f"{prefix}__{transition.target_regime_name}" + prefix = f"{prefix}__" + + merged = {**regime.resolved_fixed_params, **flat_params_for_regime} + return MappingProxyType( + { + name.removeprefix(prefix): value + for name, value in merged.items() + if name.startswith(prefix) + } + ) def validate_regime_transitions_all_periods( @@ -41,6 +91,7 @@ def validate_regime_transitions_all_periods( regimes: MappingProxyType[RegimeName, Regime], flat_params: FlatParams, ages: AgeGrid, + logger: logging.Logger, ) -> None: """Validate regime transition probabilities for all periods before solve. @@ -52,12 +103,18 @@ def validate_regime_transitions_all_periods( regimes: Immutable mapping of regime names to regimes. flat_params: Immutable mapping of regime names to flat parameter mappings. ages: Age grid for the model. + logger: Logger carrying the runtime-validation policy. `log_level="off"` + returns immediately; `"warning"` / `"progress"` log each failure and + continue; `"debug"` raises on the first failure. Raises: - InvalidRegimeTransitionProbabilitiesError: If any inactive regime receives - positive transition probability. + InvalidRegimeTransitionProbabilitiesError: If a regime transition produces + invalid probabilities and the logger implies raise mode. """ + if not validation_enabled(logger): + return + last_period = ages.n_periods - 1 non_terminal_active_at_last = [ regime_name @@ -65,12 +122,15 @@ def validate_regime_transitions_all_periods( if not regime.terminal and last_period in regime.active_periods ] if non_terminal_active_at_last: - raise InvalidRegimeTransitionProbabilitiesError( - f"Non-terminal regime(s) {non_terminal_active_at_last} are active at the " - f"last period (age {ages.exact_values[last_period]}). Non-terminal regimes " - "must not be active at the last period because there is no next period to " - "transition to. Adjust the 'active' function on these regimes to exclude " - "the last age." + raise_or_warn( + logger=logger, + error=InvalidRegimeTransitionProbabilitiesError( + f"Non-terminal regime(s) {non_terminal_active_at_last} are active at " + f"the last period (age {ages.exact_values[last_period]}). Non-terminal " + "regimes must not be active at the last period because there is no " + "next period to transition to. Adjust the 'active' function on these " + "regimes to exclude the last age." + ), ) for period in range(ages.n_periods - 1): @@ -86,14 +146,17 @@ def validate_regime_transitions_all_periods( if regime.terminal: continue - _validate_regime_transition_single( - regimes=regimes, - regime_params=flat_params[regime_name], - active_regimes_next_period=active_regimes_next_period, - regime_name=regime_name, - period=period, - ages=ages, - ) + try: + _validate_regime_transition_single( + regimes=regimes, + regime_params=flat_params[regime_name], + active_regimes_next_period=active_regimes_next_period, + regime_name=regime_name, + period=period, + ages=ages, + ) + except InvalidRegimeTransitionProbabilitiesError as error: + raise_or_warn(logger=logger, error=error) def _validate_regime_transition_single( @@ -342,3 +405,189 @@ def _validate_no_reachable_incomplete_targets( f"transition differs by target), or ensure " f"'{target_regime_name}' is unreachable." ) + + +def validate_state_transitions_all_periods( + *, + regimes: MappingProxyType[RegimeName, Regime], + flat_params: FlatParams, + ages: AgeGrid, + logger: logging.Logger, +) -> None: + """Validate every `MarkovTransition` state transition before solve. + + For each non-terminal active period of each active regime, iterate the + regime's `stochastic_state_transitions` and evaluate each + `MarkovTransition` function on the Cartesian product of its accepted + grid variables. Check: + + - The output's last-axis size matches the state's outcome count. + - All values lie in [0, 1]. + - Rows along the last axis sum to 1. + + Fast-exits when no regime in the model has any stochastic state + transitions, so models without `MarkovTransition` states pay no cost. + + Args: + regimes: Immutable mapping of regime names to canonical regimes. + flat_params: Immutable mapping of regime names to flat parameter + mappings. + ages: Age grid for the model. + logger: Logger carrying the runtime-validation policy. `log_level="off"` + returns immediately; `"warning"` / `"progress"` log each failure and + continue; `"debug"` raises on the first failure. + + Raises: + InvalidStateTransitionProbabilitiesError: If a `MarkovTransition` + function returns the wrong outcome-axis size, values outside + [0, 1], or rows that don't sum to 1, and the logger implies raise + mode. + + """ + if not validation_enabled(logger): + return + if not any(r.stochastic_state_transitions for r in regimes.values()): + return + + for period in range(ages.n_periods - 1): + for regime_name, regime in regimes.items(): + if period not in regime.active_periods: + continue + if regime.terminal: + continue + if not regime.stochastic_state_transitions: + continue + + state_action_space = regime.state_action_space( + regime_params=flat_params[regime_name], + ) + age = ages.values[period] # noqa: PD011 + for transition in regime.stochastic_state_transitions.values(): + try: + _validate_state_transition_single( + transition=transition, + regime_params=_params_callable_for_state_transition( + regime=regime, + flat_params_for_regime=flat_params[regime_name], + transition=transition, + ), + state_action_space=state_action_space, + regime_name=regime_name, + age=age, + period=period, + logger=logger, + ) + except InvalidStateTransitionProbabilitiesError as error: + raise_or_warn(logger=logger, error=error) + + +def _validate_state_transition_single( + *, + transition: _StochasticStateTransition, + regime_params: FlatRegimeParams, + state_action_space: StateActionSpace, + regime_name: RegimeName, + age: float | ScalarInt | ScalarFloat, + period: int, + logger: logging.Logger, +) -> None: + """Evaluate one MarkovTransition on its grid args and validate the output.""" + func = transition.func + sig_params = tuple(inspect.signature(func).parameters) + + grid_args: dict[StateOrActionName, FloatND | IntND] = {} + scalar_kwargs: dict[str, object] = {} + period_int32 = jnp.int32(period) + + for name in sig_params: + if name == "period": + scalar_kwargs["period"] = period_int32 + elif name == "age": + scalar_kwargs["age"] = age + elif name in state_action_space.states: + grid_args[name] = state_action_space.states[name] + elif name in state_action_space.actions: + grid_args[name] = state_action_space.actions[name] + elif name in regime_params: + scalar_kwargs[name] = regime_params[name] + else: + # An indexing param the function expects is neither a regime + # grid nor a param. Skip numerical validation for this + # transition rather than raising — a raise here would conceal + # the real error the solve step surfaces. Warn so the skip is + # not silent. + logger.warning( + "MarkovTransition for state '%s' in regime '%s' not numerically " + "validated: parameter '%s' is not a recognized grid or model " + "parameter.", + transition.state_name, + regime_name, + name, + ) + return + + if grid_args: + grid_var_names = list(grid_args.keys()) + grid_arrays = list(grid_args.values()) + mesh = jnp.meshgrid(*grid_arrays, indexing="ij") + flat_arrays = [m.ravel() for m in mesh] + + def _call( + *args: FloatND | IntND, + _names: list[str] = grid_var_names, + _scalar: dict[str, object] = scalar_kwargs, + _func: object = func, + ) -> FloatND: + kwargs = dict(zip(_names, args, strict=True)) + return _func(**kwargs, **_scalar) # ty: ignore[call-non-callable] + + probs = jax.vmap(_call)(*flat_arrays) + else: + probs = func(**scalar_kwargs) + + _check_state_probs( + probs=probs, + transition=transition, + regime_name=regime_name, + age=age, + ) + + +def _check_state_probs( + *, + probs: FloatND, + transition: _StochasticStateTransition, + regime_name: RegimeName, + age: float | ScalarInt | ScalarFloat, +) -> None: + """Assert outcome-axis size, [0, 1] range, and sum-to-1 on a probs array.""" + state_label = ( + f"state '{transition.state_name}'" + if transition.target_regime_name is None + else ( + f"state '{transition.state_name}' (target regime " + f"'{transition.target_regime_name}')" + ) + ) + + if probs.shape[-1] != transition.n_outcomes: + raise InvalidStateTransitionProbabilitiesError( + f"MarkovTransition for {state_label} in regime '{regime_name}' " + f"at age {age} returned an outcome axis of size " + f"{probs.shape[-1]}; expected {transition.n_outcomes} from the " + f"state's DiscreteGrid." + ) + + if jnp.any(probs < 0) or jnp.any(probs > 1): + raise InvalidStateTransitionProbabilitiesError( + f"MarkovTransition for {state_label} in regime '{regime_name}' " + f"at age {age} returned values outside [0, 1]." + ) + + row_sums = jnp.sum(probs, axis=-1) + if not jnp.allclose(row_sums, 1.0, atol=1e-6): + raise InvalidStateTransitionProbabilitiesError( + f"MarkovTransition for {state_label} in regime '{regime_name}' " + f"at age {age} returned rows that do not sum to 1 along the " + f"outcome axis." + ) diff --git a/src/lcm/exceptions.py b/src/lcm/exceptions.py index c00e55aff..aaf6f74f1 100644 --- a/src/lcm/exceptions.py +++ b/src/lcm/exceptions.py @@ -22,6 +22,15 @@ class InvalidRegimeTransitionProbabilitiesError(PyLCMError): """Raised when the regime transition probabilities are invalid.""" +class InvalidStateTransitionProbabilitiesError(PyLCMError): + """Raised when a stochastic state transition produces invalid probabilities. + + Covers a `MarkovTransition` function whose output has the wrong outcome-axis + size, values outside [0, 1], rows that don't sum to 1, or `probs_array[…]` + subscripts that don't match the signature parameter order. + """ + + class InvalidInitialConditionsError(PyLCMError): """Raised when the initial conditions (states or regimes) are invalid.""" diff --git a/src/lcm/grids/base.py b/src/lcm/grids/base.py index 2c7dde502..f8a6559c3 100644 --- a/src/lcm/grids/base.py +++ b/src/lcm/grids/base.py @@ -1,8 +1,38 @@ from abc import ABC, abstractmethod +from lcm.exceptions import GridInitializationError from lcm.typing import Float1D, Int1D +def _fail_if_batch_size_combined_with_distributed( + *, + batch_size: int, + distributed: bool, +) -> None: + """Reject `batch_size > 0` paired with `distributed=True` on one axis. + + Each Python-level batch is its own `jax.jit` dispatch in the solve + loop, and on a distributed axis every dispatch carries a cross-device + collective. Batching therefore multiplies the per-period collective + count by `ceil(n_per_device / batch_size)`; for small `batch_size` + the collective overhead per kernel dwarfs the compute per kernel and + sharding becomes a regression rather than a speedup. Reject the + combination at construction time so the foot-gun never reaches the + solve loop. + """ + if batch_size > 0 and distributed: + raise GridInitializationError( + f"`batch_size={batch_size}` is incompatible with " + "`distributed=True` on a single grid axis: every batch " + "triggers a per-period cross-device collective, multiplying " + "the synchronisation count by ceil(n_per_device / batch_size) " + "and inverting the compute/communication ratio. Use " + "`batch_size=0` (one kernel per period over the full " + "per-device chunk) and, if memory is tight, reduce the chunk " + "by adding devices or another distributed axis." + ) + + class Grid(ABC): """LCM Grid base class.""" diff --git a/src/lcm/grids/continuous.py b/src/lcm/grids/continuous.py index f3bff0c31..cb7981619 100644 --- a/src/lcm/grids/continuous.py +++ b/src/lcm/grids/continuous.py @@ -11,7 +11,7 @@ from lcm.dtypes import canonical_float_dtype from lcm.exceptions import GridInitializationError, format_messages from lcm.grids import coordinates as grid_coordinates -from lcm.grids.base import Grid +from lcm.grids.base import Grid, _fail_if_batch_size_combined_with_distributed from lcm.typing import ( Float1D, FloatND, @@ -194,6 +194,9 @@ def _init_uniform_grid( `ScalarInt` types and only check value invariants (finiteness, ordering, positivity). """ + _fail_if_batch_size_combined_with_distributed( + batch_size=batch_size, distributed=distributed + ) dtype = canonical_float_dtype() start_jax = jnp.asarray(start, dtype=dtype) stop_jax = jnp.asarray(stop, dtype=dtype) @@ -245,6 +248,9 @@ def __init__( batch_size: int = 0, distributed: bool = False, ) -> None: + _fail_if_batch_size_combined_with_distributed( + batch_size=batch_size, distributed=distributed + ) if points is not None: _validate_irreg_spaced_grid(points) derived_n = len(points) diff --git a/src/lcm/grids/discrete.py b/src/lcm/grids/discrete.py index 0852867a0..5856c9981 100644 --- a/src/lcm/grids/discrete.py +++ b/src/lcm/grids/discrete.py @@ -2,7 +2,7 @@ from beartype import beartype from lcm._beartype_conf import GRID_CONF -from lcm.grids.base import Grid +from lcm.grids.base import Grid, _fail_if_batch_size_combined_with_distributed from lcm.grids.categorical import _validate_discrete_grid from lcm.typing import Int1D from lcm.utils.containers import get_field_names_and_values @@ -25,6 +25,9 @@ class DiscreteGrid(Grid): def __init__( self, category_class: type, batch_size: int = 0, *, distributed: bool = False ) -> None: + _fail_if_batch_size_combined_with_distributed( + batch_size=batch_size, distributed=distributed + ) _validate_discrete_grid(category_class) names_and_values = get_field_names_and_values(category_class) self.__categories = tuple(names_and_values.keys()) diff --git a/src/lcm/interfaces.py b/src/lcm/interfaces.py index fe0730024..9295f8a99 100644 --- a/src/lcm/interfaces.py +++ b/src/lcm/interfaces.py @@ -217,6 +217,38 @@ class SimulateFunctions: """Compiled function to compute next-period states.""" +@dataclasses.dataclass(frozen=True) +class _StochasticStateTransition: + """Metadata for a stochastic state transition, used by automatic validation. + + One entry exists for every `MarkovTransition` state — and for each target + of a per-target dict. The pre-solve state-transition validator consumes + these to evaluate the function on the regime's grid Cartesian product and + check that the output has the expected outcome-axis size, lies in [0, 1], + and has rows summing to 1. + """ + + func: Callable[..., FloatND] + """The `MarkovTransition`'s wrapped function.""" + + state_name: StateName + """Name of the state being transitioned.""" + + target_regime_name: RegimeName | None + """Target regime for per-target dicts; `None` for a plain `MarkovTransition`.""" + + n_outcomes: int + """Size of the outcome axis (always the last axis of the function output).""" + + indexing_params: tuple[str, ...] + """Parameters used to index `probs_array`, in subscript order. + + Derived statically at process time from the function's AST. Empty + when the function doesn't use the `probs_array[...]` pattern, in + which case the AST subscript-order check is permissively skipped. + """ + + @dataclasses.dataclass(frozen=True, kw_only=True) class Regime: """Canonical regime produced by `process_regimes` from a user-facing `Regime`. @@ -249,6 +281,17 @@ class Regime: simulate_functions: SimulateFunctions """Compiled functions for the forward-simulation phase.""" + stochastic_state_transitions: MappingProxyType[ + TransitionFunctionName, _StochasticStateTransition + ] + """Immutable mapping of qualified transition name to validation metadata. + + Populated for every `MarkovTransition` state transition. Per-target + dict entries appear under qualified names like `next_health__working`. + Empty for terminal regimes and for regimes whose state transitions + are all deterministic. + """ + _base_state_action_space: StateActionSpace = dataclasses.field(repr=False) """Base state-action space before runtime grid substitution.""" diff --git a/src/lcm/model.py b/src/lcm/model.py index 2a96e5901..da6dd7721 100644 --- a/src/lcm/model.py +++ b/src/lcm/model.py @@ -7,13 +7,21 @@ from pathlib import Path from types import MappingProxyType +import jax.numpy as jnp import pandas as pd from beartype import beartype from lcm._beartype_conf import MODEL_CONF, PARAMS_CONF -from lcm._transition_checks import validate_regime_transitions_all_periods +from lcm._transition_checks import ( + validate_regime_transitions_all_periods, + validate_state_transitions_all_periods, +) from lcm.ages import AgeGrid -from lcm.exceptions import InvalidValueFunctionError, ModelInitializationError +from lcm.exceptions import ( + InvalidInitialConditionsError, + InvalidValueFunctionError, + ModelInitializationError, +) from lcm.grids import DiscreteGrid from lcm.model_processing import ( _validate_param_types, @@ -59,7 +67,13 @@ ensure_containers_are_mutable, get_field_names_and_values, ) -from lcm.utils.logging import LogLevel, get_logger +from lcm.utils.logging import ( + LogLevel, + get_logger, + raise_or_warn, + validation_enabled, + validation_raises, +) class Model: @@ -256,8 +270,8 @@ def solve( self, *, params: UserParams, + log_level: LogLevel, max_compilation_workers: int | None = None, - log_level: LogLevel = "progress", log_path: str | Path | None = None, log_keep_n_latest: int = 3, ) -> PeriodToRegimeToVArr: @@ -274,31 +288,46 @@ def solve( specification Values may be `pd.Series` with labeled indices; they are auto-converted to JAX arrays. + log_level: Verbosity, and the runtime-validation policy it implies. + Required — pick deliberately for the situation: + - `"off"` — silent; transition-probability and NaN checks skipped. + - `"warning"` — validation runs, failures logged as warnings, + the run continues. + - `"progress"` — as `"warning"`, plus timing. + - `"debug"` — validation runs and **raises** on the first + failure; adds value-function stats. + Start every project at `"debug"`: fail early and gather maximum + diagnostics. Ease to `"warning"` / `"off"` only once the model + is trusted and you need the speed or the non-raising behaviour + for an estimation loop. max_compilation_workers: Maximum number of threads for parallel XLA compilation. Defaults to the number of physical CPU cores. - log_level: Logging verbosity. `"off"` suppresses output, `"warning"` shows - NaN/Inf warnings, `"progress"` adds timing, `"debug"` adds stats and - requires `log_path`. - log_path: Directory for persisting debug snapshots. Required when - `log_level="debug"`. - log_keep_n_latest: Maximum number of debug snapshots to keep on disk. + log_path: Directory for persisting diagnostic snapshots. Optional at + every level; snapshots are written only when it is set. + log_keep_n_latest: Maximum number of snapshots to retain on disk. Returns: Immutable mapping of period to a value function array for each regime. """ - _validate_log_args(log_level=log_level, log_path=log_path) + log = get_logger(log_level=log_level) flat_params = self._process_params(params) validate_regime_transitions_all_periods( regimes=self.regimes, flat_params=flat_params, ages=self.ages, + logger=log, + ) + validate_state_transitions_all_periods( + regimes=self.regimes, + flat_params=flat_params, + ages=self.ages, + logger=log, ) return self._solve_compiled( flat_params=flat_params, params=params, - log=get_logger(log_level=log_level), - log_level=log_level, + log=log, log_path=log_path, log_keep_n_latest=log_keep_n_latest, max_compilation_workers=max_compilation_workers, @@ -310,12 +339,17 @@ def _solve_compiled( flat_params: FlatParams, params: UserParams, log: logging.Logger, - log_level: LogLevel, log_path: str | Path | None, log_keep_n_latest: int, max_compilation_workers: int | None, ) -> PeriodToRegimeToVArr: - """Run backward induction, persisting a snapshot on debug or NaN failure.""" + """Run backward induction, persisting a diagnostic snapshot when warranted. + + With `log_path` set, a snapshot is written at `log_level="debug"` + (every solve) and at `"warning"` / `"progress"` whenever the returned + solution contains NaN. `_enforce_retention` caps the snapshot count at + `log_keep_n_latest`. + """ try: period_to_regime_to_V_arr = solve( flat_params=flat_params, @@ -336,7 +370,11 @@ def _solve_compiled( ) exc.add_note(f"Snapshot saved to {snap_dir}") raise - if log_level == "debug" and log_path is not None: + if ( + log_path is not None + and validation_enabled(log) + and (validation_raises(log) or _contains_nan(period_to_regime_to_V_arr)) + ): save_solve_snapshot( model=self, params=params, @@ -388,9 +426,8 @@ def simulate( params: UserParams, initial_conditions: UserInitialConditions | pd.DataFrame, period_to_regime_to_V_arr: PeriodToRegimeToVArr | None, - check_initial_conditions: bool = True, + log_level: LogLevel, seed: int | None = None, - log_level: LogLevel = "progress", log_path: str | Path | None = None, log_keep_n_latest: int = 3, max_compilation_workers: int | None = None, @@ -420,14 +457,23 @@ def simulate( (auto-converted via `initial_conditions_from_dataframe`). period_to_regime_to_V_arr: Value function arrays from `solve()`. When `None`, the model is solved automatically before simulating. - check_initial_conditions: Whether to validate initial conditions. seed: Random seed. - log_level: Logging verbosity. `"off"` suppresses output, `"warning"` shows - NaN/Inf warnings, `"progress"` adds timing, `"debug"` adds stats and - requires `log_path`. - log_path: Directory for persisting debug snapshots. Required when - `log_level="debug"`. - log_keep_n_latest: Maximum number of debug snapshots to keep on disk. + log_level: Verbosity, and the runtime-validation policy it implies. + Required — pick deliberately for the situation: + - `"off"` — silent; initial-condition, transition-probability, + and NaN checks skipped. + - `"warning"` — validation runs, failures logged as warnings, + the run continues. + - `"progress"` — as `"warning"`, plus timing. + - `"debug"` — validation runs and **raises** on the first + failure; adds value-function stats. + Start every project at `"debug"`: fail early and gather maximum + diagnostics. Ease to `"warning"` / `"off"` only once the model + is trusted and you need the speed or the non-raising behaviour + for an estimation loop. + log_path: Directory for persisting diagnostic snapshots. Optional at + every level; snapshots are written only when it is set. + log_keep_n_latest: Maximum number of snapshots to retain on disk. max_compilation_workers: Maximum number of threads for parallel XLA compilation. Only used when `period_to_regime_to_V_arr` is `None` (i.e. when solve runs automatically). Defaults to the number of @@ -438,7 +484,7 @@ def simulate( optionally with additional_targets. """ - _validate_log_args(log_level=log_level, log_path=log_path) + log = get_logger(log_level=log_level) if isinstance(initial_conditions, pd.DataFrame): initial_conditions = initial_conditions_from_dataframe( df=initial_conditions, @@ -450,20 +496,29 @@ def simulate( regimes=self.regimes, ) flat_params = self._process_params(params) - if check_initial_conditions: - validate_initial_conditions( - initial_conditions=initial_conditions, - regimes=self.regimes, - regime_names_to_ids=self.regime_names_to_ids, - flat_params=flat_params, - ages=self.ages, - ) + if validation_enabled(log): + try: + validate_initial_conditions( + initial_conditions=initial_conditions, + regimes=self.regimes, + regime_names_to_ids=self.regime_names_to_ids, + flat_params=flat_params, + ages=self.ages, + ) + except InvalidInitialConditionsError as error: + raise_or_warn(logger=log, error=error) validate_regime_transitions_all_periods( regimes=self.regimes, flat_params=flat_params, ages=self.ages, + logger=log, + ) + validate_state_transitions_all_periods( + regimes=self.regimes, + flat_params=flat_params, + ages=self.ages, + logger=log, ) - log = get_logger(log_level=log_level) actual_n_subjects = len(next(iter(initial_conditions.values()))) n_subjects = self.n_subjects if n_subjects is not None and n_subjects == actual_n_subjects: @@ -485,7 +540,6 @@ def simulate( flat_params=flat_params, params=params, log=log, - log_level=log_level, log_path=log_path, log_keep_n_latest=log_keep_n_latest, max_compilation_workers=max_compilation_workers, @@ -512,7 +566,7 @@ def simulate( # the lazy regimes to keep the result cloudpickle-safe. if simulate_regimes is not self.regimes: result._regimes = self.regimes # noqa: SLF001 - if log_level == "debug" and log_path is not None: + if log_path is not None and validation_raises(log): save_simulate_snapshot( model=self, params=params, @@ -588,8 +642,10 @@ def _merge_derived_categoricals( return MappingProxyType(result) -def _validate_log_args(*, log_level: LogLevel, log_path: str | Path | None) -> None: - """Raise ValueError if log_level='debug' but log_path is not set.""" - if log_level == "debug" and log_path is None: - msg = "log_path is required when log_level='debug'" - raise ValueError(msg) +def _contains_nan(period_to_regime_to_V_arr: PeriodToRegimeToVArr) -> bool: + """Return whether any value function array holds a NaN.""" + return any( + bool(jnp.any(jnp.isnan(V_arr))) + for regime_to_V_arr in period_to_regime_to_V_arr.values() + for V_arr in regime_to_V_arr.values() + ) diff --git a/src/lcm/persistence.py b/src/lcm/persistence.py index 818185df9..e3534be15 100644 --- a/src/lcm/persistence.py +++ b/src/lcm/persistence.py @@ -435,26 +435,45 @@ def _write_environment_files(snap_dir: Path) -> None: (snap_dir / "REPRODUCE.md").write_text(reproduce_md) +def _snapshot_counter(entry: Path, prefix: str) -> int: + """Parse the numeric counter suffix of a snapshot directory. + + Returns `-1` for a name that does not end in an integer, so callers can + skip foreign directories rather than mis-order them. + """ + try: + return int(entry.name.removeprefix(f"{prefix}_")) + except ValueError: + return -1 + + def _next_counter(parent_path: Path, prefix: str) -> int: """Compute the next monotonic counter for snapshot directories with given prefix.""" - existing = sorted(parent_path.glob(f"{prefix}_*/")) - if not existing: - return 1 - counters: list[int] = [] - for entry in existing: - try: - counters.append(int(entry.name.rsplit("_", 1)[1])) - except IndexError, ValueError: - continue + counters = [ + counter + for entry in parent_path.glob(f"{prefix}_*/") + if (counter := _snapshot_counter(entry, prefix)) >= 0 + ] return max(counters, default=0) + 1 def _enforce_retention(parent_path: Path, prefix: str, *, keep_n_latest: int) -> None: - """Delete oldest snapshot directories so that at most keep_n_latest remain.""" - existing = sorted(parent_path.glob(f"{prefix}_*/")) - if len(existing) > keep_n_latest: - for snap_dir in existing[: len(existing) - keep_n_latest]: - shutil.rmtree(snap_dir) + """Delete oldest snapshot directories so that at most keep_n_latest remain. + + Directories are ordered by their parsed integer counter, not by name, so + retention stays correct once the counter grows past the zero-padded width + (e.g. `snapshot_1000` is newer than `snapshot_999`). + """ + existing = sorted( + ( + entry + for entry in parent_path.glob(f"{prefix}_*/") + if _snapshot_counter(entry, prefix) >= 0 + ), + key=lambda entry: _snapshot_counter(entry, prefix), + ) + for snap_dir in existing[: max(0, len(existing) - keep_n_latest)]: + shutil.rmtree(snap_dir) def atomic_dump(obj: object, path: str | Path, *, protocol: int) -> Path: diff --git a/src/lcm/regime_building/processing.py b/src/lcm/regime_building/processing.py index b1179ea1e..09c17732d 100644 --- a/src/lcm/regime_building/processing.py +++ b/src/lcm/regime_building/processing.py @@ -36,7 +36,10 @@ get_Q_and_F, get_Q_and_F_terminal, ) -from lcm.regime_building.transitions import collect_state_transitions +from lcm.regime_building.transitions import ( + collect_state_transitions, + collect_stochastic_state_transitions, +) from lcm.regime_building.V import VInterpolationInfo, create_v_interpolation_info from lcm.shocks import _ShockGrid from lcm.state_action_space import create_state_action_space @@ -191,6 +194,11 @@ def process_regimes( solve_compute_regime_transition_probs=solve_functions.compute_regime_transition_probs, ) + stochastic_state_transitions = collect_stochastic_state_transitions( + user_regime=user_regime, + user_regimes=user_regimes, + ) + canonical_regimes[regime_name] = Regime( name=regime_name, terminal=user_regime.terminal, @@ -200,6 +208,7 @@ def process_regimes( regime_params_template=regime_params_template, solve_functions=solve_functions, simulate_functions=simulate_functions, + stochastic_state_transitions=stochastic_state_transitions, _base_state_action_space=state_action_spaces[regime_name], ) diff --git a/src/lcm/regime_building/transitions.py b/src/lcm/regime_building/transitions.py index f1825edcd..a5a77573e 100644 --- a/src/lcm/regime_building/transitions.py +++ b/src/lcm/regime_building/transitions.py @@ -1,12 +1,28 @@ -"""State transition collection from user-facing `state_transitions` mappings.""" +"""Collect state transitions from user-facing `state_transitions` mappings. +Two collectors, both walking a regime's `state_transitions`: + +- `collect_state_transitions` — every state's transition *function* (bare + callable, `MarkovTransition`, an identity for `None`, per-target dict). +- `collect_stochastic_state_transitions` — validation *metadata* for the + `MarkovTransition` entries only: a process-time AST subscript-order check + plus `n_outcomes` derivation, cached on the canonical `Regime` and consumed + by the pre-solve state-transition validator. +""" + +import inspect from collections.abc import Callable, Mapping -from typing import TypeAliasType +from types import MappingProxyType +from typing import TypeAliasType, cast from dags.tree import QNAME_DELIMITER -from lcm.exceptions import RegimeInitializationError +from lcm.exceptions import ( + InvalidStateTransitionProbabilitiesError, + RegimeInitializationError, +) from lcm.grids import DiscreteGrid, Grid +from lcm.interfaces import _StochasticStateTransition from lcm.shocks._base import _ShockGrid from lcm.typing import ( ContinuousState, @@ -16,7 +32,9 @@ TransitionFunctionName, UserFunction, ) -from lcm.user_regime import _IdentityTransition +from lcm.user_regime import MarkovTransition, _IdentityTransition +from lcm.user_regime import Regime as UserRegime +from lcm.utils.ast_inspection import _get_func_indexing_params def collect_state_transitions( @@ -72,6 +90,67 @@ def collect_state_transitions( return transitions +def collect_stochastic_state_transitions( + *, + user_regime: UserRegime, + user_regimes: Mapping[RegimeName, UserRegime], +) -> MappingProxyType[TransitionFunctionName, _StochasticStateTransition]: + """Collect validation metadata for every `MarkovTransition` state transition. + + Walks `user_regime.state_transitions` and yields one entry per + `MarkovTransition`. Per-target dict entries are flattened into + `next_{state}__{target}` keys, mirroring the qname pattern used by + `collect_state_transitions`. Returns an empty mapping for regimes with + no stochastic state transitions (incl. terminal regimes). + + Args: + user_regime: User-facing regime to inspect. + user_regimes: All user regimes in the model. Needed to look up + `n_outcomes` for target-only states whose `DiscreteGrid` lives + on the target regime, not the source. + + Returns: + Immutable mapping of qualified transition name to validation + metadata. + + Raises: + InvalidStateTransitionProbabilitiesError: If a `MarkovTransition`'s + `probs_array` subscript order does not match the function's + signature parameter order. Permissively skipped when the + function does not use the `probs_array[...]` pattern. + + """ + entries: dict[TransitionFunctionName, _StochasticStateTransition] = {} + + for state_name, raw in user_regime.state_transitions.items(): + if isinstance(raw, MarkovTransition): + _add_stochastic_entry( + entries=entries, + key=f"next_{state_name}", + markov=raw, + state_name=state_name, + target_regime_name=None, + user_regime=user_regime, + user_regimes=user_regimes, + ) + elif isinstance(raw, Mapping): + for raw_target_name, target_value in raw.items(): + if not isinstance(target_value, MarkovTransition): + continue + target_regime_name: RegimeName = cast("RegimeName", raw_target_name) + _add_stochastic_entry( + entries=entries, + key=f"next_{state_name}__{target_regime_name}", + markov=target_value, + state_name=state_name, + target_regime_name=target_regime_name, + user_regime=user_regime, + user_regimes=user_regimes, + ) + + return MappingProxyType(entries) + + def _make_identity_fn( *, state_name: StateName, annotation: TypeAliasType ) -> _IdentityTransition: @@ -92,3 +171,109 @@ def _add_raw_transition( for target_name, target_value in raw.items(): key = f"next_{name}{QNAME_DELIMITER}{target_name}" transitions[key] = target_value + + +def _add_stochastic_entry( + *, + entries: dict[TransitionFunctionName, _StochasticStateTransition], + key: TransitionFunctionName, + markov: MarkovTransition, + state_name: str, + target_regime_name: RegimeName | None, + user_regime: UserRegime, + user_regimes: Mapping[RegimeName, UserRegime], +) -> None: + """Static-check one MarkovTransition and append its metadata.""" + func = markov.func + + state_grid = _find_state_grid( + state_name=state_name, + target_regime_name=target_regime_name, + user_regime=user_regime, + user_regimes=user_regimes, + ) + if not isinstance(state_grid, DiscreteGrid): + # `MarkovTransition` on a continuous state is not a supported + # pattern for the automatic validator. Static phase tolerates + # the omission; the runtime phase skips it by absence from the + # metadata. The subscript-order check is skipped too — it applies + # only to the discrete `probs_array[...]` pattern this validator + # covers. + return + + indexing_params = tuple( + _get_func_indexing_params(func=func, array_param_name="probs_array") + ) + _check_subscript_order( + func=func, indexing_params=indexing_params, state_name=state_name + ) + n_outcomes = len(state_grid.categories) + + entries[key] = _StochasticStateTransition( + func=func, + state_name=state_name, + target_regime_name=target_regime_name, + n_outcomes=n_outcomes, + indexing_params=indexing_params, + ) + + +def _find_state_grid( + *, + state_name: str, + target_regime_name: RegimeName | None, + user_regime: UserRegime, + user_regimes: Mapping[RegimeName, UserRegime], +) -> object: + """Look up the state's grid for outcome-axis sizing. + + For a per-target dict entry the **target** regime's grid is authoritative: + the `MarkovTransition` returns a distribution over the target's state + space, which may differ in size from the source's (cross-grid + transitions). The source grid is never substituted in that case — if the + target regime does not declare the state, `None` is returned so the + caller skips metadata creation rather than sizing off a wrong grid. + + A plain `MarkovTransition` (no per-target dict) sizes off the source + regime's grid. + + Returns `None` when no authoritative grid is found. + """ + if target_regime_name is not None: + target = user_regimes.get(target_regime_name) + if target is not None and state_name in target.states: + return target.states[state_name] + return None + if state_name in user_regime.states: + return user_regime.states[state_name] + return None + + +def _check_subscript_order( + *, + func: object, + indexing_params: tuple[str, ...], + state_name: str, +) -> None: + """Raise if `probs_array[…]` subscripts don't match signature order. + + Permissive: when the function doesn't use the `probs_array[...]` + pattern (`indexing_params` is empty), the check silently skips. The + runtime numerical checks still cover such functions. + """ + if not indexing_params: + return + sig = inspect.signature(func) # ty: ignore[invalid-argument-type] + sig_order = tuple( + p for p in sig.parameters if p != "probs_array" and p in indexing_params + ) + if indexing_params != sig_order: + func_name = getattr(func, "__name__", "") + msg = ( + f"In MarkovTransition for state '{state_name}', function " + f"'{func_name}' indexes `probs_array` as " + f"`probs_array[{', '.join(indexing_params)}]` but the signature " + f"order is `probs_array[{', '.join(sig_order)}]`. Swap the " + f"subscript order or the signature so they match." + ) + raise InvalidStateTransitionProbabilitiesError(msg) diff --git a/src/lcm/solution/solve_brute.py b/src/lcm/solution/solve_brute.py index c4dcbb9ab..d8651c495 100644 --- a/src/lcm/solution/solve_brute.py +++ b/src/lcm/solution/solve_brute.py @@ -11,6 +11,7 @@ import jax.numpy as jnp from lcm.ages import AgeGrid +from lcm.exceptions import InvalidValueFunctionError from lcm.interfaces import Regime, _build_regime_sharding from lcm.solution.validate_V import validate_V from lcm.typing import BoolND, FlatParams, FloatND, RegimeName, StateName @@ -18,6 +19,9 @@ format_duration, log_period_header, log_period_timing, + raise_or_warn, + validation_enabled, + validation_raises, ) @@ -37,7 +41,11 @@ def solve( ages: Age grid for the model. regimes: The internal regimes, that contain all necessary functions to solve the model. - logger: Logger that logs to stdout. + logger: Logger that logs to stdout, and carries the runtime-validation + policy. `log_level="debug"` stops backward induction at the first + NaN period and raises; `"warning"` / `"progress"` let induction run + to completion and log a warning, so `solve` returns a complete + (NaN-bearing) solution; `"off"` skips the NaN check. enable_jit: Whether to JIT-compile the functions of the internal regimes. max_compilation_workers: Maximum number of threads for parallel XLA compilation. Defaults to `os.cpu_count()`. @@ -97,11 +105,13 @@ def solve( # localisation. On a healthy solve no per-row materialisation # happens. # - # Gate falls out of the public log level: - # - `"off"` ⇒ nothing (skips even the NaN fail-fast) - # - `"warning"` / `"progress"` ⇒ NaN/Inf only - # - `"debug"` ⇒ adds the min/max/mean trio - diagnostics_enabled = logger.isEnabledFor(logging.WARNING) + # Two gates, both falling out of the public log level: + # - NaN/Inf tracking feeds runtime validation, so it runs whenever + # validation is not `"off"` (log levels `"warning"`/`"progress"`/ + # `"debug"`). It skips even the NaN fail-fast when validation is off. + # - The min/max/mean trio is a pure logging extra, gated on the + # logger's debug level. + diagnostics_enabled = validation_enabled(logger) stats_enabled = logger.isEnabledFor(logging.DEBUG) diagnostic_rows: list[_DiagnosticRow] = [] diagnostic_min: list[FloatND] = [] @@ -206,26 +216,33 @@ def solve( # Fail-fast on NaN: surface the offending period immediately # instead of finishing the whole backward induction. Costs one # host transfer of a scalar bool per period — negligible next - # to the per-period `max_Q_over_a` kernel, and only paid when - # diagnostics are on. Inf is non-fatal so we don't break on - # it; the post-loop emitter still raises a warning if any - # period flagged Inf. - if diagnostics_enabled and running_any_nan.item(): + # to the per-period `max_Q_over_a` kernel. Inf is non-fatal so + # we don't break on it; the post-loop emitter still raises a + # warning if any period flagged Inf. + # + # Only raise mode fails fast. Raise mode is the loudest level, so + # diagnostics are on and `running_any_nan` has been tracked. In warn + # mode induction runs to completion so `solve` returns a complete + # (NaN-bearing) solution rather than a truncated one. + if validation_raises(logger) and running_any_nan.item(): break if diagnostics_enabled: - _emit_post_loop_diagnostics( - logger=logger, - diagnostic_rows=diagnostic_rows, - solution=MappingProxyType(solution), - regimes=regimes, - flat_params=flat_params, - running_any_nan=running_any_nan, - running_any_inf=running_any_inf, - diagnostic_min=diagnostic_min if stats_enabled else None, - diagnostic_max=diagnostic_max if stats_enabled else None, - diagnostic_mean=diagnostic_mean if stats_enabled else None, - ) + try: + _emit_post_loop_diagnostics( + logger=logger, + diagnostic_rows=diagnostic_rows, + solution=MappingProxyType(solution), + regimes=regimes, + flat_params=flat_params, + running_any_nan=running_any_nan, + running_any_inf=running_any_inf, + diagnostic_min=diagnostic_min if stats_enabled else None, + diagnostic_max=diagnostic_max if stats_enabled else None, + diagnostic_mean=diagnostic_mean if stats_enabled else None, + ) + except InvalidValueFunctionError as error: + raise_or_warn(logger=logger, error=error) total_elapsed = time.monotonic() - total_start logger.info("Solution complete (%s)", format_duration(seconds=total_elapsed)) diff --git a/src/lcm/user_regime.py b/src/lcm/user_regime.py index 619dce60d..4a0e75ec6 100644 --- a/src/lcm/user_regime.py +++ b/src/lcm/user_regime.py @@ -5,9 +5,8 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass, field from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Literal, TypeAliasType, cast, overload +from typing import Any, Literal, TypeAliasType, cast, overload -import jax.numpy as jnp from beartype import beartype from dags.tree import QNAME_DELIMITER @@ -28,20 +27,10 @@ StateName, UserFunction, ) -from lcm.utils.ast_inspection import _get_func_indexing_params from lcm.utils.containers import ( ensure_containers_are_immutable, ) -# Genuine circular import: model.py imports from this module at module level. -# The `model` parameter of `validate_transition_probs` is annotated with the -# fully-qualified `lcm.model.Model` so the beartype claw resolves it by -# importing `lcm.model` at first call — long after the import cycle settles — -# rather than at module-init time. Importing `lcm.model` here keeps `lcm` a -# bound name for the type checker. -if TYPE_CHECKING: - import lcm.model - @beartype(conf=REGIME_CONF) @dataclass(frozen=True) @@ -584,166 +573,3 @@ def _validate_per_target_dict( f"MarkovTransition or none are.", ) return error_messages - - -def validate_transition_probs( - *, - probs: FloatND, - model: lcm.model.Model, - regime_name: RegimeName, - state_name: StateName, - target_regime_name: RegimeName | None = None, -) -> None: - """Validate a state transition probability array. - - Check that the array has the shape expected from the function signature, - that all values are in [0, 1], that rows sum to 1, and that the function's - `probs_array[…]` subscripts match the signature parameter order. - - For per-target state transitions (where `state_transitions[state_name]` is - a dict mapping target regime names to `MarkovTransition` instances), pass - `target_regime_name` to select the specific transition to validate. - - Regime transition probabilities are validated automatically before solve - via `validate_regime_transitions_all_periods` in - `lcm/_transition_checks.py`; this helper covers only state transitions. - - Args: - probs: The transition probability array to validate. - model: The LCM Model instance. - regime_name: Name of the regime. - state_name: Name of the state with a `MarkovTransition`. - target_regime_name: Target regime name for per-target state - transitions. Required when the state transition is a per-target - dict. - - Raises: - TypeError: If the transition is not a `MarkovTransition`. - ValueError: If the shape is wrong, values are outside [0, 1], or rows - don't sum to 1. - - """ - regime = model.user_regimes[regime_name] - raw_transition = regime.state_transitions[state_name] - markov = _extract_markov_transition( - raw_transition=raw_transition, - state_name=state_name, - regime_name=regime_name, - target_regime_name=target_regime_name, - ) - func = markov.func - grids = _build_grids(regime) - n_outcomes = len(grids[state_name].categories) - - indexing_params = _get_func_indexing_params( - func=func, array_param_name="probs_array" - ) - - sig = inspect.signature(func) - sig_order = [ - p for p in sig.parameters if p != "probs_array" and p in indexing_params - ] - if indexing_params != sig_order: - func_name = getattr(func, "__name__", "") - msg = ( - f"In function '{func_name}', `probs_array` is indexed as " - f"`probs_array[{', '.join(indexing_params)}]` but the signature " - f"order is `probs_array[{', '.join(sig_order)}]`." - ) - raise ValueError(msg) - - expected_shape = _build_expected_shape( - indexing_params=indexing_params, - n_outcomes=n_outcomes, - grids=grids, - model=model, - ) - - if probs.shape != expected_shape: - msg = f"Expected shape {expected_shape} but got {probs.shape}." - raise ValueError(msg) - - if jnp.any(probs < 0) or jnp.any(probs > 1): - msg = "All values must be in [0, 1]." - raise ValueError(msg) - - row_sums = jnp.sum(probs, axis=-1) - if not jnp.allclose(row_sums, 1.0, atol=1e-6): - msg = "Rows must sum to 1 along the last axis." - raise ValueError(msg) - - -def _extract_markov_transition( - *, - raw_transition: object, - state_name: StateName, - regime_name: RegimeName, - target_regime_name: RegimeName | None, -) -> MarkovTransition: - """Extract a MarkovTransition from a raw transition, handling per-target dicts.""" - if isinstance(raw_transition, MarkovTransition): - return raw_transition - - if isinstance(raw_transition, Mapping): - if target_regime_name is None: - targets = sorted(raw_transition.keys()) - msg = ( - f"State '{state_name}' in regime '{regime_name}' uses per-target " - f"transitions. Pass target_regime_name to select one of: {targets}." - ) - raise TypeError(msg) - if target_regime_name not in raw_transition: - msg = ( - f"Target regime '{target_regime_name}' not found in per-target " - f"transitions for state '{state_name}' in regime '{regime_name}'. " - f"Available targets: {sorted(raw_transition.keys())}." - ) - raise ValueError(msg) - entry = raw_transition[target_regime_name] # ty: ignore[invalid-argument-type] - if not isinstance(entry, MarkovTransition): - msg = ( - f"Per-target transition for '{target_regime_name}' in state " - f"'{state_name}' of regime '{regime_name}' is not a " - f"MarkovTransition. Got {type(entry).__name__}." - ) - raise TypeError(msg) - return entry - - msg = ( - f"State '{state_name}' in regime '{regime_name}' is not a " - f"MarkovTransition. Got {type(raw_transition).__name__}." - ) - raise TypeError(msg) - - -def _build_grids(user_regime: Regime) -> dict[str, DiscreteGrid]: - """Collect all DiscreteGrid instances from regime states and actions.""" - return { - name: grid - for name, grid in (*user_regime.states.items(), *user_regime.actions.items()) - if isinstance(grid, DiscreteGrid) - } - - -def _build_expected_shape( - *, - indexing_params: list[str], - n_outcomes: int, - grids: dict[str, DiscreteGrid], - model: lcm.model.Model, -) -> tuple[int, ...]: - """Compute expected shape for a transition probability array.""" - shape: list[int] = [] - for param_name in indexing_params: - if param_name == "period": - shape.append(model.n_periods) - elif param_name in grids: - shape.append(len(grids[param_name].categories)) - else: - msg = ( - f"Cannot determine expected size for parameter '{param_name}'. " - f"It is not 'period' and not a DiscreteGrid state or action." - ) - raise ValueError(msg) - shape.append(n_outcomes) - return tuple(shape) diff --git a/src/lcm/utils/logging.py b/src/lcm/utils/logging.py index 37a33a40b..4d72d4dac 100644 --- a/src/lcm/utils/logging.py +++ b/src/lcm/utils/logging.py @@ -15,6 +15,46 @@ } +def validation_enabled(logger: logging.Logger) -> bool: + """Return whether runtime validation runs at all. + + Runtime validation runs unless `log_level="off"`. The logger's level is + the single source of truth for the runtime policy: `"off"` raises the + logger to `CRITICAL`, every other level keeps it at `WARNING` or lower. + """ + return logger.isEnabledFor(logging.WARNING) + + +def validation_raises(logger: logging.Logger) -> bool: + """Return whether a validation failure raises (vs. logs a warning). + + A failure raises at `log_level="debug"` and only warns at `"warning"` / + `"progress"`. `"debug"` is the one level that lowers the logger to + `DEBUG`, so `isEnabledFor(DEBUG)` is exactly the raise predicate. + """ + return logger.isEnabledFor(logging.DEBUG) + + +def raise_or_warn(*, logger: logging.Logger, error: Exception) -> None: + """Surface a validation failure according to the logger's policy. + + Raises the error when the logger implies raise mode (`log_level="debug"`); + otherwise logs it as a warning and returns so the run continues. Must not + be called when validation is disabled (`log_level="off"`). + + Args: + logger: Logger carrying the runtime-validation policy. + error: The validation error to raise or describe. + + Raises: + Exception: The passed `error`, in raise mode. + + """ + if validation_raises(logger): + raise error + logger.warning("%s", error) + + def get_logger(*, log_level: LogLevel) -> logging.Logger: """Get a logger that logs to stdout. diff --git a/tests/simulation/test_initial_conditions.py b/tests/simulation/test_initial_conditions.py index ccb8e874b..c78278ad2 100644 --- a/tests/simulation/test_initial_conditions.py +++ b/tests/simulation/test_initial_conditions.py @@ -310,6 +310,7 @@ def test_infeasible_initial_states_detected(): _working_life = model.regime_names_to_ids["working_life"] with pytest.raises(InvalidInitialConditionsError): model.simulate( + log_level="debug", params=params, initial_conditions={ "age": jnp.array([0.0]), @@ -336,6 +337,7 @@ def test_on_grid_state_but_combination_infeasible(): _working_life = model.regime_names_to_ids["working_life"] with pytest.raises(InvalidInitialConditionsError): model.simulate( + log_level="debug", params=params, initial_conditions={ "age": jnp.array([0.0]), @@ -357,6 +359,7 @@ def test_extrapolated_initial_states_accepted(): } _working_life = model.regime_names_to_ids["working_life"] model.simulate( + log_level="debug", params=params, initial_conditions={ "age": jnp.array([0.0]), @@ -378,6 +381,7 @@ def test_on_grid_initial_states_accepted(): } _working_life = model.regime_names_to_ids["working_life"] model.simulate( + log_level="debug", params=params, initial_conditions={ "age": jnp.array([0.0]), @@ -401,6 +405,7 @@ def test_irreg_spaced_grid_with_runtime_points(): _working_life = model.regime_names_to_ids["working_life"] with pytest.raises(InvalidInitialConditionsError): model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.array([0.3]), diff --git a/tests/simulation/test_simulate.py b/tests/simulation/test_simulate.py index 461357ed7..3f05782b2 100644 --- a/tests/simulation/test_simulate.py +++ b/tests/simulation/test_simulate.py @@ -126,7 +126,7 @@ def _model_solution(n_periods): ages=ages, regime_id_class=RegimeId, ) - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) return period_to_regime_to_V_arr, params, model return _model_solution @@ -143,6 +143,7 @@ def test_simulate_using_model_methods( ) result = model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions={ @@ -197,6 +198,7 @@ def test_simulate_with_only_discrete_actions(): params = get_params(n_periods=3, wage=1.5, discount_factor=1.0, interest_rate=0.0) result = model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.array([0.0, 2.0]), @@ -261,6 +263,7 @@ def test_effect_of_discount_factor_on_last_period(): df_low = ( model.simulate( + log_level="debug", params=params_low, initial_conditions={ "wealth": initial_wealth, @@ -275,6 +278,7 @@ def test_effect_of_discount_factor_on_last_period(): df_high = ( model.simulate( + log_level="debug", params=params_high, initial_conditions={ "wealth": initial_wealth, @@ -326,6 +330,7 @@ def test_effect_of_disutility_of_work(): df_low = ( model.simulate( + log_level="debug", params=params_low, initial_conditions={ "wealth": initial_wealth, @@ -340,6 +345,7 @@ def test_effect_of_disutility_of_work(): df_high = ( model.simulate( + log_level="debug", params=params_high, initial_conditions={ "wealth": initial_wealth, @@ -373,6 +379,7 @@ def test_to_dataframe_use_labels_parameter(): model = get_model(n_periods=3) params = get_params(n_periods=3) result = model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.array([20.0, 50.0]), @@ -401,6 +408,7 @@ def regression_simulation_result(): model = get_model(n_periods=3) params = get_params(n_periods=3) return model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.array([20.0, 50.0]), @@ -444,6 +452,7 @@ def test_additional_targets_all_with_stochastic_transitions(): params = get_params(n_periods=3) result = model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.array([20.0, 50.0]), @@ -484,6 +493,7 @@ def test_simulation_result_pickle_roundtrip(tmp_path: Path): model = get_model(n_periods=3) params = get_params(n_periods=3) result = model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.array([20.0, 50.0]), diff --git a/tests/simulation/test_simulate_aot.py b/tests/simulation/test_simulate_aot.py index ee63aac17..d95d89703 100644 --- a/tests/simulation/test_simulate_aot.py +++ b/tests/simulation/test_simulate_aot.py @@ -76,6 +76,7 @@ def test_n_subjects_none_leaves_aot_cache_empty_after_simulate() -> None: params = get_params(n_periods=n_periods) model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=None, initial_conditions=_build_initial_conditions(n_subjects=4), @@ -91,6 +92,7 @@ def test_n_subjects_none_yields_simulate_result_sized_to_actual() -> None: params = get_params(n_periods=n_periods) result = model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=None, initial_conditions=_build_initial_conditions(n_subjects=4), @@ -107,7 +109,7 @@ def test_simulate_second_matching_call_does_not_invoke_compile( n_subjects = 4 model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) params = get_params(n_periods=n_periods) - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) counter = {"count": 0} original_compile = jax.stages.Lowered.compile @@ -123,6 +125,7 @@ def counting_compile( initial_conditions = _build_initial_conditions(n_subjects=n_subjects) model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions=initial_conditions, @@ -130,6 +133,7 @@ def counting_compile( counter["count"] = 0 model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions=initial_conditions, @@ -144,11 +148,12 @@ def test_simulate_first_matching_call_populates_aot_cache() -> None: n_subjects = 4 model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) params = get_params(n_periods=n_periods) - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) assert n_subjects not in model._simulate_compile_cache model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions=_build_initial_conditions(n_subjects=n_subjects), @@ -168,7 +173,7 @@ def test_solve_does_not_populate_simulate_compile_cache() -> None: model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) params = get_params(n_periods=n_periods) - model.solve(params=params) + model.solve(log_level="debug", params=params) assert dict(model._simulate_compile_cache) == {} @@ -193,10 +198,11 @@ def _mismatch_outcome( n_periods = 3 model = _build_test_model(n_periods=n_periods, n_subjects=_DECLARED_N) params = get_params(n_periods=n_periods) - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) with caplog.at_level(logging.WARNING, logger="lcm"): model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions=_build_initial_conditions(n_subjects=_ACTUAL_N), @@ -247,16 +253,18 @@ def test_simulate_warns_only_once_per_mismatching_size( n_periods = 3 model = _build_test_model(n_periods=n_periods, n_subjects=_DECLARED_N) params = get_params(n_periods=n_periods) - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) initial_conditions = _build_initial_conditions(n_subjects=_ACTUAL_N) with caplog.at_level(logging.WARNING, logger="lcm"): model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions=initial_conditions, ) model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions=initial_conditions, @@ -285,6 +293,7 @@ def test_simulate_result_pickles_when_n_subjects_matches() -> None: params = get_params(n_periods=n_periods) result = model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=None, initial_conditions=_build_initial_conditions(n_subjects=n_subjects), @@ -300,11 +309,12 @@ def test_unpickled_model_can_simulate_with_aot() -> None: n_subjects = 4 model = _build_test_model(n_periods=n_periods, n_subjects=n_subjects) params = get_params(n_periods=n_periods) - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) initial_conditions = _build_initial_conditions(n_subjects=n_subjects) # Populate the AOT cache before pickling — confirms __getstate__ drops it. model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions=initial_conditions, @@ -320,6 +330,7 @@ def test_unpickled_model_can_simulate_with_aot() -> None: # Simulate works post-unpickle and re-populates the cache for that size. restored.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=period_to_regime_to_V_arr, initial_conditions=initial_conditions, diff --git a/tests/solution/test_beta_delta.py b/tests/solution/test_beta_delta.py index 7fc172646..f57cec5b9 100644 --- a/tests/solution/test_beta_delta.py +++ b/tests/solution/test_beta_delta.py @@ -196,6 +196,7 @@ def test_beta_delta_consumption(label, beta, delta): }, } result = model.simulate( + log_level="debug", params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, @@ -204,11 +205,12 @@ def test_beta_delta_consumption(label, beta, delta): model = _make_model() # Solve with exponential discounting (beta=1) solve_params = {"working": {"H": {"beta": 1.0, "delta": delta}}} - period_to_regime_to_V_arr = model.solve(params=solve_params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=solve_params) # Simulate with present-biased params sim_params = {"working": {"H": h_params}} result = model.simulate( + log_level="debug", params=sim_params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=period_to_regime_to_V_arr, @@ -216,6 +218,7 @@ def test_beta_delta_consumption(label, beta, delta): else: model = _make_model() result = model.simulate( + log_level="debug", params={"working": {"H": h_params}}, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, diff --git a/tests/solution/test_custom_aggregator.py b/tests/solution/test_custom_aggregator.py index db841ce04..309393637 100644 --- a/tests/solution/test_custom_aggregator.py +++ b/tests/solution/test_custom_aggregator.py @@ -195,8 +195,8 @@ def test_custom_ces_aggregator_differs_from_default(): "dead": {}, } - V_default = model_default.solve(params=params_default) - V_ces = model_ces.solve(params=params_ces) + V_default = model_default.solve(log_level="debug", params=params_default) + V_ces = model_ces.solve(log_level="debug", params=params_ces) # The value functions should differ because the aggregation rule differs has_difference = False @@ -285,8 +285,8 @@ def test_terminal_regime_value_unchanged_by_H(): "dead": {}, } - V_default = model_default.solve(params=params_default) - V_ces = model_ces.solve(params=params_ces) + V_default = model_default.solve(log_level="debug", params=params_default) + V_ces = model_ces.solve(log_level="debug", params=params_ces) # Last period is terminal — value functions should be identical last_period = max(V_default.keys()) @@ -332,7 +332,7 @@ def test_dag_output_feeds_default_h_monotone_in_discount_factor(): "next_regime": {"final_age_alive": FINAL_AGE_ALIVE}, }, } - V = model.solve(params=params) + V = model.solve(log_level="debug", params=params) # Pick a non-terminal period; slice each pref_type. non_terminal_periods = [p for p in V if p < max(V.keys())] @@ -381,6 +381,7 @@ def test_h_consumes_continuous_state(): "next_regime": {"final_age_alive": FINAL_AGE_ALIVE}, } V_zero = model.solve( + log_level="debug", params={ "working_life": { "H": {"discount_factor": 0.95, "wealth_weight": 0.0}, @@ -390,6 +391,7 @@ def test_h_consumes_continuous_state(): }, ) V_pos = model.solve( + log_level="debug", params={ "working_life": { "H": {"discount_factor": 0.95, "wealth_weight": 0.1}, @@ -429,6 +431,7 @@ def test_h_consumes_continuous_action(): "next_regime": {"final_age_alive": FINAL_AGE_ALIVE}, } V_zero = model.solve( + log_level="debug", params={ "working_life": { "H": {"discount_factor": 0.95, "action_weight": 0.0}, @@ -438,6 +441,7 @@ def test_h_consumes_continuous_action(): }, ) V_pos = model.solve( + log_level="debug", params={ "working_life": { "H": {"discount_factor": 0.95, "action_weight": 0.1}, @@ -475,6 +479,7 @@ def test_h_consumes_discrete_action(): """ model = _make_model(custom_H=labor_supply_H) V = model.solve( + log_level="debug", params={ "working_life": { "H": {"discount_factor": 0.95, "bonus": 0.1}, @@ -485,6 +490,7 @@ def test_h_consumes_discrete_action(): }, ) baseline = _make_model().solve( + log_level="debug", params={ "discount_factor": 0.95, "working_life": { @@ -518,6 +524,7 @@ def test_h_consumes_discrete_state(): """ model = _make_model(custom_H=pref_type_direct_H, with_pref_type=True) V = model.solve( + log_level="debug", params={ "discount_factor_by_type": jnp.array([0.70, 0.85, 0.99]), "working_life": { @@ -561,6 +568,7 @@ def test_h_consumes_flat_param_state_action_and_dag_output(): """ model = _make_model(custom_H=mixed_H, with_pref_type=True) V = model.solve( + log_level="debug", params={ "discount_factor_by_type": jnp.array([0.70, 0.85, 0.99]), "working_life": { diff --git a/tests/solution/test_diagnostics.py b/tests/solution/test_diagnostics.py index ab121ae1f..df9db506d 100644 --- a/tests/solution/test_diagnostics.py +++ b/tests/solution/test_diagnostics.py @@ -94,7 +94,18 @@ def test_nan_failure_raises_with_regime_and_age(): model = _make_model() params = {**_HEALTHY_PARAMS, "discount_factor": float("nan")} with pytest.raises(InvalidValueFunctionError, match=r"alive"): + model.solve(params=params, log_level="debug") + + +def test_nan_failure_warns_and_continues_at_warning_level( + caplog: pytest.LogCaptureFixture, +): + """At log_level="warning", a NaN value function logs a warning, no raise.""" + model = _make_model() + params = {**_HEALTHY_PARAMS, "discount_factor": float("nan")} + with caplog.at_level(logging.WARNING, logger="lcm"): model.solve(params=params, log_level="warning") + assert "NaN" in caplog.text def test_off_level_solves_without_diagnostics(caplog: pytest.LogCaptureFixture): diff --git a/tests/solution/test_solve_brute.py b/tests/solution/test_solve_brute.py index 9d8bf38be..780edfc8d 100644 --- a/tests/solution/test_solve_brute.py +++ b/tests/solution/test_solve_brute.py @@ -144,7 +144,7 @@ def _Q_and_F( flat_params=MappingProxyType({"default": flat_params}), ages=AgeGrid(start=0, stop=2, step="Y"), regimes=MappingProxyType({"default": regime}), - logger=get_logger(log_level="off"), + logger=get_logger(log_level="debug"), enable_jit=False, ) @@ -205,7 +205,7 @@ def _Q_and_F(a, c, b, d, next_regime_to_V_arr, period, age): # noqa: ARG001 flat_params=MappingProxyType({"default": MappingProxyType({})}), ages=AgeGrid(start=0, stop=2, step="Y"), regimes=MappingProxyType({"default": regime}), - logger=get_logger(log_level="off"), + logger=get_logger(log_level="debug"), enable_jit=False, ) diff --git a/tests/test_Q_and_F.py b/tests/test_Q_and_F.py index 648b5b0d8..4325f2463 100644 --- a/tests/test_Q_and_F.py +++ b/tests/test_Q_and_F.py @@ -358,7 +358,7 @@ def _next_regime(age: float) -> ScalarInt: ) model, params = _build_incomplete_target_model(next_regime_func=_next_regime) - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) for regime_to_V_arr in period_to_regime_to_V_arr.values(): for V_arr in regime_to_V_arr.values(): assert not jnp.any(jnp.isnan(V_arr)) diff --git a/tests/test_ages.py b/tests/test_ages.py index d80d1cc3f..246822208 100644 --- a/tests/test_ages.py +++ b/tests/test_ages.py @@ -146,6 +146,7 @@ def test_model_with_quarterly_steps(): # Solve and simulate result = model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.array([50.0, 100.0, 150.0]), @@ -289,6 +290,7 @@ def test_model_with_integer_ages(): } result = model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.array([50.0, 100.0, 150.0]), diff --git a/tests/test_analytical_solution.py b/tests/test_analytical_solution.py index 59e9a59ee..010e3861c 100644 --- a/tests/test_analytical_solution.py +++ b/tests/test_analytical_solution.py @@ -57,7 +57,7 @@ def test_analytical_solution(model_name, model_and_params): params = model_and_params["params"] period_to_regime_to_V_arr: dict[int, dict[str, FloatND]] = model.solve( - params=params + log_level="debug", params=params ) period_to_regime_to_V_arr_list = [ diff --git a/tests/test_beartype_claw.py b/tests/test_beartype_claw.py index 1efe400e7..4342b9660 100644 --- a/tests/test_beartype_claw.py +++ b/tests/test_beartype_claw.py @@ -32,7 +32,7 @@ RegimeInitializationError, ) from lcm.interfaces import _build_regime_sharding -from lcm.model import _validate_log_args +from lcm.model import _contains_nan from lcm.simulation.simulate import _compute_starting_periods from lcm.solution.solve_brute import _log_per_period_stats from lcm.state_action_space import _validate_all_states_present @@ -114,10 +114,7 @@ def test_regime_with_bad_arg_raises_project_exception() -> None: def test_claw_checks_lcm_model() -> None: """Type-violating arguments to internal `lcm.model` helpers raise.""" with pytest.raises(BeartypeCallHintViolation): - _validate_log_args( - log_level="progress", - log_path=123, # ty: ignore[invalid-argument-type] - ) + _contains_nan(123) # ty: ignore[invalid-argument-type] def test_model_with_bad_arg_raises_project_exception() -> None: diff --git a/tests/test_chained_state_transitions.py b/tests/test_chained_state_transitions.py index 461000fa9..15c7dce0d 100644 --- a/tests/test_chained_state_transitions.py +++ b/tests/test_chained_state_transitions.py @@ -86,7 +86,7 @@ def test_solve_with_chained_transitions_returns_finite_value_function() -> None: model = _build_model() params = {"discount_factor": 0.9, "final_age_alive": 1.0} - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) V_active = period_to_regime_to_V_arr[0]["active"] assert jnp.all(jnp.isfinite(V_active)) @@ -113,6 +113,7 @@ def test_simulate_with_chained_transitions_yields_expected_next_wealth() -> None df = ( model.simulate( + log_level="debug", params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 40b870b72..9d5186196 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -149,6 +149,7 @@ def test_solution_running_on_multiple_cpus(correct_distributed_model): """Test that distribution over multiple CPU's works for solution.""" period_to_regime_to_V_arr = correct_distributed_model.solve( + log_level="debug", params={"discount_factor": 0.95}, ) @@ -160,6 +161,7 @@ def test_simulation_running_on_multiple_cpus(correct_distributed_model): """Test that distribution over multiple CPU's works for simulation.""" res = correct_distributed_model.simulate( + log_level="debug", params={"discount_factor": 0.95}, initial_conditions={ "age": jnp.full(36, 0), @@ -185,6 +187,7 @@ def test_solution_error_if_grid_product_exceeds_devices(wrong_distributed_model) with pytest.raises(PyLCMError, match="must equal the number"): wrong_distributed_model.solve( + log_level="debug", params={"discount_factor": 0.95}, ) @@ -195,6 +198,7 @@ def test_simulation_error_if_not_multiple(correct_distributed_model): with pytest.raises(PyLCMError, match="multiple"): correct_distributed_model.simulate( + log_level="debug", params={"discount_factor": 0.95}, initial_conditions={ "age": jnp.full(5, 0), @@ -269,6 +273,7 @@ def test_solve_with_partial_distribution_returns_correct_shardings( undistributed regime's V-array carries no per-axis sharding (single device). """ period_to_regime_to_V_arr = partially_distributed_model.solve( + log_level="debug", params={"discount_factor": 0.95}, ) assert period_to_regime_to_V_arr[0]["working_life"].sharding.num_devices == 4 diff --git a/tests/test_economic_validation.py b/tests/test_economic_validation.py index 85ab313b1..d7c9f5c3b 100644 --- a/tests/test_economic_validation.py +++ b/tests/test_economic_validation.py @@ -24,6 +24,7 @@ def _simulate(shock_type, *, sigma, rho=0.0, mu=0.0): params = get_params(shock_type, sigma=sigma, mu=mu, rho=rho) unconditional_mean = mu / (1 - rho) result = model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.full(_N_SUBJECTS, 5.0), diff --git a/tests/test_fgp_model.py b/tests/test_fgp_model.py index 301107819..a33d022f6 100644 --- a/tests/test_fgp_model.py +++ b/tests/test_fgp_model.py @@ -40,6 +40,7 @@ def _simulate(shock_type): discount_factor=BETA, ) result = model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.full(_N_SUBJECTS, 5.0), @@ -72,7 +73,7 @@ def test_model_solves(shock_type): interest_rate=R, discount_factor=BETA, ) - V = model.solve(params=params) + V = model.solve(log_level="debug", params=params) assert V is not None # Value functions include all periods (n_periods + 1 ages from AgeGrid) assert len(V) == model.n_periods diff --git a/tests/test_float_dtype_invariants.py b/tests/test_float_dtype_invariants.py index 2aae6e6d8..cca7f9b7b 100644 --- a/tests/test_float_dtype_invariants.py +++ b/tests/test_float_dtype_invariants.py @@ -174,7 +174,10 @@ def test_simulate_state_pool_dtype_stable_across_periods(x64_disabled: None): } result = model.simulate( - params=params, period_to_regime_to_V_arr=None, initial_conditions=initial + log_level="debug", + params=params, + period_to_regime_to_V_arr=None, + initial_conditions=initial, ) seen: dict[str, set] = {} @@ -189,7 +192,9 @@ def test_simulate_state_pool_dtype_stable_across_periods(x64_disabled: None): def test_solve_v_arrays_at_canonical_float_dtype(x64_disabled: None): """Every V-array returned by `model.solve()` is at `canonical_float_dtype()`.""" model = get_model(n_periods=3) - period_to_regime_to_V_arr = model.solve(params=get_params(n_periods=3)) + period_to_regime_to_V_arr = model.solve( + log_level="debug", params=get_params(n_periods=3) + ) target = canonical_float_dtype() wrong = { (period, regime_name): v_arr.dtype diff --git a/tests/test_grids.py b/tests/test_grids.py index 2db7c6e61..89c753257 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -812,3 +812,65 @@ def test_piecewise_log_spaced_grid_get_coordinate_with_array(): coords = grid.get_coordinate(values) expected = jnp.array([0.0, 1.0, 2.0]) aaae(coords, expected, decimal=DECIMAL_PRECISION) + + +@pytest.mark.parametrize( + "make_grid", + [ + pytest.param( + lambda **kw: LinSpacedGrid(start=1, stop=10, n_points=4, **kw), + id="LinSpacedGrid", + ), + pytest.param( + lambda **kw: LogSpacedGrid(start=1, stop=10, n_points=4, **kw), + id="LogSpacedGrid", + ), + pytest.param( + lambda **kw: IrregSpacedGrid(points=[1.0, 2.0, 3.0, 4.0], **kw), + id="IrregSpacedGrid", + ), + pytest.param( + lambda **kw: DiscreteGrid( + _make_dc("_BS", ("a", jnp.int32(0)), ("b", jnp.int32(1))), **kw + ), + id="DiscreteGrid", + ), + ], +) +def test_grid_rejects_batch_size_combined_with_distributed(make_grid): + """`batch_size > 0` and `distributed=True` on one axis is rejected at init. + + Each Python-level batch triggers its own per-period cross-device + collective in the sharded solve, so the combination multiplies the + sync count by `ceil(n_per_device / batch_size)` and inverts the + compute/communication ratio. Construction-time rejection prevents + the foot-gun. + """ + with pytest.raises(GridInitializationError, match="distributed=True"): + make_grid(batch_size=1, distributed=True) + + +@pytest.mark.parametrize( + "make_grid", + [ + pytest.param( + lambda: LinSpacedGrid( + start=1, stop=10, n_points=4, batch_size=0, distributed=True + ), + id="LinSpacedGrid", + ), + pytest.param( + lambda: DiscreteGrid( + _make_dc("_OK", ("a", jnp.int32(0)), ("b", jnp.int32(1))), + batch_size=0, + distributed=True, + ), + id="DiscreteGrid", + ), + ], +) +def test_grid_accepts_batch_size_zero_with_distributed(make_grid): + """`batch_size=0` with `distributed=True` is the canonical sharded setting.""" + grid = make_grid() + assert grid.distributed is True + assert grid.batch_size == 0 diff --git a/tests/test_heterogeneous_initial_ages.py b/tests/test_heterogeneous_initial_ages.py index 47cee4356..b5a451af4 100644 --- a/tests/test_heterogeneous_initial_ages.py +++ b/tests/test_heterogeneous_initial_ages.py @@ -18,6 +18,7 @@ def test_simulation_with_heterogeneous_initial_ages(): # Subject 0 starts at age 40, subject 1 starts at age 60 result = model.simulate( + log_level="debug", params=params, initial_conditions={ "age": jnp.array([40.0, 60.0]), diff --git a/tests/test_int_dtype_invariants.py b/tests/test_int_dtype_invariants.py index ee0e557c2..cbe773e6a 100644 --- a/tests/test_int_dtype_invariants.py +++ b/tests/test_int_dtype_invariants.py @@ -252,11 +252,13 @@ def test_simulate_accepts_int64_regime_initial_condition_and_round_trips() -> No } df_int32 = model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=None, initial_conditions=initial_conditions_int32, ).to_dataframe() df_int64 = model.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=None, initial_conditions=initial_conditions_int64, diff --git a/tests/test_invalid_vf.py b/tests/test_invalid_vf.py index 04854a1a2..e5673aa61 100644 --- a/tests/test_invalid_vf.py +++ b/tests/test_invalid_vf.py @@ -167,21 +167,21 @@ def test_solve_model_with_nan_value_function_array_raises_error( nan_value_model: Model, params: UserParams ) -> None: with pytest.raises(InvalidValueFunctionError): - nan_value_model.solve(params=params) + nan_value_model.solve(log_level="debug", params=params) def test_solve_model_with_inf_value_function_does_not_raise_error( inf_value_model: Model, params: UserParams ) -> None: # This should not raise an error - inf_value_model.solve(params=params) + inf_value_model.solve(log_level="debug", params=params) def test_simulate_model_with_nan_value_function_array_raises_error( nan_value_model: Model, params: UserParams ) -> None: initial_conditions = { - "wealth": jnp.array([0.9, 1.0]), + "wealth": jnp.array([1.5, 2.0]), "health": jnp.array([1.0, 1.0]), "age": jnp.array([0.0, 0.0]), "regime_id": jnp.array([RegimeId.non_terminal] * 2), @@ -189,10 +189,10 @@ def test_simulate_model_with_nan_value_function_array_raises_error( with pytest.raises(InvalidValueFunctionError): nan_value_model.simulate( + log_level="debug", params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, - check_initial_conditions=False, ) @@ -210,6 +210,7 @@ def test_simulate_model_with_inf_value_function_array_does_not_raise_error( # +inf utility term (wealth > 1.9 AND health > 0.9), but the simulation should # still complete without error. inf_value_model.simulate( + log_level="debug", params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, @@ -220,11 +221,11 @@ def test_nan_error_includes_regime_name( nan_value_model: Model, params: UserParams ) -> None: with pytest.raises(InvalidValueFunctionError, match="non_terminal"): - nan_value_model.solve(params=params) + nan_value_model.solve(log_level="debug", params=params) def test_nan_error_includes_nan_count( nan_value_model: Model, params: UserParams ) -> None: with pytest.raises(InvalidValueFunctionError, match=r"\d+ of \d+ values are NaN"): - nan_value_model.solve(params=params) + nan_value_model.solve(log_level="debug", params=params) diff --git a/tests/test_nan_diagnostics.py b/tests/test_nan_diagnostics.py index 2e8947d88..7b3610609 100644 --- a/tests/test_nan_diagnostics.py +++ b/tests/test_nan_diagnostics.py @@ -158,7 +158,7 @@ def test_nan_diagnostics_end_to_end() -> None: model, params = _build_nan_model() with pytest.raises(InvalidValueFunctionError) as exc_info: - model.solve(params=params) + model.solve(log_level="debug", params=params) exc = exc_info.value assert exc.diagnostics is not None, ( diff --git a/tests/test_pandas_utils.py b/tests/test_pandas_utils.py index c845a4c31..e0a5bbf70 100644 --- a/tests/test_pandas_utils.py +++ b/tests/test_pandas_utils.py @@ -391,6 +391,7 @@ def test_round_trip_with_discrete_model(): "regime_id": jnp.array([RegimeId.working_life, RegimeId.working_life]), } result_raw = model.simulate( + log_level="debug", params=params, initial_conditions=raw_conditions, period_to_regime_to_V_arr=None, @@ -410,6 +411,7 @@ def test_round_trip_with_discrete_model(): regime_names_to_ids=model.regime_names_to_ids, ) result_df = model.simulate( + log_level="debug", params=params, initial_conditions=df_conditions, period_to_regime_to_V_arr=None, @@ -732,6 +734,7 @@ def test_heterogeneous_health_solve_simulate() -> None: regime_names_to_ids=model.regime_names_to_ids, ) result = model.simulate( + log_level="debug", params={"bonus": 0.0, "discount_factor": 0.95}, initial_conditions=ic, period_to_regime_to_V_arr=None, @@ -771,6 +774,7 @@ def test_heterogeneous_health_simulate_use_labels_false() -> None: regime_names_to_ids=model.regime_names_to_ids, ) result = model.simulate( + log_level="debug", params={"bonus": 0.0, "discount_factor": 0.95}, initial_conditions=ic, period_to_regime_to_V_arr=None, diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 0c8de08b0..05057ff9b 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -235,10 +235,10 @@ def test_simulate_no_persistence_when_not_debug(tmp_path, model_and_params): assert len(list(tmp_path.iterdir())) == 0 -def test_debug_without_log_path_raises(model_and_params): +def test_debug_without_log_path_solves(model_and_params): + """`log_level="debug"` runs without `log_path` — it just writes no snapshot.""" model, params = model_and_params - with pytest.raises(ValueError, match="log_path is required"): - model.solve(params=params, log_level="debug") + model.solve(params=params, log_level="debug") def test_log_keep_n_latest_deletes_old_snapshots(tmp_path, model_and_params): diff --git a/tests/test_regime_state_mismatch.py b/tests/test_regime_state_mismatch.py index c5fadd564..f2f889479 100644 --- a/tests/test_regime_state_mismatch.py +++ b/tests/test_regime_state_mismatch.py @@ -184,8 +184,9 @@ def next_regime(age: float) -> ScalarInt: ) params = {"discount_factor": 0.95} - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) result = model.simulate( + log_level="debug", params=params, initial_conditions={ "age": jnp.array([0.0, 0.0]), @@ -276,8 +277,9 @@ def utility_dead(wealth: ContinuousState, heir_present: DiscreteState) -> FloatN ) params = {"discount_factor": 0.95} - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) result = model.simulate( + log_level="debug", params=params, initial_conditions={ "age": jnp.array([0.0, 0.0]), @@ -350,7 +352,7 @@ def test_per_target_dict_transitions(): ) params = {"discount_factor": 0.95} - period_to_regime_to_V_arr = model.solve(params=params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=params) n_subjects = 4 # Use codes 0 (disabled) and 1 (bad) — valid in both regimes. @@ -365,6 +367,7 @@ def test_per_target_dict_transitions(): ) result = model.simulate( + log_level="debug", params=params, initial_conditions={ "age": jnp.zeros(n_subjects), @@ -684,7 +687,7 @@ def next_regime_a(age: float) -> ScalarInt: InvalidRegimeTransitionProbabilitiesError, match=r"does not provide state transition", ): - model.solve(params={"discount_factor": 0.95}) + model.solve(log_level="debug", params={"discount_factor": 0.95}) def test_complete_per_target_stochastic_cross_grid() -> None: @@ -755,7 +758,7 @@ def next_regime_a(age: float) -> ScalarInt: ages=AgeGrid(start=0, stop=4, step="Y"), regime_id_class=_RegimeId, ) - model.solve(params={"discount_factor": 0.95}) + model.solve(log_level="debug", params={"discount_factor": 0.95}) def test_incomplete_per_target_unreachable_target() -> None: @@ -874,4 +877,4 @@ def next_regime_b(age: float) -> ScalarInt: ages=AgeGrid(start=0, stop=4, step="Y"), regime_id_class=_RegimeId, ) - model.solve(params={"discount_factor": 0.95}) + model.solve(log_level="debug", params={"discount_factor": 0.95}) diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index 887f0edf2..532ebfa8f 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -53,8 +53,11 @@ def test_regression_test(): interest_rate=0.05, ) - got_solve: Mapping[int, Mapping[str, FloatND]] = model.solve(params=params) + got_solve: Mapping[int, Mapping[str, FloatND]] = model.solve( + log_level="debug", params=params + ) got_simulate = model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.array([5.0, 20, 40, 70]), @@ -279,6 +282,7 @@ def test_model_with_different_grid_types(grid_type: str): # This should complete without error result = model.simulate( + log_level="debug", params=params, initial_conditions={ "wealth": jnp.array([5.0, 20, 40, 70]), diff --git a/tests/test_runtime_params.py b/tests/test_runtime_params.py index bbc8dc476..7e6e19b6d 100644 --- a/tests/test_runtime_params.py +++ b/tests/test_runtime_params.py @@ -310,7 +310,6 @@ def test_simulate_with_runtime_action_grid_no_nan() -> None: initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, log_level="off", - check_initial_conditions=False, ) df = result.to_dataframe() assert not df["value"].isna().any() diff --git a/tests/test_shock_grids.py b/tests/test_shock_grids.py index be478e049..2e475be83 100644 --- a/tests/test_shock_grids.py +++ b/tests/test_shock_grids.py @@ -29,10 +29,12 @@ def test_model_with_shock(distribution_type): params = get_params(distribution_type) got_solve = model.solve( + log_level="debug", params=params, ) got_simulate = model.simulate( + log_level="debug", params=params, initial_conditions={ "health": jnp.asarray([0, 0], dtype=jnp.int32), @@ -82,6 +84,7 @@ def test_model_with_cross_regime_shocks(distribution_type: str) -> None: params = get_multi_regime_params(distribution_type) # ty: ignore[invalid-argument-type] result = model.simulate( + log_level="debug", params=params, initial_conditions={ "health": jnp.zeros(2, dtype=jnp.int32), diff --git a/tests/test_single_feasible_action.py b/tests/test_single_feasible_action.py index 350bfb72e..adeab6500 100644 --- a/tests/test_single_feasible_action.py +++ b/tests/test_single_feasible_action.py @@ -241,7 +241,6 @@ def test_simulate_with_constrained_action_grid(wealth_lo, consumption_lo, label) params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, - check_initial_conditions=False, log_level="off", ) df = result.to_dataframe() @@ -529,13 +528,14 @@ def test_runtime_action_grid_passes_initial_conditions_validation(): [RegimeId.alive, RegimeId.alive, RegimeId.alive], dtype=jnp.int32 ), } - # `check_initial_conditions=True` (the default) must pass — the - # runtime-supplied consumption points are well-formed. + # `log_level="debug"` runs initial-conditions validation and raises on + # failure — the runtime-supplied consumption points are well-formed, so + # it must pass. result = model.simulate( params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, - log_level="off", + log_level="debug", ) assert result.n_subjects == 3 diff --git a/tests/test_solution_on_toy_model_deterministic.py b/tests/test_solution_on_toy_model_deterministic.py index 79a2e1bb4..1f1258208 100644 --- a/tests/test_solution_on_toy_model_deterministic.py +++ b/tests/test_solution_on_toy_model_deterministic.py @@ -235,7 +235,8 @@ def test_deterministic_solve(discount_factor, n_wealth_points): "next_regime": {"final_age_alive": model.n_periods - 2}, } got = model.solve( - params={"discount_factor": discount_factor, "alive": params_alive} + log_level="debug", + params={"discount_factor": discount_factor, "alive": params_alive}, ) wealth_grid_class = cast("LinSpacedGrid", new_states["wealth"]) @@ -282,6 +283,7 @@ def test_deterministic_simulate(discount_factor, n_wealth_points): "next_regime": {"final_age_alive": model.n_periods - 2}, } result = model.simulate( + log_level="debug", params={"discount_factor": discount_factor, "alive": params_alive}, initial_conditions={ "wealth": jnp.array([0.25, 0.75, 1.25, 1.75]), diff --git a/tests/test_solution_on_toy_model_stochastic.py b/tests/test_solution_on_toy_model_stochastic.py index d4a3ba58e..37e60f714 100644 --- a/tests/test_solution_on_toy_model_stochastic.py +++ b/tests/test_solution_on_toy_model_stochastic.py @@ -233,7 +233,9 @@ def test_stochastic_solve(discount_factor, n_wealth_points, probs_array): "next_health": {"probs_array": probs_array}, "next_regime": {"final_age_alive": model.n_periods - 2}, } - got = model.solve(params={"discount_factor": discount_factor, "alive": params}) + got = model.solve( + log_level="debug", params={"discount_factor": discount_factor, "alive": params} + ) wealth_grid_class = cast("LinSpacedGrid", new_states["wealth"]) _wealth_grid = np.linspace( @@ -299,6 +301,7 @@ def test_stochastic_simulate(discount_factor, n_wealth_points, probs_array): "regime_id": jnp.array([RegimeId.alive] * 5), } result = model.simulate( + log_level="debug", params={"discount_factor": discount_factor, "alive": params_alive}, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None, diff --git a/tests/test_stochastic.py b/tests/test_stochastic.py index 0bd0de541..cb3f2a6af 100644 --- a/tests/test_stochastic.py +++ b/tests/test_stochastic.py @@ -39,6 +39,7 @@ def test_model_simulate_with_stochastic_model(): params = get_params(n_periods=4) result = model.simulate( + log_level="debug", params=params, initial_conditions={ "health": jnp.array([1, 1, 0, 0], dtype=jnp.int32), @@ -80,7 +81,7 @@ def test_model_simulate_with_stochastic_model(): def test_model_solve_with_stochastic_model(): model = get_model(n_periods=4) - model.solve(params=get_params(n_periods=4)) + model.solve(log_level="debug", params=get_params(n_periods=4)) @pytest.fixture @@ -173,10 +174,10 @@ def test_compare_deterministic_and_stochastic_results_value_function( # Compare value function arrays # ================================================================================== solution_deterministic: Mapping[int, Mapping[str, FloatND]] = ( - model_deterministic.solve(params=params) + model_deterministic.solve(log_level="debug", params=params) ) solution_stochastic: Mapping[int, Mapping[str, FloatND]] = model_stochastic.solve( - params=params + log_level="debug", params=params ) for period in range(model_deterministic.n_periods - 1): @@ -198,11 +199,13 @@ def test_compare_deterministic_and_stochastic_results_value_function( } simulation_deterministic = model_deterministic.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=solution_deterministic, initial_conditions=initial_conditions, ) simulation_stochastic = model_stochastic.simulate( + log_level="debug", params=params, period_to_regime_to_V_arr=solution_stochastic, initial_conditions=initial_conditions, @@ -294,7 +297,7 @@ def next_draw_no_args() -> FloatND: "discount_factor": 0.95, "working_life": {"next_regime": {"final_age_alive": 1}}, } - V = model.solve(params=params) + V = model.solve(log_level="debug", params=params) assert all( jnp.all(jnp.isfinite(V[p]["working_life"])) for p in V if "working_life" in V[p] ) @@ -316,7 +319,7 @@ def next_draw_continuous(wealth: ContinuousState) -> FloatND: "discount_factor": 0.95, "working_life": {"next_regime": {"final_age_alive": 1}}, } - V = model.solve(params=params) + V = model.solve(log_level="debug", params=params) assert all( jnp.all(jnp.isfinite(V[p]["working_life"])) for p in V if "working_life" in V[p] ) @@ -344,4 +347,4 @@ def test_stochastic_regime_transition_active_at_last_period_raises(): InvalidRegimeTransitionProbabilitiesError, match=r"Non-terminal regime.*active at the last period", ): - model.solve(params=mortality.get_params(n_periods=4)) + model.solve(log_level="debug", params=mortality.get_params(n_periods=4)) diff --git a/tests/test_transition_checks.py b/tests/test_transition_checks.py new file mode 100644 index 000000000..c753a16b9 --- /dev/null +++ b/tests/test_transition_checks.py @@ -0,0 +1,504 @@ +"""Automatic state transition probability validation. + +Exercises the pre-solve numerical sweep over `MarkovTransition` state +transitions, the process-time AST subscript-order check, and the way the +`log_level` validation policy turns failures into warnings or raises. +""" + +import logging +from pathlib import Path + +import jax.numpy as jnp +import pytest + +from lcm import ( + AgeGrid, + DiscreteGrid, + LinSpacedGrid, + MarkovTransition, + Model, + categorical, +) +from lcm.exceptions import InvalidStateTransitionProbabilitiesError +from lcm.typing import ( + BoolND, + ContinuousAction, + ContinuousState, + DiscreteState, + FloatND, + ScalarInt, +) +from lcm.user_regime import Regime as UserRegime +from lcm.utils.logging import LogLevel + + +@categorical(ordered=False) +class _Health: + bad: ScalarInt + good: ScalarInt + + +@categorical(ordered=False) +class _RegimeId: + alive: ScalarInt + terminal: ScalarInt + + +WEALTH_GRID = LinSpacedGrid(start=1, stop=10, n_points=3) +CONSUMPTION_GRID = LinSpacedGrid(start=1, stop=5, n_points=3) + + +def _next_wealth(wealth: ContinuousState, consumption: ContinuousAction) -> FloatND: + return wealth - consumption + + +def _budget(wealth: ContinuousState, consumption: ContinuousAction) -> BoolND: + return consumption <= wealth + + +def _next_regime(age: float) -> ScalarInt: # noqa: ARG001 + # Alive is active only at age 0, so the next-period regime is always + # the terminal one — keeping this transition simple lets the tests + # focus on the state-transition validator rather than regime + # bookkeeping. + return jnp.asarray(_RegimeId.terminal) + + +def _utility_alive(consumption: ContinuousAction) -> FloatND: + return jnp.log(consumption) + + +def _utility_terminal(wealth: ContinuousState) -> FloatND: + return jnp.log(wealth) + + +def _terminal_regime() -> UserRegime: + return UserRegime( + transition=None, + functions={"utility": _utility_terminal}, + states={"wealth": WEALTH_GRID}, + active=lambda age: age >= 1, + ) + + +def _model_with_state_probs(next_health_func) -> Model: + alive = UserRegime( + states={"wealth": WEALTH_GRID, "health": DiscreteGrid(_Health)}, + actions={"consumption": CONSUMPTION_GRID}, + state_transitions={ + "wealth": _next_wealth, + "health": MarkovTransition(next_health_func), + }, + functions={"utility": _utility_alive}, + constraints={"budget": _budget}, + transition=_next_regime, + active=lambda age: age < 1, + ) + return Model( + regimes={"alive": alive, "terminal": _terminal_regime()}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=_RegimeId, + ) + + +def test_valid_state_probs_at_boundary_pass() -> None: + """Inclusive [0, 1] bounds and a row sum within the 1e-6 tolerance pass. + + For `health == good` the row is exactly `[0.0, 1.0]` — values at the + inclusive bounds. For `health == bad` the row sums to `1 - 5e-7`, just + inside the `atol=1e-6` row-sum tolerance. Validation must accept both + without raising. + """ + + def boundary_health_probs(health: DiscreteState) -> FloatND: + return jnp.where( + health == _Health.good, + jnp.array([0.0, 1.0]), + jnp.array([0.5, 0.4999995]), + ) + + model = _model_with_state_probs(boundary_health_probs) + model.solve(log_level="debug", params={"discount_factor": 0.95}) + + +def test_runtime_check_catches_invalidity_hidden_at_some_grid_points() -> None: + """An ensemble valid at some continuous-grid points and invalid at others raises. + + The `MarkovTransition` for `health` is conditioned on the continuous + `wealth` grid: it returns a valid row for `wealth <= 5` and a row summing + to 0.7 for `wealth > 5`. Only sweeping the full `wealth` grid surfaces the + failure — a spot check at the first grid point (`wealth == 1`) would pass. + """ + + def sneaky_health_probs( + wealth: ContinuousState, + health: DiscreteState, # noqa: ARG001 + ) -> FloatND: + return jnp.where( + wealth > 5.0, + jnp.array([0.5, 0.2]), + jnp.array([0.3, 0.7]), + ) + + model = _model_with_state_probs(sneaky_health_probs) + with pytest.raises(InvalidStateTransitionProbabilitiesError, match="sum to 1"): + model.solve(log_level="debug", params={"discount_factor": 0.95}) + + +def test_runtime_check_raises_on_wrong_outcome_axis_size() -> None: + """Wrong outcome-axis size (length 3 instead of 2) surfaces at solve.""" + + def too_many_outcomes(health: DiscreteState) -> FloatND: # noqa: ARG001 + return jnp.array([0.4, 0.4, 0.2]) + + model = _model_with_state_probs(too_many_outcomes) + with pytest.raises(InvalidStateTransitionProbabilitiesError, match="outcome axis"): + model.solve(log_level="debug", params={"discount_factor": 0.95}) + + +def test_runtime_check_raises_on_values_out_of_range() -> None: + """Negative or >1 probability values surface at solve.""" + + def negative_probs(health: DiscreteState) -> FloatND: # noqa: ARG001 + return jnp.array([1.2, -0.2]) + + model = _model_with_state_probs(negative_probs) + with pytest.raises(InvalidStateTransitionProbabilitiesError, match=r"\[0, 1\]"): + model.solve(log_level="debug", params={"discount_factor": 0.95}) + + +def test_runtime_check_raises_on_rows_not_summing_to_one() -> None: + """A row that sums to 0.7 surfaces at solve.""" + + def bad_sum_probs(health: DiscreteState) -> FloatND: # noqa: ARG001 + return jnp.array([0.5, 0.2]) + + model = _model_with_state_probs(bad_sum_probs) + with pytest.raises(InvalidStateTransitionProbabilitiesError, match="sum to 1"): + model.solve(log_level="debug", params={"discount_factor": 0.95}) + + +def test_log_level_off_skips_runtime_check() -> None: + """A model whose state probs violate sum-to-1 still solves at log_level='off'.""" + + def bad_sum_probs(health: DiscreteState) -> FloatND: # noqa: ARG001 + return jnp.array([0.5, 0.2]) + + model = _model_with_state_probs(bad_sum_probs) + # With log_level='off' the runtime numerical check is skipped — solve + # returns a (numerically dubious) V_arr rather than raising. + model.solve(params={"discount_factor": 0.95}, log_level="off") + + +@pytest.mark.parametrize("log_level", ["warning", "progress"]) +def test_warn_levels_log_invalid_probs_and_continue( + log_level: LogLevel, caplog: pytest.LogCaptureFixture +) -> None: + """At 'warning'/'progress', invalid state probs log a warning; solve continues.""" + + def bad_sum_probs(health: DiscreteState) -> FloatND: # noqa: ARG001 + return jnp.array([0.5, 0.2]) + + model = _model_with_state_probs(bad_sum_probs) + with caplog.at_level(logging.WARNING, logger="lcm"): + model.solve(params={"discount_factor": 0.95}, log_level=log_level) + assert "sum to 1" in caplog.text + + +def test_debug_level_raises_on_invalid_probs() -> None: + """At log_level='debug', invalid state probs raise rather than warn.""" + + def bad_sum_probs(health: DiscreteState) -> FloatND: # noqa: ARG001 + return jnp.array([0.5, 0.2]) + + model = _model_with_state_probs(bad_sum_probs) + with pytest.raises(InvalidStateTransitionProbabilitiesError, match="sum to 1"): + model.solve(params={"discount_factor": 0.95}, log_level="debug") + + +def test_subscript_order_swap_raises_at_process_time() -> None: + """Mismatched subscript order vs signature order raises at process time.""" + + @categorical(ordered=False) + class _Local: + bad: ScalarInt + good: ScalarInt + + @categorical(ordered=False) + class _LocalRegimeId: + alive: ScalarInt + terminal: ScalarInt + + def swapped_probs( + period: ScalarInt, + health: DiscreteState, + probs_array: FloatND, + ) -> FloatND: + # Subscripts in wrong order: signature is (period, health) but + # body indexes as [health, period]. + return probs_array[health, period] + + def _local_next_regime(age: float) -> ScalarInt: # noqa: ARG001 + return jnp.asarray(_LocalRegimeId.terminal) + + alive = UserRegime( + states={"wealth": WEALTH_GRID, "health": DiscreteGrid(_Local)}, + actions={"consumption": CONSUMPTION_GRID}, + state_transitions={ + "wealth": _next_wealth, + "health": MarkovTransition(swapped_probs), + }, + functions={"utility": _utility_alive}, + constraints={"budget": _budget}, + transition=_local_next_regime, + active=lambda age: age < 1, + ) + terminal = UserRegime( + transition=None, + functions={"utility": _utility_terminal}, + states={"wealth": WEALTH_GRID}, + active=lambda age: age >= 1, + ) + + with pytest.raises(InvalidStateTransitionProbabilitiesError, match="subscript"): + Model( + regimes={"alive": alive, "terminal": terminal}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=_LocalRegimeId, + ) + + +def test_ast_check_is_permissive_when_no_probs_array_subscript() -> None: + """A function without `probs_array[...]` survives the static AST check.""" + + def no_subscript_probs(health: DiscreteState) -> FloatND: + # No `probs_array[...]` pattern — the AST check silently skips. + return jnp.where( + health == _Health.good, + jnp.array([0.1, 0.9]), + jnp.array([0.7, 0.3]), + ) + + # Model construction must not raise just because the function lacks + # the subscript pattern; runtime numerical checks still apply. + model = _model_with_state_probs(no_subscript_probs) + model.solve(log_level="debug", params={"discount_factor": 0.95}) + + +def test_per_target_dict_validates_each_entry() -> None: + """Each MarkovTransition inside a per-target dict is validated independently.""" + + @categorical(ordered=False) + class _Heir: + no: ScalarInt + yes: ScalarInt + + @categorical(ordered=False) + class _RegId: + alive: ScalarInt + dead: ScalarInt + + def bad_heir_probs(wealth: ContinuousState) -> FloatND: # noqa: ARG001 + # Rows don't sum to 1 — should be caught even though heir_present + # lives in the target regime, not the source. + return jnp.array([0.5, 0.3]) + + def next_wealth_passthrough(wealth: ContinuousState) -> ContinuousState: + return wealth + + def _utility_alive(wealth: ContinuousState) -> FloatND: + return wealth + + def _utility_dead(wealth: ContinuousState, heir_present: DiscreteState) -> FloatND: + return wealth * heir_present + + def _to_dead(age: float) -> ScalarInt: # noqa: ARG001 + return jnp.asarray(_RegId.dead) + + alive = UserRegime( + functions={"utility": _utility_alive}, + states={"wealth": LinSpacedGrid(start=1, stop=10, n_points=3)}, + state_transitions={ + "wealth": next_wealth_passthrough, + "heir_present": {"dead": MarkovTransition(bad_heir_probs)}, + }, + transition=_to_dead, + active=lambda age: age < 1, + ) + dead = UserRegime( + transition=None, + functions={"utility": _utility_dead}, + states={ + "wealth": LinSpacedGrid(start=1, stop=10, n_points=3), + "heir_present": DiscreteGrid(_Heir), + }, + active=lambda age: age >= 1, + ) + model = Model( + regimes={"alive": alive, "dead": dead}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=_RegId, + ) + with pytest.raises(InvalidStateTransitionProbabilitiesError, match="sum to 1"): + model.solve(log_level="debug", params={"discount_factor": 0.95}) + + +def _good_health_probs(health: DiscreteState) -> FloatND: + return jnp.where( + health == _Health.good, + jnp.array([0.2, 0.8]), + jnp.array([0.6, 0.4]), + ) + + +@pytest.mark.parametrize( + ("log_level", "expect_snapshot"), + [ + ("off", False), + ("warning", False), + ("progress", False), + ("debug", True), + ], +) +def test_snapshot_written_only_at_debug_on_valid_solve( + log_level: LogLevel, + expect_snapshot: bool, # noqa: FBT001 + tmp_path: Path, +) -> None: + """With `log_path` set, a valid solve writes a snapshot only at `"debug"`. + + Pins the snapshot column of the `log_level` x `log_path` table for a solve + that produces no NaN: `"debug"` snapshots every call, the warn/off levels + do not. + """ + model = _model_with_state_probs(_good_health_probs) + model.solve( + params={"discount_factor": 0.95}, log_level=log_level, log_path=tmp_path + ) + snapshots = list(tmp_path.glob("solve_snapshot_*")) + assert bool(snapshots) is expect_snapshot + + +def test_warn_mode_writes_snapshot_on_nan_failure(tmp_path: Path) -> None: + """At `"warning"` with `log_path` set, a NaN solve writes a snapshot. + + Pins the "one per warned failure" snapshot-table cell: warn mode does not + raise, so the snapshot is the only on-disk record of the failed solve. + """ + model = _model_with_state_probs(_good_health_probs) + model.solve( + params={"discount_factor": float("nan")}, + log_level="warning", + log_path=tmp_path, + ) + assert list(tmp_path.glob("solve_snapshot_*")) + + +def test_model_with_no_markov_transitions_solves_normally() -> None: + """A purely deterministic model is unaffected (fast-exit in validator).""" + + alive = UserRegime( + states={"wealth": WEALTH_GRID}, + actions={"consumption": CONSUMPTION_GRID}, + state_transitions={"wealth": _next_wealth}, + functions={"utility": _utility_alive}, + constraints={"budget": _budget}, + transition=_next_regime, + active=lambda age: age < 1, + ) + model = Model( + regimes={"alive": alive, "terminal": _terminal_regime()}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=_RegimeId, + ) + model.solve(log_level="debug", params={"discount_factor": 0.95}) + + +def _model_with_fixed_param_health_probs() -> Model: + """Build a model whose `health` MarkovTransition reads from `fixed_params`. + + `transition_bias` lives in `fixed_params`, not the per-iteration + `params` dict. Solve sees it via `regime.resolved_fixed_params`; the + pre-solve numerical validator must do the same merge. + """ + + def health_probs(health: DiscreteState, transition_bias: float) -> FloatND: + good_row = jnp.array([0.5 - transition_bias, 0.5 + transition_bias]) + bad_row = jnp.array([0.5, 0.5]) + return jnp.where(health == _Health.good, good_row, bad_row) + + alive = UserRegime( + states={"wealth": WEALTH_GRID, "health": DiscreteGrid(_Health)}, + actions={"consumption": CONSUMPTION_GRID}, + state_transitions={ + "wealth": _next_wealth, + "health": MarkovTransition(health_probs), + }, + functions={"utility": _utility_alive}, + constraints={"budget": _budget}, + transition=_next_regime, + active=lambda age: age < 1, + ) + return Model( + regimes={"alive": alive, "terminal": _terminal_regime()}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=_RegimeId, + fixed_params={"transition_bias": 0.1}, + ) + + +def test_state_validator_resolves_params_from_fixed_params( + caplog: pytest.LogCaptureFixture, +) -> None: + """A MarkovTransition reading a `fixed_params` entry is numerically validated. + + The skip-and-warn branch must not fire just because the parameter + sits in `fixed_params` rather than the per-iteration `params` dict — + both belong to the namespace solve resolves against. + """ + model = _model_with_fixed_param_health_probs() + + with caplog.at_level(logging.WARNING, logger="lcm"): + model.solve(log_level="warning", params={"discount_factor": 0.95}) + + skips = [r for r in caplog.records if "not numerically validated" in r.message] + assert not skips, f"Validator skipped: {skips[0].message}" + + +def test_state_validator_catches_bad_probs_when_using_fixed_param() -> None: + """Invalid probs are still surfaced when the transition reads from `fixed_params`. + + Proves the merged-namespace fix doesn't just silence the skip-warning + but actually runs the numerical check. + """ + + def bad_health_probs(health: DiscreteState, transition_bias: float) -> FloatND: + # Bias is added to row 0 only, so `transition_bias=0.6` makes the + # `good` row sum to 1.6 — well outside the row-sum tolerance. + return jnp.where( + health == _Health.good, + jnp.array([0.5 + transition_bias, 0.5]), + jnp.array([0.5, 0.5]), + ) + + alive = UserRegime( + states={"wealth": WEALTH_GRID, "health": DiscreteGrid(_Health)}, + actions={"consumption": CONSUMPTION_GRID}, + state_transitions={ + "wealth": _next_wealth, + "health": MarkovTransition(bad_health_probs), + }, + functions={"utility": _utility_alive}, + constraints={"budget": _budget}, + transition=_next_regime, + active=lambda age: age < 1, + ) + model = Model( + regimes={"alive": alive, "terminal": _terminal_regime()}, + ages=AgeGrid(start=0, stop=2, step="Y"), + regime_id_class=_RegimeId, + fixed_params={"transition_bias": 0.6}, + ) + + with pytest.raises(InvalidStateTransitionProbabilitiesError): + model.solve(log_level="debug", params={"discount_factor": 0.95}) diff --git a/tests/test_regime.py b/tests/test_user_regime.py similarity index 86% rename from tests/test_regime.py rename to tests/test_user_regime.py index eba1025e2..2e474d4ac 100644 --- a/tests/test_regime.py +++ b/tests/test_user_regime.py @@ -23,10 +23,8 @@ from lcm.user_regime import ( MarkovTransition, _IdentityTransition, - validate_transition_probs, ) from lcm.user_regime import Regime as UserRegime -from tests.test_models.stochastic import get_model as get_stochastic_model def utility(consumption): @@ -418,70 +416,12 @@ def fixed_next_regime(age: float, final_age_alive: float) -> ScalarInt: regime_id_class=FixedRegimeId, ) V = model.solve( + log_level="debug", params={ "discount_factor": 0.95, "working_life": {"next_regime": {"final_age_alive": final_age}}, - } + }, ) assert all( jnp.all(jnp.isfinite(V[p]["working_life"])) for p in V if "working_life" in V[p] ) - - -def _make_partner_probs_array(): - """Build a (n_periods=3, n_work=2, n_partner=2, n_next_partner=2) array.""" - return jnp.array( - [ - [[[0.3, 0.7], [0.6, 0.4]], [[0.1, 0.9], [0.5, 0.5]]], - [[[0.4, 0.6], [0.8, 0.2]], [[0.2, 0.8], [0.7, 0.3]]], - [[[0.5, 0.5], [0.9, 0.1]], [[0.3, 0.7], [0.6, 0.4]]], - ] - ) - - -def test_validate_transition_probs_accepts_boundary_inputs(): - """Inclusive [0, 1] bounds and row sums within the 1e-6 tolerance pass. - - The first row is exactly `[0.0, 1.0]` — values at the inclusive bounds. - The last row sums to `1 - 5e-7`, just inside the `atol=1e-6` row-sum - tolerance. The validator must accept both without raising. - """ - model = get_stochastic_model(3) - arr = jnp.array( - [ - [[[0.0, 1.0], [1.0, 0.0]], [[0.3, 0.7], [0.6, 0.4]]], - [[[0.4, 0.6], [0.8, 0.2]], [[0.2, 0.8], [0.7, 0.3]]], - [[[0.5, 0.5], [0.9, 0.1]], [[0.3, 0.7], [0.5, 0.4999995]]], - ] - ) - validate_transition_probs( - probs=arr, model=model, regime_name="working_life", state_name="partner" - ) - - -def test_validate_transition_probs_wrong_shape(): - model = get_stochastic_model(3) - arr = jnp.ones((2, 2, 2)) / 2 # wrong shape - with pytest.raises(ValueError, match="shape"): - validate_transition_probs( - probs=arr, model=model, regime_name="working_life", state_name="partner" - ) - - -def test_validate_transition_probs_values_out_of_range(): - model = get_stochastic_model(3) - arr = _make_partner_probs_array() - bad_arr = arr.at[0, 0, 0, 0].set(-0.1) - with pytest.raises(ValueError, match="\\[0, 1\\]"): - validate_transition_probs( - probs=bad_arr, model=model, regime_name="working_life", state_name="partner" - ) - - -def test_validate_transition_probs_rows_dont_sum_to_one(): - model = get_stochastic_model(3) - arr = jnp.ones((3, 2, 2, 2)) * 0.3 # rows sum to 0.6, not 1 - with pytest.raises(ValueError, match="sum to 1"): - validate_transition_probs( - probs=arr, model=model, regime_name="working_life", state_name="partner" - ) diff --git a/tests/test_validate_regime_transition_probs.py b/tests/test_validate_regime_transition_probs.py index 944d64658..b8fa80019 100644 --- a/tests/test_validate_regime_transition_probs.py +++ b/tests/test_validate_regime_transition_probs.py @@ -242,7 +242,7 @@ def test_solve_catches_transition_bug_hidden_at_first_grid_point(): """Pre-solve validation catches invalid probs even if first action value is ok.""" model, params = _build_action_dependent_model() with pytest.raises(InvalidRegimeTransitionProbabilitiesError, match="outside"): - model.solve(params=params) + model.solve(log_level="debug", params=params) def test_regime_transition_validation_passes_period_as_int32(): @@ -289,7 +289,7 @@ def _transition_recording_period( ages=AgeGrid(start=25, stop=27, step="Y"), regime_id_class=_RegimeId, ) - model.solve(params={"discount_factor": 0.95}) + model.solve(log_level="debug", params={"discount_factor": 0.95}) assert seen_period_dtypes assert all(dtype == jnp.int32 for dtype in seen_period_dtypes) @@ -308,14 +308,14 @@ def test_solve_raises_for_invalid_regime_transition_probs(): model = get_model(N_PERIODS) params = get_params(N_PERIODS, survival_probs=_invalid_survival_probs(N_PERIODS)) with pytest.raises(InvalidRegimeTransitionProbabilitiesError): - model.solve(params=params) + model.solve(log_level="debug", params=params) def test_simulate_raises_for_invalid_regime_transition_probs(): """model.simulate() raises for out-of-bounds regime transition probabilities.""" model = get_model(N_PERIODS) good_params = get_params(N_PERIODS) - period_to_regime_to_V_arr = model.solve(params=good_params) + period_to_regime_to_V_arr = model.solve(log_level="debug", params=good_params) bad_params = get_params( N_PERIODS, survival_probs=_invalid_survival_probs(N_PERIODS) @@ -327,6 +327,7 @@ def test_simulate_raises_for_invalid_regime_transition_probs(): } with pytest.raises(InvalidRegimeTransitionProbabilitiesError): model.simulate( + log_level="debug", params=bad_params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=period_to_regime_to_V_arr, @@ -344,6 +345,7 @@ def test_simulate_with_solve_raises_for_invalid_regime_transition_probs(): } with pytest.raises(InvalidRegimeTransitionProbabilitiesError): model.simulate( + log_level="debug", params=params, initial_conditions=initial_conditions, period_to_regime_to_V_arr=None,