diff --git a/Vision Transformer/README.md b/Vision Transformer/README.md
new file mode 100644
index 000000000..79a8a355f
--- /dev/null
+++ b/Vision Transformer/README.md
@@ -0,0 +1,7 @@
+# vision-transformers-form-scratch
+
+![architecture](images/Vision_transforemer.png)
+
+
+This repo contains the code for all layers of the vision-transformers
+
diff --git a/Vision Transformer/images/Vision_transforemer.png b/Vision Transformer/images/Vision_transforemer.png
new file mode 100644
index 000000000..622e60f49
Binary files /dev/null and b/Vision Transformer/images/Vision_transforemer.png differ
diff --git a/Vision Transformer/vits.ipynb b/Vision Transformer/vits.ipynb
new file mode 100644
index 000000000..e21eeeac1
--- /dev/null
+++ b/Vision Transformer/vits.ipynb
@@ -0,0 +1,438 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "425cd4b8-36ab-49a3-8262-6df42f12178b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import pandas as pd\n",
+ "from torch import nn\n",
+ "from torch import optim\n",
+ "from torch.utils.data import DataLoader, Dataset\n",
+ "from torchvision import transforms\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import random\n",
+ "import timeit\n",
+ "from tqdm import tqdm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "915d570a-4178-4e50-ab78-0979003d8b40",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "RANDOM_SEED = 42\n",
+ "BATCH_SIZE = 512\n",
+ "EPOCHS = 40\n",
+ "LEARNING_RATE = 1e-4\n",
+ "NUM_CLASSES = 10\n",
+ "PATCH_SIZE = 4\n",
+ "IMG_SIZE = 28\n",
+ "IN_CHANNELS = 1\n",
+ "NUM_HEADS = 8\n",
+ "DROPOUT = 0.001\n",
+ "HIDDEN_DIM = 768\n",
+ "ADAM_WEIGHT_DECAY = 0\n",
+ "ADAM_BETAS = (0.9, 0.999)\n",
+ "ACTIVATION=\"gelu\"\n",
+ "NUM_ENCODERS = 4\n",
+ "EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # 16\n",
+ "NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2 # 49\n",
+ "\n",
+ "random.seed(RANDOM_SEED)\n",
+ "np.random.seed(RANDOM_SEED)\n",
+ "torch.manual_seed(RANDOM_SEED)\n",
+ "torch.cuda.manual_seed(RANDOM_SEED)\n",
+ "torch.cuda.manual_seed_all(RANDOM_SEED)\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "torch.backends.cudnn.benchmark = False\n",
+ "\n",
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "31a145f8-6cf8-449b-9a0c-64a56342239e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class PatchEmbedding(nn.Module):\n",
+ " def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):\n",
+ " super().__init__()\n",
+ " self.patcher = nn.Sequential(\n",
+ " nn.Conv2d(\n",
+ " in_channels=in_channels,\n",
+ " out_channels=embed_dim,\n",
+ " kernel_size=patch_size,\n",
+ " stride=patch_size,\n",
+ " ), \n",
+ " nn.Flatten(2))\n",
+ "\n",
+ " self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)\n",
+ " self.position_embeddings = nn.Parameter(torch.randn(size=(1, num_patches+1, embed_dim)), requires_grad=True)\n",
+ " self.dropout = nn.Dropout(p=dropout)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " cls_token = self.cls_token.expand(x.shape[0], -1, -1)\n",
+ "\n",
+ " x = self.patcher(x).permute(0, 2, 1)\n",
+ " x = torch.cat([cls_token, x], dim=1)\n",
+ " x = self.position_embeddings + x \n",
+ " x = self.dropout(x)\n",
+ " return x\n",
+ " \n",
+ "model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)\n",
+ "x = torch.randn(512, 1, 28, 28).to(device)\n",
+ "print(model(x).shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8e7d778e-b831-40a2-9d8e-f057e3d64b9f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class ViT(nn.Module):\n",
+ " def __init__(self, num_patches, img_size, num_classes, patch_size, embed_dim, num_encoders, num_heads, hidden_dim, dropout, activation, in_channels):\n",
+ " super().__init__()\n",
+ " self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)\n",
+ " \n",
+ " encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True)\n",
+ " self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)\n",
+ "\n",
+ " self.mlp_head = nn.Sequential(\n",
+ " nn.LayerNorm(normalized_shape=embed_dim),\n",
+ " nn.Linear(in_features=embed_dim, out_features=num_classes)\n",
+ " )\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.embeddings_block(x)\n",
+ " x = self.encoder_blocks(x)\n",
+ " x = self.mlp_head(x[:, 0, :]) # Apply MLP on the CLS token only\n",
+ " return x\n",
+ "\n",
+ "model = ViT(NUM_PATCHES, IMG_SIZE, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, HIDDEN_DIM, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)\n",
+ "x = torch.randn(512, 1, 28, 28).to(device)\n",
+ "print(model(x).shape) # BATCH_SIZE X NUM_CLASSES"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c911748f-1aba-4678-86bd-b542a432c51d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_df = pd.read_csv(\"give the path to the dataset\")\n",
+ "test_df = pd.read_csv(\"give the path to the dataset\")\n",
+ "submission_df = pd.read_csv(\"give the path to the dataset\")\n",
+ "print(\"warinign the dataset should be csv , if the ram of the pc is less than 16 make the batch_size is lesser than 12\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "13bb1bf0-52fb-434e-9c90-dc113670edda",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_df.head()\n",
+ "test_df.head()\n",
+ "submission_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ba0c79f6-2efb-4a4f-8f9a-bd4c11b8bf58",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=RANDOM_SEED, shuffle=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "503ed1f8-6c4a-4e72-81b2-b9d455c6ae04",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class MNISTTrainDataset(Dataset):\n",
+ " def __init__(self, images, labels, indicies):\n",
+ " self.images = images\n",
+ " self.labels = labels\n",
+ " self.indicies = indicies\n",
+ " self.transform = transforms.Compose([\n",
+ " transforms.ToPILImage(),\n",
+ " transforms.RandomRotation(15),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.5], [0.5])\n",
+ " ])\n",
+ " \n",
+ " def __len__(self):\n",
+ " return len(self.images)\n",
+ " \n",
+ " def __getitem__(self, idx):\n",
+ " image = self.images[idx].reshape((28, 28)).astype(np.uint8)\n",
+ " label = self.labels[idx]\n",
+ " index = self.indicies[idx]\n",
+ " image = self.transform(image)\n",
+ " \n",
+ " return {\"image\": image, \"label\": label, \"index\": index}\n",
+ " \n",
+ "class MNISTValDataset(Dataset):\n",
+ " def __init__(self, images, labels, indicies):\n",
+ " self.images = images\n",
+ " self.labels = labels\n",
+ " self.indicies = indicies\n",
+ " self.transform = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.5], [0.5])\n",
+ " ])\n",
+ " \n",
+ " def __len__(self):\n",
+ " return len(self.images)\n",
+ " \n",
+ " def __getitem__(self, idx):\n",
+ " image = self.images[idx].reshape((28, 28)).astype(np.uint8)\n",
+ " label = self.labels[idx]\n",
+ " index = self.indicies[idx]\n",
+ " image = self.transform(image)\n",
+ " \n",
+ " return {\"image\": image, \"label\": label, \"index\": index}\n",
+ " \n",
+ "class MNISTSubmitDataset(Dataset):\n",
+ " def __init__(self, images, indicies):\n",
+ " self.images = images\n",
+ " self.indicies = indicies\n",
+ " self.transform = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize([0.5], [0.5])\n",
+ " ])\n",
+ " \n",
+ " def __len__(self):\n",
+ " return len(self.images)\n",
+ " \n",
+ " def __getitem__(self, idx):\n",
+ " image = self.images[idx].reshape((28, 28)).astype(np.uint8)\n",
+ " index = self.indicies[idx]\n",
+ " image = self.transform(image)\n",
+ " \n",
+ " return {\"image\": image, \"index\": index}\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "09241273-f6e2-472a-b0ea-bcfaa91823b4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.figure()\n",
+ "f, axarr = plt.subplots(1, 3)\n",
+ "\n",
+ "train_dataset = MNISTTrainDataset(train_df.iloc[:, 1:].values.astype(np.uint8), train_df.iloc[:, 0].values, train_df.index.values)\n",
+ "print(len(train_dataset))\n",
+ "print(train_dataset[0])\n",
+ "axarr[0].imshow(train_dataset[0][\"image\"].squeeze(), cmap=\"gray\")\n",
+ "axarr[0].set_title(\"Train Image\")\n",
+ "print(\"-\"*30)\n",
+ "\n",
+ "val_dataset = MNISTValDataset(val_df.iloc[:, 1:].values.astype(np.uint8), val_df.iloc[:, 0].values, val_df.index.values)\n",
+ "print(len(val_dataset))\n",
+ "print(val_dataset[0])\n",
+ "axarr[1].imshow(val_dataset[0][\"image\"].squeeze(), cmap=\"gray\")\n",
+ "axarr[1].set_title(\"Val Image\")\n",
+ "print(\"-\"*30)\n",
+ "\n",
+ "test_dataset = MNISTSubmitDataset(test_df.values.astype(np.uint8), test_df.index.values)\n",
+ "print(len(test_dataset))\n",
+ "print(test_dataset[0])\n",
+ "axarr[2].imshow(test_dataset[0][\"image\"].squeeze(), cmap=\"gray\")\n",
+ "axarr[2].set_title(\"Test Image\")\n",
+ "print(\"-\"*30)\n",
+ "\n",
+ "plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "82c6ce1d-5494-483f-8459-c8cd992d3e36",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_dataloader = DataLoader(dataset=train_dataset,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " shuffle=True)\n",
+ "\n",
+ "val_dataloader = DataLoader(dataset=val_dataset,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " shuffle=True)\n",
+ "\n",
+ "test_dataloader = DataLoader(dataset=test_dataset,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a2f13d32-8d7f-4c7e-833a-8fd28ca92003",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "\n",
+ "criterion = nn.CrossEntropyLoss()\n",
+ "optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY)\n",
+ "\n",
+ "start = timeit.default_timer()\n",
+ "for epoch in tqdm(range(EPOCHS), position=0, leave=True):\n",
+ " model.train()\n",
+ " train_labels = []\n",
+ " train_preds = []\n",
+ " train_running_loss = 0\n",
+ " for idx, img_label in enumerate(tqdm(train_dataloader, position=0, leave=True)):\n",
+ " img = img_label[\"image\"].float().to(device)\n",
+ " label = img_label[\"label\"].type(torch.uint8).to(device)\n",
+ " y_pred = model(img)\n",
+ " y_pred_label = torch.argmax(y_pred, dim=1)\n",
+ "\n",
+ " train_labels.extend(label.cpu().detach())\n",
+ " train_preds.extend(y_pred_label.cpu().detach())\n",
+ " \n",
+ " loss = criterion(y_pred, label)\n",
+ " \n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " train_running_loss += loss.item()\n",
+ " train_loss = train_running_loss / (idx + 1)\n",
+ "\n",
+ " model.eval()\n",
+ " val_labels = []\n",
+ " val_preds = []\n",
+ " val_running_loss = 0\n",
+ " with torch.no_grad():\n",
+ " for idx, img_label in enumerate(tqdm(val_dataloader, position=0, leave=True)):\n",
+ " img = img_label[\"image\"].float().to(device)\n",
+ " label = img_label[\"label\"].type(torch.uint8).to(device) \n",
+ " y_pred = model(img)\n",
+ " y_pred_label = torch.argmax(y_pred, dim=1)\n",
+ " \n",
+ " val_labels.extend(label.cpu().detach())\n",
+ " val_preds.extend(y_pred_label.cpu().detach())\n",
+ " \n",
+ " loss = criterion(y_pred, label)\n",
+ " val_running_loss += loss.item()\n",
+ " val_loss = val_running_loss / (idx + 1)\n",
+ "\n",
+ " print(\"-\"*30)\n",
+ " print(f\"Train Loss EPOCH {epoch+1}: {train_loss:.4f}\")\n",
+ " print(f\"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}\")\n",
+ " print(f\"Train Accuracy EPOCH {epoch+1}: {sum(1 for x,y in zip(train_preds, train_labels) if x == y) / len(train_labels):.4f}\")\n",
+ " print(f\"Valid Accuracy EPOCH {epoch+1}: {sum(1 for x,y in zip(val_preds, val_labels) if x == y) / len(val_labels):.4f}\")\n",
+ " print(\"-\"*30)\n",
+ "\n",
+ "stop = timeit.default_timer()\n",
+ "print(f\"Training Time: {stop-start:.2f}s\")\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "349ab99f-244a-46c1-a19a-09794dd1f871",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "torch.cuda.empty_cache()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "110223e7-1c15-4959-8a45-70a84356461a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "labels = []\n",
+ "ids = []\n",
+ "imgs = []\n",
+ "model.eval()\n",
+ "with torch.no_grad():\n",
+ " for idx, sample in enumerate(tqdm(test_dataloader, position=0, leave=True)):\n",
+ " img = sample[\"image\"].to(device)\n",
+ " ids.extend([int(i)+1 for i in sample[\"index\"]])\n",
+ " \n",
+ " outputs = model(img)\n",
+ " \n",
+ " imgs.extend(img.detach().cpu())\n",
+ " labels.extend([int(i) for i in torch.argmax(outputs, dim=1)])\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a9a505ab-1de2-4420-a6da-f38c0b5988a7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.figure()\n",
+ "f, axarr = plt.subplots(2, 3)\n",
+ "counter = 0\n",
+ "for i in range(2):\n",
+ " for j in range(3):\n",
+ " axarr[i][j].imshow(imgs[counter].squeeze(), cmap=\"gray\")\n",
+ " axarr[i][j].set_title(f\"Predicted {labels[counter]}\")\n",
+ " counter += 1\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b868d223-37ee-4b34-a23e-28da702cfa3f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "submission_df = pd.DataFrame(list(zip(ids, labels)),\n",
+ " columns =[\"ImageId\", \"Label\"])\n",
+ "submission_df.to_csv(\"submission.csv\", index=False)\n",
+ "submission_df.head()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}