diff --git a/README.md b/README.md index a199ff1c..081dd1ef 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ **Quickstart →** **[ml4ir Read the Docs](https://ml4ir.readthedocs.io/en/latest/)** | **[ml4ir pypi](https://pypi.org/project/ml4ir/)** | **[python ReadMe](python/)** + ml4ir is an open source library for training and deploying deep learning models for search applications. ml4ir is built on top of **python3** and **tensorflow 2.x** for training and evaluation. It also comes packaged with scala utilities for **JVM inference**. ml4ir is designed as modular subcomponents which can each be combined and customized to build a variety of search ML models such as: diff --git a/python/README.md b/python/README.md index 0627e88f..04f87503 100644 --- a/python/README.md +++ b/python/README.md @@ -149,6 +149,8 @@ To use ml4ir as a deep learning library to build relevance models, look at the f * **Text Classification** : The `EntityPredictionDemo` notebook walks you through training a model to predict entity type given a user context and query. +* **Ranking Explanations** : The `Ranking_Explanations` notebook walks you through per-query explanations for a trained ml4ir model + Enter the following command to spin up Jupyter notebook on your browser to run the above notebooks ``` cd path/to/ml4ir/python/ diff --git a/python/notebooks/Ranking_Explanations.ipynb b/python/notebooks/Ranking_Explanations.ipynb new file mode 100644 index 00000000..7b9c058b --- /dev/null +++ b/python/notebooks/Ranking_Explanations.ipynb @@ -0,0 +1,3395 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Learning to Rank Expanations Demo 2022" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "In this notebook, we will explore how to explain the scores of a Learning to Rank model using OmniXAI\n", + "\n", + "#### Key Takeaways\n", + "- How to install and get started with ml4ir as a script\n", + "- Explaining the rank scores using OmniXAI\n", + "\n", + "#### Learning to Rank\n", + "The goal of Learning to Rank(LTR) is to come up with a ranking function to generate an optimal ordering of a list of documents. In this notebook, we will learn a simple **pointwise ranking function** using a **listwise loss** which will predict the ranking scores for all records of a given query. These scores can then be used at inference to determine the optimal ordering.\n", + "\n", + "#### Per Query Valid Explanations \n", + "\n", + "We explore the per-query Valid explanations using Omnixai's ValidityRankingExplainer\n", + "\n", + "Reference for algorithm: Singh, J., Khosla, M., & Anand, A. (2020). Valid Explanations for Learning to Rank Models. ArXiv, abs/2004.13972." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install ml4ir and omnixai" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "scrolled": true + }, + "source": [ + "### Install the ml4ir from github as per the README" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "!pip install omnixai" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installing visualization libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARNING: You are using pip version 22.0.4; however, version 22.2.2 is available.\r\n", + "You should consider upgrading via the '/Users/tlaud/ml4ir/python/venv/bin/python3.7 -m pip install --upgrade pip' command.\u001b[0m\u001b[33m\r\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install --upgrade -q plotly nbformat" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Look at the data" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
query_idquery_textranktext_match_scorepage_views_scorequality_scoreclickeddomain_iddomain_namename_match
0query_2MHS7A7RJB1Y4BJT20.4737300.0000000.0000002domain_21
1query_2MHS7A7RJB1Y4BJT11.0631900.2053810.3010312domain_21
2query_5KNJNWV61.3681080.0306360.0000000domain_00
3query_5KNJNWV31.3706280.0412610.3010300domain_00
4query_5KNJNWV41.3667000.0825350.3010300domain_00
5query_5KNJNWV11.3338360.0425720.3010310domain_00
6query_5KNJNWV51.3250210.0464780.0000000domain_01
\n", + "
" + ], + "text/plain": [ + " query_id query_text rank text_match_score page_views_score \\\n", + "0 query_2 MHS7A7RJB1Y4BJT 2 0.473730 0.000000 \n", + "1 query_2 MHS7A7RJB1Y4BJT 1 1.063190 0.205381 \n", + "2 query_5 KNJNWV 6 1.368108 0.030636 \n", + "3 query_5 KNJNWV 3 1.370628 0.041261 \n", + "4 query_5 KNJNWV 4 1.366700 0.082535 \n", + "5 query_5 KNJNWV 1 1.333836 0.042572 \n", + "6 query_5 KNJNWV 5 1.325021 0.046478 \n", + "\n", + " quality_score clicked domain_id domain_name name_match \n", + "0 0.00000 0 2 domain_2 1 \n", + "1 0.30103 1 2 domain_2 1 \n", + "2 0.00000 0 0 domain_0 0 \n", + "3 0.30103 0 0 domain_0 0 \n", + "4 0.30103 0 0 domain_0 0 \n", + "5 0.30103 1 0 domain_0 0 \n", + "6 0.00000 0 0 domain_0 1 " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df_train = pd.read_csv(\"../ml4ir/applications/ranking/tests/data/csv/train/file_0.csv\")\n", + "df_train.head(7)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the FeatureConfig\n", + "\n", + "**YAML File** -> configs/activate_2020/feature_config.yaml\n", + "\n", + "\n", + "\n", + "| Feature | Type | TFRecord Type | Usage |\n", + "| ---------------- | -------- | ------------- | ---------------------------------------- |\n", + "| query_text | Text | Context | Character Embeddings -> biLSTM Encoding |\n", + "| domain_name | Text | Context | VocabLookup -> Categorical Embedding |\n", + "| text_match_score | Numeric | Sequence | float |\n", + "| page_views_score | Numeric | Sequence | float |\n", + "| quality_score | Numeric | Sequence | float |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the ModelConfig" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "architecture_key: dnn\n", + "layers:\n", + " - type: dense\n", + " name: first_dense\n", + " units: 256\n", + " activation: relu\n", + " - type: dropout\n", + " name: first_dropout\n", + " rate: 0.3\n", + " - type: dense\n", + " name: second_dense\n", + " units: 64\n", + " activation: relu\n", + " - type: dense\n", + " name: final_dense\n", + " units: 1\n", + " activation: null\n", + "\n" + ] + } + ], + "source": [ + "print(open(\"configs/activate_2020/model_config.yaml\").read())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using ml4ir as a script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "!python ../ml4ir/applications/ranking/pipeline.py \\\n", + "--data_format csv \\\n", + "--data_dir ../ml4ir/applications/ranking/tests/data/csv \\\n", + "--feature_config configs/activate_2020/feature_config.yaml \\\n", + "--model_config configs/activate_2020/model_config.yaml \\\n", + "--execution_mode train_inference_evaluate \\\n", + "--loss_key softmax_cross_entropy \\\n", + "--num_epochs 3 \\\n", + "--models_dir ../models/explain_demo_2022 \\\n", + "--logs_dir ../logs/explain_demo_2022 \\\n", + "--run_id activate_demo" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Now, the model is saved and ready for inference" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_DIR = '../models/explain_demo_2022/activate_demo'" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import tensorflow as tf\n", + "import os\n", + "from ml4ir.base.io.local_io import LocalIO\n", + "from ml4ir.base.io.file_io import FileIO\n", + "from ml4ir.base.features.feature_config import FeatureConfig, SequenceExampleFeatureConfig\n", + "from ml4ir.base.model.relevance_model import RelevanceModel\n", + "from ml4ir.base.config.keys import TFRecordTypeKey" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training features\n", + "-----------------\n", + "text_match_score\n", + "page_views_score\n", + "quality_score\n", + "query_text\n", + "domain_name\n", + "text_match_score\n", + "page_views_score\n", + "quality_score\n", + "query_text\n", + "domain_name\n" + ] + } + ], + "source": [ + "# Set up file I/O handler\n", + "file_io : FileIO = LocalIO()\n", + " \n", + "\n", + "# Set up logger\n", + "logger = logging.getLogger()\n", + "\n", + "tf.get_logger().setLevel(\"INFO\")\n", + "tf.autograph.set_verbosity(3)\n", + "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n", + "\n", + "feature_config: SequenceExampleFeatureConfig = FeatureConfig.get_instance(\n", + " tfrecord_type=TFRecordTypeKey.SEQUENCE_EXAMPLE,\n", + " feature_config_dict=file_io.read_yaml(\"configs/activate_2020/feature_config.yaml\"),\n", + " logger=logger)\n", + "print(\"Training features\\n-----------------\")\n", + "print(\"\\n\".join(feature_config.get_train_features(key=\"name\")))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Sanity check" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Retraining is not yet supported. Model is loaded with compile=False\n" + ] + } + ], + "source": [ + "relevance_model = RelevanceModel(\n", + " feature_config=feature_config,\n", + " tfrecord_type=TFRecordTypeKey.EXAMPLE,\n", + " model_file=os.path.join(MODEL_DIR, 'final/default/'),\n", + " logger=logger,\n", + " output_name=\"relevance_score\",\n", + " file_io=file_io\n", + ")\n", + "\n", + "logger.info(\"Is Keras model? {}\".format(isinstance(relevance_model.model, tf.keras.Model)))\n", + "logger.info(\"Is compiled? {}\".format(relevance_model.is_compiled))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from tensorflow.keras import models as kmodels\n", + "from tensorflow import data\n", + "\n", + "model = kmodels.load_model(\n", + " os.path.join(MODEL_DIR, 'final/tfrecord/'),\n", + " compile=False)\n", + "infer_fn = model.signatures[\"serving_tfrecord\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from ml4ir.base.data.tfrecord_helper import get_sequence_example_proto\n", + "\n", + "def predict(features_df):\n", + " features_df[\"query_text\"] = features_df[\"query_text\"].fillna(\"\")\n", + " features_df = (features_df.copy()\n", + " .rename(columns={\n", + " feature[\"serving_info\"][\"name\"]: feature[\"name\"] for feature in\n", + " feature_config.context_features + feature_config.sequence_features\n", + " }))\n", + " #print(features_df)\n", + " context_feature_names = [feature[\"name\"] for feature in feature_config.context_features]\n", + " protos = features_df.groupby([\"query_id\",\"query_text\"]).apply(lambda g: get_sequence_example_proto(\n", + " group=g,\n", + " context_features=feature_config.context_features,\n", + " sequence_features=feature_config.sequence_features,\n", + " ))\n", + "\n", + "\n", + " \n", + " # Score the proto with the model\n", + " ranking_scores = protos.apply(lambda se: infer_fn(\n", + " tf.expand_dims(\n", + " tf.constant(se.SerializeToString()),\n", + " axis=-1))[\"ranking_score\"].numpy()[0])\n", + " # Check parity of scores\n", + " predicted_scores = (ranking_scores.reset_index(name=\"ranking_score\")\n", + " .set_index(\"query_id\")\n", + " .squeeze())\n", + " return predicted_scores[\"ranking_score\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's look at one of the queries" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
query_idquery_textranktext_match_scorepage_views_scorequality_scoreclickeddomain_iddomain_namename_match
2query_5KNJNWV61.3681080.0306360.0000000domain_00
3query_5KNJNWV31.3706280.0412610.3010300domain_00
4query_5KNJNWV41.3667000.0825350.3010300domain_00
5query_5KNJNWV11.3338360.0425720.3010310domain_00
6query_5KNJNWV51.3250210.0464780.0000000domain_01
7query_5KNJNWV21.3627200.0425720.3010300domain_00
\n", + "
" + ], + "text/plain": [ + " query_id query_text rank text_match_score page_views_score \\\n", + "2 query_5 KNJNWV 6 1.368108 0.030636 \n", + "3 query_5 KNJNWV 3 1.370628 0.041261 \n", + "4 query_5 KNJNWV 4 1.366700 0.082535 \n", + "5 query_5 KNJNWV 1 1.333836 0.042572 \n", + "6 query_5 KNJNWV 5 1.325021 0.046478 \n", + "7 query_5 KNJNWV 2 1.362720 0.042572 \n", + "\n", + " quality_score clicked domain_id domain_name name_match \n", + "2 0.00000 0 0 domain_0 0 \n", + "3 0.30103 0 0 domain_0 0 \n", + "4 0.30103 0 0 domain_0 0 \n", + "5 0.30103 1 0 domain_0 0 \n", + "6 0.00000 0 0 domain_0 1 \n", + "7 0.30103 0 0 domain_0 0 " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_train[df_train[\"query_id\"]==\"query_5\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### And its corresponding model output scores" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/tlaud/ml4ir/python/venv/lib/python3.7/site-packages/ipykernel_launcher.py:4: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " after removing the cwd from sys.path.\n" + ] + }, + { + "data": { + "text/plain": [ + "array([0.11998416, 0.19389412, 0.20375773, 0.17943792, 0.11195529,\n", + " 0.1909707 ], dtype=float32)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predict(df_train[df_train[\"query_id\"]==\"query_5\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Now, let's create a Tabular instance which is a standard way to process datasets in OmniXAI" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
query_idquery_textranktext_match_scorepage_views_scorequality_scoreclickeddomain_iddomain_namename_match
0query_2MHS7A7RJB1Y4BJT20.4737300.0000000.0000002domain_21
1query_2MHS7A7RJB1Y4BJT11.0631900.2053810.3010312domain_21
2query_5KNJNWV61.3681080.0306360.0000000domain_00
3query_5KNJNWV31.3706280.0412610.3010300domain_00
4query_5KNJNWV41.3667000.0825350.3010300domain_00
.................................
5671query_1487QCZ4XHLN60.2276940.0000000.0000002domain_20
5672query_1487QCZ4XHLN21.0169540.0000000.0000002domain_21
5673query_1490WYNFF8920.4746000.1907350.0000000domain_00
5674query_1490WYNFF8910.6203550.1433100.0000010domain_00
5675query_1490WYNFF8930.5083620.1907350.0000000domain_01
\n", + "

5676 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " query_id query_text rank text_match_score page_views_score \\\n", + "0 query_2 MHS7A7RJB1Y4BJT 2 0.473730 0.000000 \n", + "1 query_2 MHS7A7RJB1Y4BJT 1 1.063190 0.205381 \n", + "2 query_5 KNJNWV 6 1.368108 0.030636 \n", + "3 query_5 KNJNWV 3 1.370628 0.041261 \n", + "4 query_5 KNJNWV 4 1.366700 0.082535 \n", + "... ... ... ... ... ... \n", + "5671 query_1487 QCZ4XHLN 6 0.227694 0.000000 \n", + "5672 query_1487 QCZ4XHLN 2 1.016954 0.000000 \n", + "5673 query_1490 WYNFF89 2 0.474600 0.190735 \n", + "5674 query_1490 WYNFF89 1 0.620355 0.143310 \n", + "5675 query_1490 WYNFF89 3 0.508362 0.190735 \n", + "\n", + " quality_score clicked domain_id domain_name name_match \n", + "0 0.00000 0 2 domain_2 1 \n", + "1 0.30103 1 2 domain_2 1 \n", + "2 0.00000 0 0 domain_0 0 \n", + "3 0.30103 0 0 domain_0 0 \n", + "4 0.30103 0 0 domain_0 0 \n", + "... ... ... ... ... ... \n", + "5671 0.00000 0 2 domain_2 0 \n", + "5672 0.00000 0 2 domain_2 1 \n", + "5673 0.00000 0 0 domain_0 0 \n", + "5674 0.00000 1 0 domain_0 0 \n", + "5675 0.00000 0 0 domain_0 1 \n", + "\n", + "[5676 rows x 10 columns]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from omnixai.data.tabular import Tabular\n", + "training_data = Tabular(\n", + " df_train,\n", + " target_column='clicked',\n", + ")\n", + "training_data.to_pd() #The tabular instance can always be converted back to pandas DataFrame" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Similarly for the query sample" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
query_idquery_textranktext_match_scorepage_views_scorequality_scoreclickeddomain_iddomain_namename_match
2query_5KNJNWV61.3681080.0306360.0000000domain_00
3query_5KNJNWV31.3706280.0412610.3010300domain_00
4query_5KNJNWV41.3667000.0825350.3010300domain_00
5query_5KNJNWV11.3338360.0425720.3010310domain_00
6query_5KNJNWV51.3250210.0464780.0000000domain_01
7query_5KNJNWV21.3627200.0425720.3010300domain_00
\n", + "
" + ], + "text/plain": [ + " query_id query_text rank text_match_score page_views_score \\\n", + "2 query_5 KNJNWV 6 1.368108 0.030636 \n", + "3 query_5 KNJNWV 3 1.370628 0.041261 \n", + "4 query_5 KNJNWV 4 1.366700 0.082535 \n", + "5 query_5 KNJNWV 1 1.333836 0.042572 \n", + "6 query_5 KNJNWV 5 1.325021 0.046478 \n", + "7 query_5 KNJNWV 2 1.362720 0.042572 \n", + "\n", + " quality_score clicked domain_id domain_name name_match \n", + "2 0.00000 0 0 domain_0 0 \n", + "3 0.30103 0 0 domain_0 0 \n", + "4 0.30103 0 0 domain_0 0 \n", + "5 0.30103 1 0 domain_0 0 \n", + "6 0.00000 0 0 domain_0 1 \n", + "7 0.30103 0 0 domain_0 0 " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample_query = Tabular(\n", + " df_train[df_train[\"query_id\"]==\"query_5\"],\n", + " target_column='clicked',\n", + ")\n", + "sample_query.to_pd()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define the features that you wish to analyze. These are sequence features in our case" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "sequence_features = [f['name'] for f in feature_config.sequence_features if f['trainable']]\n", + "columns = set(training_data.columns)\n", + "ignored_features = columns - set(sequence_features)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'clicked',\n", + " 'domain_id',\n", + " 'domain_name',\n", + " 'name_match',\n", + " 'query_id',\n", + " 'query_text',\n", + " 'rank'}" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ignored_features" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize Explainer" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "from omnixai.explainers.ranking.agnostic.validity import ValidityRankingExplainer\n", + "\n", + "ranking_explainer = ValidityRankingExplainer(training_data=training_data,\n", + " ignored_features=ignored_features,\n", + " predict_function=lambda x: predict(x.to_pd()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get explanations in one call" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "explanation = ranking_explainer.explain(sample_query, # The tabular instance to be explained\n", + " k=3 # The maximum number of features to consider as explanation\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### The resulting order of feature importance:" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['quality_score', 'text_match_score', 'page_views_score'])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "explanation.get_explanations(0)[\"top_features\"].keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### We can determine the validity of our explanation" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "KendalltauResult(correlation=0.9999999999999999, pvalue=0.002777777777777778)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "explanation.get_explanations(0)['validity']['Tau']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Kendall Tau of 0.99 indicates that the feature importances are a valid explanation for the ranking.
We can also plot the features with importance grading:" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "hoverinfo": "none", + "opacity": 0.75, + "showscale": false, + "type": "heatmap", + "z": [ + [ + 0, + 1, + 0.75, + 0.5, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 1, + 0.75, + 0.5, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 1, + 0.75, + 0.5, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 1, + 0.75, + 0.5, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 1, + 0.75, + 0.5, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 1, + 0.75, + 0.5, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 1, + 0.75, + 0.5, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "annotations": [ + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "#Rank", + "x": -0.45, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "quality_score", + "x": 0.55, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "text_match_score", + "x": 1.55, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "page_views_score", + "x": 2.55, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "query_id", + "x": 3.55, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "query_text", + "x": 4.55, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "rank", + "x": 5.55, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "clicked", + "x": 6.55, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "domain_id", + "x": 7.55, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "domain_name", + "x": 8.55, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "name_match", + "x": 9.55, + "xanchor": "left", + "xref": "x", + "y": 0, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "5", + "x": -0.45, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.0", + "x": 0.55, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "1.3681", + "x": 1.55, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.0306", + "x": 2.55, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "query_5", + "x": 3.55, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "KNJNWV", + "x": 4.55, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "6", + "x": 5.55, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 6.55, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 7.55, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "domain_0", + "x": 8.55, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 9.55, + "xanchor": "left", + "xref": "x", + "y": 1, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "2", + "x": -0.45, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.301", + "x": 0.55, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "1.3706", + "x": 1.55, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.0413", + "x": 2.55, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "query_5", + "x": 3.55, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "KNJNWV", + "x": 4.55, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "3", + "x": 5.55, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 6.55, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 7.55, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "domain_0", + "x": 8.55, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 9.55, + "xanchor": "left", + "xref": "x", + "y": 2, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "1", + "x": -0.45, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.301", + "x": 0.55, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "1.3667", + "x": 1.55, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.0825", + "x": 2.55, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "query_5", + "x": 3.55, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "KNJNWV", + "x": 4.55, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "4", + "x": 5.55, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 6.55, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 7.55, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "domain_0", + "x": 8.55, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 9.55, + "xanchor": "left", + "xref": "x", + "y": 3, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "4", + "x": -0.45, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.301", + "x": 0.55, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "1.3338", + "x": 1.55, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.0426", + "x": 2.55, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "query_5", + "x": 3.55, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "KNJNWV", + "x": 4.55, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "1", + "x": 5.55, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "1", + "x": 6.55, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 7.55, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "domain_0", + "x": 8.55, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 9.55, + "xanchor": "left", + "xref": "x", + "y": 4, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "6", + "x": -0.45, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.0", + "x": 0.55, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "1.325", + "x": 1.55, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.0465", + "x": 2.55, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "query_5", + "x": 3.55, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "KNJNWV", + "x": 4.55, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "5", + "x": 5.55, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 6.55, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 7.55, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "domain_0", + "x": 8.55, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "1", + "x": 9.55, + "xanchor": "left", + "xref": "x", + "y": 5, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "3", + "x": -0.45, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.301", + "x": 0.55, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "1.3627", + "x": 1.55, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0.0426", + "x": 2.55, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "query_5", + "x": 3.55, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "KNJNWV", + "x": 4.55, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "2", + "x": 5.55, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 6.55, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 7.55, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "domain_0", + "x": 8.55, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + }, + { + "align": "left", + "font": { + "color": "#000000" + }, + "showarrow": false, + "text": "0", + "x": 9.55, + "xanchor": "left", + "xref": "x", + "y": 6, + "yref": "y" + } + ], + "autosize": false, + "height": 260, + "margin": { + "b": 0, + "l": 0, + "r": 0, + "t": 0 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "width": 1800, + "xaxis": { + "dtick": 1, + "gridwidth": 2, + "showticklabels": false, + "tick0": -0.5, + "ticks": "", + "zeroline": false + }, + "yaxis": { + "autorange": "reversed", + "dtick": 1, + "gridwidth": 2, + "showticklabels": false, + "tick0": 0.5, + "ticks": "", + "zeroline": false + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = explanation._ipython_figure(0)\n", + "fig.update_layout(autosize=False, width=1800)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ml4ir", + "language": "python", + "name": "ml4ir" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/notebooks/configs/activate_2020/feature_config.yaml b/python/notebooks/configs/activate_2020/feature_config.yaml index dae19294..ce4a3873 100644 --- a/python/notebooks/configs/activate_2020/feature_config.yaml +++ b/python/notebooks/configs/activate_2020/feature_config.yaml @@ -127,7 +127,7 @@ features: shape: null fn: categorical_embedding_with_vocabulary_file args: - vocabulary_file: '../ml4ir/applications/ranking/tests/data/config/domain_name_vocab_no_id.csv' + vocabulary_file: '../ml4ir/applications/ranking/tests/data/configs/domain_name_vocab_no_id.csv' embedding_size: 64 default_value: -1 num_oov_buckets: 1 diff --git a/python/optional_requirements.yaml b/python/optional_requirements.yaml index 6cc7d493..e9fd9658 100644 --- a/python/optional_requirements.yaml +++ b/python/optional_requirements.yaml @@ -21,6 +21,9 @@ # To install ml4ir `all`, run pip install ml4ir[all] all: - pyspark==3.0.1 # required to run ml4ir.base.pipeline + - omnixai==1.1.4 # required for running explanations demo. Upgrade to 1.1.5 when it is available pyspark: - pyspark==3.0.1 # required to support pyspark data read +explainer: + - omnixai==1.1.4 # required for running explanations demo. Upgrade to 1.1.5 when it is available # Add other optional ml4ir dependencies here \ No newline at end of file