Skip to content

Commit 7e32609

Browse files
streamlined some code
1 parent ae46562 commit 7e32609

File tree

7 files changed

+131
-192
lines changed

7 files changed

+131
-192
lines changed

mmf_sa/Forecaster.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,15 @@ def prepare_data_for_global_model(self, mode: str = None):
114114
"""
115115
src_df = self.resolve_source("train_data")
116116
src_df, removed = DataQualityChecks(src_df, self.conf, self.spark).run()
117+
118+
# This block runs when preparing data for scoring
117119
if (mode == "scoring") \
118120
and (self.conf["scoring_data"]) \
119121
and (self.conf["scoring_data"] != self.conf["train_data"]):
120122
score_df = self.resolve_source("scoring_data")
121123
score_df = score_df.where(~col(self.conf["group_id"]).isin(removed))
122124
src_df = src_df.unionByName(score_df, allowMissingColumns=True)
125+
123126
src_df = src_df.toPandas()
124127
return src_df, removed
125128

@@ -323,9 +326,7 @@ def evaluate_global_model(self, model_conf):
323326
model=final_model,
324327
registered_model_name=f"{self.conf['model_output']}.{model_conf['name']}_{self.conf['use_case_name']}",
325328
input_example=input_example,
326-
run_id=self.run_id,
327329
)
328-
329330
# Next, we train the model only with train_df and run detailed backtesting
330331
model = self.model_registry.get_model(model_name)
331332
model.fit(pd.concat([train_df]))
@@ -336,6 +337,8 @@ def evaluate_global_model(self, model_conf):
336337
model_uri=model_info.model_uri, # This model_uri is from the final model
337338
write=True,
338339
)
340+
mlflow.set_tag("run_id", self.run_id)
341+
mlflow.set_tag("model_name", model.params["name"])
339342

340343
def backtest_global_model(
341344
self,
@@ -423,13 +426,14 @@ def evaluate_foundation_model(self, model_conf):
423426
with mlflow.start_run(experiment_id=self.experiment_id) as run:
424427
model_name = model_conf["name"]
425428
model = self.model_registry.get_model(model_name)
426-
# For now, only support registering chronos, moirai and moment models
429+
hist_df, removed = self.prepare_data_for_global_model("evaluating") # Reuse the same as global
430+
train_df, val_df = self.split_df_train_val(hist_df)
431+
input_example = train_df[train_df[self.conf['group_id']] == train_df[self.conf['group_id']] \
432+
.unique()[0]].sort_values(by=[self.conf['date_col']])
427433
if model_conf["framework"] in ["Chronos", "Moirai", "Moment", "TimesFM"]:
428434
model.register(
429-
registered_model_name=f"{self.conf['model_output']}.{model_conf['name']}_{self.conf['use_case_name']}"
435+
registered_model_name=f"{self.conf['model_output']}.{model_conf['name']}_{self.conf['use_case_name']}",
430436
)
431-
hist_df, removed = self.prepare_data_for_global_model("evaluating") # Reuse the same as global
432-
train_df, val_df = self.split_df_train_val(hist_df)
433437
model_uri = f"runs:/{run.info.run_id}/model"
434438
metrics = self.backtest_global_model( # Reuse the same as global
435439
model=model,

mmf_sa/models/abstract_model.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def backtest(
4747
df: pd.DataFrame,
4848
start: pd.Timestamp,
4949
group_id: Union[str, int] = None,
50-
stride: int = None,
5150
# backtest_retrain: bool = False,
5251
spark=None,
5352
) -> pd.DataFrame:
@@ -58,19 +57,18 @@ def backtest(
5857
df (pd.DataFrame): A pandas DataFrame.
5958
start (pd.Timestamp): A pandas Timestamp object.
6059
group_id (Union[str, int], optional): A string or an integer specifying the group id. Default is None.
61-
stride (int, optional): An integer specifying the stride. Default is None.
6260
spark (SparkSession, optional): A SparkSession object. Default is None.
6361
Returns: res_df (pd.DataFrame): A pandas DataFrame.
6462
"""
65-
if stride is None:
66-
stride = int(self.params.get("stride", 7))
63+
stride = int(self.params["stride"]) # Read in stride
6764
stride_offset = (
6865
pd.offsets.MonthEnd(stride)
6966
if self.freq == "M"
7067
else pd.DateOffset(days=stride)
7168
)
7269
df = df.copy().sort_values(by=[self.params["date_col"]])
73-
end_date = df[self.params["date_col"]].max()
70+
end_date = df[self.params["date_col"]].max() # Last date from the training data
71+
# Offsets the timestamp: e.g. if it's in the middle of the month, makes it the end of the month
7472
curr_date = start + self.one_ts_offset
7573
# print("end_date = ", end_date)
7674

