Skip to content

Commit 2a625ed

Browse files
authored
Feature/forecast auto select (#825)
2 parents 5fe8c49 + c2d1416 commit 2a625ed

File tree

13 files changed

+264
-110
lines changed

13 files changed

+264
-110
lines changed

ads/opctl/operator/lowcode/anomaly/model/base_model.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,29 @@
44
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import fsspec
8+
import numpy as np
79
import os
10+
import pandas as pd
811
import tempfile
912
import time
1013
from abc import ABC, abstractmethod
11-
from typing import Tuple
12-
13-
import fsspec
14-
import pandas as pd
15-
import numpy as np
1614
from sklearn import linear_model
15+
from typing import Tuple
1716

17+
from ads.common.object_storage_details import ObjectStorageDetails
1818
from ads.opctl import logger
19-
20-
from ..operator_config import AnomalyOperatorConfig, AnomalyOperatorSpec
21-
from .anomaly_dataset import AnomalyDatasets, AnomalyOutput, TestData
2219
from ads.opctl.operator.lowcode.anomaly.const import OutputColumns, SupportedMetrics
23-
from ..const import SupportedModels
20+
from ads.opctl.operator.lowcode.anomaly.utils import _build_metrics_df, default_signer
2421
from ads.opctl.operator.lowcode.common.utils import (
2522
human_time_friendly,
2623
enable_print,
2724
disable_print,
2825
write_data,
29-
merge_category_columns,
30-
find_output_dirname,
3126
)
32-
from ads.opctl.operator.lowcode.anomaly.utils import _build_metrics_df, default_signer
33-
from ads.common.object_storage_details import ObjectStorageDetails
27+
from .anomaly_dataset import AnomalyDatasets, AnomalyOutput, TestData
28+
from ..const import SupportedModels
29+
from ..operator_config import AnomalyOperatorConfig, AnomalyOperatorSpec
3430

3531

3632
class AnomalyOperatorBaseModel(ABC):
@@ -246,7 +242,7 @@ def _save_report(
246242
"""Saves resulting reports to the given folder."""
247243
import report_creator as rc
248244

249-
unique_output_dir = find_output_dirname(self.spec.output_directory)
245+
unique_output_dir = self.spec.output_directory.url
250246

251247
if ObjectStorageDetails.is_oci_path(unique_output_dir):
252248
storage_options = default_signer()

ads/opctl/operator/lowcode/anomaly/operator_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
InputData,
1717
)
1818
from .const import SupportedModels
19+
from ads.opctl.operator.lowcode.common.utils import find_output_dirname
1920

2021

2122
@dataclass(repr=True)
@@ -79,6 +80,7 @@ class AnomalyOperatorSpec(DataClassSerializable):
7980

8081
def __post_init__(self):
8182
"""Adjusts the specification details."""
83+
self.output_directory = self.output_directory or OutputDirectory(url=find_output_dirname(self.output_directory))
8284
self.report_file_name = self.report_file_name or "report.html"
8385
self.report_theme = self.report_theme or "light"
8486
self.inliers_filename = self.inliers_filename or "inliers.csv"

ads/opctl/operator/lowcode/common/transformations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def _set_series_id_column(self, df):
9797
for value in merged_values:
9898
self._target_category_columns_map[value] = df[df[DataColumns.Series] == value][self.target_category_columns].drop_duplicates().iloc[0].to_dict()
9999

100-
df = df.drop(self.target_category_columns, axis=1)
100+
if self.target_category_columns != [DataColumns.Series]:
101+
df = df.drop(self.target_category_columns, axis=1)
101102
return df
102103

103104
def _format_datetime_col(self, df):

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def human_time_friendly(seconds):
215215

216216

217217
def find_output_dirname(output_dir: OutputDirectory):
218-
if output_dir:
218+
if output_dir and output_dir.url:
219219
return output_dir.url
220220
output_dir = "results"
221221

ads/opctl/operator/lowcode/forecast/__main__.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,9 @@ def operate(operator_config: ForecastOperatorConfig) -> None:
2424
from .model.factory import ForecastOperatorModelFactory
2525

2626
datasets = ForecastDatasets(operator_config)
27-
try:
28-
ForecastOperatorModelFactory.get_model(
29-
operator_config, datasets
30-
).generate_report()
31-
except Exception as e:
32-
if operator_config.spec.model == "auto":
33-
logger.debug(
34-
f"Failed to forecast with error {e.args}. Trying again with model `prophet`."
35-
)
36-
operator_config.spec.model = "prophet"
37-
operator_config.spec.model_kwargs = dict()
38-
datasets = ForecastDatasets(operator_config)
39-
ForecastOperatorModelFactory.get_model(
40-
operator_config, datasets
41-
).generate_report()
42-
else:
43-
raise
44-
27+
ForecastOperatorModelFactory.get_model(
28+
operator_config, datasets
29+
).generate_report()
4530

4631
def verify(spec: Dict, **kwargs: Dict) -> bool:
4732
"""Verifies the forecasting operator config."""

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,4 @@ class ForecastOutputColumns(str, metaclass=ExtendedEnumMeta):
8787
SUMMARY_METRICS_HORIZON_LIMIT = 10
8888
PROPHET_INTERNAL_DATE_COL = "ds"
8989
RENDER_LIMIT = 5000
90+
AUTO_SELECT = "auto-select"

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,19 @@
44
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7-
import json
7+
import fsspec
8+
import numpy as np
89
import os
10+
import pandas as pd
911
import tempfile
1012
import time
13+
import traceback
1114
from abc import ABC, abstractmethod
1215
from typing import Tuple
13-
import traceback
14-
15-
import fsspec
16-
import numpy as np
17-
import pandas as pd
1816

19-
from ads.opctl.operator.lowcode.forecast.utils import (
20-
default_signer,
21-
evaluate_train_metrics,
22-
get_forecast_plots,
23-
_build_metrics_df,
24-
_build_metrics_per_horizon,
25-
load_pkl,
26-
write_pkl,
27-
_label_encode_dataframe,
28-
)
17+
from ads.common.decorator.runtime_dependency import runtime_dependency
2918
from ads.common.object_storage_details import ObjectStorageDetails
3019
from ads.opctl import logger
31-
3220
from ads.opctl.operator.lowcode.common.utils import (
3321
human_time_friendly,
3422
enable_print,
@@ -37,18 +25,28 @@
3725
merged_category_column_name,
3826
datetime_to_seconds,
3927
seconds_to_datetime,
40-
find_output_dirname,
4128
)
29+
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
30+
from ads.opctl.operator.lowcode.forecast.utils import (
31+
default_signer,
32+
evaluate_train_metrics,
33+
get_forecast_plots,
34+
get_auto_select_plot,
35+
_build_metrics_df,
36+
_build_metrics_per_horizon,
37+
load_pkl,
38+
write_pkl,
39+
_label_encode_dataframe,
40+
)
41+
from .forecast_datasets import ForecastDatasets
4242
from ..const import (
4343
SUMMARY_METRICS_HORIZON_LIMIT,
4444
SupportedMetrics,
4545
SupportedModels,
4646
SpeedAccuracyMode,
47+
AUTO_SELECT
4748
)
4849
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
49-
from ads.common.decorator.runtime_dependency import runtime_dependency
50-
from .forecast_datasets import ForecastDatasets, ForecastOutput
51-
from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData
5250

5351

5452
class ForecastOperatorBaseModel(ABC):
@@ -250,6 +248,23 @@ def generate_report(self):
250248
sec9 = rc.DataTable(self.eval_metrics, index=True)
251249
train_metrics_sections = [sec9_text, sec9]
252250

251+
backtest_sections = []
252+
if self.spec.model == AUTO_SELECT:
253+
output_dir = self.spec.output_directory.url
254+
backtest_report_name = "backtest_stats.csv"
255+
backtest_stats = pd.read_csv(f"{output_dir}/{backtest_report_name}")
256+
average_dict = backtest_stats.mean().to_dict()
257+
del average_dict['backtest']
258+
best_model = min(average_dict, key=average_dict.get)
259+
backtest_text = rc.Heading("Back Testing Metrics", level=2)
260+
summary_text = rc.Text(
261+
f"Overall, the average scores for the models are {average_dict}, with {best_model}"
262+
f" being identified as the top-performing model during backtesting.")
263+
backtest_table = rc.DataTable(backtest_stats, index=True)
264+
liner_plot = get_auto_select_plot(backtest_stats)
265+
backtest_sections = [backtest_text, backtest_table, summary_text, liner_plot]
266+
267+
253268
forecast_plots = []
254269
if len(self.forecast_output.list_series_ids()) > 0:
255270
forecast_text = rc.Heading(
@@ -276,6 +291,7 @@ def generate_report(self):
276291
yaml_appendix = rc.Yaml(self.config.to_dict())
277292
report_sections = (
278293
[summary]
294+
+ backtest_sections
279295
+ forecast_plots
280296
+ other_sections
281297
+ test_metrics_sections
@@ -409,7 +425,7 @@ def _save_report(
409425
"""Saves resulting reports to the given folder."""
410426
import report_creator as rc
411427

412-
unique_output_dir = find_output_dirname(self.spec.output_directory)
428+
unique_output_dir = self.spec.output_directory.url
413429

414430
if ObjectStorageDetails.is_oci_path(unique_output_dir):
415431
storage_options = default_signer()

ads/opctl/operator/lowcode/forecast/model/factory.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7-
from ..const import SupportedModels
7+
from ..const import SupportedModels, AUTO_SELECT
88
from ..operator_config import ForecastOperatorConfig
99
from .arima import ArimaOperatorModel
1010
from .automlx import AutoMLXOperatorModel
1111
from .autots import AutoTSOperatorModel
1212
from .base_model import ForecastOperatorBaseModel
1313
from .neuralprophet import NeuralProphetOperatorModel
1414
from .prophet import ProphetOperatorModel
15-
from .ml_forecast import MLForecastOperatorModel
16-
from ..utils import select_auto_model
1715
from .forecast_datasets import ForecastDatasets
16+
from .ml_forecast import MLForecastOperatorModel
17+
from ..model_evaluator import ModelEvaluator
1818

1919
class UnSupportedModelError(Exception):
2020
def __init__(self, model_type: str):
@@ -63,8 +63,35 @@ def get_model(
6363
In case of not supported model.
6464
"""
6565
model_type = operator_config.spec.model
66-
if model_type == "auto":
67-
model_type = select_auto_model(datasets, operator_config)
66+
if model_type == AUTO_SELECT:
67+
model_type = cls.auto_select_model(datasets, operator_config)
68+
operator_config.spec.model_kwargs = dict()
6869
if model_type not in cls._MAP:
6970
raise UnSupportedModelError(model_type)
7071
return cls._MAP[model_type](config=operator_config, datasets=datasets)
72+
73+
@classmethod
74+
def auto_select_model(
75+
cls, datasets: ForecastDatasets, operator_config: ForecastOperatorConfig
76+
) -> str:
77+
"""
78+
Selects AutoMLX or Arima model based on column count.
79+
80+
If the number of columns is less than or equal to the maximum allowed for AutoMLX,
81+
returns 'AutoMLX'. Otherwise, returns 'Arima'.
82+
83+
Parameters
84+
------------
85+
datasets: ForecastDatasets
86+
Datasets for predictions
87+
88+
Returns
89+
--------
90+
str
91+
The type of the model.
92+
"""
93+
all_models = operator_config.spec.model_kwargs.get("model_list", cls._MAP.keys())
94+
num_backtests = operator_config.spec.model_kwargs.get("num_backtests", 5)
95+
sample_ratio = operator_config.spec.model_kwargs.get("sample_ratio", 0.20)
96+
model_evaluator = ModelEvaluator(all_models, num_backtests, sample_ratio)
97+
return model_evaluator.find_best_model(datasets, operator_config)

0 commit comments

Comments
 (0)