forked from fastai/course22p2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
4,914 additions
and
295 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/28_diffusion-attn-cond.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['abar', 'inv_abar', 'noisify', 'collate_ddpm', 'dl_ddpm', 'timestep_embedding', 'pre_conv', 'upsample', 'lin', | ||
'SelfAttention', 'SelfAttention2D', 'EmbResBlock', 'saved', 'DownBlock', 'UpBlock', 'EmbUNetModel', | ||
'ddim_step', 'sample', 'cond_sample'] | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 3 | ||
from .imports import * | ||
|
||
from einops import rearrange | ||
from fastprogress import progress_bar | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 6 | ||
def abar(t): return (t*math.pi/2).cos()**2 | ||
def inv_abar(x): return x.sqrt().acos()*2/math.pi | ||
|
||
def noisify(x0): | ||
device = x0.device | ||
n = len(x0) | ||
t = torch.rand(n,).to(x0).clamp(0,0.999) | ||
ε = torch.randn(x0.shape, device=device) | ||
abar_t = abar(t).reshape(-1, 1, 1, 1).to(device) | ||
xt = abar_t.sqrt()*x0 + (1-abar_t).sqrt()*ε | ||
return (xt, t.to(device)), ε | ||
|
||
def collate_ddpm(b): return noisify(default_collate(b)[xl]) | ||
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=4) | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 10 | ||
def timestep_embedding(tsteps, emb_dim, max_period= 10000): | ||
exponent = -math.log(max_period) * torch.linspace(0, 1, emb_dim//2, device=tsteps.device) | ||
emb = tsteps[:,None].float() * exponent.exp()[None,:] | ||
emb = torch.cat([emb.sin(), emb.cos()], dim=-1) | ||
return F.pad(emb, (0,1,0,0)) if emb_dim%2==1 else emb | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 11 | ||
def pre_conv(ni, nf, ks=3, stride=1, act=nn.SiLU, norm=None, bias=True): | ||
layers = nn.Sequential() | ||
if norm: layers.append(norm(ni)) | ||
if act : layers.append(act()) | ||
layers.append(nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias)) | ||
return layers | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 12 | ||
def upsample(nf): return nn.Sequential(nn.Upsample(scale_factor=2.), nn.Conv2d(nf, nf, 3, padding=1)) | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 13 | ||
def lin(ni, nf, act=nn.SiLU, norm=None, bias=True): | ||
layers = nn.Sequential() | ||
if norm: layers.append(norm(ni)) | ||
if act : layers.append(act()) | ||
layers.append(nn.Linear(ni, nf, bias=bias)) | ||
return layers | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 15 | ||
class SelfAttention(nn.Module): | ||
def __init__(self, ni, attn_chans, transpose=True): | ||
super().__init__() | ||
self.nheads = ni//attn_chans | ||
self.scale = math.sqrt(ni/self.nheads) | ||
self.norm = nn.LayerNorm(ni) | ||
self.qkv = nn.Linear(ni, ni*3) | ||
self.proj = nn.Linear(ni, ni) | ||
self.t = transpose | ||
|
||
def forward(self, x): | ||
n,c,s = x.shape | ||
if self.t: x = x.transpose(1, 2) | ||
x = self.norm(x) | ||
x = self.qkv(x) | ||
x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads) | ||
q,k,v = torch.chunk(x, 3, dim=-1) | ||
s = (q@k.transpose(1,2))/self.scale | ||
x = s.softmax(dim=-1)@v | ||
x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads) | ||
x = self.proj(x) | ||
if self.t: x = x.transpose(1, 2) | ||
return x | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 16 | ||
class SelfAttention2D(SelfAttention): | ||
def forward(self, x): | ||
n,c,h,w = x.shape | ||
return super().forward(x.view(n, c, -1)).reshape(n,c,h,w) | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 17 | ||
class EmbResBlock(nn.Module): | ||
def __init__(self, n_emb, ni, nf=None, ks=3, act=nn.SiLU, norm=nn.BatchNorm2d, attn_chans=0): | ||
super().__init__() | ||
if nf is None: nf = ni | ||
self.emb_proj = nn.Linear(n_emb, nf*2) | ||
self.conv1 = pre_conv(ni, nf, ks, act=act, norm=norm) | ||
self.conv2 = pre_conv(nf, nf, ks, act=act, norm=norm) | ||
self.idconv = fc.noop if ni==nf else nn.Conv2d(ni, nf, 1) | ||
self.attn = False | ||
if attn_chans: self.attn = SelfAttention2D(nf, attn_chans) | ||
|
||
def forward(self, x, t): | ||
inp = x | ||
x = self.conv1(x) | ||
emb = self.emb_proj(F.silu(t))[:, :, None, None] | ||
scale,shift = torch.chunk(emb, 2, dim=1) | ||
x = x*(1+scale) + shift | ||
x = self.conv2(x) | ||
x = x + self.idconv(inp) | ||
if self.attn: x = x + self.attn(x) | ||
return x | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 18 | ||
def saved(m, blk): | ||
m_ = m.forward | ||
|
||
@wraps(m.forward) | ||
def _f(*args, **kwargs): | ||
res = m_(*args, **kwargs) | ||
blk.saved.append(res) | ||
return res | ||
|
||
m.forward = _f | ||
return m | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 19 | ||
class DownBlock(nn.Module): | ||
def __init__(self, n_emb, ni, nf, add_down=True, num_layers=1, attn_chans=0): | ||
super().__init__() | ||
self.resnets = nn.ModuleList([saved(EmbResBlock(n_emb, ni if i==0 else nf, nf, attn_chans=attn_chans), self) | ||
for i in range(num_layers)]) | ||
self.down = saved(nn.Conv2d(nf, nf, 3, stride=2, padding=1), self) if add_down else nn.Identity() | ||
|
||
def forward(self, x, t): | ||
self.saved = [] | ||
for resnet in self.resnets: x = resnet(x, t) | ||
x = self.down(x) | ||
return x | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 20 | ||
class UpBlock(nn.Module): | ||
def __init__(self, n_emb, ni, prev_nf, nf, add_up=True, num_layers=2, attn_chans=0): | ||
super().__init__() | ||
self.resnets = nn.ModuleList( | ||
[EmbResBlock(n_emb, (prev_nf if i==0 else nf)+(ni if (i==num_layers-1) else nf), nf, attn_chans=attn_chans) | ||
for i in range(num_layers)]) | ||
self.up = upsample(nf) if add_up else nn.Identity() | ||
|
||
def forward(self, x, t, ups): | ||
for resnet in self.resnets: x = resnet(torch.cat([x, ups.pop()], dim=1), t) | ||
return self.up(x) | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 21 | ||
class EmbUNetModel(nn.Module): | ||
def __init__( self, in_channels=3, out_channels=3, nfs=(224,448,672,896), num_layers=1, attn_chans=8, attn_start=1): | ||
super().__init__() | ||
self.conv_in = nn.Conv2d(in_channels, nfs[0], kernel_size=3, padding=1) | ||
self.n_temb = nf = nfs[0] | ||
n_emb = nf*4 | ||
self.emb_mlp = nn.Sequential(lin(self.n_temb, n_emb, norm=nn.BatchNorm1d), | ||
lin(n_emb, n_emb)) | ||
self.downs = nn.ModuleList() | ||
n = len(nfs) | ||
for i in range(n): | ||
ni = nf | ||
nf = nfs[i] | ||
self.downs.append(DownBlock(n_emb, ni, nf, add_down=i!=n-1, num_layers=num_layers, | ||
attn_chans=0 if i<attn_start else attn_chans)) | ||
self.mid_block = EmbResBlock(n_emb, nfs[-1]) | ||
|
||
rev_nfs = list(reversed(nfs)) | ||
nf = rev_nfs[0] | ||
self.ups = nn.ModuleList() | ||
for i in range(n): | ||
prev_nf = nf | ||
nf = rev_nfs[i] | ||
ni = rev_nfs[min(i+1, len(nfs)-1)] | ||
self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=n-1, num_layers=num_layers+1, | ||
attn_chans=0 if i>=n-attn_start else attn_chans)) | ||
self.conv_out = pre_conv(nfs[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d, bias=False) | ||
|
||
def forward(self, inp): | ||
x,t = inp | ||
temb = timestep_embedding(t, self.n_temb) | ||
emb = self.emb_mlp(temb) | ||
x = self.conv_in(x) | ||
saved = [x] | ||
for block in self.downs: x = block(x, emb) | ||
saved += [p for o in self.downs for p in o.saved] | ||
x = self.mid_block(x, emb) | ||
for block in self.ups: x = block(x, emb, saved) | ||
return self.conv_out(x) | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 28 | ||
def ddim_step(x_t, noise, abar_t, abar_t1, bbar_t, bbar_t1, eta, sig): | ||
sig = ((bbar_t1/bbar_t).sqrt() * (1-abar_t/abar_t1).sqrt()) * eta | ||
x_0_hat = ((x_t-(1-abar_t).sqrt()*noise) / abar_t.sqrt()).clamp(-1,1) | ||
if bbar_t1<=sig**2+0.01: sig=0. # set to zero if very small or NaN | ||
x_t = abar_t1.sqrt()*x_0_hat + (bbar_t1-sig**2).sqrt()*noise | ||
x_t += sig * torch.randn(x_t.shape).to(x_t) | ||
return x_0_hat,x_t | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 29 | ||
@torch.no_grad() | ||
def sample(f, model, sz, steps, eta=1.): | ||
model.eval() | ||
ts = torch.linspace(1-1/steps,0,steps) | ||
x_t = torch.randn(sz).cuda() | ||
preds = [] | ||
for i,t in enumerate(progress_bar(ts)): | ||
t = t[None].cuda() | ||
abar_t = abar(t) | ||
noise = model((x_t, t)) | ||
abar_t1 = abar(t-1/steps) if t>=1/steps else torch.tensor(1) | ||
x_0_hat,x_t = f(x_t, noise, abar_t, abar_t1, 1-abar_t, 1-abar_t1, eta, 1-((i+1)/100)) | ||
preds.append(x_0_hat.float().cpu()) | ||
return preds | ||
|
||
# %% ../nbs/28_diffusion-attn-cond.ipynb 43 | ||
@torch.no_grad() | ||
def cond_sample(c, f, model, sz, steps, eta=1.): | ||
ts = torch.linspace(1-1/steps,0,steps) | ||
x_t = torch.randn(sz).cuda() | ||
c = x_t.new_full((sz[0],), c, dtype=torch.int32) | ||
preds = [] | ||
for i,t in enumerate(progress_bar(ts)): | ||
t = t[None].cuda() | ||
abar_t = abar(t) | ||
noise = model((x_t, t, c)) | ||
abar_t1 = abar(t-1/steps) if t>=1/steps else torch.tensor(1) | ||
x_0_hat,x_t = f(x_t, noise, abar_t, abar_t1, 1-abar_t, 1-abar_t1, eta, 1-((i+1)/100)) | ||
preds.append(x_0_hat.float().cpu()) | ||
return preds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import torch, random, datasets, math, fastcore.all as fc, numpy as np, matplotlib as mpl, matplotlib.pyplot as plt | ||
import torchvision.transforms as T | ||
import torchvision.transforms.functional as TF,torch.nn.functional as F | ||
|
||
from torch.utils.data import DataLoader,default_collate | ||
from pathlib import Path | ||
from torch import nn,tensor | ||
from torch.nn import init | ||
from fastcore.foundation import L | ||
from datasets import load_dataset | ||
from operator import itemgetter,attrgetter | ||
from functools import partial,wraps | ||
from torch.optim import lr_scheduler | ||
from torch import optim | ||
from torchvision.io import read_image,ImageReadMode | ||
|
||
from miniai.datasets import * | ||
from miniai.conv import * | ||
from miniai.learner import * | ||
from miniai.activations import * | ||
from miniai.init import * | ||
from miniai.sgd import * | ||
from miniai.resnet import * | ||
from miniai.augment import * | ||
from miniai.accel import * | ||
from miniai.training import * | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.