Skip to content

Commit

Permalink
add NMS with indices
Browse files Browse the repository at this point in the history
  • Loading branch information
yarden-sony committed Feb 5, 2025
1 parent 498b0c9 commit f60d0a9
Showing 1 changed file with 193 additions and 36 deletions.
229 changes: 193 additions & 36 deletions tutorials/pytorch/multiclass_nms_custom_layer_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,61 @@
"id": "29309a0291ff4f41"
},
{
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-05T16:04:15.883254Z",
"start_time": "2025-02-05T16:04:06.966813Z"
}
},
"cell_type": "code",
"source": [
"!pip install -q torch\n",
"!pip install onnx\n",
"!pip install -q model_compression_toolkit"
],
"id": "initial_id",
"outputs": [],
"execution_count": null
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m23.0.1\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m25.0\u001B[0m\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n",
"Requirement already satisfied: onnx in /Vols/vol_design/tools/swat/users/yardeny/projects/alg-llm-tools/llm_tools_venv/lib/python3.10/site-packages (1.17.0)\r\n",
"Requirement already satisfied: numpy>=1.20 in /Vols/vol_design/tools/swat/users/yardeny/projects/alg-llm-tools/llm_tools_venv/lib/python3.10/site-packages (from onnx) (1.26.4)\r\n",
"Requirement already satisfied: protobuf>=3.20.2 in /Vols/vol_design/tools/swat/users/yardeny/projects/alg-llm-tools/llm_tools_venv/lib/python3.10/site-packages (from onnx) (5.29.3)\r\n",
"\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m23.0.1\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m25.0\u001B[0m\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n",
"\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m23.0.1\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m25.0\u001B[0m\r\n",
"\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n"
]
}
],
"execution_count": 78
},
{
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-05T16:04:15.894280Z",
"start_time": "2025-02-05T16:04:15.887336Z"
}
},
"cell_type": "code",
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from typing import Iterator, List\n",
"import model_compression_toolkit as mct\n",
"from sony_custom_layers.pytorch.nms import multiclass_nms"
"from sony_custom_layers.pytorch.nms import multiclass_nms\n",
"from sony_custom_layers.pytorch.nms.nms_with_indices import multiclass_nms_with_indices"
],
"id": "3a7da9c475f95aa9",
"outputs": [],
"execution_count": null
"execution_count": 79
},
{
"metadata": {},
Expand All @@ -57,38 +88,77 @@
"\n",
"### Create Model Instance\n",
"\n",
"We will start with creating a simple object-detection model instance as an example. You can replace the model with your own model use a pre-trained model (Make sure the model is supported by MCT library). "
"We will start with creating a simple object-detection model instance as an example. You can replace the model with your own model, or use a pre-trained model (Make sure the model is supported by MCT library). "
],
"id": "cb2a4ee6cf3e5e98"
},
{
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-05T16:04:15.919644Z",
"start_time": "2025-02-05T16:04:15.896518Z"
}
},
"cell_type": "code",
"source": [
"class ObjectDetector(nn.Module):\n",
" def __init__(self):\n",
" super(ObjectDetector, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n",
" self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
" self.fc1 = nn.Linear(16 * 16 * 16, 128)\n",
" self.fc_bbox = nn.Linear(128, 4)\n",
" self.fc_class = nn.Linear(128, 2)\n",
" def __init__(self, num_classes=2, max_detections=20):\n",
" super().__init__()\n",
" self.max_detections = max_detections\n",
"\n",
" self.backbone = nn.Sequential(\n",
" nn.Conv2d(3, 16, kernel_size=3, padding=1),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(2, 2),\n",
" nn.Conv2d(16, 32, kernel_size=3, padding=1),\n",
" nn.ReLU(),\n",
" nn.MaxPool2d(2, 2)\n",
" )\n",
"\n",
" self.bbox_reg = nn.Conv2d(32, 4 * max_detections, kernel_size=1)\n",
" self.class_reg = nn.Conv2d(32, num_classes * max_detections, kernel_size=1)\n",
"\n",
" def forward(self, x):\n",
" x = torch.relu(self.conv1(x))\n",
" x = self.pool1(x)\n",
" x = x.view(-1, 16 * 16 * 16)\n",
" x = torch.relu(self.fc1(x))\n",
" bbox = self.fc_bbox(x)\n",
" class_probs = torch.softmax(self.fc_class(x), dim=1)\n",
" batch_size = x.size(0)\n",
" features = self.backbone(x)\n",
" H_prime = features.shape[2]\n",
" W_prime = features.shape[3]\n",
" \n",
" bbox = self.bbox_reg(features)\n",
" bbox = bbox.view(batch_size, self.max_detections, 4, H_prime * W_prime).mean(dim=3)\n",
" class_probs = self.class_reg(features).view(batch_size, self.max_detections, -1, H_prime * W_prime)\n",
" class_probs = F.softmax(class_probs.mean(dim=2), dim=2)\n",
"\n",
" return bbox, class_probs\n",
"\n",
"model = ObjectDetector()\n",
"model.eval()"
],
"id": "cbe0031bb7f16986",
"outputs": [],
"execution_count": null
"outputs": [
{
"data": {
"text/plain": [
"ObjectDetector(\n",
" (backbone): Sequential(\n",
" (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU()\n",
" (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (4): ReLU()\n",
" (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (bbox_reg): Conv2d(32, 80, kernel_size=(1, 1), stride=(1, 1))\n",
" (class_reg): Conv2d(32, 40, kernel_size=(1, 1), stride=(1, 1))\n",
")"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 80
},
{
"metadata": {},
Expand All @@ -103,11 +173,16 @@
"id": "d653b898460b44e2"
},
{
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-05T16:04:19.230349Z",
"start_time": "2025-02-05T16:04:15.922470Z"
}
},
"cell_type": "code",
"source": [
"n_iters = 20\n",
"batch_size = 4\n",
"NUM_ITERS = 20\n",
"BATCH_SIZE = 32\n",
"\n",
"def get_representative_dataset(n_iter: int):\n",
" \"\"\"\n",
Expand All @@ -120,18 +195,62 @@
" \"\"\"\n",
" def representative_dataset() -> Iterator[List]:\n",
" for _ in range(n_iter):\n",
" yield [torch.rand(1, 3, 64, 64)]\n",
" yield [torch.rand(BATCH_SIZE, 3, 64, 64)]\n",
"\n",
" return representative_dataset\n",
"\n",
"representative_data_generator = get_representative_dataset(n_iter=n_iters)\n",
"representative_data_generator = get_representative_dataset(n_iter=NUM_ITERS)\n",
"\n",
"quant_model, _ = mct.ptq.pytorch_post_training_quantization(model, representative_data_gen=representative_data_generator)\n",
"print('Quantized model is ready')"
],
"id": "72d25144f573ead3",
"outputs": [],
"execution_count": null
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Statistics Collection: 20it [00:02, 6.80it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Running quantization parameters search. This process might take some time, depending on the model size and the selected quantization methods.\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Calculating quantization parameters: 100%|██████████| 14/14 [00:00<00:00, 50.77it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Weights_memory: 8880.0, Activation_memory: 65536.0, Total_memory: 74416.0, BOPS: 569083166720\n",
"\n",
"Please run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n",
"Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n",
"FAQ: https://github.com/sony/model_optimization/tree/main/FAQ.md\n",
"Quantization Troubleshooting: https://github.com/sony/model_optimization/tree/main/quantization_troubleshooting.md\n",
"Quantized model is ready\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"execution_count": 81
},
{
"metadata": {},
Expand All @@ -144,7 +263,12 @@
"id": "3a8d28fed3a87f65"
},
{
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-05T16:04:19.238306Z",
"start_time": "2025-02-05T16:04:19.231720Z"
}
},
"cell_type": "code",
"source": [
"class PostProcessWrapper(nn.Module):\n",
Expand All @@ -168,6 +292,10 @@
" scores = outputs[1]\n",
" nms = multiclass_nms(boxes=boxes, scores=scores, score_threshold=self.score_threshold,\n",
" iou_threshold=self.iou_threshold, max_detections=self.max_detections)\n",
" \"\"\"\n",
" In case you're interested in NMS with indices, you can replace the above with the following code:\n",
" nms = multiclass_nms_with_indices(boxes=boxes, scores=scores, score_threshold=self.score_threshold, iou_threshold=self.iou_threshold, max_detections=self.max_detections)\n",
" \"\"\"\n",
" return nms\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
Expand All @@ -178,8 +306,16 @@
"print('Quantized model with NMS is ready')"
],
"id": "baa386a04a8dd664",
"outputs": [],
"execution_count": null
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Quantized model with NMS is ready\n"
]
}
],
"execution_count": 82
},
{
"metadata": {},
Expand All @@ -192,16 +328,37 @@
"id": "e7ca57539bdc7239"
},
{
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-05T16:04:19.528180Z",
"start_time": "2025-02-05T16:04:19.239579Z"
}
},
"cell_type": "code",
"source": [
"mct.exporter.pytorch_export_model(model=quant_model_with_nms,\n",
" save_model_path='./qmodel_with_nms.onnx',\n",
" repr_dataset=representative_data_generator)"
],
"id": "776a6f99bd0a6efe",
"outputs": [],
"execution_count": null
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Vols/vol_design/tools/swat/users/yardeny/projects/alg-llm-tools/llm_tools_venv/lib/python3.10/site-packages/mct_quantizers/pytorch/quantizers/weights_inferable_quantizers/weights_symmetric_inferable_quantizer.py:52: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
" threshold = torch.tensor(threshold, dtype=torch.float32).to(get_working_device())\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exporting onnx model with MCTQ quantizers: ./qmodel_with_nms.onnx\n"
]
}
],
"execution_count": 83
},
{
"metadata": {},
Expand Down

0 comments on commit f60d0a9

Please sign in to comment.