diff --git a/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Full_Training_Classic.ipynb b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Full_Training_Classic.ipynb new file mode 100644 index 0000000..7a62700 --- /dev/null +++ b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Full_Training_Classic.ipynb @@ -0,0 +1,440 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "de52de89-cc57-4b32-bd13-e51cb510b730", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import math\n", + "import os, glob, tarfile\n", + "from pathlib import Path \n", + "\n", + "tar_path = \"global/homes/a/aletesi/PaarT/Data/JetClass_Pythia_val_5M.tar\"\n", + "extract_dir = \"global/homes/a/aletesi/PaarT/Data/JetClass_Pythia_val_5M\"\n", + "\n", + "files = glob.glob(os.path.join(Path(\"/global/homes/a/aletesi/PaarT/Data/JetClass_Pythia_val_5M/val_5M\"), \"*.root\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8a81a422-578a-472f-b13b-6c3271c334f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/torchvision-0.21.0+7af6987-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/setuptools-75.8.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/pillow-11.1.0-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: awkward in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (2.8.7)\n", + "Requirement already satisfied: uproot in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (5.6.4)\n", + "Requirement already satisfied: vector in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (1.6.3)\n", + "Requirement already satisfied: awkward-cpp==48 in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (from awkward) (48)\n", + "Requirement already satisfied: fsspec>=2022.11.0 in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from awkward) (2024.12.0)\n", + "Requirement already satisfied: numpy>=1.18.0 in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from awkward) (2.2.3)\n", + "Requirement already satisfied: packaging in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from awkward) (24.2)\n", + "Requirement already satisfied: cramjam>=2.5.0 in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (from uproot) (2.11.0)\n", + "Requirement already satisfied: xxhash in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from uproot) (3.5.0)\n", + "torch.Size([5000, 4, 8]) torch.Size([5000, 10])\n" + ] + } + ], + "source": [ + "from torch.utils.data import Dataset\n", + "!pip install awkward uproot vector\n", + "from particle_transformer.dataloader import read_file\n", + "\n", + "all_x_parts = []\n", + "all_ys = []\n", + "\n", + "for file in files:\n", + " x_part, x_jets, y = read_file(\n", + " file,\n", + " max_num_particles=8,\n", + " particle_features=['part_pt', 'part_eta', 'part_phi', 'part_energy'],\n", + " jet_features=['jet_pt', 'jet_eta', 'jet_phi', 'jet_energy'],\n", + " labels=[\n", + " 'label_QCD', 'label_Hbb', 'label_Hcc', 'label_Hgg', 'label_H4q',\n", + " 'label_Hqql', 'label_Zqq', 'label_Wqq', 'label_Tbqq', 'label_Tbl',\n", + " ]\n", + " )\n", + " all_x_parts.append(torch.tensor(x_part, dtype=torch.float32)[:100,:,:])\n", + " all_ys.append(torch.tensor(y, dtype=torch.float32)[:100,:])\n", + "\n", + "x_all = torch.cat(all_x_parts, dim=0)\n", + "y_all = torch.cat(all_ys, dim=0)\n", + "print(x_all.shape, y_all.shape)\n", + "\n", + "class JetDataset(Dataset):\n", + " def __init__(self, x, y):\n", + " self.x = x\n", + " self.y = y\n", + "\n", + " def __len__(self):\n", + " return self.x.shape[0]\n", + "\n", + " def __getitem__(self, idx):\n", + " return self.x[idx], self.y[idx]\n", + "\n", + "dataset = JetDataset(x_all, y_all)\n", + "\n", + "from torch.utils.data import DataLoader\n", + "dataloader = DataLoader(dataset, batch_size=64, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7375769d-465d-4ad4-b859-37a3d892104e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class_token | 10 params | shape (1, 1, 10)\n", + "tokenizer.proj.weight | 40 params | shape (10, 4)\n", + "tokenizer.proj.bias | 10 params | shape (10,)\n", + "U_encoder.net.0.weight | 256 params | shape (64, 4, 1, 1)\n", + "U_encoder.net.1.weight | 64 params | shape (64,)\n", + "U_encoder.net.1.bias | 64 params | shape (64,)\n", + "U_encoder.net.3.weight | 4,096 params | shape (64, 64, 1, 1)\n", + "U_encoder.net.4.weight | 64 params | shape (64,)\n", + "U_encoder.net.4.bias | 64 params | shape (64,)\n", + "U_encoder.net.6.weight | 4,096 params | shape (64, 64, 1, 1)\n", + "U_encoder.net.7.weight | 64 params | shape (64,)\n", + "U_encoder.net.7.bias | 64 params | shape (64,)\n", + "U_encoder.net.9.weight | 128 params | shape (2, 64, 1, 1)\n", + "blocks.0.ln1.weight | 10 params | shape (10,)\n", + "blocks.0.ln1.bias | 10 params | shape (10,)\n", + "blocks.0.attn.q.weight | 100 params | shape (10, 10)\n", + "blocks.0.attn.k.weight | 100 params | shape (10, 10)\n", + "blocks.0.attn.v.weight | 100 params | shape (10, 10)\n", + "blocks.0.attn.o.weight | 100 params | shape (10, 10)\n", + "blocks.0.ln2.weight | 10 params | shape (10,)\n", + "blocks.0.ln2.bias | 10 params | shape (10,)\n", + "blocks.0.mlp.net.0.weight | 400 params | shape (40, 10)\n", + "blocks.0.mlp.net.0.bias | 40 params | shape (40,)\n", + "blocks.0.mlp.net.3.weight | 400 params | shape (10, 40)\n", + "blocks.0.mlp.net.3.bias | 10 params | shape (10,)\n", + "blocks.1.ln1.weight | 10 params | shape (10,)\n", + "blocks.1.ln1.bias | 10 params | shape (10,)\n", + "blocks.1.attn.q.weight | 100 params | shape (10, 10)\n", + "blocks.1.attn.k.weight | 100 params | shape (10, 10)\n", + "blocks.1.attn.v.weight | 100 params | shape (10, 10)\n", + "blocks.1.attn.o.weight | 100 params | shape (10, 10)\n", + "blocks.1.ln2.weight | 10 params | shape (10,)\n", + "blocks.1.ln2.bias | 10 params | shape (10,)\n", + "blocks.1.mlp.net.0.weight | 400 params | shape (40, 10)\n", + "blocks.1.mlp.net.0.bias | 40 params | shape (40,)\n", + "blocks.1.mlp.net.3.weight | 400 params | shape (10, 40)\n", + "blocks.1.mlp.net.3.bias | 10 params | shape (10,)\n", + "cls_blocks.0.ln1.weight | 10 params | shape (10,)\n", + "cls_blocks.0.ln1.bias | 10 params | shape (10,)\n", + "cls_blocks.0.attn.q_proj.weight | 100 params | shape (10, 10)\n", + "cls_blocks.0.attn.k_proj.weight | 100 params | shape (10, 10)\n", + "cls_blocks.0.attn.v_proj.weight | 100 params | shape (10, 10)\n", + "cls_blocks.0.attn.o_proj.weight | 100 params | shape (10, 10)\n", + "cls_blocks.0.ln2.weight | 10 params | shape (10,)\n", + "cls_blocks.0.ln2.bias | 10 params | shape (10,)\n", + "cls_blocks.0.mlp.net.0.weight | 400 params | shape (40, 10)\n", + "cls_blocks.0.mlp.net.0.bias | 40 params | shape (40,)\n", + "cls_blocks.0.mlp.net.3.weight | 400 params | shape (10, 40)\n", + "cls_blocks.0.mlp.net.3.bias | 10 params | shape (10,)\n", + "cls_blocks.1.ln1.weight | 10 params | shape (10,)\n", + "cls_blocks.1.ln1.bias | 10 params | shape (10,)\n", + "cls_blocks.1.attn.q_proj.weight | 100 params | shape (10, 10)\n", + "cls_blocks.1.attn.k_proj.weight | 100 params | shape (10, 10)\n", + "cls_blocks.1.attn.v_proj.weight | 100 params | shape (10, 10)\n", + "cls_blocks.1.attn.o_proj.weight | 100 params | shape (10, 10)\n", + "cls_blocks.1.ln2.weight | 10 params | shape (10,)\n", + "cls_blocks.1.ln2.bias | 10 params | shape (10,)\n", + "cls_blocks.1.mlp.net.0.weight | 400 params | shape (40, 10)\n", + "cls_blocks.1.mlp.net.0.bias | 40 params | shape (40,)\n", + "cls_blocks.1.mlp.net.3.weight | 400 params | shape (10, 40)\n", + "cls_blocks.1.mlp.net.3.bias | 10 params | shape (10,)\n", + "head.weight | 100 params | shape (10, 10)\n", + "head.bias | 10 params | shape (10,)\n", + "Trainable parameters: 14,290\n" + ] + } + ], + "source": [ + "from ParT import ParT\n", + "\n", + "model = ParT(\n", + " in_dim=4, # part_pt, eta, phi, energy\n", + " embed_dim=10,\n", + " n_heads=2,\n", + " depth=2,\n", + " class_depth=2,\n", + " num_classes=10\n", + ")\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "model = model.to(device)\n", + "\n", + "for name, p in model.named_parameters():\n", + " if p.requires_grad:\n", + " print(f\"{name:40s} | {p.numel():,} params | shape {tuple(p.shape)}\")\n", + "\n", + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "total_params = count_parameters(model)\n", + "print(f\"Trainable parameters: {total_params:,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ea684b3c-b3c3-4b43-a4df-160f8e5aef61", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/100, Loss: 2.3037\n", + "Epoch 2/100, Loss: 2.2782\n", + "Epoch 3/100, Loss: 2.2344\n", + "Epoch 4/100, Loss: 2.2259\n", + "Epoch 5/100, Loss: 2.2165\n", + "Epoch 6/100, Loss: 2.2124\n", + "Epoch 7/100, Loss: 2.2069\n", + "Epoch 8/100, Loss: 2.1887\n", + "Epoch 9/100, Loss: 2.1701\n", + "Epoch 10/100, Loss: 2.1473\n", + "Epoch 11/100, Loss: 2.1491\n", + "Epoch 12/100, Loss: 2.1332\n", + "Epoch 13/100, Loss: 2.1287\n", + "Epoch 14/100, Loss: 2.1211\n", + "Epoch 15/100, Loss: 2.1196\n", + "Epoch 16/100, Loss: 2.1174\n", + "Epoch 17/100, Loss: 2.1169\n", + "Epoch 18/100, Loss: 2.1101\n", + "Epoch 19/100, Loss: 2.1107\n", + "Epoch 20/100, Loss: 2.1052\n", + "Epoch 21/100, Loss: 2.1068\n", + "Epoch 22/100, Loss: 2.1155\n", + "Epoch 23/100, Loss: 2.1041\n", + "Epoch 24/100, Loss: 2.1064\n", + "Epoch 25/100, Loss: 2.1022\n", + "Epoch 26/100, Loss: 2.1045\n", + "Epoch 27/100, Loss: 2.1052\n", + "Epoch 28/100, Loss: 2.1044\n", + "Epoch 29/100, Loss: 2.0975\n", + "Epoch 30/100, Loss: 2.0979\n", + "Epoch 31/100, Loss: 2.1028\n", + "Epoch 32/100, Loss: 2.0966\n", + "Epoch 33/100, Loss: 2.0997\n", + "Epoch 34/100, Loss: 2.0963\n", + "Epoch 35/100, Loss: 2.0945\n", + "Epoch 36/100, Loss: 2.0965\n", + "Epoch 37/100, Loss: 2.0935\n", + "Epoch 38/100, Loss: 2.0842\n", + "Epoch 39/100, Loss: 2.0776\n", + "Epoch 40/100, Loss: 2.0724\n", + "Epoch 41/100, Loss: 2.0593\n", + "Epoch 42/100, Loss: 2.0583\n", + "Epoch 43/100, Loss: 2.0442\n", + "Epoch 44/100, Loss: 2.0370\n", + "Epoch 45/100, Loss: 2.0250\n", + "Epoch 46/100, Loss: 2.0140\n", + "Epoch 47/100, Loss: 2.0099\n", + "Epoch 48/100, Loss: 2.0036\n", + "Epoch 49/100, Loss: 1.9880\n", + "Epoch 50/100, Loss: 1.9932\n", + "Epoch 51/100, Loss: 1.9876\n", + "Epoch 52/100, Loss: 1.9781\n", + "Epoch 53/100, Loss: 1.9807\n", + "Epoch 54/100, Loss: 1.9629\n", + "Epoch 55/100, Loss: 1.9610\n", + "Epoch 56/100, Loss: 1.9652\n", + "Epoch 57/100, Loss: 1.9672\n", + "Epoch 58/100, Loss: 1.9633\n", + "Epoch 59/100, Loss: 1.9556\n", + "Epoch 60/100, Loss: 1.9504\n", + "Epoch 61/100, Loss: 1.9493\n", + "Epoch 62/100, Loss: 1.9435\n", + "Epoch 63/100, Loss: 1.9460\n", + "Epoch 64/100, Loss: 1.9413\n", + "Epoch 65/100, Loss: 1.9413\n", + "Epoch 66/100, Loss: 1.9421\n", + "Epoch 67/100, Loss: 1.9475\n", + "Epoch 68/100, Loss: 1.9430\n", + "Epoch 69/100, Loss: 1.9403\n", + "Epoch 70/100, Loss: 1.9430\n", + "Epoch 71/100, Loss: 1.9335\n", + "Epoch 72/100, Loss: 1.9326\n", + "Epoch 73/100, Loss: 1.9372\n", + "Epoch 74/100, Loss: 1.9294\n", + "Epoch 75/100, Loss: 1.9336\n", + "Epoch 76/100, Loss: 1.9304\n", + "Epoch 77/100, Loss: 1.9338\n", + "Epoch 78/100, Loss: 1.9241\n", + "Epoch 79/100, Loss: 1.9221\n", + "Epoch 80/100, Loss: 1.9112\n", + "Epoch 81/100, Loss: 1.9221\n", + "Epoch 82/100, Loss: 1.9199\n", + "Epoch 83/100, Loss: 1.9222\n", + "Epoch 84/100, Loss: 1.9182\n", + "Epoch 85/100, Loss: 1.9113\n", + "Epoch 86/100, Loss: 1.9158\n", + "Epoch 87/100, Loss: 1.9170\n", + "Epoch 88/100, Loss: 1.9078\n", + "Epoch 89/100, Loss: 1.9076\n", + "Epoch 90/100, Loss: 1.9154\n", + "Epoch 91/100, Loss: 1.9009\n", + "Epoch 92/100, Loss: 1.9082\n", + "Epoch 93/100, Loss: 1.8940\n", + "Epoch 94/100, Loss: 1.9011\n", + "Epoch 95/100, Loss: 1.8925\n", + "Epoch 96/100, Loss: 1.9030\n", + "Epoch 97/100, Loss: 1.8953\n", + "Epoch 98/100, Loss: 1.8908\n", + "Epoch 99/100, Loss: 1.9007\n", + "Epoch 100/100, Loss: 1.9037\n" + ] + } + ], + "source": [ + "loss_fn = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "num_epochs = 100\n", + "\n", + "for epoch in range(num_epochs):\n", + " model.train()\n", + " epoch_loss = 0.0\n", + " for batch_idx, (x, y) in enumerate(dataloader):\n", + " x, y = x.to(device), y.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(x) # shape [batch, 10]\n", + "\n", + " loss = loss_fn(outputs, y)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " epoch_loss += loss.item()\n", + "\n", + " avg_loss = epoch_loss / len(dataloader)\n", + " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9807aebd-af5b-4c3f-8155-7164404ace49", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy on full dataset: 0.3058\n" + ] + } + ], + "source": [ + "from torch.nn.functional import sigmoid, softmax\n", + "\n", + "model.eval()\n", + "correct = 0\n", + "total = 0\n", + "\n", + "with torch.no_grad():\n", + " for x, y in dataloader:\n", + " x, y = x.to(device), y.to(device)\n", + " outputs = model(x)\n", + " labels = torch.argmax(y, dim=1) # convert one-hot to class id\n", + " preds = torch.argmax(outputs, dim=1) # predicted class\n", + " correct += (preds == labels).sum().item()\n", + " total += y.size(0)\n", + "accuracy = correct / total\n", + "print(f\"Accuracy on full dataset: {accuracy:.4f}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4f4c5c87-7697-460d-b201-d50ea9bf2f97", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Macro-Averaged AUC: 0.7360\n" + ] + } + ], + "source": [ + "from sklearn.metrics import roc_auc_score\n", + "import numpy as np\n", + "\n", + "model.eval()\n", + "all_outputs = []\n", + "all_targets = []\n", + "\n", + "with torch.no_grad():\n", + " for x, y in dataloader:\n", + " x, y = x.to(device), y.to(device)\n", + " outputs = model(x)\n", + " all_outputs.append(outputs.cpu())\n", + " all_targets.append(y.cpu())\n", + "\n", + "# Concatenate batches\n", + "all_outputs = torch.cat(all_outputs, dim=0)\n", + "all_targets = torch.cat(all_targets, dim=0)\n", + "\n", + "# Apply sigmoid (if using BCEWithLogitsLoss)\n", + "probs = sigmoid(all_outputs).numpy() # shape: (N, C)\n", + "true = all_targets.numpy() # shape: (N, C)\n", + "\n", + "# Compute AUC for each class and average\n", + "try:\n", + " auc_macro = roc_auc_score(true, probs, average='macro', multi_class='ovr')\n", + " print(f\"Macro-Averaged AUC: {auc_macro:.4f}\")\n", + "except ValueError as e:\n", + " print(\"AUC could not be computed:\", e)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a99b52da-2a78-40f6-9c20-646b200e6707", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytorch-2.6.0", + "language": "python", + "name": "pytorch-2.6.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Full_Training_Quantum.ipynb b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Full_Training_Quantum.ipynb new file mode 100644 index 0000000..462c7fe --- /dev/null +++ b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Full_Training_Quantum.ipynb @@ -0,0 +1,352 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "d96fc207-2edd-4acf-9ea7-f17c33067964", + "metadata": {}, + "outputs": [], + "source": [ + "#hard comparisoin, 400 vs 5160 parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b13262b5-ea4b-45d6-b694-5766582351cd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/torchvision-0.21.0+7af6987-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/setuptools-75.8.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/pillow-11.1.0-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: awkward in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (2.8.7)\n", + "Requirement already satisfied: uproot in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (5.6.4)\n", + "Requirement already satisfied: vector in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (1.6.3)\n", + "Requirement already satisfied: awkward-cpp==48 in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (from awkward) (48)\n", + "Requirement already satisfied: fsspec>=2022.11.0 in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from awkward) (2024.12.0)\n", + "Requirement already satisfied: numpy>=1.18.0 in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from awkward) (2.2.3)\n", + "Requirement already satisfied: packaging in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from awkward) (24.2)\n", + "Requirement already satisfied: cramjam>=2.5.0 in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (from uproot) (2.11.0)\n", + "Requirement already satisfied: xxhash in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from uproot) (3.5.0)\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import math\n", + "import os, glob, tarfile\n", + "from pathlib import Path \n", + "\n", + "tar_path = \"global/homes/a/aletesi/PaarT/Data/JetClass_Pythia_val_5M.tar\"\n", + "extract_dir = \"global/homes/a/aletesi/PaarT/Data/JetClass_Pythia_val_5M\"\n", + "\n", + "files = glob.glob(os.path.join(Path(\"/global/homes/a/aletesi/PaarT/Data/JetClass_Pythia_val_5M/val_5M\"), \"*.root\"))\n", + "from torch.utils.data import Dataset\n", + "!pip install awkward uproot vector\n", + "from particle_transformer.dataloader import read_file" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "488b83b7-cae7-44fc-b7b1-f6492e6d3954", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1000, 4, 8]) torch.Size([1000, 10])\n" + ] + } + ], + "source": [ + "all_x_parts = []\n", + "all_ys = []\n", + "\n", + "for file in files:\n", + " x_part, x_jets, y = read_file(\n", + " file,\n", + " max_num_particles=8,\n", + " particle_features=['part_pt', 'part_eta', 'part_phi', 'part_energy'],\n", + " jet_features=['jet_pt', 'jet_eta', 'jet_phi', 'jet_energy'],\n", + " labels=[\n", + " 'label_QCD', 'label_Hbb', 'label_Hcc', 'label_Hgg', 'label_H4q',\n", + " 'label_Hqql', 'label_Zqq', 'label_Wqq', 'label_Tbqq', 'label_Tbl',\n", + " ]\n", + " )\n", + " all_x_parts.append(torch.tensor(x_part, dtype=torch.float32)[:20,:,:])\n", + " all_ys.append(torch.tensor(y, dtype=torch.float32)[:20,:])\n", + "\n", + "x_all = torch.cat(all_x_parts, dim=0)\n", + "y_all = torch.cat(all_ys, dim=0)\n", + "print(x_all.shape, y_all.shape)\n", + "\n", + "class JetDataset(Dataset):\n", + " def __init__(self, x, y):\n", + " self.x = x\n", + " self.y = y\n", + "\n", + " def __len__(self):\n", + " return self.x.shape[0]\n", + "\n", + " def __getitem__(self, idx):\n", + " return self.x[idx], self.y[idx]\n", + "\n", + "dataset = JetDataset(x_all, y_all)\n", + "\n", + "from torch.utils.data import DataLoader\n", + "dataloader = DataLoader(dataset, batch_size=64, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3550ba7a-862d-4371-a2c5-7ecd4a4715f9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/torchvision-0.21.0+7af6987-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/setuptools-75.8.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/pillow-11.1.0-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "Requirement already satisfied: tensorcircuit in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (0.12.0)\n", + "Requirement already satisfied: numpy in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from tensorcircuit) (2.2.3)\n", + "Requirement already satisfied: scipy in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from tensorcircuit) (1.15.2)\n", + "Requirement already satisfied: tensornetwork-ng in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (from tensorcircuit) (0.5.1)\n", + "Requirement already satisfied: networkx in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from tensorcircuit) (3.4.2)\n", + "Requirement already satisfied: graphviz>=0.11.1 in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (from tensornetwork-ng->tensorcircuit) (0.21)\n", + "Requirement already satisfied: opt-einsum>=2.3.0 in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (from tensornetwork-ng->tensorcircuit) (3.4.0)\n", + "Requirement already satisfied: h5py>=2.9.0 in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from tensornetwork-ng->tensorcircuit) (3.13.0)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Please first ``pip install -U qiskit`` to enable related functionality in translation module\n", + "Please first ``pip install -U cirq`` to enable related functionality in translation module\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class_token | 10 params | shape (1, 1, 10)\n", + "tokenizer.proj.weight | 40 params | shape (10, 4)\n", + "tokenizer.proj.bias | 10 params | shape (10,)\n", + "U_encoder.net.0.weight | 256 params | shape (64, 4, 1, 1)\n", + "U_encoder.net.1.weight | 64 params | shape (64,)\n", + "U_encoder.net.1.bias | 64 params | shape (64,)\n", + "U_encoder.net.3.weight | 4,096 params | shape (64, 64, 1, 1)\n", + "U_encoder.net.4.weight | 64 params | shape (64,)\n", + "U_encoder.net.4.bias | 64 params | shape (64,)\n", + "U_encoder.net.6.weight | 4,096 params | shape (64, 64, 1, 1)\n", + "U_encoder.net.7.weight | 64 params | shape (64,)\n", + "U_encoder.net.7.bias | 64 params | shape (64,)\n", + "U_encoder.net.9.weight | 128 params | shape (2, 64, 1, 1)\n", + "blocks.0.ln1.weight | 10 params | shape (10,)\n", + "blocks.0.ln1.bias | 10 params | shape (10,)\n", + "blocks.0.attn.q_proj.q.w | 10 params | shape (1, 10)\n", + "blocks.0.attn.k_proj.q.w | 10 params | shape (1, 10)\n", + "blocks.0.attn.v_proj.q.w | 10 params | shape (1, 10)\n", + "blocks.0.attn.o_proj.q.w | 10 params | shape (1, 10)\n", + "blocks.0.ln2.weight | 10 params | shape (10,)\n", + "blocks.0.ln2.bias | 10 params | shape (10,)\n", + "blocks.0.mlp.fc1.q.w | 10 params | shape (1, 10)\n", + "blocks.0.mlp.fc2.q.w | 10 params | shape (1, 10)\n", + "blocks.1.ln1.weight | 10 params | shape (10,)\n", + "blocks.1.ln1.bias | 10 params | shape (10,)\n", + "blocks.1.attn.q_proj.q.w | 10 params | shape (1, 10)\n", + "blocks.1.attn.k_proj.q.w | 10 params | shape (1, 10)\n", + "blocks.1.attn.v_proj.q.w | 10 params | shape (1, 10)\n", + "blocks.1.attn.o_proj.q.w | 10 params | shape (1, 10)\n", + "blocks.1.ln2.weight | 10 params | shape (10,)\n", + "blocks.1.ln2.bias | 10 params | shape (10,)\n", + "blocks.1.mlp.fc1.q.w | 10 params | shape (1, 10)\n", + "blocks.1.mlp.fc2.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.0.ln1.weight | 10 params | shape (10,)\n", + "cls_blocks.0.ln1.bias | 10 params | shape (10,)\n", + "cls_blocks.0.attn.q_proj.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.0.attn.k_proj.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.0.attn.v_proj.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.0.attn.o_proj.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.0.ln2.weight | 10 params | shape (10,)\n", + "cls_blocks.0.ln2.bias | 10 params | shape (10,)\n", + "cls_blocks.0.mlp.fc1.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.0.mlp.fc2.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.1.ln1.weight | 10 params | shape (10,)\n", + "cls_blocks.1.ln1.bias | 10 params | shape (10,)\n", + "cls_blocks.1.attn.q_proj.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.1.attn.k_proj.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.1.attn.v_proj.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.1.attn.o_proj.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.1.ln2.weight | 10 params | shape (10,)\n", + "cls_blocks.1.ln2.bias | 10 params | shape (10,)\n", + "cls_blocks.1.mlp.fc1.q.w | 10 params | shape (1, 10)\n", + "cls_blocks.1.mlp.fc2.q.w | 10 params | shape (1, 10)\n", + "head.weight | 100 params | shape (10, 10)\n", + "head.bias | 10 params | shape (10,)\n", + "Trainable parameters: 9,530\n" + ] + } + ], + "source": [ + "!pip install tensorcircuit\n", + "from QParT import ParT\n", + "\n", + "model = ParT(\n", + " in_dim=4, # part_pt, eta, phi, energy\n", + " embed_dim=10,\n", + " n_heads=2,\n", + " depth=2,\n", + " class_depth=2,\n", + " num_classes=10\n", + ")\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "model = model.to(device)\n", + "\n", + "for name, p in model.named_parameters():\n", + " if p.requires_grad:\n", + " print(f\"{name:40s} | {p.numel():,} params | shape {tuple(p.shape)}\")\n", + "\n", + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "total_params = count_parameters(model)\n", + "print(f\"Trainable parameters: {total_params:,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "040c03fb-1e38-4939-bc01-e8e81c5d42d3", + "metadata": {}, + "outputs": [], + "source": [ + "loss_fn = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "num_epochs = 100\n", + "\n", + "for epoch in range(num_epochs):\n", + " model.train()\n", + " epoch_loss = 0.0\n", + " for batch_idx, (x, y) in enumerate(dataloader):\n", + " x, y = x.to(device), y.to(device)\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(x) # shape [batch, 10]\n", + "\n", + " loss = loss_fn(outputs, y)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " epoch_loss += loss.item()\n", + "\n", + " avg_loss = epoch_loss / len(dataloader)\n", + " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a083374-38d4-4124-85fd-03e0e15dc720", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.nn.functional import sigmoid, softmax\n", + "\n", + "model.eval()\n", + "correct = 0\n", + "total = 0\n", + "\n", + "with torch.no_grad():\n", + " for x, y in dataloader:\n", + " x, y = x.to(device), y.to(device)\n", + " outputs = model(x)\n", + " labels = torch.argmax(y, dim=1) # convert one-hot to class id\n", + " preds = torch.argmax(outputs, dim=1) # predicted class\n", + " correct += (preds == labels).sum().item()\n", + " total += y.size(0)\n", + "accuracy = correct / total\n", + "print(f\"Accuracy on full dataset: {accuracy:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "260babbd-d034-4c5a-bf97-c115c2e637db", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.metrics import roc_auc_score\n", + "import numpy as np\n", + "\n", + "model.eval()\n", + "all_outputs = []\n", + "all_targets = []\n", + "\n", + "with torch.no_grad():\n", + " for x, y in dataloader:\n", + " x, y = x.to(device), y.to(device)\n", + " outputs = model(x)\n", + " all_outputs.append(outputs.cpu())\n", + " all_targets.append(y.cpu())\n", + "\n", + "# Concatenate batches\n", + "all_outputs = torch.cat(all_outputs, dim=0)\n", + "all_targets = torch.cat(all_targets, dim=0)\n", + "\n", + "# Apply sigmoid (if using BCEWithLogitsLoss)\n", + "probs = sigmoid(all_outputs).numpy() # shape: (N, C)\n", + "true = all_targets.numpy() # shape: (N, C)\n", + "\n", + "# Compute AUC for each class and average\n", + "try:\n", + " auc_macro = roc_auc_score(true, probs, average='macro', multi_class='ovr')\n", + " print(f\"Macro-Averaged AUC: {auc_macro:.4f}\")\n", + "except ValueError as e:\n", + " print(\"AUC could not be computed:\", e)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f12d7ce5-d081-472e-9176-9741908e602c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytorch-2.6.0", + "language": "python", + "name": "pytorch-2.6.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/ParT.py b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/ParT.py new file mode 100644 index 0000000..0571c45 --- /dev/null +++ b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/ParT.py @@ -0,0 +1,328 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +class InteractionEncoder(nn.Module): + """ + ParT interaction-feature encoder. + + Args + ---- + n_heads per mhsa: output channels d′ + hidden_channels : list[int] for intermediate 1×1 conv layers + eps : numerical guard for log + """ + + def __init__(self, + n_heads: int = 8, + hidden_channels: list[int] = (64, 64, 64), + eps: float = 1e-8): + super().__init__() + self.eps = eps + + layers: list[nn.Module] = [] + in_ch = 4 # lnΔ, ln kT, ln z, ln m² + for h in hidden_channels: + layers += [ + nn.Conv2d(in_ch, h, 1, bias=False), + nn.BatchNorm2d(h), + nn.GELU() + ] + in_ch = h + layers.append(nn.Conv2d(in_ch, n_heads, 1, bias=False)) + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x : (B, 4, N) where the 4 dims are (E, px, py, pz) + returns + ------ + U : (B, n_heads, N, N) interaction embedding + """ + B, four, N = x.shape + assert four == 4, "input must have 4 features: E, px, py, pz" + + # Split components + E, px, py, pz = x.unbind(dim=1) # each (B, N) + + # Basic kinematics ------------------------------------------------ + pT = torch.sqrt(px**2 + py**2) + self.eps + phi = torch.atan2(py, px) # (−π, π] + num = (E + pz).clamp(min=self.eps) #need to avoid negative numbers + den = (E - pz).clamp(min=self.eps) + y = 0.5 * torch.log(num / den) + + # Expand to (B, N, N) + y_a, y_b = y.unsqueeze(2), y.unsqueeze(1) # (B,N,1),(B,1,N) + phi_a, phi_b = phi.unsqueeze(2), phi.unsqueeze(1) + pT_a, pT_b = pT.unsqueeze(2), pT.unsqueeze(1) + E_a, E_b = E.unsqueeze(2), E.unsqueeze(1) + px_a, px_b = px.unsqueeze(2), px.unsqueeze(1) + py_a, py_b = py.unsqueeze(2), py.unsqueeze(1) + pz_a, pz_b = pz.unsqueeze(2), pz.unsqueeze(1) + + # ΔR, kT, z + delta = torch.sqrt((y_a - y_b) ** 2 + (phi_a - phi_b) ** 2) + self.eps + kT = torch.minimum(pT_a, pT_b) * delta + z = torch.minimum(pT_a, pT_b) / (pT_a + pT_b + self.eps) + + # m² of pair + E_sum = E_a + E_b + px_sum = px_a + px_b + py_sum = py_a + py_b + pz_sum = pz_a + pz_b + m2 = E_sum**2 - (px_sum**2 + py_sum**2 + pz_sum**2) + self.eps + m2 = torch.clamp(m2, min=self.eps) # avoid negatives + + # Stack → (B, 4, N, N) + feats = torch.stack([ + torch.log(delta), + torch.log(kT), + torch.log(z), + torch.log(m2) + ], dim=1) + + # conv + U = self.net(feats) # (B, n_heads, N, N) + return U + + +class ParticleTokenizer(nn.Module): + def __init__(self, in_dim=4, out_dim=6): + super().__init__() + self.proj = nn.Linear(in_dim, out_dim) + + def forward(self, x): + """ + x: tensor of shape (B, n_particles, in_dim) + returns: (B, n_particles, out_dim) + """ + x = x.transpose(1, 2) # Input shape: (B, n_particles, in_dim) → (B, in_dim, n_particles) + return self.proj(x) + +class MLP(nn.Module): + def __init__(self, dim, expansion=1, dropout=0.): + super().__init__() + hidden = dim * expansion + self.net = nn.Sequential( + nn.Linear(dim, hidden), nn.GELU(), nn.Dropout(dropout), + nn.Linear(hidden, dim), nn.Dropout(dropout) + ) + def forward(self, x): return self.net(x) + +class ParticleMHA(nn.Module): + """ + Multi-head self-attention with additive interaction bias U. + + Input + ----- + x : (B, N, d) token / particle embeddings + U : (broadcast → B, H, N, N) or None + + Returns + ------- + out : (B, N, d) attention output + attn_map : (B, H, N, N) attention weights (returned if + return_attn=True) + """ + def __init__(self, d: int, heads: int = 8, + dropout: float = 0.1, return_attn: bool = False): + super().__init__() + assert d % heads == 0, "`d` must be divisible by `heads`" + + self.d = d + self.h = heads + self.d_head = d // heads + self.scale = 1 / math.sqrt(self.d_head) + self.return_attn = return_attn + + # Projections + self.q = nn.Linear(d, d, bias=False) + self.k = nn.Linear(d, d, bias=False) + self.v = nn.Linear(d, d, bias=False) + self.o = nn.Linear(d, d, bias=False) + + self.drop = nn.Dropout(dropout) + + def _split(self, t: torch.Tensor): + # (B, N, d) -> (B, H, N, d_head) + B, N, _ = t.shape + return ( + t.view(B, N, self.h, self.d_head) # (B, N, H, d_head) + .transpose(1, 2) # (B, H, N, d_head) + ) + + def forward(self, x: torch.Tensor, + U: torch.Tensor | None = None): + B, N, _ = x.shape + + Q = self._split(self.q(x)) + K = self._split(self.k(x)) + V = self._split(self.v(x)) + + logits = (Q @ K.transpose(-2, -1)) * self.scale # (B, H, N, N) + + if U is not None: + logits = logits + U + + attn = F.softmax(logits, dim=-1) + attn = self.drop(attn) + + context = attn @ V # (B, H, N, d_h) + + context = ( + context.transpose(1, 2) # (B, N, H, d_h) + .contiguous() + .view(B, N, self.d) # (B, N, d) + ) + out = self.o(context) + + if self.return_attn: + return out, attn # (B, N, d), (B, H, N, N) + else: + return out + +class MHA(nn.Module): + """ + Multi-head attention (batch_first) implemented explicitly. + + Args + ---- + d_model : int embedding dim + n_heads : int + dropout: float + bias : bool use bias in projections + """ + def __init__(self, d_model: int, n_heads: int, dropout: float = 0., bias: bool = False): + super().__init__() + assert d_model % n_heads == 0, "`d_model` must be divisible by `n_heads`" + self.d_model = d_model + self.h = n_heads + self.d_head = d_model // n_heads + self.scale = self.d_head ** -0.5 + + self.q_proj = nn.Linear(d_model, d_model, bias=bias) + self.k_proj = nn.Linear(d_model, d_model, bias=bias) + self.v_proj = nn.Linear(d_model, d_model, bias=bias) + self.o_proj = nn.Linear(d_model, d_model, bias=bias) + + self.drop = nn.Dropout(dropout) + + def _split_heads(self, x: torch.Tensor): + # (B, L, d_model) -> (B, h, L, d_head) + B, L, _ = x.shape + return x.view(B, L, self.h, self.d_head).transpose(1, 2) + + def _merge_heads(self, x: torch.Tensor): + # (B, h, L, d_head) -> (B, L, d_model) + B, H, L, Dh = x.shape + return x.transpose(1, 2).contiguous().view(B, L, H * Dh) + + def forward( + self, + q: torch.Tensor, # (B, Lq, d_model) + k: torch.Tensor, # (B, Lk, d_model) + v: torch.Tensor, # (B, Lk, d_model) + need_weights: bool = False + ): + B, Lq, _ = q.shape + _, Lk, _ = k.shape + + Q = self._split_heads(self.q_proj(q)) # (B,h,Lq,d_h) + K = self._split_heads(self.k_proj(k)) # (B,h,Lk,d_h) + V = self._split_heads(self.v_proj(v)) # (B,h,Lk,d_h) + + logits = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B,h,Lq,Lk) + + attn = F.softmax(logits, dim=-1) + attn = self.drop(attn) + + context = torch.matmul(attn, V) # (B,h,Lq,d_h) + + # merge heads + output proj + out = self.o_proj(self._merge_heads(context)) # (B,Lq,d_model) + + if need_weights: + avg_weights = attn.mean(dim=1) # (B,Lq,Lk) + return out, avg_weights + return out, None + + +# Particle attention block (NormFormer style + U-bias) +class ParticleAttentionBlock(nn.Module): + def __init__(self, dim, heads, mlp_ratio=4, dropout=0.1): + super().__init__() + self.ln1 = nn.LayerNorm(dim) + self.attn = ParticleMHA(dim, heads, dropout) + self.ln2 = nn.LayerNorm(dim) + self.mlp = MLP(dim, mlp_ratio, dropout) + def forward(self, x, U): + x = x + self.attn(self.ln1(x), U) # bias-aware MHSA + x = x + self.mlp(self.ln2(x)) # feed-forward + return x + +# Class attention block (CaiT style, no U) +class ClassAttentionBlock(nn.Module): + def __init__(self, dim, heads, mlp_ratio=4, dropout=0.): + super().__init__() + self.ln1 = nn.LayerNorm(dim) + self.attn = MHA(dim, heads, dropout) + self.ln2 = nn.LayerNorm(dim) + self.mlp = MLP(dim, mlp_ratio, dropout) + def forward(self, tokens, cls): # tokens: (B,N,d), cls: (B,1,d) + z = torch.cat([cls, tokens], dim=1) # (B,1+N,d) + q = self.ln1(cls) + kv = self.ln1(z) + cls = cls + self.attn(q, kv, kv, need_weights=False)[0] + cls = cls + self.mlp(self.ln2(cls)) + return cls # (B,1,d) + +# Complete Particle Transformer +class ParT(nn.Module): + def __init__(self, + in_dim=4, # (E,px,py,pz) + embed_dim=10, + n_heads=2, + depth=2, # particle blocks + class_depth=2, # class-attention blocks + mlp_ratio=4, + num_classes=10, + dropout=0.1): + super().__init__() + + self.tokenizer = ParticleTokenizer(in_dim, embed_dim) + self.U_encoder = InteractionEncoder(n_heads=n_heads) + + self.blocks = nn.ModuleList([ + ParticleAttentionBlock(embed_dim, n_heads, mlp_ratio, dropout) + for _ in range(depth) + ]) + + self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.cls_blocks = nn.ModuleList([ + ClassAttentionBlock(embed_dim, n_heads, mlp_ratio, 0.0) + for _ in range(class_depth) + ]) + + self.head = nn.Linear(embed_dim, num_classes) + + nn.init.trunc_normal_(self.class_token, std=0.02) + nn.init.trunc_normal_(self.head.weight, std=0.02) + nn.init.zeros_(self.head.bias) + + def forward(self, x): # x: (B,4,N) + B, _, N = x.shape + + tokens = self.tokenizer(x) # (B,N,d) + U = self.U_encoder(x) # (B,H,N,N) + + for blk in self.blocks: + tokens = blk(tokens, U) # (B,N,d) + + cls = self.class_token.expand(B, -1, -1) # (B,1,d) + for blk in self.cls_blocks: + cls = blk(tokens, cls) # (B,1,d) + + logits = self.head(cls.squeeze(1)) # (B,10) + return logits \ No newline at end of file diff --git a/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/ParT_Baseline.ipynb b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/ParT_Baseline.ipynb deleted file mode 100644 index b4c5805..0000000 --- a/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/ParT_Baseline.ipynb +++ /dev/null @@ -1,1163 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "machine_shape": "hm", - "gpuType": "T4" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "smpEEHlypz0p" - }, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import math" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Toy Data" - ], - "metadata": { - "id": "PAgmRCdTvyg6" - } - }, - { - "cell_type": "code", - "source": [ - "x = torch.Tensor([[ 8.7541382e+01, 7.4368027e+01, 6.8198807e+01, 5.5106983e+01,\n", - " 3.8173203e+01, 2.3811323e+01, 1.8535612e+01, 1.8095873e+01],\n", - " [-5.1816605e-02, 4.1964740e-01, 4.4769207e-01, -6.5377988e-02,\n", - " 6.1878175e-01, 6.2997532e-01, -8.1161506e-02, -7.9315454e-02],\n", - " [ 2.4468634e+00, 2.2728169e+00, 2.2340939e+00, 2.6091626e+00,\n", - " 1.8117365e+00, 1.7758595e+00, 2.3642292e+00, 2.5458102e+00],\n", - " [ 8.7658936e+01, 8.1014442e+01, 7.5148209e+01, 5.5224800e+01,\n", - " 4.5717682e+01, 2.8694658e+01, 1.8596695e+01, 1.8153358e+01]])\n", - "\n", - "x_batch = torch.Tensor([[[ 8.7541382e+01, 7.4368027e+01, 6.8198807e+01, 5.5106983e+01,\n", - " 3.8173203e+01, 2.3811323e+01, 1.8535612e+01, 1.8095873e+01],\n", - " [-5.1816605e-02, 4.1964740e-01, 4.4769207e-01, -6.5377988e-02,\n", - " 6.1878175e-01, 6.2997532e-01, -8.1161506e-02, -7.9315454e-02],\n", - " [ 2.4468634e+00, 2.2728169e+00, 2.2340939e+00, 2.6091626e+00,\n", - " 1.8117365e+00, 1.7758595e+00, 2.3642292e+00, 2.5458102e+00],\n", - " [ 8.7658936e+01, 8.1014442e+01, 7.5148209e+01, 5.5224800e+01,\n", - " 4.5717682e+01, 2.8694658e+01, 1.8596695e+01, 1.8153358e+01]],\n", - "\n", - " [[ 8.2484528e+01, 5.2682617e+01, 5.1243843e+01, 3.6217686e+01,\n", - " 2.8948278e+01, 2.6579512e+01, 2.1946012e+01, 2.1011120e+01],\n", - " [-4.3566185e-01, -8.7309110e-01, -4.4896263e-01, -6.0569459e-01,\n", - " -4.8134822e-01, -7.0045888e-01, -6.0671657e-01, -5.7662535e-01],\n", - " [-1.9739739e+00, -2.4504409e+00, -1.9982951e+00, -1.4225215e+00,\n", - " -1.9399333e+00, -2.3558097e+00, -1.4185165e+00, -1.4236869e+00],\n", - " [ 9.0437065e+01, 7.4070679e+01, 5.6495895e+01, 4.3069641e+01,\n", - " 3.2367134e+01, 3.3371326e+01, 2.6115334e+01, 2.4602446e+01]],\n", - "\n", - " [[ 8.6492935e+01, 7.0192978e+01, 5.8423912e+01, 5.6638733e+01,\n", - " 4.9270725e+01, 4.1237038e+01, 3.6133625e+01, 3.5519596e+01],\n", - " [ 1.4010678e-01, 2.7912292e-01, 1.4376265e-01, 3.4672296e-01,\n", - " 3.4966472e-01, 1.0524009e-01, 1.2958543e-01, 3.3264065e-01],\n", - " [ 1.9334941e+00, 1.6967584e+00, 1.9219695e+00, 1.6735281e+00,\n", - " 1.6587850e+00, 1.8386338e+00, 1.9120301e+00, 1.6680365e+00],\n", - " [ 8.7343246e+01, 7.2945129e+01, 5.9030762e+01, 6.0084766e+01,\n", - " 5.2313778e+01, 4.1465607e+01, 3.6440781e+01, 3.7506149e+01]]])" - ], - "metadata": { - "id": "2jsmzKcFvx4L" - }, - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "x = x.unsqueeze(0) # batch dimension\n", - "x.shape, x_batch.shape" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Po7Fb1v37EJT", - "outputId": "5a9e6be0-869f-4811-c3b7-4d372bc69462" - }, - "execution_count": 3, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(torch.Size([1, 4, 8]), torch.Size([3, 4, 8]))" - ] - }, - "metadata": {}, - "execution_count": 3 - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Toy interaction matrix" - ], - "metadata": { - "id": "X6d5fzsV4kt7" - } - }, - { - "cell_type": "code", - "source": [ - "class InteractionEncoder(nn.Module):\n", - " \"\"\"\n", - " ParT interaction-feature encoder.\n", - "\n", - " Args\n", - " ----\n", - " n_heads per mhsa: output channels d′\n", - " hidden_channels : list[int] for intermediate 1×1 conv layers\n", - " eps : numerical guard for log\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " n_heads: int = 8,\n", - " hidden_channels: list[int] = (64, 64, 64),\n", - " eps: float = 1e-8):\n", - " super().__init__()\n", - " self.eps = eps\n", - "\n", - " layers: list[nn.Module] = []\n", - " in_ch = 4 # lnΔ, ln kT, ln z, ln m²\n", - " for h in hidden_channels:\n", - " layers += [\n", - " nn.Conv2d(in_ch, h, 1, bias=False),\n", - " nn.BatchNorm2d(h),\n", - " nn.GELU()\n", - " ]\n", - " in_ch = h\n", - " layers.append(nn.Conv2d(in_ch, n_heads, 1, bias=False))\n", - " self.net = nn.Sequential(*layers)\n", - "\n", - " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"\n", - " x : (B, 4, N) where the 4 dims are (E, px, py, pz)\n", - " returns\n", - " ------\n", - " U : (B, n_heads, N, N) interaction embedding\n", - " \"\"\"\n", - " B, four, N = x.shape\n", - " assert four == 4, \"input must have 4 features: E, px, py, pz\"\n", - "\n", - " # Split components\n", - " E, px, py, pz = x.unbind(dim=1) # each (B, N)\n", - "\n", - " # Basic kinematics ------------------------------------------------\n", - " pT = torch.sqrt(px**2 + py**2) + self.eps\n", - " phi = torch.atan2(py, px) # (−π, π]\n", - " num = (E + pz).clamp(min=self.eps) #need to avoid negative numbers\n", - " den = (E - pz).clamp(min=self.eps)\n", - " y = 0.5 * torch.log(num / den)\n", - "\n", - " # Expand to (B, N, N)\n", - " y_a, y_b = y.unsqueeze(2), y.unsqueeze(1) # (B,N,1),(B,1,N)\n", - " phi_a, phi_b = phi.unsqueeze(2), phi.unsqueeze(1)\n", - " pT_a, pT_b = pT.unsqueeze(2), pT.unsqueeze(1)\n", - " E_a, E_b = E.unsqueeze(2), E.unsqueeze(1)\n", - " px_a, px_b = px.unsqueeze(2), px.unsqueeze(1)\n", - " py_a, py_b = py.unsqueeze(2), py.unsqueeze(1)\n", - " pz_a, pz_b = pz.unsqueeze(2), pz.unsqueeze(1)\n", - "\n", - " # ΔR, kT, z\n", - " delta = torch.sqrt((y_a - y_b) ** 2 + (phi_a - phi_b) ** 2) + self.eps\n", - " kT = torch.minimum(pT_a, pT_b) * delta\n", - " z = torch.minimum(pT_a, pT_b) / (pT_a + pT_b + self.eps)\n", - "\n", - " # m² of pair\n", - " E_sum = E_a + E_b\n", - " px_sum = px_a + px_b\n", - " py_sum = py_a + py_b\n", - " pz_sum = pz_a + pz_b\n", - " m2 = E_sum**2 - (px_sum**2 + py_sum**2 + pz_sum**2) + self.eps\n", - " m2 = torch.clamp(m2, min=self.eps) # avoid negatives\n", - "\n", - " # Stack → (B, 4, N, N)\n", - " feats = torch.stack([\n", - " torch.log(delta),\n", - " torch.log(kT),\n", - " torch.log(z),\n", - " torch.log(m2)\n", - " ], dim=1)\n", - "\n", - " # conv\n", - " U = self.net(feats) # (B, n_heads, N, N)\n", - " return U\n", - "\n", - "\n", - "\n", - "B, _, N = x.shape\n", - "n_heads = 2 # d′\n", - "enc = InteractionEncoder(n_heads=n_heads)\n", - "U = enc(x)\n", - "print(\"U.shape:\", U.shape)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "sOjDJYr14zDr", - "outputId": "c062e171-de41-4c9f-c97a-5d639620b233" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "U.shape: torch.Size([1, 2, 8, 8])\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Particle Transformer" - ], - "metadata": { - "id": "5CwQyoPYwaR7" - } - }, - { - "cell_type": "code", - "source": [ - "class ParticleTokenizer(nn.Module):\n", - " def __init__(self, in_dim=4, out_dim=6):\n", - " super().__init__()\n", - " self.proj = nn.Linear(in_dim, out_dim)\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\n", - " x: tensor of shape (B, n_particles, in_dim)\n", - " returns: (B, n_particles, out_dim)\n", - " \"\"\"\n", - " x = x.transpose(1, 2) # Input shape: (B, n_particles, in_dim) → (B, in_dim, n_particles)\n", - " return self.proj(x)\n", - "\n", - "tokenizer = ParticleTokenizer(4, 10)\n", - "output = tokenizer(x)\n", - "output.shape" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Ya5SvxTJwE3r", - "outputId": "6e347893-293b-4d76-d7d7-926c5382a670" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "torch.Size([1, 8, 10])" - ] - }, - "metadata": {}, - "execution_count": 5 - } - ] - }, - { - "cell_type": "code", - "source": [ - "class MLP(nn.Module):\n", - " def __init__(self, dim, expansion=1, dropout=0.):\n", - " super().__init__()\n", - " hidden = dim * expansion\n", - " self.net = nn.Sequential(\n", - " nn.Linear(dim, hidden), nn.GELU(), nn.Dropout(dropout),\n", - " nn.Linear(hidden, dim), nn.Dropout(dropout)\n", - " )\n", - " def forward(self, x): return self.net(x)\n", - "\n", - "mlp = MLP(10, expansion=1, dropout=0.1)\n", - "output = mlp(output)\n", - "output.shape" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "S5f8DTRY07PD", - "outputId": "e553e047-ead1-43da-eee4-001c715f7258" - }, - "execution_count": 6, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "torch.Size([1, 8, 10])" - ] - }, - "metadata": {}, - "execution_count": 6 - } - ] - }, - { - "cell_type": "code", - "source": [ - "class ParticleMHA(nn.Module):\n", - " \"\"\"\n", - " Multi-head self-attention with additive interaction bias U.\n", - "\n", - " Input\n", - " -----\n", - " x : (B, N, d) token / particle embeddings\n", - " U : (broadcast → B, H, N, N) or None\n", - "\n", - " Returns\n", - " -------\n", - " out : (B, N, d) attention output\n", - " attn_map : (B, H, N, N) attention weights (returned if\n", - " return_attn=True)\n", - " \"\"\"\n", - " def __init__(self, d: int, heads: int = 8,\n", - " dropout: float = 0.1, return_attn: bool = False):\n", - " super().__init__()\n", - " assert d % heads == 0, \"`d` must be divisible by `heads`\"\n", - "\n", - " self.d = d\n", - " self.h = heads\n", - " self.d_head = d // heads\n", - " self.scale = 1 / math.sqrt(self.d_head)\n", - " self.return_attn = return_attn\n", - "\n", - " # Projections\n", - " self.q = nn.Linear(d, d, bias=False)\n", - " self.k = nn.Linear(d, d, bias=False)\n", - " self.v = nn.Linear(d, d, bias=False)\n", - " self.o = nn.Linear(d, d, bias=False)\n", - "\n", - " self.drop = nn.Dropout(dropout)\n", - "\n", - " def _split(self, t: torch.Tensor):\n", - " # (B, N, d) -> (B, H, N, d_head)\n", - " B, N, _ = t.shape\n", - " return (\n", - " t.view(B, N, self.h, self.d_head) # (B, N, H, d_head)\n", - " .transpose(1, 2) # (B, H, N, d_head)\n", - " )\n", - "\n", - " def forward(self, x: torch.Tensor,\n", - " U: torch.Tensor | None = None):\n", - " B, N, _ = x.shape\n", - "\n", - " Q = self._split(self.q(x))\n", - " K = self._split(self.k(x))\n", - " V = self._split(self.v(x))\n", - "\n", - " logits = (Q @ K.transpose(-2, -1)) * self.scale # (B, H, N, N)\n", - "\n", - " if U is not None:\n", - " logits = logits + U\n", - "\n", - " attn = F.softmax(logits, dim=-1)\n", - " attn = self.drop(attn)\n", - "\n", - " context = attn @ V # (B, H, N, d_h)\n", - "\n", - " context = (\n", - " context.transpose(1, 2) # (B, N, H, d_h)\n", - " .contiguous()\n", - " .view(B, N, self.d) # (B, N, d)\n", - " )\n", - " out = self.o(context)\n", - "\n", - " if self.return_attn:\n", - " return out, attn # (B, N, d), (B, H, N, N)\n", - " else:\n", - " return out\n", - "\n", - "B, N, d = output.shape\n", - "U = torch.randn(1, 2, N, N) # broadcast to (B, H, N, N)\n", - "\n", - "pmha = ParticleMHA(d=d, heads=2, dropout=0.1, return_attn=True)\n", - "output, A = pmha(output, U) # out: (B, N, d) A: (B, n_heads, N, N)\n", - "\n", - "print(output.shape, A.shape)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zb5nYOsSyFRk", - "outputId": "0b34f9c7-b47b-4000-856e-e14ecf6c120b" - }, - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([1, 8, 10]) torch.Size([1, 2, 8, 8])\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## transformer" - ], - "metadata": { - "id": "4650yyFtdjaR" - } - }, - { - "cell_type": "code", - "source": [ - "class MHA(nn.Module):\n", - " \"\"\"\n", - " Multi-head attention (batch_first) implemented explicitly.\n", - "\n", - " Args\n", - " ----\n", - " d_model : int embedding dim\n", - " n_heads : int\n", - " dropout: float\n", - " bias : bool use bias in projections\n", - " \"\"\"\n", - " def __init__(self, d_model: int, n_heads: int, dropout: float = 0., bias: bool = False):\n", - " super().__init__()\n", - " assert d_model % n_heads == 0, \"`d_model` must be divisible by `n_heads`\"\n", - " self.d_model = d_model\n", - " self.h = n_heads\n", - " self.d_head = d_model // n_heads\n", - " self.scale = self.d_head ** -0.5\n", - "\n", - " self.q_proj = nn.Linear(d_model, d_model, bias=bias)\n", - " self.k_proj = nn.Linear(d_model, d_model, bias=bias)\n", - " self.v_proj = nn.Linear(d_model, d_model, bias=bias)\n", - " self.o_proj = nn.Linear(d_model, d_model, bias=bias)\n", - "\n", - " self.drop = nn.Dropout(dropout)\n", - "\n", - " def _split_heads(self, x: torch.Tensor):\n", - " # (B, L, d_model) -> (B, h, L, d_head)\n", - " B, L, _ = x.shape\n", - " return x.view(B, L, self.h, self.d_head).transpose(1, 2)\n", - "\n", - " def _merge_heads(self, x: torch.Tensor):\n", - " # (B, h, L, d_head) -> (B, L, d_model)\n", - " B, H, L, Dh = x.shape\n", - " return x.transpose(1, 2).contiguous().view(B, L, H * Dh)\n", - "\n", - " def forward(\n", - " self,\n", - " q: torch.Tensor, # (B, Lq, d_model)\n", - " k: torch.Tensor, # (B, Lk, d_model)\n", - " v: torch.Tensor, # (B, Lk, d_model)\n", - " need_weights: bool = False\n", - " ):\n", - " B, Lq, _ = q.shape\n", - " _, Lk, _ = k.shape\n", - "\n", - " Q = self._split_heads(self.q_proj(q)) # (B,h,Lq,d_h)\n", - " K = self._split_heads(self.k_proj(k)) # (B,h,Lk,d_h)\n", - " V = self._split_heads(self.v_proj(v)) # (B,h,Lk,d_h)\n", - "\n", - " logits = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B,h,Lq,Lk)\n", - "\n", - " attn = F.softmax(logits, dim=-1)\n", - " attn = self.drop(attn)\n", - "\n", - " context = torch.matmul(attn, V) # (B,h,Lq,d_h)\n", - "\n", - " # merge heads + output proj\n", - " out = self.o_proj(self._merge_heads(context)) # (B,Lq,d_model)\n", - "\n", - " if need_weights:\n", - " avg_weights = attn.mean(dim=1) # (B,Lq,Lk)\n", - " return out, avg_weights\n", - " return out, None" - ], - "metadata": { - "id": "uHgMbZTo1jkU" - }, - "execution_count": 11, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# Particle attention block (NormFormer style + U-bias)\n", - "class ParticleAttentionBlock(nn.Module):\n", - " def __init__(self, dim, heads, mlp_ratio=4, dropout=0.1):\n", - " super().__init__()\n", - " self.ln1 = nn.LayerNorm(dim)\n", - " self.attn = ParticleMHA(dim, heads, dropout)\n", - " self.ln2 = nn.LayerNorm(dim)\n", - " self.mlp = MLP(dim, mlp_ratio, dropout)\n", - " def forward(self, x, U):\n", - " x = x + self.attn(self.ln1(x), U) # bias-aware MHSA\n", - " x = x + self.mlp(self.ln2(x)) # feed-forward\n", - " return x\n", - "\n", - "# Class attention block (CaiT style, no U)\n", - "class ClassAttentionBlock(nn.Module):\n", - " def __init__(self, dim, heads, mlp_ratio=4, dropout=0.):\n", - " super().__init__()\n", - " self.ln1 = nn.LayerNorm(dim)\n", - " self.attn = MHA(dim, heads, dropout)\n", - " self.ln2 = nn.LayerNorm(dim)\n", - " self.mlp = MLP(dim, mlp_ratio, dropout)\n", - " def forward(self, tokens, cls): # tokens: (B,N,d), cls: (B,1,d)\n", - " z = torch.cat([cls, tokens], dim=1) # (B,1+N,d)\n", - " q = self.ln1(cls)\n", - " kv = self.ln1(z)\n", - " cls = cls + self.attn(q, kv, kv, need_weights=False)[0]\n", - " cls = cls + self.mlp(self.ln2(cls))\n", - " return cls # (B,1,d)\n", - "\n", - "# Complete Particle Transformer\n", - "class ParT(nn.Module):\n", - " def __init__(self,\n", - " in_dim=4, # (E,px,py,pz)\n", - " embed_dim=10,\n", - " n_heads=2,\n", - " depth=2, # particle blocks\n", - " class_depth=2, # class-attention blocks\n", - " mlp_ratio=4,\n", - " num_classes=10,\n", - " dropout=0.1):\n", - " super().__init__()\n", - "\n", - " self.tokenizer = ParticleTokenizer(in_dim, embed_dim)\n", - " self.U_encoder = InteractionEncoder(n_heads=n_heads)\n", - "\n", - " self.blocks = nn.ModuleList([\n", - " ParticleAttentionBlock(embed_dim, n_heads, mlp_ratio, dropout)\n", - " for _ in range(depth)\n", - " ])\n", - "\n", - " self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n", - " self.cls_blocks = nn.ModuleList([\n", - " ClassAttentionBlock(embed_dim, n_heads, mlp_ratio, 0.0)\n", - " for _ in range(class_depth)\n", - " ])\n", - "\n", - " self.head = nn.Linear(embed_dim, num_classes)\n", - "\n", - " nn.init.trunc_normal_(self.class_token, std=0.02)\n", - " nn.init.trunc_normal_(self.head.weight, std=0.02)\n", - " nn.init.zeros_(self.head.bias)\n", - "\n", - " def forward(self, x): # x: (B,4,N)\n", - " B, _, N = x.shape\n", - "\n", - " tokens = self.tokenizer(x) # (B,N,d)\n", - " U = self.U_encoder(x) # (B,H,N,N)\n", - "\n", - " for blk in self.blocks:\n", - " tokens = blk(tokens, U) # (B,N,d)\n", - "\n", - " cls = self.class_token.expand(B, -1, -1) # (B,1,d)\n", - " for blk in self.cls_blocks:\n", - " cls = blk(tokens, cls) # (B,1,d)\n", - "\n", - " logits = self.head(cls.squeeze(1)) # (B,10)\n", - " return logits" - ], - "metadata": { - "id": "iQveEBhldizq" - }, - "execution_count": 15, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "B, _, N = x_batch.shape # (3,4,8)\n", - "model = ParT(in_dim=4,\n", - " embed_dim=10,\n", - " n_heads=2,\n", - " depth=2,\n", - " class_depth=2,\n", - " num_classes=10)\n", - "\n", - "logits = model(x_batch) # forward pass\n", - "print(\"logits:\", logits.shape) # -> torch.Size([3, 10])\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0-g5CQ2hfh2x", - "outputId": "ef81a92c-0aa7-4a94-8c19-72ff3d20a178" - }, - "execution_count": 16, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "logits: torch.Size([3, 10])\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "x_train = x_batch # (3, 4, 8)\n", - "y_train = torch.tensor([0, 1, 2]) # dummy class labels for testing\n", - "\n", - "model = ParT(in_dim=4,\n", - " embed_dim=10,\n", - " n_heads=2,\n", - " depth=2,\n", - " class_depth=2,\n", - " num_classes=10)\n", - "\n", - "criterion = nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", - "\n", - "n_epochs = 250\n", - "for epoch in range(n_epochs):\n", - " model.train()\n", - " logits = model(x_train) # (3, 10)\n", - " loss = criterion(logits, y_train)\n", - "\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # print every 50 epochs\n", - " if (epoch+1) % 50 == 0 or epoch == 0:\n", - " preds = logits.argmax(1)\n", - " acc = (preds == y_train).float().mean().item()\n", - " print(f\"epoch {epoch+1:3d} loss {loss.item():.4f} acc {acc:.3f}\")\n", - "\n", - "model.eval()\n", - "with torch.no_grad():\n", - " probs = torch.softmax(model(x_train), dim=1)\n", - "print(\"softmax-probs\\n\", probs)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mOr6fyLvfjgv", - "outputId": "093278c3-6069-4f62-d8b9-efbb3db8caac" - }, - "execution_count": 18, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "epoch 1 loss 2.3060 acc 0.000\n", - "epoch 50 loss 1.1657 acc 0.333\n", - "epoch 100 loss 1.0977 acc 0.667\n", - "epoch 150 loss 0.8846 acc 1.000\n", - "epoch 200 loss 0.0219 acc 1.000\n", - "epoch 250 loss 0.0040 acc 1.000\n", - "softmax-probs\n", - " tensor([[9.9536e-01, 2.0938e-04, 3.2167e-03, 1.9185e-04, 1.4622e-04, 1.9616e-04,\n", - " 1.6393e-04, 1.4288e-04, 2.2374e-04, 1.5239e-04],\n", - " [1.3350e-04, 9.9691e-01, 2.1701e-03, 1.5398e-04, 1.0111e-04, 1.5174e-04,\n", - " 7.6620e-05, 1.2269e-04, 7.1940e-05, 1.1269e-04],\n", - " [1.4062e-03, 1.9185e-03, 9.9667e-01, 1.0295e-06, 7.4998e-07, 8.4284e-07,\n", - " 6.6928e-07, 5.7153e-07, 5.3092e-07, 3.6395e-07]])\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Load official Data" - ], - "metadata": { - "id": "Z7bm2U4EYwY1" - } - }, - { - "cell_type": "code", - "source": [ - "# 1) Clone repo\n", - "!git clone https://github.com/jet-universe/particle_transformer.git\n", - "!cd particle_transformer\n", - "!cd /content/particle_transformer\n", - "!touch env.sh\n", - "!chmod +x env.sh" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ES5wefiuY5ca", - "outputId": "29c76f98-c96e-4d03-f009-2cc7f87cf313" - }, - "execution_count": 19, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Cloning into 'particle_transformer'...\n", - "remote: Enumerating objects: 101, done.\u001b[K\n", - "remote: Counting objects: 100% (52/52), done.\u001b[K\n", - "remote: Compressing objects: 100% (25/25), done.\u001b[K\n", - "remote: Total 101 (delta 38), reused 27 (delta 27), pack-reused 49 (from 1)\u001b[K\n", - "Receiving objects: 100% (101/101), 28.08 MiB | 12.26 MiB/s, done.\n", - "Resolving deltas: 100% (46/46), done.\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "!/content/particle_transformer/get_datasets.py JetClass -d ./datasets\n", - "!source env.sh\n", - "import os, glob, tarfile\n", - "os.environ['DATADIR_JetClass'] = os.path.abspath('./datasets/JetClass')\n", - "data_dir = os.environ['DATADIR_JetClass']\n", - "!pip install awkward uproot vector\n", - "from particle_transformer.dataloader import read_file\n", - "\n", - "# Path to the one and only thing downloaded\n", - "tar_path = \"/content/datasets/JetClass/JetClass_Pythia_val_5M.tar\"\n", - "extract_dir = \"/content/datasets/JetClass/JetClass_Pythia_val_5M\"\n", - "os.makedirs(extract_dir, exist_ok=True)" - ], - "metadata": { - "id": "5IVg4AkAK1xz", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "070415ca-05ae-4e98-b1c2-50116ff2e38c" - }, - "execution_count": 20, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://zenodo.org/record/6619768/files/JetClass_Pythia_val_5M.tar to ./datasets/JetClass/JetClass_Pythia_val_5M.tar\n", - "./datasets/JetClass/JetClass_Pythia_val_5M.tar: 100% 7.07G/7.07G [06:36<00:00, 19.1MiB/s]\n", - "Updated dataset path in env.sh to \"DATADIR_JetClass=./datasets/JetClass\".\n", - "Collecting awkward\n", - " Downloading awkward-2.8.5-py3-none-any.whl.metadata (6.9 kB)\n", - "Collecting uproot\n", - " Downloading uproot-5.6.3-py3-none-any.whl.metadata (33 kB)\n", - "Collecting vector\n", - " Downloading vector-1.6.3-py3-none-any.whl.metadata (16 kB)\n", - "Collecting awkward-cpp==47 (from awkward)\n", - " Downloading awkward_cpp-47-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (2.1 kB)\n", - "Requirement already satisfied: fsspec>=2022.11.0 in /usr/local/lib/python3.11/dist-packages (from awkward) (2025.3.0)\n", - "Requirement already satisfied: importlib-metadata>=4.13.0 in /usr/local/lib/python3.11/dist-packages (from awkward) (8.7.0)\n", - "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.11/dist-packages (from awkward) (2.0.2)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from awkward) (25.0)\n", - "Requirement already satisfied: cramjam>=2.5.0 in /usr/local/lib/python3.11/dist-packages (from uproot) (2.10.0)\n", - "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from uproot) (3.5.0)\n", - "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.11/dist-packages (from importlib-metadata>=4.13.0->awkward) (3.23.0)\n", - "Downloading awkward-2.8.5-py3-none-any.whl (886 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m886.8/886.8 kB\u001b[0m \u001b[31m47.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading awkward_cpp-47-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (638 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m638.8/638.8 kB\u001b[0m \u001b[31m44.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading uproot-5.6.3-py3-none-any.whl (382 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.8/382.8 kB\u001b[0m \u001b[31m29.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading vector-1.6.3-py3-none-any.whl (179 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.6/179.6 kB\u001b[0m \u001b[31m15.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: vector, awkward-cpp, awkward, uproot\n", - "Successfully installed awkward-2.8.5 awkward-cpp-47 uproot-5.6.3 vector-1.6.3\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "if not any(fname.endswith(\".root\") for fname in os.listdir(extract_dir)):\n", - " print(\"⏬ extracting test-set…\")\n", - " with tarfile.open(tar_path) as tar:\n", - " tar.extractall(path=extract_dir)\n", - "\n", - "# Point glob at the real ROOT files\n", - "pattern = os.path.join(extract_dir, 'val_5M', \"*.root\")\n", - "files = sorted(glob.glob(pattern))\n", - "print(f\"Found {len(files)} ROOT files\")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4QtS-37tY8wa", - "outputId": "19c7abc3-375b-415a-dd20-08464323b74d" - }, - "execution_count": 21, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "⏬ extracting test-set…\n", - "Found 50 ROOT files\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "from torch.utils.data import Dataset\n", - "all_x_parts = []\n", - "all_ys = []\n", - "\n", - "num_file = 1\n", - "for file in files:\n", - " num_file += 1\n", - " if num_file % 5 == 0:\n", - " x_part, x_jets, y = read_file(\n", - " file,\n", - " max_num_particles=8,\n", - " particle_features=['part_pt', 'part_eta', 'part_phi', 'part_energy'],\n", - " jet_features=['jet_pt', 'jet_eta', 'jet_phi', 'jet_energy'],\n", - " labels=[\n", - " 'label_QCD', 'label_Hbb', 'label_Hcc', 'label_Hgg', 'label_H4q',\n", - " 'label_Hqql', 'label_Zqq', 'label_Wqq', 'label_Tbqq', 'label_Tbl',\n", - " ]\n", - " )\n", - " all_x_parts.append(torch.tensor(x_part, dtype=torch.float32)[:100,:,:])\n", - " all_ys.append(torch.tensor(y, dtype=torch.float32)[:100,:])\n", - "\n", - "x_all = torch.cat(all_x_parts, dim=0)\n", - "y_all = torch.cat(all_ys, dim=0)\n", - "print(x_all.shape, y_all.shape)\n", - "\n", - "class JetDataset(Dataset):\n", - " def __init__(self, x, y):\n", - " self.x = x\n", - " self.y = y\n", - "\n", - " def __len__(self):\n", - " return self.x.shape[0]\n", - "\n", - " def __getitem__(self, idx):\n", - " return self.x[idx], self.y[idx]\n", - "\n", - "dataset = JetDataset(x_all, y_all)\n", - "\n", - "from torch.utils.data import DataLoader\n", - "dataloader = DataLoader(dataset, batch_size=64, shuffle=True)" - ], - "metadata": { - "id": "-q3rmopucyqr", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "1091b217-02ad-4d0e-83b1-a73e231d8426" - }, - "execution_count": 22, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([1000, 4, 8]) torch.Size([1000, 10])\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "model = ParT(\n", - " in_dim=4, # part_pt, eta, phi, energy\n", - " embed_dim=10,\n", - " n_heads=2,\n", - " depth=2,\n", - " class_depth=2,\n", - " num_classes=10\n", - ")\n", - "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", - "model = model.to(device)" - ], - "metadata": { - "id": "4bNN4pfUfldh" - }, - "execution_count": 28, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "loss_fn = nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)" - ], - "metadata": { - "id": "cGmJ9BJbfzL5" - }, - "execution_count": 29, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "num_epochs = 100\n", - "\n", - "for epoch in range(num_epochs):\n", - " model.train()\n", - " epoch_loss = 0.0\n", - " for batch_idx, (x, y) in enumerate(dataloader):\n", - " x, y = x.to(device), y.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model(x) # shape [batch, 10]\n", - "\n", - " loss = loss_fn(outputs, y)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " epoch_loss += loss.item()\n", - "\n", - " avg_loss = epoch_loss / len(dataloader)\n", - " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}\")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "jJZR1RgvgqXZ", - "outputId": "e9b7efdc-1cb0-4a3f-cb23-ca54004f0304" - }, - "execution_count": 30, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch 1/100, Loss: 2.3047\n", - "Epoch 2/100, Loss: 2.3033\n", - "Epoch 3/100, Loss: 2.3028\n", - "Epoch 4/100, Loss: 2.3024\n", - "Epoch 5/100, Loss: 2.3011\n", - "Epoch 6/100, Loss: 2.2968\n", - "Epoch 7/100, Loss: 2.2857\n", - "Epoch 8/100, Loss: 2.2636\n", - "Epoch 9/100, Loss: 2.2503\n", - "Epoch 10/100, Loss: 2.2394\n", - "Epoch 11/100, Loss: 2.2311\n", - "Epoch 12/100, Loss: 2.2308\n", - "Epoch 13/100, Loss: 2.2296\n", - "Epoch 14/100, Loss: 2.2269\n", - "Epoch 15/100, Loss: 2.2260\n", - "Epoch 16/100, Loss: 2.2261\n", - "Epoch 17/100, Loss: 2.2195\n", - "Epoch 18/100, Loss: 2.2221\n", - "Epoch 19/100, Loss: 2.2174\n", - "Epoch 20/100, Loss: 2.2180\n", - "Epoch 21/100, Loss: 2.2088\n", - "Epoch 22/100, Loss: 2.2136\n", - "Epoch 23/100, Loss: 2.2127\n", - "Epoch 24/100, Loss: 2.2058\n", - "Epoch 25/100, Loss: 2.2002\n", - "Epoch 26/100, Loss: 2.1939\n", - "Epoch 27/100, Loss: 2.1861\n", - "Epoch 28/100, Loss: 2.1853\n", - "Epoch 29/100, Loss: 2.1678\n", - "Epoch 30/100, Loss: 2.1570\n", - "Epoch 31/100, Loss: 2.1857\n", - "Epoch 32/100, Loss: 2.1693\n", - "Epoch 33/100, Loss: 2.1595\n", - "Epoch 34/100, Loss: 2.1459\n", - "Epoch 35/100, Loss: 2.1359\n", - "Epoch 36/100, Loss: 2.1423\n", - "Epoch 37/100, Loss: 2.1434\n", - "Epoch 38/100, Loss: 2.1444\n", - "Epoch 39/100, Loss: 2.1384\n", - "Epoch 40/100, Loss: 2.1368\n", - "Epoch 41/100, Loss: 2.1260\n", - "Epoch 42/100, Loss: 2.1322\n", - "Epoch 43/100, Loss: 2.1340\n", - "Epoch 44/100, Loss: 2.1186\n", - "Epoch 45/100, Loss: 2.1181\n", - "Epoch 46/100, Loss: 2.1201\n", - "Epoch 47/100, Loss: 2.1211\n", - "Epoch 48/100, Loss: 2.1304\n", - "Epoch 49/100, Loss: 2.1275\n", - "Epoch 50/100, Loss: 2.1129\n", - "Epoch 51/100, Loss: 2.1206\n", - "Epoch 52/100, Loss: 2.1085\n", - "Epoch 53/100, Loss: 2.1188\n", - "Epoch 54/100, Loss: 2.1111\n", - "Epoch 55/100, Loss: 2.1189\n", - "Epoch 56/100, Loss: 2.1186\n", - "Epoch 57/100, Loss: 2.1015\n", - "Epoch 58/100, Loss: 2.1065\n", - "Epoch 59/100, Loss: 2.1056\n", - "Epoch 60/100, Loss: 2.0990\n", - "Epoch 61/100, Loss: 2.1043\n", - "Epoch 62/100, Loss: 2.1118\n", - "Epoch 63/100, Loss: 2.1060\n", - "Epoch 64/100, Loss: 2.1024\n", - "Epoch 65/100, Loss: 2.1093\n", - "Epoch 66/100, Loss: 2.1084\n", - "Epoch 67/100, Loss: 2.0977\n", - "Epoch 68/100, Loss: 2.1020\n", - "Epoch 69/100, Loss: 2.0936\n", - "Epoch 70/100, Loss: 2.0856\n", - "Epoch 71/100, Loss: 2.0964\n", - "Epoch 72/100, Loss: 2.1131\n", - "Epoch 73/100, Loss: 2.0991\n", - "Epoch 74/100, Loss: 2.1072\n", - "Epoch 75/100, Loss: 2.0974\n", - "Epoch 76/100, Loss: 2.0862\n", - "Epoch 77/100, Loss: 2.0810\n", - "Epoch 78/100, Loss: 2.0973\n", - "Epoch 79/100, Loss: 2.0945\n", - "Epoch 80/100, Loss: 2.0898\n", - "Epoch 81/100, Loss: 2.0858\n", - "Epoch 82/100, Loss: 2.0721\n", - "Epoch 83/100, Loss: 2.0817\n", - "Epoch 84/100, Loss: 2.0805\n", - "Epoch 85/100, Loss: 2.0800\n", - "Epoch 86/100, Loss: 2.0816\n", - "Epoch 87/100, Loss: 2.0824\n", - "Epoch 88/100, Loss: 2.0869\n", - "Epoch 89/100, Loss: 2.0843\n", - "Epoch 90/100, Loss: 2.0825\n", - "Epoch 91/100, Loss: 2.0839\n", - "Epoch 92/100, Loss: 2.0743\n", - "Epoch 93/100, Loss: 2.0742\n", - "Epoch 94/100, Loss: 2.0663\n", - "Epoch 95/100, Loss: 2.0720\n", - "Epoch 96/100, Loss: 2.0858\n", - "Epoch 97/100, Loss: 2.0731\n", - "Epoch 98/100, Loss: 2.0816\n", - "Epoch 99/100, Loss: 2.0783\n", - "Epoch 100/100, Loss: 2.0721\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "from torch.nn.functional import sigmoid, softmax\n", - "\n", - "model.eval()\n", - "correct = 0\n", - "total = 0\n", - "\n", - "with torch.no_grad():\n", - " for x, y in dataloader:\n", - " x, y = x.to(device), y.to(device)\n", - " outputs = model(x)\n", - " labels = torch.argmax(y, dim=1) # convert one-hot to class id\n", - " preds = torch.argmax(outputs, dim=1) # predicted class\n", - " correct += (preds == labels).sum().item()\n", - " total += y.size(0)\n", - "accuracy = correct / total\n", - "print(f\"Accuracy on full dataset: {accuracy:.4f}\")\n" - ], - "metadata": { - "id": "ogHJm4wFhBMN", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "801b47a5-d2db-4e0f-c07e-5028d0dd048f" - }, - "execution_count": 31, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Accuracy on full dataset: 0.2280\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "from sklearn.metrics import roc_auc_score\n", - "import numpy as np\n", - "\n", - "model.eval()\n", - "all_outputs = []\n", - "all_targets = []\n", - "\n", - "with torch.no_grad():\n", - " for x, y in dataloader:\n", - " x, y = x.to(device), y.to(device)\n", - " outputs = model(x)\n", - " all_outputs.append(outputs.cpu())\n", - " all_targets.append(y.cpu())\n", - "\n", - "# Concatenate batches\n", - "all_outputs = torch.cat(all_outputs, dim=0)\n", - "all_targets = torch.cat(all_targets, dim=0)\n", - "\n", - "# Apply sigmoid (if using BCEWithLogitsLoss)\n", - "probs = sigmoid(all_outputs).numpy() # shape: (N, C)\n", - "true = all_targets.numpy() # shape: (N, C)\n", - "\n", - "# Compute AUC for each class and average\n", - "try:\n", - " auc_macro = roc_auc_score(true, probs, average='macro', multi_class='ovr')\n", - " print(f\"Macro-Averaged AUC: {auc_macro:.4f}\")\n", - "except ValueError as e:\n", - " print(\"AUC could not be computed:\", e)\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "6ehi3CwmaDhL", - "outputId": "657e3d47-e294-4f45-fc83-2cb5e6fad282" - }, - "execution_count": 32, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Macro-Averaged AUC: 0.6726\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "TPgU9VB2ajLT" - }, - "execution_count": 32, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/ParT_components.ipynb b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/ParT_components.ipynb new file mode 100644 index 0000000..074cacd --- /dev/null +++ b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/ParT_components.ipynb @@ -0,0 +1,706 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "smpEEHlypz0p" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import math" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PAgmRCdTvyg6" + }, + "source": [ + "## Toy Data" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "id": "2jsmzKcFvx4L" + }, + "outputs": [], + "source": [ + "x = torch.Tensor([[ 8.7541382e+01, 7.4368027e+01, 6.8198807e+01, 5.5106983e+01,\n", + " 3.8173203e+01, 2.3811323e+01, 1.8535612e+01, 1.8095873e+01],\n", + " [-5.1816605e-02, 4.1964740e-01, 4.4769207e-01, -6.5377988e-02,\n", + " 6.1878175e-01, 6.2997532e-01, -8.1161506e-02, -7.9315454e-02],\n", + " [ 2.4468634e+00, 2.2728169e+00, 2.2340939e+00, 2.6091626e+00,\n", + " 1.8117365e+00, 1.7758595e+00, 2.3642292e+00, 2.5458102e+00],\n", + " [ 8.7658936e+01, 8.1014442e+01, 7.5148209e+01, 5.5224800e+01,\n", + " 4.5717682e+01, 2.8694658e+01, 1.8596695e+01, 1.8153358e+01]])\n", + "\n", + "x_batch = torch.Tensor([[[ 8.7541382e+01, 7.4368027e+01, 6.8198807e+01, 5.5106983e+01,\n", + " 3.8173203e+01, 2.3811323e+01, 1.8535612e+01, 1.8095873e+01],\n", + " [-5.1816605e-02, 4.1964740e-01, 4.4769207e-01, -6.5377988e-02,\n", + " 6.1878175e-01, 6.2997532e-01, -8.1161506e-02, -7.9315454e-02],\n", + " [ 2.4468634e+00, 2.2728169e+00, 2.2340939e+00, 2.6091626e+00,\n", + " 1.8117365e+00, 1.7758595e+00, 2.3642292e+00, 2.5458102e+00],\n", + " [ 8.7658936e+01, 8.1014442e+01, 7.5148209e+01, 5.5224800e+01,\n", + " 4.5717682e+01, 2.8694658e+01, 1.8596695e+01, 1.8153358e+01]],\n", + "\n", + " [[ 8.2484528e+01, 5.2682617e+01, 5.1243843e+01, 3.6217686e+01,\n", + " 2.8948278e+01, 2.6579512e+01, 2.1946012e+01, 2.1011120e+01],\n", + " [-4.3566185e-01, -8.7309110e-01, -4.4896263e-01, -6.0569459e-01,\n", + " -4.8134822e-01, -7.0045888e-01, -6.0671657e-01, -5.7662535e-01],\n", + " [-1.9739739e+00, -2.4504409e+00, -1.9982951e+00, -1.4225215e+00,\n", + " -1.9399333e+00, -2.3558097e+00, -1.4185165e+00, -1.4236869e+00],\n", + " [ 9.0437065e+01, 7.4070679e+01, 5.6495895e+01, 4.3069641e+01,\n", + " 3.2367134e+01, 3.3371326e+01, 2.6115334e+01, 2.4602446e+01]],\n", + "\n", + " [[ 8.6492935e+01, 7.0192978e+01, 5.8423912e+01, 5.6638733e+01,\n", + " 4.9270725e+01, 4.1237038e+01, 3.6133625e+01, 3.5519596e+01],\n", + " [ 1.4010678e-01, 2.7912292e-01, 1.4376265e-01, 3.4672296e-01,\n", + " 3.4966472e-01, 1.0524009e-01, 1.2958543e-01, 3.3264065e-01],\n", + " [ 1.9334941e+00, 1.6967584e+00, 1.9219695e+00, 1.6735281e+00,\n", + " 1.6587850e+00, 1.8386338e+00, 1.9120301e+00, 1.6680365e+00],\n", + " [ 8.7343246e+01, 7.2945129e+01, 5.9030762e+01, 6.0084766e+01,\n", + " 5.2313778e+01, 4.1465607e+01, 3.6440781e+01, 3.7506149e+01]]])" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Po7Fb1v37EJT", + "outputId": "5a9e6be0-869f-4811-c3b7-4d372bc69462" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([1, 4, 8]), torch.Size([3, 4, 8]))" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = x.unsqueeze(0) # batch dimension\n", + "x.shape, x_batch.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X6d5fzsV4kt7" + }, + "source": [ + "## Toy interaction matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sOjDJYr14zDr", + "outputId": "c062e171-de41-4c9f-c97a-5d639620b233" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "U.shape: torch.Size([1, 2, 8, 8])\n" + ] + } + ], + "source": [ + "class InteractionEncoder(nn.Module):\n", + " \"\"\"\n", + " ParT interaction-feature encoder.\n", + "\n", + " Args\n", + " ----\n", + " n_heads per mhsa: output channels d′\n", + " hidden_channels : list[int] for intermediate 1×1 conv layers\n", + " eps : numerical guard for log\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " n_heads: int = 8,\n", + " hidden_channels: list[int] = (64, 64, 64),\n", + " eps: float = 1e-8):\n", + " super().__init__()\n", + " self.eps = eps\n", + "\n", + " layers: list[nn.Module] = []\n", + " in_ch = 4 # lnΔ, ln kT, ln z, ln m²\n", + " for h in hidden_channels:\n", + " layers += [\n", + " nn.Conv2d(in_ch, h, 1, bias=False),\n", + " nn.BatchNorm2d(h),\n", + " nn.GELU()\n", + " ]\n", + " in_ch = h\n", + " layers.append(nn.Conv2d(in_ch, n_heads, 1, bias=False))\n", + " self.net = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " x : (B, 4, N) where the 4 dims are (E, px, py, pz)\n", + " returns\n", + " ------\n", + " U : (B, n_heads, N, N) interaction embedding\n", + " \"\"\"\n", + " B, four, N = x.shape\n", + " assert four == 4, \"input must have 4 features: E, px, py, pz\"\n", + "\n", + " # Split components\n", + " E, px, py, pz = x.unbind(dim=1) # each (B, N)\n", + "\n", + " # Basic kinematics ------------------------------------------------\n", + " pT = torch.sqrt(px**2 + py**2) + self.eps\n", + " phi = torch.atan2(py, px) # (−π, π]\n", + " num = (E + pz).clamp(min=self.eps) #need to avoid negative numbers\n", + " den = (E - pz).clamp(min=self.eps)\n", + " y = 0.5 * torch.log(num / den)\n", + "\n", + " # Expand to (B, N, N)\n", + " y_a, y_b = y.unsqueeze(2), y.unsqueeze(1) # (B,N,1),(B,1,N)\n", + " phi_a, phi_b = phi.unsqueeze(2), phi.unsqueeze(1)\n", + " pT_a, pT_b = pT.unsqueeze(2), pT.unsqueeze(1)\n", + " E_a, E_b = E.unsqueeze(2), E.unsqueeze(1)\n", + " px_a, px_b = px.unsqueeze(2), px.unsqueeze(1)\n", + " py_a, py_b = py.unsqueeze(2), py.unsqueeze(1)\n", + " pz_a, pz_b = pz.unsqueeze(2), pz.unsqueeze(1)\n", + "\n", + " # ΔR, kT, z\n", + " delta = torch.sqrt((y_a - y_b) ** 2 + (phi_a - phi_b) ** 2) + self.eps\n", + " kT = torch.minimum(pT_a, pT_b) * delta\n", + " z = torch.minimum(pT_a, pT_b) / (pT_a + pT_b + self.eps)\n", + "\n", + " # m² of pair\n", + " E_sum = E_a + E_b\n", + " px_sum = px_a + px_b\n", + " py_sum = py_a + py_b\n", + " pz_sum = pz_a + pz_b\n", + " m2 = E_sum**2 - (px_sum**2 + py_sum**2 + pz_sum**2) + self.eps\n", + " m2 = torch.clamp(m2, min=self.eps) # avoid negatives\n", + "\n", + " # Stack → (B, 4, N, N)\n", + " feats = torch.stack([\n", + " torch.log(delta),\n", + " torch.log(kT),\n", + " torch.log(z),\n", + " torch.log(m2)\n", + " ], dim=1)\n", + "\n", + " # conv\n", + " U = self.net(feats) # (B, n_heads, N, N)\n", + " return U\n", + "\n", + "\n", + "\n", + "B, _, N = x.shape\n", + "n_heads = 2 # d′\n", + "enc = InteractionEncoder(n_heads=n_heads)\n", + "U = enc(x)\n", + "print(\"U.shape:\", U.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5CwQyoPYwaR7" + }, + "source": [ + "## Particle Transformer" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ya5SvxTJwE3r", + "outputId": "6e347893-293b-4d76-d7d7-926c5382a670" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 8, 10])" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class ParticleTokenizer(nn.Module):\n", + " def __init__(self, in_dim=4, out_dim=6):\n", + " super().__init__()\n", + " self.proj = nn.Linear(in_dim, out_dim)\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\n", + " x: tensor of shape (B, n_particles, in_dim)\n", + " returns: (B, n_particles, out_dim)\n", + " \"\"\"\n", + " x = x.transpose(1, 2) # Input shape: (B, n_particles, in_dim) → (B, in_dim, n_particles)\n", + " return self.proj(x)\n", + "\n", + "tokenizer = ParticleTokenizer(4, 10)\n", + "output = tokenizer(x)\n", + "output.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "S5f8DTRY07PD", + "outputId": "e553e047-ead1-43da-eee4-001c715f7258" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 8, 10])" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class MLP(nn.Module):\n", + " def __init__(self, dim, expansion=1, dropout=0.):\n", + " super().__init__()\n", + " hidden = dim * expansion\n", + " self.net = nn.Sequential(\n", + " nn.Linear(dim, hidden), nn.GELU(), nn.Dropout(dropout),\n", + " nn.Linear(hidden, dim), nn.Dropout(dropout)\n", + " )\n", + " def forward(self, x): return self.net(x)\n", + "\n", + "mlp = MLP(10, expansion=1, dropout=0.1)\n", + "output = mlp(output)\n", + "output.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "zb5nYOsSyFRk", + "outputId": "0b34f9c7-b47b-4000-856e-e14ecf6c120b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 8, 10]) torch.Size([1, 2, 8, 8])\n" + ] + } + ], + "source": [ + "class ParticleMHA(nn.Module):\n", + " \"\"\"\n", + " Multi-head self-attention with additive interaction bias U.\n", + "\n", + " Input\n", + " -----\n", + " x : (B, N, d) token / particle embeddings\n", + " U : (broadcast → B, H, N, N) or None\n", + "\n", + " Returns\n", + " -------\n", + " out : (B, N, d) attention output\n", + " attn_map : (B, H, N, N) attention weights (returned if\n", + " return_attn=True)\n", + " \"\"\"\n", + " def __init__(self, d: int, heads: int = 8,\n", + " dropout: float = 0.1, return_attn: bool = False):\n", + " super().__init__()\n", + " assert d % heads == 0, \"`d` must be divisible by `heads`\"\n", + "\n", + " self.d = d\n", + " self.h = heads\n", + " self.d_head = d // heads\n", + " self.scale = 1 / math.sqrt(self.d_head)\n", + " self.return_attn = return_attn\n", + "\n", + " # Projections\n", + " self.q = nn.Linear(d, d, bias=False)\n", + " self.k = nn.Linear(d, d, bias=False)\n", + " self.v = nn.Linear(d, d, bias=False)\n", + " self.o = nn.Linear(d, d, bias=False)\n", + "\n", + " self.drop = nn.Dropout(dropout)\n", + "\n", + " def _split(self, t: torch.Tensor):\n", + " # (B, N, d) -> (B, H, N, d_head)\n", + " B, N, _ = t.shape\n", + " return (\n", + " t.view(B, N, self.h, self.d_head) # (B, N, H, d_head)\n", + " .transpose(1, 2) # (B, H, N, d_head)\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor,\n", + " U: torch.Tensor | None = None):\n", + " B, N, _ = x.shape\n", + "\n", + " Q = self._split(self.q(x))\n", + " K = self._split(self.k(x))\n", + " V = self._split(self.v(x))\n", + "\n", + " logits = (Q @ K.transpose(-2, -1)) * self.scale # (B, H, N, N)\n", + "\n", + " if U is not None:\n", + " logits = logits + U\n", + "\n", + " attn = F.softmax(logits, dim=-1)\n", + " attn = self.drop(attn)\n", + "\n", + " context = attn @ V # (B, H, N, d_h)\n", + "\n", + " context = (\n", + " context.transpose(1, 2) # (B, N, H, d_h)\n", + " .contiguous()\n", + " .view(B, N, self.d) # (B, N, d)\n", + " )\n", + " out = self.o(context)\n", + "\n", + " if self.return_attn:\n", + " return out, attn # (B, N, d), (B, H, N, N)\n", + " else:\n", + " return out\n", + "\n", + "B, N, d = output.shape\n", + "U = torch.randn(1, 2, N, N) # broadcast to (B, H, N, N)\n", + "\n", + "pmha = ParticleMHA(d=d, heads=2, dropout=0.1, return_attn=True)\n", + "output, A = pmha(output, U) # out: (B, N, d) A: (B, n_heads, N, N)\n", + "\n", + "print(output.shape, A.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4650yyFtdjaR" + }, + "source": [ + "## transformer" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "id": "uHgMbZTo1jkU" + }, + "outputs": [], + "source": [ + "class MHA(nn.Module):\n", + " \"\"\"\n", + " Multi-head attention (batch_first) implemented explicitly.\n", + "\n", + " Args\n", + " ----\n", + " d_model : int embedding dim\n", + " n_heads : int\n", + " dropout: float\n", + " bias : bool use bias in projections\n", + " \"\"\"\n", + " def __init__(self, d_model: int, n_heads: int, dropout: float = 0., bias: bool = False):\n", + " super().__init__()\n", + " assert d_model % n_heads == 0, \"`d_model` must be divisible by `n_heads`\"\n", + " self.d_model = d_model\n", + " self.h = n_heads\n", + " self.d_head = d_model // n_heads\n", + " self.scale = self.d_head ** -0.5\n", + "\n", + " self.q_proj = nn.Linear(d_model, d_model, bias=bias)\n", + " self.k_proj = nn.Linear(d_model, d_model, bias=bias)\n", + " self.v_proj = nn.Linear(d_model, d_model, bias=bias)\n", + " self.o_proj = nn.Linear(d_model, d_model, bias=bias)\n", + "\n", + " self.drop = nn.Dropout(dropout)\n", + "\n", + " def _split_heads(self, x: torch.Tensor):\n", + " # (B, L, d_model) -> (B, h, L, d_head)\n", + " B, L, _ = x.shape\n", + " return x.view(B, L, self.h, self.d_head).transpose(1, 2)\n", + "\n", + " def _merge_heads(self, x: torch.Tensor):\n", + " # (B, h, L, d_head) -> (B, L, d_model)\n", + " B, H, L, Dh = x.shape\n", + " return x.transpose(1, 2).contiguous().view(B, L, H * Dh)\n", + "\n", + " def forward(\n", + " self,\n", + " q: torch.Tensor, # (B, Lq, d_model)\n", + " k: torch.Tensor, # (B, Lk, d_model)\n", + " v: torch.Tensor, # (B, Lk, d_model)\n", + " need_weights: bool = False\n", + " ):\n", + " B, Lq, _ = q.shape\n", + " _, Lk, _ = k.shape\n", + "\n", + " Q = self._split_heads(self.q_proj(q)) # (B,h,Lq,d_h)\n", + " K = self._split_heads(self.k_proj(k)) # (B,h,Lk,d_h)\n", + " V = self._split_heads(self.v_proj(v)) # (B,h,Lk,d_h)\n", + "\n", + " logits = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B,h,Lq,Lk)\n", + "\n", + " attn = F.softmax(logits, dim=-1)\n", + " attn = self.drop(attn)\n", + "\n", + " context = torch.matmul(attn, V) # (B,h,Lq,d_h)\n", + "\n", + " # merge heads + output proj\n", + " out = self.o_proj(self._merge_heads(context)) # (B,Lq,d_model)\n", + "\n", + " if need_weights:\n", + " avg_weights = attn.mean(dim=1) # (B,Lq,Lk)\n", + " return out, avg_weights\n", + " return out, None" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "id": "iQveEBhldizq" + }, + "outputs": [], + "source": [ + "# Particle attention block (NormFormer style + U-bias)\n", + "class ParticleAttentionBlock(nn.Module):\n", + " def __init__(self, dim, heads, mlp_ratio=4, dropout=0.1):\n", + " super().__init__()\n", + " self.ln1 = nn.LayerNorm(dim)\n", + " self.attn = ParticleMHA(dim, heads, dropout)\n", + " self.ln2 = nn.LayerNorm(dim)\n", + " self.mlp = MLP(dim, mlp_ratio, dropout)\n", + " def forward(self, x, U):\n", + " x = x + self.attn(self.ln1(x), U) # bias-aware MHSA\n", + " x = x + self.mlp(self.ln2(x)) # feed-forward\n", + " return x\n", + "\n", + "# Class attention block (CaiT style, no U)\n", + "class ClassAttentionBlock(nn.Module):\n", + " def __init__(self, dim, heads, mlp_ratio=4, dropout=0.):\n", + " super().__init__()\n", + " self.ln1 = nn.LayerNorm(dim)\n", + " self.attn = MHA(dim, heads, dropout)\n", + " self.ln2 = nn.LayerNorm(dim)\n", + " self.mlp = MLP(dim, mlp_ratio, dropout)\n", + " def forward(self, tokens, cls): # tokens: (B,N,d), cls: (B,1,d)\n", + " z = torch.cat([cls, tokens], dim=1) # (B,1+N,d)\n", + " q = self.ln1(cls)\n", + " kv = self.ln1(z)\n", + " cls = cls + self.attn(q, kv, kv, need_weights=False)[0]\n", + " cls = cls + self.mlp(self.ln2(cls))\n", + " return cls # (B,1,d)\n", + "\n", + "# Complete Particle Transformer\n", + "class ParT(nn.Module):\n", + " def __init__(self,\n", + " in_dim=4, # (E,px,py,pz)\n", + " embed_dim=10,\n", + " n_heads=2,\n", + " depth=2, # particle blocks\n", + " class_depth=2, # class-attention blocks\n", + " mlp_ratio=4,\n", + " num_classes=10,\n", + " dropout=0.1):\n", + " super().__init__()\n", + "\n", + " self.tokenizer = ParticleTokenizer(in_dim, embed_dim)\n", + " self.U_encoder = InteractionEncoder(n_heads=n_heads)\n", + "\n", + " self.blocks = nn.ModuleList([\n", + " ParticleAttentionBlock(embed_dim, n_heads, mlp_ratio, dropout)\n", + " for _ in range(depth)\n", + " ])\n", + "\n", + " self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n", + " self.cls_blocks = nn.ModuleList([\n", + " ClassAttentionBlock(embed_dim, n_heads, mlp_ratio, 0.0)\n", + " for _ in range(class_depth)\n", + " ])\n", + "\n", + " self.head = nn.Linear(embed_dim, num_classes)\n", + "\n", + " nn.init.trunc_normal_(self.class_token, std=0.02)\n", + " nn.init.trunc_normal_(self.head.weight, std=0.02)\n", + " nn.init.zeros_(self.head.bias)\n", + "\n", + " def forward(self, x): # x: (B,4,N)\n", + " B, _, N = x.shape\n", + "\n", + " tokens = self.tokenizer(x) # (B,N,d)\n", + " U = self.U_encoder(x) # (B,H,N,N)\n", + "\n", + " for blk in self.blocks:\n", + " tokens = blk(tokens, U) # (B,N,d)\n", + "\n", + " cls = self.class_token.expand(B, -1, -1) # (B,1,d)\n", + " for blk in self.cls_blocks:\n", + " cls = blk(tokens, cls) # (B,1,d)\n", + "\n", + " logits = self.head(cls.squeeze(1)) # (B,10)\n", + " return logits" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0-g5CQ2hfh2x", + "outputId": "ef81a92c-0aa7-4a94-8c19-72ff3d20a178" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logits: torch.Size([3, 10])\n" + ] + } + ], + "source": [ + "B, _, N = x_batch.shape # (3,4,8)\n", + "model = ParT(in_dim=4,\n", + " embed_dim=10,\n", + " n_heads=2,\n", + " depth=2,\n", + " class_depth=2,\n", + " num_classes=10)\n", + "\n", + "logits = model(x_batch) # forward pass\n", + "print(\"logits:\", logits.shape) # -> torch.Size([3, 10])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mOr6fyLvfjgv", + "outputId": "093278c3-6069-4f62-d8b9-efbb3db8caac" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 1 loss 2.3153 acc 0.000\n", + "epoch 50 loss 1.1357 acc 0.333\n", + "epoch 100 loss 1.0939 acc 0.667\n", + "epoch 150 loss 0.8085 acc 0.667\n", + "epoch 200 loss 0.0948 acc 1.000\n", + "epoch 250 loss 0.0151 acc 1.000\n", + "softmax-probs\n", + " tensor([[9.8189e-01, 2.2029e-03, 1.5892e-02, 2.0028e-06, 2.2383e-06, 1.8703e-06,\n", + " 2.4472e-06, 2.0614e-06, 1.7592e-06, 2.5751e-06],\n", + " [8.7813e-03, 9.9055e-01, 6.0938e-07, 4.5406e-05, 1.3593e-04, 9.3723e-05,\n", + " 6.2527e-05, 7.2221e-05, 9.5722e-05, 1.6248e-04],\n", + " [7.9069e-03, 4.2077e-07, 9.9208e-01, 3.3477e-06, 1.5375e-06, 1.7166e-06,\n", + " 3.2025e-06, 2.4078e-06, 1.5781e-06, 1.4262e-06]])\n" + ] + } + ], + "source": [ + "x_train = x_batch # (3, 4, 8)\n", + "y_train = torch.tensor([0, 1, 2]) # dummy class labels for testing\n", + "\n", + "model = ParT(in_dim=4,\n", + " embed_dim=10,\n", + " n_heads=2,\n", + " depth=2,\n", + " class_depth=2,\n", + " num_classes=10)\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "\n", + "n_epochs = 250\n", + "for epoch in range(n_epochs):\n", + " model.train()\n", + " logits = model(x_train) # (3, 10)\n", + " loss = criterion(logits, y_train)\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # print every 50 epochs\n", + " if (epoch+1) % 50 == 0 or epoch == 0:\n", + " preds = logits.argmax(1)\n", + " acc = (preds == y_train).float().mean().item()\n", + " print(f\"epoch {epoch+1:3d} loss {loss.item():.4f} acc {acc:.3f}\")\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " probs = torch.softmax(model(x_train), dim=1)\n", + "print(\"softmax-probs\\n\", probs)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "pytorch-2.6.0", + "language": "python", + "name": "pytorch-2.6.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/QParT.py b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/QParT.py new file mode 100644 index 0000000..7d62bda --- /dev/null +++ b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/QParT.py @@ -0,0 +1,504 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +import numpy as np +np.ComplexWarning = Warning + +from typing import Callable + +# --- monekypatch for tc--- +import jax +try: + _ = jax.tree_map # JAX < 0.6 has this +except AttributeError: + # JAX ≥ 0.6 moved it here + from jax import tree_util as _jtu + jax.tree_map = _jtu.tree_map + +import tensorcircuit as tc + +import tensorcircuit as tc +import jax.numpy as jnp +import flax.linen + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import tensorcircuit as tc + +K = tc.set_backend("jax") + + +def angle_embedding(c: tc.Circuit, inputs): + num_qubits = inputs.shape[-1] + + for j in range(num_qubits): + c.rx(j, theta=inputs[j]) + + +def basic_vqc(c: tc.Circuit, inputs, weights): + num_qubits = inputs.shape[-1] + num_qlayers = weights.shape[-2] + + for i in range(num_qlayers): + for j in range(num_qubits): + c.rx(j, theta=weights[i, j]) + if num_qubits == 2: + c.cnot(0, 1) + elif num_qubits > 2: + for j in range(num_qubits): + c.cnot(j, (j + 1) % num_qubits) + + +def get_quantum_layer_circuit(inputs, weights, + embedding: Callable = angle_embedding, vqc: Callable = basic_vqc): + """ + Equivalent to the following PennyLane circuit: + def circuit(inputs, weights): + qml.templates.AngleEmbedding(inputs, wires=range(num_qubits)) + qml.templates.BasicEntanglerLayers(weights, wires=range(num_qubits)) + """ + + num_qubits = inputs.shape[-1] + + c = tc.Circuit(num_qubits) + embedding(c, inputs) + vqc(c, inputs, weights) + + return c + + +def get_circuit(embedding: Callable = angle_embedding, vqc: Callable = basic_vqc, + torch_interface: bool = False): + def qpred(inputs, weights): + c = get_quantum_layer_circuit(inputs, weights, embedding, vqc) + return K.real(jnp.array([c.expectation_ps(z=[i]) for i in range(weights.shape[1])])) + + qpred_batch = K.vmap(qpred, vectorized_argnums=0) + if torch_interface: + qpred_batch = tc.interfaces.torch_interface(qpred_batch, jit=True) + + return qpred_batch + + +class QuantumLayer(flax.linen.Module): + circuit: Callable + num_qubits: int + w_shape: tuple = (1,) + + @flax.linen.compact + def __call__(self, x): + shape = x.shape + x = jnp.reshape(x, (-1, shape[-1])) + w = self.param('w', flax.linen.initializers.xavier_normal(), self.w_shape + (self.num_qubits,)) + x = self.circuit(x, w) + x = jnp.concatenate(x, axis=-1) + x = jnp.reshape(x, tuple(shape)) + return x + +NUM_QUBITS = 8 +NUM_Q_LAYERS = 1 +torch_layer_fn = get_circuit(torch_interface=True) + + +class TCTorchLayer(nn.Module): + """ + A thin PyTorch wrapper around the TensorCircuit/TC quantum layer. + Stores the circuit's trainable parameters as an nn.Parameter so + they appear in .parameters() and get updated by any torch optimizer. + """ + def __init__(self, num_qubits=NUM_QUBITS, num_qlayers=NUM_Q_LAYERS): + super().__init__() + init_w = 0.01 * torch.randn(num_qlayers, num_qubits) + self.w = nn.Parameter(init_w) + self.num_qubits = num_qubits + + def forward(self, x): + """ + x: (batch, num_qubits) – already pre-scaled into rotation angles. + Returns expectation values ⟨Z_i⟩ for every qubit i, shape identical + to the input (batch, num_qubits). + """ + return torch_layer_fn(x, self.w) + + +class QuantumLinear(nn.Module): + """ + Linear -> angle map -> TCTorchLayer -> Linear + Works on tensors shaped (..., din) and returns (..., dout). + """ + def __init__(self, din, dout, num_qubits): + super().__init__() + self.din = din + self.dout = dout + self.nq = num_qubits + + #self.to_q = nn.Linear(din, self.nq, bias=False) + #self.from_q = nn.Linear(self.nq, dout, bias=False) + self.q = TCTorchLayer(self.nq) + + @staticmethod + def _to_angles(x): + return torch.tanh(x) * math.pi + + def forward(self, x): + # x: (..., din) + *prefix, _ = x.shape + x = x.reshape(-1, self.din) + + #x = self.to_q(x) + x = self._to_angles(x) + x = self.q(x).float() + #x = self.from_q(x) + + x = x.reshape(*prefix, self.dout) + return x + +class InteractionEncoder(nn.Module): + """ + ParT interaction-feature encoder. + + Args + ---- + n_heads per mhsa: output channels d′ + hidden_channels : list[int] for intermediate 1×1 conv layers + eps : numerical guard for log + """ + + def __init__(self, + n_heads: int = 8, + hidden_channels: list[int] = (64, 64, 64), + eps: float = 1e-8): + super().__init__() + self.eps = eps + + layers: list[nn.Module] = [] + in_ch = 4 # lnΔ, ln kT, ln z, ln m² + for h in hidden_channels: + layers += [ + nn.Conv2d(in_ch, h, 1, bias=False), + nn.BatchNorm2d(h), + nn.GELU() + ] + in_ch = h + layers.append(nn.Conv2d(in_ch, n_heads, 1, bias=False)) + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x : (B, 4, N) where the 4 dims are (E, px, py, pz) + returns + ------ + U : (B, n_heads, N, N) interaction embedding + """ + B, four, N = x.shape + assert four == 4, "input must have 4 features: E, px, py, pz" + + # Split components + E, px, py, pz = x.unbind(dim=1) # each (B, N) + + # Basic kinematics ------------------------------------------------ + pT = torch.sqrt(px**2 + py**2) + self.eps + phi = torch.atan2(py, px) # (−π, π] + num = (E + pz).clamp(min=self.eps) #need to avoid negative numbers + den = (E - pz).clamp(min=self.eps) + y = 0.5 * torch.log(num / den) + + # Expand to (B, N, N) + y_a, y_b = y.unsqueeze(2), y.unsqueeze(1) # (B,N,1),(B,1,N) + phi_a, phi_b = phi.unsqueeze(2), phi.unsqueeze(1) + pT_a, pT_b = pT.unsqueeze(2), pT.unsqueeze(1) + E_a, E_b = E.unsqueeze(2), E.unsqueeze(1) + px_a, px_b = px.unsqueeze(2), px.unsqueeze(1) + py_a, py_b = py.unsqueeze(2), py.unsqueeze(1) + pz_a, pz_b = pz.unsqueeze(2), pz.unsqueeze(1) + + # ΔR, kT, z + delta = torch.sqrt((y_a - y_b) ** 2 + (phi_a - phi_b) ** 2) + self.eps + kT = torch.minimum(pT_a, pT_b) * delta + z = torch.minimum(pT_a, pT_b) / (pT_a + pT_b + self.eps) + + # m² of pair + E_sum = E_a + E_b + px_sum = px_a + px_b + py_sum = py_a + py_b + pz_sum = pz_a + pz_b + m2 = E_sum**2 - (px_sum**2 + py_sum**2 + pz_sum**2) + self.eps + m2 = torch.clamp(m2, min=self.eps) # avoid negatives + + # Stack → (B, 4, N, N) + feats = torch.stack([ + torch.log(delta), + torch.log(kT), + torch.log(z), + torch.log(m2) + ], dim=1) + + # conv + U = self.net(feats) # (B, n_heads, N, N) + return U + + +class ParticleTokenizer(nn.Module): + def __init__(self, in_dim=4, out_dim=6): + super().__init__() + self.proj = nn.Linear(in_dim, out_dim) + + def forward(self, x): + """ + x: tensor of shape (B, n_particles, in_dim) + returns: (B, n_particles, out_dim) + """ + x = x.transpose(1, 2) # Input shape: (B, n_particles, in_dim) → (B, in_dim, n_particles) + return self.proj(x) + +class MLP(nn.Module): + """ + Same interface as your tiny MLP, but nn.Linear -> QuantumLinear. + Works for inputs shaped (..., dim). + + Args: + dim : feature size + dropout : dropout prob + num_qubits : qubits per QuantumLinear block (defaults to dim) + """ + def __init__(self, dim, dropout=0., num_qubits=None): + super().__init__() + nq = num_qubits if num_qubits is not None else dim + + self.fc1 = QuantumLinear(dim, dim, nq) + self.fc2 = QuantumLinear(dim, dim, nq) + + self.act = nn.GELU() + self.do1 = nn.Dropout(dropout) + self.do2 = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.do1(x) + + x = self.fc2(x) + x = self.do2(x) + return x + +class ParticleMHA(nn.Module): + """ + Multi-head self-attention with quantum projections (q, k, v, o). + + Args + ---- + d : embedding dim + heads : number of attention heads + dropout : dropout prob on attn weights + return_attn : return attention maps? + num_qubits : qubits per quantum block (defaults to d) + """ + def __init__(self, d: int, heads: int = 8, + dropout: float = 0.1, return_attn: bool = False, + num_qubits: int | None = None): + super().__init__() + assert d % heads == 0, "`d` must be divisible by `heads`" + + self.d = d + self.h = heads + self.d_head = d // heads + self.scale = 1 / math.sqrt(self.d_head) + self.return_attn = return_attn + + nq = num_qubits if num_qubits is not None else d + + # quantum projections + self.q_proj = QuantumLinear(d, d, nq) + self.k_proj = QuantumLinear(d, d, nq) + self.v_proj = QuantumLinear(d, d, nq) + self.o_proj = QuantumLinear(d, d, nq) + + self.drop = nn.Dropout(dropout) + + def _split(self, t: torch.Tensor): + # (B, N, d) -> (B, H, N, d_head) + B, N, _ = t.shape + return t.view(B, N, self.h, self.d_head).transpose(1, 2) + + def forward(self, x: torch.Tensor, U: torch.Tensor | None = None): + B, N, _ = x.shape + + Q = self._split(self.q_proj(x)) + K = self._split(self.k_proj(x)) + V = self._split(self.v_proj(x)) + + logits = (Q @ K.transpose(-2, -1)) * self.scale # (B, H, N, N) + + if U is not None: + logits = logits + U + + attn = F.softmax(logits, dim=-1) + attn = self.drop(attn) + + context = attn @ V # (B, H, N, d_head) + + context = ( + context.transpose(1, 2) # (B, N, H, d_head) + .contiguous() + .view(B, N, self.d) + ) + out = self.o_proj(context) + + if self.return_attn: + return out, attn + else: + return out + +class MHA(nn.Module): + """ + Multi-head attention (batch_first) with QuantumLinear projections. + + Args + ---- + d_model : int embedding dim + n_heads : int + dropout: float + bias : bool (ignored here, QuantumLinear has no bias) + num_qubits : int|None qubits per quantum block (defaults to d_model) + """ + def __init__(self, d_model: int, n_heads: int, + dropout: float = 0., bias: bool = False, + num_qubits: int | None = None): + super().__init__() + assert d_model % n_heads == 0, "`d_model` must be divisible by `n_heads`" + self.d_model = d_model + self.h = n_heads + self.d_head = d_model // n_heads + self.scale = self.d_head ** -0.5 + + nq = num_qubits if num_qubits is not None else d_model + + # Quantum projections replace nn.Linear + self.q_proj = QuantumLinear(d_model, d_model, nq) + self.k_proj = QuantumLinear(d_model, d_model, nq) + self.v_proj = QuantumLinear(d_model, d_model, nq) + self.o_proj = QuantumLinear(d_model, d_model, nq) + + self.drop = nn.Dropout(dropout) + + def _split_heads(self, x: torch.Tensor): + # (B, L, d_model) -> (B, h, L, d_head) + B, L, _ = x.shape + return x.view(B, L, self.h, self.d_head).transpose(1, 2) + + def _merge_heads(self, x: torch.Tensor): + # (B, h, L, d_head) -> (B, L, d_model) + B, H, L, Dh = x.shape + return x.transpose(1, 2).contiguous().view(B, L, H * Dh) + + def forward( + self, + q: torch.Tensor, # (B, Lq, d_model) + k: torch.Tensor, # (B, Lk, d_model) + v: torch.Tensor, # (B, Lk, d_model) + attn_mask: torch.Tensor | None = None, + key_padding_mask: torch.Tensor | None = None, + need_weights: bool = False + ): + B, Lq, _ = q.shape + _, Lk, _ = k.shape + + Q = self._split_heads(self.q_proj(q)) # (B,h,Lq,d_h) + K = self._split_heads(self.k_proj(k)) # (B,h,Lk,d_h) + V = self._split_heads(self.v_proj(v)) # (B,h,Lk,d_h) + + logits = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B,h,Lq,Lk) + + attn = F.softmax(logits, dim=-1) + attn = self.drop(attn) + + context = torch.matmul(attn, V) # (B,h,Lq,d_h) + + out = self.o_proj(self._merge_heads(context)) # (B,Lq,d_model) + + if need_weights: + return out, attn.mean(dim=1) # (B,Lq,Lk) + return out, None + +# Particle attention block (NormFormer style + U-bias) +class ParticleAttentionBlock(nn.Module): + def __init__(self, dim, heads, mlp_ratio=4, dropout=0.1): + super().__init__() + self.ln1 = nn.LayerNorm(dim) + self.attn = ParticleMHA(dim, heads, dropout) + self.ln2 = nn.LayerNorm(dim) + self.mlp = MLP(dim, dropout) + def forward(self, x, U): + x = x + self.attn(self.ln1(x), U) # bias-aware MHSA + x = x + self.mlp(self.ln2(x)) # feed-forward + return x + +# Class attention block (CaiT style, no U) +class ClassAttentionBlock(nn.Module): + def __init__(self, dim, heads, mlp_ratio=4, dropout=0.): + super().__init__() + self.ln1 = nn.LayerNorm(dim) + self.attn = MHA(dim, heads, dropout) + self.ln2 = nn.LayerNorm(dim) + self.mlp = MLP(dim, dropout) + def forward(self, tokens, cls): # tokens: (B,N,d), cls: (B,1,d) + z = torch.cat([cls, tokens], dim=1) # (B,1+N,d) + q = self.ln1(cls) + kv = self.ln1(z) + cls = cls + self.attn(q, kv, kv, need_weights=False)[0] + cls = cls + self.mlp(self.ln2(cls)) + return cls # (B,1,d) + +# Complete Particle Transformer +class ParT(nn.Module): + def __init__(self, + in_dim=4, # (E,px,py,pz) + embed_dim=10, + n_heads=2, + depth=2, # particle blocks + class_depth=2, # class-attention blocks + mlp_ratio=4, + num_classes=10, + dropout=0.1): + super().__init__() + + self.tokenizer = ParticleTokenizer(in_dim, embed_dim) + self.U_encoder = InteractionEncoder(n_heads=n_heads) + + self.blocks = nn.ModuleList([ + ParticleAttentionBlock(embed_dim, n_heads, mlp_ratio, dropout) + for _ in range(depth) + ]) + + self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.cls_blocks = nn.ModuleList([ + ClassAttentionBlock(embed_dim, n_heads, mlp_ratio, 0.0) + for _ in range(class_depth) + ]) + + self.head = nn.Linear(embed_dim, num_classes) + + # weight init + nn.init.trunc_normal_(self.class_token, std=0.02) + nn.init.trunc_normal_(self.head.weight, std=0.02) + nn.init.zeros_(self.head.bias) + + def forward(self, x): # x: (B,4,N) + B, _, N = x.shape + + tokens = self.tokenizer(x) # (B,N,d) + U = self.U_encoder(x) # (B,H,N,N) + + for blk in self.blocks: + tokens = blk(tokens, U) # (B,N,d) + + cls = self.class_token.expand(B, -1, -1) # (B,1,d) + for blk in self.cls_blocks: + cls = blk(tokens, cls) # (B,1,d) + + logits = self.head(cls.squeeze(1)) # (B,10) + return logits \ No newline at end of file diff --git a/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Quantum_ParT.ipynb b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Quantum_ParT.ipynb deleted file mode 100644 index f78ca48..0000000 --- a/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Quantum_ParT.ipynb +++ /dev/null @@ -1,1431 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "machine_shape": "hm", - "gpuType": "T4" - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" - }, - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "smpEEHlypz0p" - }, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import math" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Toy Data" - ], - "metadata": { - "id": "PAgmRCdTvyg6" - } - }, - { - "cell_type": "code", - "source": [ - "x = torch.Tensor([[ 8.7541382e+01, 7.4368027e+01, 6.8198807e+01, 5.5106983e+01,\n", - " 3.8173203e+01, 2.3811323e+01, 1.8535612e+01, 1.8095873e+01],\n", - " [-5.1816605e-02, 4.1964740e-01, 4.4769207e-01, -6.5377988e-02,\n", - " 6.1878175e-01, 6.2997532e-01, -8.1161506e-02, -7.9315454e-02],\n", - " [ 2.4468634e+00, 2.2728169e+00, 2.2340939e+00, 2.6091626e+00,\n", - " 1.8117365e+00, 1.7758595e+00, 2.3642292e+00, 2.5458102e+00],\n", - " [ 8.7658936e+01, 8.1014442e+01, 7.5148209e+01, 5.5224800e+01,\n", - " 4.5717682e+01, 2.8694658e+01, 1.8596695e+01, 1.8153358e+01]])\n", - "\n", - "x_batch = torch.Tensor([[[ 8.7541382e+01, 7.4368027e+01, 6.8198807e+01, 5.5106983e+01,\n", - " 3.8173203e+01, 2.3811323e+01, 1.8535612e+01, 1.8095873e+01],\n", - " [-5.1816605e-02, 4.1964740e-01, 4.4769207e-01, -6.5377988e-02,\n", - " 6.1878175e-01, 6.2997532e-01, -8.1161506e-02, -7.9315454e-02],\n", - " [ 2.4468634e+00, 2.2728169e+00, 2.2340939e+00, 2.6091626e+00,\n", - " 1.8117365e+00, 1.7758595e+00, 2.3642292e+00, 2.5458102e+00],\n", - " [ 8.7658936e+01, 8.1014442e+01, 7.5148209e+01, 5.5224800e+01,\n", - " 4.5717682e+01, 2.8694658e+01, 1.8596695e+01, 1.8153358e+01]],\n", - "\n", - " [[ 8.2484528e+01, 5.2682617e+01, 5.1243843e+01, 3.6217686e+01,\n", - " 2.8948278e+01, 2.6579512e+01, 2.1946012e+01, 2.1011120e+01],\n", - " [-4.3566185e-01, -8.7309110e-01, -4.4896263e-01, -6.0569459e-01,\n", - " -4.8134822e-01, -7.0045888e-01, -6.0671657e-01, -5.7662535e-01],\n", - " [-1.9739739e+00, -2.4504409e+00, -1.9982951e+00, -1.4225215e+00,\n", - " -1.9399333e+00, -2.3558097e+00, -1.4185165e+00, -1.4236869e+00],\n", - " [ 9.0437065e+01, 7.4070679e+01, 5.6495895e+01, 4.3069641e+01,\n", - " 3.2367134e+01, 3.3371326e+01, 2.6115334e+01, 2.4602446e+01]],\n", - "\n", - " [[ 8.6492935e+01, 7.0192978e+01, 5.8423912e+01, 5.6638733e+01,\n", - " 4.9270725e+01, 4.1237038e+01, 3.6133625e+01, 3.5519596e+01],\n", - " [ 1.4010678e-01, 2.7912292e-01, 1.4376265e-01, 3.4672296e-01,\n", - " 3.4966472e-01, 1.0524009e-01, 1.2958543e-01, 3.3264065e-01],\n", - " [ 1.9334941e+00, 1.6967584e+00, 1.9219695e+00, 1.6735281e+00,\n", - " 1.6587850e+00, 1.8386338e+00, 1.9120301e+00, 1.6680365e+00],\n", - " [ 8.7343246e+01, 7.2945129e+01, 5.9030762e+01, 6.0084766e+01,\n", - " 5.2313778e+01, 4.1465607e+01, 3.6440781e+01, 3.7506149e+01]]])" - ], - "metadata": { - "id": "2jsmzKcFvx4L" - }, - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "x = x.unsqueeze(0) # batch dimension\n", - "x.shape, x_batch.shape" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Po7Fb1v37EJT", - "outputId": "2571d4f4-e9db-4e2b-ee7f-13a7e406e526" - }, - "execution_count": 3, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(torch.Size([1, 4, 8]), torch.Size([3, 4, 8]))" - ] - }, - "metadata": {}, - "execution_count": 3 - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Quantum Function" - ], - "metadata": { - "id": "-ybT64WoWQ71" - } - }, - { - "cell_type": "code", - "source": [ - "import numpy as np\n", - "np.ComplexWarning = Warning\n", - "\n", - "!pip install tensorcircuit\n", - "from typing import Callable\n", - "\n", - "import tensorcircuit as tc\n", - "import jax.numpy as jnp\n", - "import flax.linen\n", - "\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "import tensorcircuit as tc\n", - "\n", - "K = tc.set_backend(\"jax\")\n", - "\n", - "\n", - "def angle_embedding(c: tc.Circuit, inputs):\n", - " num_qubits = inputs.shape[-1]\n", - "\n", - " for j in range(num_qubits):\n", - " c.rx(j, theta=inputs[j])\n", - "\n", - "\n", - "def basic_vqc(c: tc.Circuit, inputs, weights):\n", - " num_qubits = inputs.shape[-1]\n", - " num_qlayers = weights.shape[-2]\n", - "\n", - " for i in range(num_qlayers):\n", - " for j in range(num_qubits):\n", - " c.rx(j, theta=weights[i, j])\n", - " if num_qubits == 2:\n", - " c.cnot(0, 1)\n", - " elif num_qubits > 2:\n", - " for j in range(num_qubits):\n", - " c.cnot(j, (j + 1) % num_qubits)\n", - "\n", - "\n", - "def get_quantum_layer_circuit(inputs, weights,\n", - " embedding: Callable = angle_embedding, vqc: Callable = basic_vqc):\n", - " \"\"\"\n", - " Equivalent to the following PennyLane circuit:\n", - " def circuit(inputs, weights):\n", - " qml.templates.AngleEmbedding(inputs, wires=range(num_qubits))\n", - " qml.templates.BasicEntanglerLayers(weights, wires=range(num_qubits))\n", - " \"\"\"\n", - "\n", - " num_qubits = inputs.shape[-1]\n", - "\n", - " c = tc.Circuit(num_qubits)\n", - " embedding(c, inputs)\n", - " vqc(c, inputs, weights)\n", - "\n", - " return c\n", - "\n", - "\n", - "def get_circuit(embedding: Callable = angle_embedding, vqc: Callable = basic_vqc,\n", - " torch_interface: bool = False):\n", - " def qpred(inputs, weights):\n", - " c = get_quantum_layer_circuit(inputs, weights, embedding, vqc)\n", - " return K.real(jnp.array([c.expectation_ps(z=[i]) for i in range(weights.shape[1])]))\n", - "\n", - " qpred_batch = K.vmap(qpred, vectorized_argnums=0)\n", - " if torch_interface:\n", - " qpred_batch = tc.interfaces.torch_interface(qpred_batch, jit=True)\n", - "\n", - " return qpred_batch\n", - "\n", - "\n", - "class QuantumLayer(flax.linen.Module):\n", - " circuit: Callable\n", - " num_qubits: int\n", - " w_shape: tuple = (1,)\n", - "\n", - " @flax.linen.compact\n", - " def __call__(self, x):\n", - " shape = x.shape\n", - " x = jnp.reshape(x, (-1, shape[-1]))\n", - " w = self.param('w', flax.linen.initializers.xavier_normal(), self.w_shape + (self.num_qubits,))\n", - " x = self.circuit(x, w)\n", - " x = jnp.concatenate(x, axis=-1)\n", - " x = jnp.reshape(x, tuple(shape))\n", - " return x\n", - "\n", - "\n", - "\n", - "NUM_QUBITS = 8\n", - "NUM_Q_LAYERS = 1\n", - "torch_layer_fn = get_circuit(torch_interface=True)\n", - "\n", - "\n", - "class TCTorchLayer(nn.Module):\n", - " \"\"\"\n", - " A thin PyTorch wrapper around the TensorCircuit/TC quantum layer.\n", - " Stores the circuit's trainable parameters as an nn.Parameter so\n", - " they appear in .parameters() and get updated by any torch optimizer.\n", - " \"\"\"\n", - " def __init__(self, num_qubits=NUM_QUBITS, num_qlayers=NUM_Q_LAYERS):\n", - " super().__init__()\n", - " init_w = 0.01 * torch.randn(num_qlayers, num_qubits)\n", - " self.w = nn.Parameter(init_w)\n", - " self.num_qubits = num_qubits\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\n", - " x: (batch, num_qubits) – already pre-scaled into rotation angles.\n", - " Returns expectation values ⟨Z_i⟩ for every qubit i, shape identical\n", - " to the input (batch, num_qubits).\n", - " \"\"\"\n", - " return torch_layer_fn(x, self.w)\n", - "\n", - "\n", - "class QuantumLinear(nn.Module):\n", - " \"\"\"\n", - " Linear -> angle map -> TCTorchLayer -> Linear\n", - " Works on tensors shaped (..., din) and returns (..., dout).\n", - " \"\"\"\n", - " def __init__(self, din, dout, num_qubits):\n", - " super().__init__()\n", - " self.din = din\n", - " self.dout = dout\n", - " self.nq = num_qubits\n", - "\n", - " self.to_q = nn.Linear(din, self.nq, bias=False)\n", - " self.from_q = nn.Linear(self.nq, dout, bias=False)\n", - " self.q = TCTorchLayer(self.nq)\n", - "\n", - " @staticmethod\n", - " def _to_angles(x):\n", - " return torch.tanh(x) * math.pi\n", - "\n", - " def forward(self, x):\n", - " # x: (..., din)\n", - " *prefix, _ = x.shape\n", - " x = x.reshape(-1, self.din)\n", - "\n", - " x = self.to_q(x)\n", - " x = self._to_angles(x)\n", - " x = self.q(x).float()\n", - " x = self.from_q(x)\n", - "\n", - " x = x.reshape(*prefix, self.dout)\n", - " return x" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "gm-q21_IWUCU", - "outputId": "cfa85478-65d3-4074-fce1-dcd09285b7de" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting tensorcircuit\n", - " Downloading tensorcircuit-0.12.0-py3-none-any.whl.metadata (29 kB)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from tensorcircuit) (2.0.2)\n", - "Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from tensorcircuit) (1.16.0)\n", - "Collecting tensornetwork-ng (from tensorcircuit)\n", - " Downloading tensornetwork_ng-0.5.1-py3-none-any.whl.metadata (7.0 kB)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from tensorcircuit) (3.5)\n", - "Requirement already satisfied: graphviz>=0.11.1 in /usr/local/lib/python3.11/dist-packages (from tensornetwork-ng->tensorcircuit) (0.21)\n", - "Requirement already satisfied: opt-einsum>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from tensornetwork-ng->tensorcircuit) (3.4.0)\n", - "Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.11/dist-packages (from tensornetwork-ng->tensorcircuit) (3.14.0)\n", - "Downloading tensorcircuit-0.12.0-py3-none-any.whl (342 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m342.0/342.0 kB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading tensornetwork_ng-0.5.1-py3-none-any.whl (244 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m244.1/244.1 kB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: tensornetwork-ng, tensorcircuit\n", - "Successfully installed tensorcircuit-0.12.0 tensornetwork-ng-0.5.1\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "WARNING:tensorcircuit.translation:Please first ``pip install -U qiskit`` to enable related functionality in translation module\n", - "WARNING:tensorcircuit.translation:Please first ``pip install -U cirq`` to enable related functionality in translation module\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Toy interaction matrix" - ], - "metadata": { - "id": "X6d5fzsV4kt7" - } - }, - { - "cell_type": "code", - "source": [ - "class InteractionEncoder(nn.Module):\n", - " \"\"\"\n", - " ParT interaction-feature encoder.\n", - "\n", - " Args\n", - " ----\n", - " n_heads per mhsa: output channels d′\n", - " hidden_channels : list[int] for intermediate 1×1 conv layers\n", - " eps : numerical guard for log\n", - " \"\"\"\n", - "\n", - " def __init__(self,\n", - " n_heads: int = 8,\n", - " hidden_channels: list[int] = (64, 64, 64),\n", - " eps: float = 1e-8):\n", - " super().__init__()\n", - " self.eps = eps\n", - "\n", - " layers: list[nn.Module] = []\n", - " in_ch = 4 # lnΔ, ln kT, ln z, ln m²\n", - " for h in hidden_channels:\n", - " layers += [\n", - " nn.Conv2d(in_ch, h, 1, bias=False),\n", - " nn.BatchNorm2d(h),\n", - " nn.GELU()\n", - " ]\n", - " in_ch = h\n", - " layers.append(nn.Conv2d(in_ch, n_heads, 1, bias=False))\n", - " self.net = nn.Sequential(*layers)\n", - "\n", - " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"\n", - " x : (B, 4, N) where the 4 dims are (E, px, py, pz)\n", - " returns\n", - " ------\n", - " U : (B, n_heads, N, N) interaction embedding\n", - " \"\"\"\n", - " B, four, N = x.shape\n", - " assert four == 4, \"input must have 4 features: E, px, py, pz\"\n", - "\n", - " # Split components\n", - " E, px, py, pz = x.unbind(dim=1) # each (B, N)\n", - "\n", - " # Basic kinematics ------------------------------------------------\n", - " pT = torch.sqrt(px**2 + py**2) + self.eps\n", - " phi = torch.atan2(py, px) # (−π, π]\n", - " num = (E + pz).clamp(min=self.eps) #need to avoid negative numbers\n", - " den = (E - pz).clamp(min=self.eps)\n", - " y = 0.5 * torch.log(num / den)\n", - "\n", - " # Expand to (B, N, N)\n", - " y_a, y_b = y.unsqueeze(2), y.unsqueeze(1) # (B,N,1),(B,1,N)\n", - " phi_a, phi_b = phi.unsqueeze(2), phi.unsqueeze(1)\n", - " pT_a, pT_b = pT.unsqueeze(2), pT.unsqueeze(1)\n", - " E_a, E_b = E.unsqueeze(2), E.unsqueeze(1)\n", - " px_a, px_b = px.unsqueeze(2), px.unsqueeze(1)\n", - " py_a, py_b = py.unsqueeze(2), py.unsqueeze(1)\n", - " pz_a, pz_b = pz.unsqueeze(2), pz.unsqueeze(1)\n", - "\n", - " # ΔR, kT, z\n", - " delta = torch.sqrt((y_a - y_b) ** 2 + (phi_a - phi_b) ** 2) + self.eps\n", - " kT = torch.minimum(pT_a, pT_b) * delta\n", - " z = torch.minimum(pT_a, pT_b) / (pT_a + pT_b + self.eps)\n", - "\n", - " # m² of pair\n", - " E_sum = E_a + E_b\n", - " px_sum = px_a + px_b\n", - " py_sum = py_a + py_b\n", - " pz_sum = pz_a + pz_b\n", - " m2 = E_sum**2 - (px_sum**2 + py_sum**2 + pz_sum**2) + self.eps\n", - " m2 = torch.clamp(m2, min=self.eps) # avoid negatives\n", - "\n", - " # Stack → (B, 4, N, N)\n", - " feats = torch.stack([\n", - " torch.log(delta),\n", - " torch.log(kT),\n", - " torch.log(z),\n", - " torch.log(m2)\n", - " ], dim=1)\n", - "\n", - " # conv\n", - " U = self.net(feats) # (B, n_heads, N, N)\n", - " return U\n", - "\n", - "\n", - "\n", - "B, _, N = x.shape\n", - "n_heads = 2 # d′\n", - "enc = InteractionEncoder(n_heads=n_heads)\n", - "U = enc(x)\n", - "print(\"U.shape:\", U.shape)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "sOjDJYr14zDr", - "outputId": "aa8b4aef-a5c1-4e3e-a125-6e86079468cd" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "U.shape: torch.Size([1, 2, 8, 8])\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Particle Transformer" - ], - "metadata": { - "id": "5CwQyoPYwaR7" - } - }, - { - "cell_type": "code", - "source": [ - "class ParticleTokenizer(nn.Module):\n", - " def __init__(self, in_dim=4, out_dim=6):\n", - " super().__init__()\n", - " self.proj = nn.Linear(in_dim, out_dim)\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\n", - " x: tensor of shape (B, n_particles, in_dim)\n", - " returns: (B, n_particles, out_dim)\n", - " \"\"\"\n", - " x = x.transpose(1, 2) # Input shape: (B, n_particles, in_dim) → (B, in_dim, n_particles)\n", - " return self.proj(x)\n", - "\n", - "tokenizer = ParticleTokenizer(4, 10)\n", - "output = tokenizer(x)\n", - "output.shape" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Ya5SvxTJwE3r", - "outputId": "27b1b1d0-4863-41b2-99f9-4daadd27688a" - }, - "execution_count": 6, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "torch.Size([1, 8, 10])" - ] - }, - "metadata": {}, - "execution_count": 6 - } - ] - }, - { - "cell_type": "code", - "source": [ - "class MLP(nn.Module):\n", - " \"\"\"\n", - " Same interface as your tiny MLP, but nn.Linear -> QuantumLinear.\n", - " Works for inputs shaped (..., dim).\n", - "\n", - " Args:\n", - " dim : feature size\n", - " dropout : dropout prob\n", - " num_qubits : qubits per QuantumLinear block (defaults to dim)\n", - " \"\"\"\n", - " def __init__(self, dim, dropout=0., num_qubits=None):\n", - " super().__init__()\n", - " nq = num_qubits if num_qubits is not None else dim\n", - "\n", - " self.fc1 = QuantumLinear(dim, dim, nq)\n", - " self.fc2 = QuantumLinear(dim, dim, nq)\n", - "\n", - " self.act = nn.GELU()\n", - " self.do1 = nn.Dropout(dropout)\n", - " self.do2 = nn.Dropout(dropout)\n", - "\n", - " def forward(self, x):\n", - " x = self.fc1(x)\n", - " x = self.act(x)\n", - " x = self.do1(x)\n", - "\n", - " x = self.fc2(x)\n", - " x = self.do2(x)\n", - " return x\n", - "\n", - "# usage\n", - "mlp = MLP(10, dropout=0.1)\n", - "output = mlp(output)\n", - "output.shape" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "S5f8DTRY07PD", - "outputId": "edb84881-2ccc-456a-bdc7-59ad23397fd5" - }, - "execution_count": 7, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "torch.Size([1, 8, 10])" - ] - }, - "metadata": {}, - "execution_count": 7 - } - ] - }, - { - "cell_type": "code", - "source": [ - "class ParticleMHA(nn.Module):\n", - " \"\"\"\n", - " Multi-head self-attention with quantum projections (q, k, v, o).\n", - "\n", - " Args\n", - " ----\n", - " d : embedding dim\n", - " heads : number of attention heads\n", - " dropout : dropout prob on attn weights\n", - " return_attn : return attention maps?\n", - " num_qubits : qubits per quantum block (defaults to d)\n", - " \"\"\"\n", - " def __init__(self, d: int, heads: int = 8,\n", - " dropout: float = 0.1, return_attn: bool = False,\n", - " num_qubits: int | None = None):\n", - " super().__init__()\n", - " assert d % heads == 0, \"`d` must be divisible by `heads`\"\n", - "\n", - " self.d = d\n", - " self.h = heads\n", - " self.d_head = d // heads\n", - " self.scale = 1 / math.sqrt(self.d_head)\n", - " self.return_attn = return_attn\n", - "\n", - " nq = num_qubits if num_qubits is not None else d\n", - "\n", - " # quantum projections\n", - " self.q_proj = QuantumLinear(d, d, nq)\n", - " self.k_proj = QuantumLinear(d, d, nq)\n", - " self.v_proj = QuantumLinear(d, d, nq)\n", - " self.o_proj = QuantumLinear(d, d, nq)\n", - "\n", - " self.drop = nn.Dropout(dropout)\n", - "\n", - " def _split(self, t: torch.Tensor):\n", - " # (B, N, d) -> (B, H, N, d_head)\n", - " B, N, _ = t.shape\n", - " return t.view(B, N, self.h, self.d_head).transpose(1, 2)\n", - "\n", - " def forward(self, x: torch.Tensor, U: torch.Tensor | None = None):\n", - " B, N, _ = x.shape\n", - "\n", - " Q = self._split(self.q_proj(x))\n", - " K = self._split(self.k_proj(x))\n", - " V = self._split(self.v_proj(x))\n", - "\n", - " logits = (Q @ K.transpose(-2, -1)) * self.scale # (B, H, N, N)\n", - "\n", - " if U is not None:\n", - " logits = logits + U\n", - "\n", - " attn = F.softmax(logits, dim=-1)\n", - " attn = self.drop(attn)\n", - "\n", - " context = attn @ V # (B, H, N, d_head)\n", - "\n", - " context = (\n", - " context.transpose(1, 2) # (B, N, H, d_head)\n", - " .contiguous()\n", - " .view(B, N, self.d)\n", - " )\n", - " out = self.o_proj(context)\n", - "\n", - " if self.return_attn:\n", - " return out, attn\n", - " else:\n", - " return out\n", - "\n", - "B, N, d = output.shape\n", - "\n", - "U = torch.randn(1, 2, N, N) # broadcast to (B, H, N, N)\n", - "\n", - "pmha = ParticleMHA(d=d, heads=2, dropout=0.1, return_attn=True)\n", - "output, A = pmha(output, U) # out: (B, N, d) A: (B, 8, N, N)\n", - "\n", - "print(output.shape, A.shape)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zb5nYOsSyFRk", - "outputId": "7335d6f6-9642-45ad-90b4-981175a47f6d" - }, - "execution_count": 8, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([1, 8, 10]) torch.Size([1, 2, 8, 8])\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## transformer" - ], - "metadata": { - "id": "4650yyFtdjaR" - } - }, - { - "cell_type": "code", - "source": [ - "class MHA(nn.Module):\n", - " \"\"\"\n", - " Multi-head attention (batch_first) with QuantumLinear projections.\n", - "\n", - " Args\n", - " ----\n", - " d_model : int embedding dim\n", - " n_heads : int\n", - " dropout: float\n", - " bias : bool (ignored here, QuantumLinear has no bias)\n", - " num_qubits : int|None qubits per quantum block (defaults to d_model)\n", - " \"\"\"\n", - " def __init__(self, d_model: int, n_heads: int,\n", - " dropout: float = 0., bias: bool = False,\n", - " num_qubits: int | None = None):\n", - " super().__init__()\n", - " assert d_model % n_heads == 0, \"`d_model` must be divisible by `n_heads`\"\n", - " self.d_model = d_model\n", - " self.h = n_heads\n", - " self.d_head = d_model // n_heads\n", - " self.scale = self.d_head ** -0.5\n", - "\n", - " nq = num_qubits if num_qubits is not None else d_model\n", - "\n", - " # Quantum projections replace nn.Linear\n", - " self.q_proj = QuantumLinear(d_model, d_model, nq)\n", - " self.k_proj = QuantumLinear(d_model, d_model, nq)\n", - " self.v_proj = QuantumLinear(d_model, d_model, nq)\n", - " self.o_proj = QuantumLinear(d_model, d_model, nq)\n", - "\n", - " self.drop = nn.Dropout(dropout)\n", - "\n", - " def _split_heads(self, x: torch.Tensor):\n", - " # (B, L, d_model) -> (B, h, L, d_head)\n", - " B, L, _ = x.shape\n", - " return x.view(B, L, self.h, self.d_head).transpose(1, 2)\n", - "\n", - " def _merge_heads(self, x: torch.Tensor):\n", - " # (B, h, L, d_head) -> (B, L, d_model)\n", - " B, H, L, Dh = x.shape\n", - " return x.transpose(1, 2).contiguous().view(B, L, H * Dh)\n", - "\n", - " def forward(\n", - " self,\n", - " q: torch.Tensor, # (B, Lq, d_model)\n", - " k: torch.Tensor, # (B, Lk, d_model)\n", - " v: torch.Tensor, # (B, Lk, d_model)\n", - " attn_mask: torch.Tensor | None = None,\n", - " key_padding_mask: torch.Tensor | None = None,\n", - " need_weights: bool = False\n", - " ):\n", - " B, Lq, _ = q.shape\n", - " _, Lk, _ = k.shape\n", - "\n", - " Q = self._split_heads(self.q_proj(q)) # (B,h,Lq,d_h)\n", - " K = self._split_heads(self.k_proj(k)) # (B,h,Lk,d_h)\n", - " V = self._split_heads(self.v_proj(v)) # (B,h,Lk,d_h)\n", - "\n", - " logits = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B,h,Lq,Lk)\n", - "\n", - " attn = F.softmax(logits, dim=-1)\n", - " attn = self.drop(attn)\n", - "\n", - " context = torch.matmul(attn, V) # (B,h,Lq,d_h)\n", - "\n", - " out = self.o_proj(self._merge_heads(context)) # (B,Lq,d_model)\n", - "\n", - " if need_weights:\n", - " return out, attn.mean(dim=1) # (B,Lq,Lk)\n", - " return out, None" - ], - "metadata": { - "id": "CKYo7jEiyYx1" - }, - "execution_count": 9, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# Particle attention block (NormFormer style + U-bias)\n", - "class ParticleAttentionBlock(nn.Module):\n", - " def __init__(self, dim, heads, mlp_ratio=4, dropout=0.1):\n", - " super().__init__()\n", - " self.ln1 = nn.LayerNorm(dim)\n", - " self.attn = ParticleMHA(dim, heads, dropout)\n", - " self.ln2 = nn.LayerNorm(dim)\n", - " self.mlp = MLP(dim, dropout)\n", - " def forward(self, x, U):\n", - " x = x + self.attn(self.ln1(x), U) # bias-aware MHSA\n", - " x = x + self.mlp(self.ln2(x)) # feed-forward\n", - " return x\n", - "\n", - "# Class attention block (CaiT style, no U)\n", - "class ClassAttentionBlock(nn.Module):\n", - " def __init__(self, dim, heads, mlp_ratio=4, dropout=0.):\n", - " super().__init__()\n", - " self.ln1 = nn.LayerNorm(dim)\n", - " self.attn = MHA(dim, heads, dropout)\n", - " self.ln2 = nn.LayerNorm(dim)\n", - " self.mlp = MLP(dim, dropout)\n", - " def forward(self, tokens, cls): # tokens: (B,N,d), cls: (B,1,d)\n", - " z = torch.cat([cls, tokens], dim=1) # (B,1+N,d)\n", - " q = self.ln1(cls)\n", - " kv = self.ln1(z)\n", - " cls = cls + self.attn(q, kv, kv, need_weights=False)[0]\n", - " cls = cls + self.mlp(self.ln2(cls))\n", - " return cls # (B,1,d)\n", - "\n", - "# Complete Particle Transformer\n", - "class ParT(nn.Module):\n", - " def __init__(self,\n", - " in_dim=4, # (E,px,py,pz)\n", - " embed_dim=10,\n", - " n_heads=2,\n", - " depth=2, # particle blocks\n", - " class_depth=2, # class-attention blocks\n", - " mlp_ratio=4,\n", - " num_classes=10,\n", - " dropout=0.1):\n", - " super().__init__()\n", - "\n", - " self.tokenizer = ParticleTokenizer(in_dim, embed_dim)\n", - " self.U_encoder = InteractionEncoder(n_heads=n_heads)\n", - "\n", - " self.blocks = nn.ModuleList([\n", - " ParticleAttentionBlock(embed_dim, n_heads, mlp_ratio, dropout)\n", - " for _ in range(depth)\n", - " ])\n", - "\n", - " self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n", - " self.cls_blocks = nn.ModuleList([\n", - " ClassAttentionBlock(embed_dim, n_heads, mlp_ratio, 0.0)\n", - " for _ in range(class_depth)\n", - " ])\n", - "\n", - " self.head = nn.Linear(embed_dim, num_classes)\n", - "\n", - " # weight init\n", - " nn.init.trunc_normal_(self.class_token, std=0.02)\n", - " nn.init.trunc_normal_(self.head.weight, std=0.02)\n", - " nn.init.zeros_(self.head.bias)\n", - "\n", - " def forward(self, x): # x: (B,4,N)\n", - " B, _, N = x.shape\n", - "\n", - " tokens = self.tokenizer(x) # (B,N,d)\n", - " U = self.U_encoder(x) # (B,H,N,N)\n", - "\n", - " for blk in self.blocks:\n", - " tokens = blk(tokens, U) # (B,N,d)\n", - "\n", - " cls = self.class_token.expand(B, -1, -1) # (B,1,d)\n", - " for blk in self.cls_blocks:\n", - " cls = blk(tokens, cls) # (B,1,d)\n", - "\n", - " logits = self.head(cls.squeeze(1)) # (B,10)\n", - " return logits" - ], - "metadata": { - "id": "iQveEBhldizq" - }, - "execution_count": 10, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "B, _, N = x_batch.shape # (3,4,8)\n", - "model = ParT(in_dim=4,\n", - " embed_dim=10,\n", - " n_heads=2,\n", - " depth=2,\n", - " class_depth=2,\n", - " num_classes=10)\n", - "\n", - "logits = model(x_batch) # forward pass\n", - "print(\"logits:\", logits.shape) # torch.Size([3, 10])\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0-g5CQ2hfh2x", - "outputId": "5a6e20d4-2855-46ad-a949-a19f9af36776" - }, - "execution_count": 11, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "logits: torch.Size([3, 10])\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "x_train = x_batch # (3, 4, 8)\n", - "y_train = torch.tensor([0, 1, 2]) # dummy class labels for testing\n", - "\n", - "model = ParT(in_dim=4,\n", - " embed_dim=10,\n", - " n_heads=2,\n", - " depth=2,\n", - " class_depth=2,\n", - " num_classes=10)\n", - "\n", - "criterion = nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", - "\n", - "n_epochs = 250\n", - "for epoch in range(n_epochs):\n", - " model.train()\n", - " logits = model(x_train) # (3, 10)\n", - " loss = criterion(logits, y_train)\n", - "\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " # print every 5 epochs\n", - " if (epoch+1) % 5 == 0 or epoch == 0:\n", - " preds = logits.argmax(1)\n", - " acc = (preds == y_train).float().mean().item()\n", - " print(f\"epoch {epoch+1:3d} loss {loss.item():.4f} acc {acc:.3f}\")\n", - "\n", - "model.eval()\n", - "with torch.no_grad():\n", - " probs = torch.softmax(model(x_train), dim=1)\n", - "print(\"softmax-probs\\n\", probs)\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mOr6fyLvfjgv", - "outputId": "bc187fcb-b348-47ce-d71c-e52e55889367" - }, - "execution_count": 12, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "epoch 1 loss 2.3297 acc 0.000\n", - "epoch 5 loss 2.2604 acc 0.000\n", - "epoch 10 loss 2.1752 acc 0.333\n", - "epoch 15 loss 2.0782 acc 0.333\n", - "epoch 20 loss 1.9541 acc 0.333\n", - "epoch 25 loss 1.8064 acc 0.333\n", - "epoch 30 loss 1.6569 acc 0.333\n", - "epoch 35 loss 1.5183 acc 0.333\n", - "epoch 40 loss 1.3999 acc 0.333\n", - "epoch 45 loss 1.3081 acc 0.333\n", - "epoch 50 loss 1.2424 acc 0.333\n", - "epoch 55 loss 1.1977 acc 0.333\n", - "epoch 60 loss 1.1684 acc 0.333\n", - "epoch 65 loss 1.1496 acc 0.333\n", - "epoch 70 loss 1.1375 acc 0.333\n", - "epoch 75 loss 1.1293 acc 0.333\n", - "epoch 80 loss 1.1235 acc 0.333\n", - "epoch 85 loss 1.1192 acc 0.333\n", - "epoch 90 loss 1.1162 acc 0.333\n", - "epoch 95 loss 1.1138 acc 0.333\n", - "epoch 100 loss 1.1119 acc 0.333\n", - "epoch 105 loss 1.1105 acc 0.333\n", - "epoch 110 loss 1.1093 acc 0.333\n", - "epoch 115 loss 1.1083 acc 0.333\n", - "epoch 120 loss 1.1075 acc 0.333\n", - "epoch 125 loss 1.1067 acc 0.333\n", - "epoch 130 loss 1.1061 acc 0.333\n", - "epoch 135 loss 1.1056 acc 0.333\n", - "epoch 140 loss 1.1051 acc 0.333\n", - "epoch 145 loss 1.1046 acc 0.333\n", - "epoch 150 loss 1.1043 acc 0.333\n", - "epoch 155 loss 1.1039 acc 0.333\n", - "epoch 160 loss 1.1036 acc 0.333\n", - "epoch 165 loss 1.1033 acc 0.333\n", - "epoch 170 loss 1.1030 acc 0.333\n", - "epoch 175 loss 1.1028 acc 0.333\n", - "epoch 180 loss 1.1026 acc 0.333\n", - "epoch 185 loss 1.1024 acc 0.333\n", - "epoch 190 loss 1.1022 acc 0.333\n", - "epoch 195 loss 1.1020 acc 0.333\n", - "epoch 200 loss 1.1018 acc 0.333\n", - "epoch 205 loss 1.1017 acc 0.333\n", - "epoch 210 loss 1.1015 acc 0.333\n", - "epoch 215 loss 1.1014 acc 0.333\n", - "epoch 220 loss 1.1013 acc 0.333\n", - "epoch 225 loss 1.1012 acc 0.333\n", - "epoch 230 loss 1.1011 acc 0.333\n", - "epoch 235 loss 1.1010 acc 0.333\n", - "epoch 240 loss 1.1009 acc 0.333\n", - "epoch 245 loss 1.1008 acc 0.333\n", - "epoch 250 loss 1.1007 acc 0.333\n", - "softmax-probs\n", - " tensor([[3.3254e-01, 3.3260e-01, 3.3277e-01, 2.6419e-04, 2.7729e-04, 2.5344e-04,\n", - " 3.0712e-04, 3.1026e-04, 3.2025e-04, 3.5913e-04],\n", - " [3.3253e-01, 3.3260e-01, 3.3278e-01, 2.6400e-04, 2.7708e-04, 2.5325e-04,\n", - " 3.0691e-04, 3.1005e-04, 3.2001e-04, 3.5888e-04],\n", - " [3.3253e-01, 3.3260e-01, 3.3278e-01, 2.6388e-04, 2.7696e-04, 2.5314e-04,\n", - " 3.0678e-04, 3.0992e-04, 3.1987e-04, 3.5873e-04]])\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Load official Data" - ], - "metadata": { - "id": "Z7bm2U4EYwY1" - } - }, - { - "cell_type": "code", - "source": [ - "!git clone https://github.com/jet-universe/particle_transformer.git\n", - "!cd particle_transformer\n", - "!cd /content/particle_transformer\n", - "!touch env.sh\n", - "!chmod +x env.sh" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ES5wefiuY5ca", - "outputId": "99ac4161-bab8-4015-81d0-ae740123a0e5" - }, - "execution_count": 12, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Cloning into 'particle_transformer'...\n", - "remote: Enumerating objects: 101, done.\u001b[K\n", - "remote: Counting objects: 100% (52/52), done.\u001b[K\n", - "remote: Compressing objects: 100% (25/25), done.\u001b[K\n", - "remote: Total 101 (delta 38), reused 27 (delta 27), pack-reused 49 (from 1)\u001b[K\n", - "Receiving objects: 100% (101/101), 28.08 MiB | 14.94 MiB/s, done.\n", - "Resolving deltas: 100% (46/46), done.\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "!/content/particle_transformer/get_datasets.py JetClass -d ./datasets\n", - "!source env.sh\n", - "import os, glob, tarfile\n", - "os.environ['DATADIR_JetClass'] = os.path.abspath('./datasets/JetClass')\n", - "data_dir = os.environ['DATADIR_JetClass']\n", - "!pip install awkward uproot vector\n", - "from particle_transformer.dataloader import read_file\n", - "\n", - "tar_path = \"/content/datasets/JetClass/JetClass_Pythia_val_5M.tar\"\n", - "\n", - "extract_dir = \"/content/datasets/JetClass/JetClass_Pythia_val_5M\"\n", - "os.makedirs(extract_dir, exist_ok=True)" - ], - "metadata": { - "id": "5IVg4AkAK1xz", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "75cdcc78-2bd5-4eab-c180-f748fbc65e19" - }, - "execution_count": 14, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://zenodo.org/record/6619768/files/JetClass_Pythia_val_5M.tar to ./datasets/JetClass/JetClass_Pythia_val_5M.tar\n", - "./datasets/JetClass/JetClass_Pythia_val_5M.tar: 100% 7.07G/7.07G [10:19<00:00, 12.3MiB/s]\n", - "Updated dataset path in env.sh to \"DATADIR_JetClass=./datasets/JetClass\".\n", - "Collecting awkward\n", - " Downloading awkward-2.8.5-py3-none-any.whl.metadata (6.9 kB)\n", - "Collecting uproot\n", - " Downloading uproot-5.6.3-py3-none-any.whl.metadata (33 kB)\n", - "Collecting vector\n", - " Downloading vector-1.6.3-py3-none-any.whl.metadata (16 kB)\n", - "Collecting awkward-cpp==47 (from awkward)\n", - " Downloading awkward_cpp-47-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (2.1 kB)\n", - "Requirement already satisfied: fsspec>=2022.11.0 in /usr/local/lib/python3.11/dist-packages (from awkward) (2025.3.0)\n", - "Requirement already satisfied: importlib-metadata>=4.13.0 in /usr/local/lib/python3.11/dist-packages (from awkward) (8.7.0)\n", - "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.11/dist-packages (from awkward) (2.0.2)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from awkward) (25.0)\n", - "Requirement already satisfied: cramjam>=2.5.0 in /usr/local/lib/python3.11/dist-packages (from uproot) (2.10.0)\n", - "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from uproot) (3.5.0)\n", - "Requirement already satisfied: zipp>=3.20 in /usr/local/lib/python3.11/dist-packages (from importlib-metadata>=4.13.0->awkward) (3.23.0)\n", - "Downloading awkward-2.8.5-py3-none-any.whl (886 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m886.8/886.8 kB\u001b[0m \u001b[31m54.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading awkward_cpp-47-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (638 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m638.8/638.8 kB\u001b[0m \u001b[31m48.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading uproot-5.6.3-py3-none-any.whl (382 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.8/382.8 kB\u001b[0m \u001b[31m35.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading vector-1.6.3-py3-none-any.whl (179 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.6/179.6 kB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: vector, awkward-cpp, awkward, uproot\n", - "Successfully installed awkward-2.8.5 awkward-cpp-47 uproot-5.6.3 vector-1.6.3\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "if not any(fname.endswith(\".root\") for fname in os.listdir(extract_dir)):\n", - " print(\"⏬ extracting test-set…\")\n", - " with tarfile.open(tar_path) as tar:\n", - " tar.extractall(path=extract_dir)\n", - "pattern = os.path.join(extract_dir, 'val_5M', \"*.root\")\n", - "files = sorted(glob.glob(pattern))\n", - "print(f\"Found {len(files)} ROOT files\")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4QtS-37tY8wa", - "outputId": "09cebca9-fb6c-459e-b75b-685679784bf0" - }, - "execution_count": 15, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "⏬ extracting test-set…\n", - "Found 50 ROOT files\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "import torch\n", - "from torch.utils.data import Dataset\n", - "\n", - "all_x_parts = []\n", - "all_ys = []\n", - "\n", - "num_file = 1\n", - "for file in files:\n", - " num_file += 1\n", - " if num_file % 5 == 0:\n", - " x_part, x_jets, y = read_file(\n", - " file,\n", - " max_num_particles=8,\n", - " particle_features=['part_pt', 'part_eta', 'part_phi', 'part_energy'],\n", - " jet_features=['jet_pt', 'jet_eta', 'jet_phi', 'jet_energy'],\n", - " labels=[\n", - " 'label_QCD', 'label_Hbb', 'label_Hcc', 'label_Hgg', 'label_H4q',\n", - " 'label_Hqql', 'label_Zqq', 'label_Wqq', 'label_Tbqq', 'label_Tbl',\n", - " ]\n", - " )\n", - " all_x_parts.append(torch.tensor(x_part, dtype=torch.float32)[:100,:,:])\n", - " all_ys.append(torch.tensor(y, dtype=torch.float32)[:100,:])\n", - "\n", - "x_all = torch.cat(all_x_parts, dim=0)\n", - "y_all = torch.cat(all_ys, dim=0)\n", - "print(x_all.shape, y_all.shape)\n", - "\n", - "class JetDataset(Dataset):\n", - " def __init__(self, x, y):\n", - " self.x = x\n", - " self.y = y\n", - "\n", - " def __len__(self):\n", - " return self.x.shape[0]\n", - "\n", - " def __getitem__(self, idx):\n", - " return self.x[idx], self.y[idx]\n", - "\n", - "dataset = JetDataset(x_all, y_all)\n", - "\n", - "from torch.utils.data import DataLoader\n", - "dataloader = DataLoader(dataset, batch_size=64, shuffle=True)" - ], - "metadata": { - "id": "-q3rmopucyqr", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "84d9b813-ea09-4a33-e54c-222b66ce5e7a" - }, - "execution_count": 17, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([1000, 4, 8]) torch.Size([1000, 10])\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "model = ParT(\n", - " in_dim=4, # part_pt, eta, phi, energy\n", - " embed_dim=10,\n", - " n_heads=2,\n", - " depth=2,\n", - " class_depth=2,\n", - " num_classes=10\n", - ")\n", - "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", - "model = model.to(device)" - ], - "metadata": { - "id": "4bNN4pfUfldh" - }, - "execution_count": 23, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "loss_fn = nn.CrossEntropyLoss()\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)" - ], - "metadata": { - "id": "cGmJ9BJbfzL5" - }, - "execution_count": 24, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "num_epochs = 100\n", - "\n", - "for epoch in range(num_epochs):\n", - " model.train()\n", - " epoch_loss = 0.0\n", - " for batch_idx, (x, y) in enumerate(dataloader):\n", - " x, y = x.to(device), y.to(device)\n", - "\n", - " optimizer.zero_grad()\n", - " outputs = model(x) # shape [batch, 10]\n", - "\n", - " loss = loss_fn(outputs, y)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " epoch_loss += loss.item()\n", - "\n", - " avg_loss = epoch_loss / len(dataloader)\n", - " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}\")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "jJZR1RgvgqXZ", - "outputId": "7b88db76-d9df-4dbf-a0af-ae13a4fbb81d" - }, - "execution_count": 25, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Epoch 1/100, Loss: 2.3056\n", - "Epoch 2/100, Loss: 2.3032\n", - "Epoch 3/100, Loss: 2.3018\n", - "Epoch 4/100, Loss: 2.2987\n", - "Epoch 5/100, Loss: 2.2935\n", - "Epoch 6/100, Loss: 2.2955\n", - "Epoch 7/100, Loss: 2.2838\n", - "Epoch 8/100, Loss: 2.2749\n", - "Epoch 9/100, Loss: 2.2692\n", - "Epoch 10/100, Loss: 2.2611\n", - "Epoch 11/100, Loss: 2.2566\n", - "Epoch 12/100, Loss: 2.2437\n", - "Epoch 13/100, Loss: 2.2414\n", - "Epoch 14/100, Loss: 2.2341\n", - "Epoch 15/100, Loss: 2.2255\n", - "Epoch 16/100, Loss: 2.2330\n", - "Epoch 17/100, Loss: 2.2221\n", - "Epoch 18/100, Loss: 2.2226\n", - "Epoch 19/100, Loss: 2.2205\n", - "Epoch 20/100, Loss: 2.2116\n", - "Epoch 21/100, Loss: 2.2099\n", - "Epoch 22/100, Loss: 2.2131\n", - "Epoch 23/100, Loss: 2.2063\n", - "Epoch 24/100, Loss: 2.2195\n", - "Epoch 25/100, Loss: 2.2166\n", - "Epoch 26/100, Loss: 2.2140\n", - "Epoch 27/100, Loss: 2.2045\n", - "Epoch 28/100, Loss: 2.2064\n", - "Epoch 29/100, Loss: 2.2087\n", - "Epoch 30/100, Loss: 2.2087\n", - "Epoch 31/100, Loss: 2.2096\n", - "Epoch 32/100, Loss: 2.2064\n", - "Epoch 33/100, Loss: 2.2051\n", - "Epoch 34/100, Loss: 2.2003\n", - "Epoch 35/100, Loss: 2.2138\n", - "Epoch 36/100, Loss: 2.2038\n", - "Epoch 37/100, Loss: 2.1994\n", - "Epoch 38/100, Loss: 2.1936\n", - "Epoch 39/100, Loss: 2.2029\n", - "Epoch 40/100, Loss: 2.2045\n", - "Epoch 41/100, Loss: 2.2124\n", - "Epoch 42/100, Loss: 2.1949\n", - "Epoch 43/100, Loss: 2.1994\n", - "Epoch 44/100, Loss: 2.1981\n", - "Epoch 45/100, Loss: 2.2010\n", - "Epoch 46/100, Loss: 2.2036\n", - "Epoch 47/100, Loss: 2.1906\n", - "Epoch 48/100, Loss: 2.1960\n", - "Epoch 49/100, Loss: 2.1875\n", - "Epoch 50/100, Loss: 2.1938\n", - "Epoch 51/100, Loss: 2.1916\n", - "Epoch 52/100, Loss: 2.1815\n", - "Epoch 53/100, Loss: 2.1875\n", - "Epoch 54/100, Loss: 2.1931\n", - "Epoch 55/100, Loss: 2.1950\n", - "Epoch 56/100, Loss: 2.1820\n", - "Epoch 57/100, Loss: 2.1887\n", - "Epoch 58/100, Loss: 2.1857\n", - "Epoch 59/100, Loss: 2.1884\n", - "Epoch 60/100, Loss: 2.1908\n", - "Epoch 61/100, Loss: 2.1738\n", - "Epoch 62/100, Loss: 2.1753\n", - "Epoch 63/100, Loss: 2.1776\n", - "Epoch 64/100, Loss: 2.1713\n", - "Epoch 65/100, Loss: 2.1603\n", - "Epoch 66/100, Loss: 2.1666\n", - "Epoch 67/100, Loss: 2.1701\n", - "Epoch 68/100, Loss: 2.1823\n", - "Epoch 69/100, Loss: 2.1682\n", - "Epoch 70/100, Loss: 2.1808\n", - "Epoch 71/100, Loss: 2.1734\n", - "Epoch 72/100, Loss: 2.1667\n", - "Epoch 73/100, Loss: 2.1597\n", - "Epoch 74/100, Loss: 2.1635\n", - "Epoch 75/100, Loss: 2.1517\n", - "Epoch 76/100, Loss: 2.1617\n", - "Epoch 77/100, Loss: 2.1848\n", - "Epoch 78/100, Loss: 2.1804\n", - "Epoch 79/100, Loss: 2.1683\n", - "Epoch 80/100, Loss: 2.1565\n", - "Epoch 81/100, Loss: 2.1649\n", - "Epoch 82/100, Loss: 2.1577\n", - "Epoch 83/100, Loss: 2.1343\n", - "Epoch 84/100, Loss: 2.1675\n", - "Epoch 85/100, Loss: 2.1410\n", - "Epoch 86/100, Loss: 2.1302\n", - "Epoch 87/100, Loss: 2.1299\n", - "Epoch 88/100, Loss: 2.1272\n", - "Epoch 89/100, Loss: 2.1290\n", - "Epoch 90/100, Loss: 2.1543\n", - "Epoch 91/100, Loss: 2.1378\n", - "Epoch 92/100, Loss: 2.1691\n", - "Epoch 93/100, Loss: 2.1342\n", - "Epoch 94/100, Loss: 2.1216\n", - "Epoch 95/100, Loss: 2.1249\n", - "Epoch 96/100, Loss: 2.1066\n", - "Epoch 97/100, Loss: 2.1149\n", - "Epoch 98/100, Loss: 2.1009\n", - "Epoch 99/100, Loss: 2.1167\n", - "Epoch 100/100, Loss: 2.0952\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "from torch.nn.functional import sigmoid, softmax\n", - "\n", - "model.eval()\n", - "correct = 0\n", - "total = 0\n", - "\n", - "with torch.no_grad():\n", - " for x, y in dataloader:\n", - " x, y = x.to(device), y.to(device)\n", - " outputs = model(x)\n", - " labels = torch.argmax(y, dim=1) # convert one-hot to class id\n", - " preds = torch.argmax(outputs, dim=1) # predicted class\n", - " correct += (preds == labels).sum().item()\n", - " total += y.size(0)\n", - "accuracy = correct / total\n", - "print(f\"Accuracy on full dataset: {accuracy:.4f}\")\n" - ], - "metadata": { - "id": "ogHJm4wFhBMN", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "bffea42e-0f4b-4364-9dff-0978122331d1" - }, - "execution_count": 26, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Accuracy on full dataset: 0.2010\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "from sklearn.metrics import roc_auc_score\n", - "import numpy as np\n", - "\n", - "model.eval()\n", - "all_outputs = []\n", - "all_targets = []\n", - "\n", - "with torch.no_grad():\n", - " for x, y in dataloader:\n", - " x, y = x.to(device), y.to(device)\n", - " outputs = model(x)\n", - " all_outputs.append(outputs.cpu())\n", - " all_targets.append(y.cpu())\n", - "\n", - "# Concatenate batches\n", - "all_outputs = torch.cat(all_outputs, dim=0)\n", - "all_targets = torch.cat(all_targets, dim=0)\n", - "\n", - "# Apply sigmoid (if using BCEWithLogitsLoss)\n", - "probs = sigmoid(all_outputs).numpy() # shape: (N, C)\n", - "true = all_targets.numpy() # shape: (N, C)\n", - "\n", - "# Compute AUC for each class and average\n", - "try:\n", - " auc_macro = roc_auc_score(true, probs, average='macro', multi_class='ovr')\n", - " print(f\"Macro-Averaged AUC: {auc_macro:.4f}\")\n", - "except ValueError as e:\n", - " print(\"AUC could not be computed:\", e)\n" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "6ehi3CwmaDhL", - "outputId": "dc360542-831d-4cc6-e063-982c60499e9a" - }, - "execution_count": 27, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Macro-Averaged AUC: 0.6697\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "TPgU9VB2ajLT" - }, - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Quantum_ParT_components.ipynb b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Quantum_ParT_components.ipynb new file mode 100644 index 0000000..07855fa --- /dev/null +++ b/Quantum_Transformers_Alessandro_Tesi/Quantum ParT/Quantum_ParT_components.ipynb @@ -0,0 +1,1155 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "smpEEHlypz0p" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import math" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PAgmRCdTvyg6" + }, + "source": [ + "## Toy Data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "2jsmzKcFvx4L" + }, + "outputs": [], + "source": [ + "x = torch.Tensor([[ 8.7541382e+01, 7.4368027e+01, 6.8198807e+01, 5.5106983e+01,\n", + " 3.8173203e+01, 2.3811323e+01, 1.8535612e+01, 1.8095873e+01],\n", + " [-5.1816605e-02, 4.1964740e-01, 4.4769207e-01, -6.5377988e-02,\n", + " 6.1878175e-01, 6.2997532e-01, -8.1161506e-02, -7.9315454e-02],\n", + " [ 2.4468634e+00, 2.2728169e+00, 2.2340939e+00, 2.6091626e+00,\n", + " 1.8117365e+00, 1.7758595e+00, 2.3642292e+00, 2.5458102e+00],\n", + " [ 8.7658936e+01, 8.1014442e+01, 7.5148209e+01, 5.5224800e+01,\n", + " 4.5717682e+01, 2.8694658e+01, 1.8596695e+01, 1.8153358e+01]])\n", + "\n", + "x_batch = torch.Tensor([[[ 8.7541382e+01, 7.4368027e+01, 6.8198807e+01, 5.5106983e+01,\n", + " 3.8173203e+01, 2.3811323e+01, 1.8535612e+01, 1.8095873e+01],\n", + " [-5.1816605e-02, 4.1964740e-01, 4.4769207e-01, -6.5377988e-02,\n", + " 6.1878175e-01, 6.2997532e-01, -8.1161506e-02, -7.9315454e-02],\n", + " [ 2.4468634e+00, 2.2728169e+00, 2.2340939e+00, 2.6091626e+00,\n", + " 1.8117365e+00, 1.7758595e+00, 2.3642292e+00, 2.5458102e+00],\n", + " [ 8.7658936e+01, 8.1014442e+01, 7.5148209e+01, 5.5224800e+01,\n", + " 4.5717682e+01, 2.8694658e+01, 1.8596695e+01, 1.8153358e+01]],\n", + "\n", + " [[ 8.2484528e+01, 5.2682617e+01, 5.1243843e+01, 3.6217686e+01,\n", + " 2.8948278e+01, 2.6579512e+01, 2.1946012e+01, 2.1011120e+01],\n", + " [-4.3566185e-01, -8.7309110e-01, -4.4896263e-01, -6.0569459e-01,\n", + " -4.8134822e-01, -7.0045888e-01, -6.0671657e-01, -5.7662535e-01],\n", + " [-1.9739739e+00, -2.4504409e+00, -1.9982951e+00, -1.4225215e+00,\n", + " -1.9399333e+00, -2.3558097e+00, -1.4185165e+00, -1.4236869e+00],\n", + " [ 9.0437065e+01, 7.4070679e+01, 5.6495895e+01, 4.3069641e+01,\n", + " 3.2367134e+01, 3.3371326e+01, 2.6115334e+01, 2.4602446e+01]],\n", + "\n", + " [[ 8.6492935e+01, 7.0192978e+01, 5.8423912e+01, 5.6638733e+01,\n", + " 4.9270725e+01, 4.1237038e+01, 3.6133625e+01, 3.5519596e+01],\n", + " [ 1.4010678e-01, 2.7912292e-01, 1.4376265e-01, 3.4672296e-01,\n", + " 3.4966472e-01, 1.0524009e-01, 1.2958543e-01, 3.3264065e-01],\n", + " [ 1.9334941e+00, 1.6967584e+00, 1.9219695e+00, 1.6735281e+00,\n", + " 1.6587850e+00, 1.8386338e+00, 1.9120301e+00, 1.6680365e+00],\n", + " [ 8.7343246e+01, 7.2945129e+01, 5.9030762e+01, 6.0084766e+01,\n", + " 5.2313778e+01, 4.1465607e+01, 3.6440781e+01, 3.7506149e+01]]])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Po7Fb1v37EJT", + "outputId": "2571d4f4-e9db-4e2b-ee7f-13a7e406e526" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([1, 4, 8]), torch.Size([3, 4, 8]))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = x.unsqueeze(0) # batch dimension\n", + "x.shape, x_batch.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-ybT64WoWQ71" + }, + "source": [ + "## Quantum Function" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gm-q21_IWUCU", + "outputId": "cfa85478-65d3-4074-fce1-dcd09285b7de" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/pty.py:95: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/torchvision-0.21.0+7af6987-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/setuptools-75.8.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages/pillow-11.1.0-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: tensorcircuit in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (0.12.0)\n", + "Requirement already satisfied: numpy in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from tensorcircuit) (2.2.3)\n", + "Requirement already satisfied: scipy in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from tensorcircuit) (1.15.2)\n", + "Requirement already satisfied: tensornetwork-ng in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (from tensorcircuit) (0.5.1)\n", + "Requirement already satisfied: networkx in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from tensorcircuit) (3.4.2)\n", + "Requirement already satisfied: graphviz>=0.11.1 in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (from tensornetwork-ng->tensorcircuit) (0.21)\n", + "Requirement already satisfied: opt-einsum>=2.3.0 in /global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages (from tensornetwork-ng->tensorcircuit) (3.4.0)\n", + "Requirement already satisfied: h5py>=2.9.0 in /global/common/software/nersc9/pytorch/2.6.0/lib/python3.12/site-packages (from tensornetwork-ng->tensorcircuit) (3.13.0)\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "np.ComplexWarning = Warning\n", + "\n", + "!pip install tensorcircuit\n", + "from typing import Callable\n", + "\n", + "# --- monekypatch for tc---\n", + "import jax\n", + "try:\n", + " _ = jax.tree_map # JAX < 0.6 has this\n", + "except AttributeError:\n", + " # JAX ≥ 0.6 moved it here\n", + " from jax import tree_util as _jtu\n", + " jax.tree_map = _jtu.tree_map\n", + "\n", + "import tensorcircuit as tc\n", + "\n", + "import tensorcircuit as tc\n", + "import jax.numpy as jnp\n", + "import flax.linen\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import tensorcircuit as tc\n", + "\n", + "K = tc.set_backend(\"jax\")\n", + "\n", + "\n", + "def angle_embedding(c: tc.Circuit, inputs):\n", + " num_qubits = inputs.shape[-1]\n", + "\n", + " for j in range(num_qubits):\n", + " c.rx(j, theta=inputs[j])\n", + "\n", + "\n", + "def basic_vqc(c: tc.Circuit, inputs, weights):\n", + " num_qubits = inputs.shape[-1]\n", + " num_qlayers = weights.shape[-2]\n", + "\n", + " for i in range(num_qlayers):\n", + " for j in range(num_qubits):\n", + " c.rx(j, theta=weights[i, j])\n", + " if num_qubits == 2:\n", + " c.cnot(0, 1)\n", + " elif num_qubits > 2:\n", + " for j in range(num_qubits):\n", + " c.cnot(j, (j + 1) % num_qubits)\n", + "\n", + "\n", + "def get_quantum_layer_circuit(inputs, weights,\n", + " embedding: Callable = angle_embedding, vqc: Callable = basic_vqc):\n", + " \"\"\"\n", + " Equivalent to the following PennyLane circuit:\n", + " def circuit(inputs, weights):\n", + " qml.templates.AngleEmbedding(inputs, wires=range(num_qubits))\n", + " qml.templates.BasicEntanglerLayers(weights, wires=range(num_qubits))\n", + " \"\"\"\n", + "\n", + " num_qubits = inputs.shape[-1]\n", + "\n", + " c = tc.Circuit(num_qubits)\n", + " embedding(c, inputs)\n", + " vqc(c, inputs, weights)\n", + "\n", + " return c\n", + "\n", + "\n", + "def get_circuit(embedding: Callable = angle_embedding, vqc: Callable = basic_vqc,\n", + " torch_interface: bool = False):\n", + " def qpred(inputs, weights):\n", + " c = get_quantum_layer_circuit(inputs, weights, embedding, vqc)\n", + " return K.real(jnp.array([c.expectation_ps(z=[i]) for i in range(weights.shape[1])]))\n", + "\n", + " qpred_batch = K.vmap(qpred, vectorized_argnums=0)\n", + " if torch_interface:\n", + " qpred_batch = tc.interfaces.torch_interface(qpred_batch, jit=True)\n", + "\n", + " return qpred_batch\n", + "\n", + "\n", + "class QuantumLayer(flax.linen.Module):\n", + " circuit: Callable\n", + " num_qubits: int\n", + " w_shape: tuple = (1,)\n", + "\n", + " @flax.linen.compact\n", + " def __call__(self, x):\n", + " shape = x.shape\n", + " x = jnp.reshape(x, (-1, shape[-1]))\n", + " w = self.param('w', flax.linen.initializers.xavier_normal(), self.w_shape + (self.num_qubits,))\n", + " x = self.circuit(x, w)\n", + " x = jnp.concatenate(x, axis=-1)\n", + " x = jnp.reshape(x, tuple(shape))\n", + " return x\n", + "\n", + "\n", + "\n", + "NUM_QUBITS = 8\n", + "NUM_Q_LAYERS = 1\n", + "torch_layer_fn = get_circuit(torch_interface=True)\n", + "\n", + "\n", + "class TCTorchLayer(nn.Module):\n", + " \"\"\"\n", + " A thin PyTorch wrapper around the TensorCircuit/TC quantum layer.\n", + " Stores the circuit's trainable parameters as an nn.Parameter so\n", + " they appear in .parameters() and get updated by any torch optimizer.\n", + " \"\"\"\n", + " def __init__(self, num_qubits=NUM_QUBITS, num_qlayers=NUM_Q_LAYERS):\n", + " super().__init__()\n", + " init_w = 0.01 * torch.randn(num_qlayers, num_qubits)\n", + " self.w = nn.Parameter(init_w)\n", + " self.num_qubits = num_qubits\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\n", + " x: (batch, num_qubits) – already pre-scaled into rotation angles.\n", + " Returns expectation values ⟨Z_i⟩ for every qubit i, shape identical\n", + " to the input (batch, num_qubits).\n", + " \"\"\"\n", + " return torch_layer_fn(x, self.w)\n", + "\n", + "\n", + "class QuantumLinear(nn.Module):\n", + " \"\"\"\n", + " Linear -> angle map -> TCTorchLayer -> Linear\n", + " Works on tensors shaped (..., din) and returns (..., dout).\n", + " \"\"\"\n", + " def __init__(self, din, dout, num_qubits):\n", + " super().__init__()\n", + " self.din = din\n", + " self.dout = dout\n", + " self.nq = num_qubits\n", + "\n", + " #self.to_q = nn.Linear(din, self.nq, bias=False)\n", + " #self.from_q = nn.Linear(self.nq, dout, bias=False)\n", + " self.q = TCTorchLayer(self.nq)\n", + "\n", + " @staticmethod\n", + " def _to_angles(x):\n", + " return torch.tanh(x) * math.pi\n", + "\n", + " def forward(self, x):\n", + " # x: (..., din)\n", + " *prefix, _ = x.shape\n", + " x = x.reshape(-1, self.din)\n", + "\n", + " #x = self.to_q(x)\n", + " x = self._to_angles(x)\n", + " x = self.q(x).float()\n", + " #x = self.from_q(x)\n", + "\n", + " x = x.reshape(*prefix, self.dout)\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X6d5fzsV4kt7" + }, + "source": [ + "## Toy interaction matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sOjDJYr14zDr", + "outputId": "aa8b4aef-a5c1-4e3e-a125-6e86079468cd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "U.shape: torch.Size([1, 2, 8, 8])\n" + ] + } + ], + "source": [ + "class InteractionEncoder(nn.Module):\n", + " \"\"\"\n", + " ParT interaction-feature encoder.\n", + "\n", + " Args\n", + " ----\n", + " n_heads per mhsa: output channels d′\n", + " hidden_channels : list[int] for intermediate 1×1 conv layers\n", + " eps : numerical guard for log\n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " n_heads: int = 8,\n", + " hidden_channels: list[int] = (64, 64, 64),\n", + " eps: float = 1e-8):\n", + " super().__init__()\n", + " self.eps = eps\n", + "\n", + " layers: list[nn.Module] = []\n", + " in_ch = 4 # lnΔ, ln kT, ln z, ln m²\n", + " for h in hidden_channels:\n", + " layers += [\n", + " nn.Conv2d(in_ch, h, 1, bias=False),\n", + " nn.BatchNorm2d(h),\n", + " nn.GELU()\n", + " ]\n", + " in_ch = h\n", + " layers.append(nn.Conv2d(in_ch, n_heads, 1, bias=False))\n", + " self.net = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " x : (B, 4, N) where the 4 dims are (E, px, py, pz)\n", + " returns\n", + " ------\n", + " U : (B, n_heads, N, N) interaction embedding\n", + " \"\"\"\n", + " B, four, N = x.shape\n", + " assert four == 4, \"input must have 4 features: E, px, py, pz\"\n", + "\n", + " # Split components\n", + " E, px, py, pz = x.unbind(dim=1) # each (B, N)\n", + "\n", + " # Basic kinematics ------------------------------------------------\n", + " pT = torch.sqrt(px**2 + py**2) + self.eps\n", + " phi = torch.atan2(py, px) # (−π, π]\n", + " num = (E + pz).clamp(min=self.eps) #need to avoid negative numbers\n", + " den = (E - pz).clamp(min=self.eps)\n", + " y = 0.5 * torch.log(num / den)\n", + "\n", + " # Expand to (B, N, N)\n", + " y_a, y_b = y.unsqueeze(2), y.unsqueeze(1) # (B,N,1),(B,1,N)\n", + " phi_a, phi_b = phi.unsqueeze(2), phi.unsqueeze(1)\n", + " pT_a, pT_b = pT.unsqueeze(2), pT.unsqueeze(1)\n", + " E_a, E_b = E.unsqueeze(2), E.unsqueeze(1)\n", + " px_a, px_b = px.unsqueeze(2), px.unsqueeze(1)\n", + " py_a, py_b = py.unsqueeze(2), py.unsqueeze(1)\n", + " pz_a, pz_b = pz.unsqueeze(2), pz.unsqueeze(1)\n", + "\n", + " # ΔR, kT, z\n", + " delta = torch.sqrt((y_a - y_b) ** 2 + (phi_a - phi_b) ** 2) + self.eps\n", + " kT = torch.minimum(pT_a, pT_b) * delta\n", + " z = torch.minimum(pT_a, pT_b) / (pT_a + pT_b + self.eps)\n", + "\n", + " # m² of pair\n", + " E_sum = E_a + E_b\n", + " px_sum = px_a + px_b\n", + " py_sum = py_a + py_b\n", + " pz_sum = pz_a + pz_b\n", + " m2 = E_sum**2 - (px_sum**2 + py_sum**2 + pz_sum**2) + self.eps\n", + " m2 = torch.clamp(m2, min=self.eps) # avoid negatives\n", + "\n", + " # Stack → (B, 4, N, N)\n", + " feats = torch.stack([\n", + " torch.log(delta),\n", + " torch.log(kT),\n", + " torch.log(z),\n", + " torch.log(m2)\n", + " ], dim=1)\n", + "\n", + " # conv\n", + " U = self.net(feats) # (B, n_heads, N, N)\n", + " return U\n", + "\n", + "\n", + "\n", + "B, _, N = x.shape\n", + "n_heads = 2 # d′\n", + "enc = InteractionEncoder(n_heads=n_heads)\n", + "U = enc(x)\n", + "print(\"U.shape:\", U.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5CwQyoPYwaR7" + }, + "source": [ + "## Particle Transformer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ya5SvxTJwE3r", + "outputId": "27b1b1d0-4863-41b2-99f9-4daadd27688a" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 8, 10])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class ParticleTokenizer(nn.Module):\n", + " def __init__(self, in_dim=4, out_dim=6):\n", + " super().__init__()\n", + " self.proj = nn.Linear(in_dim, out_dim)\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\n", + " x: tensor of shape (B, n_particles, in_dim)\n", + " returns: (B, n_particles, out_dim)\n", + " \"\"\"\n", + " x = x.transpose(1, 2) # Input shape: (B, n_particles, in_dim) → (B, in_dim, n_particles)\n", + " return self.proj(x)\n", + "\n", + "tokenizer = ParticleTokenizer(4, 10)\n", + "output = tokenizer(x)\n", + "output.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "S5f8DTRY07PD", + "outputId": "edb84881-2ccc-456a-bdc7-59ad23397fd5" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR:2025-09-10 05:32:57,887:jax._src.xla_bridge:487: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()\n", + "Traceback (most recent call last):\n", + " File \"/global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages/jax/_src/xla_bridge.py\", line 485, in discover_pjrt_plugins\n", + " plugin_module.initialize()\n", + " File \"/global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py\", line 328, in initialize\n", + " _check_cuda_versions(raise_on_first_error=True)\n", + " File \"/global/homes/a/aletesi/.local/perlmutter/pytorch2.6.0/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py\", line 285, in _check_cuda_versions\n", + " local_device_count = cuda_versions.cuda_device_count()\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "RuntimeError: jaxlib/cuda/versions_helpers.cc:113: operation cuInit(0) failed: CUDA_ERROR_NO_DEVICE\n" + ] + }, + { + "data": { + "text/plain": [ + "torch.Size([1, 8, 10])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class MLP(nn.Module):\n", + " \"\"\"\n", + " Same interface as your tiny MLP, but nn.Linear -> QuantumLinear.\n", + " Works for inputs shaped (..., dim).\n", + "\n", + " Args:\n", + " dim : feature size\n", + " dropout : dropout prob\n", + " num_qubits : qubits per QuantumLinear block (defaults to dim)\n", + " \"\"\"\n", + " def __init__(self, dim, dropout=0., num_qubits=None):\n", + " super().__init__()\n", + " nq = num_qubits if num_qubits is not None else dim\n", + "\n", + " self.fc1 = QuantumLinear(dim, dim, nq)\n", + " self.fc2 = QuantumLinear(dim, dim, nq)\n", + "\n", + " self.act = nn.GELU()\n", + " self.do1 = nn.Dropout(dropout)\n", + " self.do2 = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x):\n", + " x = self.fc1(x)\n", + " x = self.act(x)\n", + " x = self.do1(x)\n", + "\n", + " x = self.fc2(x)\n", + " x = self.do2(x)\n", + " return x\n", + "\n", + "# usage\n", + "mlp = MLP(10, dropout=0.1)\n", + "output = mlp(output)\n", + "output.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "zb5nYOsSyFRk", + "outputId": "7335d6f6-9642-45ad-90b4-981175a47f6d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 8, 10]) torch.Size([1, 2, 8, 8])\n" + ] + } + ], + "source": [ + "class ParticleMHA(nn.Module):\n", + " \"\"\"\n", + " Multi-head self-attention with quantum projections (q, k, v, o).\n", + "\n", + " Args\n", + " ----\n", + " d : embedding dim\n", + " heads : number of attention heads\n", + " dropout : dropout prob on attn weights\n", + " return_attn : return attention maps?\n", + " num_qubits : qubits per quantum block (defaults to d)\n", + " \"\"\"\n", + " def __init__(self, d: int, heads: int = 8,\n", + " dropout: float = 0.1, return_attn: bool = False,\n", + " num_qubits: int | None = None):\n", + " super().__init__()\n", + " assert d % heads == 0, \"`d` must be divisible by `heads`\"\n", + "\n", + " self.d = d\n", + " self.h = heads\n", + " self.d_head = d // heads\n", + " self.scale = 1 / math.sqrt(self.d_head)\n", + " self.return_attn = return_attn\n", + "\n", + " nq = num_qubits if num_qubits is not None else d\n", + "\n", + " # quantum projections\n", + " self.q_proj = QuantumLinear(d, d, nq)\n", + " self.k_proj = QuantumLinear(d, d, nq)\n", + " self.v_proj = QuantumLinear(d, d, nq)\n", + " self.o_proj = QuantumLinear(d, d, nq)\n", + "\n", + " self.drop = nn.Dropout(dropout)\n", + "\n", + " def _split(self, t: torch.Tensor):\n", + " # (B, N, d) -> (B, H, N, d_head)\n", + " B, N, _ = t.shape\n", + " return t.view(B, N, self.h, self.d_head).transpose(1, 2)\n", + "\n", + " def forward(self, x: torch.Tensor, U: torch.Tensor | None = None):\n", + " B, N, _ = x.shape\n", + "\n", + " Q = self._split(self.q_proj(x))\n", + " K = self._split(self.k_proj(x))\n", + " V = self._split(self.v_proj(x))\n", + "\n", + " logits = (Q @ K.transpose(-2, -1)) * self.scale # (B, H, N, N)\n", + "\n", + " if U is not None:\n", + " logits = logits + U\n", + "\n", + " attn = F.softmax(logits, dim=-1)\n", + " attn = self.drop(attn)\n", + "\n", + " context = attn @ V # (B, H, N, d_head)\n", + "\n", + " context = (\n", + " context.transpose(1, 2) # (B, N, H, d_head)\n", + " .contiguous()\n", + " .view(B, N, self.d)\n", + " )\n", + " out = self.o_proj(context)\n", + "\n", + " if self.return_attn:\n", + " return out, attn\n", + " else:\n", + " return out\n", + "\n", + "B, N, d = output.shape\n", + "\n", + "U = torch.randn(1, 2, N, N) # broadcast to (B, H, N, N)\n", + "\n", + "pmha = ParticleMHA(d=d, heads=2, dropout=0.1, return_attn=True)\n", + "output, A = pmha(output, U) # out: (B, N, d) A: (B, 8, N, N)\n", + "\n", + "print(output.shape, A.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4650yyFtdjaR" + }, + "source": [ + "## transformer" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "CKYo7jEiyYx1" + }, + "outputs": [], + "source": [ + "class MHA(nn.Module):\n", + " \"\"\"\n", + " Multi-head attention (batch_first) with QuantumLinear projections.\n", + "\n", + " Args\n", + " ----\n", + " d_model : int embedding dim\n", + " n_heads : int\n", + " dropout: float\n", + " bias : bool (ignored here, QuantumLinear has no bias)\n", + " num_qubits : int|None qubits per quantum block (defaults to d_model)\n", + " \"\"\"\n", + " def __init__(self, d_model: int, n_heads: int,\n", + " dropout: float = 0., bias: bool = False,\n", + " num_qubits: int | None = None):\n", + " super().__init__()\n", + " assert d_model % n_heads == 0, \"`d_model` must be divisible by `n_heads`\"\n", + " self.d_model = d_model\n", + " self.h = n_heads\n", + " self.d_head = d_model // n_heads\n", + " self.scale = self.d_head ** -0.5\n", + "\n", + " nq = num_qubits if num_qubits is not None else d_model\n", + "\n", + " # Quantum projections replace nn.Linear\n", + " self.q_proj = QuantumLinear(d_model, d_model, nq)\n", + " self.k_proj = QuantumLinear(d_model, d_model, nq)\n", + " self.v_proj = QuantumLinear(d_model, d_model, nq)\n", + " self.o_proj = QuantumLinear(d_model, d_model, nq)\n", + "\n", + " self.drop = nn.Dropout(dropout)\n", + "\n", + " def _split_heads(self, x: torch.Tensor):\n", + " # (B, L, d_model) -> (B, h, L, d_head)\n", + " B, L, _ = x.shape\n", + " return x.view(B, L, self.h, self.d_head).transpose(1, 2)\n", + "\n", + " def _merge_heads(self, x: torch.Tensor):\n", + " # (B, h, L, d_head) -> (B, L, d_model)\n", + " B, H, L, Dh = x.shape\n", + " return x.transpose(1, 2).contiguous().view(B, L, H * Dh)\n", + "\n", + " def forward(\n", + " self,\n", + " q: torch.Tensor, # (B, Lq, d_model)\n", + " k: torch.Tensor, # (B, Lk, d_model)\n", + " v: torch.Tensor, # (B, Lk, d_model)\n", + " attn_mask: torch.Tensor | None = None,\n", + " key_padding_mask: torch.Tensor | None = None,\n", + " need_weights: bool = False\n", + " ):\n", + " B, Lq, _ = q.shape\n", + " _, Lk, _ = k.shape\n", + "\n", + " Q = self._split_heads(self.q_proj(q)) # (B,h,Lq,d_h)\n", + " K = self._split_heads(self.k_proj(k)) # (B,h,Lk,d_h)\n", + " V = self._split_heads(self.v_proj(v)) # (B,h,Lk,d_h)\n", + "\n", + " logits = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B,h,Lq,Lk)\n", + "\n", + " attn = F.softmax(logits, dim=-1)\n", + " attn = self.drop(attn)\n", + "\n", + " context = torch.matmul(attn, V) # (B,h,Lq,d_h)\n", + "\n", + " out = self.o_proj(self._merge_heads(context)) # (B,Lq,d_model)\n", + "\n", + " if need_weights:\n", + " return out, attn.mean(dim=1) # (B,Lq,Lk)\n", + " return out, None" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "iQveEBhldizq" + }, + "outputs": [], + "source": [ + "# Particle attention block (NormFormer style + U-bias)\n", + "class ParticleAttentionBlock(nn.Module):\n", + " def __init__(self, dim, heads, mlp_ratio=4, dropout=0.1):\n", + " super().__init__()\n", + " self.ln1 = nn.LayerNorm(dim)\n", + " self.attn = ParticleMHA(dim, heads, dropout)\n", + " self.ln2 = nn.LayerNorm(dim)\n", + " self.mlp = MLP(dim, dropout)\n", + " def forward(self, x, U):\n", + " x = x + self.attn(self.ln1(x), U) # bias-aware MHSA\n", + " x = x + self.mlp(self.ln2(x)) # feed-forward\n", + " return x\n", + "\n", + "# Class attention block (CaiT style, no U)\n", + "class ClassAttentionBlock(nn.Module):\n", + " def __init__(self, dim, heads, mlp_ratio=4, dropout=0.):\n", + " super().__init__()\n", + " self.ln1 = nn.LayerNorm(dim)\n", + " self.attn = MHA(dim, heads, dropout)\n", + " self.ln2 = nn.LayerNorm(dim)\n", + " self.mlp = MLP(dim, dropout)\n", + " def forward(self, tokens, cls): # tokens: (B,N,d), cls: (B,1,d)\n", + " z = torch.cat([cls, tokens], dim=1) # (B,1+N,d)\n", + " q = self.ln1(cls)\n", + " kv = self.ln1(z)\n", + " cls = cls + self.attn(q, kv, kv, need_weights=False)[0]\n", + " cls = cls + self.mlp(self.ln2(cls))\n", + " return cls # (B,1,d)\n", + "\n", + "# Complete Particle Transformer\n", + "class ParT(nn.Module):\n", + " def __init__(self,\n", + " in_dim=4, # (E,px,py,pz)\n", + " embed_dim=10,\n", + " n_heads=2,\n", + " depth=2, # particle blocks\n", + " class_depth=2, # class-attention blocks\n", + " mlp_ratio=4,\n", + " num_classes=10,\n", + " dropout=0.1):\n", + " super().__init__()\n", + "\n", + " self.tokenizer = ParticleTokenizer(in_dim, embed_dim)\n", + " self.U_encoder = InteractionEncoder(n_heads=n_heads)\n", + "\n", + " self.blocks = nn.ModuleList([\n", + " ParticleAttentionBlock(embed_dim, n_heads, mlp_ratio, dropout)\n", + " for _ in range(depth)\n", + " ])\n", + "\n", + " self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n", + " self.cls_blocks = nn.ModuleList([\n", + " ClassAttentionBlock(embed_dim, n_heads, mlp_ratio, 0.0)\n", + " for _ in range(class_depth)\n", + " ])\n", + "\n", + " self.head = nn.Linear(embed_dim, num_classes)\n", + "\n", + " # weight init\n", + " nn.init.trunc_normal_(self.class_token, std=0.02)\n", + " nn.init.trunc_normal_(self.head.weight, std=0.02)\n", + " nn.init.zeros_(self.head.bias)\n", + "\n", + " def forward(self, x): # x: (B,4,N)\n", + " B, _, N = x.shape\n", + "\n", + " tokens = self.tokenizer(x) # (B,N,d)\n", + " U = self.U_encoder(x) # (B,H,N,N)\n", + "\n", + " for blk in self.blocks:\n", + " tokens = blk(tokens, U) # (B,N,d)\n", + "\n", + " cls = self.class_token.expand(B, -1, -1) # (B,1,d)\n", + " for blk in self.cls_blocks:\n", + " cls = blk(tokens, cls) # (B,1,d)\n", + "\n", + " logits = self.head(cls.squeeze(1)) # (B,10)\n", + " return logits" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0-g5CQ2hfh2x", + "outputId": "5a6e20d4-2855-46ad-a949-a19f9af36776" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logits: torch.Size([3, 10])\n" + ] + } + ], + "source": [ + "B, _, N = x_batch.shape # (3,4,8)\n", + "model = ParT(in_dim=4,\n", + " embed_dim=10,\n", + " n_heads=2,\n", + " depth=2,\n", + " class_depth=2,\n", + " num_classes=10)\n", + "\n", + "logits = model(x_batch) # forward pass\n", + "print(\"logits:\", logits.shape) # torch.Size([3, 10])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mOr6fyLvfjgv", + "outputId": "bc187fcb-b348-47ce-d71c-e52e55889367" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 1 loss 2.3316 acc 0.000\n", + "epoch 5 loss 2.2308 acc 0.000\n", + "epoch 10 loss 2.1054 acc 0.333\n", + "epoch 15 loss 1.9857 acc 0.333\n", + "epoch 20 loss 1.8778 acc 0.333\n", + "epoch 25 loss 1.7586 acc 0.333\n", + "epoch 30 loss 1.6916 acc 0.333\n", + "epoch 35 loss 1.6236 acc 0.333\n", + "epoch 40 loss 1.5450 acc 0.333\n", + "epoch 45 loss 1.4880 acc 0.333\n", + "epoch 50 loss 1.4353 acc 0.333\n", + "epoch 55 loss 1.4027 acc 0.333\n", + "epoch 60 loss 1.4613 acc 0.333\n", + "epoch 65 loss 1.3617 acc 0.333\n", + "epoch 70 loss 1.3581 acc 0.333\n", + "epoch 75 loss 1.2670 acc 0.333\n", + "epoch 80 loss 1.2401 acc 0.333\n", + "epoch 85 loss 1.2122 acc 0.333\n", + "epoch 90 loss 1.1957 acc 0.333\n", + "epoch 95 loss 1.1852 acc 0.333\n", + "epoch 100 loss 1.1740 acc 0.333\n", + "epoch 105 loss 1.1652 acc 0.333\n", + "epoch 110 loss 1.1514 acc 0.333\n", + "epoch 115 loss 1.1471 acc 0.333\n", + "epoch 120 loss 1.1424 acc 0.333\n", + "epoch 125 loss 1.1397 acc 0.333\n", + "epoch 130 loss 1.1355 acc 0.333\n", + "epoch 135 loss 1.1337 acc 0.333\n", + "epoch 140 loss 1.1302 acc 0.333\n", + "epoch 145 loss 1.1279 acc 0.333\n", + "epoch 150 loss 1.1240 acc 0.333\n", + "epoch 155 loss 1.1214 acc 0.333\n", + "epoch 160 loss 1.1180 acc 0.333\n", + "epoch 165 loss 1.1163 acc 0.333\n", + "epoch 170 loss 1.1139 acc 0.333\n", + "epoch 175 loss 1.1097 acc 0.333\n", + "epoch 180 loss 1.1100 acc 0.333\n", + "epoch 185 loss 1.1121 acc 0.333\n", + "epoch 190 loss 1.1089 acc 0.333\n", + "epoch 195 loss 1.1108 acc 0.333\n", + "epoch 200 loss 1.1086 acc 0.333\n", + "epoch 205 loss 1.1066 acc 0.333\n", + "epoch 210 loss 1.1056 acc 0.333\n", + "epoch 215 loss 1.1020 acc 0.333\n", + "epoch 220 loss 1.1009 acc 0.333\n", + "epoch 225 loss 1.1007 acc 0.333\n", + "epoch 230 loss 1.1007 acc 0.333\n", + "epoch 235 loss 1.1055 acc 0.333\n", + "epoch 240 loss 1.1003 acc 0.333\n", + "epoch 245 loss 1.0947 acc 0.333\n", + "epoch 250 loss 1.0965 acc 0.333\n", + "epoch 255 loss 1.0997 acc 0.333\n", + "epoch 260 loss 1.0980 acc 0.333\n", + "epoch 265 loss 1.0971 acc 0.333\n", + "epoch 270 loss 1.0947 acc 0.333\n", + "epoch 275 loss 1.0946 acc 0.333\n", + "epoch 280 loss 1.0967 acc 0.333\n", + "epoch 285 loss 1.0937 acc 0.333\n", + "epoch 290 loss 1.0973 acc 0.333\n", + "epoch 295 loss 1.0912 acc 0.333\n", + "epoch 300 loss 1.0900 acc 0.333\n", + "epoch 305 loss 1.0939 acc 0.333\n", + "epoch 310 loss 1.0940 acc 0.333\n", + "epoch 315 loss 1.0840 acc 0.333\n", + "epoch 320 loss 1.0876 acc 0.333\n", + "epoch 325 loss 1.0823 acc 0.333\n", + "epoch 330 loss 1.0891 acc 0.333\n", + "epoch 335 loss 1.0800 acc 0.333\n", + "epoch 340 loss 1.0744 acc 0.333\n", + "epoch 345 loss 1.0699 acc 0.333\n", + "epoch 350 loss 1.0774 acc 0.333\n", + "epoch 355 loss 1.0752 acc 0.333\n", + "epoch 360 loss 1.0614 acc 0.333\n", + "epoch 365 loss 1.0602 acc 0.333\n", + "epoch 370 loss 1.0550 acc 0.333\n", + "epoch 375 loss 1.0498 acc 0.333\n", + "epoch 380 loss 1.0475 acc 0.333\n", + "epoch 385 loss 1.0468 acc 0.333\n", + "epoch 390 loss 1.0278 acc 0.333\n", + "epoch 395 loss 1.0297 acc 0.333\n", + "epoch 400 loss 1.0267 acc 0.333\n", + "epoch 405 loss 1.0167 acc 0.333\n", + "epoch 410 loss 1.0112 acc 0.333\n", + "epoch 415 loss 1.0093 acc 0.333\n", + "epoch 420 loss 1.0064 acc 0.333\n", + "epoch 425 loss 1.0078 acc 0.333\n", + "epoch 430 loss 0.9952 acc 0.333\n", + "epoch 435 loss 0.9849 acc 0.333\n", + "epoch 440 loss 0.9791 acc 0.333\n", + "epoch 445 loss 0.9811 acc 0.333\n", + "epoch 450 loss 0.9727 acc 0.333\n", + "epoch 455 loss 0.9896 acc 0.333\n", + "epoch 460 loss 0.9642 acc 0.333\n", + "epoch 465 loss 0.9644 acc 0.333\n", + "epoch 470 loss 0.9621 acc 0.333\n", + "epoch 475 loss 0.9535 acc 0.333\n", + "epoch 480 loss 0.9475 acc 0.333\n", + "epoch 485 loss 0.9359 acc 0.333\n", + "epoch 490 loss 0.9262 acc 0.333\n", + "epoch 495 loss 0.9167 acc 0.333\n", + "epoch 500 loss 0.9068 acc 0.333\n", + "epoch 505 loss 0.9054 acc 0.333\n", + "epoch 510 loss 0.8907 acc 0.667\n", + "epoch 515 loss 0.8850 acc 0.667\n", + "epoch 520 loss 0.8911 acc 0.333\n", + "epoch 525 loss 0.8699 acc 0.667\n", + "epoch 530 loss 0.9836 acc 0.667\n", + "epoch 535 loss 0.8612 acc 0.667\n", + "epoch 540 loss 0.8495 acc 0.667\n", + "epoch 545 loss 0.8412 acc 0.667\n", + "epoch 550 loss 0.8383 acc 0.667\n", + "epoch 555 loss 0.8294 acc 0.667\n", + "epoch 560 loss 0.8248 acc 0.667\n", + "epoch 565 loss 0.8084 acc 1.000\n", + "epoch 570 loss 0.8243 acc 0.667\n", + "epoch 575 loss 0.9476 acc 0.333\n", + "epoch 580 loss 0.9564 acc 0.333\n", + "epoch 585 loss 0.8961 acc 0.333\n", + "epoch 590 loss 0.8742 acc 0.667\n", + "epoch 595 loss 0.9468 acc 0.333\n", + "epoch 600 loss 0.8692 acc 1.000\n", + "epoch 605 loss 0.8769 acc 1.000\n", + "epoch 610 loss 0.8851 acc 0.667\n", + "epoch 615 loss 0.9639 acc 0.333\n", + "epoch 620 loss 0.8225 acc 0.667\n", + "epoch 625 loss 0.8487 acc 1.000\n", + "epoch 630 loss 0.8182 acc 0.667\n", + "epoch 635 loss 0.8045 acc 0.667\n", + "epoch 640 loss 0.7901 acc 0.667\n", + "epoch 645 loss 0.7905 acc 0.667\n", + "epoch 650 loss 0.7570 acc 0.667\n", + "epoch 655 loss 0.7791 acc 0.667\n", + "epoch 660 loss 0.7679 acc 0.667\n", + "epoch 665 loss 0.7704 acc 0.667\n", + "epoch 670 loss 0.7302 acc 1.000\n", + "epoch 675 loss 0.7271 acc 1.000\n", + "epoch 680 loss 0.7225 acc 0.667\n", + "epoch 685 loss 0.7620 acc 0.667\n", + "epoch 690 loss 0.7139 acc 0.667\n", + "epoch 695 loss 0.7024 acc 1.000\n", + "epoch 700 loss 0.7003 acc 1.000\n", + "epoch 705 loss 0.6757 acc 1.000\n", + "epoch 710 loss 0.6819 acc 1.000\n", + "epoch 715 loss 0.6731 acc 0.667\n", + "epoch 720 loss 0.6665 acc 1.000\n", + "epoch 725 loss 0.6693 acc 0.667\n", + "epoch 730 loss 0.6495 acc 1.000\n", + "epoch 735 loss 0.6726 acc 0.667\n", + "epoch 740 loss 0.6353 acc 1.000\n", + "epoch 745 loss 0.6218 acc 1.000\n", + "epoch 750 loss 0.6278 acc 1.000\n", + "epoch 755 loss 0.6223 acc 1.000\n", + "epoch 760 loss 0.6610 acc 1.000\n", + "epoch 765 loss 0.6178 acc 1.000\n", + "epoch 770 loss 0.6261 acc 0.667\n", + "epoch 775 loss 0.5861 acc 1.000\n", + "epoch 780 loss 0.5800 acc 1.000\n", + "epoch 785 loss 0.5715 acc 1.000\n", + "epoch 790 loss 0.5731 acc 1.000\n", + "epoch 795 loss 0.5497 acc 1.000\n", + "epoch 800 loss 0.5470 acc 1.000\n", + "epoch 805 loss 0.5386 acc 1.000\n", + "epoch 810 loss 0.5647 acc 1.000\n", + "epoch 815 loss 0.5224 acc 1.000\n", + "epoch 820 loss 0.5323 acc 1.000\n", + "epoch 825 loss 0.5144 acc 1.000\n", + "epoch 830 loss 0.5173 acc 1.000\n", + "epoch 835 loss 0.5020 acc 1.000\n", + "epoch 840 loss 0.4899 acc 1.000\n", + "epoch 845 loss 0.4940 acc 1.000\n", + "epoch 850 loss 0.4891 acc 1.000\n", + "epoch 855 loss 0.4936 acc 1.000\n", + "epoch 860 loss 0.5026 acc 1.000\n", + "epoch 865 loss 0.4797 acc 1.000\n", + "epoch 870 loss 0.4636 acc 1.000\n", + "epoch 875 loss 0.4686 acc 1.000\n", + "epoch 880 loss 0.4484 acc 1.000\n", + "epoch 885 loss 0.4472 acc 1.000\n", + "epoch 890 loss 0.4440 acc 1.000\n", + "epoch 895 loss 0.4381 acc 1.000\n", + "epoch 900 loss 0.4589 acc 1.000\n", + "epoch 905 loss 0.4373 acc 1.000\n", + "epoch 910 loss 0.4210 acc 1.000\n", + "epoch 915 loss 0.4409 acc 1.000\n", + "epoch 920 loss 0.4291 acc 1.000\n", + "epoch 925 loss 0.4108 acc 1.000\n", + "epoch 930 loss 0.4230 acc 1.000\n", + "epoch 935 loss 0.3981 acc 1.000\n", + "epoch 940 loss 0.3675 acc 1.000\n", + "epoch 945 loss 0.3853 acc 1.000\n", + "epoch 950 loss 0.3950 acc 1.000\n", + "epoch 955 loss 0.3614 acc 1.000\n", + "epoch 960 loss 0.3742 acc 1.000\n", + "epoch 965 loss 0.3547 acc 1.000\n", + "epoch 970 loss 0.3848 acc 1.000\n", + "epoch 975 loss 0.3478 acc 1.000\n", + "epoch 980 loss 0.3503 acc 1.000\n", + "epoch 985 loss 0.3486 acc 1.000\n", + "epoch 990 loss 0.3257 acc 1.000\n", + "epoch 995 loss 0.3348 acc 1.000\n", + "epoch 1000 loss 0.3489 acc 1.000\n", + "softmax-probs\n", + " tensor([[8.6662e-01, 1.2480e-04, 1.3326e-01, 1.6081e-08, 1.7857e-08, 1.5209e-08,\n", + " 1.4145e-08, 2.3633e-08, 3.4979e-08, 1.2713e-08],\n", + " [1.7590e-02, 6.8803e-01, 1.7758e-01, 1.6722e-02, 1.6682e-02, 1.6407e-02,\n", + " 1.6124e-02, 1.7075e-02, 1.8375e-02, 1.5415e-02],\n", + " [2.6015e-01, 9.3295e-02, 6.4522e-01, 1.7904e-04, 1.8575e-04, 1.7249e-04,\n", + " 1.6769e-04, 2.1287e-04, 2.6556e-04, 1.5600e-04]])\n" + ] + } + ], + "source": [ + "x_train = x_batch # (3, 4, 8)\n", + "y_train = torch.tensor([0, 1, 2]) # dummy class labels for testing\n", + "\n", + "model = ParT(in_dim=4,\n", + " embed_dim=10,\n", + " n_heads=2,\n", + " depth=2,\n", + " class_depth=2,\n", + " num_classes=10)\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "\n", + "n_epochs = 1000\n", + "for epoch in range(n_epochs):\n", + " model.train()\n", + " logits = model(x_train) # (3, 10)\n", + " loss = criterion(logits, y_train)\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # print every 5 epochs\n", + " if (epoch+1) % 5 == 0 or epoch == 0:\n", + " preds = logits.argmax(1)\n", + " acc = (preds == y_train).float().mean().item()\n", + " print(f\"epoch {epoch+1:3d} loss {loss.item():.4f} acc {acc:.3f}\")\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " probs = torch.softmax(model(x_train), dim=1)\n", + "print(\"softmax-probs\\n\", probs)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "pytorch-2.6.0", + "language": "python", + "name": "pytorch-2.6.0" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}