Skip to content

Commit

Permalink
🐛 fix dataset names
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasBouche committed Aug 23, 2022
1 parent 7df3dcd commit bda2e64
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 18 deletions.
43 changes: 27 additions & 16 deletions eurybia/core/smartdrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pickle
import shutil
import tempfile
import warnings
from pathlib import Path
from typing import Dict, Text

Expand Down Expand Up @@ -290,7 +289,7 @@ def compile(
loss_function=hyperparameter["loss_function"],
eval_metric=hyperparameter["eval_metric"],
task_type="CPU",
allow_writing_files=False
allow_writing_files=False,
)

datadrift_classifier = datadrift_classifier.fit(train_pool_cat, eval_set=test_pool_cat, silent=True)
Expand Down Expand Up @@ -318,8 +317,12 @@ def compile(
self.pb_cols, self.err_mods = pb_cols, err_mods
if self.deployed_model is not None:
self.js_divergence = compute_js_divergence(
self.df_predict.loc[lambda df: df["dataset"] == "Baseline dataset", :]["Score"].values,
self.df_predict.loc[lambda df: df["dataset"] == "Current dataset", :]["Score"].values,
self.df_predict.loc[lambda df: df["dataset"] == self.dataset_names["df_baseline"].values[0], :][
"Score"
].values,
self.df_predict.loc[lambda df: df["dataset"] == self.dataset_names["df_current"].values[0], :][
"Score"
].values,
n_bins=20,
)

Expand Down Expand Up @@ -405,22 +408,27 @@ def _analyze_consistency(self, full_validation=False, ignore_cols: list = list()
new_cols = [c for c in self.df_baseline.columns if c not in self.df_current.columns]
removed_cols = [c for c in self.df_current.columns if c not in self.df_baseline.columns]
if len(new_cols) > 0:
print(f"""The following variables are no longer available in the
current dataset and will not be analyzed: \n {new_cols}""")
print(
f"""The following variables are no longer available in the
current dataset and will not be analyzed: \n {new_cols}"""
)
if len(removed_cols) > 0:
print(f"""The following variables are only available in the
current dataset and will not be analyzed: \n {removed_cols}""")
print(
f"""The following variables are only available in the
current dataset and will not be analyzed: \n {removed_cols}"""
)
common_cols = [c for c in self.df_current.columns if c in self.df_baseline.columns]
# dtypes
err_dtypes = [
c for c in common_cols if self.df_baseline.dtypes.map(str)[c] != self.df_current.dtypes.map(str)[c]
]
if len(err_dtypes) > 0:
print(f"""The following variables have mismatching dtypes
and will not be analyzed: \n {err_dtypes}""")
print(
f"""The following variables have mismatching dtypes
and will not be analyzed: \n {err_dtypes}"""
)
# Feature values
err_mods: Dict[Text, Dict] = {}
variables_mm_mods = []
if full_validation is True:
invalid_cols = ignore_cols + new_cols + removed_cols + err_dtypes
for column in self.df_baseline.columns:
Expand All @@ -433,8 +441,10 @@ def _analyze_consistency(self, full_validation=False, ignore_cols: list = list()
err_mods[column] = {}
err_mods[column]["New distinct values"] = new_mods
err_mods[column]["Removed distinct values"] = removed_mods
print(f"""The variable {column} has mismatching unique values:
{new_mods} | {removed_mods}\n""")
print(
f"""The variable {column} has mismatching unique values:
{new_mods} | {removed_mods}\n"""
)
return ({"New columns": new_cols, "Removed columns": removed_cols, "Type errors": err_dtypes}, err_mods)

def _predict(self, deployed_model=None, encoding=None):
Expand Down Expand Up @@ -716,13 +726,14 @@ def _compute_datadrift_stat_test(self, max_size=50000, categ_max=20):
test = ksmirnov_test(current[features].to_numpy(), baseline[features].to_numpy())
except BaseException as e:
raise Exception(
"""
"""
There is a problem with the format of {} column between the two datasets.
Error:
""".format(
str(features)
str(features)
)
+ str(e)
)
+ str(e))
test_results[features] = test

return pd.DataFrame.from_dict(test_results, orient="index")
Expand Down
4 changes: 2 additions & 2 deletions eurybia/core/smartplotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def generate_modeldrift_data(
template: Optional[str] = None,
title: Optional[str] = None,
xaxis_title: Optional[str] = None,
yaxis_title: Optional[str] = None,
yaxis_title: Optional[dict] = None,
xaxis: Optional[str] = None,
height: Optional[str] = None,
width: Optional[str] = None,
Expand All @@ -578,7 +578,7 @@ def generate_modeldrift_data(
Plot title
xaxis_title: str, optional
X axis title
yaxis_title: str, optional
yaxis_title: dict, optional
y axis title
xaxis: str, optional
X axis options (spike line, margin, range ...)
Expand Down
14 changes: 14 additions & 0 deletions tests/unit_tests/core/test_smartdrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ def test_compile_model_encoder(self):
smart_drift.compile()
assert isinstance(smart_drift.xpl, shapash.explainer.smart_explainer.SmartExplainer)

def test_compile_dataset_names(self):
"""
test compile() with a model and an encoder specified
"""
smart_drift = SmartDrift(
self.titanic_df_1,
self.titanic_df_2,
deployed_model=self.rf,
encoding=self.categ_encoding,
dataset_names={"df_current": "titanic 2", "df_baseline": "titanic 1"},
)
smart_drift.compile()
assert isinstance(smart_drift.xpl, shapash.explainer.smart_explainer.SmartExplainer)

def test_generate_report_fullvalid(self):
"""
test generate_report() with fullvalidation option specified to True
Expand Down

0 comments on commit bda2e64

Please sign in to comment.