Skip to content

Commit fd9b99f

Browse files
[ENH] Trend as an Effect
[ENH] Trend as an Effect
2 parents 591099e + cc948e5 commit fd9b99f

32 files changed

+959
-782
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ examples/personal
44
notebooks/*
55
.vscode/*
66
poetry.lock
7+
x.py
78

89
# Byte-compiled / optimized / DLL files
910
__pycache__/

docs/deprecation.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Deprecation policy

extension_templates/effect.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import jax.numpy as jnp
66
import pandas as pd
77

8-
from prophetverse.effects.base import BaseEffect, Stage
8+
from prophetverse.effects.base import BaseEffect
99
from prophetverse.utils.frame_to_array import series_to_tensor_or_array
1010

1111

@@ -19,33 +19,42 @@ class MyEffectName(BaseEffect):
1919
# If no columns are found, should
2020
# _predict be skipped?
2121
"skip_predict_if_no_match": True,
22+
# Should only the indexes related to the forecasting horizon be passed to
23+
# _transform?
24+
"filter_indexes_with_forecating_horizon_at_transform": True,
2225
}
2326

2427
def __init__(self, param1: Any, param2: Any):
2528
self.param1 = param1
2629
self.param2 = param2
2730

28-
def _fit(self, X: pd.DataFrame, scale: float = 1.0):
31+
def _fit(self, y: pd.DataFrame, X: pd.DataFrame, scale: float = 1.0):
2932
"""Customize the initialization of the effect.
3033
3134
This method is called by the `fit()` method and can be overridden by
3235
subclasses to provide additional initialization logic.
3336
3437
Parameters
3538
----------
39+
y : pd.DataFrame
40+
The timeseries dataframe
41+
3642
X : pd.DataFrame
3743
The DataFrame to initialize the effect.
44+
45+
scale : float, optional
46+
The scale of the timeseries. For multivariate timeseries, this is
47+
a dataframe. For univariate, it is a simple float.
3848
"""
3949
# Do something with X, scale, and other parameters
4050
pass
4151

42-
def _transform(
43-
self, X: pd.DataFrame, stage: Stage = Stage.TRAIN
44-
) -> Dict[str, jnp.ndarray]:
45-
"""Prepare the input data in a dict of jax arrays.
52+
def _transform(self, X: pd.DataFrame, fh: pd.Index) -> Any:
53+
"""Prepare input data to be passed to numpyro model.
4654
47-
This method is called by the `fit()` method and can be overridden
48-
by subclasses to provide additional data preparation logic.
55+
This method receives the Exogenous variables DataFrame and should return a
56+
the data needed for the effect. Those data will be passed to the `predict`
57+
method as `data` argument.
4958
5059
Parameters
5160
----------
@@ -54,43 +63,46 @@ def _transform(
5463
time indexes, if passed during fit, or for the forecasting time indexes, if
5564
passed during predict.
5665
57-
stage : Stage, optional
58-
The stage of the effect, by default Stage.TRAIN. This can be used to
59-
differentiate between training and prediction stages and apply different
60-
transformations accordingly.
66+
fh : pd.Index
67+
The forecasting horizon as a pandas Index.
6168
6269
Returns
6370
-------
64-
Dict[str, jnp.ndarray]
65-
A dictionary containing the data needed for the effect. The keys of the
66-
dictionary should be the names of the arguments of the `apply` method, and
67-
the values should be the corresponding data as jnp.ndarray.
71+
Any
72+
Any object containing the data needed for the effect. The object will be
73+
passed to `predict` method as `data` argument.
6874
"""
6975
# Do something with X
70-
if stage == "train":
71-
array = series_to_tensor_or_array(X)
72-
else:
73-
# something else
74-
pass
75-
return {"data": array}
76-
77-
def _predict(self, trend: jnp.ndarray, **kwargs) -> jnp.ndarray:
78-
"""Apply the effect.
76+
array = series_to_tensor_or_array(X)
77+
return array
7978

80-
This method is called by the `apply()` method and must be overridden by
81-
subclasses to provide the actual effect computation logic.
79+
def predict(
80+
self,
81+
data: Dict,
82+
predicted_effects: Dict[str, jnp.ndarray],
83+
) -> jnp.ndarray:
84+
"""Apply and return the effect values.
8285
8386
Parameters
8487
----------
85-
trend : jnp.ndarray
86-
An array containing the trend values.
88+
data : Any
89+
Data obtained from the transformed method.
8790
88-
kwargs: dict
89-
Additional keyword arguments that may be needed to compute the effect.
91+
predicted_effects : Dict[str, jnp.ndarray], optional
92+
A dictionary containing the predicted effects, by default None.
9093
9194
Returns
9295
-------
9396
jnp.ndarray
94-
The effect values.
97+
An array with shape (T,1) for univariate timeseries, or (N, T, 1) for
98+
multivariate timeseries, where T is the number of timepoints and N is the
99+
number of series.
95100
"""
101+
# Get the trend
102+
# (T, 1) shaped array for univariate timeseries
103+
# (N, T, 1) shaped array for multivariate timeseries, where N is the number of
104+
# series
105+
# trend: jnp.ndarray = predicted_effects["trend"]
106+
# Or user predicted_effects.get("trend") to return None if the trend is
107+
# not found
96108
raise NotImplementedError("Subclasses must implement _predict()")

0 commit comments

Comments
 (0)