From 1e5a016f7d8ccb138f48a707d6a028e0336e49db Mon Sep 17 00:00:00 2001 From: Dimitri Korsch Date: Wed, 17 Apr 2024 08:44:04 +0200 Subject: [PATCH 1/9] updated implementation --- src/concrete_dropout/__main__.py | 20 ++++++++++++++++++++ src/concrete_dropout/default.py | 7 ++++--- src/concrete_dropout/improved.py | 22 ++++++++++------------ src/fido/module.py | 19 ++++++++++++++----- 4 files changed, 48 insertions(+), 20 deletions(-) create mode 100644 src/concrete_dropout/__main__.py diff --git a/src/concrete_dropout/__main__.py b/src/concrete_dropout/__main__.py new file mode 100644 index 0000000..a1edfa6 --- /dev/null +++ b/src/concrete_dropout/__main__.py @@ -0,0 +1,20 @@ +import torch as th +import sys + +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +from concrete_dropout.default import concrete_dropout +from concrete_dropout.improved import sigmoid_concrete_dropout + +def main(): + size = 32 + X = th.zeros((size, size), requires_grad=True, dtype=th.float32) + u = th.zeros_like(X).uniform_() + + y0 = concrete_dropout(X.sigmoid(), u=u) + y1 = sigmoid_concrete_dropout(X, u=u) + assert th.allclose(y0, y1) + +main() diff --git a/src/concrete_dropout/default.py b/src/concrete_dropout/default.py index fe8c08f..ae3c94d 100644 --- a/src/concrete_dropout/default.py +++ b/src/concrete_dropout/default.py @@ -1,8 +1,10 @@ import numpy as np import torch as th -def concrete_dropout(dropout_rate: th.Tensor, temp: float = 0.1, eps: float = np.finfo(float).eps): - u = th.zeros_like(dropout_rate).uniform_() + +def concrete_dropout(dropout_rate: th.Tensor, temp: float = 0.1, u = None, eps: float = np.finfo(float).eps): + if u is None: + u = th.zeros_like(dropout_rate).uniform_() approx = ( (dropout_rate + eps).log() - @@ -13,4 +15,3 @@ def concrete_dropout(dropout_rate: th.Tensor, temp: float = 0.1, eps: float = np ) return 1 - (approx / temp).sigmoid() - diff --git a/src/concrete_dropout/improved.py b/src/concrete_dropout/improved.py index 56f3d57..fec2a8a 100644 --- a/src/concrete_dropout/improved.py +++ b/src/concrete_dropout/improved.py @@ -1,26 +1,24 @@ import numpy as np import torch as th -EPS = 1e-7 class SigmoidConcreteDropout(th.autograd.Function): @staticmethod - def forward(ctx, logit_p, temp: float = 1.0, u = None): + def forward(ctx, logit_p, temp: float = 1.0, u = None, *, eps = np.finfo(float).eps): """ our proposed simplification """ - global EPS temp = th.scalar_tensor(temp) if u is None: u = th.zeros_like(logit_p).uniform_() - noise = ((u + EPS) / (1 - u + EPS)).log() + noise = ((u + eps) / (1 - u + eps)).log() logit_p_temp = (logit_p + noise) / temp - res = logit_p_temp.sigmoid() + keep_rate = logit_p_temp.sigmoid() - ctx.save_for_backward(res, logit_p_temp, temp) - return res + ctx.save_for_backward(keep_rate, logit_p_temp, temp) + return 1 - keep_rate @staticmethod def backward(ctx, grad_output): @@ -28,12 +26,12 @@ def backward(ctx, grad_output): Gradient of random_tensor w.r.t logit_p is simply 1/temp * sigmoid(logit_p_temp)^2 * exp(-logit_p_temp) """ - res, logit_p_temp, temp = ctx.saved_tensors - grad = th.zeros_like(res) - mask = res != 0 - grad[mask] = res[mask]**2 * (-logit_p_temp[mask]).exp() / temp + keep_rate, logit_p_temp, temp = ctx.saved_tensors + grad = th.zeros_like(keep_rate) + mask = keep_rate != 0 + grad[mask] = keep_rate[mask]**2 * (-logit_p_temp[mask]).exp() / temp - return grad * grad_output, None, None + return -grad * grad_output, None, None @classmethod def _check_grad(clf, temp: float = 1.0, size: int = 16, *, dtype = th.float64): diff --git a/src/fido/module.py b/src/fido/module.py index 0cd8ae9..a16077c 100644 --- a/src/fido/module.py +++ b/src/fido/module.py @@ -69,7 +69,7 @@ def tv_loss(self, *, reduction="mean"): # p = th.stack([self.ssr_logit_p, self.sdr_logit_p], axis=0) p = th.stack([self.ssr_dropout_rate, self.sdr_dropout_rate], axis=0) # p = self.joint_dropout_rate[None] - return tm.TotalVariation(reduction=reduction).to(p.device)(p[None]) + return tm.image.TotalVariation(reduction=reduction).to(p.device)(p[None]) def l1_norm(self): # return self.joint_dropout_rate.sum() @@ -120,7 +120,7 @@ def forward(self, X, **kwargs): sdr = self.sdr(X, **kwargs) return ssr, sdr - def objective(self, X, y, clf, *, batch_size: int, deterministic: bool = False): + def objective(self, X, y, clf, *, batch_size: int, deterministic: bool = False, add_deterministic: bool = False): ssr, sdr = self(X, batch_size=batch_size, deterministic=deterministic) n = len(ssr) _x = th.concatenate([ssr, sdr]) @@ -129,11 +129,17 @@ def objective(self, X, y, clf, *, batch_size: int, deterministic: bool = False): ssr_odds, sdr_odds = odds[:n], odds[n:] # sdr - ssr <== minimizing sdr and maximizing ssr probabilities loss = (sdr_odds - ssr_odds).mean() + + if not deterministic and add_deterministic: + _, _, det_loss, _, _ = self.objective(X, y, clf, batch_size=1, deterministic=True, add_deterministic=False) + loss += det_loss + return (ssr, sdr), (ssr_prob.mean(), sdr_prob.mean()), loss, self.l1_norm(), self.tv_loss() def fit(self, im, y, clf, *, config: FIDOConfig, metrics: Metrics = None, update_callback = None): opt = th.optim.AdamW(self.params, lr=config.learning_rate, eps=0.1, weight_decay=config.l2) + # opt = th.optim.SGD(self.params, lr=config.learning_rate, momentum=0.9, weight_decay=config.l2) opt.zero_grad() ssr_grad, sdr_grad = 1, 1 @@ -142,7 +148,6 @@ def fit(self, im, y, clf, *, config: FIDOConfig, metrics: Metrics = None, update if config.approx_steps is None: # no approximation masks, probs, loss, l1_norm, tvl = self.objective(im, y, clf, batch_size=config.batch_size) - else: if i % config.approx_steps == 0: (ssr, sdr), probs, loss, l1_norm, tvl = self.objective(im, y, clf, batch_size=config.batch_size) @@ -184,7 +189,7 @@ def fit(self, im, y, clf, *, config: FIDOConfig, metrics: Metrics = None, update update_callback(i, is_last_step=True) - def plot(self, im, pred, clf, *, output=None, metrics: dict = None): + def plot(self, im, pred, clf, *, output=None, metrics: dict = None, thresh: float = None): ssr_keep_rate = self._upsample(im, self._keep_rate(self.ssr_logit_p, deterministic=True)) ssr_keep_rate = ssr_keep_rate.detach().cpu().squeeze(0).numpy() # ssr_bernouli = self._upsample(im, self._keep_rate(self.ssr_logit_p, deterministic=False)) @@ -195,6 +200,9 @@ def plot(self, im, pred, clf, *, output=None, metrics: dict = None): joint_keep_rate = np.sqrt(ssr_keep_rate * (1-sdr_keep_rate)) + if thresh is not None: + joint_keep_rate[joint_keep_rate < thresh] = np.nan + cls_id = pred.argmax() orig_im = (im * 0.5 + 0.5).permute(1, 2, 0).numpy() @@ -256,6 +264,7 @@ def plot(self, im, pred, clf, *, output=None, metrics: dict = None): plt.show() else: plt.savefig(output) + plt.close(fig=fig) if metrics is not None: @@ -274,4 +283,4 @@ def plot(self, im, pred, clf, *, output=None, metrics: dict = None): loss_graphs = output.with_suffix(f".losses{output.suffix}") plt.savefig(loss_graphs) - plt.close() + plt.close(fig=fig) From 93c73f3ebb532f2ac2b619d2500512b8ae65bc57 Mon Sep 17 00:00:00 2001 From: Dimitri Korsch Date: Fri, 3 May 2024 08:49:40 +0200 Subject: [PATCH 2/9] refactored infills --- src/fido/configs.py | 3 +- src/fido/infill.py | 77 +++++++++++++++++++++++++-------------------- src/fido/module.py | 5 ++- 3 files changed, 47 insertions(+), 38 deletions(-) diff --git a/src/fido/configs.py b/src/fido/configs.py index 3f09232..c656ba3 100644 --- a/src/fido/configs.py +++ b/src/fido/configs.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from fido.infill import Infill @dataclass class FIDOConfig: @@ -23,4 +24,4 @@ def reg_params(self) -> dict: class MaskConfig: optimized: bool = True mask_size: int = None - infill_strategy: str = "blur" + infill_strategy: Infill = Infill.BLUR diff --git a/src/fido/infill.py b/src/fido/infill.py index 032f13f..d67094a 100644 --- a/src/fido/infill.py +++ b/src/fido/infill.py @@ -1,7 +1,9 @@ import cv2 +import enum import numpy as np import torch as th +from functools import partial from numpy.lib.stride_tricks import sliding_window_view from skimage.measure import regionprops from torchvision.transforms import functional as tr @@ -13,55 +15,62 @@ def patches(im, size): window_shape = (size, size, im.shape[-1]) return sliding_window_view(im, window_shape)[::size, ::size] -def new_infill(im: th.Tensor, strategy: str, *, mean: float = 0.5, std: float = 0.5, size: int = 9, device=None): - if strategy == "original": - infill = im.clone() +class Infill(enum.Enum): + ORIGINIAL = enum.auto() + ZEROS = enum.auto() + UNIFORM = enum.auto() + NORMAL = enum.auto() + MEAN = enum.auto() + BLUR = enum.auto() + KNOCKOFF = enum.auto() - elif strategy == "zeros": - infill = th.zeros_like(im, device=device) + def normalize(self, arr: th.Tensor, mean: float, std: float) -> th.Tensor: + return (arr - mean) / std - elif strategy == "uniform": - infill = th.zeros_like(im, device=device).uniform_(0, 1) + def new(self, im: th.Tensor, *, mean: float = 0.5, std: float = 0.5, size: int = 9, device: th.device = None): + _normalize = partial(self.normalize, mean=mean, std=std) + if self == Infill.BLUR: + # should result in std=10, same as in the paper + # de-normalize first + return _normalize(tr.gaussian_blur(im.to(device), kernel_size=75).to(im.device)) - elif strategy == "normal": - infill = th.zeros_like(im, device=device).normal_(std=0.2) + if self == Infill.ORIGINIAL: + return _normalize(im.clone()) - elif strategy == "mean": - mean_pixel = im.mean(axis=(1,2), keepdim=True) - infill = mean_pixel.expand(im.shape) + if self == Infill.ZEROS: + return _normalize(th.zeros_like(im, device=device)) - elif strategy == "blur": - # should result in std=10, same as in the paper - # de-normalize first - infill = tr.gaussian_blur(im.to(device), kernel_size=75).to(im.device) + if self == Infill.UNIFORM: + return _normalize(th.zeros_like(im, device=device).uniform_(0, 1)) - elif strategy == "local": - pass + if self == Infill.NORMAL: + return _normalize(th.zeros_like(im, device=device).normal_(std=0.2)) - elif strategy == "knockoff": - infill = np.zeros_like(im) - windows = patches(im.detach().cpu().numpy().transpose(1,2,0), size=size) - rows, cols, *_ = windows.shape - for i in range(rows*cols): - row, col, _chan = np.unravel_index(i, (rows, cols, 1)) - patch = windows[row, col, _chan] + if self == Infill.MEAN: + mean_pixel = im.mean(axis=(1,2), keepdim=True) + return _normalize(mean_pixel.expand(im.shape)) - x0, y0 = col*size, row*size - x1, y1 = (col+1)*size, (row+1)*size + if self == Infill.KNOCKOFF: - idxs = np.random.permutation(size**2) - idxs = np.unravel_index(idxs, (size, size)) + infill = np.zeros_like(im) + windows = patches(im.detach().cpu().numpy().transpose(1,2,0), size=size) + rows, cols, *_ = windows.shape + for i in range(rows*cols): + row, col, _chan = np.unravel_index(i, (rows, cols, 1)) + patch = windows[row, col, _chan] - infill[:, y0:y1, x0:x1] = patch[idxs].reshape(patch.shape).transpose(2,0,1) + x0, y0 = col*size, row*size + x1, y1 = (col+1)*size, (row+1)*size - infill = th.tensor(infill) - else: - raise ValueError(f"Unknown in-fill strategy: {strategy}") + idxs = np.random.permutation(size**2) + idxs = np.unravel_index(idxs, (size, size)) - return (infill - mean) / std + infill[:, y0:y1, x0:x1] = patch[idxs].reshape(patch.shape).transpose(2,0,1) + return _normalize(th.tensor(infill)) + raise NotImplementedError(f"Strategy is not implemented yet: {self}!") def _calc_bbox(mask, min_size, *, pad: int = 10, squared: bool = True): diff --git a/src/fido/module.py b/src/fido/module.py index a16077c..c348a1e 100644 --- a/src/fido/module.py +++ b/src/fido/module.py @@ -10,7 +10,6 @@ from concrete_dropout import sigmoid_concrete_dropout from fido.configs import FIDOConfig from fido.configs import MaskConfig -from fido.infill import new_infill from fido.metrics import Metrics def reg_scale(i, max_value): @@ -21,10 +20,10 @@ def reg_scale(i, max_value): class FIDO(th.nn.Module): @classmethod - def new(cls, im, params: MaskConfig, *, device): + def new(cls, im, params: MaskConfig, *, device: th.device): # de-normalize first - infill = new_infill(im * 0.5 + 0.5, params.infill_strategy, device=device) + infill = params.infill_strategy.new(im * 0.5 + 0.5, device=device) C, H, W = im.shape size = (params.mask_size, params.mask_size) From 4ba05c24a51c80b59685b35a8b615d5442a81398 Mon Sep 17 00:00:00 2001 From: Dimitri Korsch Date: Sat, 4 May 2024 14:43:22 +0200 Subject: [PATCH 3/9] added GAN-based infill generation --- src/fido/infill.py | 57 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/src/fido/infill.py b/src/fido/infill.py index d67094a..5e60880 100644 --- a/src/fido/infill.py +++ b/src/fido/infill.py @@ -7,7 +7,7 @@ from numpy.lib.stride_tricks import sliding_window_view from skimage.measure import regionprops from torchvision.transforms import functional as tr - +from pathlib import Path from fido.metrics import _thresh @@ -15,6 +15,7 @@ def patches(im, size): window_shape = (size, size, im.shape[-1]) return sliding_window_view(im, window_shape)[::size, ::size] +gan = None class Infill(enum.Enum): ORIGINIAL = enum.auto() @@ -23,12 +24,14 @@ class Infill(enum.Enum): NORMAL = enum.auto() MEAN = enum.auto() BLUR = enum.auto() + GAN = enum.auto() KNOCKOFF = enum.auto() def normalize(self, arr: th.Tensor, mean: float, std: float) -> th.Tensor: return (arr - mean) / std def new(self, im: th.Tensor, *, mean: float = 0.5, std: float = 0.5, size: int = 9, device: th.device = None): + global gan _normalize = partial(self.normalize, mean=mean, std=std) if self == Infill.BLUR: # should result in std=10, same as in the paper @@ -70,8 +73,60 @@ def new(self, im: th.Tensor, *, mean: float = 0.5, std: float = 0.5, size: int return _normalize(th.tensor(infill)) + if self == Infill.GAN: + if gan is None: + gan = GAN(device=im.device) + infill = gan(im) + return infill + raise NotImplementedError(f"Strategy is not implemented yet: {self}!") +import dmfn # noqa: E402 +from dmfn.models.networks import define_G # noqa: E402 + +class GAN: + + DEFAULT_WEIGHTS: Path = Path(dmfn.__file__).resolve().parent.parent.parent / "outputs/cub/checkpoints/latest_G.pth" + + @staticmethod + def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + th.nn.init.orthogonal_(m.weight.data, 1.0) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + elif classname.find('Linear') != -1: + th.nn.init.orthogonal_(m.weight.data, 1.0) + if m.bias is not None: + m.bias.data.fill_(0.0) + + def __init__(self, weights: Path = None, device: th.device = th.device("cpu")): + + if weights is None: + weights = GAN.DEFAULT_WEIGHTS + assert Path(weights).exists(), \ + f"Could not find weights: {weights}" + + opt = dict(network_G=dict( which_model_G='DMFN', in_nc=4, out_nc=3, nf=64, n_res=8),is_train=False) + self.generator = define_G(opt).to(device) + self.generator.load_state_dict(th.load(weights), strict=True) + self.generator.eval() + + def __call__(self, im: th.Tensor, *, grid_size: int = 16, size: tuple = (256, 256)): + mask = th.randint(0, 2, size=(1, 1, grid_size, grid_size), dtype=im.dtype, device=im.device) + orig_size = im.size()[-2:] + im = tr.resize(im, size) + mask = tr.resize(mask, size, interpolation=tr.InterpolationMode.NEAREST) + + X1 = th.cat([im * mask, 1-mask], dim=1) + X2 = th.cat([im * (1-mask), mask], dim=1) + + X = th.cat([X1, X2], dim=0) + out = self.generator(X).detach() + # combine the result from the generated images + res = out[0] * (1 - mask[0]) + out[1] * mask[0] + return tr.resize(res, orig_size) def _calc_bbox(mask, min_size, *, pad: int = 10, squared: bool = True): props = regionprops(mask) From c7ca2ffe2ee7d7b16b1c028a26d128f9edb5c3d5 Mon Sep 17 00:00:00 2001 From: Dimitri Korsch Date: Thu, 16 May 2024 16:00:45 +0200 Subject: [PATCH 4/9] fix in variable name --- src/fido/storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fido/storage.py b/src/fido/storage.py index 51efc30..9f1cce1 100644 --- a/src/fido/storage.py +++ b/src/fido/storage.py @@ -51,7 +51,7 @@ def __getitem__(self, i): return self._cache[key].unpack() def __len__(self) -> int: - return len(self._npzfile.files) // 2 + return len(self.npzfile.files) // 2 def _maps_file_name(subset, *, lazy_load: bool = True): suffix = ".lazy" if lazy_load else "" From ea3c610c59dae9fa85654004fe9f78ac4b3a33c4 Mon Sep 17 00:00:00 2001 From: Dimitri Korsch Date: Fri, 14 Jun 2024 07:26:45 +0200 Subject: [PATCH 5/9] refactor infill generation for better performance and readability --- src/fido/infill.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/fido/infill.py b/src/fido/infill.py index 5e60880..20ff347 100644 --- a/src/fido/infill.py +++ b/src/fido/infill.py @@ -45,10 +45,12 @@ def new(self, im: th.Tensor, *, mean: float = 0.5, std: float = 0.5, size: int return _normalize(th.zeros_like(im, device=device)) if self == Infill.UNIFORM: - return _normalize(th.zeros_like(im, device=device).uniform_(0, 1)) + res = th.zeros_like(im, device=device).uniform_(0, 1) + return _normalize(self.normalize(res, mean=res.mean(), std=res.std())) if self == Infill.NORMAL: - return _normalize(th.zeros_like(im, device=device).normal_(std=0.2)) + res = th.zeros_like(im, device=device).normal_(0, 1) + return _normalize(self.normalize(res, mean=res.mean(), std=res.std())) if self == Infill.MEAN: mean_pixel = im.mean(axis=(1,2), keepdim=True) From 4af84fb2933479b277ac891b7ba0cd1ed5463fd6 Mon Sep 17 00:00:00 2001 From: Dimitri Korsch Date: Mon, 29 Jul 2024 14:32:20 +0200 Subject: [PATCH 6/9] added a check if the callback function is given or not --- src/fido/module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fido/module.py b/src/fido/module.py index c348a1e..d6c5923 100644 --- a/src/fido/module.py +++ b/src/fido/module.py @@ -185,7 +185,8 @@ def fit(self, im, y, clf, *, config: FIDOConfig, metrics: Metrics = None, update opt.step() opt.zero_grad() - update_callback(i, is_last_step=True) + if update_callback is not None: + update_callback(i, is_last_step=True) def plot(self, im, pred, clf, *, output=None, metrics: dict = None, thresh: float = None): From 5d7abfebb2c091d890896e01678a67d1c377f041 Mon Sep 17 00:00:00 2001 From: Dimitri Korsch Date: Tue, 30 Jul 2024 07:50:22 +0200 Subject: [PATCH 7/9] added a check whether the classifier has log_odds implemented --- src/fido/__init__.py | 19 +++++++++++++++++++ src/fido/module.py | 27 ++++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/fido/__init__.py b/src/fido/__init__.py index e69de29..d3f7573 100644 --- a/src/fido/__init__.py +++ b/src/fido/__init__.py @@ -0,0 +1,19 @@ +from fido.module import FIDO +from fido.infill import Infill +from fido.configs import FIDOConfig +from fido.configs import MaskConfig +from fido.metrics import Metrics +from fido.module import log_odds +from fido.storage import load_maps +from fido.storage import dump_maps + +__all__ = [ + "FIDO", + "Infill", + "FIDOConfig", + "MaskConfig", + "Metrics", + "log_odds", + "load_maps", + "dump_maps", +] diff --git a/src/fido/module.py b/src/fido/module.py index d6c5923..7c2ba53 100644 --- a/src/fido/module.py +++ b/src/fido/module.py @@ -47,6 +47,17 @@ def __init__(self, size: tuple, *, self._optimized = optimized self._infill = infill + def to(self, *args, **kwargs): + self._infill = self._infill.to(*args, **kwargs) + self.ssr_logit_p = self.ssr_logit_p.to(*args, **kwargs) + self.sdr_logit_p = self.sdr_logit_p.to(*args, **kwargs) + return super().to(*args, **kwargs) + + def cpu(self): + self._infill = self._infill.cpu() + self.ssr_logit_p = self.ssr_logit_p.cpu() + self.sdr_logit_p = self.sdr_logit_p.cpu() + return super().cpu() @property def params(self): @@ -123,7 +134,12 @@ def objective(self, X, y, clf, *, batch_size: int, deterministic: bool = False, ssr, sdr = self(X, batch_size=batch_size, deterministic=deterministic) n = len(ssr) _x = th.concatenate([ssr, sdr]) - prob, odds = clf.log_odds(_x, c=y) + if hasattr(clf, "log_odds"): + prob, odds = clf.log_odds(_x, c=y) + else: + logits = clf(_x) + prob, odds = log_odds(logits, c=y) + ssr_prob, sdr_prob = prob[:n], prob[n:] ssr_odds, sdr_odds = odds[:n], odds[n:] # sdr - ssr <== minimizing sdr and maximizing ssr probabilities @@ -284,3 +300,12 @@ def plot(self, im, pred, clf, *, output=None, metrics: dict = None, thresh: floa plt.savefig(loss_graphs) plt.close(fig=fig) + + +def log_odds(logits, c): + # normalized log-probabilities + log_probs = F.log_softmax(logits, dim=1) + mask = th.ones_like(log_probs[:1]) + mask[:, c] = 0 + odds = log_probs[:, c] - th.logsumexp(log_probs * mask, dim=1) + return log_probs[:, c].exp(), odds From f3852f010c2671cf52b6528bc53ae04a57e5b519 Mon Sep 17 00:00:00 2001 From: Dimitri Korsch Date: Tue, 30 Jul 2024 08:20:47 +0200 Subject: [PATCH 8/9] moved metric tracking into the FIDO module --- src/fido/module.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/fido/module.py b/src/fido/module.py index 7c2ba53..68ec4c1 100644 --- a/src/fido/module.py +++ b/src/fido/module.py @@ -20,7 +20,7 @@ def reg_scale(i, max_value): class FIDO(th.nn.Module): @classmethod - def new(cls, im, params: MaskConfig, *, device: th.device): + def new(cls, im, params: MaskConfig, *, device: th.device, track_metrics: bool = False): # de-normalize first infill = params.infill_strategy.new(im * 0.5 + 0.5, device=device) @@ -30,14 +30,22 @@ def new(cls, im, params: MaskConfig, *, device: th.device): if params.mask_size is None: size = (H, W) - fido = cls(size, infill=infill.unsqueeze(0), device=device, optimized=params.optimized) + metrics = Metrics() if track_metrics else None + + fido = cls(size, + infill=infill.unsqueeze(0), + device=device, + optimized=params.optimized, + metrics=metrics, + ) return fido def __init__(self, size: tuple, *, infill: th.Tensor, optimized: bool = True, - device = None + device = None, + metrics: Metrics = None ): super().__init__() @@ -46,6 +54,7 @@ def __init__(self, size: tuple, *, self._optimized = optimized self._infill = infill + self.metrics = metrics def to(self, *args, **kwargs): self._infill = self._infill.to(*args, **kwargs) @@ -205,7 +214,7 @@ def fit(self, im, y, clf, *, config: FIDOConfig, metrics: Metrics = None, update update_callback(i, is_last_step=True) - def plot(self, im, pred, clf, *, output=None, metrics: dict = None, thresh: float = None): + def plot(self, im, pred, clf, *, output=None, thresh: float = None): ssr_keep_rate = self._upsample(im, self._keep_rate(self.ssr_logit_p, deterministic=True)) ssr_keep_rate = ssr_keep_rate.detach().cpu().squeeze(0).numpy() # ssr_bernouli = self._upsample(im, self._keep_rate(self.ssr_logit_p, deterministic=False)) @@ -282,8 +291,8 @@ def plot(self, im, pred, clf, *, output=None, metrics: dict = None, thresh: floa plt.savefig(output) plt.close(fig=fig) - - if metrics is not None: + if self.metrics is not None: + metrics = self.metrics.as_dict() fig, axs = plt.subplots(len(metrics), 1, figsize=(16,9), squeeze=False) for i, (key, values) in enumerate(metrics.items()): ax = axs[np.unravel_index(i, axs.shape)] From ed30b750415c11507f707a7698c704300e914097 Mon Sep 17 00:00:00 2001 From: Dimitri Korsch Date: Thu, 1 Aug 2024 14:09:56 +0200 Subject: [PATCH 9/9] added possibility to initialize the dropout masks --- src/fido/module.py | 53 +++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/src/fido/module.py b/src/fido/module.py index 68ec4c1..c913fc7 100644 --- a/src/fido/module.py +++ b/src/fido/module.py @@ -2,6 +2,7 @@ import torch as th import torch.nn.functional as F import torchmetrics as tm +import typing as T from matplotlib import pyplot as plt from pathlib import Path @@ -17,10 +18,13 @@ def reg_scale(i, max_value): return 1.0 return i / max_value +def sigm_inv(x): + return th.log(x / (1 - x)) + class FIDO(th.nn.Module): @classmethod - def new(cls, im, params: MaskConfig, *, device: th.device, track_metrics: bool = False): + def new(cls, im, params: MaskConfig, *, device: th.device, init: th.Tensor = None, track_metrics: bool = False): # de-normalize first infill = params.infill_strategy.new(im * 0.5 + 0.5, device=device) @@ -34,44 +38,35 @@ def new(cls, im, params: MaskConfig, *, device: th.device, track_metrics: bool = fido = cls(size, infill=infill.unsqueeze(0), - device=device, optimized=params.optimized, metrics=metrics, - ) + init=init, + ) - return fido + return fido.to(device) def __init__(self, size: tuple, *, infill: th.Tensor, + init: T.Optional[th.Tensor] = None, optimized: bool = True, - device = None, metrics: Metrics = None ): super().__init__() + if init is None: + self.ssr_logit_p = th.nn.Parameter(th.zeros(size, requires_grad=True)) + self.sdr_logit_p = th.nn.Parameter(th.zeros(size, requires_grad=True)) + else: + assert init.shape == size, f"Expected shape {size} but got {init.shape}" + ssr_init = sigm_inv(1-init) # init = 1 - ssr_init.sigmoid() + sdr_init = sigm_inv(init) # init = sdr_init.sigmoid() - self.ssr_logit_p = th.zeros(size, device=device, requires_grad=True) - self.sdr_logit_p = th.zeros(size, device=device, requires_grad=True) + self.ssr_logit_p = th.nn.Parameter(ssr_init, requires_grad=True) + self.sdr_logit_p = th.nn.Parameter(sdr_init, requires_grad=True) - self._optimized = optimized - self._infill = infill + self.register_buffer("optimized", th.as_tensor(optimized)) + self.register_buffer("infill", infill) self.metrics = metrics - def to(self, *args, **kwargs): - self._infill = self._infill.to(*args, **kwargs) - self.ssr_logit_p = self.ssr_logit_p.to(*args, **kwargs) - self.sdr_logit_p = self.sdr_logit_p.to(*args, **kwargs) - return super().to(*args, **kwargs) - - def cpu(self): - self._infill = self._infill.cpu() - self.ssr_logit_p = self.ssr_logit_p.cpu() - self.sdr_logit_p = self.sdr_logit_p.cpu() - return super().cpu() - - @property - def params(self): - return [self.ssr_logit_p, self.sdr_logit_p] - @property def sdr_dropout_rate(self): return self.sdr_logit_p.sigmoid() @@ -102,7 +97,7 @@ def _keep_rate(self, logit_p, *, batch_size: int = 1, deterministic: bool = Fals return th.stack([1 - logit_p.sigmoid()], axis=0) else: # bernouli sampling - if self._optimized: + if self.optimized: return th.stack([sigmoid_concrete_dropout(logit_p) for _ in range(batch_size)], axis=0) else: return th.stack([concrete_dropout(logit_p.sigmoid()) for _ in range(batch_size)], axis=0) @@ -125,7 +120,7 @@ def _blend(self, X, logit_p, **kwargs): keep_rate = keep_rate.unsqueeze(1) X = X.unsqueeze(0) - return keep_rate * X + (1 - keep_rate) * self._infill + return keep_rate * X + (1 - keep_rate) * self.infill def ssr(self, X, **kwargs): return self._blend(X, self.ssr_logit_p, **kwargs) @@ -162,7 +157,7 @@ def objective(self, X, y, clf, *, batch_size: int, deterministic: bool = False, def fit(self, im, y, clf, *, config: FIDOConfig, metrics: Metrics = None, update_callback = None): - opt = th.optim.AdamW(self.params, lr=config.learning_rate, eps=0.1, weight_decay=config.l2) + opt = th.optim.AdamW(self.parameters(), lr=config.learning_rate, eps=0.1, weight_decay=config.l2) # opt = th.optim.SGD(self.params, lr=config.learning_rate, momentum=0.9, weight_decay=config.l2) opt.zero_grad() @@ -264,7 +259,7 @@ def plot(self, im, pred, clf, *, output=None, thresh: float = None): ssr_keep_rate ), ( - f"[1-SDR mask] min: {float(sdr_keep_rate.min()):.3f} | max: {float(sdr_keep_rate.max()):.3f}", + f"[1-SDR mask] min: {float((1-sdr_keep_rate).min()):.3f} | max: {float((1-sdr_keep_rate).max()):.3f}", 1-sdr_keep_rate ), ]