From cdcb283262c9f559305c3fc7ed715c66e5ff3a1d Mon Sep 17 00:00:00 2001 From: Yuehan Qin Date: Wed, 3 Jul 2024 23:17:39 -0500 Subject: [PATCH 1/2] device for deep_svdd.py --- pyod/models/deep_svdd.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pyod/models/deep_svdd.py b/pyod/models/deep_svdd.py index 18d429ae..dbdd83ab 100644 --- a/pyod/models/deep_svdd.py +++ b/pyod/models/deep_svdd.py @@ -241,7 +241,7 @@ def __init__(self, n_features, c=None, use_ae=False, hidden_neurons=None, batch_size=32, dropout_rate=0.2, l2_regularizer=0.1, validation_size=0.1, preprocessing=True, - verbose=1, random_state=None, contamination=0.1): + verbose=1, random_state=None, contamination=0.1, device=None): super(DeepSVDD, self).__init__(contamination=contamination) self.n_features = n_features @@ -261,6 +261,7 @@ def __init__(self, n_features, c=None, use_ae=False, hidden_neurons=None, self.random_state = random_state self.model_ = None self.best_model_dict = None + self.device = device if self.random_state is not None: torch.manual_seed(self.random_state) @@ -313,10 +314,11 @@ def fit(self, X, y=None): output_activation=self.output_activation, dropout_rate=self.dropout_rate, l2_regularizer=self.l2_regularizer) + self.model_.to(self.device) X_norm = torch.tensor(X_norm, dtype=torch.float32) if self.c is None: self.c = 0.0 - self.model_._init_c(X_norm) + self.model_._init_c(X_norm.to(self.device)) # Predict on X itself and calculate the reconstruction error as # the outlier scores. Noted X_norm was shuffled has to recreate @@ -325,7 +327,7 @@ def fit(self, X, y=None): else: X_norm = np.copy(X) - X_norm = torch.tensor(X_norm, dtype=torch.float32) + X_norm = torch.tensor(X_norm, dtype=torch.float32).to(self.device) dataset = TensorDataset(X_norm, X_norm) dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) @@ -342,6 +344,7 @@ def fit(self, X, y=None): self.model_.train() epoch_loss = 0 for batch_x, _ in dataloader: + batch_x = batch_x.to(self.device) optimizer.zero_grad() outputs = self.model_(batch_x) dist = torch.sum((outputs - self.c) ** 2, dim=-1) @@ -389,10 +392,10 @@ def decision_function(self, X): X_norm = self.scaler_.transform(X) else: X_norm = np.copy(X) - X_norm = torch.tensor(X_norm, dtype=torch.float32) + X_norm = torch.tensor(X_norm, dtype=torch.float32).to(self.device) self.model_.eval() with torch.no_grad(): outputs = self.model_(X_norm) dist = torch.sum((outputs - self.c) ** 2, dim=-1) - anomaly_scores = dist.numpy() + anomaly_scores = dist.cpu().numpy() return anomaly_scores From b97f6f24e47d613107c72bcf108a1b31745a08f8 Mon Sep 17 00:00:00 2001 From: Yuehan Qin Date: Wed, 3 Jul 2024 23:55:49 -0500 Subject: [PATCH 2/2] deep_svdd base --- pyod/models/deep_svdd.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/pyod/models/deep_svdd.py b/pyod/models/deep_svdd.py index dbdd83ab..b6a1074c 100644 --- a/pyod/models/deep_svdd.py +++ b/pyod/models/deep_svdd.py @@ -11,30 +11,15 @@ import numpy as np import torch import torch.nn as nn -import torch.optim as optim from sklearn.preprocessing import StandardScaler from sklearn.utils import check_array from torch.utils.data import DataLoader, TensorDataset from .base import BaseDetector -from ..utils.torch_utility import get_activation_by_name +from ..utils.torch_utility import get_activation_by_name, get_optimizer_by_name from ..utils.utility import check_parameter -optimizer_dict = { - 'sgd': optim.SGD, - 'adam': optim.Adam, - 'rmsprop': optim.RMSprop, - 'adagrad': optim.Adagrad, - 'adadelta': optim.Adadelta, - 'adamw': optim.AdamW, - 'nadam': optim.NAdam, - 'sparseadam': optim.SparseAdam, - 'asgd': optim.ASGD, - 'lbfgs': optim.LBFGS -} - - class InnerDeepSVDD(nn.Module): """Inner class for DeepSVDD model. @@ -335,8 +320,7 @@ def fit(self, X, y=None): best_loss = float('inf') best_model_dict = None - optimizer = optimizer_dict[self.optimizer](self.model_.parameters(), - weight_decay=self.l2_regularizer) + optimizer = get_optimizer_by_name(self.model_, self.optimizer, weight_decay=self.l2_regularizer) w_d = 1e-6 * sum( [torch.linalg.norm(w) for w in self.model_.parameters()])