diff --git a/sdks/python/apache_beam/ml/anomaly/detectors/pyod_adapter.py b/sdks/python/apache_beam/ml/anomaly/detectors/pyod_adapter.py index 10bd25514761..73dad118e17a 100644 --- a/sdks/python/apache_beam/ml/anomaly/detectors/pyod_adapter.py +++ b/sdks/python/apache_beam/ml/anomaly/detectors/pyod_adapter.py @@ -15,95 +15,106 @@ # limitations under the License. # + +"""Utilities to adapt PyOD models for Beam's anomaly detection APIs. + +Provides a ModelHandler implementation for PyOD detectors and a factory for +creating OfflineDetector wrappers around pickled PyOD models. +""" + import pickle -from collections.abc import Iterable -from collections.abc import Sequence -from typing import Any -from typing import Optional +from collections.abc import Iterable, Sequence +from typing import Optional, cast import numpy as np +from pyod.models.base import BaseDetector as PyODBaseDetector import apache_beam as beam from apache_beam.io.filesystems import FileSystems from apache_beam.ml.anomaly.detectors.offline import OfflineDetector from apache_beam.ml.anomaly.specifiable import specifiable from apache_beam.ml.anomaly.thresholds import FixedThreshold -from apache_beam.ml.inference.base import KeyedModelHandler -from apache_beam.ml.inference.base import ModelHandler -from apache_beam.ml.inference.base import PredictionResult -from apache_beam.ml.inference.base import _PostProcessingModelHandler +from apache_beam.ml.inference.base import ( + KeyedModelHandler, + ModelHandler, + PredictionResult, + _PostProcessingModelHandler, +) from apache_beam.ml.inference.utils import _convert_to_result -from pyod.models.base import BaseDetector as PyODBaseDetector + # Turn the used ModelHandler into specifiable, but without lazy init. KeyedModelHandler = specifiable( # type: ignore[misc] KeyedModelHandler, on_demand_init=False, - just_in_time_init=False) + just_in_time_init=False, +) + _PostProcessingModelHandler = specifiable( # type: ignore[misc] _PostProcessingModelHandler, on_demand_init=False, - just_in_time_init=False) + just_in_time_init=False, +) @specifiable -class PyODModelHandler(ModelHandler[beam.Row, - PredictionResult, - PyODBaseDetector]): - """Implementation of the ModelHandler interface for PyOD [#]_ Models. - - The ModelHandler processes input data as `beam.Row` objects. +class PyODModelHandler(ModelHandler[beam.Row, PredictionResult, PyODBaseDetector]): + """ModelHandler implementation for PyOD models. - **NOTE:** This API and its implementation are currently under active - development and may not be backward compatible. - - Args: - model_uri: The URI specifying the location of the pickled PyOD model. - - .. [#] https://github.com/yzhao062/pyod - """ - def __init__(self, model_uri: str): - self._model_uri = model_uri - - def load_model(self) -> PyODBaseDetector: - file = FileSystems.open(self._model_uri, 'rb') - return pickle.load(file) - - def run_inference( - self, - batch: Sequence[beam.Row], - model: PyODBaseDetector, - inference_args: Optional[dict[str, Any]] = None - ) -> Iterable[PredictionResult]: - np_batch = [] - for row in batch: - np_batch.append(np.fromiter(row, dtype=np.float64)) - - # stack a batch of samples into a 2-D array for better performance - vectorized_batch = np.stack(np_batch, axis=0) - predictions = model.decision_function(vectorized_batch) - - return _convert_to_result(batch, predictions, model_id=self._model_uri) - - -class PyODFactory(): - @staticmethod - def create_detector(model_uri: str, **kwargs) -> OfflineDetector: - """A utility function to create OfflineDetector for a PyOD model. - - **NOTE:** This API and its implementation are currently under active - development and may not be backward compatible. + Processes `beam.Row` inputs, flattening vector-like fields (list/tuple/ndarray) + into a single numeric feature vector per row before invoking the PyOD + model's ``decision_function``. + NOTE: Experimental; interface may change. Args: - model_uri: The URI specifying the location of the pickled PyOD model. - **kwargs: Additional keyword arguments. + model_uri: Location of the pickled PyOD model. """ - model_handler = KeyedModelHandler( - PyODModelHandler(model_uri=model_uri)).with_postprocess_fn( + + def __init__(self, model_uri: str): + super().__init__() + self._model_uri = model_uri + + def load_model(self) -> PyODBaseDetector: + with FileSystems.open(self._model_uri, 'rb') as file: + return pickle.load(file) + + def run_inference( # type: ignore[override] + self, + batch: Sequence[beam.Row], + model: PyODBaseDetector, + inference_args: Optional[dict[str, object]] = None, + ) -> Iterable[PredictionResult]: + """Run inference on a batch of rows. + + Flattens vector-like fields, stacks the batch into a 2-D array and + returns PredictionResult objects. + """ + + def _flatten_row(row_values): + for value in row_values: + if isinstance(value, (list, tuple, np.ndarray)): + yield from value + else: + yield value + + np_batch = [np.fromiter(_flatten_row(row), dtype=np.float64) for row in batch] + vectorized_batch = np.stack(np_batch, axis=0) + predictions = model.decision_function(vectorized_batch) + return _convert_to_result(batch, predictions, model_id=self._model_uri) + + +class PyODFactory: + """Factory helpers to create OfflineDetector instances from PyOD models.""" + + @staticmethod + def create_detector(model_uri: str, **kwargs) -> OfflineDetector: + handler = KeyedModelHandler(PyODModelHandler(model_uri=model_uri)).with_postprocess_fn( OfflineDetector.score_prediction_adapter) - m = model_handler.load_model() - assert (isinstance(m, PyODBaseDetector)) - threshold = float(m.threshold_) - detector = OfflineDetector( - model_handler, threshold_criterion=FixedThreshold(threshold), **kwargs) # type: ignore[arg-type] - return detector + model = handler.load_model() + assert isinstance(model, PyODBaseDetector) + threshold = float(model.threshold_) + return OfflineDetector( + cast(object, handler), + threshold_criterion=FixedThreshold(threshold), + **kwargs, + )