Skip to content

Commit ab7b1f2

Browse files
[ENH] Add geometric adstock effect
[ENH] Add geometric adstock effect
2 parents 9763d46 + b08152f commit ab7b1f2

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""Definition of Geometric Adstock Effect class."""
2+
3+
from typing import Dict
4+
5+
import jax
6+
import jax.numpy as jnp
7+
import numpyro
8+
from numpyro import distributions as dist
9+
10+
from prophetverse.effects.base import BaseEffect
11+
12+
__all__ = ["GeometricAdstockEffect"]
13+
14+
15+
class GeometricAdstockEffect(BaseEffect):
16+
"""Represents a Geometric Adstock effect in a time series model.
17+
18+
Parameters
19+
----------
20+
decay_prior : Distribution, optional
21+
Prior distribution for the decay parameter (controls the rate of decay).
22+
rase_error_if_fh_changes : bool, optional
23+
Whether to raise an error if the forecasting horizon changes during predict
24+
"""
25+
26+
_tags = {
27+
"supports_multivariate": False,
28+
"skip_predict_if_no_match": True,
29+
"filter_indexes_with_forecating_horizon_at_transform": True,
30+
}
31+
32+
def __init__(
33+
self,
34+
decay_prior: dist.Distribution = None,
35+
raise_error_if_fh_changes: bool = True,
36+
):
37+
self.decay_prior = decay_prior or dist.Beta(
38+
2, 2
39+
) # Default Beta distribution for decay rate.
40+
self.raise_errror_if_fh_changes = raise_error_if_fh_changes
41+
super().__init__()
42+
43+
self._min_date = None
44+
45+
def _transform(self, X, fh):
46+
"""Transform the dataframe and horizon to array.
47+
48+
Parameters
49+
----------
50+
X : pd.DataFrame
51+
dataframe with exogenous variables
52+
fh : pd.Index
53+
Forecast horizon
54+
55+
Returns
56+
-------
57+
jnp.ndarray
58+
the array with data for _predict
59+
60+
Raises
61+
------
62+
ValueError
63+
If the forecasting horizon is different during predict and fit.
64+
"""
65+
if self._min_date is None:
66+
self._min_date = X.index.min()
67+
else:
68+
if self._min_date != X.index.min() and self.raise_errror_if_fh_changes:
69+
raise ValueError(
70+
"The X dataframe and forecat horizon"
71+
"must be start at the same"
72+
"date as the previous one"
73+
)
74+
return super()._transform(X, fh)
75+
76+
def _sample_params(
77+
self, data: jnp.ndarray, predicted_effects: Dict[str, jnp.ndarray]
78+
) -> Dict[str, jnp.ndarray]:
79+
"""
80+
Sample the parameters of the effect.
81+
82+
Parameters
83+
----------
84+
data : jnp.ndarray
85+
Data obtained from the transformed method.
86+
predicted_effects : Dict[str, jnp.ndarray]
87+
A dictionary containing the predicted effects.
88+
89+
Returns
90+
-------
91+
Dict[str, jnp.ndarray]
92+
A dictionary containing the sampled parameters of the effect.
93+
"""
94+
return {
95+
"decay": numpyro.sample("decay", self.decay_prior),
96+
}
97+
98+
def _predict(
99+
self,
100+
data: jnp.ndarray,
101+
predicted_effects: Dict[str, jnp.ndarray],
102+
params: Dict[str, jnp.ndarray],
103+
) -> jnp.ndarray:
104+
"""
105+
Apply and return the geometric adstock effect values.
106+
107+
Parameters
108+
----------
109+
data : jnp.ndarray
110+
Data obtained from the transformed method (shape: T, 1).
111+
predicted_effects : Dict[str, jnp.ndarray]
112+
A dictionary containing the predicted effects.
113+
params : Dict[str, jnp.ndarray]
114+
A dictionary containing the sampled parameters of the effect.
115+
116+
Returns
117+
-------
118+
jnp.ndarray
119+
An array with shape (T, 1) for univariate timeseries.
120+
"""
121+
decay = params["decay"]
122+
123+
# Apply geometric adstock using jax.lax.scan for efficiency
124+
def adstock_step(carry, current):
125+
prev_adstock = carry
126+
new_adstock = current + decay * prev_adstock
127+
return new_adstock, new_adstock
128+
129+
_, adstock = jax.lax.scan(
130+
adstock_step, init=jnp.array([0], dtype=data.dtype), xs=data
131+
)
132+
return adstock.reshape(-1, 1)

tests/effects/test_adstock.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Pytest for Geometric Adstock Effect class."""
2+
3+
import jax.numpy as jnp
4+
import pandas as pd
5+
import pytest
6+
from numpyro import handlers
7+
from numpyro.distributions import Beta
8+
9+
from prophetverse.effects.adstock import GeometricAdstockEffect
10+
11+
12+
def test_geometric_adstock_sampling():
13+
"""Test parameter sampling using numpyro.handlers.trace."""
14+
effect = GeometricAdstockEffect(decay_prior=Beta(2, 2))
15+
data = jnp.ones((10, 1)) # Dummy data
16+
predicted_effects = {}
17+
18+
with handlers.trace() as trace, handlers.seed(rng_seed=0):
19+
effect._sample_params(data, predicted_effects)
20+
21+
# Verify trace contains decay site
22+
assert "decay" in trace, "Decay parameter not found in trace."
23+
24+
# Verify decay is sampled from the correct prior
25+
assert trace["decay"]["type"] == "sample", "Decay parameter not sampled."
26+
assert isinstance(
27+
trace["decay"]["fn"], Beta
28+
), "Decay parameter not sampled from Beta distribution."
29+
30+
31+
def test_geometric_adstock_predict():
32+
"""Test the predict method for correctness with predefined parameters."""
33+
effect = GeometricAdstockEffect()
34+
35+
# Define mock data and parameters
36+
data = jnp.array([[10.0], [20.0], [30.0]]) # Example input data (T, 1)
37+
params = {"decay": jnp.array(0.5)}
38+
predicted_effects = {}
39+
40+
# Call _predict
41+
result = effect._predict(data, predicted_effects, params)
42+
43+
# Expected adstock output
44+
expected = jnp.array(
45+
[
46+
[10.0],
47+
[20.0 + 0.5 * 10.0],
48+
[30.0 + 0.5 * (20.0 + 0.5 * 10.0)],
49+
]
50+
)
51+
52+
# Verify output shape
53+
assert result.shape == data.shape, "Output shape mismatch."
54+
55+
# Verify output values
56+
assert jnp.allclose(result, expected), "Adstock computation incorrect."
57+
58+
59+
def test_error_when_different_fh():
60+
effect = GeometricAdstockEffect()
61+
X = pd.DataFrame(
62+
data={"exog": [10.0, 20.0, 30.0, 30.0, 40.0, 50.0]},
63+
index=pd.date_range("2021-01-01", periods=6),
64+
)
65+
fh = X.index
66+
effect.transform(X=X, fh=fh)
67+
68+
effect.transform(X=X.iloc[:1], fh=fh[:1])
69+
with pytest.raises(ValueError):
70+
effect.transform(X=X, fh=fh[1:])

0 commit comments

Comments
 (0)