Skip to content

Commit

Permalink
lsun
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Mar 2, 2023
1 parent 127d8a3 commit 3367033
Show file tree
Hide file tree
Showing 11 changed files with 4,914 additions and 295 deletions.
41 changes: 41 additions & 0 deletions miniai/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,46 @@
'miniai.datasets.show_image': ('datasets.html#show_image', 'miniai/datasets.py'),
'miniai.datasets.show_images': ('datasets.html#show_images', 'miniai/datasets.py'),
'miniai.datasets.subplots': ('datasets.html#subplots', 'miniai/datasets.py')},
'miniai.diffusion': { 'miniai.diffusion.DownBlock': ('diffusion-attn-cond.html#downblock', 'miniai/diffusion.py'),
'miniai.diffusion.DownBlock.__init__': ( 'diffusion-attn-cond.html#downblock.__init__',
'miniai/diffusion.py'),
'miniai.diffusion.DownBlock.forward': ( 'diffusion-attn-cond.html#downblock.forward',
'miniai/diffusion.py'),
'miniai.diffusion.EmbResBlock': ('diffusion-attn-cond.html#embresblock', 'miniai/diffusion.py'),
'miniai.diffusion.EmbResBlock.__init__': ( 'diffusion-attn-cond.html#embresblock.__init__',
'miniai/diffusion.py'),
'miniai.diffusion.EmbResBlock.forward': ( 'diffusion-attn-cond.html#embresblock.forward',
'miniai/diffusion.py'),
'miniai.diffusion.EmbUNetModel': ('diffusion-attn-cond.html#embunetmodel', 'miniai/diffusion.py'),
'miniai.diffusion.EmbUNetModel.__init__': ( 'diffusion-attn-cond.html#embunetmodel.__init__',
'miniai/diffusion.py'),
'miniai.diffusion.EmbUNetModel.forward': ( 'diffusion-attn-cond.html#embunetmodel.forward',
'miniai/diffusion.py'),
'miniai.diffusion.SelfAttention': ('diffusion-attn-cond.html#selfattention', 'miniai/diffusion.py'),
'miniai.diffusion.SelfAttention.__init__': ( 'diffusion-attn-cond.html#selfattention.__init__',
'miniai/diffusion.py'),
'miniai.diffusion.SelfAttention.forward': ( 'diffusion-attn-cond.html#selfattention.forward',
'miniai/diffusion.py'),
'miniai.diffusion.SelfAttention2D': ('diffusion-attn-cond.html#selfattention2d', 'miniai/diffusion.py'),
'miniai.diffusion.SelfAttention2D.forward': ( 'diffusion-attn-cond.html#selfattention2d.forward',
'miniai/diffusion.py'),
'miniai.diffusion.UpBlock': ('diffusion-attn-cond.html#upblock', 'miniai/diffusion.py'),
'miniai.diffusion.UpBlock.__init__': ('diffusion-attn-cond.html#upblock.__init__', 'miniai/diffusion.py'),
'miniai.diffusion.UpBlock.forward': ('diffusion-attn-cond.html#upblock.forward', 'miniai/diffusion.py'),
'miniai.diffusion.abar': ('diffusion-attn-cond.html#abar', 'miniai/diffusion.py'),
'miniai.diffusion.collate_ddpm': ('diffusion-attn-cond.html#collate_ddpm', 'miniai/diffusion.py'),
'miniai.diffusion.cond_sample': ('diffusion-attn-cond.html#cond_sample', 'miniai/diffusion.py'),
'miniai.diffusion.ddim_step': ('diffusion-attn-cond.html#ddim_step', 'miniai/diffusion.py'),
'miniai.diffusion.dl_ddpm': ('diffusion-attn-cond.html#dl_ddpm', 'miniai/diffusion.py'),
'miniai.diffusion.inv_abar': ('diffusion-attn-cond.html#inv_abar', 'miniai/diffusion.py'),
'miniai.diffusion.lin': ('diffusion-attn-cond.html#lin', 'miniai/diffusion.py'),
'miniai.diffusion.noisify': ('diffusion-attn-cond.html#noisify', 'miniai/diffusion.py'),
'miniai.diffusion.pre_conv': ('diffusion-attn-cond.html#pre_conv', 'miniai/diffusion.py'),
'miniai.diffusion.sample': ('diffusion-attn-cond.html#sample', 'miniai/diffusion.py'),
'miniai.diffusion.saved': ('diffusion-attn-cond.html#saved', 'miniai/diffusion.py'),
'miniai.diffusion.timestep_embedding': ( 'diffusion-attn-cond.html#timestep_embedding',
'miniai/diffusion.py'),
'miniai.diffusion.upsample': ('diffusion-attn-cond.html#upsample', 'miniai/diffusion.py')},
'miniai.fid': { 'miniai.fid.ImageEval': ('fid.html#imageeval', 'miniai/fid.py'),
'miniai.fid.ImageEval.__init__': ('fid.html#imageeval.__init__', 'miniai/fid.py'),
'miniai.fid.ImageEval.fid': ('fid.html#imageeval.fid', 'miniai/fid.py'),
Expand All @@ -92,6 +132,7 @@
'miniai.fid._calc_stats': ('fid.html#_calc_stats', 'miniai/fid.py'),
'miniai.fid._sqrtm_newton_schulz': ('fid.html#_sqrtm_newton_schulz', 'miniai/fid.py'),
'miniai.fid._squared_mmd': ('fid.html#_squared_mmd', 'miniai/fid.py')},
'miniai.imports': {},
'miniai.init': { 'miniai.init.BatchTransformCB': ('initializing.html#batchtransformcb', 'miniai/init.py'),
'miniai.init.BatchTransformCB.__init__': ('initializing.html#batchtransformcb.__init__', 'miniai/init.py'),
'miniai.init.BatchTransformCB.before_batch': ( 'initializing.html#batchtransformcb.before_batch',
Expand Down
230 changes: 230 additions & 0 deletions miniai/diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/28_diffusion-attn-cond.ipynb.

# %% auto 0
__all__ = ['abar', 'inv_abar', 'noisify', 'collate_ddpm', 'dl_ddpm', 'timestep_embedding', 'pre_conv', 'upsample', 'lin',
'SelfAttention', 'SelfAttention2D', 'EmbResBlock', 'saved', 'DownBlock', 'UpBlock', 'EmbUNetModel',
'ddim_step', 'sample', 'cond_sample']

# %% ../nbs/28_diffusion-attn-cond.ipynb 3
from .imports import *

from einops import rearrange
from fastprogress import progress_bar

# %% ../nbs/28_diffusion-attn-cond.ipynb 6
def abar(t): return (t*math.pi/2).cos()**2
def inv_abar(x): return x.sqrt().acos()*2/math.pi

def noisify(x0):
device = x0.device
n = len(x0)
t = torch.rand(n,).to(x0).clamp(0,0.999)
ε = torch.randn(x0.shape, device=device)
abar_t = abar(t).reshape(-1, 1, 1, 1).to(device)
xt = abar_t.sqrt()*x0 + (1-abar_t).sqrt()*ε
return (xt, t.to(device)), ε

def collate_ddpm(b): return noisify(default_collate(b)[xl])
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=4)

# %% ../nbs/28_diffusion-attn-cond.ipynb 10
def timestep_embedding(tsteps, emb_dim, max_period= 10000):
exponent = -math.log(max_period) * torch.linspace(0, 1, emb_dim//2, device=tsteps.device)
emb = tsteps[:,None].float() * exponent.exp()[None,:]
emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
return F.pad(emb, (0,1,0,0)) if emb_dim%2==1 else emb

# %% ../nbs/28_diffusion-attn-cond.ipynb 11
def pre_conv(ni, nf, ks=3, stride=1, act=nn.SiLU, norm=None, bias=True):
layers = nn.Sequential()
if norm: layers.append(norm(ni))
if act : layers.append(act())
layers.append(nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias))
return layers

# %% ../nbs/28_diffusion-attn-cond.ipynb 12
def upsample(nf): return nn.Sequential(nn.Upsample(scale_factor=2.), nn.Conv2d(nf, nf, 3, padding=1))

# %% ../nbs/28_diffusion-attn-cond.ipynb 13
def lin(ni, nf, act=nn.SiLU, norm=None, bias=True):
layers = nn.Sequential()
if norm: layers.append(norm(ni))
if act : layers.append(act())
layers.append(nn.Linear(ni, nf, bias=bias))
return layers

# %% ../nbs/28_diffusion-attn-cond.ipynb 15
class SelfAttention(nn.Module):
def __init__(self, ni, attn_chans, transpose=True):
super().__init__()
self.nheads = ni//attn_chans
self.scale = math.sqrt(ni/self.nheads)
self.norm = nn.LayerNorm(ni)
self.qkv = nn.Linear(ni, ni*3)
self.proj = nn.Linear(ni, ni)
self.t = transpose

def forward(self, x):
n,c,s = x.shape
if self.t: x = x.transpose(1, 2)
x = self.norm(x)
x = self.qkv(x)
x = rearrange(x, 'n s (h d) -> (n h) s d', h=self.nheads)
q,k,v = torch.chunk(x, 3, dim=-1)
s = (q@k.transpose(1,2))/self.scale
x = s.softmax(dim=-1)@v
x = rearrange(x, '(n h) s d -> n s (h d)', h=self.nheads)
x = self.proj(x)
if self.t: x = x.transpose(1, 2)
return x

# %% ../nbs/28_diffusion-attn-cond.ipynb 16
class SelfAttention2D(SelfAttention):
def forward(self, x):
n,c,h,w = x.shape
return super().forward(x.view(n, c, -1)).reshape(n,c,h,w)

# %% ../nbs/28_diffusion-attn-cond.ipynb 17
class EmbResBlock(nn.Module):
def __init__(self, n_emb, ni, nf=None, ks=3, act=nn.SiLU, norm=nn.BatchNorm2d, attn_chans=0):
super().__init__()
if nf is None: nf = ni
self.emb_proj = nn.Linear(n_emb, nf*2)
self.conv1 = pre_conv(ni, nf, ks, act=act, norm=norm)
self.conv2 = pre_conv(nf, nf, ks, act=act, norm=norm)
self.idconv = fc.noop if ni==nf else nn.Conv2d(ni, nf, 1)
self.attn = False
if attn_chans: self.attn = SelfAttention2D(nf, attn_chans)

def forward(self, x, t):
inp = x
x = self.conv1(x)
emb = self.emb_proj(F.silu(t))[:, :, None, None]
scale,shift = torch.chunk(emb, 2, dim=1)
x = x*(1+scale) + shift
x = self.conv2(x)
x = x + self.idconv(inp)
if self.attn: x = x + self.attn(x)
return x

# %% ../nbs/28_diffusion-attn-cond.ipynb 18
def saved(m, blk):
m_ = m.forward

@wraps(m.forward)
def _f(*args, **kwargs):
res = m_(*args, **kwargs)
blk.saved.append(res)
return res

m.forward = _f
return m

# %% ../nbs/28_diffusion-attn-cond.ipynb 19
class DownBlock(nn.Module):
def __init__(self, n_emb, ni, nf, add_down=True, num_layers=1, attn_chans=0):
super().__init__()
self.resnets = nn.ModuleList([saved(EmbResBlock(n_emb, ni if i==0 else nf, nf, attn_chans=attn_chans), self)
for i in range(num_layers)])
self.down = saved(nn.Conv2d(nf, nf, 3, stride=2, padding=1), self) if add_down else nn.Identity()

def forward(self, x, t):
self.saved = []
for resnet in self.resnets: x = resnet(x, t)
x = self.down(x)
return x

# %% ../nbs/28_diffusion-attn-cond.ipynb 20
class UpBlock(nn.Module):
def __init__(self, n_emb, ni, prev_nf, nf, add_up=True, num_layers=2, attn_chans=0):
super().__init__()
self.resnets = nn.ModuleList(
[EmbResBlock(n_emb, (prev_nf if i==0 else nf)+(ni if (i==num_layers-1) else nf), nf, attn_chans=attn_chans)
for i in range(num_layers)])
self.up = upsample(nf) if add_up else nn.Identity()

def forward(self, x, t, ups):
for resnet in self.resnets: x = resnet(torch.cat([x, ups.pop()], dim=1), t)
return self.up(x)

# %% ../nbs/28_diffusion-attn-cond.ipynb 21
class EmbUNetModel(nn.Module):
def __init__( self, in_channels=3, out_channels=3, nfs=(224,448,672,896), num_layers=1, attn_chans=8, attn_start=1):
super().__init__()
self.conv_in = nn.Conv2d(in_channels, nfs[0], kernel_size=3, padding=1)
self.n_temb = nf = nfs[0]
n_emb = nf*4
self.emb_mlp = nn.Sequential(lin(self.n_temb, n_emb, norm=nn.BatchNorm1d),
lin(n_emb, n_emb))
self.downs = nn.ModuleList()
n = len(nfs)
for i in range(n):
ni = nf
nf = nfs[i]
self.downs.append(DownBlock(n_emb, ni, nf, add_down=i!=n-1, num_layers=num_layers,
attn_chans=0 if i<attn_start else attn_chans))
self.mid_block = EmbResBlock(n_emb, nfs[-1])

rev_nfs = list(reversed(nfs))
nf = rev_nfs[0]
self.ups = nn.ModuleList()
for i in range(n):
prev_nf = nf
nf = rev_nfs[i]
ni = rev_nfs[min(i+1, len(nfs)-1)]
self.ups.append(UpBlock(n_emb, ni, prev_nf, nf, add_up=i!=n-1, num_layers=num_layers+1,
attn_chans=0 if i>=n-attn_start else attn_chans))
self.conv_out = pre_conv(nfs[0], out_channels, act=nn.SiLU, norm=nn.BatchNorm2d, bias=False)

def forward(self, inp):
x,t = inp
temb = timestep_embedding(t, self.n_temb)
emb = self.emb_mlp(temb)
x = self.conv_in(x)
saved = [x]
for block in self.downs: x = block(x, emb)
saved += [p for o in self.downs for p in o.saved]
x = self.mid_block(x, emb)
for block in self.ups: x = block(x, emb, saved)
return self.conv_out(x)

# %% ../nbs/28_diffusion-attn-cond.ipynb 28
def ddim_step(x_t, noise, abar_t, abar_t1, bbar_t, bbar_t1, eta, sig):
sig = ((bbar_t1/bbar_t).sqrt() * (1-abar_t/abar_t1).sqrt()) * eta
x_0_hat = ((x_t-(1-abar_t).sqrt()*noise) / abar_t.sqrt()).clamp(-1,1)
if bbar_t1<=sig**2+0.01: sig=0. # set to zero if very small or NaN
x_t = abar_t1.sqrt()*x_0_hat + (bbar_t1-sig**2).sqrt()*noise
x_t += sig * torch.randn(x_t.shape).to(x_t)
return x_0_hat,x_t

# %% ../nbs/28_diffusion-attn-cond.ipynb 29
@torch.no_grad()
def sample(f, model, sz, steps, eta=1.):
model.eval()
ts = torch.linspace(1-1/steps,0,steps)
x_t = torch.randn(sz).cuda()
preds = []
for i,t in enumerate(progress_bar(ts)):
t = t[None].cuda()
abar_t = abar(t)
noise = model((x_t, t))
abar_t1 = abar(t-1/steps) if t>=1/steps else torch.tensor(1)
x_0_hat,x_t = f(x_t, noise, abar_t, abar_t1, 1-abar_t, 1-abar_t1, eta, 1-((i+1)/100))
preds.append(x_0_hat.float().cpu())
return preds

# %% ../nbs/28_diffusion-attn-cond.ipynb 43
@torch.no_grad()
def cond_sample(c, f, model, sz, steps, eta=1.):
ts = torch.linspace(1-1/steps,0,steps)
x_t = torch.randn(sz).cuda()
c = x_t.new_full((sz[0],), c, dtype=torch.int32)
preds = []
for i,t in enumerate(progress_bar(ts)):
t = t[None].cuda()
abar_t = abar(t)
noise = model((x_t, t, c))
abar_t1 = abar(t-1/steps) if t>=1/steps else torch.tensor(1)
x_0_hat,x_t = f(x_t, noise, abar_t, abar_t1, 1-abar_t, 1-abar_t1, eta, 1-((i+1)/100))
preds.append(x_0_hat.float().cpu())
return preds
27 changes: 27 additions & 0 deletions miniai/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch, random, datasets, math, fastcore.all as fc, numpy as np, matplotlib as mpl, matplotlib.pyplot as plt
import torchvision.transforms as T
import torchvision.transforms.functional as TF,torch.nn.functional as F

from torch.utils.data import DataLoader,default_collate
from pathlib import Path
from torch import nn,tensor
from torch.nn import init
from fastcore.foundation import L
from datasets import load_dataset
from operator import itemgetter,attrgetter
from functools import partial,wraps
from torch.optim import lr_scheduler
from torch import optim
from torchvision.io import read_image,ImageReadMode

from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *
from miniai.resnet import *
from miniai.augment import *
from miniai.accel import *
from miniai.training import *

3 changes: 2 additions & 1 deletion miniai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def to_cpu(x):
if isinstance(x, Mapping): return {k:to_cpu(v) for k,v in x.items()}
if isinstance(x, list): return [to_cpu(o) for o in x]
if isinstance(x, tuple): return tuple(to_cpu(list(x)))
return x.detach().cpu()
res = x.detach().cpu()
return res.float() if res.dtype==torch.float16 else res

# %% ../nbs/09_learner.ipynb 35
class MetricsCB(Callback):
Expand Down
9 changes: 5 additions & 4 deletions nbs/09_learner.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6ca050462ee540518c7028378f76b22d",
"model_id": "e7b7e934e9d347f580ec07a2e01d5883",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -585,7 +585,8 @@
" if isinstance(x, Mapping): return {k:to_cpu(v) for k,v in x.items()}\n",
" if isinstance(x, list): return [to_cpu(o) for o in x]\n",
" if isinstance(x, tuple): return tuple(to_cpu(list(x)))\n",
" return x.detach().cpu()"
" res = x.detach().cpu()\n",
" return res.float() if res.dtype==torch.float16 else res"
]
},
{
Expand Down Expand Up @@ -644,8 +645,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'accuracy': '0.602', 'loss': '1.183', 'epoch': 0, 'train': 'train'}\n",
"{'accuracy': '0.700', 'loss': '0.847', 'epoch': 0, 'train': 'eval'}\n"
"{'accuracy': '0.613', 'loss': '1.152', 'epoch': 0, 'train': 'train'}\n",
"{'accuracy': '0.679', 'loss': '0.814', 'epoch': 0, 'train': 'eval'}\n"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion nbs/27_attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "python3",
"language": "python",
"name": "python3"
}
Expand Down
Loading

0 comments on commit 3367033

Please sign in to comment.