diff --git a/shapiq/explainer/tree/base.py b/shapiq/explainer/tree/base.py index 89cedc09..d852a598 100644 --- a/shapiq/explainer/tree/base.py +++ b/shapiq/explainer/tree/base.py @@ -87,8 +87,10 @@ def __post_init__(self) -> None: self.leaf_mask = np.asarray(self.children_left == -1) # sanitize features self.features = np.where(self.leaf_mask, -2, self.features) + self.features = self.features.astype(int) # make features integer type # sanitize thresholds self.thresholds = np.where(self.leaf_mask, np.nan, self.thresholds) + # self.thresholds = np.round(self.thresholds, 4) # round thresholds # setup empty prediction if self.empty_prediction is None: self.compute_empty_prediction() @@ -118,6 +120,13 @@ def __post_init__(self) -> None: # setup new feature mapping if self.feature_map_internal_original is None: self.feature_map_internal_original = {i: i for i in unique_features} + # flatten values if necessary + if self.values.ndim > 1: + if self.values.shape[1] != 1: + raise ValueError("Values array has more than one column.") + self.values = self.values.flatten() + # set all values of non leaf nodes to zero + self.values[~self.leaf_mask] = 0 def reduce_feature_complexity(self) -> None: """Reduces the feature complexity of the tree model. diff --git a/shapiq/explainer/tree/conversion/xgboost.py b/shapiq/explainer/tree/conversion/xgboost.py index abc70c99..0e29aea8 100644 --- a/shapiq/explainer/tree/conversion/xgboost.py +++ b/shapiq/explainer/tree/conversion/xgboost.py @@ -1,6 +1,7 @@ """Functions for converting xgboost decision trees to the format used by shapiq.""" +import warnings from typing import Optional import numpy as np @@ -25,20 +26,23 @@ def convert_xgboost_booster( Returns: The converted xgboost booster. """ + try: + intercept = tree_booster.base_score + if intercept is None: + intercept = float(tree_booster.intercept_[0]) + tree_booster = tree_booster.get_booster() + except AttributeError: + intercept = 0.0 + warnings.warn( + "The model does not have a valid base score. Setting the intercept to 0.0." + "Baseline values of the interaction models might be different." + ) + # https://github.com/shap/shap/blob/77e92c3c110e816b768a0ec2acfbf4cc08ee13db/shap/explainers/_tree.py#L1992 scaling = 1.0 booster_df = tree_booster.trees_to_dataframe() output_type = "raw" - if len(booster_df["Tree"].unique()) > tree_booster.num_boosted_rounds(): - # choose only trees for the selected class (xgboost grows n_estimators*n_class trees) - # approximation for the number of classes in xgboost - n_class = int(len(booster_df["Tree"].unique()) / tree_booster.num_boosted_rounds()) - if class_label is None: - class_label = 0 - idc = booster_df["Tree"] % n_class == class_label - booster_df = booster_df.loc[idc, :] - # if tree_booster.feature_names: feature_names = tree_booster.feature_names else: @@ -52,14 +56,29 @@ def convert_xgboost_booster( booster_df.loc[:, "Feature"] = booster_df.loc[:, "Feature"].replace( convert_feature_str_to_int ) + + if len(booster_df["Tree"].unique()) > tree_booster.num_boosted_rounds(): + # choose only trees for the selected class (xgboost grows n_estimators*n_class trees) + # approximation for the number of classes in xgboost + n_class = int(len(booster_df["Tree"].unique()) / tree_booster.num_boosted_rounds()) + if class_label is None: + class_label = 0 + idc = booster_df["Tree"] % n_class == class_label + booster_df = booster_df.loc[idc, :] + + n_trees = len(booster_df["Tree"].unique()) + intercept /= n_trees return [ - _convert_xgboost_tree_as_df(tree_df=tree_df, output_type=output_type, scaling=scaling) + _convert_xgboost_tree_as_df( + tree_df=tree_df, intercept=intercept, output_type=output_type, scaling=scaling + ) for _, tree_df in booster_df.groupby("Tree") ] def _convert_xgboost_tree_as_df( tree_df: Model, + intercept: float, output_type: str, scaling: float = 1.0, ) -> TreeModel: @@ -77,7 +96,8 @@ def _convert_xgboost_tree_as_df( # pandas can't chill https://stackoverflow.com/q/77900971 with pd.option_context("future.no_silent_downcasting", True): - return TreeModel( + values = tree_df["Gain"].values * scaling + intercept # add intercept to all values + tree_model = TreeModel( children_left=tree_df["Yes"] .replace(convert_node_str_to_int) .infer_objects(copy=False) @@ -92,8 +112,10 @@ def _convert_xgboost_tree_as_df( .values, features=tree_df["Feature"].values, thresholds=tree_df["Split"].values, - values=tree_df["Gain"].values * scaling, # values in non-leaf nodes are not used + values=values, # values in non-leaf nodes are not used node_sample_weight=tree_df["Cover"].values, empty_prediction=None, original_output_type=output_type, ) + + return tree_model diff --git a/shapiq/explainer/tree/explainer.py b/shapiq/explainer/tree/explainer.py index 836e3fc4..f8e93a2d 100644 --- a/shapiq/explainer/tree/explainer.py +++ b/shapiq/explainer/tree/explainer.py @@ -71,11 +71,7 @@ def __init__( self._treeshapiq_explainers: list[TreeSHAPIQ] = [ TreeSHAPIQ(model=_tree, max_order=self._max_order, index=index) for _tree in self._trees ] - - # TODO: for the current implementation this is correct for other trees this may vary - self.baseline_value = sum( - [treeshapiq.empty_prediction for treeshapiq in self._treeshapiq_explainers] - ) + self.baseline_value = self._compute_baseline_value() def explain(self, x: np.ndarray) -> InteractionValues: # run treeshapiq for all trees @@ -90,3 +86,17 @@ def explain(self, x: np.ndarray) -> InteractionValues: for i in range(1, len(interaction_values)): final_explanation += interaction_values[i] return final_explanation + + def _compute_baseline_value(self) -> float: + """Computes the baseline value for the explainer. + + The baseline value is the sum of the empty predictions of all trees in the ensemble. + + Returns: + The baseline value for the explainer. + """ + + baseline_value = sum( + [treeshapiq.empty_prediction for treeshapiq in self._treeshapiq_explainers] + ) + return baseline_value diff --git a/shapiq/explainer/tree/treeshapiq.py b/shapiq/explainer/tree/treeshapiq.py index 4375e63b..07d7bfc2 100644 --- a/shapiq/explainer/tree/treeshapiq.py +++ b/shapiq/explainer/tree/treeshapiq.py @@ -102,9 +102,13 @@ def __init__( self._edge_tree: EdgeTree = copy.deepcopy(edge_tree) # compute the empty prediction - self.empty_prediction: float = float( + computed_empty_prediction = float( np.sum(self._edge_tree.empty_predictions[self._tree.leaf_mask]) ) + tree_empty_prediction = self._tree.empty_prediction + if tree_empty_prediction is None: + tree_empty_prediction = computed_empty_prediction + self.empty_prediction: float = tree_empty_prediction # stores the interaction scores up to a given order self.subset_ancestors_store: dict = {} diff --git a/shapiq/explainer/tree/validation.py b/shapiq/explainer/tree/validation.py index 79cda334..09f5cdf4 100644 --- a/shapiq/explainer/tree/validation.py +++ b/shapiq/explainer/tree/validation.py @@ -42,6 +42,8 @@ def validate_tree_model( # tree model (is already in the correct format) if type(model).__name__ == "TreeModel": tree_model = model + elif isinstance(model, list) and all([type(m).__name__ == "TreeModel" for m in model]): + tree_model = model # dict as model is parsed to TreeModel (the dict needs to have the correct format and names) elif type(model).__name__ == "dict": tree_model = TreeModel(**model) @@ -73,8 +75,6 @@ def validate_tree_model( elif safe_isinstance(model, "xgboost.sklearn.XGBRegressor") or safe_isinstance( model, "xgboost.sklearn.XGBClassifier" ): - tree_model = convert_xgboost_booster(model.get_booster(), class_label=class_label) - elif safe_isinstance(model, "xgboost.core.Booster"): tree_model = convert_xgboost_booster(model, class_label=class_label) # unsupported model else: diff --git a/shapiq/explainer/utils.py b/shapiq/explainer/utils.py index 1c328126..c5e6ba30 100644 --- a/shapiq/explainer/utils.py +++ b/shapiq/explainer/utils.py @@ -90,6 +90,9 @@ def get_predict_function_and_model_type(model, model_class): elif isinstance(model, tree.TreeModel): # test scenario _predict_function = model.compute_empty_prediction _model_type = "tree" + elif isinstance(model, list) and all([isinstance(m, tree.TreeModel) for m in model]): + _predict_function = model[0].compute_empty_prediction + _model_type = "tree" elif _predict_function is None: raise TypeError( f"`model` is of unsupported type: {model_class}.\n" diff --git a/tests/conftest.py b/tests/conftest.py index 2fa77af4..b9143f50 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -85,6 +85,17 @@ def rf_clf_model() -> RandomForestClassifier: return model +@pytest.fixture +def xgb_reg_model(): + """Return a simple xgboost regression model.""" + from xgboost import XGBRegressor + + X, y = make_regression(n_samples=100, n_features=7, random_state=42) + model = XGBRegressor(random_state=42, n_estimators=3) + model.fit(X, y) + return model + + @pytest.fixture def rf_clf_binary_model() -> RandomForestClassifier: """Return a simple random forest model.""" @@ -102,6 +113,25 @@ def rf_clf_binary_model() -> RandomForestClassifier: return model +@pytest.fixture +def xgb_clf_model(): + """Return a simple xgboost classification model.""" + from xgboost import XGBClassifier + + X, y = make_classification( + n_samples=100, + n_features=7, + random_state=42, + n_classes=3, + n_informative=7, + n_repeated=0, + n_redundant=0, + ) + model = XGBClassifier(random_state=42, n_estimators=3) + model.fit(X, y) + return model + + @pytest.fixture def background_reg_data() -> np.ndarray: """Return a simple background dataset.""" diff --git a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py index 8fb37316..c7654425 100644 --- a/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py +++ b/tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py @@ -1,5 +1,7 @@ """This test module contains all tests for the tree explainer module of the shapiq package.""" +import copy + import numpy as np import pytest @@ -121,3 +123,122 @@ def test_against_shap_implementation(): with pytest.warns(UserWarning): _ = TreeExplainer(model=tree_model, max_order=2, min_order=1, index="SV") + + +def test_xgboost_reg(xgb_reg_model, background_reg_data): + """Tests the shapiq implementation of TreeSHAP agains SHAP's implementation for XGBoost.""" + + explanation_instance = 0 + + # the following code is used to get the shap values from the SHAP implementation + # import shap + # explainer_shap = shap.TreeExplainer(model=xgb_reg_model) + # x_explain_shap = background_reg_data[explanation_instance].reshape(1, -1) + # sv_shap = explainer_shap.shap_values(x_explain_shap)[0] + sv_shap = [-2.555832, 28.50987, 1.7708225, -7.8653603, 10.7955885, -0.1877861, 4.549199] + sv_shap = np.asarray(sv_shap) + + # compute with shapiq + explainer_shapiq = TreeExplainer(model=xgb_reg_model, max_order=1, index="SV") + x_explain_shapiq = background_reg_data[explanation_instance] + sv_shapiq = explainer_shapiq.explain(x=x_explain_shapiq) + sv_shapiq_values = sv_shapiq.get_n_order_values(1) + baseline_shapiq = sv_shapiq.baseline_value + + assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5) + + # get prediction of the model + prediction = xgb_reg_model.predict(x_explain_shapiq.reshape(1, -1)) + assert prediction == pytest.approx(baseline_shapiq + np.sum(sv_shapiq_values), rel=1e-5) + + +def test_xgboost_clf(xgb_clf_model, background_clf_data): + """Tests the shapiq implementation of TreeSHAP agains SHAP's implementation for XGBoost.""" + + explanation_instance = 1 + class_label = 1 + + # the following code is used to get the shap values from the SHAP implementation + # import shap + # model_copy = copy.deepcopy(xgb_clf_model) + # explainer_shap = shap.TreeExplainer(model=model_copy) + # baseline_shap = float(explainer_shap.expected_value[class_label]) + # print(baseline_shap) + # x_explain_shap = copy.deepcopy(background_clf_data[explanation_instance].reshape(1, -1)) + # sv_shap_all_classes = explainer_shap.shap_values(x_explain_shap) + # sv_shap = sv_shap_all_classes[0][:, class_label] + # print(sv_shap) + sv = [-0.00545454, -0.15837783, -0.17675081, -0.24213657, 0.00247543, 0.00988865, -0.01564346] + sv_shap = np.array(sv) + + # compute with shapiq + explainer_shapiq = TreeExplainer( + model=xgb_clf_model, max_order=1, index="SV", class_label=class_label + ) + x_explain_shapiq = copy.deepcopy(background_clf_data[explanation_instance]) + sv_shapiq = explainer_shapiq.explain(x=x_explain_shapiq) + sv_shapiq_values = sv_shapiq.get_n_order_values(1) + baseline_shapiq = sv_shapiq.baseline_value + + # assert baseline_shap == pytest.approx(baseline_shapiq, rel=1e-4) + assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5) + + # get prediction of the model (as the log odds) + prediction = xgb_clf_model.predict(x_explain_shapiq.reshape(1, -1), output_margin=True)[0][ + class_label + ] + assert prediction == pytest.approx(baseline_shapiq + np.sum(sv_shapiq_values), rel=1e-5) + + +def test_xgboost_shap_error(xgb_clf_model, background_clf_data): + """Tests for the strange behavior of SHAP's XGBoost implementation. + + The test is used to show that the shapiq implementation is correct and the SHAP implementation + is doing something weird. For some instances (e.g. the one used in this test) the SHAP values + are different from the shapiq values. However, when we round the `thresholds` of the xgboost + trees in shapiq, then the computed explanations match. This is a strange behavior as rounding + the thresholds makes the model less true to the original model but only then the explanations + match. + """ + + explanation_instance = 0 + class_label = 1 + + # get the shap explanations (the following code is used to get SVs from SHAP) + # import shap + # model_copy = copy.deepcopy(xgb_clf_model) + # explainer_shap = shap.TreeExplainer(model=model_copy) + # baseline_shap = float(explainer_shap.expected_value[class_label]) + # x_explain_shap = copy.deepcopy(background_clf_data[explanation_instance].reshape(1, -1)) + # sv_shap_all_classes = explainer_shap.shap_values(x_explain_shap) + # sv_shap = sv_shap_all_classes[0][:, class_label] + # print(sv_shap) + # print(baseline_shap) + sv = [-0.00163636, 0.05099502, -0.13182959, -0.44538185, 0.00428653, -0.04872373, -0.01370917] + sv_shap = np.array(sv) + + # setup shapiq TreeSHAP + explainer_shapiq = TreeExplainer( + model=xgb_clf_model, max_order=1, index="SV", class_label=class_label + ) + x_explain_shapiq = copy.deepcopy(background_clf_data[explanation_instance]) + sv_shapiq = explainer_shapiq.explain(x=x_explain_shapiq) + sv_shapiq_values = sv_shapiq.get_n_order_values(1) + + # the SHAP sv values should be different from the shapiq values + assert not np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5) + + # when we round the model thresholds of the xgb model (thresholds decide weather a feature is + # used or not) -> then suddenly the shap and shapiq values are the same, which points to the + # fact that the shapiq implementation is correct + explainer_shapiq_rounded = TreeExplainer( + model=xgb_clf_model, max_order=1, index="SV", class_label=class_label + ) + for tree_explainer in explainer_shapiq_rounded._treeshapiq_explainers: + tree_explainer._tree.thresholds = np.round(tree_explainer._tree.thresholds, 4) + x_explain_shapiq_rounded = copy.deepcopy(background_clf_data[explanation_instance]) + sv_shapiq_rounded = explainer_shapiq_rounded.explain(x=x_explain_shapiq_rounded) + sv_shapiq_rounded_values = sv_shapiq_rounded.get_n_order_values(1) + + # now the values surprisingly are the same + assert np.allclose(sv_shap, sv_shapiq_rounded_values, rtol=1e-5)