Skip to content

Commit

Permalink
Rectification for RF
Browse files Browse the repository at this point in the history
  • Loading branch information
szczepanskiNicolas committed Feb 22, 2024
1 parent d8d61a2 commit 17af2bc
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 22 deletions.
19 changes: 8 additions & 11 deletions pyxai/sources/core/explainer/explainerRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,17 +551,14 @@ def rectify(self, *, conditions, label):
Tools.verbose("")
Tools.verbose("-------------- Rectification information:")
tree_decision_rule = self._random_forest.forest[0].decision_rule_to_tree(conditions)

if label == 1:
tree_decision_rule = tree_decision_rule.negating_tree()
Tools.verbose("Classification Rule - Number of nodes:", tree_decision_rule.n_nodes())
Tools.verbose("Model - Number of nodes:", self._random_forest.n_nodes())

for i, tree in enumerate(self._random_forest.forest):
Tools.verbose("Model - Number of nodes "+str(i)+":"+ str(tree.n_nodes()))
for i, tree in enumerate(self._random_forest.forest):
if label == 1:
# When label is 1, we have to inverse the decision rule and disjoint the two trees.
# Bug ici
tree_decision_rule = tree_decision_rule.negating_tree()
self._random_forest.forest[i] = tree.disjoint_tree(tree_decision_rule)
elif label == 0:
# When label is 0, we have to concatenate the two trees.
Expand All @@ -571,13 +568,13 @@ def rectify(self, *, conditions, label):

Tools.verbose("Model - Number of nodes (after rectification):", self._random_forest.n_nodes())

#for i, tree in enumerate(self._random_forest.forest):
# self._random_forest.forest[i] = self.simplify_theory(tree)
#Tools.verbose("Model - Number of nodes (after simplification using the theory):", self._random_forest.n_nodes())
for i, tree in enumerate(self._random_forest.forest):
self._random_forest.forest[i] = self.simplify_theory(tree)
Tools.verbose("Model - Number of nodes (after simplification using the theory):", self._random_forest.n_nodes())

#for i, tree in enumerate(self._random_forest.forest):
# tree.simplify()
#Tools.verbose("Model - Number of nodes (after elimination of redundant nodes):", self._random_forest.n_nodes())
for i, tree in enumerate(self._random_forest.forest):
tree.simplify()
Tools.verbose("Model - Number of nodes (after elimination of redundant nodes):", self._random_forest.n_nodes())

if self._instance is not None:
self.set_instance(self._instance)
Expand Down
1 change: 0 additions & 1 deletion pyxai/sources/core/structure/randomForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def predict_instance(self, instance):
n_votes = numpy.zeros(self.n_classes)
for tree in self.forest:
n_votes[tree.predict_instance(instance)] += 1
print("predict: ",n_votes)
return numpy.argmax(n_votes)


Expand Down
13 changes: 3 additions & 10 deletions pyxai/tests/functionality/Rectify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math


#Tools.set_verbose(0)
Tools.set_verbose(0)

import unittest
class TestRectify(unittest.TestCase):
Expand All @@ -12,14 +12,12 @@ def test_rectify_5(self):

dict_information = learner.get_instances(model, n=1, indexes=Learning.TEST, correct=False, details=True)

#all_dict_information = learner.get_instances(model, indexes=Learning.ALL, details=True)
all_dict_information = learner.get_instances(model, indexes=Learning.ALL, details=True)

instance = dict_information["instance"]
label = dict_information["label"]
prediction = dict_information["prediction"]
print("prediction:", prediction)
print("before:", model.predict_instance(instance))


compas_types = {
"numerical": ["Number_of_Priors"],
"binary": ["Misdemeanor", "score_factor", "Female"],
Expand All @@ -31,10 +29,7 @@ def test_rectify_5(self):
explainer = Explainer.initialize(model, instance=instance, features_type=compas_types)
reason = explainer.majoritary_reason(n=1)

print("explanation:", reason)
print("explanation:", explainer.to_features(reason))
model = explainer.rectify(conditions=reason, label=1)
print("after:", model.predict_instance(instance))

self.assertEqual(model.predict_instance(instance), 1)

Expand Down Expand Up @@ -155,8 +150,6 @@ def test_rectify_3(self):

explainer = Explainer.initialize(model, instance=instance, features_type=compas_types)
minimal_reason = explainer.minimal_sufficient_reason(n=1)
print("explanation:", minimal_reason)
print("explanation:", explainer.to_features(minimal_reason))
model = explainer.rectify(conditions=minimal_reason, label=1)

self.assertEqual(model.predict_instance(instance), 1)
Expand Down

0 comments on commit 17af2bc

Please sign in to comment.