-
Notifications
You must be signed in to change notification settings - Fork 0
/
p2p.py
209 lines (181 loc) · 9.12 KB
/
p2p.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
from typing import Tuple, Union, Optional, List
import torch
import torch.nn as nn
from torch.optim.adamw import AdamW
from torch.optim.sgd import SGD
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from IPython.display import display, clear_output
T = torch.Tensor
TN = Optional[T]
TS = Union[Tuple[T, ...], List[T]]
device = torch.device('cuda:5')
def load_512(image_path: str, left=0, right=0, top=0, bottom=0):
image = np.array(Image.open(image_path))[:, :, :3]
h, w, c = image.shape
left = min(left, w-1)
right = min(right, w - left - 1)
top = min(top, h - left - 1)
bottom = min(bottom, h - top - 1)
image = image[top:h-bottom, left:w-right]
h, w, c = image.shape
if h < w:
offset = (w - h) // 2
image = image[:, offset:offset + h]
elif w < h:
offset = (h - w) // 2
image = image[offset:offset + w]
image = np.array(Image.fromarray(image).resize((512, 512)))
return image
@torch.no_grad()
def get_text_embeddings(pipe: StableDiffusionPipeline, text: str) -> T:
tokens = pipe.tokenizer([text], padding="max_length", max_length=77, truncation=True,
return_tensors="pt", return_overflowing_tokens=True).input_ids.to(device)
return pipe.text_encoder(tokens).last_hidden_state.detach()
@torch.no_grad()
def denormalize(image):
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).astype(np.uint8)
return image[0]
@torch.no_grad()
def decode(latent: T, pipe: StableDiffusionPipeline, im_cat: TN = None):
image = pipe.vae.decode((1 / 0.18215) * latent, return_dict=False)[0]
image = denormalize(image)
if im_cat is not None:
image = np.concatenate((im_cat, image), axis=1)
return Image.fromarray(image)
def init_pipe(device, dtype, unet, scheduler) -> Tuple[UNet2DConditionModel, T, T]:
with torch.inference_mode():
alphas = torch.sqrt(scheduler.alphas_cumprod).to(device, dtype=dtype)
sigmas = torch.sqrt(1 - scheduler.alphas_cumprod).to(device, dtype=dtype)
for p in unet.parameters():
p.requires_grad = False
return unet, alphas, sigmas
class DDSLoss:
def noise_input(self, z, eps=None, timestep: Optional[int] = None):
if timestep is None:
b = z.shape[0]
timestep = torch.randint(
low=self.t_min,
high=min(self.t_max, 1000) - 1, # Avoid the highest timestep.
size=(b,),
device=z.device, dtype=torch.long)
if eps is None:
eps = torch.randn_like(z)
alpha_t = self.alphas[timestep, None, None, None]
sigma_t = self.sigmas[timestep, None, None, None]
z_t = alpha_t * z + sigma_t * eps
return z_t, eps, timestep, alpha_t, sigma_t
def get_eps_prediction(self, z_t: T, timestep: T, text_embeddings: T, alpha_t: T, sigma_t: T, get_raw=False,
guidance_scale=7.5):
latent_input = torch.cat([z_t] * 2)
timestep = torch.cat([timestep] * 2)
embedd = text_embeddings.permute(1, 0, 2, 3).reshape(-1, *text_embeddings.shape[2:])
with torch.autocast(device_type="cuda", dtype=torch.float16):
e_t = self.unet(latent_input, timestep, embedd).sample
if self.prediction_type == 'v_prediction':
e_t = torch.cat([alpha_t] * 2) * e_t + torch.cat([sigma_t] * 2) * latent_input
e_t_uncond, e_t = e_t.chunk(2)
if get_raw:
return e_t_uncond, e_t
e_t = e_t_uncond + guidance_scale * (e_t - e_t_uncond)
assert torch.isfinite(e_t).all()
if get_raw:
return e_t
pred_z0 = (z_t - sigma_t * e_t) / alpha_t
return e_t, pred_z0
def get_sds_loss(self, z: T, text_embeddings: T, eps: TN = None, mask=None, t=None,
timestep: Optional[int] = None, guidance_scale=7.5) -> TS:
with torch.inference_mode():
z_t, eps, timestep, alpha_t, sigma_t = self.noise_input(z, eps=eps, timestep=timestep)
e_t, _ = self.get_eps_prediction(z_t, timestep, text_embeddings, alpha_t, sigma_t,
guidance_scale=guidance_scale)
grad_z = (alpha_t ** self.alpha_exp) * (sigma_t ** self.sigma_exp) * (e_t - eps)
assert torch.isfinite(grad_z).all()
grad_z = torch.nan_to_num(grad_z.detach(), 0.0, 0.0, 0.0)
if mask is not None:
grad_z = grad_z * mask
log_loss = (grad_z ** 2).mean()
sds_loss = grad_z.clone() * z
del grad_z
return sds_loss.sum() / (z.shape[2] * z.shape[3]), log_loss
def get_dds_loss(self, z_source: T, z_target: T, text_emb_source: T, text_emb_target: T,
eps=None, reduction='mean', symmetric: bool = False, calibration_grad=None, timestep: Optional[int] = None,
guidance_scale=7.5, raw_log=False) -> TS:
with torch.inference_mode():
z_t_source, eps, timestep, alpha_t, sigma_t = self.noise_input(z_source, eps, timestep)
z_t_target, _, _, _, _ = self.noise_input(z_target, eps, timestep)
eps_pred, _ = self.get_eps_prediction(torch.cat((z_t_source, z_t_target)),
torch.cat((timestep, timestep)),
torch.cat((text_emb_source, text_emb_target)),
torch.cat((alpha_t, alpha_t)),
torch.cat((sigma_t, sigma_t)),
guidance_scale=guidance_scale)
eps_pred_source, eps_pred_target = eps_pred.chunk(2)
grad = (alpha_t ** self.alpha_exp) * (sigma_t ** self.sigma_exp) * (eps_pred_target - eps_pred_source)
if calibration_grad is not None:
if calibration_grad.dim() == 4:
grad = grad - calibration_grad
else:
grad = grad - calibration_grad[timestep - self.t_min]
if raw_log:
log_loss = eps.detach().cpu(), eps_pred_target.detach().cpu(), eps_pred_source.detach().cpu()
else:
log_loss = (grad ** 2).mean()
loss = z_target * grad.clone()
if symmetric:
loss = loss.sum() / (z_target.shape[2] * z_target.shape[3])
loss_symm = self.rescale * z_source * (-grad.clone())
loss += loss_symm.sum() / (z_target.shape[2] * z_target.shape[3])
elif reduction == 'mean':
loss = loss.sum() / (z_target.shape[2] * z_target.shape[3])
return loss, log_loss
def __init__(self, device, pipe: StableDiffusionPipeline, dtype=torch.float32):
self.t_min = 50
self.t_max = 950
self.alpha_exp = 0
self.sigma_exp = 0
self.dtype = dtype
self.unet, self.alphas, self.sigmas = init_pipe(device, dtype, pipe.unet, pipe.scheduler)
self.prediction_type = pipe.scheduler.prediction_type
model_id ="runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id,).to(device)
def image_optimization(pipe: StableDiffusionPipeline, image: np.ndarray, text_source: str, text_target: str, num_iters=200, use_dds=True) -> None:
dds_loss = DDSLoss(device, pipe)
image_source = torch.from_numpy(image).float().permute(2, 0, 1) / 127.5 - 1
image_source = image_source.unsqueeze(0).to(device)
with torch.no_grad():
z_source = pipeline.vae.encode(image_source)['latent_dist'].mean * 0.18215
image_target = image_source.clone()
embedding_null = get_text_embeddings(pipeline, "")
embedding_text = get_text_embeddings(pipeline, text_source)
embedding_text_target = get_text_embeddings(pipeline, text_target)
embedding_source = torch.stack([embedding_null, embedding_text], dim=1)
embedding_target = torch.stack([embedding_null, embedding_text_target], dim=1)
guidance_scale = 7.5
image_target.requires_grad = True
use_dds_loss = True
z_taregt = z_source.clone()
z_taregt.requires_grad = True
optimizer = SGD(params=[z_taregt], lr=1e-1)
for i in range(num_iters):
if use_dds:
loss, log_loss = dds_loss.get_dds_loss(z_source, z_taregt, embedding_source, embedding_target)
else:
loss, log_loss = dds_loss.get_sds_loss(z_taregt, embedding_target)
optimizer.zero_grad()
(2000 * loss).backward()
optimizer.step()
if (i + 1) % 10 == 0:
out = decode(z_taregt, pipeline, im_cat=image)
clear_output(wait=True)
display(out)
## (Base) Image from the example_images folder produced by diffusion
image = load_512(f"./example_images/gnochi_mirror.jpeg")
## Produce a new image using one shot
image_optimization(pipeline, image, "a photo of a cat.", "a photo of a tiger.", use_dds=False)
image = load_512(f"./example_images/gnochi_mirror.jpeg")
image_optimization(pipeline, image, "a photo of a cat.", "a photo of a tiger.", use_dds=True)