Skip to content

Commit 55d134e

Browse files
authored
Normalize gaussian basis (#1912)
* Normalize gaussian basis * Adapt test to normalization
1 parent 53fa538 commit 55d134e

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

pymc_marketing/mmm/events.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ class GaussianBasis(Basis):
261261

262262
def function(self, x: pt.TensorLike, sigma: pt.TensorLike) -> TensorVariable:
263263
"""Gaussian bump function."""
264-
return pm.math.exp(-0.5 * (x / sigma) ** 2)
264+
rv = pm.Normal.dist(mu=0.0, sigma=sigma)
265+
out = pm.math.exp(pm.logp(rv, x))
266+
return out
265267

266268
default_priors = {
267269
"sigma": Prior("Gamma", mu=7, sigma=1),

tests/mmm/test_events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def test_gaussian_basis_function():
171171
sigma = np.array([1.0])
172172

173173
result = gaussian.function(x, sigma).eval()
174-
expected = np.exp(-0.5 * (x / sigma) ** 2)
174+
expected = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * (x / sigma) ** 2)
175175

176176
np.testing.assert_array_almost_equal(result, expected)
177177

0 commit comments

Comments
 (0)