Skip to content

Commit 88a864b

Browse files
TeemuSailynojacetagostinitimbo112711juanitorduz
authored
Add more Basis implementations for use with EventEffect (#1911)
* Add HalfGaussianBasis, and AsymmetricGaussianBasis classes for one-sided and asymmetric Gaussian EventEffects in the Multidimensional MMM module. * Added tests for the new classes * Add example plots of HalfGaussian and AsymmetricGaussian * Updates mmm_compenents notebook with new AsymmetricGaussianBasis * Normalize HalfGaussianBasis and AsymmetricGaussianBasis --------- Co-authored-by: Carlos Trujillo <[email protected]> Co-authored-by: tim mcwilliams <[email protected]> Co-authored-by: Juan Orduz <[email protected]>
1 parent 55d134e commit 88a864b

File tree

4 files changed

+1802
-734
lines changed

4 files changed

+1802
-734
lines changed

docs/source/notebooks/mmm/mmm_components.ipynb

Lines changed: 1206 additions & 732 deletions
Large diffs are not rendered by default.

pymc_marketing/mmm/events.py

Lines changed: 178 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def create_basis_matrix(df_events: pd.DataFrame, model_dates: np.ndarray):
9292
9393
"""
9494

95-
from typing import cast
95+
from typing import Literal, cast
9696

9797
import numpy as np
9898
import numpy.typing as npt
@@ -270,6 +270,183 @@ def function(self, x: pt.TensorLike, sigma: pt.TensorLike) -> TensorVariable:
270270
}
271271

272272

273+
class HalfGaussianBasis(Basis):
274+
R"""One-sided Gaussian basis transformation.
275+
276+
.. plot::
277+
:context: close-figs
278+
279+
import matplotlib.pyplot as plt
280+
from pymc_marketing.mmm.events import HalfGaussianBasis
281+
from pymc_extras.prior import Prior
282+
half_gaussian = HalfGaussianBasis(
283+
priors={
284+
"sigma": Prior("Gamma", mu=[3, 4], sigma=1, dims="event"),
285+
}
286+
)
287+
coords = {"event": ["PyData-Berlin", "PyCon-Finland"]}
288+
prior = half_gaussian.sample_prior(coords=coords)
289+
curve = half_gaussian.sample_curve(prior)
290+
fig, axes = half_gaussian.plot_curve(
291+
curve, subplot_kwargs={"figsize": (6, 3), "sharey": True}
292+
)
293+
for ax in axes:
294+
ax.set_xlabel("")
295+
plt.show()
296+
297+
Parameters
298+
----------
299+
mode : Literal["after", "before"]
300+
Whether the basis is located before or after the event.
301+
include_event : bool
302+
Whether to include the event days in the basis.
303+
priors : dict[str, Prior]
304+
Prior for the sigma parameter.
305+
prefix : str
306+
Prefix for the parameter names.
307+
"""
308+
309+
lookup_name = "half_gaussian"
310+
311+
def __init__(
312+
self,
313+
mode: Literal["after", "before"] = "after",
314+
include_event: bool = True,
315+
**kwargs,
316+
):
317+
super().__init__(**kwargs)
318+
self.mode = mode
319+
self.include_event = include_event
320+
321+
def function(self, x: pt.TensorLike, sigma: pt.TensorLike) -> TensorVariable:
322+
"""One-sided Gaussian bump function."""
323+
rv = pm.Normal.dist(mu=0.0, sigma=sigma)
324+
out = pm.math.exp(pm.logp(rv, x))
325+
# Sign determines if the zeroing happens after or before the event.
326+
sign = 1 if self.mode == "after" else -1
327+
# Build boolean mask(s) in x's shape and broadcast to out's shape.
328+
pre_mask = sign * x < 0
329+
if not self.include_event:
330+
pre_mask = pm.math.or_(pre_mask, sign * x == 0)
331+
332+
# Ensure mask matches output shape for elementwise switch
333+
pre_mask = pt.broadcast_to(pre_mask, out.shape)
334+
335+
return pt.switch(pre_mask, 0, out)
336+
337+
def to_dict(self) -> dict:
338+
"""Convert the half Gaussian basis to a dictionary."""
339+
return {
340+
**super().to_dict(),
341+
"mode": self.mode,
342+
"include_event": self.include_event,
343+
}
344+
345+
default_priors = {
346+
"sigma": Prior("Gamma", mu=7, sigma=1),
347+
}
348+
349+
350+
class AsymmetricGaussianBasis(Basis):
351+
R"""Asymmetric Gaussian bump basis transformation.
352+
353+
Allows different widths (sigma_before, sigma_after) and amplitudes (a_after)
354+
after the event.
355+
356+
.. plot::
357+
:context: close-figs
358+
359+
import matplotlib.pyplot as plt
360+
from pymc_marketing.mmm.events import AsymmetricGaussianBasis
361+
from pymc_extras.prior import Prior
362+
asy_gaussian = AsymmetricGaussianBasis(
363+
priors={
364+
"sigma_before": Prior("Gamma", mu=[3, 4], sigma=1, dims="event"),
365+
"a_after": Prior("Normal", mu=[-.75, .5], sigma=.2, dims="event"),
366+
}
367+
)
368+
coords = {"event": ["PyData-Berlin", "PyCon-Finland"]}
369+
prior = asy_gaussian.sample_prior(coords=coords)
370+
curve = asy_gaussian.sample_curve(prior)
371+
fig, axes = asy_gaussian.plot_curve(
372+
curve, subplot_kwargs={"figsize": (6, 3), "sharey": True}
373+
)
374+
for ax in axes:
375+
ax.set_xlabel("")
376+
plt.show()
377+
378+
Parameters
379+
----------
380+
event_in : Literal["before", "after", "exclude"]
381+
Whether to include the event in the before or after part of the basis,
382+
or leave it out entirely. Default is "after".
383+
priors : dict[str, Prior]
384+
Prior for the sigma_before, sigma_after, a_before, and a_after parameters.
385+
prefix : str
386+
Prefix for the parameters.
387+
"""
388+
389+
lookup_name = "asymmetric_gaussian"
390+
391+
def __init__(
392+
self,
393+
event_in: Literal["before", "after", "exclude"] = "after",
394+
**kwargs,
395+
):
396+
super().__init__(**kwargs)
397+
self.event_in = event_in
398+
399+
def function(
400+
self,
401+
x: pt.TensorLike,
402+
sigma_before: pt.TensorLike,
403+
sigma_after: pt.TensorLike,
404+
a_after: pt.TensorLike,
405+
) -> pt.TensorVariable:
406+
"""Asymmetric Gaussian bump function."""
407+
match self.event_in:
408+
case "before":
409+
indicator_before = pt.cast(x <= 0, "float32")
410+
indicator_after = pt.cast(x > 0, "float32")
411+
case "after":
412+
indicator_before = pt.cast(x < 0, "float32")
413+
indicator_after = pt.cast(x >= 0, "float32")
414+
case "exclude":
415+
indicator_before = pt.cast(x < 0, "float32")
416+
indicator_after = pt.cast(x > 0, "float32")
417+
case _:
418+
raise ValueError(f"Invalid event_in: {self.event_in}")
419+
420+
rv_before = pm.Normal.dist(mu=0.0, sigma=sigma_before)
421+
rv_after = pm.Normal.dist(mu=0.0, sigma=sigma_after)
422+
423+
y_before = pt.switch(
424+
indicator_before,
425+
pm.math.exp(pm.logp(rv_before, x)),
426+
0,
427+
)
428+
y_after = pt.switch(
429+
indicator_after,
430+
pm.math.exp(pm.logp(rv_after, x)) * a_after,
431+
0,
432+
)
433+
434+
return y_before + y_after
435+
436+
def to_dict(self) -> dict:
437+
"""Convert the asymmetric Gaussian basis to a dictionary."""
438+
return {
439+
**super().to_dict(),
440+
"event_in": self.event_in,
441+
}
442+
443+
default_priors = {
444+
"sigma_before": Prior("Gamma", mu=3, sigma=1),
445+
"sigma_after": Prior("Gamma", mu=7, sigma=2),
446+
"a_after": Prior("Normal", mu=1, sigma=0.5),
447+
}
448+
449+
273450
def days_from_reference(
274451
dates: pd.Series | pd.DatetimeIndex,
275452
reference_date: str | pd.Timestamp,

0 commit comments

Comments
 (0)