Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: model initial_point fails when pt.config.floatX = "float32" #7608

Closed
nataziel opened this issue Dec 6, 2024 · 27 comments · Fixed by #7610
Closed

BUG: model initial_point fails when pt.config.floatX = "float32" #7608

nataziel opened this issue Dec 6, 2024 · 27 comments · Fixed by #7610
Labels

Comments

@nataziel
Copy link
Contributor

nataziel commented Dec 6, 2024

Describe the issue:

I have a hierarchical mmm model setup (with pymc, not pymc-marketing) and have been successfully using it with float32s up to the 5.19 release. It is using numpyro/jax to sample.

With 5.19 I am getting errors in _init_jitter, it appears that there is something going wrong when passing the generated initial points to the compiled logp function. I think the use of zerosumNormal distributions is causing the problem but I'm not sure if it's the values returned by the ipfn(seed) or the evaluation in the compiled model_logp_fn. I've included the verbose model debug return from my model, but the reproducible example below is using the example radon model.

Reproduceable code example:

import pytensor as pt

pt.config.floatX = "float32"
pt.config.warn_float64 = "ignore"

import numpy as np

import pandas as pd
import pymc as pm

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")

idata = pm.sample(
    model=model,
    chains=1,
    tune=500,
    draws=500,
    progressbar=False,
    compute_convergence_checks=False,
    return_inferencedata=False,
    # compile_kwargs=dict(mode="NUMBA")
)

Error message:

SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'mu_adstock_logodds__': array([-0.05812943, -1.3623805 , -0.11238763, -0.5489    , -1.0595012 ,
       -1.6993128 , -1.0827137 , -1.0727112 , -1.7406836 , -0.25875893,
       -1.0392814 , -0.94655335,  0.08540058, -0.5202998 ,  0.33476347,
       -0.5190295 , -1.0631133 , -0.57824194, -0.7651843 , -0.87049246],
      dtype=float32), 'mu_lambda_log__': array([0.80796826, 0.6380648 , 1.2703241 , 1.9139447 , 1.5351636 ,
       1.7267554 , 1.0714266 , 1.0567162 , 2.0499995 , 1.5639257 ,
       1.9248102 , 1.6098787 , 0.56195414, 0.22962454, 0.10311574,
       0.36890286, 0.88685906, 1.7213213 , 1.2186754 , 0.36423793],
      dtype=float32), 'mu_a': array(0.37686896, dtype=float32), 'z_a_state_zerosum__': array([ 0.19204606,  0.04281855,  0.5793297 , -0.2614823 , -0.60941106,
       -0.5645286 ], dtype=float32), 'z_a_age_zerosum__': array([-0.39018095,  0.82710207,  0.93642265, -0.30185702,  0.7038302 ],
      dtype=float32), 'z_a_brand_zerosum__': array([-0.19112755], dtype=float32), 'z_a_cohort_zerosum__': array([0.48081324], dtype=float32), 'roas_rv_log__': array([-2.0530562 , -1.1738657 , -2.1438377 , -0.5091131 , -1.7475982 ,
       -0.71832424,  0.7603461 , -0.6206605 , -0.06742464, -1.7866528 ,
       -0.8273122 ,  0.3449509 , -0.4612695 , -1.5151714 , -0.9103267 ,
        0.92829466,  0.05409501, -0.17818904,  0.4767271 ,  1.0659611 ],
      dtype=float32), 'z_b_state_zerosum__': array([[ 0.9404972 ,  0.5173597 ,  0.947904  ,  0.13234697,  0.91516215,
         0.8013973 ],
       [ 0.76825845, -0.00395296,  0.20068735,  0.89948386, -0.8211762 ,
        -0.85832894],
       [-0.23005143, -0.21612945,  0.38762948, -0.17819382, -0.12716204,
         0.5923743 ],
       [ 0.42042527, -0.8524518 ,  0.41847345, -0.27271616,  0.9246246 ,
         0.10817041],
       [ 0.5263319 ,  0.2471719 , -0.7485617 , -0.30133858,  0.48334756,
         0.3941832 ],
       [ 0.20216979,  0.85309595, -0.13996926,  0.13978173,  0.2398329 ,
        -0.6436887 ],
       [ 0.6737369 ,  0.56847805, -0.6305385 ,  0.11638433, -0.88574356,
        -0.28407305],
       [-0.88218284, -0.5086114 ,  0.5899615 ,  0.4378485 ,  0.1587185 ,
        -0.06953551],
       [-0.4915884 ,  0.5323545 ,  0.5476521 ,  0.03720003,  0.93754584,
         0.20691545],
       [ 0.52630943,  0.8224987 ,  0.38608572, -0.48465434, -0.43573558,
         0.8957741 ],
       [-0.36157402, -0.80625093, -0.35036987,  0.58878934, -0.01390497,
        -0.9381669 ],
       [ 0.84330225, -0.24921258, -0.5905653 ,  0.71469283,  0.04288541,
        -0.15973687],
       [ 0.16558522, -0.82091933,  0.6832055 ,  0.9772107 ,  0.05962521,
        -0.7585791 ],
       [-0.34562954, -0.14827304, -0.7131723 , -0.49016312,  0.92288506,
         0.5976615 ],
       [ 0.14006369, -0.78234476,  0.58030427,  0.7087098 ,  0.83111507,
        -0.4035646 ],
       [-0.89755225, -0.7397433 ,  0.87145287,  0.5502826 , -0.75381684,
        -0.08390825],
       [ 0.52683276,  0.2059658 , -0.65518147,  0.9072355 ,  0.02874351,
         0.8622515 ],
       [-0.31716648, -0.22920816, -0.4773452 ,  0.24883471, -0.71415335,
        -0.00118549],
       [ 0.67181724,  0.83293927,  0.06912381,  0.59113485,  0.15970528,
         0.15737481],
       [ 0.91586804, -0.00439743,  0.8568587 , -0.5599965 ,  0.9024629 ,
        -0.92843384]], dtype=float32), 'z_b_age_zerosum__': array([[ 2.76717365e-01, -5.32886267e-01, -4.37204897e-01,
         6.65345013e-01,  9.51409161e-01],
       [ 5.68430305e-01,  5.80501974e-01,  9.02473092e-01,
        -1.05089314e-01, -6.80017248e-02],
       [-7.40278482e-01, -6.54472530e-01,  5.73029280e-01,
        -2.49546096e-01, -9.83492434e-01],
       [ 4.39474136e-01,  4.55991507e-01,  2.91431248e-01,
         7.93459237e-01,  8.33085358e-01],
       [ 7.54141450e-01, -9.88980412e-01,  7.37549663e-01,
        -9.54164326e-01,  3.69425505e-01],
       [-4.45214868e-01,  1.18648775e-01,  4.35143918e-01,
         7.96567798e-01, -5.37025869e-01],
       [-8.19233358e-01, -3.28816384e-01, -4.24525887e-01,
        -4.72912073e-01, -2.51088679e-01],
       [ 4.16932464e-01,  1.86953232e-01, -2.34448180e-01,
        -5.28278828e-01, -7.83707380e-01],
       [-2.08375111e-01, -1.69877082e-01, -9.20472383e-01,
        -5.55105388e-01,  2.24135935e-01],
       [-7.18319640e-02, -8.23212624e-01, -1.14380375e-01,
         3.75080615e-01,  4.38587993e-01],
       [-7.84464240e-01,  9.64653268e-02,  1.33498237e-01,
         2.63148099e-01,  9.03292537e-01],
       [ 3.70964020e-01,  8.96655202e-01, -9.78391707e-01,
         6.00353360e-01, -3.29210430e-01],
       [ 9.80111659e-01,  6.26725018e-01,  8.71558905e-01,
        -7.08010912e-01,  3.21216695e-02],
       [-8.55567873e-01, -4.15038317e-01, -2.70858496e-01,
         7.64281690e-01,  1.69419169e-01],
       [-6.30245388e-01,  5.22969842e-01, -6.22790098e-01,
         8.40588808e-01,  5.42818129e-01],
       [-2.61249971e-02,  5.77672958e-01, -9.52823997e-01,
        -5.49517214e-01, -4.92883384e-01],
       [ 3.19638193e-01,  8.80902350e-01,  2.54505854e-02,
        -4.16665673e-01,  7.45047331e-01],
       [-1.37079775e-01, -9.72663925e-04,  2.88793862e-01,
         9.96275783e-01,  5.16300082e-01],
       [-2.26764768e-01, -9.21454072e-01, -2.66458213e-01,
        -2.89255470e-01, -5.44357836e-01],
       [-7.08415627e-01, -2.39693552e-01,  1.69611976e-01,
         6.88308775e-01, -2.90724158e-01]], dtype=float32), 'z_b_brand_zerosum__': array([[ 0.5283236 ],
       [-0.14984424],
       [ 0.5653758 ],
       [ 0.6604869 ],
       [ 0.5594151 ],
       [ 0.36363953],
       [-0.13847719],
       [ 0.2760732 ],
       [ 0.60931265],
       [ 0.39675766],
       [ 0.13061056],
       [-0.843226  ],
       [-0.24025336],
       [ 0.21590135],
       [ 0.38261482],
       [-0.9853659 ],
       [-0.89518636],
       [-0.73512644],
       [ 0.24093248],
       [ 0.53579485]], dtype=float32), 'z_b_cohort_zerosum__': array([[-0.89432555],
       [ 0.56194794],
       [ 0.8784441 ],
       [ 0.29107055],
       [-0.03464561],
       [-0.7969992 ],
       [ 0.3803715 ],
       [ 0.14672193],
       [-0.72870535],
       [-0.9154726 ],
       [-0.05452913],
       [ 0.3046088 ],
       [ 0.11125816],
       [ 0.47531176],
       [-0.26566276],
       [-0.97336334],
       [ 0.8083869 ],
       [ 0.10919274],
       [ 0.34626842],
       [-0.14968246]], dtype=float32), 'mu_b_pos_con': array([-0.55350554, -1.271216  , -2.4323664 , -2.2574682 , -1.7525793 ],
      dtype=float32), 'z_b_pos_con_state_zerosum__': array([[ 0.7088774 , -0.64973927,  0.01091878,  0.42494458, -0.5738007 ,
        -0.13388953],
       [ 0.2020803 , -0.4415184 ,  0.9188914 , -0.91647035,  0.20404132,
        -0.19665638],
       [-0.460112  ,  0.250589  ,  0.43331107, -0.31884784, -0.20407897,
        -0.85575414],
       [ 0.053673  , -0.31756902,  0.743054  , -0.83611   ,  0.24574716,
         0.31937018],
       [-0.24382843, -0.7048149 , -0.29253975, -0.2640002 , -0.7875447 ,
         0.9551751 ]], dtype=float32), 'z_b_pos_con_age_zerosum__': array([[-0.968712  , -0.18256   ,  0.54134816,  0.5418771 , -0.20875828],
       [ 0.07226439,  0.3872751 ,  0.78559935,  0.56196845, -0.21816622],
       [-0.07638574,  0.56413966, -0.29065445, -0.8537547 ,  0.26876295],
       [-0.03170992,  0.7177836 , -0.14550653,  0.9839035 ,  0.05719684],
       [-0.09502391,  0.7952581 , -0.3214874 ,  0.5325462 , -0.84900486]],
      dtype=float32), 'z_b_pos_con_brand_zerosum__': array([[ 0.9145228 ],
       [-0.576082  ],
       [-0.9886196 ],
       [-0.76018703],
       [ 0.8999498 ]], dtype=float32), 'z_b_pos_con_cohort_zerosum__': array([[-0.32495368],
       [-0.32046455],
       [ 0.20123298],
       [-0.7780437 ],
       [-0.78692883]], dtype=float32), 'mu_b_neg_con': array([-1.6263036], dtype=float32), 'z_b_neg_con_state_zerosum__': array([[ 0.6405368 ,  0.02604678,  0.7836906 , -0.9528276 ,  0.65443355,
        -0.00763485]], dtype=float32), 'z_b_neg_con_age_zerosum__': array([[-0.88560593, -0.3676332 , -0.1298519 , -0.39139104,  0.05216115]],
      dtype=float32), 'z_b_neg_con_brand_zerosum__': array([[0.06288974]], dtype=float32), 'z_b_neg_con_cohort_zerosum__': array([[-0.2893259]], dtype=float32), 'mu_b_lag': array([-3.4769938], dtype=float32), 'z_b_lag_state_zerosum__': array([[-0.5544768 ,  0.8843655 , -0.85776746, -0.54040647,  0.83116657,
        -0.91366553]], dtype=float32), 'z_b_lag_age_zerosum__': array([[-0.27069455,  0.2308255 ,  0.02965806, -0.922125  ,  0.25014895]],
      dtype=float32), 'z_b_lag_brand_zerosum__': array([[-0.95589054]], dtype=float32), 'z_b_lag_cohort_zerosum__': array([[-0.9150893]], dtype=float32), 'mu_b_fourier_year': array([ 0.99425113,  0.6344655 , -0.11898322,  0.8841637 ,  0.21039897,
        0.07808983,  0.8532348 ,  0.44000137,  0.98853546, -0.52862114,
        0.79177517, -0.01937965, -0.7798197 , -0.8342347 , -0.64850366,
        0.97407717,  0.95000124,  0.7092928 ,  0.6113028 , -0.27019796],
      dtype=float32), 'sd_y_log__': array(4.4781275, dtype=float32)}

