Skip to content

Commit

Permalink
Test errors solved
Browse files Browse the repository at this point in the history
  • Loading branch information
szczepanskiNicolas committed Feb 22, 2024
1 parent d52da4f commit dad6deb
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
6 changes: 5 additions & 1 deletion pyxai/sources/core/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
from time import time
from typing import Iterable

import sklearn
from sklearn.metrics import confusion_matrix, mean_squared_error, mean_absolute_error

from pyxai.sources.core.structure.type import PreferredReasonMethod
from packaging import version

def check_sklearn_14():
return version.parse(sklearn.__version__) > version.parse("1.4.0")

def check_PyQt6():
ok, error = _check_PyQt6()
Expand Down Expand Up @@ -52,7 +56,7 @@ class Metric:
def compute_metrics_regression(labels, predictions):
return {
"mean_squared_error": mean_squared_error(labels, predictions),
"root_mean_squared_error": mean_squared_error(labels, predictions, squared=False),
"root_mean_squared_error": mean_squared_error(labels, predictions, squared=False) if check_sklearn_14() is False else sklearn.metrics.root_mean_squared_error(labels, predictions),
"mean_absolute_error": mean_absolute_error(labels, predictions)
}

Expand Down
2 changes: 1 addition & 1 deletion pyxai/sources/learning/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def create_dict_labels(self, labels):
self.dict_labels = OrderedDict()
self.inverse_dict_labels = OrderedDict()
set_labels = set(labels)
check_type_int = all(numpy.issubdtype(x, numpy.integer) or isinstance(x, int) or (isinstance(x, str) and x.isnumeric()) for x in set_labels)
check_type_int = all((not isinstance(x, str) and numpy.issubdtype(x, numpy.integer)) or isinstance(x, int) or (isinstance(x, str) and x.isnumeric()) for x in set_labels)
if check_type_int is True:
for p in labels:
if str(p) not in self.dict_labels:
Expand Down
2 changes: 1 addition & 1 deletion pyxai/tests/functionality/Metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def do_import(self, dataset, learner_type):
learner, models = Learning.import_models(sk_models)
for i, model in enumerate(models):
instances_details = learner.get_instances(model, dataset=dataset, indexes=Learning.TEST, test_indexes=test_indexes[i], details=True)
predictions = [learner.get_value_from_label(element["prediction"]) for element in instances_details]
predictions = [element["prediction"] for element in instances_details]
true_values = [element["label"] for element in instances_details]
metrics = learner.compute_metrics(true_values, predictions)
self.assertTrue("micro_averaging_accuracy" in metrics.keys())
Expand Down
4 changes: 2 additions & 2 deletions pyxai/tests/importing/ScikitLearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def do_import(self, dataset, learner_type):
instances = learner.get_instances(dataset=dataset, model=model, n=10)

for (instance, prediction_classifier) in instances:
prediction_model_1 = learner.get_label_from_value(model.predict_instance(instance))
prediction_model_1 = model.predict_instance(instance)
implicant = model.instance_to_binaries(instance)
prediction_model_2 = learner.get_label_from_value(model.predict_implicant(implicant))
prediction_model_2 = model.predict_implicant(implicant)

self.assertEqual(str(prediction_model_1), str(prediction_model_2))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
classifiers=['Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Education'],
packages=find_packages(), # exclude=["problems/g7_todo/"]),
package_dir={'pyxai': 'pyxai'},
install_requires=['lxml', 'numpy', 'wheel', 'pandas', 'termcolor', 'shap', 'wordfreq', 'python-sat[pblib,aiger]', 'xgboost==1.7.3', 'pycsp3', 'matplotlib', 'dill', 'lightgbm', 'docplex', 'ortools'],
install_requires=['lxml', 'numpy', 'wheel', 'pandas', 'termcolor', 'shap', 'wordfreq', 'python-sat[pblib,aiger]', 'xgboost==1.7.3', 'pycsp3', 'matplotlib', 'dill', 'lightgbm', 'docplex', 'ortools', 'packaging'],
extras_require={
"gui": ['pyqt6'],
},
Expand Down

0 comments on commit dad6deb

Please sign in to comment.