diff --git a/pipelineservice/explain/_explain.py b/pipelineservice/explain/_explain.py index 791993f..94cb00c 100644 --- a/pipelineservice/explain/_explain.py +++ b/pipelineservice/explain/_explain.py @@ -15,9 +15,11 @@ def __init__(self, self.y = y def get_estimator(self): try: - return self.pipeline[-1] + pipeline = self.pipeline[-1] except: - return self.pipeline + pipeline = self.pipeline + sklearn.utils.validation.check_is_fitted(pipeline) + return pipeline def get_xtransform(self): try: X_transform = self.pipeline[:-1].transform(self.X) @@ -51,8 +53,8 @@ def plot_feature_importance(self, top_k = None): def get_permutation_importance(self, n_jobs = -1, **kwargs): - permutation = permutation_importance(estimator = self.pipeline, - X = self.X, + permutation = permutation_importance(estimator = self.get_estimator(), + X = self.get_xtransform(), y = self.y, n_jobs = n_jobs, **kwargs @@ -75,13 +77,16 @@ def plot_permutation_importance(self, top_k = None): return facegrid.ax def plot_partial_dependence(self, features, + n_jobs = -1, **kwargs): fig = plot_partial_dependence(estimator = self.get_estimator(), X = self.get_xtransform(), - features = features + features = features, + n_jobs = n_jobs, + **kwargs ) return fig - def shapley_importance(self): + def plot_shap_importance(self): explainer = shap.TreeExplainer(self.get_estimator()) shap_values = explainer.shap_values(self.get_xtransform()) return shap.summary_plot(shap_values, self.get_xtransform(), plot_type = "bar")