diff --git a/Vision Transformer/README.md b/Vision Transformer/README.md new file mode 100644 index 000000000..79a8a355f --- /dev/null +++ b/Vision Transformer/README.md @@ -0,0 +1,7 @@ +# vision-transformers-form-scratch + +![architecture](images/Vision_transforemer.png) + +
+This repo contains the code for all layers of the vision-transformers +
diff --git a/Vision Transformer/images/Vision_transforemer.png b/Vision Transformer/images/Vision_transforemer.png new file mode 100644 index 000000000..622e60f49 Binary files /dev/null and b/Vision Transformer/images/Vision_transforemer.png differ diff --git a/Vision Transformer/vits.ipynb b/Vision Transformer/vits.ipynb new file mode 100644 index 000000000..e21eeeac1 --- /dev/null +++ b/Vision Transformer/vits.ipynb @@ -0,0 +1,438 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "425cd4b8-36ab-49a3-8262-6df42f12178b", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import pandas as pd\n", + "from torch import nn\n", + "from torch import optim\n", + "from torch.utils.data import DataLoader, Dataset\n", + "from torchvision import transforms\n", + "from sklearn.model_selection import train_test_split\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import random\n", + "import timeit\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "915d570a-4178-4e50-ab78-0979003d8b40", + "metadata": {}, + "outputs": [], + "source": [ + "RANDOM_SEED = 42\n", + "BATCH_SIZE = 512\n", + "EPOCHS = 40\n", + "LEARNING_RATE = 1e-4\n", + "NUM_CLASSES = 10\n", + "PATCH_SIZE = 4\n", + "IMG_SIZE = 28\n", + "IN_CHANNELS = 1\n", + "NUM_HEADS = 8\n", + "DROPOUT = 0.001\n", + "HIDDEN_DIM = 768\n", + "ADAM_WEIGHT_DECAY = 0\n", + "ADAM_BETAS = (0.9, 0.999)\n", + "ACTIVATION=\"gelu\"\n", + "NUM_ENCODERS = 4\n", + "EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # 16\n", + "NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2 # 49\n", + "\n", + "random.seed(RANDOM_SEED)\n", + "np.random.seed(RANDOM_SEED)\n", + "torch.manual_seed(RANDOM_SEED)\n", + "torch.cuda.manual_seed(RANDOM_SEED)\n", + "torch.cuda.manual_seed_all(RANDOM_SEED)\n", + "torch.backends.cudnn.deterministic = True\n", + "torch.backends.cudnn.benchmark = False\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31a145f8-6cf8-449b-9a0c-64a56342239e", + "metadata": {}, + "outputs": [], + "source": [ + "class PatchEmbedding(nn.Module):\n", + " def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):\n", + " super().__init__()\n", + " self.patcher = nn.Sequential(\n", + " nn.Conv2d(\n", + " in_channels=in_channels,\n", + " out_channels=embed_dim,\n", + " kernel_size=patch_size,\n", + " stride=patch_size,\n", + " ), \n", + " nn.Flatten(2))\n", + "\n", + " self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)\n", + " self.position_embeddings = nn.Parameter(torch.randn(size=(1, num_patches+1, embed_dim)), requires_grad=True)\n", + " self.dropout = nn.Dropout(p=dropout)\n", + "\n", + " def forward(self, x):\n", + " cls_token = self.cls_token.expand(x.shape[0], -1, -1)\n", + "\n", + " x = self.patcher(x).permute(0, 2, 1)\n", + " x = torch.cat([cls_token, x], dim=1)\n", + " x = self.position_embeddings + x \n", + " x = self.dropout(x)\n", + " return x\n", + " \n", + "model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)\n", + "x = torch.randn(512, 1, 28, 28).to(device)\n", + "print(model(x).shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e7d778e-b831-40a2-9d8e-f057e3d64b9f", + "metadata": {}, + "outputs": [], + "source": [ + "class ViT(nn.Module):\n", + " def __init__(self, num_patches, img_size, num_classes, patch_size, embed_dim, num_encoders, num_heads, hidden_dim, dropout, activation, in_channels):\n", + " super().__init__()\n", + " self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)\n", + " \n", + " encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True)\n", + " self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)\n", + "\n", + " self.mlp_head = nn.Sequential(\n", + " nn.LayerNorm(normalized_shape=embed_dim),\n", + " nn.Linear(in_features=embed_dim, out_features=num_classes)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.embeddings_block(x)\n", + " x = self.encoder_blocks(x)\n", + " x = self.mlp_head(x[:, 0, :]) # Apply MLP on the CLS token only\n", + " return x\n", + "\n", + "model = ViT(NUM_PATCHES, IMG_SIZE, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, HIDDEN_DIM, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)\n", + "x = torch.randn(512, 1, 28, 28).to(device)\n", + "print(model(x).shape) # BATCH_SIZE X NUM_CLASSES" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c911748f-1aba-4678-86bd-b542a432c51d", + "metadata": {}, + "outputs": [], + "source": [ + "train_df = pd.read_csv(\"give the path to the dataset\")\n", + "test_df = pd.read_csv(\"give the path to the dataset\")\n", + "submission_df = pd.read_csv(\"give the path to the dataset\")\n", + "print(\"warinign the dataset should be csv , if the ram of the pc is less than 16 make the batch_size is lesser than 12\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13bb1bf0-52fb-434e-9c90-dc113670edda", + "metadata": {}, + "outputs": [], + "source": [ + "train_df.head()\n", + "test_df.head()\n", + "submission_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba0c79f6-2efb-4a4f-8f9a-bd4c11b8bf58", + "metadata": {}, + "outputs": [], + "source": [ + "train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=RANDOM_SEED, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "503ed1f8-6c4a-4e72-81b2-b9d455c6ae04", + "metadata": {}, + "outputs": [], + "source": [ + "class MNISTTrainDataset(Dataset):\n", + " def __init__(self, images, labels, indicies):\n", + " self.images = images\n", + " self.labels = labels\n", + " self.indicies = indicies\n", + " self.transform = transforms.Compose([\n", + " transforms.ToPILImage(),\n", + " transforms.RandomRotation(15),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.5], [0.5])\n", + " ])\n", + " \n", + " def __len__(self):\n", + " return len(self.images)\n", + " \n", + " def __getitem__(self, idx):\n", + " image = self.images[idx].reshape((28, 28)).astype(np.uint8)\n", + " label = self.labels[idx]\n", + " index = self.indicies[idx]\n", + " image = self.transform(image)\n", + " \n", + " return {\"image\": image, \"label\": label, \"index\": index}\n", + " \n", + "class MNISTValDataset(Dataset):\n", + " def __init__(self, images, labels, indicies):\n", + " self.images = images\n", + " self.labels = labels\n", + " self.indicies = indicies\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.5], [0.5])\n", + " ])\n", + " \n", + " def __len__(self):\n", + " return len(self.images)\n", + " \n", + " def __getitem__(self, idx):\n", + " image = self.images[idx].reshape((28, 28)).astype(np.uint8)\n", + " label = self.labels[idx]\n", + " index = self.indicies[idx]\n", + " image = self.transform(image)\n", + " \n", + " return {\"image\": image, \"label\": label, \"index\": index}\n", + " \n", + "class MNISTSubmitDataset(Dataset):\n", + " def __init__(self, images, indicies):\n", + " self.images = images\n", + " self.indicies = indicies\n", + " self.transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize([0.5], [0.5])\n", + " ])\n", + " \n", + " def __len__(self):\n", + " return len(self.images)\n", + " \n", + " def __getitem__(self, idx):\n", + " image = self.images[idx].reshape((28, 28)).astype(np.uint8)\n", + " index = self.indicies[idx]\n", + " image = self.transform(image)\n", + " \n", + " return {\"image\": image, \"index\": index}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09241273-f6e2-472a-b0ea-bcfaa91823b4", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure()\n", + "f, axarr = plt.subplots(1, 3)\n", + "\n", + "train_dataset = MNISTTrainDataset(train_df.iloc[:, 1:].values.astype(np.uint8), train_df.iloc[:, 0].values, train_df.index.values)\n", + "print(len(train_dataset))\n", + "print(train_dataset[0])\n", + "axarr[0].imshow(train_dataset[0][\"image\"].squeeze(), cmap=\"gray\")\n", + "axarr[0].set_title(\"Train Image\")\n", + "print(\"-\"*30)\n", + "\n", + "val_dataset = MNISTValDataset(val_df.iloc[:, 1:].values.astype(np.uint8), val_df.iloc[:, 0].values, val_df.index.values)\n", + "print(len(val_dataset))\n", + "print(val_dataset[0])\n", + "axarr[1].imshow(val_dataset[0][\"image\"].squeeze(), cmap=\"gray\")\n", + "axarr[1].set_title(\"Val Image\")\n", + "print(\"-\"*30)\n", + "\n", + "test_dataset = MNISTSubmitDataset(test_df.values.astype(np.uint8), test_df.index.values)\n", + "print(len(test_dataset))\n", + "print(test_dataset[0])\n", + "axarr[2].imshow(test_dataset[0][\"image\"].squeeze(), cmap=\"gray\")\n", + "axarr[2].set_title(\"Test Image\")\n", + "print(\"-\"*30)\n", + "\n", + "plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82c6ce1d-5494-483f-8459-c8cd992d3e36", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataloader = DataLoader(dataset=train_dataset,\n", + " batch_size=BATCH_SIZE,\n", + " shuffle=True)\n", + "\n", + "val_dataloader = DataLoader(dataset=val_dataset,\n", + " batch_size=BATCH_SIZE,\n", + " shuffle=True)\n", + "\n", + "test_dataloader = DataLoader(dataset=test_dataset,\n", + " batch_size=BATCH_SIZE,\n", + " shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2f13d32-8d7f-4c7e-833a-8fd28ca92003", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY)\n", + "\n", + "start = timeit.default_timer()\n", + "for epoch in tqdm(range(EPOCHS), position=0, leave=True):\n", + " model.train()\n", + " train_labels = []\n", + " train_preds = []\n", + " train_running_loss = 0\n", + " for idx, img_label in enumerate(tqdm(train_dataloader, position=0, leave=True)):\n", + " img = img_label[\"image\"].float().to(device)\n", + " label = img_label[\"label\"].type(torch.uint8).to(device)\n", + " y_pred = model(img)\n", + " y_pred_label = torch.argmax(y_pred, dim=1)\n", + "\n", + " train_labels.extend(label.cpu().detach())\n", + " train_preds.extend(y_pred_label.cpu().detach())\n", + " \n", + " loss = criterion(y_pred, label)\n", + " \n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " train_running_loss += loss.item()\n", + " train_loss = train_running_loss / (idx + 1)\n", + "\n", + " model.eval()\n", + " val_labels = []\n", + " val_preds = []\n", + " val_running_loss = 0\n", + " with torch.no_grad():\n", + " for idx, img_label in enumerate(tqdm(val_dataloader, position=0, leave=True)):\n", + " img = img_label[\"image\"].float().to(device)\n", + " label = img_label[\"label\"].type(torch.uint8).to(device) \n", + " y_pred = model(img)\n", + " y_pred_label = torch.argmax(y_pred, dim=1)\n", + " \n", + " val_labels.extend(label.cpu().detach())\n", + " val_preds.extend(y_pred_label.cpu().detach())\n", + " \n", + " loss = criterion(y_pred, label)\n", + " val_running_loss += loss.item()\n", + " val_loss = val_running_loss / (idx + 1)\n", + "\n", + " print(\"-\"*30)\n", + " print(f\"Train Loss EPOCH {epoch+1}: {train_loss:.4f}\")\n", + " print(f\"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}\")\n", + " print(f\"Train Accuracy EPOCH {epoch+1}: {sum(1 for x,y in zip(train_preds, train_labels) if x == y) / len(train_labels):.4f}\")\n", + " print(f\"Valid Accuracy EPOCH {epoch+1}: {sum(1 for x,y in zip(val_preds, val_labels) if x == y) / len(val_labels):.4f}\")\n", + " print(\"-\"*30)\n", + "\n", + "stop = timeit.default_timer()\n", + "print(f\"Training Time: {stop-start:.2f}s\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "349ab99f-244a-46c1-a19a-09794dd1f871", + "metadata": {}, + "outputs": [], + "source": [ + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "110223e7-1c15-4959-8a45-70a84356461a", + "metadata": {}, + "outputs": [], + "source": [ + "labels = []\n", + "ids = []\n", + "imgs = []\n", + "model.eval()\n", + "with torch.no_grad():\n", + " for idx, sample in enumerate(tqdm(test_dataloader, position=0, leave=True)):\n", + " img = sample[\"image\"].to(device)\n", + " ids.extend([int(i)+1 for i in sample[\"index\"]])\n", + " \n", + " outputs = model(img)\n", + " \n", + " imgs.extend(img.detach().cpu())\n", + " labels.extend([int(i) for i in torch.argmax(outputs, dim=1)])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9a505ab-1de2-4420-a6da-f38c0b5988a7", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure()\n", + "f, axarr = plt.subplots(2, 3)\n", + "counter = 0\n", + "for i in range(2):\n", + " for j in range(3):\n", + " axarr[i][j].imshow(imgs[counter].squeeze(), cmap=\"gray\")\n", + " axarr[i][j].set_title(f\"Predicted {labels[counter]}\")\n", + " counter += 1\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b868d223-37ee-4b34-a23e-28da702cfa3f", + "metadata": {}, + "outputs": [], + "source": [ + "submission_df = pd.DataFrame(list(zip(ids, labels)),\n", + " columns =[\"ImageId\", \"Label\"])\n", + "submission_df.to_csv(\"submission.csv\", index=False)\n", + "submission_df.head()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}