From 6ff9238f5c7096530bba883ebd85a21e240d691e Mon Sep 17 00:00:00 2001 From: Arvind Srikantan Date: Fri, 29 Jul 2022 09:37:58 -0700 Subject: [PATCH] Adding RankMatchFailure metric (#184) * Adding RankMatchFailure metric * Adding tests and fixing a few bugs * Adding missing variable --- .../ml4ir/applications/ranking/config/keys.py | 1 + .../ranking/model/losses/listwise_losses.py | 21 +- .../ranking/model/metrics/metric_factory.py | 4 +- .../ranking/model/metrics/metrics_impl.py | 281 +++++++++++++++++- .../ranking/tests/test_auxiliary_loss.py | 153 +++++++--- .../applications/ranking/tests/test_losses.py | 11 +- .../ranking/tests/test_rank_match_failure.py | 228 ++++++++++++++ python/ml4ir/base/features/feature_layer.py | 1 + .../ml4ir/base/model/metrics/metrics_impl.py | 13 +- python/ml4ir/base/model/relevance_model.py | 154 ++++++---- .../ml4ir/base/model/scoring/scoring_model.py | 6 +- 11 files changed, 743 insertions(+), 130 deletions(-) create mode 100644 python/ml4ir/applications/ranking/tests/test_rank_match_failure.py diff --git a/python/ml4ir/applications/ranking/config/keys.py b/python/ml4ir/applications/ranking/config/keys.py index 3c8f79f9..7a76dd0a 100644 --- a/python/ml4ir/applications/ranking/config/keys.py +++ b/python/ml4ir/applications/ranking/config/keys.py @@ -27,6 +27,7 @@ class MetricKey(Key): ACR = "ACR" NDCG = "NDCG" PRECISION = "Precision" + RankMatchFailure = "RankMatchFailure" CATEGORICAL_ACCURACY = "categorical_accuracy" TOP_5_CATEGORICAL_ACCURACY = "top_5_categorical_accuracy" diff --git a/python/ml4ir/applications/ranking/model/losses/listwise_losses.py b/python/ml4ir/applications/ranking/model/losses/listwise_losses.py index fc0dc56d..f3961812 100644 --- a/python/ml4ir/applications/ranking/model/losses/listwise_losses.py +++ b/python/ml4ir/applications/ranking/model/losses/listwise_losses.py @@ -36,18 +36,24 @@ def _loss_fn(y_true, y_pred): mask : [batch_size, num_classes] """ - #Fixme + # Fixme """ Queries with ties in the highest scores would have multiple one's in the 1-hot vector. Queries with all zeros for y_true would have all ones as their 1-hot vector. A simple remedy is to scale down the loss by the number of ties per query. """ if is_aux_loss: # converting y-true to 1-hot for cce - y_true_1_hot = tf.equal(y_true, tf.expand_dims(tf.math.reduce_max(y_true, axis=1), axis=1)) + y_true_1_hot = tf.equal( + y_true, tf.expand_dims(tf.math.reduce_max(y_true, axis=1), axis=1) + ) y_true_1_hot = tf.cast(y_true_1_hot, dtype=tf.float32) # scaling down the loss of a query by 1/(number of ties) - sample_weights = tf.math.divide(tf.constant(1, dtype=tf.float32), tf.reduce_sum(y_true_1_hot, axis=1)) - return cce(y_true_1_hot, tf.math.multiply(y_pred, mask), sample_weight=sample_weights) + sample_weights = tf.math.divide( + tf.constant(1, dtype=tf.float32), tf.reduce_sum(y_true_1_hot, axis=1) + ) + return cce( + y_true_1_hot, tf.math.multiply(y_pred, mask), sample_weight=sample_weights + ) else: return cce(y_true, tf.math.multiply(y_pred, mask)) @@ -133,7 +139,10 @@ def _loss_fn(y_true, y_pred): y_pred_non_zero = tf.boolean_mask(y_pred, non_zero) # retain values in y_true corresponding to non zero values in y_pred y_true_softmax_masked = tf.boolean_mask(y_true_softmax, non_zero) - return tf.math.divide(-tf.reduce_sum(y_true_softmax_masked * tf.math.log(y_pred_non_zero)), tf.constant(batch_size, dtype=tf.float32)) + return tf.math.divide( + -tf.reduce_sum(y_true_softmax_masked * tf.math.log(y_pred_non_zero)), + tf.constant(batch_size, dtype=tf.float32), + ) else: return -tf.reduce_sum(y_true * tf.math.log(tf.math.multiply(y_pred, mask)), 1) @@ -224,4 +233,4 @@ def _loss_fn(y_true, y_pred): # Scale the sum of losses down by number of queries in the batch return tf.math.divide(bce(y_true, y_pred), batch_size) - return _loss_fn \ No newline at end of file + return _loss_fn diff --git a/python/ml4ir/applications/ranking/model/metrics/metric_factory.py b/python/ml4ir/applications/ranking/model/metrics/metric_factory.py index 3add392d..46346fbc 100644 --- a/python/ml4ir/applications/ranking/model/metrics/metric_factory.py +++ b/python/ml4ir/applications/ranking/model/metrics/metric_factory.py @@ -1,7 +1,7 @@ from tensorflow.keras.metrics import Metric from ml4ir.applications.ranking.config.keys import MetricKey -from ml4ir.applications.ranking.model.metrics.metrics_impl import MRR, ACR +from ml4ir.applications.ranking.model.metrics.metrics_impl import MRR, ACR, RankMatchFailure from ml4ir.applications.classification.model.metrics.metrics_impl import CategoricalAccuracy @@ -25,6 +25,8 @@ def get_metric(metric_key: str) -> Metric: return ACR elif metric_key == MetricKey.NDCG: raise NotImplementedError + elif metric_key == MetricKey.RankMatchFailure: + return RankMatchFailure elif metric_key == MetricKey.CATEGORICAL_ACCURACY: return CategoricalAccuracy else: diff --git a/python/ml4ir/applications/ranking/model/metrics/metrics_impl.py b/python/ml4ir/applications/ranking/model/metrics/metrics_impl.py index afe82d2b..af8e3a68 100644 --- a/python/ml4ir/applications/ranking/model/metrics/metrics_impl.py +++ b/python/ml4ir/applications/ranking/model/metrics/metrics_impl.py @@ -1,14 +1,14 @@ -import tensorflow as tf -from tensorflow.keras import metrics -from tensorflow.python.ops import math_ops +from typing import Optional, Dict + import numpy as np +import tensorflow as tf from tensorflow import Tensor from tensorflow import dtypes +from tensorflow.keras import metrics +from tensorflow.python.ops import math_ops -from ml4ir.base.model.metrics.metrics_impl import MetricState from ml4ir.base.features.feature_config import FeatureConfig - -from typing import Optional, Dict +from ml4ir.base.model.metrics.metrics_impl import MetricState, CombinationMetric class MeanMetricWrapper(metrics.Mean): @@ -68,10 +68,15 @@ def update_state(self, y_true, y_pred, sample_weight=None): `y_true` and `y_pred` should have the same shape. """ query_scores: Tensor = self._fn(y_true, y_pred, **self._fn_kwargs) + if not sample_weight: + sample_weight = self.get_sample_weight(query_scores) return super(MeanMetricWrapper, self).update_state( query_scores, sample_weight=sample_weight ) + def get_sample_weight(self, query_scores): + return None + class MeanRankMetric(MeanMetricWrapper): def __init__( @@ -81,7 +86,7 @@ def __init__( state: str = MetricState.NEW, name="MeanRankMetric", dtype: Optional[dtypes.DType] = None, - **kwargs + **kwargs, ): """ Creates a `MeanRankMetric` instance to compute mean of rank @@ -191,7 +196,7 @@ def __init__( metadata_features: Dict, name="MRR", state=MetricState.NEW, - **kwargs + **kwargs, ): """ Creates a `MRR` instance to compute mean of reciprocal rank @@ -214,7 +219,7 @@ def __init__( metadata_features=metadata_features, name=name, state=state, - **kwargs + **kwargs, ) def _get_matches_hook(self, y_pred_click_ranks): @@ -256,7 +261,7 @@ def __init__( metadata_features: Dict, name="ACR", state=MetricState.NEW, - **kwargs + **kwargs, ): """ Creates a `ACR` instance to compute mean of rank @@ -279,7 +284,7 @@ def __init__( metadata_features=metadata_features, name=name, state=state, - **kwargs + **kwargs, ) def _get_matches_hook(self, y_pred_click_ranks): @@ -297,3 +302,257 @@ def _get_matches_hook(self, y_pred_click_ranks): Ranks tensor cast to float """ return tf.cast(y_pred_click_ranks, tf.float32) + + +class RankMatchFailure(MeanMetricWrapper, CombinationMetric): + def __init__( + self, + feature_config: FeatureConfig, + metadata_features: Dict, + state: str = MetricState.NEW, + name="RankMatchFailure", + dtype: Optional[dtypes.DType] = None, + **kwargs, + ): + """ + Creates a `RankMatchFailure` instance to compute mean of rank + + Parameters + ---------- + name : str + string name of the metric instance. + dtype : str, optional + data type of the metric result. + rank : Tensor object + 2D tensor representing ranks/rankitions of records in a query + mask : Tensor object + 2D tensor representing 0/1 mask for padded records + + Notes + ----- + rank and mask should be same shape as y_pred and y_true + + This metric creates two local variables, `total` and `count` that are used to + compute the frequency with which `y_pred` matches `y_true`. This frequency is + ultimately returned as `categorical accuracy`: an idempotent operation that + simply divides `total` by `count`. + `y_pred` and `y_true` should be passed in as vectors of probabilities, rather + than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector. + If `sample_weight` is `None`, weights default to 1. + Use `sample_weight` of 0 to mask values. + """ + name = "{}_{}".format(state, name) + # TODO: Handle Example dataset without mask and rank fields + rank = metadata_features[feature_config.get_rank("node_name")] + mask = metadata_features[feature_config.get_mask("node_name")] + if not feature_config.aux_label: + raise ValueError( + f"{self.__class__.__qualname__} needs an aux label in the feature config" + ) + y_aux = metadata_features[feature_config.get_aux_label("node_name")] + + super().__init__(self._compute, name, dtype=dtype, y_aux=y_aux, rank=rank, mask=mask) + self.state = state + + def get_sample_weight(self, query_scores): + mask = tf.ones_like(query_scores) + return tf.where( + query_scores == tf.constant(-np.inf, dtype=tf.float32), + tf.constant(0, dtype=tf.float32), + mask, + ) + + def _compute(self, y_true, y_pred, y_aux, rank, mask): + """ + Compute mean rank metric + + Parameters + ---------- + y_true : Tensor object + Tensor object that contains the true label values + y_pred : Tensor object + Tensor object containing the predicted scores + rank : Tensor object + Tensor object that contains the rank of each record for a query + mask : Tensor object + Tensor object that contains 0/1 flag to identify which + records were padded and thus should be excluded from metric computation + """ + if self.state == "new": + """Rerank using trained model""" + # Convert y_pred for the masked records to -inf + y_pred = tf.where(tf.equal(mask, 0), tf.constant(-np.inf), y_pred) + + # Convert predicted ranking scores into ranks for each record per query + # TODO: Currently these ranks are defined below the clicked document too. Scores below the clicked document shouldn't affect the final rank for NDCG + y_pred_ranks = tf.add( + tf.argsort( + tf.argsort(y_pred, axis=-1, direction="DESCENDING", stable=True), stable=True + ), + tf.constant(1), + ) + click_ranks = tf.reduce_sum( + tf.where(tf.equal(tf.cast(y_true, tf.int32), tf.constant(1)), y_pred_ranks, 0), + axis=-1, + ) + + y_true_click_rank = tf.reduce_sum( + tf.where(tf.equal(tf.cast(y_true, tf.int32), tf.constant(1)), rank, 0), axis=-1 + ) + ranks = y_pred_ranks + + else: + """Compute mean rank metric for existing data""" + y_true_click_rank = tf.reduce_sum( + tf.where(tf.equal(tf.cast(y_true, tf.int32), tf.constant(1)), rank, 0), axis=-1 + ) + click_ranks = y_true_click_rank + ranks = rank + # Mask ranks with max possible value so that they are ignored downstream + ranks = tf.where(tf.equal(mask, 0), tf.constant(np.inf), tf.cast(ranks, tf.float32)) + + return self._compute_match_failure( + tf.cast(ranks, tf.float32), + tf.cast(y_true_click_rank, tf.float32), + tf.cast(click_ranks, tf.float32), + tf.cast(y_aux, tf.float32), + ) + + @staticmethod + def _compute_match_failure(ranks, y_true_click_rank, metric_click_ranks, y_aux): + """ + Compute match failure metric for a batch + + Parameters + ---------- + ranks : Tensor object + Tensor object that contains scores for various documents + y_true_click_rank : Tensor object + Tensor object that contains scores for various documents + metric_click_ranks : Tensor object + Tensor object that contains scores for various documents + y_aux : Tensor object + Tensor object that contains scores for various documents + + Returns + ------- + Tensor object + Tensor of Match Failure scores for each query + """ + # Mask all values of y_aux which are below the clicked rank + scores = tf.where( + ranks <= tf.expand_dims(metric_click_ranks, axis=-1), y_aux, tf.constant(-np.inf) + ) + rank_scores = RankMatchFailure.convert_to_rank_scores(scores) + ranks_above_click = tf.where( + ranks <= tf.expand_dims(metric_click_ranks, axis=-1), ranks, tf.constant(np.inf) + ) + num_match = tf.cast(tf.math.count_nonzero(scores > 0, axis=-1), tf.float32) + match_failure = 1 - tf.cast( + RankMatchFailure.normalized_discounted_cumulative_gain(rank_scores, ranks_above_click), + tf.float32, + ) + # If all records<=click have a name match, then it is not an NMF + # If number of scores>0 is same as the clicked rank, all ranks have a name match + match_failure = tf.where( + tf.equal(metric_click_ranks, num_match), + tf.constant(0, dtype=tf.float32), + match_failure, + ) + # No Match Failure when there is no match on the clicked rank + idxs = tf.expand_dims(tf.range(tf.shape(ranks)[0]), -1) + y_true_click_rank = tf.expand_dims(tf.cast(y_true_click_rank, tf.int32), axis=-1) + y_true_click_idx = tf.where(y_true_click_rank > 0, y_true_click_rank - 1, 0) + clicked_records_score = tf.gather_nd( + y_aux, indices=tf.concat([idxs, y_true_click_idx], axis=-1) + ) + match_failure = tf.where( + clicked_records_score == 0.0, tf.constant(-np.inf, dtype=tf.float32), match_failure + ) + return match_failure + + @staticmethod + def convert_to_rank_scores(scores): + """ + Maps each score -> 1/rank for standardizing the score ranges across queries + Parameters + ---------- + scores : Tensor object + Tensor object that contains scores for various documents + + Returns + ------- + Tensor object + Tensor of 1/rank(score) + """ + # Note: 2 scores with same value can get different rank scores. + # There is no sane TF way to handle this today + scores = tf.cast(scores, dtype=tf.float32) + score_rank = tf.add( + tf.argsort( + tf.argsort(scores, axis=-1, direction="DESCENDING", stable=True), stable=True + ), + tf.constant(1), + ) + rank_scores = 1 / score_rank + rank_scores = tf.cast(rank_scores, dtype=tf.float32) + rank_scores = tf.where( + scores == tf.constant(0.0, dtype=tf.float32), + tf.constant(0.0, dtype=tf.float32), + rank_scores, + ) + # -inf is used as mask + rank_scores = tf.where( + scores == tf.constant(-np.inf, dtype=tf.float32), + tf.constant(-np.inf, dtype=tf.float32), + rank_scores, + ) + return rank_scores + + @staticmethod + def discounted_cumulative_gain(relevance_grades, ranks): + """ + Compute the discounted cumulative gain + + Parameters + ---------- + relevance_grades : Tensor object + Tensor object that contains scores in the order of ranks (predicted or ideal) + + Returns + ------- + Tensor object + Tensor of DCG scores along 0th axis + """ + dcg_unmasked = (tf.cast(tf.math.pow(2.0, relevance_grades) - 1, dtype=tf.float32)) / ( + tf.math.log(tf.cast(ranks, dtype=tf.float32) + 1) / tf.math.log(2.0) + ) + # Remove DCG where relevance grade was -inf (mask) + dcg_masked = tf.where(dcg_unmasked < 0, tf.constant(0.0, dtype=tf.float32), dcg_unmasked) + return tf.reduce_sum(dcg_masked, axis=-1) + + @staticmethod + def normalized_discounted_cumulative_gain(relevance_grades, ranks): + """ + Compute the normalized discounted cumulative gain + + Parameters + ---------- + relevance_grades : Tensor object + Tensor object that contains scores in the order of predicted ranks + + Returns + ------- + Tensor object + Tensor of NDCG scores along 0th axis + """ + ideal_ranks = 1 + tf.range(tf.shape(relevance_grades)[1]) + sorted_relevance_grades = tf.sort(relevance_grades, direction="DESCENDING", axis=-1) + dcg_score = RankMatchFailure.discounted_cumulative_gain(relevance_grades, ranks) + idcg_score = RankMatchFailure.discounted_cumulative_gain( + sorted_relevance_grades, ideal_ranks + ) + ndcg_raw = dcg_score / idcg_score + return tf.where( + idcg_score == 0, tf.constant(-np.inf, dtype=tf.float32), ndcg_raw + ) # Handle invalid condition, return -inf diff --git a/python/ml4ir/applications/ranking/tests/test_auxiliary_loss.py b/python/ml4ir/applications/ranking/tests/test_auxiliary_loss.py index 1a326be1..bbc2828b 100644 --- a/python/ml4ir/applications/ranking/tests/test_auxiliary_loss.py +++ b/python/ml4ir/applications/ranking/tests/test_auxiliary_loss.py @@ -1,16 +1,16 @@ -import unittest -import warnings -import pandas as pd -import numpy as np -import pathlib -from testfixtures import TempDirectory import gc import os +import pathlib +import unittest +import warnings +import numpy as np +import pandas as pd import tensorflow.keras.backend as K +from testfixtures import TempDirectory -from ml4ir.applications.ranking.pipeline import RankingPipeline from ml4ir.applications.ranking.config.parse_args import get_args +from ml4ir.applications.ranking.pipeline import RankingPipeline warnings.filterwarnings("ignore") @@ -18,39 +18,64 @@ def train_ml4ir(data_dir, feature_config, model_config, logs_dir, aux_loss): - argv = ["--data_dir", data_dir, - "--feature_config", feature_config, - "--loss_type", "listwise", - "--scoring_type", "listwise", - "--run_id", "test_aux_loss", - "--data_format", "tfrecord", - "--execution_mode", "train_evaluate", - "--loss_key", "softmax_cross_entropy", - "--aux_loss_key", aux_loss, - "--primary_loss_weight", "0.8", - "--aux_loss_weight", "0.2", - "--num_epochs", "1", - "--model_config", model_config, - "--batch_size", "32", - "--logs_dir", logs_dir, - "--max_sequence_size", "25", - "--train_pcent_split", "0.7", - "--val_pcent_split", "0.15", - "--test_pcent_split", "0.15", - "--early_stopping_patience", "25", - "--metrics_keys", "MRR", "categorical_accuracy", - "--monitor_metric", "categorical_accuracy"] + argv = [ + "--data_dir", + data_dir, + "--feature_config", + feature_config, + "--loss_type", + "listwise", + "--scoring_type", + "listwise", + "--run_id", + "test_aux_loss", + "--data_format", + "tfrecord", + "--execution_mode", + "train_evaluate", + "--loss_key", + "softmax_cross_entropy", + "--aux_loss_key", + aux_loss, + "--primary_loss_weight", + "0.8", + "--aux_loss_weight", + "0.2", + "--num_epochs", + "1", + "--model_config", + model_config, + "--batch_size", + "32", + "--logs_dir", + logs_dir, + "--max_sequence_size", + "25", + "--train_pcent_split", + "0.7", + "--val_pcent_split", + "0.15", + "--test_pcent_split", + "0.15", + "--early_stopping_patience", + "25", + "--metrics_keys", + "MRR", + "RankMatchFailure", + "categorical_accuracy", + "--monitor_metric", + "categorical_accuracy", + ] args = get_args(argv) rp = RankingPipeline(args=args) rp.run() class TestDualObjectiveTraining(unittest.TestCase): - def setUp(self): self.dir = pathlib.Path(__file__).parent self.working_dir = TempDirectory() - self.log_dir = self.working_dir.makedir('logs') + self.log_dir = self.working_dir.makedir("logs") def tearDown(self): TempDirectory.cleanup_all() @@ -60,37 +85,79 @@ def tearDown(self): K.clear_session() def test_E2E_aux_softmax_CE(self): - feature_config_path = os.path.join(ROOT_DATA_DIR, "configs", "feature_config_aux_loss.yaml") + feature_config_path = os.path.join( + ROOT_DATA_DIR, "configs", "feature_config_aux_loss.yaml" + ) model_config_path = os.path.join(ROOT_DATA_DIR, "configs", "model_config_cyclic_lr.yaml") data_dir = os.path.join(ROOT_DATA_DIR, "tfrecord") aux_loss = "softmax_cross_entropy" train_ml4ir(data_dir, feature_config_path, model_config_path, self.log_dir, aux_loss) - ml4ir_results = pd.read_csv(os.path.join(self.log_dir, 'test_aux_loss', '_SUCCESS'), header=None) - primary_training_loss = float(ml4ir_results.loc[ml4ir_results[0] == 'train_ranking_score_loss'][1]) + ml4ir_results = pd.read_csv( + os.path.join(self.log_dir, "test_aux_loss", "_SUCCESS"), header=None + ) + primary_training_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "train_ranking_score_loss"][1] + ) assert np.isclose(primary_training_loss, 1.1877643, atol=0.0001) - aux_training_loss = float(ml4ir_results.loc[ml4ir_results[0] == 'train_aux_ranking_score_loss'][1]) + aux_training_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "train_aux_ranking_score_loss"][1] + ) assert np.isclose(aux_training_loss, 1.2242277, atol=0.0001) - primary_val_loss = float(ml4ir_results.loc[ml4ir_results[0] == 'val_ranking_score_loss'][1]) + primary_val_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "val_ranking_score_loss"][1] + ) assert np.isclose(primary_val_loss, 1.2086908, atol=0.0001) - aux_val_loss = float(ml4ir_results.loc[ml4ir_results[0] == 'val_aux_ranking_score_loss'][1]) + aux_val_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "val_aux_ranking_score_loss"][1] + ) + # RankMatchFailure metric comparisons assert np.isclose(aux_val_loss, 1.2806674, atol=0.0001) + # RankMatchFailure metric comparisons + aux_val_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "train_ranking_score_old_RankMatchFailure"][1] + ) + assert np.isclose(aux_val_loss, 0.00082716695, atol=0.00001) + aux_val_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "train_ranking_score_new_RankMatchFailure"][1] + ) + assert np.isclose(aux_val_loss, 0.00012763393, atol=0.00001) + aux_val_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "val_ranking_score_old_RankMatchFailure"][1] + ) + assert np.isclose(aux_val_loss, 0.0011314502, atol=0.00001) + aux_val_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "val_ranking_score_new_RankMatchFailure"][1] + ) + assert np.isclose(aux_val_loss, 0.0002121307, atol=0.00001) def test_E2E_aux_basic_CE(self): - feature_config_path = os.path.join(ROOT_DATA_DIR, "configs", "feature_config_aux_loss.yaml") + feature_config_path = os.path.join( + ROOT_DATA_DIR, "configs", "feature_config_aux_loss.yaml" + ) model_config_path = os.path.join(ROOT_DATA_DIR, "configs", "model_config_cyclic_lr.yaml") data_dir = os.path.join(ROOT_DATA_DIR, "tfrecord") aux_loss = "basic_cross_entropy" train_ml4ir(data_dir, feature_config_path, model_config_path, self.log_dir, aux_loss) - ml4ir_results = pd.read_csv(os.path.join(self.log_dir, 'test_aux_loss', '_SUCCESS'), header=None) - primary_training_loss = float(ml4ir_results.loc[ml4ir_results[0] == 'train_ranking_score_loss'][1]) + ml4ir_results = pd.read_csv( + os.path.join(self.log_dir, "test_aux_loss", "_SUCCESS"), header=None + ) + primary_training_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "train_ranking_score_loss"][1] + ) assert np.isclose(primary_training_loss, 1.1911143, atol=0.0001) - aux_training_loss = float(ml4ir_results.loc[ml4ir_results[0] == 'train_aux_ranking_score_loss'][1]) + aux_training_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "train_aux_ranking_score_loss"][1] + ) assert np.isclose(aux_training_loss, 0.3824733, atol=0.0001) - primary_val_loss = float(ml4ir_results.loc[ml4ir_results[0] == 'val_ranking_score_loss'][1]) + primary_val_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "val_ranking_score_loss"][1] + ) assert np.isclose(primary_val_loss, 1.2130133, atol=0.0001) - aux_val_loss = float(ml4ir_results.loc[ml4ir_results[0] == 'val_aux_ranking_score_loss'][1]) + aux_val_loss = float( + ml4ir_results.loc[ml4ir_results[0] == "val_aux_ranking_score_loss"][1] + ) assert np.isclose(aux_val_loss, 0.3906489, atol=0.0001) diff --git a/python/ml4ir/applications/ranking/tests/test_losses.py b/python/ml4ir/applications/ranking/tests/test_losses.py index f6809994..4949cb8e 100644 --- a/python/ml4ir/applications/ranking/tests/test_losses.py +++ b/python/ml4ir/applications/ranking/tests/test_losses.py @@ -8,7 +8,6 @@ class RankingModelTest(RankingTestBase): - def setUp(self): super().setUp() @@ -55,7 +54,7 @@ def test_softmax_cross_entropy(self): y_pred = activation_op(logits=self.logits, mask=self.mask) assert np.isclose(y_pred[0][0].numpy(), 0.19868991, atol=1e-5) - assert np.isclose(y_pred[2][4].numpy(), 0., atol=1e-5) + assert np.isclose(y_pred[2][4].numpy(), 0.0, atol=1e-5) assert np.isclose(loss_fn(self.y_true, y_pred), 1.306335, atol=1e-5) @@ -68,7 +67,7 @@ def test_basic_softmax_cross_entropy(self): y_pred = activation_op(logits=self.logits, mask=self.mask) assert np.isclose(y_pred[0][0].numpy(), 0.19868991, atol=1e-5) - assert np.isclose(y_pred[2][4].numpy(), 0., atol=1e-5) + assert np.isclose(y_pred[2][4].numpy(), 0.0, atol=1e-5) assert np.isclose(loss_fn(self.y_true_aux, y_pred), 0.75868917, atol=1e-5) @@ -81,7 +80,7 @@ def test_softmax_cross_entropy_auxiliary(self): y_pred = activation_op(logits=self.logits, mask=self.mask) assert np.isclose(y_pred[0][0].numpy(), 0.19868991, atol=1e-5) - assert np.isclose(y_pred[2][4].numpy(), 0., atol=1e-5) + assert np.isclose(y_pred[2][4].numpy(), 0.0, atol=1e-5) assert np.isclose(loss_fn(self.y_true_aux, y_pred), 0.5249801, atol=1e-5) @@ -94,7 +93,7 @@ def test_softmax_cross_entropy_auxiliary_ties(self): y_pred = activation_op(logits=self.logits, mask=self.mask) assert np.isclose(y_pred[0][0].numpy(), 0.19868991, atol=1e-5) - assert np.isclose(y_pred[2][4].numpy(), 0., atol=1e-5) + assert np.isclose(y_pred[2][4].numpy(), 0.0, atol=1e-5) assert np.isclose(loss_fn(self.y_true_aux_ties, y_pred), 4.117315, atol=1e-5) @@ -107,6 +106,6 @@ def test_rank_one_list_net(self): y_pred = activation_op(logits=self.logits, mask=self.mask) assert np.isclose(y_pred[0][0].numpy(), 0.19868991, atol=1e-5) - assert np.isclose(y_pred[2][4].numpy(), 0., atol=1e-5) + assert np.isclose(y_pred[2][4].numpy(), 0.0, atol=1e-5) assert np.isclose(loss_fn(self.y_true, y_pred), 2.1073625, atol=1e-5) diff --git a/python/ml4ir/applications/ranking/tests/test_rank_match_failure.py b/python/ml4ir/applications/ranking/tests/test_rank_match_failure.py new file mode 100644 index 00000000..b6c16079 --- /dev/null +++ b/python/ml4ir/applications/ranking/tests/test_rank_match_failure.py @@ -0,0 +1,228 @@ +import logging +import unittest +from collections import defaultdict + +import numpy as np +import tensorflow as tf +import yaml + +from ml4ir.applications.ranking.model.metrics.metrics_impl import RankMatchFailure +from ml4ir.base.features.feature_config import FeatureConfig +from ml4ir.base.model.metrics.metrics_impl import MetricState + + +class RankMachFailureTest(tf.test.TestCase): + def test_convert_to_rank_scores(self): + scores = tf.constant([ + [5., 4., 3., 2., 1., -np.inf, -np.inf], + # Duplicate scores + [1., 4., 3., 2., 1., -np.inf, -np.inf], + [5., 4., 3., 2., 1., 6., -np.inf], + [1., 2., 3., 4., 5., -np.inf, -np.inf], + [3., 2., 1., 5., 4., -np.inf, -np.inf], + # 0 scores are retained as 0 rank_scorers + [3., 2., 1., 5., 4., 0., 0.], + ]) + actual_rank_scores = RankMatchFailure.convert_to_rank_scores(scores) + expected_rank_scores = tf.constant([ + [1 / 1., 1 / 2., 1 / 3., 1 / 4., 1 / 5., -np.inf, -np.inf], + # Note duplicate scores don't get the same rank score + [1 / 4., 1 / 1., 1 / 2., 1 / 3., 1 / 5., -np.inf, -np.inf], + [1 / 2., 1 / 3., 1 / 4., 1 / 5., 1 / 6., 1 / 1., -np.inf], + [1 / 5., 1 / 4., 1 / 3., 1 / 2., 1 / 1., -np.inf, -np.inf], + [1 / 3., 1 / 4., 1 / 5., 1 / 1., 1 / 2., -np.inf, -np.inf], + # 0 scores are retained as 0 rank_scorers + [1 / 3., 1 / 4., 1 / 5., 1 / 1., 1 / 2., 0, 0], + ]) + tf.debugging.assert_equal(actual_rank_scores, expected_rank_scores) + + def test_normalized_discounted_cumulative_gain(self): + relevance_grades = tf.constant([ + [1 / 1., 1 / 2., 1 / 3., 1 / 4., 1 / 5., -np.inf, -np.inf], + [1 / 4., 1 / 1., 1 / 2., 1 / 3., 1 / 5., -np.inf, -np.inf], + [1 / 2., 1 / 3., 1 / 4., 1 / 5., 1 / 6., 1 / 1., -np.inf], + [1 / 5., 1 / 4., 1 / 3., 1 / 2., 1 / 1., -np.inf, -np.inf], + [1 / 3., 1 / 4., 1 / 5., 1 / 1., 1 / 2., 0, 0], + # This should return -inf as a mask -> not defined in context of rank match failure + [0., 0., 0., 0., 0., -np.inf, -np.inf], + [0., 0., 0., 0., 1., -np.inf, -np.inf], + ]) + ranks = 1 + tf.range(tf.shape(relevance_grades)[1]) + actual_ndcg = RankMatchFailure.normalized_discounted_cumulative_gain( + relevance_grades, ranks=ranks + ) + expected_ndcg = tf.constant( + [1.0, 0.78200406, 0.7245744, 0.62946665, 0.6825818, -np.inf, 0.3868528] + ) + tf.debugging.assert_equal(actual_ndcg, expected_ndcg) + + def test_discounted_cumulative_gain(self): + relevance_grades = tf.constant([ + [1 / 1., 1 / 2., 1 / 3., 1 / 4., 1 / 5., -np.inf, -np.inf], + [1 / 4., 1 / 1., 1 / 2., 1 / 3., 1 / 5., -np.inf, -np.inf], + [1 / 2., 1 / 3., 1 / 4., 1 / 5., 1 / 6., 1 / 1., -np.inf], + [1 / 5., 1 / 4., 1 / 3., 1 / 2., 1 / 1., -np.inf, -np.inf], + [1 / 3., 1 / 4., 1 / 5., 1 / 1., 1 / 2., 0, 0], + [0., 0., 0., 0., 0., -np.inf, -np.inf], + [0., 0., 0., 0., 1., -np.inf, -np.inf], + ]) + ranks = 1 + tf.range(tf.shape(relevance_grades)[1]) + actual_dcg = RankMatchFailure.discounted_cumulative_gain(relevance_grades, ranks=ranks) + expected_dcg = tf.constant( + [1.5303116, 1.1967099, 1.1404319, 0.9632801, 1.0445628, 0.0, 0.3868528] + ) + tf.debugging.assert_equal(actual_dcg, expected_dcg) + + def test__compute_match_failure(self): + y_true_click_rank = tf.constant( + [ + # Clicked record stays in same position + 4.0, + 4.0, + 4.0, + 4.0, + # RR improves/degrades + 4.0, + 4.0, + 4.0, + ] + ) + y_pred_click_ranks = tf.constant( + [ + # Clicked record stays in same position + 4.0, + 4.0, + 4.0, + 4.0, + # RR improves/degrades + 2.0, + 1.0, + 2.0, + ] + ) + y_pred_doc_ranks = tf.constant( + [ + # Clicked record stays in same position + [1.0, 2.0, 3.0, 4.0, 5.0, np.inf, np.inf], + [1.0, 2.0, 5.0, 4.0, 3.0, np.inf, np.inf], + [3.0, 2.0, 5.0, 4.0, 1.0, np.inf, np.inf], + [3.0, 5.0, 2.0, 4.0, 1.0, np.inf, np.inf], + # MRR improves/degrades + [1.0, 4.0, 5.0, 2.0, 3.0, np.inf, np.inf], + [2.0, 4.0, 5.0, 1.0, 3.0, np.inf, np.inf], + [1.0, 4.0, 5.0, 2.0, 3.0, np.inf, np.inf], + ] + ) + y_aux = tf.constant( + [ + # Clicked record stays in same position + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + [0.0, 1.0, 3.0, 4.0, 5.0, 0.0, 0.0], + # MRR improves/degrades + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + ] + ) + actual_rmfs = RankMatchFailure._compute_match_failure( + y_pred_doc_ranks, y_true_click_rank, y_pred_click_ranks, y_aux + ) + expected_rmfs = tf.constant( + [0.42372233, 0.3945347, 0.03515863, 0.03515863, 0.36907023, 0.0, 0.36907023] + ) + self.assertAllClose(actual_rmfs, expected_rmfs, atol=1e-04) + + def test__compute(self): + ranks = tf.constant( + [ + [1.0, 2.0, 3.0, 4.0, 5.0, -np.inf, -np.inf], + [1.0, 2.0, 3.0, 4.0, 5.0, -np.inf, -np.inf], + [1.0, 2.0, 3.0, 4.0, 5.0, -np.inf, -np.inf], + [1.0, 2.0, 3.0, 4.0, 5.0, -np.inf, -np.inf], + [1.0, 2.0, 3.0, 4.0, 5.0, -np.inf, -np.inf], + [1.0, 2.0, 3.0, 4.0, 5.0, -np.inf, -np.inf], + [1.0, 2.0, 3.0, 4.0, 5.0, -np.inf, -np.inf], + ] + ) + y_true = tf.constant( + [ + # Clicked record stays in same position + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + # RR improves/degrades + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + ] + ) + y_pred = tf.constant( + [ + # Clicked record stays in same position + [0.7, 0.15, 0.07, 0.05, 0.03, 0.0, 0.0], + [0.7, 0.15, 0.03, 0.05, 0.07, 0.0, 0.0], + [0.07, 0.15, 0.03, 0.05, 0.7, 0.0, 0.0], + [0.07, 0.03, 0.15, 0.05, 0.7, 0.0, 0.0], + # MRR improves/degrades + [0.7, 0.05, 0.03, 0.15, 0.07, 0.0, 0.0], + [0.15, 0.05, 0.03, 0.7, 0.07, 0.0, 0.0], + [0.07, 0.05, 0.03, 0.15, 0.7, 0.0, 0.0], + ] + ) + y_aux = tf.constant( + [ + # Clicked record stays in same position + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + [0.0, 1.0, 3.0, 4.0, 5.0, 0.0, 0.0], + # MRR improves/degrades + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + [0.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0], + ] + ) + mask = tf.constant( + [ + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0], + ] + ) + with open( + "ml4ir/applications/ranking/tests/data/configs/feature_config_aux_loss.yaml" + ) as feature_config_file: + feature_config: FeatureConfig = FeatureConfig.get_instance( + tfrecord_type="sequence", + feature_config_dict=yaml.safe_load(feature_config_file), + logger=logging.Logger("test_logger"), + ) + rmf = RankMatchFailure( + feature_config, + metadata_features=defaultdict(lambda: tf.constant([1])), + state=MetricState.NEW, + ) + actual_rmfs = rmf._compute(y_true, y_pred, y_aux, ranks, mask) + expected_rmfs = tf.constant( + [ + 0.42372233, + 0.3945347, + 0.03515863, + 0.03515863, + 0.36907023, + 0.0, + 0.0, + ] + ) + self.assertAllClose(actual_rmfs, expected_rmfs, atol=1e-04) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/ml4ir/base/features/feature_layer.py b/python/ml4ir/base/features/feature_layer.py index 05d74697..cecd135f 100644 --- a/python/ml4ir/base/features/feature_layer.py +++ b/python/ml4ir/base/features/feature_layer.py @@ -155,6 +155,7 @@ def feature_layer_op(inputs): tf.expand_dims(tf.gather(inputs["mask"], indices=0), axis=0) ) + # Note we exclude only label. Aux label is included as metadata for feature_info in feature_config.get_all_features(include_label=False): feature_node_name = feature_info.get("node_name", feature_info["name"]) feature_layer_info = feature_info["feature_layer_info"] diff --git a/python/ml4ir/base/model/metrics/metrics_impl.py b/python/ml4ir/base/model/metrics/metrics_impl.py index 6ce8b439..05ce90fd 100644 --- a/python/ml4ir/base/model/metrics/metrics_impl.py +++ b/python/ml4ir/base/model/metrics/metrics_impl.py @@ -1,20 +1,26 @@ +from typing import Dict from typing import Type, List, Union + from tensorflow.keras.metrics import Metric from ml4ir.base.features.feature_config import FeatureConfig -from typing import Dict - class MetricState: OLD = "old" NEW = "new" +class CombinationMetric: + # Metrics from combinations of outputs + pass + + def get_metrics_impl( metrics: List[Union[str, Type[Metric]]], feature_config: FeatureConfig, metadata_features: Dict, + for_aux_output: bool = False, **kwargs ) -> List[Union[Metric, str]]: """ @@ -44,6 +50,9 @@ def get_metrics_impl( if isinstance(metric, str): # If metric is specified as a string, then do nothing metrics_impl.append(metric) + if issubclass(metric, CombinationMetric) and for_aux_output: + # Combination metrics are only defined for main output + continue else: # If metric is a class of type Metric try: diff --git a/python/ml4ir/base/model/relevance_model.py b/python/ml4ir/base/model/relevance_model.py index dee36060..7a5872f9 100644 --- a/python/ml4ir/base/model/relevance_model.py +++ b/python/ml4ir/base/model/relevance_model.py @@ -20,9 +20,10 @@ from ml4ir.base.model.serving import define_serving_signatures from ml4ir.base.model.scoring.prediction_helper import get_predict_fn from ml4ir.base.model.callbacks.debugging import DebuggingCallback -from ml4ir.base.model.calibration.temperature_scaling import temperature_scale,\ - TemperatureScalingLayer -from ml4ir.applications.ranking.config.keys import PositionalBiasHandler +from ml4ir.base.model.calibration.temperature_scaling import ( + temperature_scale, + TemperatureScalingLayer, +) from ml4ir.base.config.keys import LearningRateScheduleKey @@ -132,31 +133,42 @@ def __init__( if self.feature_config.get_aux_label(): # Create model with functional Keras API - self.model = Model(inputs=inputs, outputs={self.output_name: scores, self.aux_output_name: scores}) + self.model = Model( + inputs=inputs, outputs={self.output_name: scores, self.aux_output_name: scores} + ) self.model.output_names = [self.output_name, self.aux_output_name] # Get loss fn loss_fn = scorer.loss[self.output_name].get_loss_fn(**metadata_features) - metadata_features['is_aux_loss'] = True - metadata_features['batch_size'] = self.batch_size + metadata_features["is_aux_loss"] = True + metadata_features["batch_size"] = self.batch_size loss_fn_aux = scorer.loss[self.aux_output_name].get_loss_fn(**metadata_features) - losses = { - self.output_name: loss_fn, - self.aux_output_name: loss_fn_aux} + losses = {self.output_name: loss_fn, self.aux_output_name: loss_fn_aux} - lossWeights = {self.output_name: primary_loss_weight, self.aux_output_name: aux_loss_weight} + loss_weights = { + self.output_name: primary_loss_weight, + self.aux_output_name: aux_loss_weight, + } # Get metric objects metrics_impl: List[Union[str, kmetrics.Metric]] = get_metrics_impl( - metrics=metrics, feature_config=feature_config, metadata_features=metadata_features + metrics=metrics, + feature_config=feature_config, + metadata_features=metadata_features, ) metrics_impl_aux: List[Union[str, kmetrics.Metric]] = get_metrics_impl( - metrics=metrics, feature_config=feature_config, metadata_features=metadata_features + metrics=metrics, + feature_config=feature_config, + metadata_features=metadata_features, + for_aux_output=True, ) self.model.compile( optimizer=optimizer, loss=losses, - loss_weights=lossWeights, - metrics=[metrics_impl, metrics_impl_aux], + loss_weights=loss_weights, + metrics={ + self.output_name: metrics_impl, + self.aux_output_name: metrics_impl_aux, + }, experimental_run_tf_function=False, ) else: @@ -169,7 +181,9 @@ def __init__( # Get metric objects metrics_impl: List[Union[str, kmetrics.Metric]] = get_metrics_impl( - metrics=metrics, feature_config=feature_config, metadata_features=metadata_features + metrics=metrics, + feature_config=feature_config, + metadata_features=metadata_features, ) """ NOTE: @@ -406,25 +420,29 @@ def define_scheduler_as_callback(self, monitor_metric, model_config): The created scheduler callback object. """ - if model_config and 'lr_schedule' in model_config: - lr_schedule = model_config['lr_schedule'] - lr_schedule_key = lr_schedule['key'] + if model_config and "lr_schedule" in model_config: + lr_schedule = model_config["lr_schedule"] + lr_schedule_key = lr_schedule["key"] if lr_schedule_key == LearningRateScheduleKey.REDUCE_LR_ON_PLATEAU: if monitor_metric is None: - reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(factor=lr_schedule.get('factor', 0.5), - patience=lr_schedule.get('patience', 5), - min_lr=lr_schedule.get('min_lr', 0.0001), - mode=lr_schedule.get('mode', 'auto'), - verbose=1) + reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( + factor=lr_schedule.get("factor", 0.5), + patience=lr_schedule.get("patience", 5), + min_lr=lr_schedule.get("min_lr", 0.0001), + mode=lr_schedule.get("mode", "auto"), + verbose=1, + ) else: if not monitor_metric.startswith("val_"): monitor_metric = "val_{}".format(monitor_metric) - reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor=monitor_metric, - factor=lr_schedule.get('factor', 0.5), - patience=lr_schedule.get('patience', 5), - min_lr=lr_schedule.get('min_lr', 0.0001), - mode=lr_schedule.get('mode', 'auto'), - verbose=1) + reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( + monitor=monitor_metric, + factor=lr_schedule.get("factor", 0.5), + patience=lr_schedule.get("patience", 5), + min_lr=lr_schedule.get("min_lr", 0.0001), + mode=lr_schedule.get("mode", "auto"), + verbose=1, + ) return reduce_lr def fit( @@ -469,8 +487,8 @@ def fit( where key is metric name and value is floating point metric value. This dictionary will be used for experiment tracking for each ml4ir run """ - if self.feature_config.aux_label: #val_ranking_score_new_MRR - monitor_metric = 'val_' + self.output_name + '_' + monitor_metric + if self.feature_config.aux_label: # val_ranking_score_new_MRR + monitor_metric = "val_" + self.output_name + "_" + monitor_metric if not monitor_metric.startswith("val_"): monitor_metric = "val_{}".format(monitor_metric) callbacks_list: list = self._build_callback_hooks( @@ -574,9 +592,12 @@ def predict( if logs_dir: np.set_printoptions( - formatter={"all": lambda x: str(x.decode("utf-8")) - if isinstance(x, bytes) else str(x)}, - linewidth=sys.maxsize, threshold=sys.maxsize) # write the full line in the csv not the truncated version. + formatter={ + "all": lambda x: str(x.decode("utf-8")) if isinstance(x, bytes) else str(x) + }, + linewidth=sys.maxsize, + threshold=sys.maxsize, + ) # write the full line in the csv not the truncated version. # Decode bytes features to strings for col in predictions_df.columns: @@ -691,7 +712,7 @@ def save( pad_sequence: bool = False, sub_dir: str = "final", dataset: Optional[RelevanceDataset] = None, - experiment_details: Optional[dict] = None + experiment_details: Optional[dict] = None, ): """ Save the RelevanceModel as a tensorflow SavedModel to the `models_dir` @@ -774,7 +795,10 @@ def save( ) except FileNotFoundError: self.logger.warning( - "Error saving layer: {} due to FileNotFoundError. Skipping...".format(layer.name)) + "Error saving layer: {} due to FileNotFoundError. Skipping...".format( + layer.name + ) + ) self.logger.info("Final model saved to : {}".format(model_file)) @@ -908,7 +932,9 @@ def _build_callback_hooks( callbacks_list.append(DebuggingCallback(self.logger, logging_frequency)) # Adding lr scheduler as a callback; used for `ReduceLROnPlateau` which we treat today as a callback - scheduler_callback = self.define_scheduler_as_callback(monitor_metric, self.scorer.model_config) + scheduler_callback = self.define_scheduler_as_callback( + monitor_metric, self.scorer.model_config + ) if scheduler_callback: callbacks_list.append(scheduler_callback) @@ -916,8 +942,9 @@ def _build_callback_hooks( return callbacks_list - def calibrate(self, relevance_dataset, logger, logs_dir_local, **kwargs)\ - -> Tuple[np.ndarray, ...]: + def calibrate( + self, relevance_dataset, logger, logs_dir_local, **kwargs + ) -> Tuple[np.ndarray, ...]: """Calibrate model with temperature scaling Parameters ---------- @@ -936,15 +963,19 @@ def calibrate(self, relevance_dataset, logger, logs_dir_local, **kwargs)\ """ logger.info("=" * 50) logger.info("Calibrating the model with temperature scaling") - return temperature_scale(model=self.model, - scorer=self.scorer, - dataset=relevance_dataset, - logger=logger, - logs_dir_local=logs_dir_local, - file_io=self.file_io, - **kwargs) - - def add_temperature_layer(self, temperature: float = 1.0, layer_name: str = 'temperature_layer'): + return temperature_scale( + model=self.model, + scorer=self.scorer, + dataset=relevance_dataset, + logger=logger, + logs_dir_local=logs_dir_local, + file_io=self.file_io, + **kwargs, + ) + + def add_temperature_layer( + self, temperature: float = 1.0, layer_name: str = "temperature_layer" + ): """Add temperature layer to the input of last activation (softmax) layer Parameters ---------- @@ -962,25 +993,32 @@ def add_temperature_layer(self, temperature: float = 1.0, layer_name: str = 'tem """ # get last layer's output --> MUST **NOT** BE AN ACTIVATION (e.g. SOFTMAX) LAYER - final_layer_name = self.scorer.model_config['layers'][-1]['name'] + final_layer_name = self.scorer.model_config["layers"][-1]["name"] final_layer = self.model.get_layer(name=final_layer_name).output - temperature_layer = TemperatureScalingLayer(name=layer_name, - temperature=temperature)(final_layer) + temperature_layer = TemperatureScalingLayer(name=layer_name, temperature=temperature)( + final_layer + ) # using the `last layer` as final activation function before computing loss idx_activation = -1 - if len(self.model.layers) > 0 and isinstance(self.model.layers[idx_activation], - tf.keras.layers.Activation): + if len(self.model.layers) > 0 and isinstance( + self.model.layers[idx_activation], tf.keras.layers.Activation + ): # creating new activation layer activation_layer_name = self.model.get_layer(index=idx_activation).name activation_function = self.model.get_layer(index=idx_activation).activation activation_layer = tf.keras.layers.Activation( - activation_function, name=activation_layer_name)(temperature_layer) + activation_function, name=activation_layer_name + )(temperature_layer) # creating new keras Functional API model self.model = Model(self.model.inputs, activation_layer) - self.logger.info(f'Temperature Scaling layer added and new Functional API model' - f' replaced; temperature = {temperature}.') + self.logger.info( + f"Temperature Scaling layer added and new Functional API model" + f" replaced; temperature = {temperature}." + ) else: - self.logger.info("Skipping adding Temperature Scaling layer because no activation " - "exist in the last layer of Keras original model!") + self.logger.info( + "Skipping adding Temperature Scaling layer because no activation " + "exist in the last layer of Keras original model!" + ) diff --git a/python/ml4ir/base/model/scoring/scoring_model.py b/python/ml4ir/base/model/scoring/scoring_model.py index 80e224e4..5597989e 100644 --- a/python/ml4ir/base/model/scoring/scoring_model.py +++ b/python/ml4ir/base/model/scoring/scoring_model.py @@ -7,7 +7,7 @@ from ml4ir.base.io.file_io import FileIO from logging import Logger -from typing import Dict, Optional +from typing import Dict, Optional, Union, List class ScorerBase(object): @@ -29,7 +29,7 @@ def __init__( model_config: dict, feature_config: FeatureConfig, interaction_model: InteractionModel, - loss: RelevanceLossBase, + loss: Union[RelevanceLossBase, List[RelevanceLossBase]], file_io: FileIO, output_name: str = "score", logger: Optional[Logger] = None, @@ -110,7 +110,7 @@ def from_model_config_file( loss=loss, file_io=file_io, output_name=output_name, - logger=logger + logger=logger, ) def __call__(self, inputs: Dict[str, Input]):