diff --git a/pyxai/examples/DT/base.py b/pyxai/examples/DT/base.py new file mode 100644 index 0000000..6accb03 --- /dev/null +++ b/pyxai/examples/DT/base.py @@ -0,0 +1,23 @@ +from pyxai import Learning, Explainer + +learner = Learning.Scikitlearn("pyxai/tests/dermatology.csv", learner_type=Learning.CLASSIFICATION) +model = learner.evaluate(method=Learning.HOLD_OUT, output=Learning.DT) +instance, prediction = learner.get_instances(model, n=1, correct=True, predictions=[0]) + +print("le:", learner.dict_labels) + +explainer = Explainer.initialize(model, instance) +print("instance:", instance) +print("binary representation:", explainer.binary_representation) + +sufficient_reason = explainer.sufficient_reason(n=1) +print("sufficient_reason:", sufficient_reason) +print("to_features:", explainer.to_features(sufficient_reason)) + +instance, prediction = learner.get_instances(model, n=1, correct=False) +explainer.set_instance(instance) +contrastive_reason = explainer.contrastive_reason() +print("contrastive reason", contrastive_reason) +print("to_features:", explainer.to_features(contrastive_reason, contrastive=True)) + +explainer.visualisation.screen(instance, contrastive_reason, contrastive=True) \ No newline at end of file diff --git a/pyxai/sources/learning/learner.py b/pyxai/sources/learning/learner.py index f0cb87d..4a3e192 100644 --- a/pyxai/sources/learning/learner.py +++ b/pyxai/sources/learning/learner.py @@ -166,11 +166,21 @@ def create_dict_labels(self, labels): index = 0 self.dict_labels = OrderedDict() self.inverse_dict_labels = OrderedDict() - for p in labels: - if str(p) not in self.dict_labels: - self.dict_labels[str(p)] = index - self.inverse_dict_labels[index] = str(p) - index += 1 + set_labels = set(labels) + check_type_int = all(numpy.issubdtype(x, numpy.integer) or isinstance(x, int) or (isinstance(x, str) and x.isnumeric()) for x in set_labels) + if check_type_int is True: + for p in labels: + if str(p) not in self.dict_labels: + self.dict_labels[str(p)] = int(p) + self.inverse_dict_labels[int(p)] = str(p) + else: + for p in labels: + if str(p) not in self.dict_labels: + self.dict_labels[str(p)] = index + self.inverse_dict_labels[index] = str(p) + index += 1 + + """