Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some updates in readability and performance #3

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/concrete_dropout/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch as th
import sys

from pathlib import Path

sys.path.append(str(Path(__file__).parent.parent))

from concrete_dropout.default import concrete_dropout
from concrete_dropout.improved import sigmoid_concrete_dropout

def main():
size = 32
X = th.zeros((size, size), requires_grad=True, dtype=th.float32)
u = th.zeros_like(X).uniform_()

y0 = concrete_dropout(X.sigmoid(), u=u)
y1 = sigmoid_concrete_dropout(X, u=u)
assert th.allclose(y0, y1)

main()
7 changes: 4 additions & 3 deletions src/concrete_dropout/default.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import torch as th

def concrete_dropout(dropout_rate: th.Tensor, temp: float = 0.1, eps: float = np.finfo(float).eps):
u = th.zeros_like(dropout_rate).uniform_()

def concrete_dropout(dropout_rate: th.Tensor, temp: float = 0.1, u = None, eps: float = np.finfo(float).eps):
if u is None:
u = th.zeros_like(dropout_rate).uniform_()

approx = (
(dropout_rate + eps).log() -
Expand All @@ -13,4 +15,3 @@ def concrete_dropout(dropout_rate: th.Tensor, temp: float = 0.1, eps: float = np
)

return 1 - (approx / temp).sigmoid()

22 changes: 10 additions & 12 deletions src/concrete_dropout/improved.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,37 @@
import numpy as np
import torch as th

EPS = 1e-7
class SigmoidConcreteDropout(th.autograd.Function):

@staticmethod
def forward(ctx, logit_p, temp: float = 1.0, u = None):
def forward(ctx, logit_p, temp: float = 1.0, u = None, *, eps = np.finfo(float).eps):
"""
our proposed simplification
"""
global EPS
temp = th.scalar_tensor(temp)

if u is None:
u = th.zeros_like(logit_p).uniform_()

noise = ((u + EPS) / (1 - u + EPS)).log()
noise = ((u + eps) / (1 - u + eps)).log()
logit_p_temp = (logit_p + noise) / temp
res = logit_p_temp.sigmoid()
keep_rate = logit_p_temp.sigmoid()

ctx.save_for_backward(res, logit_p_temp, temp)
return res
ctx.save_for_backward(keep_rate, logit_p_temp, temp)
return 1 - keep_rate

@staticmethod
def backward(ctx, grad_output):
"""
Gradient of random_tensor w.r.t logit_p is simply
1/temp * sigmoid(logit_p_temp)^2 * exp(-logit_p_temp)
"""
res, logit_p_temp, temp = ctx.saved_tensors
grad = th.zeros_like(res)
mask = res != 0
grad[mask] = res[mask]**2 * (-logit_p_temp[mask]).exp() / temp
keep_rate, logit_p_temp, temp = ctx.saved_tensors
grad = th.zeros_like(keep_rate)
mask = keep_rate != 0
grad[mask] = keep_rate[mask]**2 * (-logit_p_temp[mask]).exp() / temp

return grad * grad_output, None, None
return -grad * grad_output, None, None

@classmethod
def _check_grad(clf, temp: float = 1.0, size: int = 16, *, dtype = th.float64):
Expand Down
19 changes: 19 additions & 0 deletions src/fido/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from fido.module import FIDO
from fido.infill import Infill
from fido.configs import FIDOConfig
from fido.configs import MaskConfig
from fido.metrics import Metrics
from fido.module import log_odds
from fido.storage import load_maps
from fido.storage import dump_maps

__all__ = [
"FIDO",
"Infill",
"FIDOConfig",
"MaskConfig",
"Metrics",
"log_odds",
"load_maps",
"dump_maps",
]
3 changes: 2 additions & 1 deletion src/fido/configs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from fido.infill import Infill

@dataclass
class FIDOConfig:
Expand All @@ -23,4 +24,4 @@ def reg_params(self) -> dict:
class MaskConfig:
optimized: bool = True
mask_size: int = None
infill_strategy: str = "blur"
infill_strategy: Infill = Infill.BLUR
166 changes: 116 additions & 50 deletions src/fido/infill.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,134 @@
import cv2
import enum
import numpy as np
import torch as th

from functools import partial
from numpy.lib.stride_tricks import sliding_window_view
from skimage.measure import regionprops
from torchvision.transforms import functional as tr

from pathlib import Path
from fido.metrics import _thresh


def patches(im, size):
window_shape = (size, size, im.shape[-1])
return sliding_window_view(im, window_shape)[::size, ::size]

def new_infill(im: th.Tensor, strategy: str, *, mean: float = 0.5, std: float = 0.5, size: int = 9, device=None):

if strategy == "original":
infill = im.clone()

elif strategy == "zeros":
infill = th.zeros_like(im, device=device)

elif strategy == "uniform":
infill = th.zeros_like(im, device=device).uniform_(0, 1)

elif strategy == "normal":
infill = th.zeros_like(im, device=device).normal_(std=0.2)

elif strategy == "mean":
mean_pixel = im.mean(axis=(1,2), keepdim=True)
infill = mean_pixel.expand(im.shape)

elif strategy == "blur":
# should result in std=10, same as in the paper
# de-normalize first
infill = tr.gaussian_blur(im.to(device), kernel_size=75).to(im.device)

elif strategy == "local":
pass

elif strategy == "knockoff":
infill = np.zeros_like(im)
windows = patches(im.detach().cpu().numpy().transpose(1,2,0), size=size)
rows, cols, *_ = windows.shape
for i in range(rows*cols):
row, col, _chan = np.unravel_index(i, (rows, cols, 1))
patch = windows[row, col, _chan]

x0, y0 = col*size, row*size
x1, y1 = (col+1)*size, (row+1)*size

idxs = np.random.permutation(size**2)
idxs = np.unravel_index(idxs, (size, size))

infill[:, y0:y1, x0:x1] = patch[idxs].reshape(patch.shape).transpose(2,0,1)

infill = th.tensor(infill)
else:
raise ValueError(f"Unknown in-fill strategy: {strategy}")

return (infill - mean) / std


gan = None

class Infill(enum.Enum):
ORIGINIAL = enum.auto()
ZEROS = enum.auto()
UNIFORM = enum.auto()
NORMAL = enum.auto()
MEAN = enum.auto()
BLUR = enum.auto()
GAN = enum.auto()
KNOCKOFF = enum.auto()

def normalize(self, arr: th.Tensor, mean: float, std: float) -> th.Tensor:
return (arr - mean) / std

def new(self, im: th.Tensor, *, mean: float = 0.5, std: float = 0.5, size: int = 9, device: th.device = None):
global gan
_normalize = partial(self.normalize, mean=mean, std=std)
if self == Infill.BLUR:
# should result in std=10, same as in the paper
# de-normalize first
return _normalize(tr.gaussian_blur(im.to(device), kernel_size=75).to(im.device))

if self == Infill.ORIGINIAL:
return _normalize(im.clone())

if self == Infill.ZEROS:
return _normalize(th.zeros_like(im, device=device))

if self == Infill.UNIFORM:
res = th.zeros_like(im, device=device).uniform_(0, 1)
return _normalize(self.normalize(res, mean=res.mean(), std=res.std()))

if self == Infill.NORMAL:
res = th.zeros_like(im, device=device).normal_(0, 1)
return _normalize(self.normalize(res, mean=res.mean(), std=res.std()))

if self == Infill.MEAN:
mean_pixel = im.mean(axis=(1,2), keepdim=True)
return _normalize(mean_pixel.expand(im.shape))

if self == Infill.KNOCKOFF:

infill = np.zeros_like(im)
windows = patches(im.detach().cpu().numpy().transpose(1,2,0), size=size)
rows, cols, *_ = windows.shape
for i in range(rows*cols):
row, col, _chan = np.unravel_index(i, (rows, cols, 1))
patch = windows[row, col, _chan]

x0, y0 = col*size, row*size
x1, y1 = (col+1)*size, (row+1)*size

idxs = np.random.permutation(size**2)
idxs = np.unravel_index(idxs, (size, size))

infill[:, y0:y1, x0:x1] = patch[idxs].reshape(patch.shape).transpose(2,0,1)

return _normalize(th.tensor(infill))

if self == Infill.GAN:
if gan is None:
gan = GAN(device=im.device)
infill = gan(im)
return infill

raise NotImplementedError(f"Strategy is not implemented yet: {self}!")

import dmfn # noqa: E402
from dmfn.models.networks import define_G # noqa: E402

class GAN:

DEFAULT_WEIGHTS: Path = Path(dmfn.__file__).resolve().parent.parent.parent / "outputs/cub/checkpoints/latest_G.pth"

@staticmethod
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
th.nn.init.orthogonal_(m.weight.data, 1.0)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('Linear') != -1:
th.nn.init.orthogonal_(m.weight.data, 1.0)
if m.bias is not None:
m.bias.data.fill_(0.0)

def __init__(self, weights: Path = None, device: th.device = th.device("cpu")):

if weights is None:
weights = GAN.DEFAULT_WEIGHTS
assert Path(weights).exists(), \
f"Could not find weights: {weights}"

opt = dict(network_G=dict( which_model_G='DMFN', in_nc=4, out_nc=3, nf=64, n_res=8),is_train=False)
self.generator = define_G(opt).to(device)
self.generator.load_state_dict(th.load(weights), strict=True)
self.generator.eval()

def __call__(self, im: th.Tensor, *, grid_size: int = 16, size: tuple = (256, 256)):
mask = th.randint(0, 2, size=(1, 1, grid_size, grid_size), dtype=im.dtype, device=im.device)
orig_size = im.size()[-2:]
im = tr.resize(im, size)
mask = tr.resize(mask, size, interpolation=tr.InterpolationMode.NEAREST)

X1 = th.cat([im * mask, 1-mask], dim=1)
X2 = th.cat([im * (1-mask), mask], dim=1)

X = th.cat([X1, X2], dim=0)
out = self.generator(X).detach()
# combine the result from the generated images
res = out[0] * (1 - mask[0]) + out[1] * mask[0]
return tr.resize(res, orig_size)

def _calc_bbox(mask, min_size, *, pad: int = 10, squared: bool = True):
props = regionprops(mask)
Expand Down
Loading