Skip to content

Commit c2d1416

Browse files
committed
added doc strings and exposed auto-select parameters
1 parent faf1684 commit c2d1416

File tree

6 files changed

+30
-10
lines changed

6 files changed

+30
-10
lines changed

ads/opctl/operator/lowcode/anomaly/operator_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
InputData,
1717
)
1818
from .const import SupportedModels
19+
from ads.opctl.operator.lowcode.common.utils import find_output_dirname
1920

2021

2122
@dataclass(repr=True)
@@ -79,6 +80,7 @@ class AnomalyOperatorSpec(DataClassSerializable):
7980

8081
def __post_init__(self):
8182
"""Adjusts the specification details."""
83+
self.output_directory = self.output_directory or OutputDirectory(url=find_output_dirname(self.output_directory))
8284
self.report_file_name = self.report_file_name or "report.html"
8385
self.report_theme = self.report_theme or "light"
8486
self.inliers_filename = self.inliers_filename or "inliers.csv"

ads/opctl/operator/lowcode/forecast/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,4 @@ class ForecastOutputColumns(str, metaclass=ExtendedEnumMeta):
8787
SUMMARY_METRICS_HORIZON_LIMIT = 10
8888
PROPHET_INTERNAL_DATE_COL = "ds"
8989
RENDER_LIMIT = 5000
90+
AUTO_SELECT = "auto-select"

ads/opctl/operator/lowcode/forecast/model/base_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
SupportedMetrics,
4545
SupportedModels,
4646
SpeedAccuracyMode,
47+
AUTO_SELECT
4748
)
4849
from ..operator_config import ForecastOperatorConfig, ForecastOperatorSpec
4950

@@ -248,7 +249,7 @@ def generate_report(self):
248249
train_metrics_sections = [sec9_text, sec9]
249250

250251
backtest_sections = []
251-
if self.spec.model == "auto-select":
252+
if self.spec.model == AUTO_SELECT:
252253
output_dir = self.spec.output_directory.url
253254
backtest_report_name = "backtest_stats.csv"
254255
backtest_stats = pd.read_csv(f"{output_dir}/{backtest_report_name}")

ads/opctl/operator/lowcode/forecast/model/factory.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7-
from ..const import SupportedModels
7+
from ..const import SupportedModels, AUTO_SELECT
88
from ..operator_config import ForecastOperatorConfig
99
from .arima import ArimaOperatorModel
1010
from .automlx import AutoMLXOperatorModel
@@ -14,6 +14,7 @@
1414
from .prophet import ProphetOperatorModel
1515
from .forecast_datasets import ForecastDatasets
1616
from .ml_forecast import MLForecastOperatorModel
17+
from ..model_evaluator import ModelEvaluator
1718

1819
class UnSupportedModelError(Exception):
1920
def __init__(self, model_type: str):
@@ -62,8 +63,9 @@ def get_model(
6263
In case of not supported model.
6364
"""
6465
model_type = operator_config.spec.model
65-
if model_type == "auto-select":
66+
if model_type == AUTO_SELECT:
6667
model_type = cls.auto_select_model(datasets, operator_config)
68+
operator_config.spec.model_kwargs = dict()
6769
if model_type not in cls._MAP:
6870
raise UnSupportedModelError(model_type)
6971
return cls._MAP[model_type](config=operator_config, datasets=datasets)
@@ -88,7 +90,8 @@ def auto_select_model(
8890
str
8991
The type of the model.
9092
"""
91-
from ..model_evaluator import ModelEvaluator
92-
all_models = cls._MAP.keys()
93-
model_evaluator = ModelEvaluator(all_models)
93+
all_models = operator_config.spec.model_kwargs.get("model_list", cls._MAP.keys())
94+
num_backtests = operator_config.spec.model_kwargs.get("num_backtests", 5)
95+
sample_ratio = operator_config.spec.model_kwargs.get("sample_ratio", 0.20)
96+
model_evaluator = ModelEvaluator(all_models, num_backtests, sample_ratio)
9497
return model_evaluator.find_best_model(datasets, operator_config)

ads/opctl/operator/lowcode/forecast/model_evaluator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,22 @@
1515

1616

1717
class ModelEvaluator:
18+
"""
19+
A class used to evaluate and determine the best model or framework from a given set of candidates.
20+
21+
This class is responsible for comparing different models or frameworks based on specified evaluation
22+
metrics and returning the best-performing option.
23+
"""
1824
def __init__(self, models, k=5, subsample_ratio=0.20):
25+
"""
26+
Initializes the ModelEvaluator with a list of models, number of backtests and subsample ratio.
27+
28+
Properties:
29+
----------
30+
models (list): The list of model to be evaluated.
31+
k (int): The number of times each model is backtested to verify its performance.
32+
subsample_ratio (float): The proportion of the data used in the evaluation process.
33+
"""
1934
self.models = models
2035
self.k = k
2136
self.subsample_ratio = subsample_ratio
@@ -83,6 +98,7 @@ def create_operator_config(self, operator_config, backtest, model, historical_da
8398
backtest_spec["additional_data"]["url"] = additional_data_url
8499
backtest_spec["test_data"]["url"] = test_data_url
85100
backtest_spec["model"] = model
101+
backtest_spec['model_kwargs'] = None
86102
backtest_spec["output_directory"] = {"url": output_file_path}
87103
backtest_spec["target_category_columns"] = [DataColumns.Series]
88104
backtest_spec['generate_explanations'] = False

ads/opctl/operator/lowcode/forecast/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from .operator_config import ForecastOperatorSpec, ForecastOperatorConfig
3434
from ads.opctl.operator.lowcode.common.utils import merge_category_columns
3535
from ads.opctl.operator.lowcode.forecast.const import ForecastOutputColumns
36-
# from ads.opctl.operator.lowcode.forecast.model.forecast_datasets import TestData, ForecastOutput
36+
import report_creator as rc
3737

3838

3939
def _label_encode_dataframe(df, no_encode=set()):
@@ -256,8 +256,6 @@ def evaluate_train_metrics(output, metrics_col_name=None):
256256

257257

258258
def _select_plot_list(fn, series_ids):
259-
import report_creator as rc
260-
261259
blocks = [rc.Widget(fn(s_id=s_id), label=s_id) for s_id in series_ids]
262260
return rc.Select(blocks=blocks) if len(blocks) > 1 else blocks[0]
263261

@@ -280,7 +278,6 @@ def get_auto_select_plot(backtest_results):
280278
name=column,
281279
))
282280

283-
import report_creator as rc
284281
return rc.Widget(fig)
285282

286283

0 commit comments

Comments
 (0)