Logp initial evaluation results:
{'mu_adstock': -20.54, 'mu_lambda': -213.9, 'mu_a': -1.3, 'z_a_state': -inf, 'z_a_age': -inf, 'z_a_brand': -0.44, 'z_a_cohort': -inf, 'roas_rv': -72.16, 'z_b_state': -inf, 'z_b_age': -inf, 'z_b_brand': -inf, 'z_b_cohort': -inf, 'mu_b_pos_con': -5.82, 'z_b_pos_con_state': -inf, 'z_b_pos_con_age': -inf, 'z_b_pos_con_brand': -inf, 'z_b_pos_con_cohort': -66.75, 'mu_b_neg_con': -1.12, 'z_b_neg_con_state': -inf, 'z_b_neg_con_age': -inf, 'z_b_neg_con_brand': 1.19, 'z_b_neg_con_cohort': -2.8, 'mu_b_lag': -1.03, 'z_b_lag_state': -inf, 'z_b_lag_age': -inf, 'z_b_lag_brand': -44.3, 'z_b_lag_cohort': -40.49, 'mu_b_fourier_year': -91.0, 'sd_y': -1.21, 'y_like': -56160.32}
You can call `model.debug()` for more details.

SamplingError                             Traceback (most recent call last)
File ~/.ipykernel/1917/command--1-2479103431:18
     15 entry = [ep for ep in metadata.distribution("mmm_v2").entry_points if ep.name == "mmm"]
     16 if entry:
     17   # Load and execute the entrypoint, assumes no parameters
---> 18   entry[0].load()()
     19 else:
     20   import importlib

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/mmm_v2/main.py:16, in main()
     13 pt.config.floatX = "float32"  # pyright: ignore[reportPrivateImportUsage] // TODO: we can probably set this via env vars or pytensor.rc
     14 config = setup_mmm()
---> 16 run_mmm(config)

-- elided non important dataprep pipeline stuff --

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/mmm_v2/model/ModelBuilder.py:634, in HLM_ModelBuilder.fit(self)
    631 """Fits the model with the pymc model `sample()` method."""
    632 self.logger.debug("Sampling model.")
    633 self.model_trace.extend(
--> 634     pm.sample(
    635         draws=self.config.draws,
    636         tune=self.config.tune,
    637         chains=self.config.chains,
    638         model=self.model,
    639         nuts_sampler="numpyro",
    640         idata_kwargs={"log_likelihood": False},
    641         var_names=[f"{x}" for x in self.model.free_RVs],
    642         target_accept=self.config.target_accept,
    643         random_seed=self.config.sampler_seed,
    644         progressbar=self.config.progress_bars,
    645     )
    646 )
    647 self.logger.debug("Finished sampling.")
    649 HLM_ModelBuilder._print_diagnostics(self.model_trace, self.logger)

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/mcmc.py:773, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    768         raise ValueError(
    769             "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
    770         )
    772     with joined_blas_limiter():
--> 773         return _sample_external_nuts(
    774             sampler=nuts_sampler,
    775             draws=draws,
    776             tune=tune,
    777             chains=chains,
    778             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    779             random_seed=random_seed,
    780             initvals=initvals,
    781             model=model,
    782             var_names=var_names,
    783             progressbar=progressbar,
    784             idata_kwargs=idata_kwargs,
    785             compute_convergence_checks=compute_convergence_checks,
    786             nuts_sampler_kwargs=nuts_sampler_kwargs,
    787             **kwargs,
    788         )
    790 if exclusive_nuts and not provided_steps:
    791     # Special path for NUTS initialization
    792     if "nuts" in kwargs:

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/mcmc.py:389, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    386 elif sampler in ("numpyro", "blackjax"):
    387     import pymc.sampling.jax as pymc_jax
--> 389     idata = pymc_jax.sample_jax_nuts(
    390         draws=draws,
    391         tune=tune,
    392         chains=chains,
    393         target_accept=target_accept,
    394         random_seed=random_seed,
    395         initvals=initvals,
    396         model=model,
    397         var_names=var_names,
    398         progressbar=progressbar,
    399         nuts_sampler=sampler,
    400         idata_kwargs=idata_kwargs,
    401         compute_convergence_checks=compute_convergence_checks,
    402         **nuts_sampler_kwargs,
    403     )
    404     return idata
    406 else:

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/jax.py:595, in sample_jax_nuts(draws, tune, chains, target_accept, random_seed, initvals, jitter, model, var_names, nuts_kwargs, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_vectorize, postprocessing_chunks, idata_kwargs, compute_convergence_checks, nuts_sampler)
    589 vars_to_sample = list(
    590     get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
    591 )
    593 (random_seed,) = _get_seeds_per_chain(random_seed, 1)
--> 595 initial_points = _get_batched_jittered_initial_points(
    596     model=model,
    597     chains=chains,
    598     initvals=initvals,
    599     random_seed=random_seed,
    600     jitter=jitter,
    601 )
    603 if nuts_sampler == "numpyro":
    604     sampler_fn = _sample_numpyro_nuts

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/jax.py:225, in _get_batched_jittered_initial_points(model, chains, initvals, random_seed, jitter, jitter_max_retries)
    209 def _get_batched_jittered_initial_points(
    210     model: Model,
    211     chains: int,
   (...)
    215     jitter_max_retries: int = 10,
    216 ) -> np.ndarray | list[np.ndarray]:
    217     """Get jittered initial point in format expected by NumPyro MCMC kernel.
    218 
    219     Returns
   (...)
    223         Each item has shape `(chains, *var.shape)`
    224     """
--> 225     initial_points = _init_jitter(
    226         model,
    227         initvals,
    228         seeds=_get_seeds_per_chain(random_seed, chains),
    229         jitter=jitter,
    230         jitter_max_retries=jitter_max_retries,
    231     )
    232     initial_points_values = [list(initial_point.values()) for initial_point in initial_points]
    233     if chains == 1:

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/sampling/mcmc.py:1382, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_dlogp_func)
   1379 if not np.isfinite(point_logp):
   1380     if i == jitter_max_retries:
   1381         # Print informative message on last attempted point
-> 1382         model.check_start_vals(point)
   1383     # Retry with a new seed
   1384     seed = rng.integers(2**30, dtype=np.int64)

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/pymc/model/core.py:1769, in Model.check_start_vals(self, start, **kwargs)
   1766 initial_eval = self.point_logps(point=elem, **kwargs)
   1768 if not all(np.isfinite(v) for v in initial_eval.values()):
-> 1769     raise SamplingError(
   1770         "Initial evaluation of model at starting point failed!\n"
   1771         f"Starting values:\n{elem}\n\n"
   1772         f"Logp initial evaluation results:\n{initial_eval}\n"
   1773         "You can call `model.debug()` for more details."
   1774     )

PyMC version information:

Running on a windows machine in a linux container
pymc installed via poetry(pypi).
using libopenblas

