|
| 1 | +"""Definition of Geometric Adstock Effect class.""" |
| 2 | + |
| 3 | +from typing import Dict |
| 4 | + |
| 5 | +import jax |
| 6 | +import jax.numpy as jnp |
| 7 | +import numpyro |
| 8 | +from numpyro import distributions as dist |
| 9 | + |
| 10 | +from prophetverse.effects.base import BaseEffect |
| 11 | + |
| 12 | +__all__ = ["GeometricAdstockEffect"] |
| 13 | + |
| 14 | + |
| 15 | +class GeometricAdstockEffect(BaseEffect): |
| 16 | + """Represents a Geometric Adstock effect in a time series model. |
| 17 | +
|
| 18 | + Parameters |
| 19 | + ---------- |
| 20 | + decay_prior : Distribution, optional |
| 21 | + Prior distribution for the decay parameter (controls the rate of decay). |
| 22 | + rase_error_if_fh_changes : bool, optional |
| 23 | + Whether to raise an error if the forecasting horizon changes during predict |
| 24 | + """ |
| 25 | + |
| 26 | + _tags = { |
| 27 | + "supports_multivariate": False, |
| 28 | + "skip_predict_if_no_match": True, |
| 29 | + "filter_indexes_with_forecating_horizon_at_transform": True, |
| 30 | + } |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + decay_prior: dist.Distribution = None, |
| 35 | + raise_error_if_fh_changes: bool = True, |
| 36 | + ): |
| 37 | + self.decay_prior = decay_prior or dist.Beta( |
| 38 | + 2, 2 |
| 39 | + ) # Default Beta distribution for decay rate. |
| 40 | + self.raise_errror_if_fh_changes = raise_error_if_fh_changes |
| 41 | + super().__init__() |
| 42 | + |
| 43 | + self._min_date = None |
| 44 | + |
| 45 | + def _transform(self, X, fh): |
| 46 | + """Transform the dataframe and horizon to array. |
| 47 | +
|
| 48 | + Parameters |
| 49 | + ---------- |
| 50 | + X : pd.DataFrame |
| 51 | + dataframe with exogenous variables |
| 52 | + fh : pd.Index |
| 53 | + Forecast horizon |
| 54 | +
|
| 55 | + Returns |
| 56 | + ------- |
| 57 | + jnp.ndarray |
| 58 | + the array with data for _predict |
| 59 | +
|
| 60 | + Raises |
| 61 | + ------ |
| 62 | + ValueError |
| 63 | + If the forecasting horizon is different during predict and fit. |
| 64 | + """ |
| 65 | + if self._min_date is None: |
| 66 | + self._min_date = X.index.min() |
| 67 | + else: |
| 68 | + if self._min_date != X.index.min() and self.raise_errror_if_fh_changes: |
| 69 | + raise ValueError( |
| 70 | + "The X dataframe and forecat horizon" |
| 71 | + "must be start at the same" |
| 72 | + "date as the previous one" |
| 73 | + ) |
| 74 | + return super()._transform(X, fh) |
| 75 | + |
| 76 | + def _sample_params( |
| 77 | + self, data: jnp.ndarray, predicted_effects: Dict[str, jnp.ndarray] |
| 78 | + ) -> Dict[str, jnp.ndarray]: |
| 79 | + """ |
| 80 | + Sample the parameters of the effect. |
| 81 | +
|
| 82 | + Parameters |
| 83 | + ---------- |
| 84 | + data : jnp.ndarray |
| 85 | + Data obtained from the transformed method. |
| 86 | + predicted_effects : Dict[str, jnp.ndarray] |
| 87 | + A dictionary containing the predicted effects. |
| 88 | +
|
| 89 | + Returns |
| 90 | + ------- |
| 91 | + Dict[str, jnp.ndarray] |
| 92 | + A dictionary containing the sampled parameters of the effect. |
| 93 | + """ |
| 94 | + return { |
| 95 | + "decay": numpyro.sample("decay", self.decay_prior), |
| 96 | + } |
| 97 | + |
| 98 | + def _predict( |
| 99 | + self, |
| 100 | + data: jnp.ndarray, |
| 101 | + predicted_effects: Dict[str, jnp.ndarray], |
| 102 | + params: Dict[str, jnp.ndarray], |
| 103 | + ) -> jnp.ndarray: |
| 104 | + """ |
| 105 | + Apply and return the geometric adstock effect values. |
| 106 | +
|
| 107 | + Parameters |
| 108 | + ---------- |
| 109 | + data : jnp.ndarray |
| 110 | + Data obtained from the transformed method (shape: T, 1). |
| 111 | + predicted_effects : Dict[str, jnp.ndarray] |
| 112 | + A dictionary containing the predicted effects. |
| 113 | + params : Dict[str, jnp.ndarray] |
| 114 | + A dictionary containing the sampled parameters of the effect. |
| 115 | +
|
| 116 | + Returns |
| 117 | + ------- |
| 118 | + jnp.ndarray |
| 119 | + An array with shape (T, 1) for univariate timeseries. |
| 120 | + """ |
| 121 | + decay = params["decay"] |
| 122 | + |
| 123 | + # Apply geometric adstock using jax.lax.scan for efficiency |
| 124 | + def adstock_step(carry, current): |
| 125 | + prev_adstock = carry |
| 126 | + new_adstock = current + decay * prev_adstock |
| 127 | + return new_adstock, new_adstock |
| 128 | + |
| 129 | + _, adstock = jax.lax.scan( |
| 130 | + adstock_step, init=jnp.array([0], dtype=data.dtype), xs=data |
| 131 | + ) |
| 132 | + return adstock.reshape(-1, 1) |
0 commit comments