Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Topic recognition #469

Open
wants to merge 63 commits into
base: topic-recognition
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
d4943af
Initial commit and filling out basic project structure.
4lon Oct 11, 2022
32876fb
Added basic dataset class that loads all images in directory and retu…
4lon Oct 11, 2022
0c9dc73
Added template for VAE
4lon Oct 11, 2022
0d501b0
Corrected error in dataset class location, moved from train to dataset.
4lon Oct 11, 2022
4176432
Added Vector Quantizer
4lon Oct 14, 2022
067b1fc
Updated encoder and decoders
4lon Oct 14, 2022
d9e43b3
Fixed wrong layer errors and added vqvae model, likely need to fix pa…
4lon Oct 14, 2022
5cb7a40
Updated to use pytorch embeddings and fleshed out vector quantizer.
4lon Oct 20, 2022
662ef51
Cleanup
4lon Oct 20, 2022
d2611f7
Cleanup and redid vq_vae class to add more individual functions. This…
4lon Oct 20, 2022
eb289e9
Updated to make dataloaders in dataset.py
4lon Oct 20, 2022
dbea67a
Basis of training script setting up
4lon Oct 20, 2022
43cc875
Built out saving model and modularised functions
4lon Oct 20, 2022
6a60fba
Fixed up incorrect dataset formatting and loading of imgs
4lon Oct 20, 2022
40bb45c
Fleshed out training function
4lon Oct 20, 2022
eecfde1
Fleshed out testing function
4lon Oct 20, 2022
aa65485
Added image generation and saving to tensorboard for visual analysis
4lon Oct 20, 2022
9ef9451
Refactored modules to include original functions file and updated dat…
4lon Oct 20, 2022
670cc67
Increase HPC utilisation
4lon Oct 20, 2022
3c7ebce
add checks
4lon Oct 20, 2022
968ad8a
Increased Epochs
4lon Oct 20, 2022
cd008ee
Original pytorch attempt was not working, start of tensorflow attempt…
4lon Oct 20, 2022
eee3087
Old train model, not relevant anymore but good for looking back and n…
4lon Oct 20, 2022
65a76a4
Redesigned VQVAE modules for tensorflow. This is a simpler design ove…
4lon Oct 20, 2022
c54bd95
Reimplemented basic training method, but this one does not evaluate a…
4lon Oct 20, 2022
00490b4
Fixed some errors encountered when trying to train because of tracker…
4lon Oct 20, 2022
09ca20c
Added plotting of results including losses and reconstruction post le…
4lon Oct 21, 2022
53811cf
Updated directory
4lon Oct 21, 2022
bd443ee
Fixed missing param
4lon Oct 21, 2022
120f5e1
Train on more images
4lon Oct 21, 2022
d7400d1
Updated optional param
4lon Oct 21, 2022
07950f1
Imporved plotting
4lon Oct 21, 2022
e65e287
Removed print statement
4lon Oct 21, 2022
36a4200
Fixed image formatting
4lon Oct 21, 2022
ba208a4
New dataset dtype to remove reconstruction issues.
4lon Oct 21, 2022
178c251
Reduced batch size because of memory running out in new datatype (goi…
4lon Oct 21, 2022
f6eea6d
Reduced batch size again.
4lon Oct 21, 2022
ed9a611
Changed data type because space requirement was unachievable
4lon Oct 21, 2022
d32fda4
Built basic masked convolution layer and limited dataset size (memory…
4lon Oct 21, 2022
12bbfa6
Implemented residual block for pixelcnn
4lon Oct 21, 2022
31bcdef
Implemented basic pixel cnn architecture but not sure it works.
4lon Oct 21, 2022
c0ede93
Added pixelcnn training to learn codebook production
4lon Oct 21, 2022
cd4838a
Updated plotting to also save pixel cnn performance
4lon Oct 21, 2022
3660c53
Updated training with ssim metrics
4lon Oct 21, 2022
b1e9fda
reduced dataset size
4lon Oct 21, 2022
fafcd18
reduced dataset size
4lon Oct 21, 2022
b9d4ce0
reduced dataset size
4lon Oct 21, 2022
2de0a80
reduced dataset size
4lon Oct 21, 2022
570a4b6
Added sample model outputs
4lon Oct 21, 2022
df901ca
Added better sample model outputs
4lon Oct 21, 2022
14d2ebb
Refactored training to seperate pixel cnn and vqvae in case of sepera…
4lon Oct 21, 2022
b193525
Fleshing out report
4lon Oct 21, 2022
a15a457
Fleshing out report
4lon Oct 21, 2022
19d4895
Add brain generating function but doesn't really work.
4lon Oct 21, 2022
9de9e81
Removed old attempt in pytorch and added snapshots of inference.
4lon Oct 21, 2022
3fb1dfc
Fixed up incorrect function call from refactor and add SSIM evaulatio…
4lon Oct 21, 2022
fcbbb82
Finalising report
4lon Oct 21, 2022
0a71411
Fixed reconstruction a bit
4lon Oct 21, 2022
0231d01
Merge branch 'topic-recognition' into topic-recognition
4lon Nov 24, 2022
8be7c8e
Delete recognition/44801582_OASIS_VAE/samples/pixelcnn_model directory
4lon Nov 24, 2022
67ca8b2
Delete recognition/44801582_OASIS_VAE/samples/vqvae_model directory
4lon Nov 24, 2022
949fd52
Delete pixelcnn_model_weights.h5
4lon Nov 24, 2022
308597f
Delete vqvae_model_weights.h5
4lon Nov 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactored modules to include original functions file and updated dat…
…aset location.
4lon committed Oct 20, 2022
commit 9ef94512a7988795968703002a082537541a881e
6 changes: 3 additions & 3 deletions recognition/44801582_OASIS_VAE/dataset.py
Original file line number Diff line number Diff line change
@@ -28,11 +28,11 @@ def get_loaders():
batch_size = 4
num_workers = 1

train_loader = DataLoader(OASISDataset("keras_png_slices_data/keras_png_slices_train"),
train_loader = DataLoader(OASISDataset("data/keras_png_slices_data/keras_png_slices_train"),
batch_size=batch_size, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(OASISDataset("keras_png_slices_data/keras_png_slices_test"),
test_loader = DataLoader(OASISDataset("data/keras_png_slices_data/keras_png_slices_test"),
batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(OASISDataset("keras_png_slices_data/keras_png_slices_validate"),
validation_loader = DataLoader(OASISDataset("data/keras_png_slices_data/keras_png_slices_validate"),
batch_size=batch_size, drop_last=True,
num_workers=num_workers, pin_memory=True)

105 changes: 86 additions & 19 deletions recognition/44801582_OASIS_VAE/modules.py
Original file line number Diff line number Diff line change
@@ -4,16 +4,33 @@

import torch
import torch.nn as nn
from functions import vector_quantizer, vector_quantizer_straight_through
from torch.autograd import Function


class ResBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.block = nn.Sequential(
nn.ReLU(True),
nn.LazyConv2d(dim, 3, 1, 1),
nn.LazyBatchNorm2d(),
nn.ReLU(True),
nn.LazyConv2d(dim, 1),
nn.LazyBatchNorm2d()
)

def forward(self, x):
return x + self.block(x)


def get_encoder(latent_dim=16):
enc_model = nn.Sequential(
nn.LazyConv2d(32, 3, stride=2, padding=1),
nn.ReLU(),
nn.LazyConv2d(64, 3, stride=2, padding=1),
nn.ReLU(),
nn.LazyConv2d(latent_dim, 1, padding=1)
nn.LazyConv2d(latent_dim, 4, 2, 1),
nn.LazyBatchNorm2d(),
nn.ReLU(True),
nn.LazyConv2d(latent_dim, 4, 2, 1),
ResBlock(latent_dim),
ResBlock(latent_dim),
)

return enc_model
@@ -23,42 +40,92 @@ class VQ(nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super().__init__()
self.embed = nn.Embedding(num_embeddings, embedding_dim)
self.vector_quantizer = self.def_vq()
self.vector_quantizer_straight_through = self.def_vq_straight_through()

def def_vq(self):
class Quantization(Function):
def forward(ctx, inputs, codebook):
with torch.no_grad():
codebook_sqr = torch.sum(codebook ** 2, dim=1)
inputs_sqr = torch.sum(inputs.view(-1, codebook.size(1)) ** 2, dim=1, keepdim=True)

distances = torch.addmm(codebook_sqr + inputs_sqr,
inputs.view(-1, codebook.size(1)), codebook.t(), alpha=-2.0, beta=1.0)

indices = torch.min(distances, dim=1)[1].view(*inputs.size()[:-1])
ctx.mark_non_differentiable(indices)

return indices

return Quantization.apply

def def_vq_straight_through(self):
class QuantizationST(Function):
def forward(ctx, inputs, codebook):
indices = self.vector_quantizer(inputs, codebook).view(-1)
ctx.save_for_backward(indices, codebook)
ctx.mark_non_differentiable(indices)

codes_flatten = torch.index_select(codebook, dim=0,
index=indices)
codes = codes_flatten.view_as(inputs)

return codes, indices

def backward(ctx, grad_output, grad_indices):
grad_inputs, grad_codebook = None, None

if ctx.needs_input_grad[0]:
grad_inputs = grad_output.clone()
if ctx.needs_input_grad[1]:
indices, codebook = ctx.saved_tensors

grad_codebook = torch.zeros_like(codebook)
grad_codebook.index_add_(0, indices, grad_output.contiguous().view(-1, codebook.size(1)))

return grad_inputs, grad_codebook

return QuantizationST.apply

def forward(self, x):
x = torch.permute(x, (0, 2, 3, 1)).contiguous()
latent = vector_quantizer(x, self.embed.weight)
latent = self.vector_quantizer(x, self.embed.weight)

return latent

def straight_through(self, x):
x = x.permute(0, 2, 3, 1).contiguous()
x_quantized_straight_through, ind = vector_quantizer_straight_through(x, self.embed.weight.detach())
x_quantized_straight_through, ind = self.vector_quantizer_straight_through(x, self.embed.weight.detach())
x_quantized_straight_through = x_quantized_straight_through.permute(0, 3, 1, 2).contiguous()

x_quantized = torch.index_select(self.embedding.weight, dim=0, index=ind)\
x_quantized = torch.index_select(self.embed.weight, dim=0, index=ind)\
.view_as(x).permute(0, 3, 1, 2).contiguous()

return x_quantized_straight_through, x_quantized


def get_decoder():
def get_decoder(latent_dim, final_dim):
dec_model = nn.Sequential(
nn.LazyConvTranspose2d(64, 3, stride=2, padding=1),
nn.ReLU(),
nn.LazyConvTranspose2d(32, 3, stride=2, padding=1),
nn.ReLU(),
nn.LazyConvTranspose2d(1, 3, stride=2, padding=1)
ResBlock(latent_dim),
ResBlock(latent_dim),
nn.ReLU(True),
nn.LazyConvTranspose2d(latent_dim, 4, 2, 1),
nn.LazyBatchNorm2d(),
nn.ReLU(True),
nn.LazyConvTranspose2d(final_dim, 4, 2, 1),
nn.Tanh()
)

return dec_model


class VQ_VAE(nn.Module):
def __init__(self, latent_dim=16, num_embeddings=64):
def __init__(self, dim=16, num_embeddings=64):
super().__init__()
self.encoder = get_encoder()
self.codebook = VQ(latent_dim, num_embeddings)
self.decoder = get_decoder()
self.encoder = get_encoder(dim)
self.codebook = VQ(num_embeddings, dim)
self.decoder = get_decoder(dim, 1)

def encode(self, x):
return self.codebook(self.encoder(x))
2 changes: 1 addition & 1 deletion recognition/44801582_OASIS_VAE/train.py
Original file line number Diff line number Diff line change
@@ -62,7 +62,7 @@ def main():

train_loader, test_loader, validation_loader = get_loaders()

model = VQ_VAE(1, 256, 512).to(device)
model = VQ_VAE(256, 512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

fixed_images = next(iter(test_loader))