@@ -92,7 +92,7 @@ def create_basis_matrix(df_events: pd.DataFrame, model_dates: np.ndarray):
92
92
93
93
"""
94
94
95
- from typing import cast
95
+ from typing import Literal , cast
96
96
97
97
import numpy as np
98
98
import numpy .typing as npt
@@ -270,6 +270,183 @@ def function(self, x: pt.TensorLike, sigma: pt.TensorLike) -> TensorVariable:
270
270
}
271
271
272
272
273
+ class HalfGaussianBasis (Basis ):
274
+ R"""One-sided Gaussian basis transformation.
275
+
276
+ .. plot::
277
+ :context: close-figs
278
+
279
+ import matplotlib.pyplot as plt
280
+ from pymc_marketing.mmm.events import HalfGaussianBasis
281
+ from pymc_extras.prior import Prior
282
+ half_gaussian = HalfGaussianBasis(
283
+ priors={
284
+ "sigma": Prior("Gamma", mu=[3, 4], sigma=1, dims="event"),
285
+ }
286
+ )
287
+ coords = {"event": ["PyData-Berlin", "PyCon-Finland"]}
288
+ prior = half_gaussian.sample_prior(coords=coords)
289
+ curve = half_gaussian.sample_curve(prior)
290
+ fig, axes = half_gaussian.plot_curve(
291
+ curve, subplot_kwargs={"figsize": (6, 3), "sharey": True}
292
+ )
293
+ for ax in axes:
294
+ ax.set_xlabel("")
295
+ plt.show()
296
+
297
+ Parameters
298
+ ----------
299
+ mode : Literal["after", "before"]
300
+ Whether the basis is located before or after the event.
301
+ include_event : bool
302
+ Whether to include the event days in the basis.
303
+ priors : dict[str, Prior]
304
+ Prior for the sigma parameter.
305
+ prefix : str
306
+ Prefix for the parameter names.
307
+ """
308
+
309
+ lookup_name = "half_gaussian"
310
+
311
+ def __init__ (
312
+ self ,
313
+ mode : Literal ["after" , "before" ] = "after" ,
314
+ include_event : bool = True ,
315
+ ** kwargs ,
316
+ ):
317
+ super ().__init__ (** kwargs )
318
+ self .mode = mode
319
+ self .include_event = include_event
320
+
321
+ def function (self , x : pt .TensorLike , sigma : pt .TensorLike ) -> TensorVariable :
322
+ """One-sided Gaussian bump function."""
323
+ rv = pm .Normal .dist (mu = 0.0 , sigma = sigma )
324
+ out = pm .math .exp (pm .logp (rv , x ))
325
+ # Sign determines if the zeroing happens after or before the event.
326
+ sign = 1 if self .mode == "after" else - 1
327
+ # Build boolean mask(s) in x's shape and broadcast to out's shape.
328
+ pre_mask = sign * x < 0
329
+ if not self .include_event :
330
+ pre_mask = pm .math .or_ (pre_mask , sign * x == 0 )
331
+
332
+ # Ensure mask matches output shape for elementwise switch
333
+ pre_mask = pt .broadcast_to (pre_mask , out .shape )
334
+
335
+ return pt .switch (pre_mask , 0 , out )
336
+
337
+ def to_dict (self ) -> dict :
338
+ """Convert the half Gaussian basis to a dictionary."""
339
+ return {
340
+ ** super ().to_dict (),
341
+ "mode" : self .mode ,
342
+ "include_event" : self .include_event ,
343
+ }
344
+
345
+ default_priors = {
346
+ "sigma" : Prior ("Gamma" , mu = 7 , sigma = 1 ),
347
+ }
348
+
349
+
350
+ class AsymmetricGaussianBasis (Basis ):
351
+ R"""Asymmetric Gaussian bump basis transformation.
352
+
353
+ Allows different widths (sigma_before, sigma_after) and amplitudes (a_after)
354
+ after the event.
355
+
356
+ .. plot::
357
+ :context: close-figs
358
+
359
+ import matplotlib.pyplot as plt
360
+ from pymc_marketing.mmm.events import AsymmetricGaussianBasis
361
+ from pymc_extras.prior import Prior
362
+ asy_gaussian = AsymmetricGaussianBasis(
363
+ priors={
364
+ "sigma_before": Prior("Gamma", mu=[3, 4], sigma=1, dims="event"),
365
+ "a_after": Prior("Normal", mu=[-.75, .5], sigma=.2, dims="event"),
366
+ }
367
+ )
368
+ coords = {"event": ["PyData-Berlin", "PyCon-Finland"]}
369
+ prior = asy_gaussian.sample_prior(coords=coords)
370
+ curve = asy_gaussian.sample_curve(prior)
371
+ fig, axes = asy_gaussian.plot_curve(
372
+ curve, subplot_kwargs={"figsize": (6, 3), "sharey": True}
373
+ )
374
+ for ax in axes:
375
+ ax.set_xlabel("")
376
+ plt.show()
377
+
378
+ Parameters
379
+ ----------
380
+ event_in : Literal["before", "after", "exclude"]
381
+ Whether to include the event in the before or after part of the basis,
382
+ or leave it out entirely. Default is "after".
383
+ priors : dict[str, Prior]
384
+ Prior for the sigma_before, sigma_after, a_before, and a_after parameters.
385
+ prefix : str
386
+ Prefix for the parameters.
387
+ """
388
+
389
+ lookup_name = "asymmetric_gaussian"
390
+
391
+ def __init__ (
392
+ self ,
393
+ event_in : Literal ["before" , "after" , "exclude" ] = "after" ,
394
+ ** kwargs ,
395
+ ):
396
+ super ().__init__ (** kwargs )
397
+ self .event_in = event_in
398
+
399
+ def function (
400
+ self ,
401
+ x : pt .TensorLike ,
402
+ sigma_before : pt .TensorLike ,
403
+ sigma_after : pt .TensorLike ,
404
+ a_after : pt .TensorLike ,
405
+ ) -> pt .TensorVariable :
406
+ """Asymmetric Gaussian bump function."""
407
+ match self .event_in :
408
+ case "before" :
409
+ indicator_before = pt .cast (x <= 0 , "float32" )
410
+ indicator_after = pt .cast (x > 0 , "float32" )
411
+ case "after" :
412
+ indicator_before = pt .cast (x < 0 , "float32" )
413
+ indicator_after = pt .cast (x >= 0 , "float32" )
414
+ case "exclude" :
415
+ indicator_before = pt .cast (x < 0 , "float32" )
416
+ indicator_after = pt .cast (x > 0 , "float32" )
417
+ case _:
418
+ raise ValueError (f"Invalid event_in: { self .event_in } " )
419
+
420
+ rv_before = pm .Normal .dist (mu = 0.0 , sigma = sigma_before )
421
+ rv_after = pm .Normal .dist (mu = 0.0 , sigma = sigma_after )
422
+
423
+ y_before = pt .switch (
424
+ indicator_before ,
425
+ pm .math .exp (pm .logp (rv_before , x )),
426
+ 0 ,
427
+ )
428
+ y_after = pt .switch (
429
+ indicator_after ,
430
+ pm .math .exp (pm .logp (rv_after , x )) * a_after ,
431
+ 0 ,
432
+ )
433
+
434
+ return y_before + y_after
435
+
436
+ def to_dict (self ) -> dict :
437
+ """Convert the asymmetric Gaussian basis to a dictionary."""
438
+ return {
439
+ ** super ().to_dict (),
440
+ "event_in" : self .event_in ,
441
+ }
442
+
443
+ default_priors = {
444
+ "sigma_before" : Prior ("Gamma" , mu = 3 , sigma = 1 ),
445
+ "sigma_after" : Prior ("Gamma" , mu = 7 , sigma = 2 ),
446
+ "a_after" : Prior ("Normal" , mu = 1 , sigma = 0.5 ),
447
+ }
448
+
449
+
273
450
def days_from_reference (
274
451
dates : pd .Series | pd .DatetimeIndex ,
275
452
reference_date : str | pd .Timestamp ,
0 commit comments