annotated-types 0.7.0
arviz 0.19.0
babel 2.16.0
blinker 1.4
build 1.2.2.post1
CacheControl 0.14.1
cachetools 5.5.0
certifi 2024.8.30
cfgv 3.4.0
charset-normalizer 3.4.0
cleo 2.1.0
click 8.1.7
cloudpickle 3.1.0
colorama 0.4.6
cons 0.4.6
contourpy 1.3.1
coverage 7.6.8
crashtest 0.4.1
cryptography 3.4.8
cycler 0.12.1
dbus-python 1.2.18
distlib 0.3.9
distro 1.7.0
distro-info 1.1+ubuntu0.2
dm-tree 0.1.8
dulwich 0.21.7
etuples 0.3.9
exceptiongroup 1.2.2
fastjsonschema 2.21.1
filelock 3.16.1
fonttools 4.55.1
ghp-import 2.1.0
graphviz 0.20.3
griffe 1.5.1
h5netcdf 1.4.1
h5py 3.12.1
httplib2 0.20.2
identify 2.6.3
idna 3.10
importlib_metadata 8.5.0
iniconfig 2.0.0
installer 0.7.0
jaraco.classes 3.4.0
jax 0.4.35
jaxlib 0.4.35
jeepney 0.7.1
Jinja2 3.1.4
joblib 1.4.2
keyring 24.3.1
kiwisolver 1.4.7
launchpadlib 1.10.16
lazr.restfulclient 0.14.4
lazr.uri 1.0.6
logical-unification 0.4.6
loguru 0.7.2
Markdown 3.7
markdown-it-py 3.0.0
MarkupSafe 3.0.2
matplotlib 3.9.3
mdurl 0.1.2
mergedeep 1.3.4
miniKanren 1.0.3
mkdocs 1.6.1
mkdocs-autorefs 1.2.0
mkdocs-gen-files 0.5.0
mkdocs-get-deps 0.2.0
mkdocs-glightbox 0.4.0
mkdocs-literate-nav 0.6.1
mkdocs-material 9.5.47
mkdocs-material-extensions 1.3.1
mkdocs-section-index 0.3.9
mkdocstrings 0.26.2
mkdocstrings-python 1.12.2
ml_dtypes 0.5.0
mmm_v2 0.0.1 /workspaces/mmm_v2
more-itertools 8.10.0
msgpack 1.1.0
multimethod 1.10
multipledispatch 1.0.0
mypy-extensions 1.0.0
nodeenv 1.9.1
numpy 1.26.4
numpyro 0.15.3
oauthlib 3.2.0
opt_einsum 3.4.0
packaging 24.2
paginate 0.5.7
pandas 2.2.3
pandera 0.20.4
pathspec 0.12.1
pexpect 4.9.0
pillow 11.0.0
pip 24.3.1
pkginfo 1.12.0
platformdirs 4.3.6
pluggy 1.5.0
poetry 1.8.4
poetry-core 1.9.1
poetry-plugin-export 1.8.0
pre_commit 4.0.1
ptyprocess 0.7.0
pyarrow 18.1.0
pydantic 2.10.3
pydantic_core 2.27.1
Pygments 2.18.0
PyGObject 3.42.1
PyJWT 2.3.0
pymc 5.19.1
pymc-marketing 0.6.0
pymdown-extensions 10.12
pyparsing 3.2.0
pyproject_hooks 1.2.0
pytensor 2.26.4
pytest 8.3.4
pytest-cov 6.0.0
python-apt 2.4.0+ubuntu3
python-dateutil 2.9.0.post0
pytz 2024.2
PyYAML 6.0.2
pyyaml_env_tag 0.1
RapidFuzz 3.10.1
regex 2024.11.6
requests 2.32.3
requests-toolbelt 1.0.0
rich 13.9.4
ruff 0.8.1
scikit-learn 1.5.2
scipy 1.14.1
seaborn 0.13.2
SecretStorage 3.3.1
setuptools 75.6.0
shellingham 1.5.4
six 1.17.0
ssh-import-id 5.11
threadpoolctl 3.5.0
tomli 2.2.1
tomlkit 0.13.2
toolz 1.0.0
tqdm 4.67.1
trove-classifiers 2024.10.21.16
typeguard 4.4.1
typing_extensions 4.12.2
typing-inspect 0.9.0
tzdata 2024.2
unattended-upgrades 0.1
urllib3 2.2.3
virtualenv 20.28.0
wadllib 1.3.6
watchdog 6.0.0
wheel 0.38.4
wrapt 1.17.0
xarray 2024.9.0
xarray-einstats 0.8.0
zipp 3.21.0

Context for the issue:

I want to use f32 due to memory constraints, I've got a big model and dataset with a hierarchy that takes up 40+gb of ram when running using f32 so I'd need a huge box to go up to f64.

Maybe the optimisations in 5.19 make that problem moot? I've managed to get it running with f64 temporarily but I'm not sure if it'll be a long term solution

@nataziel nataziel added the bug label Dec 6, 2024
Copy link

welcome bot commented Dec 6, 2024

Welcome Banner]
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

@nataziel
Copy link
Contributor Author

nataziel commented Dec 6, 2024

debug trace wouldn't fit in the issue:

Model debug info from using the debugger to pause at mcmc.py L1386 and calling `model.debug(point, verbose=True)`

point={'mu_adstock_logodds__': array([-0.4352023 , -0.2916657 , -0.22959015, -0.13479875, -0.30916345,
       -1.5283642 , -1.655654  , -0.48334104, -0.30794948,  0.33021173,
        0.0656831 , -1.6596988 , -0.47080922, -0.91784286, -0.15092045,
       -1.1535047 , -1.0793369 , -0.8457366 , -1.3219562 , -1.3434141 ],
      dtype=float32), 'mu_lambda_log__': array([2.2058203 , 1.453856  , 1.6157596 , 0.43113965, 1.6577175 ,
       0.84230906, 0.27920187, 0.899368  , 1.8292453 , 2.2519255 ,
       2.3814824 , 2.108992  , 1.4626849 , 1.5980046 , 0.8299414 ,
       0.3984745 , 1.8697526 , 1.4488894 , 2.1219566 , 0.18294996],
      dtype=float32), 'mu_a': array(0.16189706, dtype=float32), 'z_a_state_zerosum__': array([-0.28566697, -0.48986882, -0.24404879, -0.4398191 , -0.85556257,
        0.10748474], dtype=float32), 'z_a_age_zerosum__': array([-0.8873828 , -0.32871515,  0.57389575,  0.30648488, -0.6199051 ],
      dtype=float32), 'z_a_brand_zerosum__': array([-0.46795005], dtype=float32), 'z_a_cohort_zerosum__': array([0.4478183], dtype=float32), 'roas_rv_log__': array([-0.6555429 , -0.068955  , -0.9309296 , -0.47975117, -1.3743148 ,
        0.47711676,  0.14864558,  0.10893539,  0.38675046, -0.8370424 ,
       -0.66264176, -1.5664616 , -0.572411  , -1.4761899 , -1.0794916 ,
       -0.2966979 ,  0.05331168,  1.0579    ,  0.13020533,  0.8946314 ],
      dtype=float32), 'z_b_state_zerosum__': array([[-0.00928065,  0.22661608,  0.26506114,  0.10868097,  0.47448924,
         0.6638588 ],
       [ 0.39453208,  0.9125831 , -0.05330055,  0.16203904, -0.7428469 ,
        -0.2878859 ],
       [-0.22169496,  0.25026786,  0.7721265 , -0.16524933, -0.8161399 ,
         0.5124463 ],
       [ 0.4635996 ,  0.935813  ,  0.3664374 , -0.39854062, -0.11831979,
         0.23826346],
       [ 0.4144873 , -0.07588482, -0.4675184 ,  0.9954296 ,  0.44995347,
         0.6562674 ],
       [ 0.94758034,  0.2068893 ,  0.6966277 ,  0.31964955, -0.8013234 ,
        -0.59591883],
       [-0.17715056, -0.7038275 ,  0.18067661,  0.01431344, -0.9491178 ,
        -0.3321023 ],
       [ 0.3942007 ,  0.9996393 , -0.31270924, -0.08990093, -0.09300919,
        -0.16450764],
       [-0.00447197, -0.61609423,  0.8628801 ,  0.96006954,  0.7203218 ,
        -0.7518324 ],
       [ 0.6942957 , -0.44699988,  0.57910615,  0.8879041 ,  0.531556  ,
        -0.9510816 ],
       [ 0.78471005,  0.10752742,  0.4335172 , -0.58196217, -0.9966123 ,
        -0.17337854],
       [-0.06716231, -0.5351729 , -0.1103561 , -0.15798165, -0.15524508,
         0.8739795 ],
       [ 0.47066316,  0.03429028, -0.2272006 ,  0.57281727,  0.9989922 ,
        -0.26203355],
       [-0.59414744, -0.34866652, -0.58397436, -0.12034182,  0.16198853,
        -0.36454397],
       [ 0.12944746, -0.05762197,  0.99427617,  0.81767935,  0.5921547 ,
         0.9800794 ],
       [ 0.9717736 ,  0.9814946 ,  0.4856121 , -0.5534532 ,  0.11700594,
         0.9247631 ],
       [-0.2042932 , -0.411241  ,  0.27332023,  0.9046378 ,  0.6154953 ,
        -0.08056752],
       [ 0.9214974 , -0.65947914, -0.41038954,  0.54713   , -0.3560202 ,
        -0.9969325 ],
       [-0.08087815,  0.18727091,  0.84307253, -0.48801887,  0.29456693,
         0.5796735 ],
       [ 0.27902   ,  0.29730567, -0.40406513, -0.18478568, -0.61452186,
        -0.5549851 ]], dtype=float32), 'z_b_age_zerosum__': array([[-6.8486583e-01, -4.6815168e-02, -5.9378707e-01,  8.9602423e-01,
         8.5502052e-01],
       [ 7.8728044e-01, -3.6670679e-01, -3.3962426e-01,  3.0838227e-01,
         3.6406529e-01],
       [ 2.6335692e-02, -6.4281446e-01,  5.1187193e-01,  8.4743094e-01,
         2.2725777e-01],
       [ 1.9795492e-01,  9.2090023e-01, -8.8563585e-01, -2.8022802e-01,
        -2.3639840e-01],
       [ 2.9900335e-02,  7.1486712e-01,  6.7400551e-01, -7.0308822e-01,
        -6.9614536e-01],
       [ 1.3013636e-01,  1.0248652e-01, -5.7761997e-02, -8.4077924e-01,
        -4.7718164e-01],
       [-5.9512955e-01, -9.2812777e-02, -9.9525869e-02,  9.1666229e-02,
         7.0176089e-01],
       [ 4.7530079e-01,  2.4512438e-01, -4.2329890e-01, -6.4359361e-01,
         4.8717073e-01],
       [-4.3700787e-01, -3.6620468e-01, -5.4181212e-01, -7.7344787e-01,
        -5.0397778e-01],
       [ 5.0274485e-01,  8.3750290e-01, -1.0284081e-01,  6.3057953e-01,
         9.5303126e-02],
       [ 9.8040715e-02,  3.1213897e-01,  9.7941196e-01,  6.9000393e-01,
         1.3390434e-01],
       [ 5.9936064e-01, -1.6784413e-01,  9.4419844e-02, -5.2747607e-02,
         1.4664505e-01],
       [-8.8135488e-02, -1.6365480e-01, -2.4431337e-01,  2.9717276e-01,
         8.4692138e-01],
       [ 4.3495792e-01,  5.3633377e-02,  4.0893257e-01, -1.4952035e-01,
        -2.0427135e-01],
       [-4.2102399e-01,  7.0554173e-01, -7.1471161e-01, -6.7319351e-01,
         6.4274400e-01],
       [ 2.9349172e-01, -4.3267983e-01, -6.5261596e-01, -7.2232783e-01,
         8.7439209e-01],
       [ 3.5815242e-01,  9.3956250e-01,  7.0418483e-01,  6.0373771e-01,
         3.7868690e-02],
       [-8.4775686e-01, -4.1210197e-02, -8.3802587e-01, -7.5553125e-01,
         1.8493308e-01],
       [-9.3824327e-01,  4.9592870e-01,  7.3176724e-01,  1.4875463e-01,
         4.6959123e-01],
       [-9.1654237e-04, -3.3683771e-01,  2.9216877e-01, -1.3542533e-01,
         3.8034424e-02]], dtype=float32), 'z_b_brand_zerosum__': array([[ 0.1755699 ],
       [-0.64566314],
       [-0.164115  ],
       [-0.7532449 ],
       [ 0.8172519 ],
       [-0.07537084],
       [ 0.18247645],
       [-0.6951736 ],
       [-0.34786522],
       [ 0.13041314],
       [-0.26256517],
       [-0.75953436],
       [ 0.13290128],
       [ 0.7542543 ],
       [ 0.43538788],
       [ 0.2650157 ],
       [-0.22280337],
       [ 0.9436325 ],
       [ 0.88120073],
       [-0.2978969 ]], dtype=float32), 'z_b_cohort_zerosum__': array([[ 0.08365148],
       [ 0.4037446 ],
       [ 0.219702  ],
       [-0.04724727],
       [-0.39688712],
       [-0.05613985],
       [ 0.8531689 ],
       [ 0.21144761],
       [-0.7068673 ],
       [ 0.0072215 ],
       [-0.5336951 ],
       [ 0.34807727],
       [ 0.39438143],
       [ 0.7545098 ],
       [ 0.8487969 ],
       [-0.93471295],
       [ 0.36905304],
       [ 0.23143601],
       [-0.85752094],
       [-0.93718106]], dtype=float32), 'mu_b_pos_con': array([-2.4778757, -0.7670131, -2.4533374, -2.1609292, -1.4259497],
      dtype=float32), 'z_b_pos_con_state_zerosum__': array([[-0.12270445,  0.24486493,  0.39655864, -0.2905437 , -0.34443325,
         0.00207796],
       [-0.6199098 , -0.9983607 ,  0.17514546,  0.70688826, -0.32889152,
        -0.14889538],
       [-0.7383829 , -0.3004701 , -0.3937158 ,  0.05630853,  0.04613764,
        -0.7968154 ],
       [-0.83736265,  0.53274584,  0.00774734,  0.3415022 , -0.02279032,
        -0.80878764],
       [-0.29654893,  0.8825508 , -0.31042358, -0.4510256 , -0.8262907 ,
        -0.539443  ]], dtype=float32), 'z_b_pos_con_age_zerosum__': array([[ 0.78295594,  0.05425891, -0.34404022,  0.54096305,  0.30847767],
       [-0.73004913, -0.9038148 , -0.16850933, -0.00236369, -0.6851207 ],
       [ 0.39292192,  0.9246333 ,  0.6826134 , -0.11111186,  0.9510223 ],
       [-0.83537763,  0.82692575, -0.73304725,  0.18374918, -0.9958687 ],
       [ 0.56813633,  0.7769569 , -0.94843704,  0.06147736, -0.7710584 ]],
      dtype=float32), 'z_b_pos_con_brand_zerosum__': array([[-0.9075556 ],
       [-0.27464116],
       [ 0.7783447 ],
       [-0.8754568 ],
       [-0.9165574 ]], dtype=float32), 'z_b_pos_con_cohort_zerosum__': array([[-0.37049973],
       [ 0.95798755],
       [ 0.71250516],
       [ 0.83639866],
       [-0.19618082]], dtype=float32), 'mu_b_neg_con': array([-1.0723019], dtype=float32), 'z_b_neg_con_state_zerosum__': array([[ 0.9188992 , -0.9298911 ,  0.19285995, -0.25151858,  0.37471578,
        -0.33717418]], dtype=float32), 'z_b_neg_con_age_zerosum__': array([[-0.80139756,  0.6711084 ,  0.80932933, -0.03985522,  0.22147077]],
      dtype=float32), 'z_b_neg_con_brand_zerosum__': array([[0.5175842]], dtype=float32), 'z_b_neg_con_cohort_zerosum__': array([[-0.61060673]], dtype=float32), 'mu_b_lag': array([-2.16711], dtype=float32), 'z_b_lag_state_zerosum__': array([[-0.79260015, -0.74764526, -0.82469386, -0.03429089, -0.9838945 ,
        -0.7090971 ]], dtype=float32), 'z_b_lag_age_zerosum__': array([[-0.48504904,  0.87696105, -0.742052  ,  0.9109938 ,  0.04824264]],
      dtype=float32), 'z_b_lag_brand_zerosum__': array([[0.00786572]], dtype=float32), 'z_b_lag_cohort_zerosum__': array([[-0.23193653]], dtype=float32), 'mu_b_fourier_year': array([-0.77838343,  0.6178618 ,  0.57258415,  0.7673871 ,  0.3929526 ,
        0.8627489 ,  0.4622131 , -0.769917  ,  0.13717452, -0.13153929,
       -0.83192456,  0.1140356 ,  0.05431605, -0.10188404,  0.9689602 ,
        0.05997486,  0.06146386, -0.9775481 , -0.47403213, -0.6081761 ],
      dtype=float32), 'sd_y_log__': array(3.3264434, dtype=float32)}

