-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6773278
commit 7c8c1fa
Showing
42 changed files
with
7,786 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector | ||
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector | ||
from apps.stable_diffusion.src.utils.stencils.normal_bae import ( | ||
NormalBaeDetector, | ||
) |
140 changes: 140 additions & 0 deletions
140
apps/stable_diffusion/src/utils/stencils/normal_bae/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import os | ||
import types | ||
import warnings | ||
|
||
import cv2 | ||
import numpy as np | ||
import torch | ||
import torchvision.transforms as transforms | ||
from einops import rearrange | ||
from huggingface_hub import hf_hub_download | ||
from PIL import Image | ||
|
||
from .nets.NNET import NNET | ||
|
||
|
||
def HWC3(x): | ||
assert x.dtype == np.uint8 | ||
if x.ndim == 2: | ||
x = x[:, :, None] | ||
assert x.ndim == 3 | ||
H, W, C = x.shape | ||
assert C == 1 or C == 3 or C == 4 | ||
if C == 3: | ||
return x | ||
if C == 1: | ||
return np.concatenate([x, x, x], axis=2) | ||
if C == 4: | ||
color = x[:, :, 0:3].astype(np.float32) | ||
alpha = x[:, :, 3:4].astype(np.float32) / 255.0 | ||
y = color * alpha + 255.0 * (1.0 - alpha) | ||
y = y.clip(0, 255).astype(np.uint8) | ||
return y | ||
|
||
|
||
# load model | ||
def load_checkpoint(fpath, model): | ||
ckpt = torch.load(fpath, map_location="cpu")["model"] | ||
|
||
load_dict = {} | ||
for k, v in ckpt.items(): | ||
if k.startswith("module."): | ||
k_ = k.replace("module.", "") | ||
load_dict[k_] = v | ||
else: | ||
load_dict[k] = v | ||
|
||
model.load_state_dict(load_dict) | ||
return model | ||
|
||
|
||
class NormalBaeDetector: | ||
def __init__(self, model): | ||
self.model = model | ||
self.norm = transforms.Normalize( | ||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | ||
) | ||
|
||
@classmethod | ||
def from_pretrained( | ||
cls, pretrained_model_or_path, filename=None, cache_dir=None | ||
): | ||
filename = filename or "scannet.pt" | ||
|
||
if os.path.isdir(pretrained_model_or_path): | ||
model_path = os.path.join(pretrained_model_or_path, filename) | ||
else: | ||
model_path = hf_hub_download( | ||
pretrained_model_or_path, filename, cache_dir=cache_dir | ||
) | ||
|
||
args = types.SimpleNamespace() | ||
args.mode = "client" | ||
args.architecture = "BN" | ||
args.pretrained = "scannet" | ||
args.sampling_ratio = 0.4 | ||
args.importance_ratio = 0.7 | ||
model = NNET(args) | ||
model = load_checkpoint(model_path, model) | ||
model.eval() | ||
|
||
return cls(model) | ||
|
||
def to(self, device): | ||
self.model.to(device) | ||
return self | ||
|
||
# def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): | ||
def __call__(self, input_image, output_type="pil", **kwargs): | ||
if "return_pil" in kwargs: | ||
warnings.warn( | ||
"return_pil is deprecated. Use output_type instead.", | ||
DeprecationWarning, | ||
) | ||
output_type = "pil" if kwargs["return_pil"] else "np" | ||
if type(output_type) is bool: | ||
warnings.warn( | ||
"Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions" | ||
) | ||
if output_type: | ||
output_type = "pil" | ||
|
||
device = next(iter(self.model.parameters())).device | ||
if not isinstance(input_image, np.ndarray): | ||
input_image = np.array(input_image, dtype=np.uint8) | ||
|
||
input_image = HWC3(input_image) | ||
# input_image = resize_image(input_image, detect_resolution) | ||
|
||
assert input_image.ndim == 3 | ||
image_normal = input_image | ||
with torch.no_grad(): | ||
image_normal = torch.from_numpy(image_normal).float().to(device) | ||
image_normal = image_normal / 255.0 | ||
image_normal = rearrange(image_normal, "h w c -> 1 c h w") | ||
image_normal = self.norm(image_normal) | ||
|
||
normal = self.model(image_normal) | ||
normal = normal[0][-1][:, :3] | ||
# d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5 | ||
# d = torch.maximum(d, torch.ones_like(d) * 1e-5) | ||
# normal /= d | ||
normal = ((normal + 1) * 0.5).clip(0, 1) | ||
|
||
normal = rearrange(normal[0], "c h w -> h w c").cpu().numpy() | ||
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8) | ||
|
||
detected_map = normal_image | ||
detected_map = HWC3(detected_map) | ||
|
||
# img = resize_image(input_image, image_resolution) | ||
H, W, C = input_image.shape | ||
|
||
detected_map = cv2.resize( | ||
detected_map, (W, H), interpolation=cv2.INTER_LINEAR | ||
) | ||
|
||
if output_type == "pil": | ||
detected_map = Image.fromarray(detected_map) | ||
|
||
return detected_map |
22 changes: 22 additions & 0 deletions
22
apps/stable_diffusion/src/utils/stencils/normal_bae/nets/NNET.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from .submodules.encoder import Encoder | ||
from .submodules.decoder import Decoder | ||
|
||
|
||
class NNET(nn.Module): | ||
def __init__(self, args): | ||
super(NNET, self).__init__() | ||
self.encoder = Encoder() | ||
self.decoder = Decoder(args) | ||
|
||
def get_1x_lr_params(self): # lr/10 learning rate | ||
return self.encoder.parameters() | ||
|
||
def get_10x_lr_params(self): # lr learning rate | ||
return self.decoder.parameters() | ||
|
||
def forward(self, img, **kwargs): | ||
return self.decoder(self.encoder(img), **kwargs) |
Empty file.
102 changes: 102 additions & 0 deletions
102
apps/stable_diffusion/src/utils/stencils/normal_bae/nets/baseline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from .submodules.submodules import UpSampleBN, norm_normalize | ||
|
||
|
||
# This is the baseline encoder-decoder we used in the ablation study | ||
class NNET(nn.Module): | ||
def __init__(self, args=None): | ||
super(NNET, self).__init__() | ||
self.encoder = Encoder() | ||
self.decoder = Decoder(num_classes=4) | ||
|
||
def forward(self, x, **kwargs): | ||
out = self.decoder(self.encoder(x), **kwargs) | ||
|
||
# Bilinearly upsample the output to match the input resolution | ||
up_out = F.interpolate( | ||
out, | ||
size=[x.size(2), x.size(3)], | ||
mode="bilinear", | ||
align_corners=False, | ||
) | ||
|
||
# L2-normalize the first three channels / ensure positive value for concentration parameters (kappa) | ||
up_out = norm_normalize(up_out) | ||
return up_out | ||
|
||
def get_1x_lr_params(self): # lr/10 learning rate | ||
return self.encoder.parameters() | ||
|
||
def get_10x_lr_params(self): # lr learning rate | ||
modules = [self.decoder] | ||
for m in modules: | ||
yield from m.parameters() | ||
|
||
|
||
# Encoder | ||
class Encoder(nn.Module): | ||
def __init__(self): | ||
super(Encoder, self).__init__() | ||
|
||
basemodel_name = "tf_efficientnet_b5_ap" | ||
basemodel = torch.hub.load( | ||
"rwightman/gen-efficientnet-pytorch", | ||
basemodel_name, | ||
pretrained=True, | ||
) | ||
|
||
# Remove last layer | ||
basemodel.global_pool = nn.Identity() | ||
basemodel.classifier = nn.Identity() | ||
|
||
self.original_model = basemodel | ||
|
||
def forward(self, x): | ||
features = [x] | ||
for k, v in self.original_model._modules.items(): | ||
if k == "blocks": | ||
for ki, vi in v._modules.items(): | ||
features.append(vi(features[-1])) | ||
else: | ||
features.append(v(features[-1])) | ||
return features | ||
|
||
|
||
# Decoder (no pixel-wise MLP, no uncertainty-guided sampling) | ||
class Decoder(nn.Module): | ||
def __init__(self, num_classes=4): | ||
super(Decoder, self).__init__() | ||
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) | ||
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) | ||
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) | ||
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) | ||
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) | ||
self.conv3 = nn.Conv2d( | ||
128, num_classes, kernel_size=3, stride=1, padding=1 | ||
) | ||
|
||
def forward(self, features): | ||
x_block0, x_block1, x_block2, x_block3, x_block4 = ( | ||
features[4], | ||
features[5], | ||
features[6], | ||
features[8], | ||
features[11], | ||
) | ||
x_d0 = self.conv2(x_block4) | ||
x_d1 = self.up1(x_d0, x_block3) | ||
x_d2 = self.up2(x_d1, x_block2) | ||
x_d3 = self.up3(x_d2, x_block1) | ||
x_d4 = self.up4(x_d3, x_block0) | ||
out = self.conv3(x_d4) | ||
return out | ||
|
||
|
||
if __name__ == "__main__": | ||
model = Baseline() | ||
x = torch.rand(2, 3, 480, 640) | ||
out = model(x) | ||
print(out.shape) |
Empty file.
Oops, something went wrong.