5
5
import jax .numpy as jnp
6
6
import pandas as pd
7
7
8
- from prophetverse .effects .base import BaseEffect , Stage
8
+ from prophetverse .effects .base import BaseEffect
9
9
from prophetverse .utils .frame_to_array import series_to_tensor_or_array
10
10
11
11
@@ -19,33 +19,42 @@ class MyEffectName(BaseEffect):
19
19
# If no columns are found, should
20
20
# _predict be skipped?
21
21
"skip_predict_if_no_match" : True ,
22
+ # Should only the indexes related to the forecasting horizon be passed to
23
+ # _transform?
24
+ "filter_indexes_with_forecating_horizon_at_transform" : True ,
22
25
}
23
26
24
27
def __init__ (self , param1 : Any , param2 : Any ):
25
28
self .param1 = param1
26
29
self .param2 = param2
27
30
28
- def _fit (self , X : pd .DataFrame , scale : float = 1.0 ):
31
+ def _fit (self , y : pd . DataFrame , X : pd .DataFrame , scale : float = 1.0 ):
29
32
"""Customize the initialization of the effect.
30
33
31
34
This method is called by the `fit()` method and can be overridden by
32
35
subclasses to provide additional initialization logic.
33
36
34
37
Parameters
35
38
----------
39
+ y : pd.DataFrame
40
+ The timeseries dataframe
41
+
36
42
X : pd.DataFrame
37
43
The DataFrame to initialize the effect.
44
+
45
+ scale : float, optional
46
+ The scale of the timeseries. For multivariate timeseries, this is
47
+ a dataframe. For univariate, it is a simple float.
38
48
"""
39
49
# Do something with X, scale, and other parameters
40
50
pass
41
51
42
- def _transform (
43
- self , X : pd .DataFrame , stage : Stage = Stage .TRAIN
44
- ) -> Dict [str , jnp .ndarray ]:
45
- """Prepare the input data in a dict of jax arrays.
52
+ def _transform (self , X : pd .DataFrame , fh : pd .Index ) -> Any :
53
+ """Prepare input data to be passed to numpyro model.
46
54
47
- This method is called by the `fit()` method and can be overridden
48
- by subclasses to provide additional data preparation logic.
55
+ This method receives the Exogenous variables DataFrame and should return a
56
+ the data needed for the effect. Those data will be passed to the `predict`
57
+ method as `data` argument.
49
58
50
59
Parameters
51
60
----------
@@ -54,43 +63,46 @@ def _transform(
54
63
time indexes, if passed during fit, or for the forecasting time indexes, if
55
64
passed during predict.
56
65
57
- stage : Stage, optional
58
- The stage of the effect, by default Stage.TRAIN. This can be used to
59
- differentiate between training and prediction stages and apply different
60
- transformations accordingly.
66
+ fh : pd.Index
67
+ The forecasting horizon as a pandas Index.
61
68
62
69
Returns
63
70
-------
64
- Dict[str, jnp.ndarray]
65
- A dictionary containing the data needed for the effect. The keys of the
66
- dictionary should be the names of the arguments of the `apply` method, and
67
- the values should be the corresponding data as jnp.ndarray.
71
+ Any
72
+ Any object containing the data needed for the effect. The object will be
73
+ passed to `predict` method as `data` argument.
68
74
"""
69
75
# Do something with X
70
- if stage == "train" :
71
- array = series_to_tensor_or_array (X )
72
- else :
73
- # something else
74
- pass
75
- return {"data" : array }
76
-
77
- def _predict (self , trend : jnp .ndarray , ** kwargs ) -> jnp .ndarray :
78
- """Apply the effect.
76
+ array = series_to_tensor_or_array (X )
77
+ return array
79
78
80
- This method is called by the `apply()` method and must be overridden by
81
- subclasses to provide the actual effect computation logic.
79
+ def predict (
80
+ self ,
81
+ data : Dict ,
82
+ predicted_effects : Dict [str , jnp .ndarray ],
83
+ ) -> jnp .ndarray :
84
+ """Apply and return the effect values.
82
85
83
86
Parameters
84
87
----------
85
- trend : jnp.ndarray
86
- An array containing the trend values .
88
+ data : Any
89
+ Data obtained from the transformed method .
87
90
88
- kwargs: dict
89
- Additional keyword arguments that may be needed to compute the effect .
91
+ predicted_effects : Dict[str, jnp.ndarray], optional
92
+ A dictionary containing the predicted effects, by default None .
90
93
91
94
Returns
92
95
-------
93
96
jnp.ndarray
94
- The effect values.
97
+ An array with shape (T,1) for univariate timeseries, or (N, T, 1) for
98
+ multivariate timeseries, where T is the number of timepoints and N is the
99
+ number of series.
95
100
"""
101
+ # Get the trend
102
+ # (T, 1) shaped array for univariate timeseries
103
+ # (N, T, 1) shaped array for multivariate timeseries, where N is the number of
104
+ # series
105
+ # trend: jnp.ndarray = predicted_effects["trend"]
106
+ # Or user predicted_effects.get("trend") to return None if the trend is
107
+ # not found
96
108
raise NotImplementedError ("Subclasses must implement _predict()" )
0 commit comments