Skip to content

Commit

Permalink
Feature LOF outlier detector (#746)
Browse files Browse the repository at this point in the history
* Add LOF torch backend

* Add LOF Frontend

* Add tests
  • Loading branch information
mauicv authored Jun 12, 2023
1 parent df48c94 commit 5e69f4b
Show file tree
Hide file tree
Showing 7 changed files with 874 additions and 2 deletions.
217 changes: 217 additions & 0 deletions alibi_detect/od/_lof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
from typing import Callable, Union, Optional, Dict, Any, List, Tuple
from typing import TYPE_CHECKING
from typing_extensions import Literal

import numpy as np

from alibi_detect.base import outlier_prediction_dict
from alibi_detect.exceptions import _catch_error as catch_error
from alibi_detect.od.base import TransformProtocol, TransformProtocolType
from alibi_detect.base import BaseDetector, FitMixin, ThresholdMixin
from alibi_detect.od.pytorch import LOFTorch, Ensembler
from alibi_detect.od.base import get_aggregator, get_normalizer, NormalizerLiterals, AggregatorLiterals
from alibi_detect.utils.frameworks import BackendValidator
from alibi_detect.version import __version__


if TYPE_CHECKING:
import torch


backends = {
'pytorch': (LOFTorch, Ensembler)
}


class LOF(BaseDetector, FitMixin, ThresholdMixin):
def __init__(
self,
k: Union[int, np.ndarray, List[int], Tuple[int]],
kernel: Optional[Callable] = None,
normalizer: Optional[Union[TransformProtocolType, NormalizerLiterals]] = 'PValNormalizer',
aggregator: Union[TransformProtocol, AggregatorLiterals] = 'AverageAggregator',
backend: Literal['pytorch'] = 'pytorch',
device: Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']] = None,
) -> None:
"""
Local Outlier Factor (LOF) outlier detector.
The LOF detector is a non-parametric method for outlier detection. It computes the local density
deviation of a given data point with respect to its neighbors. It considers as outliers the
samples that have a substantially lower density than their neighbors.
The detector can be initialized with `k` a single value or an array of values. If `k` is a single value then
the score method uses the distance/kernel similarity to the k-th nearest neighbor. If `k` is an array of
values then the score method uses the distance/kernel similarity to each of the specified `k` neighbors.
In the latter case, an `aggregator` must be specified to aggregate the scores.
Note that, in the multiple k case, a normalizer can be provided. If a normalizer is passed then it is fit in
the `infer_threshold` method and so this method must be called before the `predict` method. If this is not
done an exception is raised. If `k` is a single value then the predict method can be called without first
calling `infer_threshold` but only scores will be returned and not outlier predictions.
Parameters
----------
k
Number of nearest neighbors to compute distance to. `k` can be a single value or
an array of integers. If an array is passed, an aggregator is required to aggregate
the scores. If `k` is a single value we compute the local outlier factor for that `k`.
Otherwise if `k` is a list then we compute and aggregate the local outlier factor for each
value in `k`.
kernel
Kernel function to use for outlier detection. If ``None``, `torch.cdist` is used.
Otherwise if a kernel is specified then instead of using `torch.cdist` the kernel
defines the k nearest neighbor distance.
normalizer
Normalizer to use for outlier detection. If ``None``, no normalization is applied.
For a list of available normalizers, see :mod:`alibi_detect.od.pytorch.ensemble`.
aggregator
Aggregator to use for outlier detection. Can be set to ``None`` if `k` is a single
value. For a list of available aggregators, see :mod:`alibi_detect.od.pytorch.ensemble`.
backend
Backend used for outlier detection. Defaults to ``'pytorch'``. Options are ``'pytorch'``.
device
Device type used. The default tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of
``torch.device``.
Raises
------
ValueError
If `k` is an array and `aggregator` is None.
NotImplementedError
If choice of `backend` is not implemented.
"""
super().__init__()

backend_str: str = backend.lower()
BackendValidator(
backend_options={'pytorch': ['pytorch']},
construct_name=self.__class__.__name__
).verify_backend(backend_str)

backend_cls, ensembler_cls = backends[backend]
ensembler = None

if aggregator is None and isinstance(k, (list, np.ndarray, tuple)):
raise ValueError('If `k` is a `np.ndarray`, `list` or `tuple`, '
'the `aggregator` argument cannot be ``None``.')

if isinstance(k, (list, np.ndarray, tuple)):
ensembler = ensembler_cls(
normalizer=get_normalizer(normalizer),
aggregator=get_aggregator(aggregator)
)

self.backend = backend_cls(k, kernel=kernel, ensembler=ensembler, device=device)

# set metadata
self.meta['detector_type'] = 'outlier'
self.meta['data_type'] = 'numeric'
self.meta['online'] = False

def fit(self, x_ref: np.ndarray) -> None:
"""Fit the detector on reference data.
Parameters
----------
x_ref
Reference data used to fit the detector.
"""
self.backend.fit(self.backend._to_tensor(x_ref))

@catch_error('NotFittedError')
@catch_error('ThresholdNotInferredError')
def score(self, x: np.ndarray) -> np.ndarray:
"""Score `x` instances using the detector.
Computes the local outlier factor for each point in `x`. This is the density of each point `x`
relative to those of its neighbors in `x_ref`. If `k` is an array of values then the score for
each `k` is aggregated using the ensembler.
Parameters
----------
x
Data to score. The shape of `x` should be `(n_instances, n_features)`.
Returns
-------
Outlier scores. The shape of the scores is `(n_instances,)`. The higher the score, the more anomalous the \
instance.
Raises
------
NotFittedError
If called before detector has been fit.
ThresholdNotInferredError
If k is a list and a threshold was not inferred.
"""
score = self.backend.score(self.backend._to_tensor(x))
score = self.backend._ensembler(score)
return self.backend._to_numpy(score)

@catch_error('NotFittedError')
def infer_threshold(self, x: np.ndarray, fpr: float) -> None:
"""Infer the threshold for the LOF detector.
The threshold is computed so that the outlier detector would incorrectly classify `fpr` proportion of the
reference data as outliers.
Parameters
----------
x
Reference data used to infer the threshold.
fpr
False positive rate used to infer the threshold. The false positive rate is the proportion of
instances in `x` that are incorrectly classified as outliers. The false positive rate should
be in the range ``(0, 1)``.
Raises
------
ValueError
Raised if `fpr` is not in ``(0, 1)``.
NotFittedError
If called before detector has been fit.
"""
self.backend.infer_threshold(self.backend._to_tensor(x), fpr)

@catch_error('NotFittedError')
@catch_error('ThresholdNotInferredError')
def predict(self, x: np.ndarray) -> Dict[str, Any]:
"""Predict whether the instances in `x` are outliers or not.
Scores the instances in `x` and if the threshold was inferred, returns the outlier labels and p-values as well.
Parameters
----------
x
Data to predict. The shape of `x` should be `(n_instances, n_features)`.
Returns
-------
Dictionary with keys 'data' and 'meta'. 'data' contains the outlier scores. If threshold inference was \
performed, 'data' also contains the threshold value, outlier labels and p-vals . The shape of the scores is \
`(n_instances,)`. The higher the score, the more anomalous the instance. 'meta' contains information about \
the detector.
Raises
------
NotFittedError
If called before detector has been fit.
ThresholdNotInferredError
If k is a list and a threshold was not inferred.
"""
outputs = self.backend.predict(self.backend._to_tensor(x))
output = outlier_prediction_dict()
output['data'] = {
**output['data'],
**self.backend._to_numpy(outputs)
}
output['meta'] = {
**output['meta'],
'name': self.__class__.__name__,
'detector_type': 'outlier',
'online': False,
'version': __version__,
}
return output
1 change: 1 addition & 0 deletions alibi_detect/od/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from alibi_detect.utils.missing_optional_dependency import import_optional

KNNTorch = import_optional('alibi_detect.od.pytorch.knn', ['KNNTorch'])
LOFTorch = import_optional('alibi_detect.od.pytorch.lof', ['LOFTorch'])
MahalanobisTorch = import_optional('alibi_detect.od.pytorch.mahalanobis', ['MahalanobisTorch'])
KernelPCATorch, LinearPCATorch = import_optional('alibi_detect.od.pytorch.pca', ['KernelPCATorch', 'LinearPCATorch'])
Ensembler = import_optional('alibi_detect.od.pytorch.ensemble', ['Ensembler'])
Expand Down
164 changes: 164 additions & 0 deletions alibi_detect/od/pytorch/lof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import Optional, Union, List, Tuple
from typing_extensions import Literal
import numpy as np
import torch

from alibi_detect.od.pytorch.ensemble import Ensembler
from alibi_detect.od.pytorch.base import TorchOutlierDetector


class LOFTorch(TorchOutlierDetector):
def __init__(
self,
k: Union[np.ndarray, List, Tuple, int],
kernel: Optional[torch.nn.Module] = None,
ensembler: Optional[Ensembler] = None,
device: Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']] = None,
):
"""PyTorch backend for LOF detector.
Parameters
----------
k
Number of nearest neighbors used to compute the local outlier factor. `k` can be a single
value or an array of integers. If `k` is a single value the score method uses the
distance/kernel similarity to the `k`-th nearest neighbor. If `k` is a list then it uses
the distance/kernel similarity to each of the specified `k` neighbors.
kernel
If a kernel is specified then instead of using `torch.cdist` the kernel defines the `k` nearest
neighbor distance.
ensembler
If `k` is an array of integers then the ensembler must not be ``None``. Should be an instance
of :py:obj:`alibi_detect.od.pytorch.ensemble.ensembler`. Responsible for combining
multiple scores into a single score.
device
Device type used. The default tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of
``torch.device``.
"""
TorchOutlierDetector.__init__(self, device=device)
self.kernel = kernel
self.ensemble = isinstance(k, (np.ndarray, list, tuple))
self.ks = torch.tensor(k) if self.ensemble else torch.tensor([k], device=self.device)
self.ensembler = ensembler

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Detect if `x` is an outlier.
Parameters
----------
x
`torch.Tensor` with leading batch dimension.
Returns
-------
`torch.Tensor` of ``bool`` values with leading batch dimension.
Raises
------
ThresholdNotInferredError
If called before detector has had `infer_threshold` method called.
"""
raw_scores = self.score(x)
scores = self._ensembler(raw_scores)
if not torch.jit.is_scripting():
self.check_threshold_inferred()
preds = scores > self.threshold
return preds

def _make_mask(self, reachabilities: torch.Tensor):
"""Generate a mask for computing the average reachability.
If k is an array then we need to compute the average reachability for each k separately. To do
this we use a mask to weight the reachability of each k-close neighbor by 1/k and the rest to 0.
"""
mask = torch.zeros_like(reachabilities[0], device=self.device)
for i, k in enumerate(self.ks):
mask[:k, i] = torch.ones(k, device=self.device)/k
return mask

def _compute_K(self, x, y):
"""Compute the distance matrix matrix between `x` and `y`."""
return torch.exp(-self.kernel(x, y)) if self.kernel is not None else torch.cdist(x, y)

def score(self, x: torch.Tensor) -> torch.Tensor:
"""Computes the score of `x`
Parameters
----------
x
The tensor of instances. First dimension corresponds to batch.
Returns
-------
Tensor of scores for each element in `x`.
Raises
------
NotFittedError
If called before detector has been fit.
"""
self.check_fitted()

# compute the distance matrix between x and x_ref
K = self._compute_K(x, self.x_ref)

# compute k nearest neighbors for maximum k in self.ks
max_k = torch.max(self.ks)
bot_k_items = torch.topk(K, int(max_k), dim=1, largest=False)
bot_k_inds, bot_k_dists = bot_k_items.indices, bot_k_items.values

# To compute the reachabilities we get the k-distances of each object in the instances
# k nearest neighbors. Then we take the maximum of their k-distances and the distance
# to the instance.
lower_bounds = self.knn_dists_ref[bot_k_inds]
reachabilities = torch.max(bot_k_dists[:, :, None], lower_bounds)

# Compute the average reachability for each instance. We use a mask to manage each k in
# self.ks separately.
mask = self._make_mask(reachabilities)
avg_reachabilities = (reachabilities*mask[None, :, :]).sum(1)

# Compute the LOF score for each instance. Note we don't take 1/avg_reachabilities as
# avg_reachabilities is the denominator in the LOF formula.
factors = (self.ref_inv_avg_reachabilities[bot_k_inds] * mask[None, :, :]).sum(1)
lofs = (avg_reachabilities * factors)
return lofs if self.ensemble else lofs[:, 0]

def fit(self, x_ref: torch.Tensor):
"""Fits the detector
Parameters
----------
x_ref
The Dataset tensor.
"""
# compute the distance matrix
K = self._compute_K(x_ref, x_ref)
# set diagonal to max distance to prevent torch.topk from returning the instance itself
K += torch.eye(len(K), device=self.device) * torch.max(K)

# compute k nearest neighbors for maximum k in self.ks
max_k = torch.max(self.ks)
bot_k_items = torch.topk(K, int(max_k), dim=1, largest=False)
bot_k_inds, bot_k_dists = bot_k_items.indices, bot_k_items.values

# store the k-distances for each instance for each k.
self.knn_dists_ref = bot_k_dists[:, self.ks-1]

# To compute the reachabilities we get the k-distances of each object in the instances
# k nearest neighbors. Then we take the maximum of their k-distances and the distance
# to the instance.
lower_bounds = self.knn_dists_ref[bot_k_inds]
reachabilities = torch.max(bot_k_dists[:, :, None], lower_bounds)

# Compute the average reachability for each instance. We use a mask to manage each k in
# self.ks separately.
mask = self._make_mask(reachabilities)
avg_reachabilities = (reachabilities*mask[None, :, :]).sum(1)

# Compute the inverse average reachability for each instance.
self.ref_inv_avg_reachabilities = 1/avg_reachabilities

self.x_ref = x_ref
self._set_fitted()
Loading

0 comments on commit 5e69f4b

Please sign in to comment.