Skip to content

Commit

Permalink
add final steps and output
Browse files Browse the repository at this point in the history
  • Loading branch information
mfranzon committed Oct 26, 2023
1 parent bba66a3 commit f26e012
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 14 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
.env
.env
data
MNIST
3 changes: 2 additions & 1 deletion simple_gan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ There are several variants and extensions of the original GAN architecture, each

- StyleGAN and StyleGAN2: These models are known for generating highly realistic images, particularly faces. They introduce techniques like style-based architecture and progressive growing to improve image quality.

## Code and Comments

![Vanilla GAN result for MNIST dataset](output.gif)
Binary file added simple_gan/output.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
298 changes: 286 additions & 12 deletions simple_gan/vanillaGan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,41 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"%pip install torch"
"import torch\n",
"from torch import nn"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"from torch import nn"
"# Configurable variables\n",
"DIR = \"./data/\"\n",
"NOISE_DIMENSION = 50\n",
"GENERATOR_OUTPUT_IMAGE_SHAPE = 28 * 28 * 1\n",
"NUM_EPOCHS = 50\n",
"BATCH_SIZE = 128\n",
"PRINT_STATS_AFTER_BATCH = 50\n",
"OPTIMIZER_LR = 0.0002\n",
"OPTIMIZER_BETAS = (0.5, 0.999)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"# Configurable variables\n",
"NOISE_DIMENSION = 50\n",
"GENERATOR_OUTPUT_IMAGE_SHAPE = 28 * 28 * 1"
"## Generator"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -74,9 +79,16 @@
" return self.layers(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Discriminator"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -104,6 +116,268 @@
" \"\"\"Forward pass\"\"\"\n",
" return self.layers(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Utilities "
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"def generate_noise(number_of_images = 1, noise_dimension = NOISE_DIMENSION, device=None):\n",
"\n",
" return torch.randn(number_of_images, noise_dimension, device=device)\n",
"\n",
"def generate_image(generator, epoch = 0, batch = 0, device=device):\n",
"\n",
" images = []\n",
" noise = generate_noise(BATCH_SIZE, device=device)\n",
" generator.eval()\n",
" images = generator(noise)\n",
" plt.figure(figsize=(10, 10))\n",
" for i in range(16):\n",
"\n",
" image = images[i]\n",
"\n",
" image = image.cpu().detach().numpy()\n",
" image = np.reshape(image, (28, 28))\n",
"\n",
" plt.subplot(4, 4, i+1)\n",
" plt.imshow(image, cmap='gray')\n",
" plt.axis('off')\n",
" if not os.path.exists(f'./data/images'):\n",
" os.mkdir(f'./data/images')\n",
" plt.savefig(f'./data/images/epoch{epoch}_batch{batch}.jpg')\n",
"\n",
"\n",
"def save_models(generator, discriminator, epoch):\n",
"\n",
" torch.save(generator.state_dict(), f'./data/generator_{epoch}.pth')\n",
" torch.save(discriminator.state_dict(), f'./data/discriminator_{epoch}.pth')\n",
"\n",
"\n",
"def print_training_progress(batch, generator_loss, discriminator_loss):\n",
" print('Losses after mini-batch %5d: generator %e, discriminator %e' %\n",
" (batch, generator_loss, discriminator_loss))"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"from torchvision.datasets import MNIST\n",
"from torchvision import transforms\n",
"from torch.utils.data import DataLoader\n",
"\n",
"def prepare_dataset():\n",
"\n",
" # MNIST dataset\n",
" dataset = MNIST(os.getcwd(), download=True, train=True, transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,))\n",
" ]))\n",
" # Batch and shuffle data with DataLoader\n",
" trainloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)\n",
" return trainloader"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"def initialize_models(device = device):\n",
"\n",
" generator = Generator()\n",
" discriminator = Discriminator()\n",
" # Move models to specific device\n",
" generator.to(device)\n",
" discriminator.to(device)\n",
" # Return models\n",
" return generator, discriminator\n",
"\n",
"\n",
"def initialize_loss():\n",
"\n",
" return nn.BCELoss()\n",
"\n",
"\n",
"def initialize_optimizers(generator, discriminator):\n",
"\n",
" generator_optimizer = torch.optim.AdamW(generator.parameters(), lr=OPTIMIZER_LR,betas=OPTIMIZER_BETAS)\n",
" discriminator_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=OPTIMIZER_LR,betas=OPTIMIZER_BETAS)\n",
" return generator_optimizer, discriminator_optimizer\n",
"\n",
"def efficient_zero_grad(model):\n",
"\n",
" for param in model.parameters():\n",
" param.grad = None\n",
"\n",
"\n",
"def forward_and_backward(model, data, loss_function, targets):\n",
"\n",
" outputs = model(data)\n",
" error = loss_function(outputs, targets)\n",
" error.backward()\n",
" return error.item()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Generator and Discriminator Training"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"def train_step(generator, discriminator, real_data, \\\n",
" loss_function, generator_optimizer, discriminator_optimizer, device = device):\n",
" \n",
" # 1. PREPARATION\n",
" # Set real and fake labels.\n",
" real_label, fake_label = 1.0, 0.0\n",
" # Get images on CPU or GPU as configured and available\n",
" # Also set 'actual batch size', whih can be smaller than BATCH_SIZE\n",
" # in some cases.\n",
" real_images = real_data[0].to(device)\n",
" actual_batch_size = real_images.size(0)\n",
" label = torch.full((actual_batch_size,1), real_label, device=device)\n",
" \n",
" # 2. TRAINING THE DISCRIMINATOR\n",
" # Zero the gradients for discriminator\n",
" efficient_zero_grad(discriminator)\n",
" # Forward + backward on real images, reshaped\n",
" real_images = real_images.view(real_images.size(0), -1)\n",
" error_real_images = forward_and_backward(discriminator, real_images, \\\n",
" loss_function, label)\n",
" # Forward + backward on generated images\n",
" noise = generate_noise(actual_batch_size, device=device)\n",
" generated_images = generator(noise)\n",
" label.fill_(fake_label)\n",
" error_generated_images =forward_and_backward(discriminator, \\\n",
" generated_images.detach(), loss_function, label)\n",
" # Optim for discriminator\n",
" discriminator_optimizer.step()\n",
" \n",
" # 3. TRAINING THE GENERATOR\n",
" # Forward + backward + optim for generator, including zero grad\n",
" efficient_zero_grad(generator)\n",
" label.fill_(real_label)\n",
" error_generator = forward_and_backward(discriminator, generated_images, loss_function, label)\n",
" generator_optimizer.step()\n",
" \n",
" # 4. COMPUTING RESULTS\n",
" # Compute loss values in floats for discriminator, which is joint loss.\n",
" error_discriminator = error_real_images + error_generated_images\n",
" # Return generator and discriminator loss so that it can be printed.\n",
" return error_generator, error_discriminator"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"def perform_epoch(dataloader, generator, discriminator, loss_function, \\\n",
" generator_optimizer, discriminator_optimizer, epoch):\n",
"\n",
" for batch_no, real_data in enumerate(dataloader, 0):\n",
" # Perform training step\n",
" generator_loss_val, discriminator_loss_val = train_step(generator, \\\n",
" discriminator, real_data, loss_function, \\\n",
" generator_optimizer, discriminator_optimizer)\n",
" # Print statistics and generate image after every n-th batch\n",
" if batch_no % PRINT_STATS_AFTER_BATCH == 0:\n",
" print_training_progress(batch_no, generator_loss_val, discriminator_loss_val)\n",
" generate_image(generator, epoch, batch_no)\n",
" \n",
" # Save models on epoch completion.\n",
" # save_models(generator, discriminator, epoch)\n",
" \n",
" # Clear memory after every epoch\n",
" torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train_gan():\n",
" \"\"\" \n",
" Train the DCGAN. \n",
" \"\"\"\n",
"\n",
" torch.manual_seed(42)\n",
"\n",
" dataloader = prepare_dataset()\n",
"\n",
" generator, discriminator = initialize_models()\n",
"\n",
" loss_function = initialize_loss()\n",
" generator_optimizer, discriminator_optimizer = initialize_optimizers(generator, discriminator)\n",
"\n",
" for epoch in range(NUM_EPOCHS):\n",
" print(f'Starting epoch {epoch}...')\n",
" perform_epoch(dataloader, generator, discriminator, loss_function, \\\n",
" generator_optimizer, discriminator_optimizer, epoch)\n",
"\n",
" print('Finished :-)')\n",
"\n",
"\n",
"if __name__ == '__main__':\n",
" print(device)\n",
" train_gan()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create GIF"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"from glob import glob\n",
"\n",
"images = []\n",
"image_paths = sorted(glob(\"./data/images/*_batch100.jpg\"))\n",
"for image_path in image_paths:\n",
" img = Image.open(image_path)\n",
" images.append(img)\n",
"\n",
"output_path = \"output.gif\"\n",
"\n",
"images[0].save(output_path, save_all=True, append_images=images[1:], loop=0, duration=250)"
]
}
],
"metadata": {
Expand Down

0 comments on commit f26e012

Please sign in to comment.