-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* test * main implementation * PR fixes * add NMS with indices * clear outputs * add package installation * final changes --------- Co-authored-by: yarden-sony <[email protected]>
- Loading branch information
1 parent
4a2bba6
commit 9105133
Showing
1 changed file
with
270 additions
and
0 deletions.
There are no files selected for viewing
270 changes: 270 additions & 0 deletions
270
tutorials/pytorch/multiclass_nms_custom_layer_example.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "27d28731857f2991", | ||
"metadata": {}, | ||
"source": [ | ||
"# Custom Layer Usage Example - multiclass_nms Layer\n", | ||
"\n", | ||
"[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/custom_layers/blob/main/tutorials/pytorch/multiclass_nms_custom_layer_example.ipynb)\n", | ||
" \n", | ||
"\n", | ||
"## Overview\n", | ||
"\n", | ||
"In this tutorial we will illustrate how to integrate a custom layer with model quantization using the MCT library.\n", | ||
"Using a simple object detection model as an example, we will apply post-training quantization, then incorporate a custom NMS layer into the quantized model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "29309a0291ff4f41", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setup\n", | ||
"\n", | ||
"### Install & import relevant packages" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "initial_id", | ||
"metadata": {}, | ||
"source": [ | ||
"!pip install -q torch\n", | ||
"!pip install -q onnx\n", | ||
"!pip install -q model_compression_toolkit\n", | ||
"!pip install -q sony-custom-layers" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "3a7da9c475f95aa9", | ||
"metadata": {}, | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"from typing import Iterator, List\n", | ||
"import model_compression_toolkit as mct\n", | ||
"from sony_custom_layers.pytorch.nms import multiclass_nms" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "cb2a4ee6cf3e5e98", | ||
"metadata": {}, | ||
"source": [ | ||
"## Model Quantization\n", | ||
"\n", | ||
"### Create Model Instance\n", | ||
"\n", | ||
"We will start with creating a simple object-detection model instance as an example. You can replace the model with your own pre-trained model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "cbe0031bb7f16986", | ||
"metadata": {}, | ||
"source": [ | ||
"class ObjectDetector(nn.Module):\n", | ||
" def __init__(self, num_classes=2, max_detections=20):\n", | ||
" super().__init__()\n", | ||
" self.max_detections = max_detections\n", | ||
"\n", | ||
" self.backbone = nn.Sequential(\n", | ||
" nn.Conv2d(3, 16, kernel_size=3, padding=1),\n", | ||
" nn.ReLU(),\n", | ||
" nn.MaxPool2d(2, 2),\n", | ||
" nn.Conv2d(16, 32, kernel_size=3, padding=1),\n", | ||
" nn.ReLU(),\n", | ||
" nn.MaxPool2d(2, 2)\n", | ||
" )\n", | ||
"\n", | ||
" self.bbox_reg = nn.Conv2d(32, 4 * max_detections, kernel_size=1)\n", | ||
" self.class_reg = nn.Conv2d(32, num_classes * max_detections, kernel_size=1)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" batch_size = x.size(0)\n", | ||
" features = self.backbone(x)\n", | ||
" H_prime = features.shape[2]\n", | ||
" W_prime = features.shape[3]\n", | ||
" \n", | ||
" bbox = self.bbox_reg(features)\n", | ||
" bbox = bbox.view(batch_size, self.max_detections, 4, H_prime * W_prime).mean(dim=3)\n", | ||
" class_probs = self.class_reg(features).view(batch_size, self.max_detections, -1, H_prime * W_prime)\n", | ||
" class_probs = F.softmax(class_probs.mean(dim=2), dim=2)\n", | ||
"\n", | ||
" return bbox, class_probs\n", | ||
"\n", | ||
"model = ObjectDetector()\n", | ||
"model.eval()" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d653b898460b44e2", | ||
"metadata": {}, | ||
"source": [ | ||
"### Post-Training Quantization using Model Compression Toolkit\n", | ||
"\n", | ||
"We're all set to use MCT's post-training quantization. \n", | ||
"To begin, we'll define a representative dataset generator. Please note that for demonstration purposes, we will generate random data of the desired image shape instead of using real images. \n", | ||
"Then, we will apply PTQ on our model using the dataset generator we have created." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "72d25144f573ead3", | ||
"metadata": { | ||
"jupyter": { | ||
"is_executing": true | ||
} | ||
}, | ||
"source": [ | ||
"NUM_ITERS = 20\n", | ||
"BATCH_SIZE = 32\n", | ||
"\n", | ||
"def get_representative_dataset(n_iter: int):\n", | ||
" \"\"\"\n", | ||
" This function creates a representative dataset generator. The generator yields numpy\n", | ||
" arrays of batches of shape: [Batch, C, H, W].\n", | ||
" Args:\n", | ||
" n_iter: number of iterations for MCT to calibrate on\n", | ||
" Returns:\n", | ||
" A representative dataset generator\n", | ||
" \"\"\"\n", | ||
" def representative_dataset() -> Iterator[List]:\n", | ||
" for _ in range(n_iter):\n", | ||
" yield [torch.rand(BATCH_SIZE, 3, 64, 64)]\n", | ||
"\n", | ||
" return representative_dataset\n", | ||
"\n", | ||
"representative_data_generator = get_representative_dataset(n_iter=NUM_ITERS)\n", | ||
"\n", | ||
"quant_model, _ = mct.ptq.pytorch_post_training_quantization(model, representative_data_gen=representative_data_generator)\n", | ||
"print('Quantized model is ready')" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3a8d28fed3a87f65", | ||
"metadata": {}, | ||
"source": [ | ||
"## Custom Layer Stitching\n", | ||
"\n", | ||
"Now that we have a quantized model, we can add it a custom layer. In our example we will add NMS layer by creating a model wrapper that applies NMS over the quantized model output. You can use this wrapper for your own model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "baa386a04a8dd664", | ||
"metadata": { | ||
"jupyter": { | ||
"is_executing": true | ||
} | ||
}, | ||
"source": [ | ||
"class PostProcessWrapper(nn.Module):\n", | ||
" def __init__(self,\n", | ||
" model: nn.Module,\n", | ||
" score_threshold: float = 0.001,\n", | ||
" iou_threshold: float = 0.7,\n", | ||
" max_detections: int = 300):\n", | ||
"\n", | ||
" super(PostProcessWrapper, self).__init__()\n", | ||
" self.model = model\n", | ||
" self.score_threshold = score_threshold\n", | ||
" self.iou_threshold = iou_threshold\n", | ||
" self.max_detections = max_detections\n", | ||
"\n", | ||
" def forward(self, images):\n", | ||
" # model inference\n", | ||
" outputs = self.model(images)\n", | ||
"\n", | ||
" boxes = outputs[0]\n", | ||
" scores = outputs[1]\n", | ||
" nms = multiclass_nms(boxes=boxes, scores=scores, score_threshold=self.score_threshold,\n", | ||
" iou_threshold=self.iou_threshold, max_detections=self.max_detections)\n", | ||
" return nms\n", | ||
"\n", | ||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | ||
"quant_model_with_nms = PostProcessWrapper(model=quant_model,\n", | ||
" score_threshold=0.001,\n", | ||
" iou_threshold=0.7,\n", | ||
" max_detections=300).to(device=device)\n", | ||
"print('Quantized model with NMS is ready')" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e7ca57539bdc7239", | ||
"metadata": {}, | ||
"source": [ | ||
"### Model Export\n", | ||
"\n", | ||
"Finally, we can export the quantized model into a .onnx format file. Please ensure that the save_model_path has been set correctly." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "776a6f99bd0a6efe", | ||
"metadata": { | ||
"jupyter": { | ||
"is_executing": true | ||
} | ||
}, | ||
"source": [ | ||
"mct.exporter.pytorch_export_model(model=quant_model_with_nms,\n", | ||
" save_model_path='./qmodel_with_nms.onnx',\n", | ||
" repr_dataset=representative_data_generator)" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bb7c13e41a012f3", | ||
"metadata": {}, | ||
"source": [ | ||
"Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.\n", | ||
"\n", | ||
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n", | ||
"\n", | ||
"http://www.apache.org/licenses/LICENSE-2.0\n", | ||
"Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |