Skip to content

Commit

Permalink
Typos fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
szczepanskiNicolas committed Feb 21, 2024
1 parent e93f1c9 commit d8d61a2
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 44 deletions.
17 changes: 12 additions & 5 deletions pyxai/sources/core/explainer/Explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,12 @@ def set_features_type(self, features_types):
nCategorical = len(self._reg_exp_categorical_features.keys())

Tools.verbose("--------- Theory Feature Types -----------")
Tools.verbose("Before the encoding (without one hot encoded features), we have:")
Tools.verbose("Before the one-hot encoding of categorical features:")
Tools.verbose("Numerical features:", nNumerical)
Tools.verbose("Categorical features:", nCategorical)
Tools.verbose("Binary features:", nBinaries)
Tools.verbose("Number of features:", nNumerical + nCategorical + nBinaries)
Tools.verbose("Values of categorical features:", self._values_categorical_features)
Tools.verbose("Characteristics of categorical features:", self._values_categorical_features)
used_features = set()
used_features_without_one_hot_encoded = set()
for key in model.map_features_to_id_binaries.keys():
Expand All @@ -274,9 +274,9 @@ def set_features_type(self, features_types):
used_features.add(key[0])
used_features_without_one_hot_encoded.add(self.map_indexes[key[0]])
Tools.verbose("")
Tools.verbose("Number of used features in the model (before the encoding):",
Tools.verbose("Number of used features in the model (before the encoding of categorical features):",
len(used_features_without_one_hot_encoded))
Tools.verbose("Number of used features in the model (after the encoding):", len(used_features))
Tools.verbose("Number of used features in the model (after the encoding of categorical features):", len(used_features))
Tools.verbose("----------------------------------------------")

def _theory_clauses(self):
Expand Down Expand Up @@ -598,7 +598,14 @@ def _most_anchored_reason(self, *, n_variables, cnf, time_limit=None, check=Fals
self.last_n_anchors = 0

return None if previous_reason is None else Explainer.format(previous_reason)


def simplify_theory(self, tree):
if self._theory is True:
solver = GlucoseSolver()
theory_cnf = self.get_model().get_theory(None)
return solver.symplify_theory(tree, theory_cnf)
return tree

@property
def visualisation(self):
"""This object allows to open gui, save images, and so on
Expand Down
36 changes: 11 additions & 25 deletions pyxai/sources/core/explainer/explainerDT.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,18 +303,7 @@ def is_reason(self, reason, *, n_samples=-1):
return self._tree.is_implicant(reason, self.target_prediction)


def simplify_theory(self, tree_rectified):
if self._theory is True:
solver = GlucoseSolver()
#max_id_binary_cnf = CNFencoding.compute_max_id_variable(tree_rectified)
#theory_cnf, _ = self._tree.get_theory(
# self._binary_representation,
# theory_type=TypeTheory.NEW_VARIABLES,
# id_new_var=max_id_binary_cnf)
theory_cnf = self._tree.get_theory(None)
tree_rectified = solver.symplify_theory(tree_rectified, theory_cnf)
return tree_rectified
return tree_rectified


def rectify(self, *, conditions, label):
"""
Expand All @@ -328,29 +317,26 @@ def rectify(self, *, conditions, label):
DecisionTree: The rectified tree.
"""
Tools.verbose("")
Tools.verbose("-------------- Rectify information:")
Tools.verbose("-------------- Rectification information:")
tree_decision_rule = self._tree.decision_rule_to_tree(conditions)
Tools.verbose("Desision Rule Number of nodes:", tree_decision_rule.n_nodes())
Tools.verbose("Model Number of nodes:", self._tree.n_nodes())
Tools.verbose("Classification Rule - Number of nodes:", tree_decision_rule.n_nodes())
Tools.verbose("Model - Number of nodes:", self._tree.n_nodes())
if label == 1:
# When label is 1, we have to inverse the decision rule and disjoint the two trees.
tree_decision_rule = tree_decision_rule.negating_tree()
tree_rectified = self._tree.disjoint_tree(tree_decision_rule)
Tools.verbose("Model Number of nodes (after rectify):", tree_rectified.n_nodes())
tree_rectified = self.simplify_theory(tree_rectified)
Tools.verbose("Model Number of nodes (symplify theory):", tree_rectified.n_nodes())
tree_rectified.simplify()
Tools.verbose("Model Number of nodes (symplify redundancy):", tree_rectified.n_nodes())
elif label == 0:
# When label is 0, we have to concatenate the two trees.
tree_rectified = self._tree.concatenate_tree(tree_decision_rule)
Tools.verbose("Model Number of nodes (after rectify):", tree_rectified.n_nodes())
self.simplify_theory(tree_rectified)
Tools.verbose("Model Number of nodes (symplify theory):", tree_rectified.n_nodes())
tree_rectified.simplify()
Tools.verbose("Model Number of nodes (symplify redundancy):", tree_rectified.n_nodes())
else:
raise NotImplementedError("Multiclasses is in progress.")

Tools.verbose("Model - Number of nodes (after rectification):", tree_rectified.n_nodes())
tree_rectified = self.simplify_theory(tree_rectified)
Tools.verbose("Model - Number of nodes (after simplification using the theory):", tree_rectified.n_nodes())
tree_rectified.simplify()
Tools.verbose("Model - Number of nodes (after elimination of redundant nodes):", tree_rectified.n_nodes())

self._tree = tree_rectified
if self._instance is not None:
self.set_instance(self._instance)
Expand Down
48 changes: 47 additions & 1 deletion pyxai/sources/core/explainer/explainerRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyxai.sources.solvers.MUS.MUSERSolver import MUSERSolver
from pyxai.sources.solvers.MUS.OPTUXSolver import OPTUXSolver
from pyxai.sources.solvers.SAT.glucoseSolver import GlucoseSolver

from pyxai import Tools

class ExplainerRF(Explainer):

Expand Down Expand Up @@ -537,3 +537,49 @@ def most_anchored_reason(self, *, time_limit=None, check=False, type_references=
raise NotImplementedError("The anchored_reason() method for RF works only with binary-class datasets.")


def rectify(self, *, conditions, label):
"""
Rectify the Decision Tree (self._tree) of the explainer according to a `conditions` and a `label`.
Simplify the model (the theory can help to eliminate some nodes).
Args:
decision_rule (list or tuple): A decision rule in the form of list of literals (binary variables representing the conditions of the tree).
label (int): The label of the decision rule.
Returns:
DecisionTree: The rectified tree.
"""
Tools.verbose("")
Tools.verbose("-------------- Rectification information:")
tree_decision_rule = self._random_forest.forest[0].decision_rule_to_tree(conditions)

Tools.verbose("Classification Rule - Number of nodes:", tree_decision_rule.n_nodes())
Tools.verbose("Model - Number of nodes:", self._random_forest.n_nodes())

for i, tree in enumerate(self._random_forest.forest):
Tools.verbose("Model - Number of nodes "+str(i)+":"+ str(tree.n_nodes()))
for i, tree in enumerate(self._random_forest.forest):
if label == 1:
# When label is 1, we have to inverse the decision rule and disjoint the two trees.
# Bug ici
tree_decision_rule = tree_decision_rule.negating_tree()
self._random_forest.forest[i] = tree.disjoint_tree(tree_decision_rule)
elif label == 0:
# When label is 0, we have to concatenate the two trees.
self._random_forest.forest[i] = tree.concatenate_tree(tree_decision_rule)
else:
raise NotImplementedError("Multiclasses is in progress.")

Tools.verbose("Model - Number of nodes (after rectification):", self._random_forest.n_nodes())

#for i, tree in enumerate(self._random_forest.forest):
# self._random_forest.forest[i] = self.simplify_theory(tree)
#Tools.verbose("Model - Number of nodes (after simplification using the theory):", self._random_forest.n_nodes())

#for i, tree in enumerate(self._random_forest.forest):
# tree.simplify()
#Tools.verbose("Model - Number of nodes (after elimination of redundant nodes):", self._random_forest.n_nodes())

if self._instance is not None:
self.set_instance(self._instance)
Tools.verbose("--------------")
return self._random_forest
26 changes: 16 additions & 10 deletions pyxai/sources/core/structure/decisionTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def to_tuples(self, node, for_cpp=False):


def simplify(self):
while self._simplify(self.root):
while self._simplify(self.root, self.root):
pass


def _simplify(self, node, path=[], come_from=None, previous_node=None, previous_previous_node=None):
def _simplify(self, root, node, path=[], come_from=None, previous_node=None, previous_previous_node=None):
res_1 = False
res_2 = False
change = False
Expand All @@ -89,24 +89,30 @@ def _simplify(self, node, path=[], come_from=None, previous_node=None, previous_
new_tuple = (self.get_id_variable(previous_node), come_from)
if new_tuple in path:
if path[-1][1] == 0:
previous_previous_node.left = node
if previous_previous_node is not None:
previous_previous_node.left = node
change = True
elif path[-1][1] == 1:
previous_previous_node.right = node
change = True
if previous_previous_node is not None:
previous_previous_node.right = node
change = True
path.append(new_tuple)

# print("path:", path)
if not node.is_leaf():
raw = self.to_tuples(node)
if raw[1] == raw[2]:
if come_from == 0:
previous_node.left = node.left
if previous_node is not None:
previous_node.left = node.left
change = True
if come_from == 1:
previous_node.right = node.right
change = True
if previous_node is not None:
previous_node.right = node.right
change = True
pp = previous_node
res_1 = self._simplify(node.left, copy.deepcopy(path), come_from=0, previous_node=node, previous_previous_node=pp)
res_2 = self._simplify(node.right, copy.deepcopy(path), come_from=1, previous_node=node, previous_previous_node=pp)
res_1 = self._simplify(root, node.left, copy.deepcopy(path), come_from=0, previous_node=node, previous_previous_node=pp)
res_2 = self._simplify(root, node.right, copy.deepcopy(path), come_from=1, previous_node=node, previous_previous_node=pp)
return res_1 or res_2 or change

"""
Expand Down
1 change: 1 addition & 0 deletions pyxai/sources/core/structure/randomForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def predict_instance(self, instance):
n_votes = numpy.zeros(self.n_classes)
for tree in self.forest:
n_votes[tree.predict_instance(instance)] += 1
print("predict: ",n_votes)
return numpy.argmax(n_votes)


Expand Down
2 changes: 2 additions & 0 deletions pyxai/sources/core/structure/treeEnsembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, forest, learner_information):
tree.map_id_binaries_to_features = self.map_id_binaries_to_features
tree.map_features_to_id_binaries = self.map_features_to_id_binaries

def n_nodes(self):
return sum(tree.n_nodes() for tree in self.forest)

def redundancy_analysis(self):
n_variables = len(self.map_features_to_id_binaries)
Expand Down
3 changes: 2 additions & 1 deletion pyxai/sources/learning/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ def get_instances(self,
else:
for j in original_indexes:
current_index = possible_indexes[j]
prediction_solver = learner.predict(numpy.ascontiguousarray(data[j].reshape(1, -1)))[0]
#prediction_solver = learner.predict(numpy.ascontiguousarray(data[j].reshape(1, -1)))[0]
prediction_solver = model.predict_instance(data[j])

# J'ai, a priori de la chance, que la fonction predict de xgboost et scikit learnt ont la meme def !
# A voir comment faire, peux être au niveau de extras si on a un probleme avec cela.
Expand Down
47 changes: 45 additions & 2 deletions pyxai/tests/functionality/Rectify.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,50 @@

import unittest
class TestRectify(unittest.TestCase):
def test_rectify_5(self):
learner = Learning.Scikitlearn("tests/compas.csv", learner_type=Learning.CLASSIFICATION)
model = learner.evaluate(method=Learning.HOLD_OUT, output=Learning.RF)

dict_information = learner.get_instances(model, n=1, indexes=Learning.TEST, correct=False, details=True)

#all_dict_information = learner.get_instances(model, indexes=Learning.ALL, details=True)

instance = dict_information["instance"]
label = dict_information["label"]
prediction = dict_information["prediction"]
print("prediction:", prediction)
print("before:", model.predict_instance(instance))

compas_types = {
"numerical": ["Number_of_Priors"],
"binary": ["Misdemeanor", "score_factor", "Female"],
"categorical": {"Origin*": ["African_American", "Asian", "Hispanic", "Native_American", "Other"],
"Age*": ["Above_FourtyFive", "Below_TwentyFive"]}
}


explainer = Explainer.initialize(model, instance=instance, features_type=compas_types)
reason = explainer.majoritary_reason(n=1)

print("explanation:", reason)
print("explanation:", explainer.to_features(reason))
model = explainer.rectify(conditions=reason, label=1)
print("after:", model.predict_instance(instance))

self.assertEqual(model.predict_instance(instance), 1)

reason = set(reason)

for instance_dict in all_dict_information:
instance = instance_dict["instance"]
old_prediction = instance_dict["prediction"]
binary_representation = set(explainer._to_binary_representation(instance))
result = binary_representation.intersection(reason)
if len(result) == len(reason):
self.assertEqual(model.predict_instance(instance), 1)
else:
self.assertEqual(model.predict_instance(instance), old_prediction)

def test_rectify_1(self):
nodeT1_3 = Builder.DecisionNode(3, left=0, right=1)
nodeT1_2 = Builder.DecisionNode(2, left=1, right=0)
Expand Down Expand Up @@ -129,8 +172,8 @@ def test_rectify_3(self):
self.assertEqual(model.predict_instance(instance), 1)
else:
self.assertEqual(model.predict_instance(instance), old_prediction)



if __name__ == '__main__':
print("Tests: " + TestRectify.__name__ + ":")
Expand Down

0 comments on commit d8d61a2

Please sign in to comment.