The variable z_a_state has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [7] [id B] <Vector(int64, shape=(1,))>
2: [] [id C] <Vector(int64, shape=(0,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [7]
2: []
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axes=None}.0, Composite{...}.1)
Toposort index: 14
Inputs types: [TensorType(float32, shape=()), TensorType(bool, shape=())]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [array(-57.60402, dtype=float32), array(False)]
Outputs clients: [[output[0](z_a_state_zerosum___logprob)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_vars.py", line 267, in eval_in_context
    result = eval(compiled, global_vars, local_vars)
  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 1886, in debug
    self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
  File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 696, in logp
    rv_logps = transformed_conditional_logp(
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 595, in transformed_conditional_logp
    temp_logp_terms = conditional_logp(
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 122, in transformed_value_logprob
    logprobs_jac.append(logp + log_jac_det)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_a_age has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [6] [id B] <Vector(int64, shape=(1,))>
2: [] [id C] <Vector(int64, shape=(0,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [6]
2: []
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axes=None}.0, Composite{...}.1)
Toposort index: 14
Inputs types: [TensorType(float32, shape=()), TensorType(bool, shape=())]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [array(-78.23544, dtype=float32), array(False)]
Outputs clients: [[output[0](z_a_age_zerosum___logprob)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_vars.py", line 267, in eval_in_context
    result = eval(compiled, global_vars, local_vars)
  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 1886, in debug
    self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
  File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 696, in logp
    rv_logps = transformed_conditional_logp(
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 595, in transformed_conditional_logp
    temp_logp_terms = conditional_logp(
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 122, in transformed_value_logprob
    logprobs_jac.append(logp + log_jac_det)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_state has the following parameters:
0: 0.05 [id A] <Scalar(float32, shape=())>
1: [7] [id B] <Vector(int64, shape=(1,))>
2: [20] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.05000000074505806
1: [7]
2: [20]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(20,), ()]
Inputs strides: [(4,), ()]
Inputs values: ['not shown', array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_age has the following parameters:
0: 0.05 [id A] <Scalar(float32, shape=())>
1: [6] [id B] <Vector(int64, shape=(1,))>
2: [20] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.05000000074505806
1: [6]
2: [20]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(20,), ()]
Inputs strides: [(4,), ()]
Inputs values: ['not shown', array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_brand has the following parameters:
0: 0.05 [id A] <Scalar(float32, shape=())>
1: [2] [id B] <Vector(int64, shape=(1,))>
2: [20] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.05000000074505806
1: [2]
2: [20]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(20,), ()]
Inputs strides: [(4,), ()]
Inputs values: ['not shown', array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_cohort has the following parameters:
0: 0.05 [id A] <Scalar(float32, shape=())>
1: [2] [id B] <Vector(int64, shape=(1,))>
2: [20] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.05000000074505806
1: [2]
2: [20]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(20,), ()]
Inputs strides: [(4,), ()]
Inputs values: ['not shown', array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_pos_con_state has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [7] [id B] <Vector(int64, shape=(1,))>
2: [5] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [7]
2: [5]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(5,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-13.464529, -93.78406 , -63.23401 , -79.51486 , -98.71708 ],
      dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_pos_con_age has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [6] [id B] <Vector(int64, shape=(1,))>
2: [5] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [6]
2: [5]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(5,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([ -49.188126,  -85.46398 , -112.68599 , -140.30869 , -114.29598 ],
      dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_pos_con_brand has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [2] [id B] <Vector(int64, shape=(1,))>
2: [5] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [2]
2: [5]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(5,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-39.79921  ,  -2.3877418, -28.907381 , -36.937584 , -40.620224 ],
      dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_pos_con_cohort has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [2] [id B] <Vector(int64, shape=(1,))>
2: [5] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [2]
2: [5]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(5,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([ -5.4798565 , -44.503365  , -23.999535  , -33.594482  ,
        -0.54069936], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_neg_con_state has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [7] [id B] <Vector(int64, shape=(1,))>
2: [1] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [7]
2: [1]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(1,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-94.879524], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_neg_con_age has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [6] [id B] <Vector(int64, shape=(1,))>
2: [1] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [6]
2: [1]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(1,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-82.99558], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_neg_con_cohort has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [2] [id B] <Vector(int64, shape=(1,))>
2: [1] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [2]
2: [1]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(1,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-17.258383], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_lag_state has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [7] [id B] <Vector(int64, shape=(1,))>
2: [1] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [7]
2: [1]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(1,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-158.66568], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable z_b_lag_age has the following parameters:
0: 0.1 [id A] <Scalar(float32, shape=())>
1: [6] [id B] <Vector(int64, shape=(1,))>
2: [1] [id C] <Vector(int64, shape=(1,))>
The parameters evaluate to:
0: 0.10000000149011612
1: [6]
2: [1]
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axis=1}.0, All{axes=None}.0)
Toposort index: 23
Inputs types: [TensorType(float32, shape=(None,)), TensorType(bool, shape=())]
Inputs shapes: [(1,), ()]
Inputs strides: [(4,), ()]
Inputs values: [array([-112.442345], dtype=float32), array(False)]
Outputs clients: [[Sum{axes=None}(Check{mean(value, axis=n_zerosum_axes) = 0}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

@ricardoV94
Copy link
Member

Your traceback doesn't look like pymc=5.19, the code doesn't use check_start_vals inside init_nuts anymore. PyMC-Marketing is also specifically not compatible with more recent versions of pymc so you have to check what versions are actually running in your script. Try to print pymc.__version__.

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 6, 2024

Your traceback does not correspond to the example code either, it's going through a jax sampler, but I don't see it specified in pm.sample. Can you provide a fully reproducible snippet?

@ricardoV94
Copy link
Member

I can reproduce it though. Perhaps the logp/dlogp underflows in float32

@nataziel
Copy link
Contributor Author

nataziel commented Dec 6, 2024

5.19 definitely uses check_start_vals:

model.check_start_vals(point)

@nataziel
Copy link
Contributor Author

nataziel commented Dec 6, 2024

As I mentioned in the original comment, the traceback I provided is from my own model, but the reproducible example is using the radon dataset.

The only code we use from pymc-marketing is https://github.com/pymc-labs/pymc-marketing/blob/c8121be81d2d2b73ed5d4bf3d32f6592e3ba2afa/pymc_marketing/mmm/transformers.py#L154 and https://github.com/pymc-labs/pymc-marketing/blob/c8121be81d2d2b73ed5d4bf3d32f6592e3ba2afa/pymc_marketing/mmm/transformers.py#L433, which we could easily vendor to remove the dependency. The reason we specify pymc-marketing==0.6.0 is specifically to avoid the restriction on using newer versions of pymc.

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 6, 2024

Ah my bad, it does after it fails with the jitter. Either way when you sample with the pymc nuts, it's not failing in the jitter but only inside NUTS already. The logp_dlogp_func returns -inf for the logp at the point of failure, whereas the model.logp evaluated at that point does not. So some optimization / numerical issue is what I suspect.

Setting an older pymc-marketing gets around the pymc limitation but it's not a solution. It's just because that older version did not pin pymc as strictly as the new one does. Either way that's not the problem you're seeing here

@ricardoV94
Copy link
Member

Right, the pymc-marketing is irrelevant here.

@nataziel
Copy link
Contributor Author

nataziel commented Dec 6, 2024

Ah, I can see the confusion though because the "reproducible" example doesn't reproduce exactly what I was reporting. I took my coworker's word for it at 4pm on a friday afternoon haha.

This for example doesn't throw the error:

import numpy as np
import pandas as pd
import pymc as pm
import pytensor as pt

pt.config.floatX = "float32"
pt.config.warn_float64 = "ignore"

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")

idata = pm.sample(
    model=model,
    chains=1,
    tune=500,
    draws=500,
    progressbar=False,
    compute_convergence_checks=False,
    return_inferencedata=False,
    nuts_sampler="numpyro",
    # compile_kwargs=dict(mode="NUMBA")
)

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 6, 2024

Here is a more direct reproducible example:

import pytensor

pytensor.config.floatX = "float32"

import numpy as np

import pandas as pd
import pymc as pm

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float32)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")


# Bad point fished from the interactive debugger
from numpy import array, float32
q = {'intercept': array(-0.8458743, dtype=float32),
 'county_raw_zerosum__': array([-0.8473211 ,  0.97756225,  0.5851473 , -0.8831246 ,  0.67874885,
         0.74649656,  0.40699005,  0.9938065 ,  0.90805703, -0.55194354,
         0.7369223 , -0.8693557 , -0.18068689,  0.34439757,  0.8696054 ,
        -0.90608346, -0.19901727,  0.18405294,  0.85029787,  0.69731015,
        -0.11369044, -0.45499414,  0.4499965 , -0.78362477, -0.42028612,
         0.33963433, -0.56401193,  0.45644552, -0.39769658, -0.00929202,
        -0.9610129 ,  0.40683702,  0.11690333, -0.21440822, -0.35790983,
        -0.72231764, -0.7358892 , -0.76221883, -0.44132066,  0.8106245 ,
        -0.01106247,  0.89837337,  0.15829656, -0.48148382, -0.07137716,
        -0.37613812, -0.36517394,  0.14016594, -0.63096076, -0.42230594,
         0.776719  , -0.3128489 ,  0.56846076,  0.11121392,  0.5724536 ,
        -0.46519637,  0.83556646, -0.3795832 , -0.24870592, -0.908497  ,
        -0.62978345, -0.23155476,  0.21914907,  0.5683378 ,  0.4083237 ,
         0.45315483, -0.06205622,  0.63755155,  0.97950894, -0.05648626,
        -0.16262522,  0.40750283, -0.9556285 , -0.42807412,  0.6204139 ,
         0.5904101 , -0.7840837 , -0.45694816, -0.6592951 , -0.20405641,
         0.7004118 ,  0.09331694,  0.06100031,  0.10267377], dtype=float32),
 'county_sd_log__': array(0.45848975, dtype=float32),
 'floor_effect': array(0.43849692, dtype=float32),
 'county_floor_raw_zerosum__': array([ 0.68369645,  0.6433043 , -0.0029135 , -0.49709547, -0.02687999,
         0.8271722 , -0.10023019, -0.30813244, -0.4091758 , -0.591417  ,
         0.2297259 , -0.6770909 ,  0.46815294,  0.23881096,  0.41891697,
         0.6744159 , -0.8680713 ,  0.9475378 ,  0.36461526, -0.11404609,
        -0.2285417 , -0.52589136,  0.9446311 ,  0.5722908 ,  0.86332804,
        -0.42848182, -0.1902879 ,  0.95098126,  0.1297681 ,  0.51527834,
         0.7873266 , -0.5753548 ,  0.4216227 , -0.08488699, -0.3141113 ,
         0.6385347 , -0.26448518, -0.0412051 , -0.6691395 , -0.8684154 ,
         0.48946136, -0.5839668 , -0.43648678, -0.20375745,  0.6134852 ,
         0.34660435,  0.2335634 , -0.30285057,  0.0682539 , -0.7834195 ,
        -0.54660916, -0.94278365,  0.5532979 , -0.76577055, -0.6490462 ,
        -0.492982  , -0.74057543, -0.7026031 ,  0.5502333 , -0.8355645 ,
        -0.16759473, -0.1209451 ,  0.5091448 , -0.76411086,  0.14865868,
        -0.71105725,  0.8838853 , -0.43318895, -0.8210448 ,  0.04136186,
        -0.11312467, -0.92210877, -0.19974665, -0.87211764, -0.8225621 ,
         0.03210128, -0.31010386, -0.5447676 , -0.79350907,  0.737303  ,
        -0.04805126, -0.7177033 , -0.77231514,  0.45744413], dtype=float32),
 'county_floor_sd_log__': array(-0.49195403, dtype=float32),
 'sigma_log__': array(0.16196507, dtype=float32)}

model.compile_logp()(q)  # array(-2813.50744694)
model.logp_dlogp_function(ravel_inputs=False)._pytensor_function(**q)[0]  # array(-inf)

I suspect it may be some change on the PyTensor side, not PyMC if it was indeed working in older versions.

@nataziel
Copy link
Contributor Author

nataziel commented Dec 6, 2024

import numpy as np
import pandas as pd
import pymc as pm
import pytensor as pt

pt.config.floatX = "float32"
pt.config.warn_float64 = "ignore"

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")

try:
    idata = pm.sample(
        model=model,
        chains=1,
        tune=500,
        draws=500,
        progressbar=False,
        compute_convergence_checks=False,
        return_inferencedata=False,
        # nuts_sampler="numpyro",
        # compile_kwargs=dict(mode="NUMBA")
    )
except Exception as e:
    print(e)
    model.debug()

results in:

Initializing NUTS using jitter+adapt_diag...
Sequential sampling (1 chains in 1 job)
NUTS: [intercept, county_raw, county_sd, floor_effect, county_floor_raw, county_floor_sd, sigma]
Bad initial energy: SamplerWarning(kind=<WarningType.BAD_ENERGY: 8>, message='Bad initial energy, check any log probabilities that are inf or -inf, nan or very small:\n[]\n.Try model.debug() to identify parametrization problems.', level='critical', step=0, exec_info=None, extra=None, divergence_point_source=None, divergence_point_dest=None, divergence_info=None)
point={'intercept': array(0., dtype=float32), 'county_raw_zerosum__': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32), 'county_sd_log__': array(0., dtype=float32), 'floor_effect': array(0., dtype=float32), 'county_floor_raw_zerosum__': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32), 'county_floor_sd_log__': array(0., dtype=float32), 'sigma_log__': array(0.4054651, dtype=float32)}

No problems found

@ricardoV94
Copy link
Member

The problem is not the initial point, not even the one with jitter but the one from the first step of NUTS which is already mutated. And also it only shows up with the logp_dlogp_function that fuses the logp and dlogp. model.debug does not test that

@ricardoV94
Copy link
Member

This snippet #7608 (comment) fails in 5.18.0, so maybe the model is just not good enough on float32. So it may be irrelevant for your issue?

@nataziel
Copy link
Contributor Author

nataziel commented Dec 6, 2024

import numpy as np
import pandas as pd
import pymc as pm
import pytensor as pt
from numpy import array, float32

pt.config.floatX = "float32"
pt.config.warn_float64 = "ignore"

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic("county_floor_effect", raw * sd, dims="county")

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal("log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id")

model_logp_fn = model.compile_logp()

q = {
    "intercept": array(-0.8458743, dtype=float32),
    "county_raw_zerosum__": array(
        [
            -0.8473211,
            0.97756225,
            0.5851473,
            -0.8831246,
            0.67874885,
            0.74649656,
            0.40699005,
            0.9938065,
            0.90805703,
            -0.55194354,
            0.7369223,
            -0.8693557,
            -0.18068689,
            0.34439757,
            0.8696054,
            -0.90608346,
            -0.19901727,
            0.18405294,
            0.85029787,
            0.69731015,
            -0.11369044,
            -0.45499414,
            0.4499965,
            -0.78362477,
            -0.42028612,
            0.33963433,
            -0.56401193,
            0.45644552,
            -0.39769658,
            -0.00929202,
            -0.9610129,
            0.40683702,
            0.11690333,
            -0.21440822,
            -0.35790983,
            -0.72231764,
            -0.7358892,
            -0.76221883,
            -0.44132066,
            0.8106245,
            -0.01106247,
            0.89837337,
            0.15829656,
            -0.48148382,
            -0.07137716,
            -0.37613812,
            -0.36517394,
            0.14016594,
            -0.63096076,
            -0.42230594,
            0.776719,
            -0.3128489,
            0.56846076,
            0.11121392,
            0.5724536,
            -0.46519637,
            0.83556646,
            -0.3795832,
            -0.24870592,
            -0.908497,
            -0.62978345,
            -0.23155476,
            0.21914907,
            0.5683378,
            0.4083237,
            0.45315483,
            -0.06205622,
            0.63755155,
            0.97950894,
            -0.05648626,
            -0.16262522,
            0.40750283,
            -0.9556285,
            -0.42807412,
            0.6204139,
            0.5904101,
            -0.7840837,
            -0.45694816,
            -0.6592951,
            -0.20405641,
            0.7004118,
            0.09331694,
            0.06100031,
            0.10267377,
        ],
        dtype=float32,
    ),
    "county_sd_log__": array(0.45848975, dtype=float32),
    "floor_effect": array(0.43849692, dtype=float32),
    "county_floor_raw_zerosum__": array(
        [
            0.68369645,
            0.6433043,
            -0.0029135,
            -0.49709547,
            -0.02687999,
            0.8271722,
            -0.10023019,
            -0.30813244,
            -0.4091758,
            -0.591417,
            0.2297259,
            -0.6770909,
            0.46815294,
            0.23881096,
            0.41891697,
            0.6744159,
            -0.8680713,
            0.9475378,
            0.36461526,
            -0.11404609,
            -0.2285417,
            -0.52589136,
            0.9446311,
            0.5722908,
            0.86332804,
            -0.42848182,
            -0.1902879,
            0.95098126,
            0.1297681,
            0.51527834,
            0.7873266,
            -0.5753548,
            0.4216227,
            -0.08488699,
            -0.3141113,
            0.6385347,
            -0.26448518,
            -0.0412051,
            -0.6691395,
            -0.8684154,
            0.48946136,
            -0.5839668,
            -0.43648678,
            -0.20375745,
            0.6134852,
            0.34660435,
            0.2335634,
            -0.30285057,
            0.0682539,
            -0.7834195,
            -0.54660916,
            -0.94278365,
            0.5532979,
            -0.76577055,
            -0.6490462,
            -0.492982,
            -0.74057543,
            -0.7026031,
            0.5502333,
            -0.8355645,
            -0.16759473,
            -0.1209451,
            0.5091448,
            -0.76411086,
            0.14865868,
            -0.71105725,
            0.8838853,
            -0.43318895,
            -0.8210448,
            0.04136186,
            -0.11312467,
            -0.92210877,
            -0.19974665,
            -0.87211764,
            -0.8225621,
            0.03210128,
            -0.31010386,
            -0.5447676,
            -0.79350907,
            0.737303,
            -0.04805126,
            -0.7177033,
            -0.77231514,
            0.45744413,
        ],
        dtype=float32,
    ),
    "county_floor_sd_log__": array(-0.49195403, dtype=float32),
    "sigma_log__": array(0.16196507, dtype=float32),
}

model.debug(q, verbose=True)

makes a nicer debug trace:

point={'intercept': array(-0.8458743, dtype=float32), 'county_raw_zerosum__': array([-0.8473211 ,  0.97756225,  0.5851473 , -0.8831246 ,  0.67874885,
        0.74649656,  0.40699005,  0.9938065 ,  0.90805703, -0.55194354,
        0.7369223 , -0.8693557 , -0.18068689,  0.34439757,  0.8696054 ,
       -0.90608346, -0.19901727,  0.18405294,  0.85029787,  0.69731015,
       -0.11369044, -0.45499414,  0.4499965 , -0.78362477, -0.42028612,
        0.33963433, -0.56401193,  0.45644552, -0.39769658, -0.00929202,
       -0.9610129 ,  0.40683702,  0.11690333, -0.21440822, -0.35790983,
       -0.72231764, -0.7358892 , -0.76221883, -0.44132066,  0.8106245 ,
       -0.01106247,  0.89837337,  0.15829656, -0.48148382, -0.07137716,
       -0.37613812, -0.36517394,  0.14016594, -0.63096076, -0.42230594,
        0.776719  , -0.3128489 ,  0.56846076,  0.11121392,  0.5724536 ,
       -0.46519637,  0.83556646, -0.3795832 , -0.24870592, -0.908497  ,
       -0.62978345, -0.23155476,  0.21914907,  0.5683378 ,  0.4083237 ,
        0.45315483, -0.06205622,  0.63755155,  0.97950894, -0.05648626,
       -0.16262522,  0.40750283, -0.9556285 , -0.42807412,  0.6204139 ,
        0.5904101 , -0.7840837 , -0.45694816, -0.6592951 , -0.20405641,
        0.7004118 ,  0.09331694,  0.06100031,  0.10267377], dtype=float32), 'county_sd_log__': array(0.45848975, dtype=float32), 'floor_effect': array(0.43849692, dtype=float32), 'county_floor_raw_zerosum__': array([ 0.68369645,  0.6433043 , -0.0029135 , -0.49709547, -0.02687999,
        0.8271722 , -0.10023019, -0.30813244, -0.4091758 , -0.591417  ,
        0.2297259 , -0.6770909 ,  0.46815294,  0.23881096,  0.41891697,
        0.6744159 , -0.8680713 ,  0.9475378 ,  0.36461526, -0.11404609,
       -0.2285417 , -0.52589136,  0.9446311 ,  0.5722908 ,  0.86332804,
       -0.42848182, -0.1902879 ,  0.95098126,  0.1297681 ,  0.51527834,
        0.7873266 , -0.5753548 ,  0.4216227 , -0.08488699, -0.3141113 ,
        0.6385347 , -0.26448518, -0.0412051 , -0.6691395 , -0.8684154 ,
        0.48946136, -0.5839668 , -0.43648678, -0.20375745,  0.6134852 ,
        0.34660435,  0.2335634 , -0.30285057,  0.0682539 , -0.7834195 ,
       -0.54660916, -0.94278365,  0.5532979 , -0.76577055, -0.6490462 ,
       -0.492982  , -0.74057543, -0.7026031 ,  0.5502333 , -0.8355645 ,
       -0.16759473, -0.1209451 ,  0.5091448 , -0.76411086,  0.14865868,
       -0.71105725,  0.8838853 , -0.43318895, -0.8210448 ,  0.04136186,
       -0.11312467, -0.92210877, -0.19974665, -0.87211764, -0.8225621 ,
        0.03210128, -0.31010386, -0.5447676 , -0.79350907,  0.737303  ,
       -0.04805126, -0.7177033 , -0.77231514,  0.45744413], dtype=float32), 'county_floor_sd_log__': array(-0.49195403, dtype=float32), 'sigma_log__': array(0.16196507, dtype=float32)}

The variable county_raw has the following parameters:
0: 1.0 [id A] <Scalar(float32, shape=())>
1: MakeVector{dtype='int64'} [id B] <Vector(int64, shape=(1,))>
 └─ county [id C] <Scalar(int64, shape=())>
2: [] [id D] <Vector(int64, shape=(0,))>
The parameters evaluate to:
0: 1.0
1: [85]
2: []
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axes=None}.0, Composite{...}.1)
Toposort index: 14
Inputs types: [TensorType(float32, shape=()), TensorType(bool, shape=())]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [array(-91.11339, dtype=float32), array(False)]
Outputs clients: [[output[0](county_raw_zerosum___logprob)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],
  File "/workspaces/mmm_v2/reproducible.py", line 231, in <module>
    model.debug(q, verbose=True)
  File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 1886, in debug
    self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
  File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 696, in logp
    rv_logps = transformed_conditional_logp(
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 595, in transformed_conditional_logp
    temp_logp_terms = conditional_logp(
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 122, in transformed_value_logprob
    logprobs_jac.append(logp + log_jac_det)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
The variable county_floor_raw has the following parameters:
0: 1.0 [id A] <Scalar(float32, shape=())>
1: MakeVector{dtype='int64'} [id B] <Vector(int64, shape=(1,))>
 └─ county [id C] <Scalar(int64, shape=())>
2: [] [id D] <Vector(int64, shape=(0,))>
The parameters evaluate to:
0: 1.0
1: [85]
2: []
This does not respect one of the following constraints: mean(value, axis=n_zerosum_axes) = 0

mean(value, axis=n_zerosum_axes) = 0
Apply node that caused the error: Check{mean(value, axis=n_zerosum_axes) = 0}(Sum{axes=None}.0, Composite{...}.1)
Toposort index: 14
Inputs types: [TensorType(float32, shape=()), TensorType(bool, shape=())]
Inputs shapes: [(), ()]
Inputs strides: [(), ()]
Inputs values: [array(-90.89934, dtype=float32), array(False)]
Outputs clients: [[output[0](county_floor_raw_zerosum___logprob)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 84, in transformed_value_logprob
    logprobs = _logprob(rv_op, values, *rv_inputs, **kwargs)
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/multivariate.py", line 2841, in zerosumnormal_logp
    return check_parameters(out, *zerosums, msg="mean(value, axis=n_zerosum_axes) = 0")
  File "/usr/local/lib/python3.10/dist-packages/pymc/distributions/dist_math.py", line 74, in check_parameters
    return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/graph/op.py", line 293, in __call__
    node = self.make_node(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/raise_op.py", line 97, in make_node
    [value.type()],
  File "/workspaces/mmm_v2/reproducible.py", line 231, in <module>
    model.debug(q, verbose=True)
  File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 1886, in debug
    self.value_vars, self.logp(vars=rv, sum=False)[0], on_unused_input="ignore"
  File "/usr/local/lib/python3.10/dist-packages/pymc/model/core.py", line 696, in logp
    rv_logps = transformed_conditional_logp(
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 595, in transformed_conditional_logp
    temp_logp_terms = conditional_logp(
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/basic.py", line 529, in conditional_logp
    node_logprobs = _logprob(
  File "/usr/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/usr/local/lib/python3.10/dist-packages/pymc/logprob/transform_value.py", line 122, in transformed_value_logprob
    logprobs_jac.append(logp + log_jac_det)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

@nataziel
Copy link
Contributor Author

nataziel commented Dec 6, 2024

If you do the above but change FloatX to 64 and the datatypes in the array to float64 you get:

point={'intercept': array(-0.8458743), 'county_raw_zerosum__': array([-0.8473211 ,  0.97756225,  0.5851473 , -0.8831246 ,  0.67874885,
        0.74649656,  0.40699005,  0.9938065 ,  0.90805703, -0.55194354,
        0.7369223 , -0.8693557 , -0.18068689,  0.34439757,  0.8696054 ,
       -0.90608346, -0.19901727,  0.18405294,  0.85029787,  0.69731015,
       -0.11369044, -0.45499414,  0.4499965 , -0.78362477, -0.42028612,
        0.33963433, -0.56401193,  0.45644552, -0.39769658, -0.00929202,
       -0.9610129 ,  0.40683702,  0.11690333, -0.21440822, -0.35790983,
       -0.72231764, -0.7358892 , -0.76221883, -0.44132066,  0.8106245 ,
       -0.01106247,  0.89837337,  0.15829656, -0.48148382, -0.07137716,
       -0.37613812, -0.36517394,  0.14016594, -0.63096076, -0.42230594,
        0.776719  , -0.3128489 ,  0.56846076,  0.11121392,  0.5724536 ,
       -0.46519637,  0.83556646, -0.3795832 , -0.24870592, -0.908497  ,
       -0.62978345, -0.23155476,  0.21914907,  0.5683378 ,  0.4083237 ,
        0.45315483, -0.06205622,  0.63755155,  0.97950894, -0.05648626,
       -0.16262522,  0.40750283, -0.9556285 , -0.42807412,  0.6204139 ,
        0.5904101 , -0.7840837 , -0.45694816, -0.6592951 , -0.20405641,
        0.7004118 ,  0.09331694,  0.06100031,  0.10267377]), 'county_sd_log__': array(0.45848975), 'floor_effect': array(0.43849692), 'county_floor_raw_zerosum__': array([ 0.68369645,  0.6433043 , -0.0029135 , -0.49709547, -0.02687999,
        0.8271722 , -0.10023019, -0.30813244, -0.4091758 , -0.591417  ,
        0.2297259 , -0.6770909 ,  0.46815294,  0.23881096,  0.41891697,
        0.6744159 , -0.8680713 ,  0.9475378 ,  0.36461526, -0.11404609,
       -0.2285417 , -0.52589136,  0.9446311 ,  0.5722908 ,  0.86332804,
       -0.42848182, -0.1902879 ,  0.95098126,  0.1297681 ,  0.51527834,
        0.7873266 , -0.5753548 ,  0.4216227 , -0.08488699, -0.3141113 ,
        0.6385347 , -0.26448518, -0.0412051 , -0.6691395 , -0.8684154 ,
        0.48946136, -0.5839668 , -0.43648678, -0.20375745,  0.6134852 ,
        0.34660435,  0.2335634 , -0.30285057,  0.0682539 , -0.7834195 ,
       -0.54660916, -0.94278365,  0.5532979 , -0.76577055, -0.6490462 ,
       -0.492982  , -0.74057543, -0.7026031 ,  0.5502333 , -0.8355645 ,
       -0.16759473, -0.1209451 ,  0.5091448 , -0.76411086,  0.14865868,
       -0.71105725,  0.8838853 , -0.43318895, -0.8210448 ,  0.04136186,
       -0.11312467, -0.92210877, -0.19974665, -0.87211764, -0.8225621 ,
        0.03210128, -0.31010386, -0.5447676 , -0.79350907,  0.737303  ,
       -0.04805126, -0.7177033 , -0.77231514,  0.45744413]), 'county_floor_sd_log__': array(-0.49195403), 'sigma_log__': array(0.16196507)}

No problems found

@ricardoV94
Copy link
Member

Yes it's a precision issue that sometimes shows up. debug is not the most useful here because it's evaluating each variable logp at a time.

It also doesn't count as bug just yet because float32 is inherently less precise (that's the whole point of it). To mark it as a bug we need proof of some regression (used to work when evaluated at the same point) or a justification for why imprecision is unreasonable here

@nataziel
Copy link
Contributor Author

nataziel commented Dec 6, 2024

I'm out of time to check it tonight but I'm 99% sure the value returned by feeding the same bad point into the compiled logp function between 5.18 and 5.19 is different.

I can test by running the debugger to the first line of _init_jitter, calling model.compile_logp()(bad_point). I'll see if I can give it a go over the weekend

@ricardoV94
Copy link
Member

I tested the script I pasted above in 5.18.0 and it wasn't different. Both cases underflowed to -inf.

Also need to narrow the focus. Ignore model.debug and check_start_values, that's too noisy and indirect. Evaluating the logp_dlogp_function like the snippet I shared and checking if the first value is -inf is what matters for us/nuts.

@nataziel
Copy link
Contributor Author

nataziel commented Dec 10, 2024

I think I've worked it out. Regardless of whether there is a difference in the compiled logp function or the generated points, we would have never failed in init_jitter before 5.19.

Previously:

  def _init_jitter
    ... # not important stuff

    initial_points = []
    for ipfn, seed in zip(ipfns, seeds):
        rng = np.random.RandomState(seed)
        for i in range(jitter_max_retries + 1):
            point = ipfn(seed)
            if i < jitter_max_retries:
                try:
                    model.check_start_vals(point)
                except SamplingError:
                    # Retry with a new seed
                    seed = rng.randint(2**30, dtype=np.int64)
                else:
                    break
        initial_points.append(point)
    return initial_points

Here in pymc==5.18 it will try jitter_max_retries times then append the point to initial_points anyway, even if model.check_start_vals raises a SamplingError every single time.

From there it returns up the stack to sample_jax_nuts and continues just fine. I tried following the code down through _sample_numpyro_nuts where it gets the get_jaxified_logp function, but I can't work out how to pass the initial points into that function to check whether the initial points are feasible here compared to the previous pymc/pytensor compiled logp function.

I've tried to follow the code through where it initialises the numpyro NUTS kernel, MCMC sampler and running the sampler with the initial points, but it's pretty deep in the JAX internals and because it uses functools.partial on _single_chain_mcmc I can't go any deeper with the debugger. I think it's safe to assume that numpyro/jax doesn't seem to have a problem with the initial point that is passed to it for my model, because it continues just fine from here.

So I guess the problem for me here is that pymc prematurely exits in 5.19 because it can't find a feasible starting point, even though numpyro may be able to sample just fine from the generated starting point

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 10, 2024

That sounds right @nataziel. Question: was numpyro actually sampling or did it just return 100% divergences for that chain? It usually does so silently.

If it was working fine, it could be that the initial point is not underflowing with the JAX backend, in which case we should probably use the jax logp function to evaluate the init jitter instead of the default model.compile_logp()

@nataziel
Copy link
Contributor Author

I edited the code in _init_jitter on my local 5.19 installation to bypass the checks and it sampled all 4 chains with 0 divergences.

I don't quite understand what needs to be passed to the jaxified logp function, it seems like the values returned by _get_batched_jittered_initial_points aren't quite correct. Does it need to be a dictionary instead of an array of arrays? If I understood the structure of what needs to be passed to the jaxified logp function I could write a separate _init_jitter_jax function or do some conditional checking within the current _init_jitter function?

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 10, 2024

I meant that we should pass the jaxified logp function to be used here instead of defaulting to model.compile_logp():

pymc/pymc/sampling/mcmc.py

Lines 1373 to 1379 in a714b24

if logp_dlogp_func is None:
model_logp_fn = model.compile_logp()
else:
def model_logp_fn(ip):
q, _ = DictToArrayBijection.map(ip)
return logp_dlogp_func([q], extra_vars={})[0]
.

@nataziel
Copy link
Contributor Author

I tried this:

    from pymc.sampling.jax import get_jaxified_logp
    model_logp_fn: Callable
    if logp_dlogp_func is None:
        model_logp_fn = get_jaxified_logp(model=model, negative_logp=False)
    else:

        def model_logp_fn(ip):
            q, _ = DictToArrayBijection.map(ip)
            return logp_dlogp_func([q], extra_vars={})[0]

and tried passing in this point to model_logp_fn:

{'mu_adstock_logodds__': array([-0.7406309 , -0.2676502 ,  0.72276473,  0.39711773, -1.4843715 ,
       -0.9303382 , -0.70388365,  0.12047255, -1.20291   , -1.0157162 ,
       -0.9788475 , -0.16161698, -0.5186577 ,  0.44525263,  0.50664103,
       -0.5502606 , -0.9929964 , -1.7073784 , -0.469832  , -0.7702299 ],
      dtype=float32), 'mu_lambda_log__': array([1.4899505 , 0.87828857, 0.33736384, 0.12322366, 1.6119224 ,
       2.1611521 , 0.19036102, 0.78274024, 1.9591975 , 0.86114645,
       1.00022   , 1.8908874 , 0.7735787 , 1.2600771 , 0.8721991 ,
       0.3919416 , 0.628017  , 0.5571408 , 2.2277155 , 0.701397  ],
      dtype=float32), 'mu_a': array(0.46919715, dtype=float32), 'z_a_state_zerosum__': array([ 0.96741545,  0.19968931,  0.55482584, -0.40800413, -0.783277  ,
        0.6665936 ], dtype=float32), 'z_a_age_zerosum__': array([ 0.97223747, -0.88604414, -0.60649115,  0.691295  , -0.17161931],
      dtype=float32), 'z_a_brand_zerosum__': array([-0.21275534], dtype=float32), 'z_a_cohort_zerosum__': array([-0.01881989], dtype=float32), 'roas_rv_log__': array([-1.8874438 , -1.3390365 , -2.1119297 , -0.30115628, -1.3759781 ,
       -0.9544507 ,  1.3654704 , -0.80472004,  1.3217607 , -1.6872417 ,
       -1.0485291 , -0.90976775, -1.1248429 , -1.5477487 , -0.30651912,
        0.51637214,  0.5301037 , -0.49982694,  1.757268  ,  1.03213   ],
      dtype=float32), 'z_b_state_zerosum__': array([[ 0.00446961, -0.85987175, -0.74123687, -0.46256822, -0.52106553,
         0.28104278],
       [ 0.05966987, -0.8486371 , -0.43098626, -0.12444586,  0.1801346 ,
        -0.37303272],
       [ 0.08682439,  0.53125477, -0.4337221 , -0.80694795, -0.41105202,
         0.8999604 ],
       [ 0.01910053,  0.2654662 , -0.07900505,  0.47407308,  0.7956779 ,
        -0.64507806],
       [ 0.10577084, -0.01806336, -0.4654986 ,  0.00858531,  0.4964019 ,
        -0.15452549],
       [-0.58119875, -0.533203  ,  0.8720117 , -0.9220113 , -0.08726341,
        -0.33014426],
       [ 0.5597552 ,  0.21657923, -0.6274215 , -0.00888674,  0.5606966 ,
        -0.6045255 ],
       [-0.49455065,  0.5478223 ,  0.9508188 , -0.7354254 ,  0.19366987,
         0.5816819 ],
       [-0.82646775, -0.5263257 ,  0.20099497,  0.88074464, -0.4345398 ,
         0.06769424],
       [ 0.26323393,  0.61359143,  0.01295813, -0.40680176, -0.3380146 ,
         0.3240754 ],
       [ 0.6390363 , -0.07461884,  0.17888807, -0.17294951, -0.8052904 ,
        -0.2960819 ],
       [-0.88565934,  0.13199767, -0.09011242, -0.57291055,  0.71278757,
        -0.06531783],
       [ 0.4843889 ,  0.9435816 ,  0.14761145,  0.2508237 , -0.02830961,
         0.40583134],
       [ 0.64028126, -0.09345473,  0.44015244, -0.18035695,  0.63984483,
         0.40306124],
       [ 0.85732955,  0.20738094, -0.77978706, -0.5081236 ,  0.25628823,
         0.9576838 ],
       [-0.43284124,  0.49378812, -0.7574774 , -0.9391033 ,  0.6099457 ,
        -0.83641356],
       [-0.6440243 ,  0.68688387,  0.4862265 ,  0.5263312 , -0.3289637 ,
         0.18450338],
       [-0.7553003 ,  0.8161998 , -0.88512534, -0.06603678,  0.24693777,
        -0.78690183],
       [ 0.1868632 , -0.21966957,  0.3369232 , -0.9996609 , -0.35670304,
         0.4821175 ],
       [ 0.3532054 , -0.5449791 , -0.00193312, -0.16562222,  0.51523185,
        -0.3292687 ]], dtype=float32), 'z_b_age_zerosum__': array([[ 0.4639124 ,  0.94854355, -0.32051557,  0.5695813 , -0.464497  ],
       [ 0.24478397, -0.38236296,  0.2442325 ,  0.6532343 , -0.5803767 ],
       [-0.43226814,  0.53636163,  0.3303343 , -0.42391777, -0.04977154],
       [-0.0449917 , -0.18323518,  0.09939765,  0.44787315,  0.21340491],
       [-0.48593655,  0.9875687 , -0.30522144, -0.24290714,  0.7979216 ],
       [-0.5306124 ,  0.43397802, -0.20600496,  0.8865641 ,  0.36890575],
       [ 0.16192637, -0.85434455,  0.14579847, -0.6387437 , -0.6332226 ],
       [-0.00474756, -0.4770844 , -0.80014896, -0.4984475 ,  0.08337943],
       [-0.38859433,  0.81244034, -0.15071645,  0.7578935 , -0.22230786],
       [ 0.21995819, -0.45969793, -0.05771023, -0.3626073 ,  0.8941617 ],
       [ 0.07187908, -0.25421968,  0.11764435, -0.01395176, -0.6094777 ],
       [-0.13571994,  0.9205862 , -0.6560107 ,  0.3603058 , -0.8363712 ],
       [ 0.78542286, -0.4191767 , -0.27891508,  0.2105725 ,  0.38422632],
       [-0.7178596 ,  0.49950433, -0.2591695 , -0.6500654 ,  0.78156734],
       [ 0.24467742, -0.09884497, -0.9059215 ,  0.69811964, -0.04913842],
       [-0.00168463, -0.1506732 , -0.8326015 ,  0.260028  , -0.318087  ],
       [ 0.40579167, -0.42483094,  0.15233344,  0.3852206 ,  0.26324713],
       [ 0.6354356 , -0.5003003 , -0.09142033,  0.80062026, -0.05573656],
       [-0.64219224,  0.75683314, -0.25206646,  0.9859022 , -0.7528035 ],
       [ 0.04350585, -0.21413967,  0.7432214 , -0.6038442 ,  0.219704  ]],
      dtype=float32), 'z_b_brand_zerosum__': array([[ 0.77682245],
       [-0.68846065],
       [ 0.15427889],
       [ 0.85911787],
       [ 0.7141093 ],
       [ 0.11126634],
       [ 0.8475281 ],
       [-0.018953  ],
       [-0.08016697],
       [-0.4876936 ],
       [ 0.92964715],
       [ 0.77537847],
       [ 0.23121522],
       [ 0.11847817],
       [-0.7639938 ],
       [-0.22309716],
       [-0.41808844],
       [ 0.23701279],
       [-0.04789526],
       [ 0.09624694]], dtype=float32), 'z_b_cohort_zerosum__': array([[-0.5643933 ],
       [ 0.9322058 ],
       [-0.3288698 ],
       [ 0.00913158],
       [-0.7385534 ],
       [-0.68776625],
       [-0.5444413 ],
       [ 0.42466304],
       [ 0.7728684 ],
       [-0.7700562 ],
       [-0.18284462],
       [ 0.6402906 ],
       [ 0.6988464 ],
       [ 0.8284381 ],
       [-0.7142591 ],
       [-0.12452263],
       [ 0.1419726 ],
       [ 0.20289187],
       [ 0.6634637 ],
       [ 0.77786607]], dtype=float32), 'mu_b_pos_con': array([-1.3094065, -1.7115856, -1.4250492, -2.2511344, -1.226164 ],
      dtype=float32), 'z_b_pos_con_state_zerosum__': array([[-0.39936054,  0.8805549 ,  0.97654635,  0.6494237 ,  0.5060455 ,
         0.8129397 ],
       [-0.7020755 ,  0.8573673 , -0.11473656, -0.81267875,  0.52816015,
         0.25964367],
       [-0.30995888, -0.909639  , -0.03129133, -0.83288676, -0.8827531 ,
         0.8252884 ],
       [-0.23201741,  0.5135355 , -0.8893724 , -0.00104977, -0.5592616 ,
         0.8351593 ],
       [-0.03384887,  0.25019094, -0.80081666,  0.45951134, -0.35681835,
        -0.8254566 ]], dtype=float32), 'z_b_pos_con_age_zerosum__': array([[-0.61742157, -0.09719887,  0.58104664, -0.92894936, -0.9795723 ],
       [-0.42654088,  0.64068526,  0.30092153,  0.24177577,  0.2526327 ],
       [-0.80097747,  0.9057477 ,  0.43585142, -0.85004056, -0.01753056],
       [-0.31914535,  0.14012223, -0.6530986 ,  0.7002828 ,  0.6456084 ],
       [-0.16960691,  0.26178694, -0.47111732, -0.3870159 ,  0.63950986]],
      dtype=float32), 'z_b_pos_con_brand_zerosum__': array([[-0.06096065],
       [ 0.35775375],
       [ 0.8893246 ],
       [ 0.14325647],
       [ 0.9434139 ]], dtype=float32), 'z_b_pos_con_cohort_zerosum__': array([[-0.954737  ],
       [-0.07808845],
       [ 0.56892526],
       [-0.37843582],
       [-0.66838884]], dtype=float32), 'mu_b_neg_con': array([-0.25086808], dtype=float32), 'z_b_neg_con_state_zerosum__': array([[-0.932691  , -0.56728685, -0.08727422,  0.06912095,  0.8635172 ,
        -0.2142895 ]], dtype=float32), 'z_b_neg_con_age_zerosum__': array([[ 0.90207416, -0.88017714, -0.83211   , -0.5490533 ,  0.6520192 ]],
      dtype=float32), 'z_b_neg_con_brand_zerosum__': array([[0.25788188]], dtype=float32), 'z_b_neg_con_cohort_zerosum__': array([[0.56738347]], dtype=float32), 'mu_b_lag': array([-2.4469194], dtype=float32), 'z_b_lag_state_zerosum__': array([[ 0.1700482 ,  0.5989304 ,  0.47253153, -0.75125474, -0.5838406 ,
         0.68338937]], dtype=float32), 'z_b_lag_age_zerosum__': array([[-0.33903205, -0.63544524,  0.03893599,  0.47806814, -0.04220857]],
      dtype=float32), 'z_b_lag_brand_zerosum__': array([[-0.08016187]], dtype=float32), 'z_b_lag_cohort_zerosum__': array([[-0.18082687]], dtype=float32), 'mu_b_fourier_year': array([ 0.8240624 ,  0.646912  , -0.452118  , -0.08140495, -0.96048754,
       -0.74027205, -0.9938018 ,  0.20062245, -0.28748137, -0.82254994,
       -0.4910437 , -0.322535  , -0.09896964, -0.30639052,  0.8899779 ,
        0.2462373 , -0.25278836,  0.16529965, -0.17628683,  0.96998924],
      dtype=float32), 'sd_y_log__': array(3.4784012, dtype=float32)}

and am getting this error:

model_logp_fn(point)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 703, in dtype
    dt = np.result_type(x)
TypeError: data type 'sd_y_log__' not understood

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 5394, in array
    dtype = dtypes._lattice_result_type(*leaves)[0] if leaves else dtypes.float_
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 713, in _lattice_result_type
    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 713, in <genexpr>
    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 516, in _dtype_and_weaktype
    return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 705, in dtype
    raise TypeError(f"Cannot determine dtype of {x}") from err
TypeError: Cannot determine dtype of sd_y_log__

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pymc/sampling/jax.py", line 154, in logp_fn_wrap
    return logp_fn(*x)[0]
  File "/tmp/tmpa31z87gr", line 3, in jax_funcified_fgraph
    tensor_variable = elemwise_fn(sd_y_log_)
  File "/usr/local/lib/python3.10/dist-packages/pytensor/link/jax/dispatch/elemwise.py", line 17, in elemwise_fn
    Elemwise._check_runtime_broadcast(node, tuple(map(jnp.asarray, inputs)))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 5592, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 5399, in array
    dtype = dtypes._lattice_result_type(*leaves)[0]
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 713, in _lattice_result_type
    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 713, in <genexpr>
    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 516, in _dtype_and_weaktype
    return dtype(value), any(value is typ for typ in _weak_types) or is_weakly_typed(value)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dtypes.py", line 707, in dtype
    raise TypeError(f"Value '{x}' with dtype {dt} is not a valid JAX array "
TypeError: Value 'sd_y_log__' with dtype <U10 is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

I tried changing the value of sd_y_log__ to array([3.4784012], dtype=float32) thinking the indexing or something was the problem but that didn't seem to help. Seems like it's passing the value of sd_y_log__ into the jax_funcified_fgraph as a string (dtype <U10 is a 10 character string, literally sd_y_log__)

@nataziel
Copy link
Contributor Author

Ah! The culprit is here:

pymc/pymc/sampling/jax.py

Lines 153 to 156 in a714b24

def logp_fn_wrap(x):
return logp_fn(*x)[0]
return logp_fn_wrap

It's passing the keys instead of the values

so we can do:

    model_logp_fn: Callable
    if logp_dlogp_func is None:
        model_logp_fn = get_jaxified_logp(model=model, negative_logp=False)
    else:

        def model_logp_fn(ip):
            q, _ = DictToArrayBijection.map(ip)
            return logp_dlogp_func([q], extra_vars={})[0]

    initial_points = []
    for ipfn, seed in zip(ipfns, seeds):
        rng = np.random.default_rng(seed)
        for i in range(jitter_max_retries + 1):
            point = ipfn(seed)
            point_logp = model_logp_fn(point)
            if not np.isfinite(point_logp):
                if i == jitter_max_retries:
                    # Print informative message on last attempted point
                    model.check_start_vals(point.values())
                # Retry with a new seed
                seed = rng.integers(2**30, dtype=np.int64)
            else:
                break

And that doesn't underflow

@ricardoV94
Copy link
Member

Instead of changing the logic inside init_nuts, pass the jax logp fn wrapped in a callable that handles the conversion

@nataziel
Copy link
Contributor Author

Yep, sounds good. Working on a PR now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants