diff --git a/alibi_detect/cd/base.py b/alibi_detect/cd/base.py index 937a08c7d..9c1d925f2 100644 --- a/alibi_detect/cd/base.py +++ b/alibi_detect/cd/base.py @@ -506,7 +506,7 @@ def __init__( preprocess_fn: Optional[Callable] = None, sigma: Optional[np.ndarray] = None, configure_kernel_from_x_ref: bool = True, - n_permutations: int = 100, + n_permutations: int = None, input_shape: Optional[tuple] = None, data_type: Optional[str] = None ) -> None: diff --git a/alibi_detect/cd/mmd.py b/alibi_detect/cd/mmd.py index 4391f2ccd..38f8f10f8 100644 --- a/alibi_detect/cd/mmd.py +++ b/alibi_detect/cd/mmd.py @@ -4,12 +4,13 @@ from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow, BackendValidator from alibi_detect.utils.warnings import deprecated_alias from alibi_detect.base import DriftConfigMixin +from alibi_detect.utils._types import Literal if has_pytorch: - from alibi_detect.cd.pytorch.mmd import MMDDriftTorch + from alibi_detect.cd.pytorch.mmd import MMDDriftTorch, LinearTimeMMDDriftTorch if has_tensorflow: - from alibi_detect.cd.tensorflow.mmd import MMDDriftTF + from alibi_detect.cd.tensorflow.mmd import MMDDriftTF, LinearTimeMMDDriftTF logger = logging.getLogger(__name__) @@ -21,6 +22,7 @@ def __init__( x_ref: Union[np.ndarray, list], backend: str = 'tensorflow', p_val: float = .05, + estimator: Literal['quad', 'linear'] = 'quad', x_ref_preprocessed: bool = False, preprocess_at_init: bool = True, update_x_ref: Optional[Dict[str, int]] = None, @@ -44,6 +46,11 @@ def __init__( Backend used for the MMD implementation. p_val p-value used for the significance of the permutation test. + estimator + Estimator used for the MMD^2 computation. 'quad' is the default and + uses the quadratic u-statistics on each square kernel matrix. 'linear' uses the linear + time estimator as in Gretton et al. (JMLR 2014, sec 6), and the threshold is computed + using the Gaussian asympotic distribution under null. x_ref_preprocessed Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference @@ -65,7 +72,8 @@ def __init__( configure_kernel_from_x_ref Whether to already configure the kernel bandwidth from the reference data. n_permutations - Number of permutations used in the permutation test. + Number of permutations used in the permutation test, only used for the quadratic estimator + (estimator='quad'). device Device type used. The default None tries to use the GPU and falls back on CPU if needed. Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend. @@ -80,6 +88,7 @@ def __init__( self._set_config(locals()) backend = backend.lower() + estimator = estimator.lower() # type: ignore BackendValidator( backend_options={'tensorflow': ['tensorflow'], 'pytorch': ['pytorch']}, @@ -88,7 +97,7 @@ def __init__( kwargs = locals() args = [kwargs['x_ref']] - pop_kwargs = ['self', 'x_ref', 'backend', '__class__'] + pop_kwargs = ['self', 'x_ref', 'backend', '__class__', 'estimator'] [kwargs.pop(k, None) for k in pop_kwargs] if kernel is None: @@ -100,9 +109,21 @@ def __init__( if backend == 'tensorflow' and has_tensorflow: kwargs.pop('device', None) - self._detector = MMDDriftTF(*args, **kwargs) # type: ignore + if estimator == 'quad': + self._detector = MMDDriftTF(*args, **kwargs) # type: ignore + elif estimator == 'linear': + kwargs.pop('n_permutations', None) + self._detector = LinearTimeMMDDriftTF(*args, **kwargs) # type: ignore + else: + raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.') else: - self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore + if estimator == 'quad': + self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore + elif estimator == 'linear': + kwargs.pop('n_permutations', None) + self._detector = LinearTimeMMDDriftTorch(*args, **kwargs) # type: ignore + else: + raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.') self.meta = self._detector.meta def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True, return_distance: bool = True) \ @@ -139,7 +160,7 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]: Returns ------- - p-value obtained from the permutation test, the MMD^2 between the reference and test set, + p-value obtained from the test, the MMD^2 between the reference and test set, and the MMD^2 threshold above which drift is flagged. """ return self._detector.score(x) diff --git a/alibi_detect/cd/pytorch/mmd.py b/alibi_detect/cd/pytorch/mmd.py index 97ebc4790..ea1a70233 100644 --- a/alibi_detect/cd/pytorch/mmd.py +++ b/alibi_detect/cd/pytorch/mmd.py @@ -1,10 +1,11 @@ import logging import numpy as np +import scipy.stats as stats import torch from typing import Callable, Dict, Optional, Tuple, Union from alibi_detect.cd.base import BaseMMDDrift +from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix, linear_mmd2 from alibi_detect.utils.pytorch import get_device -from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix from alibi_detect.utils.pytorch.kernels import GaussianRBF from alibi_detect.utils.warnings import deprecated_alias @@ -123,6 +124,7 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]: and the MMD^2 threshold above which drift is flagged. """ x_ref, x = self.preprocess(x) + n = x.shape[0] x_ref = torch.from_numpy(x_ref).to(self.device) # type: ignore[assignment] x = torch.from_numpy(x).to(self.device) # type: ignore[assignment] # compute kernel matrix, MMD^2 and apply permutation test using the kernel matrix @@ -130,10 +132,11 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]: n = x.shape[0] # type: ignore kernel_mat = self.kernel_matrix(x_ref, x) # type: ignore[arg-type] kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal - mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False) + mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False) # type: ignore[assignment] mmd2_permuted = torch.Tensor( - [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)] - ) + [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) + for _ in range(self.n_permutations)] + ) if self.device.type == 'cuda': mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu() p_val = (mmd2 <= mmd2_permuted).float().mean() @@ -141,3 +144,142 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]: idx_threshold = int(self.p_val * len(mmd2_permuted)) distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold] return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy() + + +class LinearTimeMMDDriftTorch(BaseMMDDrift): + def __init__( + self, + x_ref: Union[np.ndarray, list], + p_val: float = .05, + x_ref_preprocessed: bool = False, + preprocess_at_init: bool = True, + update_x_ref: Optional[Dict[str, int]] = None, + preprocess_fn: Optional[Callable] = None, + kernel: Callable = GaussianRBF, + sigma: Optional[np.ndarray] = None, + configure_kernel_from_x_ref: bool = True, + device: Optional[str] = None, + input_shape: Optional[tuple] = None, + data_type: Optional[str] = None + ) -> None: + """ + Maximum Mean Discrepancy (MMD) data drift detector using a linear-time estimator. + + Parameters + ---------- + x_ref + Data used as reference distribution. + p_val + p-value used for the significance of the permutation test. + x_ref_preprocessed + Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only + the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference + data will also be preprocessed. + preprocess_at_init + Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference + data will be preprocessed at prediction time. Only applies if `x_ref_preprocessed=False`. + update_x_ref + Reference data can optionally be updated to the last n instances seen by the detector + or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while + for reservoir sampling {'reservoir_sampling': n} is passed. + preprocess_fn + Function to preprocess the data before computing the data drift metrics. + kernel + Kernel used for the MMD computation, defaults to Gaussian RBF kernel. + sigma + Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array. + The kernel evaluation is then averaged over those bandwidths. + configure_kernel_from_x_ref + Whether to already configure the kernel bandwidth from the reference data. + device + Device type used. The default None tries to use the GPU and falls back on CPU if needed. + Can be specified by passing either 'cuda', 'gpu' or 'cpu'. + input_shape + Shape of input data. + data_type + Optionally specify the data type (tabular, image or time-series). Added to metadata. + """ + super().__init__( + x_ref=x_ref, + p_val=p_val, + x_ref_preprocessed=x_ref_preprocessed, + preprocess_at_init=preprocess_at_init, + update_x_ref=update_x_ref, + preprocess_fn=preprocess_fn, + sigma=sigma, + configure_kernel_from_x_ref=configure_kernel_from_x_ref, + input_shape=input_shape, + data_type=data_type + ) + self.meta.update({'backend': 'pytorch'}) + + # set backend + if device is None or device.lower() in ['gpu', 'cuda']: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if self.device.type == 'cpu': + print('No GPU detected, fall back on CPU.') + else: + self.device = torch.device('cpu') + + # initialize kernel + sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment] + np.ndarray) else None + self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel + + # compute kernel matrix for the reference data + if self.infer_sigma or isinstance(sigma, torch.Tensor): + n = self.x_ref.shape[0] + n_hat = int(np.floor(n / 2) * 2) + x = torch.from_numpy(self.x_ref[:n_hat, :]).to(self.device) + self.k_xx = self.kernel(x=x[0::2, :], y=x[1::2, :], + pairwise=False, infer_sigma=self.infer_sigma) + self.infer_sigma = False + else: + self.k_xx, self.infer_sigma = None, True + + def kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ Compute and return full kernel matrix between arrays x and y. """ + k_xy = self.kernel(x, y, self.infer_sigma) + k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x) + k_yy = self.kernel(y, y) + kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0) + return kernel_mat + + def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]: + """ + Compute the p-value using the maximum mean discrepancy as a distance measure between the + reference data and the data to be tested. x and x_ref are required to have the same size. + The sample size is then specified as the maximal even number below the data size. + + Parameters + ---------- + x + Batch of instances. + + Returns + ------- + p-value obtained from the null hypothesis, the MMD^2 between the reference and test set + and the MMD^2 threshold for the given significance level. + """ + x_ref, x = self.preprocess(x) + n = x.shape[0] + m = x_ref.shape[0] + if n != m: + raise ValueError('x and x_ref must have the same size.') + n_hat = int(np.floor(n / 2) * 2) + x_ref = torch.from_numpy(x_ref[:n_hat, :]).to(self.device) # type: ignore[assignment] + x = torch.from_numpy(x[:n_hat, :]).to(self.device) # type: ignore[assignment] + if self.k_xx is not None and self.update_x_ref is None: + k_xx = self.k_xx + else: + k_xx = self.kernel(x=x_ref[0::2, :], y=x_ref[1::2, :], pairwise=False) + mmd2, var_mmd2 = linear_mmd2(k_xx, x_ref, x, self.kernel) # type: ignore[arg-type] + if self.device.type == 'cuda': + mmd2 = mmd2.cpu() + mmd2 = mmd2.numpy().item() + var_mmd2 = np.clip(var_mmd2.numpy().item(), 1e-8, 1e8) + std_mmd2 = np.sqrt(var_mmd2) + t = mmd2 / (std_mmd2 / np.sqrt(n_hat / 2.)) + p_val = 1 - stats.t.cdf(t, df=(n_hat / 2.) - 1) + distance_threshold = stats.t.ppf(1 - self.p_val, df=(n_hat / 2.) - 1) + return p_val, t, distance_threshold diff --git a/alibi_detect/cd/pytorch/tests/test_linear_time_mmd_pt.py b/alibi_detect/cd/pytorch/tests/test_linear_time_mmd_pt.py new file mode 100644 index 000000000..ffd46505c --- /dev/null +++ b/alibi_detect/cd/pytorch/tests/test_linear_time_mmd_pt.py @@ -0,0 +1,96 @@ +from functools import partial +from itertools import product +import numpy as np +import pytest +import torch +import torch.nn as nn +from typing import Callable, List +from alibi_detect.cd.pytorch.mmd import LinearTimeMMDDriftTorch +from alibi_detect.cd.pytorch.preprocess import HiddenOutput, preprocess_drift + +n, n_hidden, n_classes = 500, 10, 5 + + +class MyModel(nn.Module): + def __init__(self, n_features: int): + super().__init__() + self.dense1 = nn.Linear(n_features, 20) + self.dense2 = nn.Linear(20, 2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = nn.ReLU()(self.dense1(x)) + return self.dense2(x) + + +# test List[Any] inputs to the detector +def preprocess_list(x: List[np.ndarray]) -> np.ndarray: + return np.concatenate(x, axis=0) + + +n_features = [10] +n_enc = [None, 3] +preprocess = [ + (None, None), + (preprocess_drift, {'model': HiddenOutput, 'layer': -1}), + (preprocess_list, None) +] +update_x_ref = [{'last': 500}, {'reservoir_sampling': 500}, None] +preprocess_at_init = [True, False] +tests_mmddrift = list(product(n_features, n_enc, preprocess, + update_x_ref, preprocess_at_init)) +n_tests = len(tests_mmddrift) + + +@pytest.fixture +def mmd_params(request): + return tests_mmddrift[request.param] + + +@pytest.mark.parametrize('mmd_params', list(range(n_tests)), indirect=True) +def test_mmd(mmd_params): + n_features, n_enc, preprocess, update_x_ref, preprocess_at_init = mmd_params + + np.random.seed(0) + torch.manual_seed(0) + + x_ref = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32) + preprocess_fn, preprocess_kwargs = preprocess + to_list = False + if hasattr(preprocess_fn, '__name__') and preprocess_fn.__name__ == 'preprocess_list': + if not preprocess_at_init: + return + to_list = True + x_ref = [_[None, :] for _ in x_ref] + elif isinstance(preprocess_fn, Callable) and 'layer' in list(preprocess_kwargs.keys()) \ + and preprocess_kwargs['model'].__name__ == 'HiddenOutput': + model = MyModel(n_features) + layer = preprocess_kwargs['layer'] + preprocess_fn = partial(preprocess_fn, model=HiddenOutput(model=model, layer=layer)) + else: + preprocess_fn = None + + cd = LinearTimeMMDDriftTorch( + x_ref=x_ref, + p_val=.05, + preprocess_at_init=preprocess_at_init if isinstance(preprocess_fn, Callable) else False, + update_x_ref=update_x_ref, + preprocess_fn=preprocess_fn + ) + x = x_ref.copy() + preds = cd.predict(x, return_p_val=True) + assert preds['data']['is_drift'] == 0 and preds['data']['p_val'] >= cd.p_val + if isinstance(update_x_ref, dict): + k = list(update_x_ref.keys())[0] + assert cd.n == len(x) + len(x_ref) + assert cd.x_ref.shape[0] == min(update_x_ref[k], len(x) + len(x_ref)) + + x_h1 = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32) + if to_list: + x_h1 = [_[None, :] for _ in x_h1] + preds = cd.predict(x_h1, return_p_val=True) + if preds['data']['is_drift'] == 1: + assert preds['data']['p_val'] < preds['data']['threshold'] == cd.p_val + assert preds['data']['distance'] > preds['data']['distance_threshold'] + else: + assert preds['data']['p_val'] >= preds['data']['threshold'] == cd.p_val + assert preds['data']['distance'] <= preds['data']['distance_threshold'] diff --git a/alibi_detect/cd/tensorflow/mmd.py b/alibi_detect/cd/tensorflow/mmd.py index 1de8d908a..10b4d3a11 100644 --- a/alibi_detect/cd/tensorflow/mmd.py +++ b/alibi_detect/cd/tensorflow/mmd.py @@ -1,9 +1,10 @@ import logging import numpy as np +import scipy.stats as stats import tensorflow as tf from typing import Callable, Dict, Optional, Tuple, Union from alibi_detect.cd.base import BaseMMDDrift -from alibi_detect.utils.tensorflow.distance import mmd2_from_kernel_matrix +from alibi_detect.utils.tensorflow.distance import mmd2_from_kernel_matrix, linear_mmd2 from alibi_detect.utils.tensorflow.kernels import GaussianRBF from alibi_detect.utils.warnings import deprecated_alias @@ -121,10 +122,135 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]: mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False).numpy() mmd2_permuted = np.array( [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False).numpy() - for _ in range(self.n_permutations)] - ) + for _ in range(self.n_permutations)]) p_val = (mmd2 <= mmd2_permuted).mean() # compute distance threshold idx_threshold = int(self.p_val * len(mmd2_permuted)) distance_threshold = np.sort(mmd2_permuted)[::-1][idx_threshold] return p_val, mmd2, distance_threshold + + +class LinearTimeMMDDriftTF(BaseMMDDrift): + def __init__( + self, + x_ref: Union[np.ndarray, list], + p_val: float = .05, + x_ref_preprocessed: bool = False, + preprocess_at_init: bool = True, + update_x_ref: Optional[Dict[str, int]] = None, + preprocess_fn: Optional[Callable] = None, + kernel: Callable = GaussianRBF, + sigma: Optional[np.ndarray] = None, + configure_kernel_from_x_ref: bool = True, + input_shape: Optional[tuple] = None, + data_type: Optional[str] = None + ) -> None: + """ + Maximum Mean Discrepancy (MMD) data drift detector using a linear-time estimator. + + Parameters + ---------- + x_ref + Data used as reference distribution. + p_val + p-value used for the significance of the permutation test. + x_ref_preprocessed + Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only + the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference + data will also be preprocessed. + preprocess_at_init + Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference + data will be preprocessed at prediction time. Only applies if `x_ref_preprocessed=False`. + update_x_ref + Reference data can optionally be updated to the last n instances seen by the detector + or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while + for reservoir sampling {'reservoir_sampling': n} is passed. + preprocess_fn + Function to preprocess the data before computing the data drift metrics. + kernel + Kernel used for the MMD computation, defaults to Gaussian RBF kernel. + sigma + Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array. + The kernel evaluation is then averaged over those bandwidths. + configure_kernel_from_x_ref + Whether to already configure the kernel bandwidth from the reference data. + input_shape + Shape of input data. + data_type + Optionally specify the data type (tabular, image or time-series). Added to metadata. + """ + super().__init__( + x_ref=x_ref, + p_val=p_val, + x_ref_preprocessed=x_ref_preprocessed, + preprocess_at_init=preprocess_at_init, + update_x_ref=update_x_ref, + preprocess_fn=preprocess_fn, + sigma=sigma, + configure_kernel_from_x_ref=configure_kernel_from_x_ref, + input_shape=input_shape, + data_type=data_type + ) + self.meta.update({'backend': 'tensorflow'}) + + # initialize kernel + if isinstance(sigma, np.ndarray): + sigma = tf.convert_to_tensor(sigma) + self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel + + # compute kernel matrix for the reference data + if self.infer_sigma or isinstance(sigma, tf.Tensor): + n = self.x_ref.shape[0] + n_hat = int(np.floor(n / 2) * 2) + x = self.x_ref[:n_hat, :] + self.k_xx = self.kernel(x=x[0::2, :], y=x[1::2, :], + pairwise=False, infer_sigma=self.infer_sigma) + self.infer_sigma = False + else: + self.k_xx, self.infer_sigma = None, True + + def kernel_matrix(self, x: Union[np.ndarray, tf.Tensor], y: Union[np.ndarray, tf.Tensor]) -> tf.Tensor: + """ Compute and return full kernel matrix between arrays x and y. """ + k_xy = self.kernel(x, y, self.infer_sigma) + k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x) + k_yy = self.kernel(y, y) + kernel_mat = tf.concat([tf.concat([k_xx, k_xy], 1), tf.concat([tf.transpose(k_xy, (1, 0)), k_yy], 1)], 0) + return kernel_mat + + def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]: + """ + Compute the p-value using the maximum mean discrepancy as a distance measure between the + reference data and the data to be tested. The sample size is specified as the maximal even + number in the smaller dataset between x and x_ref. + + Parameters + ---------- + x + Batch of instances. + + Returns + ------- + p-value obtained from the null hypothesis, the MMD^2 between the reference and test set + and the MMD^2 threshold for the given significance level. + """ + x_ref, x = self.preprocess(x) + # compute kernel matrix, MMD^2 and apply permutation test using the kernel matrix + n = x.shape[0] + m = x_ref.shape[0] + if n != m: + raise ValueError('x and x_ref must have the same size.') + n_hat = int(np.floor(n / 2) * 2) + x_ref = x_ref[:n_hat, :] + x = x[:n_hat, :] + if self.k_xx is not None and self.update_x_ref is None: + k_xx = self.k_xx + else: + k_xx = self.kernel(x=x_ref[0::2, :], y=x_ref[1::2, :], pairwise=False) + mmd2, var_mmd2 = linear_mmd2(k_xx, x_ref, x, self.kernel) + mmd2 = mmd2.numpy() + var_mmd2 = np.clip(var_mmd2.numpy(), 1e-8, 1e8) + std_mmd2 = np.sqrt(var_mmd2) + t = mmd2 / (std_mmd2 / np.sqrt(n_hat / 2.)) + p_val = 1 - stats.t.cdf(t, df=(n_hat / 2.) - 1) + distance_threshold = stats.t.ppf(1 - self.p_val, df=(n_hat / 2.) - 1) + return p_val, t, distance_threshold diff --git a/alibi_detect/cd/tensorflow/tests/test_linear_time_mmd_tf.py b/alibi_detect/cd/tensorflow/tests/test_linear_time_mmd_tf.py new file mode 100644 index 000000000..48377cb10 --- /dev/null +++ b/alibi_detect/cd/tensorflow/tests/test_linear_time_mmd_tf.py @@ -0,0 +1,107 @@ +from functools import partial +from itertools import product +import numpy as np +import pytest +import tensorflow as tf +from tensorflow.keras.layers import Dense, Input, InputLayer +from typing import Callable, List +from alibi_detect.cd.tensorflow.mmd import LinearTimeMMDDriftTF +from alibi_detect.cd.tensorflow.preprocess import HiddenOutput, UAE, preprocess_drift + +n, n_hidden, n_classes = 500, 10, 5 + +tf.random.set_seed(0) + + +def mymodel(shape): + x_in = Input(shape=shape) + x = Dense(n_hidden)(x_in) + x_out = Dense(n_classes, activation='softmax')(x) + return tf.keras.models.Model(inputs=x_in, outputs=x_out) + + +# test List[Any] inputs to the detector +def preprocess_list(x: List[np.ndarray]) -> np.ndarray: + return np.concatenate(x, axis=0) + + +n_features = [10] +n_enc = [None, 3] +preprocess = [ + (None, None), + (preprocess_drift, {'model': HiddenOutput, 'layer': -1}), + (preprocess_drift, {'model': UAE}), + (preprocess_list, None) +] +update_x_ref = [{'last': 500}, {'reservoir_sampling': 500}, None] +preprocess_at_init = [True, False] +tests_mmddrift = list(product(n_features, n_enc, preprocess, + update_x_ref, preprocess_at_init)) +n_tests = len(tests_mmddrift) + + +@pytest.fixture +def mmd_params(request): + return tests_mmddrift[request.param] + + +@pytest.mark.parametrize('mmd_params', list(range(n_tests)), indirect=True) +def test_mmd(mmd_params): + n_features, n_enc, preprocess, update_x_ref, preprocess_at_init = mmd_params + + np.random.seed(0) + + x_ref = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32) + preprocess_fn, preprocess_kwargs = preprocess + to_list = False + if hasattr(preprocess_fn, '__name__') and preprocess_fn.__name__ == 'preprocess_list': + if not preprocess_at_init: + return + to_list = True + x_ref = [_[None, :] for _ in x_ref] + elif isinstance(preprocess_fn, Callable): + if 'layer' in list(preprocess_kwargs.keys()) \ + and preprocess_kwargs['model'].__name__ == 'HiddenOutput': + model = mymodel((n_features,)) + layer = preprocess_kwargs['layer'] + preprocess_fn = partial(preprocess_fn, model=HiddenOutput(model=model, layer=layer)) + elif preprocess_kwargs['model'].__name__ == 'UAE' \ + and n_features > 1 and isinstance(n_enc, int): + tf.random.set_seed(0) + encoder_net = tf.keras.Sequential( + [ + InputLayer(input_shape=(n_features,)), + Dense(n_enc) + ] + ) + preprocess_fn = partial(preprocess_fn, model=UAE(encoder_net=encoder_net)) + else: + preprocess_fn = None + else: + preprocess_fn = None + + cd = LinearTimeMMDDriftTF( + x_ref=x_ref, + p_val=.05, + preprocess_at_init=preprocess_at_init if isinstance(preprocess_fn, Callable) else False, + update_x_ref=update_x_ref, + preprocess_fn=preprocess_fn + ) + x = x_ref.copy() + preds = cd.predict(x, return_p_val=True) + assert preds['data']['is_drift'] == 0 and preds['data']['p_val'] >= cd.p_val + if isinstance(update_x_ref, dict): + k = list(update_x_ref.keys())[0] + assert cd.n == len(x) + len(x_ref) + assert cd.x_ref.shape[0] == min(update_x_ref[k], len(x) + len(x_ref)) + + x_h1 = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32) + if to_list: + x_h1 = [_[None, :] for _ in x_h1] + preds = cd.predict(x_h1, return_p_val=True) + if preds['data']['is_drift'] == 1: + assert preds['data']['p_val'] < preds['data']['threshold'] == cd.p_val + assert preds['data']['distance'] > preds['data']['distance_threshold'] + else: + assert preds['data']['p_val'] >= preds['data']['threshold'] == cd.p_val + assert preds['data']['distance'] <= preds['data']['distance_threshold'] diff --git a/alibi_detect/cd/tests/test_mmd.py b/alibi_detect/cd/tests/test_mmd.py index 33e776e14..93adf70a1 100644 --- a/alibi_detect/cd/tests/test_mmd.py +++ b/alibi_detect/cd/tests/test_mmd.py @@ -1,12 +1,15 @@ +import itertools import numpy as np import pytest from alibi_detect.cd import MMDDrift -from alibi_detect.cd.pytorch.mmd import MMDDriftTorch -from alibi_detect.cd.tensorflow.mmd import MMDDriftTF +from alibi_detect.cd.pytorch.mmd import MMDDriftTorch, LinearTimeMMDDriftTorch +from alibi_detect.cd.tensorflow.mmd import MMDDriftTF, LinearTimeMMDDriftTF n, n_features = 100, 5 -tests_mmddrift = ['tensorflow', 'pytorch', 'PyToRcH', 'mxnet'] +tests_backend_mmddrift = ['tensorflow', 'pytorch', 'PyToRcH', 'mxnet'] +tests_estimator_mmddrift = ['quad', 'linear', 'Quad', 'Linear', 'cubic'] +tests_mmddrift = list(itertools.product(tests_backend_mmddrift, tests_estimator_mmddrift)) n_tests = len(tests_mmddrift) @@ -17,17 +20,31 @@ def mmddrift_params(request): @pytest.mark.parametrize('mmddrift_params', list(range(n_tests)), indirect=True) def test_mmddrift(mmddrift_params): - backend = mmddrift_params + backend, estimator = mmddrift_params x_ref = np.random.randn(*(n, n_features)) try: - cd = MMDDrift(x_ref=x_ref, backend=backend) + cd = MMDDrift(x_ref=x_ref, backend=backend, estimator=estimator) except NotImplementedError: cd = None if backend.lower() == 'pytorch': - assert isinstance(cd._detector, MMDDriftTorch) + if estimator.lower() == 'quad': + assert isinstance(cd._detector, MMDDriftTorch) + assert isinstance(cd._detector.n_permutations, int) + elif estimator.lower() == 'linear': + assert isinstance(cd._detector, LinearTimeMMDDriftTorch) + assert hasattr(cd._detector, 'n_permutations') + else: + assert cd is None elif backend.lower() == 'tensorflow': - assert isinstance(cd._detector, MMDDriftTF) + if estimator.lower() == 'quad': + assert isinstance(cd._detector, MMDDriftTF) + assert isinstance(cd._detector.n_permutations, int) + elif estimator.lower() == 'linear': + assert isinstance(cd._detector, LinearTimeMMDDriftTF) + assert hasattr(cd._detector, 'n_permutations') + else: + assert cd is None else: assert cd is None diff --git a/alibi_detect/saving/schemas.py b/alibi_detect/saving/schemas.py index c80ce4db0..559d52c90 100644 --- a/alibi_detect/saving/schemas.py +++ b/alibi_detect/saving/schemas.py @@ -628,6 +628,7 @@ class MMDDriftConfig(DriftDetectorConfig): :class:`~alibi_detect.cd.MMDDrift` documentation for a description of each field. """ p_val: float = .05 + estimator: Literal['quad', 'linear'] = 'quad' preprocess_at_init: bool = True update_x_ref: Optional[Dict[str, int]] = None kernel: Optional[Union[str, KernelConfig]] = None @@ -646,6 +647,7 @@ class MMDDriftConfigResolved(DriftDetectorConfigResolved): :class:`~alibi_detect.cd.MMDDrift` documentation for a description of each field. """ p_val: float = .05 + estimator: Literal['quad', 'linear'] = 'quad' preprocess_at_init: bool = True update_x_ref: Optional[Dict[str, int]] = None kernel: Optional[Callable] = None diff --git a/alibi_detect/saving/tests/test_saving.py b/alibi_detect/saving/tests/test_saving.py index e99d2a6f8..105f34b07 100644 --- a/alibi_detect/saving/tests/test_saving.py +++ b/alibi_detect/saving/tests/test_saving.py @@ -355,8 +355,9 @@ def test_save_cvmdrift(data, preprocess_custom, tmp_path): {'sigma': 0.5, 'trainable': False}, # pass kernel as object ], indirect=True ) +@parametrize('estimator', ['quad', 'linear']) @parametrize_with_cases("data", cases=ContinuousData, prefix='data_') -def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed): +def test_save_mmddrift(data, kernel, estimator, preprocess_custom, backend, tmp_path, seed): """ Test MMDDrift on continuous datasets, with UAE as preprocess_fn. @@ -368,6 +369,7 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed) cd = MMDDrift(X_ref, p_val=P_VAL, backend=backend, + estimator=estimator, preprocess_fn=preprocess_custom, n_permutations=N_PERMUTATIONS, preprocess_at_init=True, @@ -386,7 +388,8 @@ def test_save_mmddrift(data, kernel, preprocess_custom, backend, tmp_path, seed) # assertions np.testing.assert_array_equal(preprocess_custom(X_ref), cd_load._detector.x_ref) assert not cd_load._detector.infer_sigma - assert cd_load._detector.n_permutations == N_PERMUTATIONS + if estimator == 'quad': + assert cd_load._detector.n_permutations == N_PERMUTATIONS assert cd_load._detector.p_val == P_VAL assert isinstance(cd_load._detector.preprocess_fn, Callable) assert cd_load._detector.preprocess_fn.func.__name__ == 'preprocess_drift' diff --git a/alibi_detect/utils/pytorch/distance.py b/alibi_detect/utils/pytorch/distance.py index b5b5e85de..85a1ed710 100644 --- a/alibi_detect/utils/pytorch/distance.py +++ b/alibi_detect/utils/pytorch/distance.py @@ -8,7 +8,12 @@ @torch.jit.script -def squared_pairwise_distance(x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30) -> torch.Tensor: +def squared_pairwise_distance( + x: torch.Tensor, + y: torch.Tensor, + a_min: float = 1e-30, + pairwise: bool = True +) -> torch.Tensor: """ PyTorch pairwise squared Euclidean distance between samples x and y. @@ -20,13 +25,40 @@ def squared_pairwise_distance(x: torch.Tensor, y: torch.Tensor, a_min: float = 1 Batch of instances of shape [Ny, features]. a_min Lower bound to clip distance values. + pairwise + Whether to compute pairwise distances or not. Returns ------- Pairwise squared Euclidean distance [Nx, Ny]. """ x2 = x.pow(2).sum(dim=-1, keepdim=True) y2 = y.pow(2).sum(dim=-1, keepdim=True) - dist = torch.addmm(y2.transpose(-2, -1), x, y.transpose(-2, -1), alpha=-2).add_(x2) + if pairwise: + dist = torch.addmm(y2.transpose(-2, -1), x, y.transpose(-2, -1), alpha=-2).add_(x2) + else: + dist = x2 + y2 - (2 * x * y).sum(dim=-1, keepdim=True) + return dist.clamp_min_(a_min) + + +def squared_distance(x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30) -> torch.Tensor: + """ + PyTorch squared Euclidean distance between samples x and y. + + Parameters + ---------- + x + Batch of instances of shape [N, features]. + y + Batch of instances of shape [N, features]. + a_min + Lower bound to clip distance values. + Returns + ------- + Squared Euclidean distance [N, 1]. + """ + x2 = x.pow(2).sum(dim=-1, keepdim=True) + y2 = y.pow(2).sum(dim=-1, keepdim=True) + dist = x2 + y2 - (2 * x * y).sum(dim=-1, keepdim=True) return dist.clamp_min_(a_min) @@ -93,8 +125,43 @@ def batch_compute_kernel_matrix( return k_mat -def mmd2_from_kernel_matrix(kernel_mat: torch.Tensor, m: int, permute: bool = False, - zero_diag: bool = True) -> torch.Tensor: +def linear_mmd2( + k_xx: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + kernel: Callable +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute maximum mean discrepancy (MMD^2) between 2 samples x and y with the + linear-time estimator. + + Parameters + ---------- + x + Batch of instances of shape [Nx, features]. + y + Batch of instances of shape [Ny, features]. + kernel + Kernel function. + Returns + ------- + MMD^2 between the samples. + """ + k_yy = kernel(x=y[0::2, :], y=y[1::2, :], pairwise=False) + k_xy = kernel(x=x[0::2, :], y=y[1::2, :], pairwise=False) + k_yx = kernel(x=y[0::2, :], y=x[1::2, :], pairwise=False) + h = k_xx + k_yy - k_xy - k_yx + mmd2 = h.mean() + var_mmd2 = torch.var(h, unbiased=True) + return mmd2, var_mmd2 + + +def mmd2_from_kernel_matrix( + kernel_mat: torch.Tensor, + m: int, + permute: bool = False, + zero_diag: bool = True +) -> torch.Tensor: """ Compute maximum mean discrepancy (MMD^2) between 2 samples x and y from the full kernel matrix between the samples. diff --git a/alibi_detect/utils/pytorch/kernels.py b/alibi_detect/utils/pytorch/kernels.py index d5e9349c2..77a398459 100644 --- a/alibi_detect/utils/pytorch/kernels.py +++ b/alibi_detect/utils/pytorch/kernels.py @@ -5,7 +5,11 @@ from typing import Optional, Union, Callable -def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor: +def sigma_median( + x: torch.Tensor, + y: torch.Tensor, + dist: torch.Tensor, +) -> torch.Tensor: """ Bandwidth estimation using the median heuristic :cite:t:`Gretton2012`. @@ -16,8 +20,8 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch. y Tensor of instances with dimension [Ny, features]. dist - Tensor with dimensions [Nx, Ny], containing the pairwise distances between `x` and `y`. - + Tensor with dimensions [Nx, Ny] when pairwise=True, containing the pairwise distances between `x` and `y`. + Dimensions are [Nx, 1] when pairwise=False. Returns ------- The computed bandwidth, `sigma`. @@ -70,16 +74,25 @@ def __init__( def sigma(self) -> torch.Tensor: return self.log_sigma.exp() - def forward(self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor], - infer_sigma: bool = False) -> torch.Tensor: + def forward( + self, x: Union[np.ndarray, torch.Tensor], + y: Union[np.ndarray, torch.Tensor], + infer_sigma: bool = False, + pairwise: bool = True + ) -> torch.Tensor: x, y = torch.as_tensor(x), torch.as_tensor(y) - dist = distance.squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny] + dist = distance.squared_pairwise_distance(x=x.flatten(1), + y=y.flatten(1), + pairwise=pairwise) # [Nx, Ny] or [Nx, 1] if infer_sigma or self.init_required: if self.trainable and infer_sigma: raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value") - sigma = self.init_sigma_fn(x, y, dist) + if pairwise: + sigma = self.init_sigma_fn(x, y, dist) + else: + sigma = (.5 * dist.flatten().sort().values[dist.shape[0] // 2 - 1].unsqueeze(dim=-1)) ** .5 with torch.no_grad(): self.log_sigma.copy_(sigma.log().clone()) self.init_required = False diff --git a/alibi_detect/utils/tensorflow/distance.py b/alibi_detect/utils/tensorflow/distance.py index 7b003fd7f..ddc66b55f 100644 --- a/alibi_detect/utils/tensorflow/distance.py +++ b/alibi_detect/utils/tensorflow/distance.py @@ -6,7 +6,13 @@ logger = logging.getLogger(__name__) -def squared_pairwise_distance(x: tf.Tensor, y: tf.Tensor, a_min: float = 1e-30, a_max: float = 1e30) -> tf.Tensor: +def squared_pairwise_distance( + x: tf.Tensor, + y: tf.Tensor, + a_min: float = 1e-30, + a_max: float = 1e30, + pairwise: bool = True +) -> tf.Tensor: """ TensorFlow pairwise squared Euclidean distance between samples x and y. @@ -20,14 +26,44 @@ def squared_pairwise_distance(x: tf.Tensor, y: tf.Tensor, a_min: float = 1e-30, Lower bound to clip distance values. a_max Upper bound to clip distance values. - + pairwise + Whether to compute pairwise distances or not. Returns ------- Pairwise squared Euclidean distance [Nx, Ny]. """ x2 = tf.reduce_sum(x ** 2, axis=-1, keepdims=True) y2 = tf.reduce_sum(y ** 2, axis=-1, keepdims=True) - dist = x2 + tf.transpose(y2, (1, 0)) - 2. * x @ tf.transpose(y, (1, 0)) + if pairwise: + dist = x2 + tf.transpose(y2, (1, 0)) - 2. * x @ tf.transpose(y, (1, 0)) + else: + dist = x2 + y2 - 2. * tf.reduce_sum(x * y, axis=-1, keepdims=True) + return tf.clip_by_value(dist, a_min, a_max) + + +def squared_distance(x: tf.Tensor, y: tf.Tensor, + a_min: float = 1e-30, a_max: float = 1e30) -> tf.Tensor: + """ + TensorFlow squared Euclidean distance between samples x and y. + + Parameters + ---------- + x + Batch of instances of shape [N, features]. + y + Batch of instances of shape [N, features]. + a_min + Lower bound to clip distance values. + a_max + Upper bound to clip distance values. + + Returns + ------- + Squared Euclidean distance [N, 1]. + """ + x2 = tf.reduce_sum(x ** 2, axis=-1, keepdims=True) + y2 = tf.reduce_sum(y ** 2, axis=-1, keepdims=True) + dist = x2 + y2 - 2. * tf.reduce_sum(x * y, axis=-1, keepdims=True) return tf.clip_by_value(dist, a_min, a_max) @@ -83,8 +119,41 @@ def batch_compute_kernel_matrix( return k_mat -def mmd2_from_kernel_matrix(kernel_mat: tf.Tensor, m: int, permute: bool = False, - zero_diag: bool = True) -> tf.Tensor: +def linear_mmd2( + k_xx: tf.Tensor, + x: tf.Tensor, + y: tf.Tensor, + kernel: Callable +) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Compute maximum mean discrepancy (MMD^2) between 2 samples x and y with the + linear-time estimator. + + Parameters + ---------- + x + Batch of instances of shape [Nx, features]. + y + Batch of instances of shape [Ny, features]. + kernel + Kernel function. + """ + n = x.shape[0] + k_yy = kernel(x=y[0::2, :], y=y[1::2, :], pairwise=False) + k_xy = kernel(x=x[0::2, :], y=y[1::2, :], pairwise=False) + k_yx = kernel(x=y[0::2, :], y=x[1::2, :], pairwise=False) + h = k_xx + k_yy - k_xy - k_yx + mmd2 = tf.reduce_mean(h) + var_mmd2 = tf.math.reduce_sum(h ** 2) / ((n / 2.) - 1) - (mmd2 ** 2) + return mmd2, var_mmd2 + + +def mmd2_from_kernel_matrix( + kernel_mat: tf.Tensor, + m: int, + permute: bool = False, + zero_diag: bool = True +) -> tf.Tensor: """ Compute maximum mean discrepancy (MMD^2) between 2 samples x and y from the full kernel matrix between the samples. diff --git a/alibi_detect/utils/tensorflow/kernels.py b/alibi_detect/utils/tensorflow/kernels.py index 6236f713d..9a21ddd48 100644 --- a/alibi_detect/utils/tensorflow/kernels.py +++ b/alibi_detect/utils/tensorflow/kernels.py @@ -5,7 +5,11 @@ from scipy.special import logit -def sigma_median(x: tf.Tensor, y: tf.Tensor, dist: tf.Tensor) -> tf.Tensor: +def sigma_median( + x: tf.Tensor, + y: tf.Tensor, + dist: tf.Tensor +) -> tf.Tensor: """ Bandwidth estimation using the median heuristic :cite:t:`Gretton2012`. @@ -17,7 +21,6 @@ def sigma_median(x: tf.Tensor, y: tf.Tensor, dist: tf.Tensor) -> tf.Tensor: Tensor of instances with dimension [Ny, features]. dist Tensor with dimensions [Nx, Ny], containing the pairwise distances between `x` and `y`. - Returns ------- The computed bandwidth, `sigma`. @@ -71,15 +74,23 @@ def __init__( def sigma(self) -> tf.Tensor: return tf.math.exp(self.log_sigma) - def call(self, x: tf.Tensor, y: tf.Tensor, infer_sigma: bool = False) -> tf.Tensor: + def call( + self, x: tf.Tensor, + y: tf.Tensor, + infer_sigma: bool = False, + pairwise: bool = True + ) -> tf.Tensor: y = tf.cast(y, x.dtype) x, y = tf.reshape(x, (x.shape[0], -1)), tf.reshape(y, (y.shape[0], -1)) # flatten - dist = distance.squared_pairwise_distance(x, y) # [Nx, Ny] + dist = distance.squared_pairwise_distance(x=x, y=y, pairwise=pairwise) # [Nx, Ny] or [Nx, 1] if infer_sigma or self.init_required: if self.trainable and infer_sigma: raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value") - sigma = self.init_sigma_fn(x, y, dist) + if pairwise: + sigma = self.init_sigma_fn(x, y, dist) + else: + sigma = tf.expand_dims((.5 * tf.sort(tf.reshape(dist, (-1,)))[dist.shape[0] // 2 - 1]) ** .5, axis=0) self.log_sigma.assign(tf.math.log(sigma)) self.init_required = False diff --git a/doc/source/cd/methods/mmddrift.ipynb b/doc/source/cd/methods/mmddrift.ipynb index 3dc2e1d99..99b22272d 100644 --- a/doc/source/cd/methods/mmddrift.ipynb +++ b/doc/source/cd/methods/mmddrift.ipynb @@ -48,6 +48,8 @@ "\n", "* `p_val`: p-value used for significance of the permutation test.\n", "\n", + "* `estimator`: Estimator used for the MMD^2 computation. *'quad'* is the default and uses the quadratic u-statistics on each square kernel matrix. *'linear'* uses the linear time estimator as in Gretton et al. (2014), and the threshold is computed using the Gaussian asympotic distribution under null.\n", + "\n", "* `preprocess_at_init`: Whether to already apply the (optional) preprocessing step to the reference data at initialization and store the preprocessed data. Dependent on the preprocessing step, this can reduce the computation time for the predict step significantly, especially when the reference dataset is large. Defaults to *True*. It is possible that it needs to be set to *False* if the preprocessing step requires statistics from both the reference and test data, such as the mean or standard deviation.\n", "\n", "* `x_ref_preprocessed`: Whether or not the reference data `x_ref` has already been preprocessed. If *True*, the reference data will be skipped and preprocessing will only be applied to the test data passed to `predict`.\n",