@@ -24,6 +24,7 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
24
24
self .local_explanation = {}
25
25
self .formatted_global_explanation = None
26
26
self .formatted_local_explanation = None
27
+ self .date_col = config .spec .datetime_column .name
27
28
28
29
def set_kwargs (self ):
29
30
"""
@@ -77,8 +78,8 @@ def _train_model(self, data_train, data_test, model_kwargs):
77
78
alpha = model_kwargs ["lower_quantile" ],
78
79
),
79
80
},
80
- freq = pd .infer_freq (data_train ["Date" ].drop_duplicates ())
81
- or pd .infer_freq (data_train ["Date" ].drop_duplicates ()[- 5 :]),
81
+ freq = pd .infer_freq (data_train [self . date_col ].drop_duplicates ())
82
+ or pd .infer_freq (data_train [self . date_col ].drop_duplicates ()[- 5 :]),
82
83
target_transforms = [Differences ([12 ])],
83
84
lags = model_kwargs .get (
84
85
"lags" ,
@@ -108,7 +109,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
108
109
data_train [self .model_columns ],
109
110
static_features = model_kwargs .get ("static_features" , []),
110
111
id_col = ForecastOutputColumns .SERIES ,
111
- time_col = self .spec . datetime_column . name ,
112
+ time_col = self .date_col ,
112
113
target_col = self .spec .target_column ,
113
114
fitted = True ,
114
115
max_horizon = None if num_models is False else self .spec .horizon ,
@@ -173,7 +174,7 @@ def _build_model(self) -> pd.DataFrame:
173
174
confidence_interval_width = self .spec .confidence_interval_width ,
174
175
horizon = self .spec .horizon ,
175
176
target_column = self .original_target_column ,
176
- dt_column = self .spec . datetime_column . name ,
177
+ dt_column = self .date_col ,
177
178
)
178
179
self ._train_model (data_train , data_test , model_kwargs )
179
180
return self .forecast_output .get_forecast_long ()
0 commit comments