diff --git a/pyxai/sources/core/tools/utils.py b/pyxai/sources/core/tools/utils.py index 5e361ea..901d9c2 100644 --- a/pyxai/sources/core/tools/utils.py +++ b/pyxai/sources/core/tools/utils.py @@ -12,10 +12,14 @@ from time import time from typing import Iterable +import sklearn from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error from pyxai.sources.core.structure.type import PreferredReasonMethod +from packaging import version +def check_sklearn_14(): + return version.parse(sklearn.__version__) > version.parse("1.4.0") def check_PyQt6(): ok, error = _check_PyQt6() @@ -52,7 +56,7 @@ class Metric: def compute_metrics_regression(labels, predictions): return { "mean_squared_error": mean_squared_error(labels, predictions), - "root_mean_squared_error": mean_squared_error(labels, predictions, squared=False), + "root_mean_squared_error": mean_squared_error(labels, predictions, squared=False) if check_sklearn_14() is False else sklearn.metrics.root_mean_squared_error(labels, predictions), "mean_absolute_error": mean_absolute_error(labels, predictions) } diff --git a/pyxai/sources/learning/learner.py b/pyxai/sources/learning/learner.py index 5b9cdb2..f201fd6 100644 --- a/pyxai/sources/learning/learner.py +++ b/pyxai/sources/learning/learner.py @@ -167,7 +167,7 @@ def create_dict_labels(self, labels): self.dict_labels = OrderedDict() self.inverse_dict_labels = OrderedDict() set_labels = set(labels) - check_type_int = all(numpy.issubdtype(x, numpy.integer) or isinstance(x, int) or (isinstance(x, str) and x.isnumeric()) for x in set_labels) + check_type_int = all((not isinstance(x, str) and numpy.issubdtype(x, numpy.integer)) or isinstance(x, int) or (isinstance(x, str) and x.isnumeric()) for x in set_labels) if check_type_int is True: for p in labels: if str(p) not in self.dict_labels: diff --git a/pyxai/tests/functionality/Metrics.py b/pyxai/tests/functionality/Metrics.py index 4719a50..286dff9 100644 --- a/pyxai/tests/functionality/Metrics.py +++ b/pyxai/tests/functionality/Metrics.py @@ -128,7 +128,7 @@ def do_import(self, dataset, learner_type): learner, models = Learning.import_models(sk_models) for i, model in enumerate(models): instances_details = learner.get_instances(model, dataset=dataset, indexes=Learning.TEST, test_indexes=test_indexes[i], details=True) - predictions = [learner.get_value_from_label(element["prediction"]) for element in instances_details] + predictions = [element["prediction"] for element in instances_details] true_values = [element["label"] for element in instances_details] metrics = learner.compute_metrics(true_values, predictions) self.assertTrue("micro_averaging_accuracy" in metrics.keys()) diff --git a/pyxai/tests/importing/ScikitLearn.py b/pyxai/tests/importing/ScikitLearn.py index d255294..0ec0863 100644 --- a/pyxai/tests/importing/ScikitLearn.py +++ b/pyxai/tests/importing/ScikitLearn.py @@ -59,9 +59,9 @@ def do_import(self, dataset, learner_type): instances = learner.get_instances(dataset=dataset, model=model, n=10) for (instance, prediction_classifier) in instances: - prediction_model_1 = learner.get_label_from_value(model.predict_instance(instance)) + prediction_model_1 = model.predict_instance(instance) implicant = model.instance_to_binaries(instance) - prediction_model_2 = learner.get_label_from_value(model.predict_implicant(implicant)) + prediction_model_2 = model.predict_implicant(implicant) self.assertEqual(str(prediction_model_1), str(prediction_model_2)) diff --git a/setup.py b/setup.py index e3f370a..3f01f07 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ classifiers=['Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Education'], packages=find_packages(), # exclude=["problems/g7_todo/"]), package_dir={'pyxai': 'pyxai'}, - install_requires=['lxml', 'numpy', 'wheel', 'pandas', 'termcolor', 'shap', 'wordfreq', 'python-sat[pblib,aiger]', 'xgboost==1.7.3', 'pycsp3', 'matplotlib', 'dill', 'lightgbm', 'docplex', 'ortools'], + install_requires=['lxml', 'numpy', 'wheel', 'pandas', 'termcolor', 'shap', 'wordfreq', 'python-sat[pblib,aiger]', 'xgboost==1.7.3', 'pycsp3', 'matplotlib', 'dill', 'lightgbm', 'docplex', 'ortools', 'packaging'], extras_require={ "gui": ['pyqt6'], },