Skip to content

Commit ae46562

Browse files
Merge pull request #82 from databricks-industry-solutions/fix-neuralforecast-covariates
bug fixed for neuralforecast models
2 parents dfdcee8 + 39eacb2 commit ae46562

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

mmf_sa/Forecaster.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,9 @@ def backtest_global_model(
378378
StructField("model_pickle", BinaryType()),
379379
]
380380
)
381+
# Covert to Python-native types before converting to pyspark dataframe
382+
res_pdf['forecast'] = res_pdf['forecast'].apply(lambda x: [float(i) for i in x])
383+
res_pdf['actual'] = res_pdf['actual'].apply(lambda x: [float(i) for i in x])
381384
res_sdf = self.spark.createDataFrame(res_pdf, schema)
382385
# Write evaluation results to a delta table
383386
if write:

mmf_sa/models/neuralforecast/NeuralForecastPipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,9 @@ def calculate_metrics(
217217
else:
218218
raise Exception(f"Metric {self.params['metric']} not supported!")
219219
for key in keys:
220-
actual = val_df[val_df[self.params["group_id"]] == key][self.params["target"]]
220+
actual = val_df[val_df[self.params["group_id"]] == key][self.params["target"]].reset_index(drop=True)
221221
forecast = pred_df[pred_df[self.params["group_id"]] == key][self.params["target"]].\
222-
iloc[-self.params["prediction_length"]:]
222+
iloc[-self.params["prediction_length"]:].reset_index(drop=True)
223223
try:
224224
if metric_name == "smape":
225225
smape = MeanAbsolutePercentageError(symmetric=True)

0 commit comments

Comments
 (0)