Skip to content

Commit 39c9061

Browse files
Merge pull request #154 from felipeangelimvieira/feature/add_chaineffect
[ENH] Add chain effect
2 parents ab7b1f2 + bfda9fc commit 39c9061

File tree

3 files changed

+271
-0
lines changed

3 files changed

+271
-0
lines changed

src/prophetverse/effects/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Effects that define relationships between variables and the target."""
22

3+
from .adstock import GeometricAdstockEffect
34
from .base import BaseEffect
5+
from .chain import ChainedEffects
46
from .exact_likelihood import ExactLikelihood
57
from .fourier import LinearFourierSeasonality
68
from .hill import HillEffect
@@ -16,4 +18,6 @@
1618
"ExactLikelihood",
1719
"LiftExperimentLikelihood",
1820
"LinearFourierSeasonality",
21+
"GeometricAdstockEffect",
22+
"ChainedEffects",
1923
]

src/prophetverse/effects/chain.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
"""Definition of Chained Effects class."""
2+
3+
from typing import Any, Dict, List
4+
5+
import jax.numpy as jnp
6+
from numpyro import handlers
7+
from skbase.base import BaseMetaEstimatorMixin
8+
9+
from prophetverse.effects.base import BaseEffect
10+
11+
__all__ = ["ChainedEffects"]
12+
13+
14+
class ChainedEffects(BaseMetaEstimatorMixin, BaseEffect):
15+
"""
16+
Chains multiple effects sequentially, applying them one after the other.
17+
18+
Parameters
19+
----------
20+
steps : List[BaseEffect]
21+
A list of effects to be applied sequentially.
22+
"""
23+
24+
_tags = {
25+
"supports_multivariate": True,
26+
"skip_predict_if_no_match": True,
27+
"filter_indexes_with_forecating_horizon_at_transform": True,
28+
}
29+
30+
def __init__(self, steps: List[BaseEffect]):
31+
self.steps = steps
32+
super().__init__()
33+
34+
def _fit(self, y: Any, X: Any, scale: float = 1.0):
35+
"""
36+
Fit all chained effects sequentially.
37+
38+
Parameters
39+
----------
40+
y : Any
41+
Target data (e.g., time series values).
42+
X : Any
43+
Exogenous variables.
44+
scale : float, optional
45+
Scale of the timeseries.
46+
"""
47+
for effect in self.steps:
48+
effect.fit(y, X, scale)
49+
50+
def _transform(self, X: Any, fh: Any) -> Any:
51+
"""
52+
Transform input data sequentially through all chained effects.
53+
54+
Parameters
55+
----------
56+
X : Any
57+
Input data (e.g., exogenous variables).
58+
fh : Any
59+
Forecasting horizon.
60+
61+
Returns
62+
-------
63+
Any
64+
Transformed data after applying all effects.
65+
"""
66+
output = X
67+
output = self.steps[0].transform(output, fh)
68+
return output
69+
70+
def _sample_params(
71+
self, data: jnp.ndarray, predicted_effects: Dict[str, jnp.ndarray]
72+
) -> Dict[str, jnp.ndarray]:
73+
"""
74+
Sample parameters for all chained effects.
75+
76+
Parameters
77+
----------
78+
data : jnp.ndarray
79+
Data obtained from the transformed method.
80+
predicted_effects : Dict[str, jnp.ndarray]
81+
A dictionary containing the predicted effects.
82+
83+
Returns
84+
-------
85+
Dict[str, jnp.ndarray]
86+
A dictionary containing the sampled parameters for all effects.
87+
"""
88+
params = {}
89+
for idx, effect in enumerate(self.steps):
90+
with handlers.scope(prefix=f"{idx}"):
91+
effect_params = effect.sample_params(data, predicted_effects)
92+
params[f"effect_{idx}"] = effect_params
93+
return params
94+
95+
def _predict(
96+
self,
97+
data: jnp.ndarray,
98+
predicted_effects: Dict[str, jnp.ndarray],
99+
params: Dict[str, Dict[str, jnp.ndarray]],
100+
) -> jnp.ndarray:
101+
"""
102+
Apply all chained effects sequentially.
103+
104+
Parameters
105+
----------
106+
data : jnp.ndarray
107+
Data obtained from the transformed method (shape: T, 1).
108+
predicted_effects : Dict[str, jnp.ndarray]
109+
A dictionary containing the predicted effects.
110+
params : Dict[str, Dict[str, jnp.ndarray]]
111+
A dictionary containing the sampled parameters for each effect.
112+
113+
Returns
114+
-------
115+
jnp.ndarray
116+
The transformed data after applying all effects.
117+
"""
118+
output = data
119+
for idx, effect in enumerate(self.steps):
120+
effect_params = params[f"effect_{idx}"]
121+
output = effect._predict(output, predicted_effects, effect_params)
122+
return output
123+
124+
def _coerce_to_named_object_tuples(self, objs, clone=False, make_unique=True):
125+
"""Coerce sequence of objects or named objects to list of (str, obj) tuples.
126+
127+
Input that is sequence of objects, list of (str, obj) tuples or
128+
dict[str, object] will be coerced to list of (str, obj) tuples on return.
129+
130+
Parameters
131+
----------
132+
objs : list of objects, list of (str, object tuples) or dict[str, object]
133+
The input should be coerced to list of (str, object) tuples. Should
134+
be a sequence of objects, or follow named object API.
135+
clone : bool, default=False.
136+
Whether objects in the returned list of (str, object) tuples are
137+
cloned (True) or references (False).
138+
make_unique : bool, default=True
139+
Whether the str names in the returned list of (str, object) tuples
140+
should be coerced to unique str values (if str names in input
141+
are already unique they will not be changed).
142+
143+
Returns
144+
-------
145+
list[tuple[str, BaseObject]]
146+
List of tuples following named object API.
147+
148+
- If `objs` was already a list of (str, object) tuples then either the
149+
same named objects (as with other cases cloned versions are
150+
returned if ``clone=True``).
151+
- If `objs` was a dict[str, object] then the named objects are unpacked
152+
into a list of (str, object) tuples.
153+
- If `objs` was a list of objects then string names were generated based
154+
on the object's class names (with coercion to unique strings if
155+
necessary).
156+
"""
157+
objs = [(f"effect_{idx}", obj) for idx, obj in enumerate(objs)]
158+
return super()._coerce_to_named_object_tuples(objs, clone, make_unique)

tests/effects/test_chain.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Pytest for Chained Effects class."""
2+
3+
import jax.numpy as jnp
4+
import numpyro
5+
import pandas as pd
6+
import pytest
7+
from numpyro import handlers
8+
9+
from prophetverse.effects.base import BaseEffect
10+
from prophetverse.effects.chain import ChainedEffects
11+
12+
13+
class MockEffect(BaseEffect):
14+
def __init__(self, value):
15+
self.value = value
16+
super().__init__()
17+
18+
self._transform_called = False
19+
20+
def _transform(self, X, fh):
21+
self._transform_called = True
22+
return super()._transform(X, fh)
23+
24+
def _sample_params(self, data, predicted_effects):
25+
return {
26+
"param": numpyro.sample("param", numpyro.distributions.Delta(self.value))
27+
}
28+
29+
def _predict(self, data, predicted_effects, params):
30+
return data * params["param"]
31+
32+
33+
@pytest.fixture
34+
def index():
35+
return pd.date_range("2021-01-01", periods=6)
36+
37+
38+
@pytest.fixture
39+
def y(index):
40+
return pd.DataFrame(index=index, data=[1] * len(index))
41+
42+
43+
@pytest.fixture
44+
def X(index):
45+
return pd.DataFrame(
46+
data={"exog": [10, 20, 30, 40, 50, 60]},
47+
index=index,
48+
)
49+
50+
51+
def test_chained_effects_fit(X, y):
52+
"""Test the fit method of ChainedEffects."""
53+
effects = [MockEffect(2), MockEffect(3)]
54+
chained = ChainedEffects(steps=effects)
55+
56+
scale = 1
57+
chained.fit(y=y, X=X, scale=scale)
58+
# Ensure no exceptions occur in fit
59+
60+
61+
def test_chained_effects_transform(X, y):
62+
"""Test the transform method of ChainedEffects."""
63+
effects = [MockEffect(2), MockEffect(3)]
64+
chained = ChainedEffects(steps=effects)
65+
transformed = chained.transform(X, fh=X.index)
66+
expected = MockEffect(2).transform(X, fh=X.index)
67+
assert jnp.allclose(transformed, expected), "Chained transform output mismatch."
68+
69+
70+
def test_chained_effects_sample_params(X, y):
71+
"""Test the sample_params method of ChainedEffects."""
72+
effects = [MockEffect(2), MockEffect(3)]
73+
chained = ChainedEffects(steps=effects)
74+
chained.fit(y=y, X=X, scale=1)
75+
data = chained.transform(X, fh=X.index)
76+
77+
with handlers.trace() as trace:
78+
params = chained.sample_params(data, {})
79+
80+
assert "effect_0" in params, "Missing effect_0 params."
81+
assert "effect_1" in params, "Missing effect_1 params."
82+
assert params["effect_0"]["param"] == 2, "Incorrect effect_0 param."
83+
assert params["effect_1"]["param"] == 3, "Incorrect effect_1 param."
84+
85+
assert "0/param" in trace, "Missing effect_0 trace."
86+
assert "1/param" in trace, "Missing effect_1 trace."
87+
88+
89+
def test_chained_effects_predict(X, y):
90+
"""Test the predict method of ChainedEffects."""
91+
effects = [MockEffect(2), MockEffect(3)]
92+
chained = ChainedEffects(steps=effects)
93+
chained.fit(y=y, X=X, scale=1)
94+
data = chained.transform(X, fh=X.index)
95+
predicted_effects = {}
96+
97+
predicted = chained.predict(data, predicted_effects)
98+
expected = data * 2 * 3
99+
assert jnp.allclose(predicted, expected), "Chained predict output mismatch."
100+
101+
102+
def test_get_params():
103+
effects = [MockEffect(2), MockEffect(3)]
104+
chained = ChainedEffects(steps=effects)
105+
106+
params = chained.get_params()
107+
108+
assert params["effect_0__value"] == 2, "Incorrect effect_0 param."
109+
assert params["effect_1__value"] == 3, "Incorrect effect_1 param."

0 commit comments

Comments
 (0)