mmf_sa/models/chronosforecast/ChronosPipeline.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def prepare_data(self, df: pd.DataFrame, future: bool = False, spark=None) -> Da
7171
.agg(
7272
collect_list(self.params.date_col).alias('ds'),
7373
collect_list(self.params.target).alias('y'),
74-
))
74+
)).withColumnRenamed(self.params.group_id, "unique_id")
75+
7576
return df
7677

7778
def predict(self,
@@ -110,37 +111,24 @@ def calculate_metrics(
110111
pred_df, model_pretrained = self.predict(hist_df, val_df, curr_date, spark)
111112
keys = pred_df[self.params["group_id"]].unique()
112113
metrics = []
113-
if self.params["metric"] == "smape":
114-
metric_name = "smape"
115-
elif self.params["metric"] == "mape":
116-
metric_name = "mape"
117-
elif self.params["metric"] == "mae":
118-
metric_name = "mae"
119-
elif self.params["metric"] == "mse":
120-
metric_name = "mse"
121-
elif self.params["metric"] == "rmse":
122-
metric_name = "rmse"
123-
else:
114+
metric_name = self.params["metric"]
115+
if metric_name not in ("smape", "mape", "mae", "mse", "rmse"):
124116
raise Exception(f"Metric {self.params['metric']} not supported!")
125117
for key in keys:
126118
actual = val_df[val_df[self.params["group_id"]] == key][self.params["target"]].to_numpy()
127119
forecast = pred_df[pred_df[self.params["group_id"]] == key][self.params["target"]].to_numpy()[0]
120+
# Mapping metric names to their respective classes
121+
metric_classes = {
122+
"smape": MeanAbsolutePercentageError(symmetric=True),
123+
"mape": MeanAbsolutePercentageError(symmetric=False),
124+
"mae": MeanAbsoluteError(),
125+
"mse": MeanSquaredError(square_root=False),
126+
"rmse": MeanSquaredError(square_root=True),
127+
}
128128
try:
129-
if metric_name == "smape":
130-
smape = MeanAbsolutePercentageError(symmetric=True)
131-
metric_value = smape(actual, forecast)
132-
elif metric_name == "mape":
133-
mape = MeanAbsolutePercentageError(symmetric=False)
134-
metric_value = mape(actual, forecast)
135-
elif metric_name == "mae":
136-
mae = MeanAbsoluteError()
137-
metric_value = mae(actual, forecast)
138-
elif metric_name == "mse":
139-
mse = MeanSquaredError(square_root=False)
140-
metric_value = mse(actual, forecast)
141-
elif metric_name == "rmse":
142-
rmse = MeanSquaredError(square_root=True)
143-
metric_value = rmse(actual, forecast)
129+
if metric_name in metric_classes:
130+
metric_function = metric_classes[metric_name]
131+
metric_value = metric_function(actual, forecast)
144132
metrics.extend(
145133
[(
146134
key,
@@ -240,6 +228,7 @@ def __init__(self, params):
240228
self.params = params
241229
self.repo = "amazon/chronos-bolt-small"
242230

231+
243232
class ChronosBoltBase(ChronosForecaster):
244233
def __init__(self, params):
245234
super().__init__(params)
@@ -268,4 +257,3 @@ def predict(self, context, input_data, params=None):
268257
prediction_length=self.prediction_length,
269258
)
270259
return forecast.numpy()
271-

mmf_sa/models/moiraiforecast/MoiraiPipeline.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def prepare_data(self, df: pd.DataFrame, future: bool = False, spark=None) -> Da
6969
.agg(
7070
collect_list(self.params.date_col).alias('ds'),
7171
collect_list(self.params.target).alias('y'),
72-
))
72+
)).withColumnRenamed(self.params.group_id, "unique_id")
7373
return df
7474

7575
def predict(self,
@@ -110,37 +110,24 @@ def calculate_metrics(
110110
pred_df, model_pretrained = self.predict(hist_df, val_df, curr_date, spark)
111111
keys = pred_df[self.params["group_id"]].unique()
112112
metrics = []
113-
if self.params["metric"] == "smape":
114-
metric_name = "smape"
115-
elif self.params["metric"] == "mape":
116-
metric_name = "mape"
117-
elif self.params["metric"] == "mae":
118-
metric_name = "mae"
119-
elif self.params["metric"] == "mse":
120-
metric_name = "mse"
121-
elif self.params["metric"] == "rmse":
122-
metric_name = "rmse"
123-
else:
113+
metric_name = self.params["metric"]
114+
if metric_name not in ("smape", "mape", "mae", "mse", "rmse"):
124115
raise Exception(f"Metric {self.params['metric']} not supported!")
125116
for key in keys:
126117
actual = val_df[val_df[self.params["group_id"]] == key][self.params["target"]].to_numpy()
127118
forecast = pred_df[pred_df[self.params["group_id"]] == key][self.params["target"]].to_numpy()[0]
119+
# Mapping metric names to their respective classes
120+
metric_classes = {
121+
"smape": MeanAbsolutePercentageError(symmetric=True),
122+
"mape": MeanAbsolutePercentageError(symmetric=False),
123+
"mae": MeanAbsoluteError(),
124+
"mse": MeanSquaredError(square_root=False),
125+
"rmse": MeanSquaredError(square_root=True),
126+
}
128127
try:
129-
if metric_name == "smape":
130-
smape = MeanAbsolutePercentageError(symmetric=True)
131-
metric_value = smape(actual, forecast)
132-
elif metric_name == "mape":
133-
mape = MeanAbsolutePercentageError(symmetric=False)
134-
metric_value = mape(actual, forecast)
135-
elif metric_name == "mae":
136-
mae = MeanAbsoluteError()
137-
metric_value = mae(actual, forecast)
138-
elif metric_name == "mse":
139-
mse = MeanSquaredError(square_root=False)
140-
metric_value = mse(actual, forecast)
141-
elif metric_name == "rmse":
142-
rmse = MeanSquaredError(square_root=True)
143-
metric_value = rmse(actual, forecast)
128+
if metric_name in metric_classes:
129+
metric_function = metric_classes[metric_name]
130+
metric_value = metric_function(actual, forecast)
144131
metrics.extend(
145132
[(
146133
key,

mmf_sa/models/momentforecast/MomentPipeline.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -115,37 +115,24 @@ def calculate_metrics(
115115
pred_df, model_pretrained = self.predict(hist_df, val_df, curr_date, spark)
116116
keys = pred_df[self.params["group_id"]].unique()
117117
metrics = []
118-
if self.params["metric"] == "smape":
119-
metric_name = "smape"
120-
elif self.params["metric"] == "mape":
121-
metric_name = "mape"
122-
elif self.params["metric"] == "mae":
123-
metric_name = "mae"
124-
elif self.params["metric"] == "mse":
125-
metric_name = "mse"
126-
elif self.params["metric"] == "rmse":
127-
metric_name = "rmse"
128-
else:
118+
metric_name = self.params["metric"]
119+
if metric_name not in ("smape", "mape", "mae", "mse", "rmse"):
129120
raise Exception(f"Metric {self.params['metric']} not supported!")
130121
for key in keys:
131122
actual = val_df[val_df[self.params["group_id"]] == key][self.params["target"]].to_numpy()
132123
forecast = pred_df[pred_df[self.params["group_id"]] == key][self.params["target"]].to_numpy()[0]
124+
# Mapping metric names to their respective classes
125+
metric_classes = {
126+
"smape": MeanAbsolutePercentageError(symmetric=True),
127+
"mape": MeanAbsolutePercentageError(symmetric=False),
128+
"mae": MeanAbsoluteError(),
129+
"mse": MeanSquaredError(square_root=False),
130+
"rmse": MeanSquaredError(square_root=True),
131+
}
133132
try:
134-
if metric_name == "smape":
135-
smape = MeanAbsolutePercentageError(symmetric=True)
136-
metric_value = smape(actual, forecast)
137-
elif metric_name == "mape":
138-
mape = MeanAbsolutePercentageError(symmetric=False)
139-
metric_value = mape(actual, forecast)
140-
elif metric_name == "mae":
141-
mae = MeanAbsoluteError()
142-
metric_value = mae(actual, forecast)
143-
elif metric_name == "mse":
144-
mse = MeanSquaredError(square_root=False)
145-
metric_value = mse(actual, forecast)
146-
elif metric_name == "rmse":
147-
rmse = MeanSquaredError(square_root=True)
148-
metric_value = rmse(actual, forecast)
133+
if metric_name in metric_classes:
134+
metric_function = metric_classes[metric_name]
135+
metric_value = metric_function(actual, forecast)
149136
metrics.extend(
150137
[(
151138
key,

0 commit comments

Comments
 (0)