From 8f76d87d430f070d9e8ca9ecf52dba1c5e5ee097 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:13:17 +0100 Subject: [PATCH 1/2] :art: Refactor `ModelABC` to Help Use Default Torch Models Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tiatoolbox/models/architecture/vanilla.py | 103 +++++++--------------- tiatoolbox/models/engine/engine_abc.py | 3 +- tiatoolbox/models/models_abc.py | 24 ----- 3 files changed, 36 insertions(+), 94 deletions(-) diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index e7b956411..410702d17 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -78,6 +78,40 @@ def _get_architecture( return model.features +def infer_batch( + model: nn.Module, + batch_data: torch.Tensor, + *, + device: str = "cpu", +) -> dict[str, np.ndarray]: + """Run inference on an input batch. + + Contains logic for forward operation as well as i/o aggregation. + + Args: + model (nn.Module): + PyTorch defined model. + batch_data (torch.Tensor): + A batch of data generated by + `torch.utils.data.DataLoader`. + device (str): + Transfers model to the specified device. Default is "cpu". + + """ + img_patches_device = batch_data.to(device).type( + torch.float32, + ) # to NCHW + img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() + + # Inference mode + model.eval() + # Do not compute the gradient (not training) + with torch.inference_mode(): + output = model(img_patches_device) + # Output should be a single tensor or scalar + return {"probabilities": output.cpu().numpy()} + + class CNNModel(ModelABC): """Retrieve the model backbone and attach an extra FCN to perform classification. @@ -137,40 +171,6 @@ def postproc(image: np.ndarray) -> np.ndarray: """ return np.argmax(image, axis=-1) - @staticmethod - def infer_batch( - model: nn.Module, - batch_data: torch.Tensor, - *, - device: str = "cpu", - ) -> dict[str, np.ndarray]: - """Run inference on an input batch. - - Contains logic for forward operation as well as i/o aggregation. - - Args: - model (nn.Module): - PyTorch defined model. - batch_data (torch.Tensor): - A batch of data generated by - `torch.utils.data.DataLoader`. - device (str): - Transfers model to the specified device. Default is "cpu". - - """ - img_patches_device = batch_data.to(device).type( - torch.float32, - ) # to NCHW - img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() - - # Inference mode - model.eval() - # Do not compute the gradient (not training) - with torch.inference_mode(): - output = model(img_patches_device) - # Output should be a single tensor or scalar - return {"probabilities": output.cpu().numpy()} - class CNNBackbone(ModelABC): """Retrieve the model backbone and strip the classification layer. @@ -233,38 +233,3 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: feat = self.feat_extract(imgs) gap_feat = self.pool(feat) return torch.flatten(gap_feat, 1) - - @staticmethod - def infer_batch( - model: nn.Module, - batch_data: torch.Tensor, - *, - device: str, - ) -> dict[str, np.ndarray]: - """Run inference on an input batch. - - Contains logic for forward operation as well as i/o aggregation. - - Args: - model (nn.Module): - PyTorch defined model. - batch_data (torch.Tensor): - A batch of data generated by - `torch.utils.data.DataLoader`. - device (str): - Transfers model to the specified device. Default is "cpu". - - """ - img_patches_device = batch_data.to(device).type( - torch.float32, - ) # to NCHW - img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() - - # Inference mode - model.eval() - # Do not compute the gradient (not training) - with torch.inference_mode(): - output = model(img_patches_device) - - # Output should be a single tensor or scalar - return {"probabilities": output.cpu().numpy()} diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 465230116..82a5fd856 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -17,6 +17,7 @@ from tiatoolbox import DuplicateFilter, logger from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.architecture.vanilla import infer_batch from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils.misc import ( @@ -573,7 +574,7 @@ def infer_patches( zarr_group = zarr.open(save_path, mode="w") for _, batch_data in enumerate(dataloader): - batch_output = self.model.infer_batch( + batch_output = infer_batch( self.model, batch_data["image"], device=self.device, diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index a27eb670c..c7a58f1be 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -73,30 +73,6 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None: """Torch method, this contains logic for using layers defined in init.""" ... # pragma: no cover - @staticmethod - @abstractmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> dict: - """Run inference on an input batch. - - Contains logic for forward operation as well as I/O aggregation. - - Args: - model (nn.Module): - PyTorch defined model. - batch_data (np.ndarray): - A batch of data generated by - `torch.utils.data.DataLoader`. - device (str): - Transfers model to the specified device. Default is "cpu". - - Returns: - dict: - Returns a dictionary of predictions and other expected outputs - depending on the network architecture. - - """ - ... # pragma: no cover - @staticmethod def preproc(image: np.ndarray) -> np.ndarray: """Define the pre-processing of this class of model.""" From 76c7972060250e3c78a05973599697d25293ade5 Mon Sep 17 00:00:00 2001 From: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:52:47 +0100 Subject: [PATCH 2/2] :white_check_mark: Fix test Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com> --- tests/models/test_arch_vanilla.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index cfae665b2..3d37439a1 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -4,7 +4,7 @@ import pytest import torch -from tiatoolbox.models.architecture.vanilla import CNNModel +from tiatoolbox.models.architecture.vanilla import CNNModel, infer_batch from tiatoolbox.models.models_abc import model_to from tiatoolbox.utils.misc import select_device @@ -46,7 +46,7 @@ def test_functional() -> None: for backbone in backbones: model = CNNModel(backbone, num_classes=1) model_ = model_to(device=device, model=model) - model.infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU)) + infer_batch(model_, samples, device=select_device(on_gpu=ON_GPU)) except ValueError as exc: msg = f"Model {backbone} failed." raise AssertionError(msg) from exc