Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎨 Refactor ModelABC to Help Use Default Torch Models #867

Draft
wants to merge 6 commits into
base: dev-define-engines-abc
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/models/test_arch_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel
from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel, infer_batch
from tiatoolbox.models.models_abc import model_to

ON_GPU = False
Expand Down Expand Up @@ -45,7 +45,7 @@ def test_functional() -> None:
for backbone in backbones:
model = CNNModel(backbone, num_classes=1)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand All @@ -72,7 +72,7 @@ def test_timm_functional() -> None:
for backbone in backbones:
model = TimmModel(backbone=backbone, num_classes=1, pretrained=False)
model_ = model_to(device=device, model=model)
model.infer_batch(model_, samples, device=device)
infer_batch(model_, samples, device=device)
except ValueError as exc:
msg = f"Model {backbone} failed."
raise AssertionError(msg) from exc
Expand Down
34 changes: 34 additions & 0 deletions tiatoolbox/models/architecture/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,40 @@ def _get_architecture(
return model.features


def infer_batch(
model: nn.Module,
batch_data: torch.Tensor,
*,
device: str = "cpu",
) -> dict[str, np.ndarray]:
"""Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".

"""
img_patches_device = batch_data.to(device).type(
torch.float32,
) # to NCHW
img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()

# Inference mode
model.eval()
# Do not compute the gradient (not training)
with torch.inference_mode():
output = model(img_patches_device)
# Output should be a single tensor or scalar
return {"probabilities": output.cpu().numpy()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current develop branch, neither CNNModel, nor CNNBackbone returned dictionaries as output of their infer_batch() methods. Also, CNNModel currently returns an array, while CNNBackbone returns a list with the array. It might be fine, just wanted to highlight this.

CNNModel

return output.cpu().numpy()

CNNBackbone

return [output.cpu().numpy()]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. We are aware of this. Our preference is to use torch nn models but to generalise for multi modal output we may need dictionaries. This PR is to check if we can move to generic torch models or we will need a sub class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.



def _get_timm_architecture(
arch_name: str,
*,
Expand Down
3 changes: 2 additions & 1 deletion tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tiatoolbox import DuplicateFilter, logger, rcParam
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.architecture.utils import compile_model
from tiatoolbox.models.architecture.vanilla import infer_batch
from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset
from tiatoolbox.models.models_abc import load_torch_model
from tiatoolbox.utils.misc import (
Expand Down Expand Up @@ -573,7 +574,7 @@ def infer_patches(
zarr_group = zarr.open(save_path, mode="w")

for _, batch_data in enumerate(dataloader):
batch_output = self.model.infer_batch(
batch_output = infer_batch(
self.model,
batch_data["image"],
device=self.device,
Expand Down
24 changes: 0 additions & 24 deletions tiatoolbox/models/models_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,6 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None:
"""Torch method, this contains logic for using layers defined in init."""
... # pragma: no cover

@staticmethod
@abstractmethod
def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> dict:
"""Run inference on an input batch.

Contains logic for forward operation as well as I/O aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (np.ndarray):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Transfers model to the specified device. Default is "cpu".

Returns:
dict:
Returns a dictionary of predictions and other expected outputs
depending on the network architecture.

"""
... # pragma: no cover

@staticmethod
def preproc(image: np.ndarray) -> np.ndarray:
"""Define the pre-processing of this class of model."""
Expand Down
Loading