diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index d001436a3ad..eb54f086c0a 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -19,6 +19,7 @@ from ignite.metrics.precision import Precision from ignite.metrics.psnr import PSNR from ignite.metrics.recall import Recall +from ignite.metrics.recsys.ndcg import NDCG from ignite.metrics.root_mean_squared_error import RootMeanSquaredError from ignite.metrics.running_average import RunningAverage from ignite.metrics.ssim import SSIM diff --git a/ignite/metrics/recsys/__init__.py b/ignite/metrics/recsys/__init__.py new file mode 100644 index 00000000000..71e737cc0bd --- /dev/null +++ b/ignite/metrics/recsys/__init__.py @@ -0,0 +1,5 @@ +from ignite.metrics.recsys.ndcg import NDCG + +__all__ = [ + "NDCG", +] diff --git a/ignite/metrics/recsys/ndcg.py b/ignite/metrics/recsys/ndcg.py new file mode 100644 index 00000000000..0545055e5e4 --- /dev/null +++ b/ignite/metrics/recsys/ndcg.py @@ -0,0 +1,122 @@ +from typing import Callable, Optional, Sequence, Union + +import torch + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["NDCG"] + + +def _tie_averaged_dcg( + y_pred: torch.Tensor, + y_true: torch.Tensor, + discount_cumsum: torch.Tensor, + device: Union[str, torch.device] = torch.device("cpu"), +) -> torch.Tensor: + + _, inv, counts = torch.unique(-y_pred, return_inverse=True, return_counts=True) + ranked = torch.zeros(counts.shape[0]).to(device) + ranked.index_put_([inv], y_true, accumulate=True) + ranked /= counts + groups = torch.cumsum(counts, dim=-1) - 1 + discount_sums = torch.empty(counts.shape[0]).to(device) + discount_sums[0] = discount_cumsum[groups[0]] + discount_sums[1:] = torch.diff(discount_cumsum[groups]) + + return torch.sum(torch.mul(ranked, discount_sums)) + + +def _dcg_sample_scores( + y_pred: torch.Tensor, + y_true: torch.Tensor, + k: Optional[int] = None, + log_base: Union[int, float] = 2, + ignore_ties: bool = False, + device: Union[str, torch.device] = torch.device("cpu"), +) -> torch.Tensor: + + discount = torch.log(torch.tensor(log_base)) / torch.log(torch.arange(y_true.shape[1]) + 2) + discount = discount.to(device) + + if k is not None: + discount[k:] = 0.0 + + if ignore_ties: + ranking = torch.argsort(y_pred, descending=True) + ranked = y_true[torch.arange(ranking.shape[0]).reshape(-1, 1), ranking].to(device) + discounted_gains = torch.mm(ranked, discount.reshape(-1, 1)) + + else: + discount_cumsum = torch.cumsum(discount, dim=-1) + discounted_gains = torch.tensor( + [_tie_averaged_dcg(y_p, y_t, discount_cumsum, device) for y_p, y_t in zip(y_pred, y_true)], device=device + ) + + return discounted_gains + + +def _ndcg_sample_scores( + y_pred: torch.Tensor, + y_true: torch.Tensor, + k: Optional[int] = None, + log_base: Union[int, float] = 2, + ignore_ties: bool = False, +) -> torch.Tensor: + + device = y_true.device + gain = _dcg_sample_scores(y_pred, y_true, k=k, log_base=log_base, ignore_ties=ignore_ties, device=device) + if not ignore_ties: + gain = gain.unsqueeze(dim=-1) + normalizing_gain = _dcg_sample_scores(y_true, y_true, k=k, log_base=log_base, ignore_ties=True, device=device) + all_relevant = normalizing_gain != 0 + normalized_gain = gain[all_relevant] / normalizing_gain[all_relevant] + return normalized_gain + + +class NDCG(Metric): + def __init__( + self, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + k: Optional[int] = None, + log_base: Union[int, float] = 2, + exponential: bool = False, + ignore_ties: bool = False, + ): + + if log_base == 1 or log_base <= 0: + raise ValueError(f"Argument log_base should positive and not equal one,but got {log_base}") + self.log_base = log_base + self.k = k + self.exponential = exponential + self.ignore_ties = ignore_ties + super(NDCG, self).__init__(output_transform=output_transform, device=device) + + @reinit__is_reduced + def reset(self) -> None: + + self.num_examples = 0 + self.ndcg = torch.tensor(0.0, device=self._device) + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + + y_pred, y_true = output[0].detach(), output[1].detach() + + y_pred = y_pred.to(torch.float32).to(self._device) + y_true = y_true.to(torch.float32).to(self._device) + + if self.exponential: + y_true = 2 ** y_true - 1 + + gain = _ndcg_sample_scores(y_pred, y_true, k=self.k, log_base=self.log_base, ignore_ties=self.ignore_ties) + self.ndcg += torch.sum(gain) + self.num_examples += y_pred.shape[0] + + @sync_all_reduce("ndcg", "num_examples") + def compute(self) -> float: + if self.num_examples == 0: + raise NotComputableError("NGCD must have at least one example before it can be computed.") + + return (self.ndcg / self.num_examples).item() diff --git a/tests/ignite/metrics/test_ndcg.py b/tests/ignite/metrics/test_ndcg.py new file mode 100644 index 00000000000..2429b5bc4a6 --- /dev/null +++ b/tests/ignite/metrics/test_ndcg.py @@ -0,0 +1,266 @@ +import os + +import numpy as np +import pytest +import torch +from sklearn.metrics import ndcg_score +from sklearn.metrics._ranking import _dcg_sample_scores + +import ignite.distributed as idist +from ignite.engine import Engine + +from ignite.exceptions import NotComputableError +from ignite.metrics.recsys.ndcg import NDCG + + +@pytest.fixture(params=[item for item in range(6)]) +def test_case(request): + + return [ + (torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])), + ( + torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7], [3.7, 3.7, 3.7, 3.7, 3.9]]), + torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]), + ), + ][request.param % 2] + + +@pytest.mark.parametrize("k", [None, 2, 3]) +@pytest.mark.parametrize("exponential", [True, False]) +@pytest.mark.parametrize("ignore_ties, replacement", [(True, False), (False, True), (False, False)]) +def test_output_cpu(test_case, k, exponential, ignore_ties, replacement): + + device = "cpu" + y_pred_distribution, y_true = test_case + + y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement) + + ndcg = NDCG(k=k, device=device, exponential=exponential, ignore_ties=ignore_ties) + ndcg.update([y_pred, y_true]) + result_ignite = ndcg.compute() + + if exponential: + y_true = 2 ** y_true - 1 + + result_sklearn = ndcg_score(y_true.numpy(), y_pred.numpy(), k=k, ignore_ties=ignore_ties) + + np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) + + +@pytest.mark.parametrize("k", [None, 2, 3]) +@pytest.mark.parametrize("exponential", [True, False]) +@pytest.mark.parametrize("ignore_ties, replacement", [(True, False), (False, True), (False, False)]) +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_output_cuda(test_case, k, exponential, ignore_ties, replacement): + + device = "cuda" + y_pred_distribution, y_true = test_case + + y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement) + + y_pred = y_pred.to(device) + y_true = y_true.to(device) + + ndcg = NDCG(k=k, device=device, exponential=exponential, ignore_ties=ignore_ties) + ndcg.update([y_pred, y_true]) + result_ignite = ndcg.compute() + + if exponential: + y_true = 2 ** y_true - 1 + + result_sklearn = ndcg_score(y_true.cpu().numpy(), y_pred.cpu().numpy(), k=k, ignore_ties=ignore_ties) + + np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) + + +def test_reset(): + + y_true = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + y_pred = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]) + ndcg = NDCG() + ndcg.update([y_pred, y_true]) + ndcg.reset() + + with pytest.raises(NotComputableError, match=r"NGCD must have at least one example before it can be computed."): + ndcg.compute() + + +def _ndcg_sample_scores(y_true, y_score, k=None, ignore_ties=False): + + gain = _dcg_sample_scores(y_true, y_score, k, ignore_ties=ignore_ties) + normalizing_gain = _dcg_sample_scores(y_true, y_true, k, ignore_ties=True) + all_irrelevant = normalizing_gain == 0 + gain[all_irrelevant] = 0 + gain[~all_irrelevant] /= normalizing_gain[~all_irrelevant] + return gain + + +@pytest.mark.parametrize("log_base", [2, 3, 10]) +def test_log_base(log_base): + def ndcg_score_with_log_base(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False, log_base=2): + + gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties) + return np.average(gain, weights=sample_weight) + + y_true = torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]) + y_pred = torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]]) + + ndcg = NDCG(log_base=log_base) + ndcg.update([y_pred, y_true]) + + result_ignite = ndcg.compute() + result_sklearn = ndcg_score_with_log_base(y_true.numpy(), y_pred.numpy(), log_base=log_base) + + np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) + + +def test_update(test_case): + + y_pred, y_true = test_case + + y_pred = y_pred + y_true = y_true + + y1_pred = torch.multinomial(y_pred, 5, replacement=True) + y1_true = torch.multinomial(y_true, 5, replacement=True) + + y2_pred = torch.multinomial(y_pred, 5, replacement=True) + y2_true = torch.multinomial(y_true, 5, replacement=True) + + y_pred_combined = torch.cat((y1_pred, y2_pred)) + y_true_combined = torch.cat((y1_true, y2_true)) + + ndcg = NDCG() + + ndcg.update([y1_pred, y1_true]) + ndcg.update([y2_pred, y2_true]) + + result_ignite = ndcg.compute() + + result_sklearn = ndcg_score(y_true_combined.numpy(), y_pred_combined.numpy()) + + np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) + + +def _test_distrib_output(device): + + rank = idist.get_rank() + + def _test(n_epochs, metric_device): + + metric_device = torch.device(metric_device) + + n_iters = 5 + batch_size = 8 + n_items = 5 + + torch.manual_seed(12 + rank) + + y_true = torch.rand((n_iters * batch_size, n_items)).to(device) + y_preds = torch.rand((n_iters * batch_size, n_items)).to(device) + + def update(engine, i): + return ( + y_preds[i * batch_size : (i + 1) * batch_size, ...], + y_true[i * batch_size : (i + 1) * batch_size, ...], + ) + + engine = Engine(update) + + ndcg = NDCG(device=metric_device) + ndcg.attach(engine, "ndcg") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + y_true = idist.all_gather(y_true) + y_preds = idist.all_gather(y_preds) + + assert ( + ndcg._device == metric_device + ), f"{type(ndcg._device)}:{ndcg._device} vs {type(metric_device)}:{metric_device}" + + assert "ndcg" in engine.state.metrics + res = engine.state.metrics["ndcg"] + if isinstance(res, torch.Tensor): + res = res.cpu().numpy() + + true_res = ndcg_score(y_true.cpu().numpy(), y_preds.cpu().numpy()) + assert pytest.approx(res) == true_res + + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + for _ in range(2): + _test(n_epochs=1, metric_device=metric_device) + _test(n_epochs=2, metric_device=metric_device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): + + device = idist.device() + _test_distrib_output(device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): + + device = idist.device() + _test_distrib_output(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_distrib_nccl_gpu(distributed_context_single_node_nccl): + + device = idist.device() + _test_distrib_output(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): + + device = idist.device() + _test_distrib_output(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") +@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") +def test_distrib_hvd(gloo_hvd_executor): + + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() + + gloo_hvd_executor(_test_distrib_output, (device,), np=nproc, do_init=True) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_single_device_xla(): + + device = idist.device() + _test_distrib_output(device) + + +def _test_distrib_xla_nprocs(index): + + device = idist.device() + _test_distrib_output(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_xla_nprocs(xmp_executor): + n = int(os.environ["NUM_TPU_WORKERS"]) + xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)