@@ -71,7 +71,8 @@ def prepare_data(self, df: pd.DataFrame, future: bool = False, spark=None) -> Da
71
71
.agg (
72
72
collect_list (self .params .date_col ).alias ('ds' ),
73
73
collect_list (self .params .target ).alias ('y' ),
74
- ))
74
+ )).withColumnRenamed (self .params .group_id , "unique_id" )
75
+
75
76
return df
76
77
77
78
def predict (self ,
@@ -110,37 +111,24 @@ def calculate_metrics(
110
111
pred_df , model_pretrained = self .predict (hist_df , val_df , curr_date , spark )
111
112
keys = pred_df [self .params ["group_id" ]].unique ()
112
113
metrics = []
113
- if self .params ["metric" ] == "smape" :
114
- metric_name = "smape"
115
- elif self .params ["metric" ] == "mape" :
116
- metric_name = "mape"
117
- elif self .params ["metric" ] == "mae" :
118
- metric_name = "mae"
119
- elif self .params ["metric" ] == "mse" :
120
- metric_name = "mse"
121
- elif self .params ["metric" ] == "rmse" :
122
- metric_name = "rmse"
123
- else :
114
+ metric_name = self .params ["metric" ]
115
+ if metric_name not in ("smape" , "mape" , "mae" , "mse" , "rmse" ):
124
116
raise Exception (f"Metric { self .params ['metric' ]} not supported!" )
125
117
for key in keys :
126
118
actual = val_df [val_df [self .params ["group_id" ]] == key ][self .params ["target" ]].to_numpy ()
127
119
forecast = pred_df [pred_df [self .params ["group_id" ]] == key ][self .params ["target" ]].to_numpy ()[0 ]
120
+ # Mapping metric names to their respective classes
121
+ metric_classes = {
122
+ "smape" : MeanAbsolutePercentageError (symmetric = True ),
123
+ "mape" : MeanAbsolutePercentageError (symmetric = False ),
124
+ "mae" : MeanAbsoluteError (),
125
+ "mse" : MeanSquaredError (square_root = False ),
126
+ "rmse" : MeanSquaredError (square_root = True ),
127
+ }
128
128
try :
129
- if metric_name == "smape" :
130
- smape = MeanAbsolutePercentageError (symmetric = True )
131
- metric_value = smape (actual , forecast )
132
- elif metric_name == "mape" :
133
- mape = MeanAbsolutePercentageError (symmetric = False )
134
- metric_value = mape (actual , forecast )
135
- elif metric_name == "mae" :
136
- mae = MeanAbsoluteError ()
137
- metric_value = mae (actual , forecast )
138
- elif metric_name == "mse" :
139
- mse = MeanSquaredError (square_root = False )
140
- metric_value = mse (actual , forecast )
141
- elif metric_name == "rmse" :
142
- rmse = MeanSquaredError (square_root = True )
143
- metric_value = rmse (actual , forecast )
129
+ if metric_name in metric_classes :
130
+ metric_function = metric_classes [metric_name ]
131
+ metric_value = metric_function (actual , forecast )
144
132
metrics .extend (
145
133
[(
146
134
key ,
@@ -240,6 +228,7 @@ def __init__(self, params):
240
228
self .params = params
241
229
self .repo = "amazon/chronos-bolt-small"
242
230
231
+
243
232
class ChronosBoltBase (ChronosForecaster ):
244
233
def __init__ (self , params ):
245
234
super ().__init__ (params )
@@ -268,4 +257,3 @@ def predict(self, context, input_data, params=None):
268
257
prediction_length = self .prediction_length ,
269
258
)
270
259
return forecast .numpy ()
271
-
0 commit comments