Skip to content

Commit

Permalink
Merge branch 'main' of github.com:bambinos/bambi into hsgp-multivaria…
Browse files Browse the repository at this point in the history
…te-responses
  • Loading branch information
tomicapretto committed Nov 10, 2024
2 parents 59d7059 + 7a18fb9 commit 0a78220
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 17 deletions.
43 changes: 39 additions & 4 deletions bambi/families/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@ def transform_coords(self, model, mean):
return mean

def posterior_predictive(self, model, posterior, **kwargs):
n = model.response_component.term.data.sum(1).astype(int)
data = kwargs["data"]
if data is None:
y = model.response_component.term.data
trials = model.response_component.term.data.sum(1).astype(int)
else:
y = response_evaluate_new_data(model, data).astype(int)
trials = y.sum(1).astype(int)

# Prepend 'draw' and 'chain' dimensions
trials = trials[np.newaxis, np.newaxis, :]
dont_reshape = ["n"]
return super().posterior_predictive(model, posterior, n=n, dont_reshape=dont_reshape)
return super().posterior_predictive(model, posterior, n=trials, dont_reshape=dont_reshape)

def log_likelihood(self, model, posterior, data, **kwargs):
if data is None:
Expand Down Expand Up @@ -91,9 +100,35 @@ class DirichletMultinomial(MultivariateFamily):
SUPPORTED_LINKS = {"a": ["log"]}

def posterior_predictive(self, model, posterior, **kwargs):
n = model.response_component.term.data.sum(1).astype(int)
data = kwargs["data"]
if data is None:
y = model.response_component.term.data
trials = model.response_component.term.data.sum(1).astype(int)
else:
y = response_evaluate_new_data(model, data).astype(int)
trials = y.sum(1).astype(int)

# Prepend 'draw' and 'chain' dimensions
trials = trials[np.newaxis, np.newaxis, :]
dont_reshape = ["n"]
return super().posterior_predictive(model, posterior, n=n, dont_reshape=dont_reshape)
return super().posterior_predictive(model, posterior, n=trials, dont_reshape=dont_reshape)

def log_likelihood(self, model, posterior, data, **kwargs):
if data is None:
y = model.response_component.term.data
trials = model.response_component.term.data.sum(1).astype(int)
else:
y = response_evaluate_new_data(model, data).astype(int)
trials = y.sum(1).astype(int)

# Prepend 'draw' and 'chain' dimensions
y = y[np.newaxis, np.newaxis, :]
trials = trials[np.newaxis, np.newaxis, :]

dont_reshape = ["n"]
return super().log_likelihood(
model, posterior, data=None, y=y, n=trials, dont_reshape=dont_reshape, **kwargs
)

def get_coords(self, response):
name = get_aliased_name(response) + "_dim"
Expand Down
22 changes: 20 additions & 2 deletions bambi/families/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,17 @@ def transform_backend_eta(eta, kwargs):
# shape(threshold) = (K, )
# shape(eta) = (n, )
# shape(threshold - shape_padright(eta)) = (n, K)

threshold = kwargs["threshold"]
eta_shifted = threshold - pt.shape_padright(eta)

# When the model does not have any predictors.
# Inference can be slower, as this can potentially build a larger object.
# However, this is needed for consistency with other parts of the codebase
if eta == 0:
eta_shifted = threshold - pt.shape_padright(pt.zeros(len(kwargs["observed"])))
else:
eta_shifted = threshold - pt.shape_padright(eta)

return eta_shifted

@staticmethod
Expand Down Expand Up @@ -393,8 +402,17 @@ def transform_backend_eta(eta, kwargs):
# shape(threshold) = (K, )
# shape(eta) = (n, )
# shape(threshold - shape_padright(eta)) = (n, K)

threshold = kwargs["threshold"]
eta_shifted = threshold - pt.shape_padright(eta)

# When the model does not have any predictors.
# Inference can be slower, as this can potentially build a larger object.
# However, this is needed for consistency with other parts of the codebase
if eta == 0:
eta_shifted = threshold - pt.shape_padright(pt.zeros(len(kwargs["observed"])))
else:
eta_shifted = threshold - pt.shape_padright(eta)

return eta_shifted

@staticmethod
Expand Down
12 changes: 10 additions & 2 deletions bambi/priors/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,19 @@ def scale_response(self):
# Here we would add cases for other families if we wanted
if isinstance(self.model.family, (Gaussian, StudentT)):
sigma = self.model.components["sigma"]
if isinstance(sigma, ConstantComponent) and sigma.prior.auto_scale:
if (
isinstance(sigma, ConstantComponent)
and hasattr(sigma.prior, "auto_scale") # not available when `.prior` is a scalar
and sigma.prior.auto_scale
):
sigma.prior = Prior("HalfStudentT", nu=4, sigma=self.response_std)
elif isinstance(self.model.family, VonMises):
kappa = self.model.components["kappa"]
if isinstance(kappa, ConstantComponent) and kappa.prior.auto_scale:
if (
isinstance(kappa, ConstantComponent)
and hasattr(kappa.prior, "auto_scale") # not available when `.prior` is a scalar
and kappa.prior.auto_scale
):
kappa.prior = Prior("HalfStudentT", nu=4, sigma=self.response_std)

def scale_intercept(self, term):
Expand Down
22 changes: 13 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"formulae>=0.5.3",
"graphviz",
"pandas>=1.0.0",
"pymc>=5.16.1",
"pymc>=5.18.0",
]

[project.optional-dependencies]
Expand All @@ -38,8 +38,12 @@ dev = [
"seaborn>=0.9.0",
]

# TODO: Unpin this before making a release
jax = [
"bayeux-ml>=0.1.13",
"bayeux-ml==0.1.14",
"blackjax==1.2.3",
"jax<=0.4.33",
"jaxlib<=0.4.33",
]

[project.urls]
Expand All @@ -50,14 +54,14 @@ changelog = "https://github.com/bambinos/bambi/blob/main/docs/CHANGELOG.md"

[tool.setuptools]
packages = [
"bambi",
"bambi.backend",
"bambi.data",
"bambi.defaults",
"bambi",
"bambi.backend",
"bambi.data",
"bambi.defaults",
"bambi.families",
"bambi.interpret",
"bambi.priors",
"bambi.terms",
"bambi.interpret",
"bambi.priors",
"bambi.terms",
]

[tool.black]
Expand Down
8 changes: 8 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,10 @@ def test_intercept_only(self, multinomial_data):
idata = self.predict_oos(model, idata, data=model.data)
self.assert_posterior_predictive(model, idata)

# Out of sample with different number of rows, see issue #845
idata = self.predict_oos(model, idata, data=model.data.sample(frac=0.8, random_state=1211))
self.assert_posterior_predictive(model, idata)

def test_numerical_predictors(self, multinomial_data):
model = bmb.Model(
"c(y1, y2, y3, y4) ~ treat + carry", multinomial_data, family="multinomial"
Expand Down Expand Up @@ -1242,6 +1246,10 @@ def test_intercept_only(self, multinomial_data):
idata = self.predict_oos(model, idata, model.data)
self.assert_posterior_predictive(model, idata)

# Out of sample with different number of rows, see issue #845
idata = self.predict_oos(model, idata, data=model.data.sample(frac=0.8, random_state=1211))
self.assert_posterior_predictive(model, idata)

def test_predictor(self, multinomial_data):
model = bmb.Model(
"c(y1, y2, y3, y4) ~ 0 + treat", multinomial_data, family="dirichlet_multinomial"
Expand Down

0 comments on commit 0a78220

Please sign in to comment.