Skip to content

Commit

Permalink
Add normal stencil
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters94 committed Sep 25, 2023
1 parent 6773278 commit c39ec29
Show file tree
Hide file tree
Showing 42 changed files with 7,689 additions and 2 deletions.
2 changes: 1 addition & 1 deletion apps/stable_diffusion/src/utils/stable_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def is_valid_file(arg):

p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble"],
choices=["canny", "openpose", "scribble", "normal"],
help="Enable the stencil feature.",
)

Expand Down
3 changes: 3 additions & 0 deletions apps/stable_diffusion/src/utils/stencils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector
from apps.stable_diffusion.src.utils.stencils.normal_bae import (
NormalBaeDetector,
)
140 changes: 140 additions & 0 deletions apps/stable_diffusion/src/utils/stencils/normal_bae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
import types
import warnings

import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image

from .nets.NNET import NNET


def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y


# load model
def load_checkpoint(fpath, model):
ckpt = torch.load(fpath, map_location="cpu")["model"]

load_dict = {}
for k, v in ckpt.items():
if k.startswith("module."):
k_ = k.replace("module.", "")
load_dict[k_] = v
else:
load_dict[k] = v

model.load_state_dict(load_dict)
return model


class NormalBaeDetector:
def __init__(self, model):
self.model = model
self.norm = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)

@classmethod
def from_pretrained(
cls, pretrained_model_or_path, filename=None, cache_dir=None
):
filename = filename or "scannet.pt"

if os.path.isdir(pretrained_model_or_path):
model_path = os.path.join(pretrained_model_or_path, filename)
else:
model_path = hf_hub_download(
pretrained_model_or_path, filename, cache_dir=cache_dir
)

args = types.SimpleNamespace()
args.mode = "client"
args.architecture = "BN"
args.pretrained = "scannet"
args.sampling_ratio = 0.4
args.importance_ratio = 0.7
model = NNET(args)
model = load_checkpoint(model_path, model)
model.eval()

return cls(model)

def to(self, device):
self.model.to(device)
return self

# def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
def __call__(self, input_image, output_type="pil", **kwargs):
if "return_pil" in kwargs:
warnings.warn(
"return_pil is deprecated. Use output_type instead.",
DeprecationWarning,
)
output_type = "pil" if kwargs["return_pil"] else "np"
if type(output_type) is bool:
warnings.warn(
"Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions"
)
if output_type:
output_type = "pil"

device = next(iter(self.model.parameters())).device
if not isinstance(input_image, np.ndarray):
input_image = np.array(input_image, dtype=np.uint8)

input_image = HWC3(input_image)
# input_image = resize_image(input_image, detect_resolution)

assert input_image.ndim == 3
image_normal = input_image
with torch.no_grad():
image_normal = torch.from_numpy(image_normal).float().to(device)
image_normal = image_normal / 255.0
image_normal = rearrange(image_normal, "h w c -> 1 c h w")
image_normal = self.norm(image_normal)

normal = self.model(image_normal)
normal = normal[0][-1][:, :3]
# d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5
# d = torch.maximum(d, torch.ones_like(d) * 1e-5)
# normal /= d
normal = ((normal + 1) * 0.5).clip(0, 1)

normal = rearrange(normal[0], "c h w -> h w c").cpu().numpy()
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)

detected_map = normal_image
detected_map = HWC3(detected_map)

# img = resize_image(input_image, image_resolution)
H, W, C = input_image.shape

detected_map = cv2.resize(
detected_map, (W, H), interpolation=cv2.INTER_LINEAR
)

if output_type == "pil":
detected_map = Image.fromarray(detected_map)

return detected_map
22 changes: 22 additions & 0 deletions apps/stable_diffusion/src/utils/stencils/normal_bae/nets/NNET.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from .submodules.encoder import Encoder
from .submodules.decoder import Decoder


class NNET(nn.Module):
def __init__(self, args):
super(NNET, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder(args)

def get_1x_lr_params(self): # lr/10 learning rate
return self.encoder.parameters()

def get_10x_lr_params(self): # lr learning rate
return self.decoder.parameters()

def forward(self, img, **kwargs):
return self.decoder(self.encoder(img), **kwargs)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from .submodules.submodules import UpSampleBN, norm_normalize


# This is the baseline encoder-decoder we used in the ablation study
class NNET(nn.Module):
def __init__(self, args=None):
super(NNET, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder(num_classes=4)

def forward(self, x, **kwargs):
out = self.decoder(self.encoder(x), **kwargs)

# Bilinearly upsample the output to match the input resolution
up_out = F.interpolate(
out,
size=[x.size(2), x.size(3)],
mode="bilinear",
align_corners=False,
)

# L2-normalize the first three channels / ensure positive value for concentration parameters (kappa)
up_out = norm_normalize(up_out)
return up_out

def get_1x_lr_params(self): # lr/10 learning rate
return self.encoder.parameters()

def get_10x_lr_params(self): # lr learning rate
modules = [self.decoder]
for m in modules:
yield from m.parameters()


# Encoder
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()

basemodel_name = "tf_efficientnet_b5_ap"
basemodel = torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
basemodel_name,
pretrained=True,
)

# Remove last layer
basemodel.global_pool = nn.Identity()
basemodel.classifier = nn.Identity()

self.original_model = basemodel

def forward(self, x):
features = [x]
for k, v in self.original_model._modules.items():
if k == "blocks":
for ki, vi in v._modules.items():
features.append(vi(features[-1]))
else:
features.append(v(features[-1]))
return features


# Decoder (no pixel-wise MLP, no uncertainty-guided sampling)
class Decoder(nn.Module):
def __init__(self, num_classes=4):
super(Decoder, self).__init__()
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
self.conv3 = nn.Conv2d(
128, num_classes, kernel_size=3, stride=1, padding=1
)

def forward(self, features):
x_block0, x_block1, x_block2, x_block3, x_block4 = (
features[4],
features[5],
features[6],
features[8],
features[11],
)
x_d0 = self.conv2(x_block4)
x_d1 = self.up1(x_d0, x_block3)
x_d2 = self.up2(x_d1, x_block2)
x_d3 = self.up3(x_d2, x_block1)
x_d4 = self.up4(x_d3, x_block0)
out = self.conv3(x_d4)
return out
Empty file.
Loading

0 comments on commit c39ec29

Please sign in to comment.