diff --git a/src/concrete_dropout/__main__.py b/src/concrete_dropout/__main__.py new file mode 100644 index 0000000..a1edfa6 --- /dev/null +++ b/src/concrete_dropout/__main__.py @@ -0,0 +1,20 @@ +import torch as th +import sys + +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +from concrete_dropout.default import concrete_dropout +from concrete_dropout.improved import sigmoid_concrete_dropout + +def main(): + size = 32 + X = th.zeros((size, size), requires_grad=True, dtype=th.float32) + u = th.zeros_like(X).uniform_() + + y0 = concrete_dropout(X.sigmoid(), u=u) + y1 = sigmoid_concrete_dropout(X, u=u) + assert th.allclose(y0, y1) + +main() diff --git a/src/concrete_dropout/default.py b/src/concrete_dropout/default.py index fe8c08f..ae3c94d 100644 --- a/src/concrete_dropout/default.py +++ b/src/concrete_dropout/default.py @@ -1,8 +1,10 @@ import numpy as np import torch as th -def concrete_dropout(dropout_rate: th.Tensor, temp: float = 0.1, eps: float = np.finfo(float).eps): - u = th.zeros_like(dropout_rate).uniform_() + +def concrete_dropout(dropout_rate: th.Tensor, temp: float = 0.1, u = None, eps: float = np.finfo(float).eps): + if u is None: + u = th.zeros_like(dropout_rate).uniform_() approx = ( (dropout_rate + eps).log() - @@ -13,4 +15,3 @@ def concrete_dropout(dropout_rate: th.Tensor, temp: float = 0.1, eps: float = np ) return 1 - (approx / temp).sigmoid() - diff --git a/src/concrete_dropout/improved.py b/src/concrete_dropout/improved.py index 56f3d57..fec2a8a 100644 --- a/src/concrete_dropout/improved.py +++ b/src/concrete_dropout/improved.py @@ -1,26 +1,24 @@ import numpy as np import torch as th -EPS = 1e-7 class SigmoidConcreteDropout(th.autograd.Function): @staticmethod - def forward(ctx, logit_p, temp: float = 1.0, u = None): + def forward(ctx, logit_p, temp: float = 1.0, u = None, *, eps = np.finfo(float).eps): """ our proposed simplification """ - global EPS temp = th.scalar_tensor(temp) if u is None: u = th.zeros_like(logit_p).uniform_() - noise = ((u + EPS) / (1 - u + EPS)).log() + noise = ((u + eps) / (1 - u + eps)).log() logit_p_temp = (logit_p + noise) / temp - res = logit_p_temp.sigmoid() + keep_rate = logit_p_temp.sigmoid() - ctx.save_for_backward(res, logit_p_temp, temp) - return res + ctx.save_for_backward(keep_rate, logit_p_temp, temp) + return 1 - keep_rate @staticmethod def backward(ctx, grad_output): @@ -28,12 +26,12 @@ def backward(ctx, grad_output): Gradient of random_tensor w.r.t logit_p is simply 1/temp * sigmoid(logit_p_temp)^2 * exp(-logit_p_temp) """ - res, logit_p_temp, temp = ctx.saved_tensors - grad = th.zeros_like(res) - mask = res != 0 - grad[mask] = res[mask]**2 * (-logit_p_temp[mask]).exp() / temp + keep_rate, logit_p_temp, temp = ctx.saved_tensors + grad = th.zeros_like(keep_rate) + mask = keep_rate != 0 + grad[mask] = keep_rate[mask]**2 * (-logit_p_temp[mask]).exp() / temp - return grad * grad_output, None, None + return -grad * grad_output, None, None @classmethod def _check_grad(clf, temp: float = 1.0, size: int = 16, *, dtype = th.float64): diff --git a/src/fido/__init__.py b/src/fido/__init__.py index e69de29..d3f7573 100644 --- a/src/fido/__init__.py +++ b/src/fido/__init__.py @@ -0,0 +1,19 @@ +from fido.module import FIDO +from fido.infill import Infill +from fido.configs import FIDOConfig +from fido.configs import MaskConfig +from fido.metrics import Metrics +from fido.module import log_odds +from fido.storage import load_maps +from fido.storage import dump_maps + +__all__ = [ + "FIDO", + "Infill", + "FIDOConfig", + "MaskConfig", + "Metrics", + "log_odds", + "load_maps", + "dump_maps", +] diff --git a/src/fido/configs.py b/src/fido/configs.py index 3f09232..c656ba3 100644 --- a/src/fido/configs.py +++ b/src/fido/configs.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from fido.infill import Infill @dataclass class FIDOConfig: @@ -23,4 +24,4 @@ def reg_params(self) -> dict: class MaskConfig: optimized: bool = True mask_size: int = None - infill_strategy: str = "blur" + infill_strategy: Infill = Infill.BLUR diff --git a/src/fido/infill.py b/src/fido/infill.py index 032f13f..20ff347 100644 --- a/src/fido/infill.py +++ b/src/fido/infill.py @@ -1,11 +1,13 @@ import cv2 +import enum import numpy as np import torch as th +from functools import partial from numpy.lib.stride_tricks import sliding_window_view from skimage.measure import regionprops from torchvision.transforms import functional as tr - +from pathlib import Path from fido.metrics import _thresh @@ -13,56 +15,120 @@ def patches(im, size): window_shape = (size, size, im.shape[-1]) return sliding_window_view(im, window_shape)[::size, ::size] -def new_infill(im: th.Tensor, strategy: str, *, mean: float = 0.5, std: float = 0.5, size: int = 9, device=None): - - if strategy == "original": - infill = im.clone() - - elif strategy == "zeros": - infill = th.zeros_like(im, device=device) - - elif strategy == "uniform": - infill = th.zeros_like(im, device=device).uniform_(0, 1) - - elif strategy == "normal": - infill = th.zeros_like(im, device=device).normal_(std=0.2) - - elif strategy == "mean": - mean_pixel = im.mean(axis=(1,2), keepdim=True) - infill = mean_pixel.expand(im.shape) - - elif strategy == "blur": - # should result in std=10, same as in the paper - # de-normalize first - infill = tr.gaussian_blur(im.to(device), kernel_size=75).to(im.device) - - elif strategy == "local": - pass - - elif strategy == "knockoff": - infill = np.zeros_like(im) - windows = patches(im.detach().cpu().numpy().transpose(1,2,0), size=size) - rows, cols, *_ = windows.shape - for i in range(rows*cols): - row, col, _chan = np.unravel_index(i, (rows, cols, 1)) - patch = windows[row, col, _chan] - - x0, y0 = col*size, row*size - x1, y1 = (col+1)*size, (row+1)*size - - idxs = np.random.permutation(size**2) - idxs = np.unravel_index(idxs, (size, size)) - - infill[:, y0:y1, x0:x1] = patch[idxs].reshape(patch.shape).transpose(2,0,1) - - infill = th.tensor(infill) - else: - raise ValueError(f"Unknown in-fill strategy: {strategy}") - - return (infill - mean) / std - - +gan = None +class Infill(enum.Enum): + ORIGINIAL = enum.auto() + ZEROS = enum.auto() + UNIFORM = enum.auto() + NORMAL = enum.auto() + MEAN = enum.auto() + BLUR = enum.auto() + GAN = enum.auto() + KNOCKOFF = enum.auto() + + def normalize(self, arr: th.Tensor, mean: float, std: float) -> th.Tensor: + return (arr - mean) / std + + def new(self, im: th.Tensor, *, mean: float = 0.5, std: float = 0.5, size: int = 9, device: th.device = None): + global gan + _normalize = partial(self.normalize, mean=mean, std=std) + if self == Infill.BLUR: + # should result in std=10, same as in the paper + # de-normalize first + return _normalize(tr.gaussian_blur(im.to(device), kernel_size=75).to(im.device)) + + if self == Infill.ORIGINIAL: + return _normalize(im.clone()) + + if self == Infill.ZEROS: + return _normalize(th.zeros_like(im, device=device)) + + if self == Infill.UNIFORM: + res = th.zeros_like(im, device=device).uniform_(0, 1) + return _normalize(self.normalize(res, mean=res.mean(), std=res.std())) + + if self == Infill.NORMAL: + res = th.zeros_like(im, device=device).normal_(0, 1) + return _normalize(self.normalize(res, mean=res.mean(), std=res.std())) + + if self == Infill.MEAN: + mean_pixel = im.mean(axis=(1,2), keepdim=True) + return _normalize(mean_pixel.expand(im.shape)) + + if self == Infill.KNOCKOFF: + + infill = np.zeros_like(im) + windows = patches(im.detach().cpu().numpy().transpose(1,2,0), size=size) + rows, cols, *_ = windows.shape + for i in range(rows*cols): + row, col, _chan = np.unravel_index(i, (rows, cols, 1)) + patch = windows[row, col, _chan] + + x0, y0 = col*size, row*size + x1, y1 = (col+1)*size, (row+1)*size + + idxs = np.random.permutation(size**2) + idxs = np.unravel_index(idxs, (size, size)) + + infill[:, y0:y1, x0:x1] = patch[idxs].reshape(patch.shape).transpose(2,0,1) + + return _normalize(th.tensor(infill)) + + if self == Infill.GAN: + if gan is None: + gan = GAN(device=im.device) + infill = gan(im) + return infill + + raise NotImplementedError(f"Strategy is not implemented yet: {self}!") + +import dmfn # noqa: E402 +from dmfn.models.networks import define_G # noqa: E402 + +class GAN: + + DEFAULT_WEIGHTS: Path = Path(dmfn.__file__).resolve().parent.parent.parent / "outputs/cub/checkpoints/latest_G.pth" + + @staticmethod + def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + th.nn.init.orthogonal_(m.weight.data, 1.0) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + elif classname.find('Linear') != -1: + th.nn.init.orthogonal_(m.weight.data, 1.0) + if m.bias is not None: + m.bias.data.fill_(0.0) + + def __init__(self, weights: Path = None, device: th.device = th.device("cpu")): + + if weights is None: + weights = GAN.DEFAULT_WEIGHTS + assert Path(weights).exists(), \ + f"Could not find weights: {weights}" + + opt = dict(network_G=dict( which_model_G='DMFN', in_nc=4, out_nc=3, nf=64, n_res=8),is_train=False) + self.generator = define_G(opt).to(device) + self.generator.load_state_dict(th.load(weights), strict=True) + self.generator.eval() + + def __call__(self, im: th.Tensor, *, grid_size: int = 16, size: tuple = (256, 256)): + mask = th.randint(0, 2, size=(1, 1, grid_size, grid_size), dtype=im.dtype, device=im.device) + orig_size = im.size()[-2:] + im = tr.resize(im, size) + mask = tr.resize(mask, size, interpolation=tr.InterpolationMode.NEAREST) + + X1 = th.cat([im * mask, 1-mask], dim=1) + X2 = th.cat([im * (1-mask), mask], dim=1) + + X = th.cat([X1, X2], dim=0) + out = self.generator(X).detach() + # combine the result from the generated images + res = out[0] * (1 - mask[0]) + out[1] * mask[0] + return tr.resize(res, orig_size) def _calc_bbox(mask, min_size, *, pad: int = 10, squared: bool = True): props = regionprops(mask) diff --git a/src/fido/module.py b/src/fido/module.py index 0cd8ae9..c913fc7 100644 --- a/src/fido/module.py +++ b/src/fido/module.py @@ -2,6 +2,7 @@ import torch as th import torch.nn.functional as F import torchmetrics as tm +import typing as T from matplotlib import pyplot as plt from pathlib import Path @@ -10,7 +11,6 @@ from concrete_dropout import sigmoid_concrete_dropout from fido.configs import FIDOConfig from fido.configs import MaskConfig -from fido.infill import new_infill from fido.metrics import Metrics def reg_scale(i, max_value): @@ -18,40 +18,54 @@ def reg_scale(i, max_value): return 1.0 return i / max_value +def sigm_inv(x): + return th.log(x / (1 - x)) + class FIDO(th.nn.Module): @classmethod - def new(cls, im, params: MaskConfig, *, device): + def new(cls, im, params: MaskConfig, *, device: th.device, init: th.Tensor = None, track_metrics: bool = False): # de-normalize first - infill = new_infill(im * 0.5 + 0.5, params.infill_strategy, device=device) + infill = params.infill_strategy.new(im * 0.5 + 0.5, device=device) C, H, W = im.shape size = (params.mask_size, params.mask_size) if params.mask_size is None: size = (H, W) - fido = cls(size, infill=infill.unsqueeze(0), device=device, optimized=params.optimized) + metrics = Metrics() if track_metrics else None + + fido = cls(size, + infill=infill.unsqueeze(0), + optimized=params.optimized, + metrics=metrics, + init=init, + ) - return fido + return fido.to(device) def __init__(self, size: tuple, *, infill: th.Tensor, + init: T.Optional[th.Tensor] = None, optimized: bool = True, - device = None + metrics: Metrics = None ): super().__init__() + if init is None: + self.ssr_logit_p = th.nn.Parameter(th.zeros(size, requires_grad=True)) + self.sdr_logit_p = th.nn.Parameter(th.zeros(size, requires_grad=True)) + else: + assert init.shape == size, f"Expected shape {size} but got {init.shape}" + ssr_init = sigm_inv(1-init) # init = 1 - ssr_init.sigmoid() + sdr_init = sigm_inv(init) # init = sdr_init.sigmoid() - self.ssr_logit_p = th.zeros(size, device=device, requires_grad=True) - self.sdr_logit_p = th.zeros(size, device=device, requires_grad=True) - - self._optimized = optimized - self._infill = infill - + self.ssr_logit_p = th.nn.Parameter(ssr_init, requires_grad=True) + self.sdr_logit_p = th.nn.Parameter(sdr_init, requires_grad=True) - @property - def params(self): - return [self.ssr_logit_p, self.sdr_logit_p] + self.register_buffer("optimized", th.as_tensor(optimized)) + self.register_buffer("infill", infill) + self.metrics = metrics @property def sdr_dropout_rate(self): @@ -69,7 +83,7 @@ def tv_loss(self, *, reduction="mean"): # p = th.stack([self.ssr_logit_p, self.sdr_logit_p], axis=0) p = th.stack([self.ssr_dropout_rate, self.sdr_dropout_rate], axis=0) # p = self.joint_dropout_rate[None] - return tm.TotalVariation(reduction=reduction).to(p.device)(p[None]) + return tm.image.TotalVariation(reduction=reduction).to(p.device)(p[None]) def l1_norm(self): # return self.joint_dropout_rate.sum() @@ -83,7 +97,7 @@ def _keep_rate(self, logit_p, *, batch_size: int = 1, deterministic: bool = Fals return th.stack([1 - logit_p.sigmoid()], axis=0) else: # bernouli sampling - if self._optimized: + if self.optimized: return th.stack([sigmoid_concrete_dropout(logit_p) for _ in range(batch_size)], axis=0) else: return th.stack([concrete_dropout(logit_p.sigmoid()) for _ in range(batch_size)], axis=0) @@ -106,7 +120,7 @@ def _blend(self, X, logit_p, **kwargs): keep_rate = keep_rate.unsqueeze(1) X = X.unsqueeze(0) - return keep_rate * X + (1 - keep_rate) * self._infill + return keep_rate * X + (1 - keep_rate) * self.infill def ssr(self, X, **kwargs): return self._blend(X, self.ssr_logit_p, **kwargs) @@ -120,20 +134,31 @@ def forward(self, X, **kwargs): sdr = self.sdr(X, **kwargs) return ssr, sdr - def objective(self, X, y, clf, *, batch_size: int, deterministic: bool = False): + def objective(self, X, y, clf, *, batch_size: int, deterministic: bool = False, add_deterministic: bool = False): ssr, sdr = self(X, batch_size=batch_size, deterministic=deterministic) n = len(ssr) _x = th.concatenate([ssr, sdr]) - prob, odds = clf.log_odds(_x, c=y) + if hasattr(clf, "log_odds"): + prob, odds = clf.log_odds(_x, c=y) + else: + logits = clf(_x) + prob, odds = log_odds(logits, c=y) + ssr_prob, sdr_prob = prob[:n], prob[n:] ssr_odds, sdr_odds = odds[:n], odds[n:] # sdr - ssr <== minimizing sdr and maximizing ssr probabilities loss = (sdr_odds - ssr_odds).mean() + + if not deterministic and add_deterministic: + _, _, det_loss, _, _ = self.objective(X, y, clf, batch_size=1, deterministic=True, add_deterministic=False) + loss += det_loss + return (ssr, sdr), (ssr_prob.mean(), sdr_prob.mean()), loss, self.l1_norm(), self.tv_loss() def fit(self, im, y, clf, *, config: FIDOConfig, metrics: Metrics = None, update_callback = None): - opt = th.optim.AdamW(self.params, lr=config.learning_rate, eps=0.1, weight_decay=config.l2) + opt = th.optim.AdamW(self.parameters(), lr=config.learning_rate, eps=0.1, weight_decay=config.l2) + # opt = th.optim.SGD(self.params, lr=config.learning_rate, momentum=0.9, weight_decay=config.l2) opt.zero_grad() ssr_grad, sdr_grad = 1, 1 @@ -142,7 +167,6 @@ def fit(self, im, y, clf, *, config: FIDOConfig, metrics: Metrics = None, update if config.approx_steps is None: # no approximation masks, probs, loss, l1_norm, tvl = self.objective(im, y, clf, batch_size=config.batch_size) - else: if i % config.approx_steps == 0: (ssr, sdr), probs, loss, l1_norm, tvl = self.objective(im, y, clf, batch_size=config.batch_size) @@ -181,10 +205,11 @@ def fit(self, im, y, clf, *, config: FIDOConfig, metrics: Metrics = None, update opt.step() opt.zero_grad() - update_callback(i, is_last_step=True) + if update_callback is not None: + update_callback(i, is_last_step=True) - def plot(self, im, pred, clf, *, output=None, metrics: dict = None): + def plot(self, im, pred, clf, *, output=None, thresh: float = None): ssr_keep_rate = self._upsample(im, self._keep_rate(self.ssr_logit_p, deterministic=True)) ssr_keep_rate = ssr_keep_rate.detach().cpu().squeeze(0).numpy() # ssr_bernouli = self._upsample(im, self._keep_rate(self.ssr_logit_p, deterministic=False)) @@ -195,6 +220,9 @@ def plot(self, im, pred, clf, *, output=None, metrics: dict = None): joint_keep_rate = np.sqrt(ssr_keep_rate * (1-sdr_keep_rate)) + if thresh is not None: + joint_keep_rate[joint_keep_rate < thresh] = np.nan + cls_id = pred.argmax() orig_im = (im * 0.5 + 0.5).permute(1, 2, 0).numpy() @@ -231,7 +259,7 @@ def plot(self, im, pred, clf, *, output=None, metrics: dict = None): ssr_keep_rate ), ( - f"[1-SDR mask] min: {float(sdr_keep_rate.min()):.3f} | max: {float(sdr_keep_rate.max()):.3f}", + f"[1-SDR mask] min: {float((1-sdr_keep_rate).min()):.3f} | max: {float((1-sdr_keep_rate).max()):.3f}", 1-sdr_keep_rate ), ] @@ -256,9 +284,10 @@ def plot(self, im, pred, clf, *, output=None, metrics: dict = None): plt.show() else: plt.savefig(output) + plt.close(fig=fig) - - if metrics is not None: + if self.metrics is not None: + metrics = self.metrics.as_dict() fig, axs = plt.subplots(len(metrics), 1, figsize=(16,9), squeeze=False) for i, (key, values) in enumerate(metrics.items()): ax = axs[np.unravel_index(i, axs.shape)] @@ -274,4 +303,13 @@ def plot(self, im, pred, clf, *, output=None, metrics: dict = None): loss_graphs = output.with_suffix(f".losses{output.suffix}") plt.savefig(loss_graphs) - plt.close() + plt.close(fig=fig) + + +def log_odds(logits, c): + # normalized log-probabilities + log_probs = F.log_softmax(logits, dim=1) + mask = th.ones_like(log_probs[:1]) + mask[:, c] = 0 + odds = log_probs[:, c] - th.logsumexp(log_probs * mask, dim=1) + return log_probs[:, c].exp(), odds diff --git a/src/fido/storage.py b/src/fido/storage.py index 51efc30..9f1cce1 100644 --- a/src/fido/storage.py +++ b/src/fido/storage.py @@ -51,7 +51,7 @@ def __getitem__(self, i): return self._cache[key].unpack() def __len__(self) -> int: - return len(self._npzfile.files) // 2 + return len(self.npzfile.files) // 2 def _maps_file_name(subset, *, lazy_load: bool = True): suffix = ".lazy" if lazy_load else ""