Skip to content

Commit a72ae42

Browse files
authored
Merge pull request #15 from ranamihir/improvements
Metric improvements
2 parents 284aa97 + a01b699 commit a72ae42

File tree

8 files changed

+221
-60
lines changed

8 files changed

+221
-60
lines changed

pytorch_common/config.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,8 @@ def set_pytorch_config(config: _Config) -> None:
105105
# Fix seed
106106
set_seed(config.seed)
107107

108-
# Check for model and classification type
109-
assert (
110-
config.model_type == "classification" and config.classification_type in ["binary", "multiclass", "multilabel"]
111-
) or (config.model_type == "regression" and not hasattr(config, "classification_type"))
112-
if config.model_type == "regression":
113-
config.classification_type = None
114-
115-
# TODO: Remove this after extending FocalLoss
116-
if config.model_type == "classification" and config.loss_criterion == "focal-loss":
117-
assert (
118-
config.classification_type == "binary"
119-
), "FocalLoss is currently only supported for binary classification."
108+
# Check miscellaneous configurations
109+
check_and_set_misc_config(config)
120110

121111
# Ensure GPU availability as some models are prohibitively slow on CPU
122112
if config.assert_gpu:
@@ -186,7 +176,6 @@ def _check_loss_and_eval_criteria(config: _Config) -> None:
186176
assert config.get("eval_criteria") and isinstance(config.eval_criteria, list)
187177

188178
loss_criteria = CLASSIFICATION_LOSS_CRITERIA if config.model_type == "classification" else REGRESSION_LOSS_CRITERIA
189-
190179
assert config.loss_criterion in loss_criteria, (
191180
f"Loss criterion ('{config.loss_criterion}') "
192181
f"for `model_type=='classification' must be one"
@@ -306,3 +295,29 @@ def set_mode_batch_size(mode: str, batch_size_per_gpu: int) -> None:
306295

307296
# Set per-GPU and total batch size for mode
308297
set_mode_batch_size(mode, batch_size_to_set)
298+
299+
300+
def check_and_set_misc_config(config: _Config) -> None:
301+
"""
302+
Check all miscellaneous configurations, e.g.:
303+
- model_type
304+
- classification_type
305+
"""
306+
# Check for model and classification type
307+
assert (
308+
config.model_type == "classification" and config.classification_type in ["binary", "multiclass", "multilabel"]
309+
) or (config.model_type == "regression" and not hasattr(config, "classification_type"))
310+
311+
# Set classification_type to None if regression
312+
if config.model_type == "regression":
313+
config.classification_type = None
314+
315+
# TODO: Remove this after extending FocalLoss
316+
if config.model_type == "classification" and config.loss_criterion == "focal-loss":
317+
assert (
318+
config.classification_type == "binary"
319+
), "FocalLoss is currently only supported for binary classification."
320+
321+
# Used for dataloader sampling
322+
config.num_batches = config.get("num_batches", None)
323+
config.percentage = config.get("percentage", None)

pytorch_common/datasets.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1+
import numpy as np
2+
from torch.utils.data import DataLoader, Dataset
3+
14
from .additional_configs import BaseDatasetConfig
25
from .datasets_dl import BasePyTorchDataset, DummyMultiClassDataset, DummyMultiLabelDataset, DummyRegressionDataset
6+
from .types import Optional, _Config
7+
from .utils import setup_logging
8+
9+
logger = setup_logging(__name__)
310

411

512
def create_dataset(dataset_name: str, config: BaseDatasetConfig) -> BasePyTorchDataset:
@@ -16,3 +23,46 @@ def create_dataset(dataset_name: str, config: BaseDatasetConfig) -> BasePyTorchD
1623
else:
1724
raise RuntimeError(f"Unknown dataset name {dataset_name}.")
1825
return dataset
26+
27+
28+
def create_dataloader(dataset: Dataset, config: _Config, is_train: Optional[bool] = True) -> DataLoader:
29+
"""
30+
Create a dataloader wrapped
31+
around the given dataset.
32+
33+
Option to sample a subset of the data:
34+
During development, you can just set
35+
`num_batches` or `percentage` to a small
36+
number to run quickly on a sample dataset.
37+
"""
38+
if is_train:
39+
shuffle = True
40+
batch_size = config.train_batch_size
41+
else:
42+
shuffle = False
43+
batch_size = config.eval_batch_size
44+
45+
num_batches, percentage = config.num_batches, config.percentage
46+
47+
if num_batches and percentage:
48+
raise ValueError(
49+
f"At most one of `num_batches` ({num_batches}) or `percentage` ({percentage}) may be specified."
50+
)
51+
52+
elif num_batches or percentage:
53+
n = len(dataset)
54+
if num_batches:
55+
assert num_batches <= np.ceil(n / batch_size).astype(int)
56+
logger.info(f"Sampling {num_batches} batches from whole dataloader.")
57+
sampled_indices = np.random.choice(range(n), size=min(num_batches * batch_size, n))
58+
else:
59+
assert percentage <= 100.0
60+
logger.info(f"Sampling {percentage}% of whole dataset.")
61+
sampled_indices = np.random.choice(range(n), size=int(percentage * n))
62+
63+
dataset.data = dataset.data.iloc[sampled_indices]
64+
logger.info("Done.")
65+
66+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=8, pin_memory=True)
67+
68+
return dataloader

