Skip to content

Commit

Permalink
Improve eval_model_on_test; packaging updates
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Jul 8, 2024
1 parent 67d30c7 commit 846e8fe
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 21 deletions.
47 changes: 30 additions & 17 deletions ebtorch/nn/utils/evalutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
# SPDX-License-Identifier: Apache-2.0
#
from collections.abc import Callable
from typing import Optional
from typing import Tuple
from typing import Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

__all__ = [
Expand All @@ -37,13 +39,16 @@

def eval_model_on_test( # NOSONAR
model: Module,
model_is_classifier: bool,
test_data_loader,
test_data_loader: DataLoader,
device: torch.device,
criterion_non_classifier: Callable,
model_is_classifier: bool = True,
criterion_non_classifier: Optional[Callable] = None,
extract_z_non_classifier: bool = False,
verbose: bool = False,
) -> Union[Union[int, float], Tuple[Union[int, float], Tensor, Tensor]]:
trackingmetric: Union[int, float]

if not model_is_classifier and criterion_non_classifier is None:
raise ValueError("Criterion must be provided for non-classifier models.")

num_elem: int = 0
if model_is_classifier:
Expand All @@ -54,40 +59,48 @@ def eval_model_on_test( # NOSONAR
model.eval()

with torch.no_grad():
for batch_idx_e, batched_datapoint_e in tqdm( # type: ignore
for batch_idx_e, batched_datapoint_e in tqdm(
enumerate(test_data_loader),
total=len(test_data_loader),
desc="Testing batch",
leave=False,
disable=True,
disable=not verbose,
):

# Explicitly type-hint `x_e` and `y_e`
x_e: Tensor
y_e: Tensor

if model_is_classifier:
x_e, y_e = batched_datapoint_e
x_e, y_e = x_e.to(device), y_e.to(device)
modeltarget_e = model(x_e)
ypred_e = torch.argmax(modeltarget_e, dim=1, keepdim=True)
modeltarget_e: Tensor = model(x_e)
ypred_e: Tensor = torch.argmax(modeltarget_e, dim=1, keepdim=True)
trackingmetric += ypred_e.eq(y_e.view_as(ypred_e)).sum().item()
else:

else: # not model_is_classifier
x_e, y_e = batched_datapoint_e
x_e = x_e.to(device)
modeltarget_e_tuple = model(x_e)
modeltarget_e = modeltarget_e_tuple[0]
x_e: Tensor = x_e.to(device)
modeltarget_e_tuple: Tuple[Tensor, Tensor] = model(x_e)
modeltarget_e: Tensor = modeltarget_e_tuple[0]
if extract_z_non_classifier:
z_to_cat = modeltarget_e_tuple[1]
z = (
z_to_cat: Tensor = modeltarget_e_tuple[1]
z: Tensor = (
torch.cat(tensors=(z, z_to_cat), dim=0)
if batch_idx_e > 0
else z_to_cat
)
y_e_to_cat = y_e
y_e = (
y_e_to_cat: Tensor = y_e
y_e: Tensor = (
torch.cat(tensors=(y_e, y_e_to_cat), dim=0)
if batch_idx_e > 0
else y_e_to_cat
)
trackingmetric += criterion_non_classifier(modeltarget_e, x_e).item()

num_elem += x_e.shape[0]
if extract_z_non_classifier:

if extract_z_non_classifier and not model_is_classifier:
return trackingmetric / num_elem, z, y_e
else:
return trackingmetric / num_elem
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def read(fname):

setup(
name=PACKAGENAME,
version="0.25.9",
version="0.25.10",
author="Emanuele Ballarin",
author_email="[email protected]",
url="https://github.com/emaballarin/ebtorch",
Expand All @@ -45,7 +45,7 @@ def read(fname):
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
],
python_requires=">=3.10",
python_requires=">=3.11",
install_requires=[
"advertorch>=0.2.4", # pip install git+https://github.com/BorealisAI/advertorch.git
"matplotlib>=3.8",
Expand All @@ -59,6 +59,6 @@ def read(fname):
"torchvision>=0.15",
"tqdm>=4.65",
],
include_package_data=True,
zip_safe=False,
include_package_data=False,
zip_safe=True,
)

0 comments on commit 846e8fe

Please sign in to comment.