pytorch_common/decorators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def timing_with_param(*parameter_names) -> Callable:
4343
def decorator(func: Callable) -> Callable:
4444
@wraps(func)
4545
def wrapper(*args, **kwargs):
46-
start = time.time()
46+
start = time.perf_counter()
4747
result = func(*args, **kwargs)
48-
end = time.time()
48+
end = time.perf_counter()
4949
elapsed = end - start
5050
elapsed_human = human_time_interval(elapsed)
5151

@@ -54,7 +54,7 @@ def wrapper(*args, **kwargs):
5454
logged_param_str = f" {logged_param}" if logged_param else ""
5555

5656
module_name, function_name = func.__module__, func.__qualname__
57-
PRINT_FUNC(f"Function '{module_name}.{function_name}{logged_param_str}' took {elapsed_human}")
57+
PRINT_FUNC(f"Function '{module_name}.{function_name}{logged_param_str}' took {elapsed_human}.")
5858
return result
5959

6060
return wrapper

pytorch_common/metric_utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def correct_argument_order(func: Callable) -> Callable:
5858
return corrected_func
5959

6060

61+
@torch.no_grad()
6162
def get_mse_loss(y_predicted: torch.Tensor, y_true: torch.Tensor, **kwargs) -> float:
6263
"""
6364
Compute MSE loss.
@@ -75,14 +76,56 @@ def auc_score(y_predicted: torch.Tensor, y_true: torch.Tensor, **kwargs) -> floa
7576
return auc(fpr, tpr)
7677

7778

79+
@torch.no_grad()
80+
def top_k_accuracy_scores(
81+
eval_metrics: List[object], results: _StringDict, y_predicted: torch.Tensor, y_true: torch.Tensor, **kwargs
82+
) -> None:
83+
"""
84+
A batch implementation of `top_k_accuracy_score()`.
85+
86+
It takes all top_k_accuracy-based metrics as input,
87+
and computes the respective metrics efficiently.
88+
If done separately, the top-k indices would need
89+
to be computed separately for each k, while here
90+
it happens only once.
91+
"""
92+
assert len(y_predicted) == len(y_true)
93+
94+
ks = []
95+
for eval_metric in eval_metrics:
96+
k = int(match(eval_metric.criterion, TOP_K_ACCURACY_REGEX)) # Extract k
97+
ks.append(k)
98+
99+
max_k = max(ks)
100+
_, top_indices = torch.topk(y_predicted, max_k, dim=1) # Compute the top `max_k` predicted classes
101+
top_indices = top_indices.t() # Transpose for mathematical convenience
102+
correct_max_k = top_indices.eq(
103+
y_true.long().view(1, -1).expand_as(top_indices)
104+
) # Get correct predictions in top max_k
105+
106+
# Compute top-k accuracy for all k's
107+
for i, k in enumerate(ks):
108+
correct_k = correct_max_k[:k].reshape(-1).float().sum(dim=0, keepdim=True) # Get correct predictions in top k
109+
top_k_accuracy = correct_k / len(y_true) # Divide by batch size (because of transpose earlier)
110+
results[eval_metrics[i].criterion] = top_k_accuracy.item()
111+
112+
113+
@torch.no_grad()
78114
def top_k_accuracy_score(y_predicted: torch.Tensor, y_true: torch.Tensor, **kwargs) -> float:
79115
"""
80116
Compute the top-k accuracy score
81117
in a multi-class setting.
82118
83119
Conversion to numpy is expensive in this
84120
case. Stick to using PyTorch tensors.
121+
122+
Note: This function is not recommended if you have
123+
more than one k that this is to be computed
124+
for. Please use the much more efficient
125+
`top_k_accuracy_scores()` in that case.
85126
"""
127+
assert len(y_predicted) == len(y_true)
128+
86129
k = int(match(kwargs["criterion"], TOP_K_ACCURACY_REGEX)) # Extract k
87130
_, topk_indices = torch.topk(y_predicted, k, dim=1) # Compute the top-k predicted classes
88131
correct_examples = torch.eq(y_true[..., None, ...].long(), topk_indices).any(dim=1)
@@ -191,7 +234,7 @@ def __repr__(self):
191234
},
192235
"auc": {"preprocess_fn": prob_class1, "eval_fn": auc_score, "model_type": "classification"},
193236
"top_k_accuracy": {
194-
"eval_fn": top_k_accuracy_score,
237+
"eval_fn": top_k_accuracy_score, # Not actually used (in favor of `top_k_accuracy_scores()` for efficiecy)
195238
"regex": TOP_K_ACCURACY_REGEX,
196239
"model_type": "classification",
197240
},

pytorch_common/metrics.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
import torch
66
import torch.nn as nn
77

8-
from .metric_utils import EVAL_METRIC_FUNCTIONS, LOSS_CRITERIA, PREPROCESSING_FUNCTIONS, FocalLoss, canonicalize, match
8+
from .metric_utils import (
9+
EVAL_METRIC_FUNCTIONS,
10+
LOSS_CRITERIA,
11+
PREPROCESSING_FUNCTIONS,
12+
FocalLoss,
13+
canonicalize,
14+
match,
15+
top_k_accuracy_scores,
16+
)
917
from .types import *
1018

1119

@@ -232,12 +240,26 @@ def compute_per_class(self, y_predicted: torch.Tensor, y_true: torch.Tensor) ->
232240
"""
233241
results = {}
234242
for preprocess_fn, supported_metrics in PREPROCESSING_FUNCTIONS.items():
235-
metrics_to_compute = [metric_name for metric_name in supported_metrics if metric_name in self.criteria]
236-
if metrics_to_compute:
243+
metrics_in_group = [metric_name for metric_name in supported_metrics if metric_name in self.criteria]
244+
if metrics_in_group:
237245
preprocessed_input = preprocess_fn(y_predicted, y_true) # Share preprocessing
238-
for metric_name in metrics_to_compute:
246+
for metric_name in metrics_in_group:
247+
# Get eval metrics to be computed in this group
239248
eval_metrics = [criterion for criterion in self.criteria if criterion.name == metric_name]
249+
250+
# Separate out all top_k_accuracy-based metrics
251+
top_k_accuracy_metrics, other_metrics = [], []
240252
for eval_metric in eval_metrics:
253+
top_k_accuracy_metrics.append(
254+
eval_metric
255+
) if eval_metric.name == "top_k_accuracy" else other_metrics.append(eval_metric)
256+
257+
# Compute all top_k_accuracy-based metrics together (for efficiency)
258+
for eval_metric in top_k_accuracy_metrics:
259+
top_k_accuracy_scores(top_k_accuracy_metrics, results, *preprocessed_input)
260+
261+
# Compute all other metrics as normal
262+
for eval_metric in other_metrics:
241263
results[eval_metric.criterion] = eval_metric(*preprocessed_input)
242264
return results
243265

0 commit comments

Comments
 (0)