Skip to content

Commit

Permalink
Rectify
Browse files Browse the repository at this point in the history
  • Loading branch information
szczepanskiNicolas committed Feb 19, 2024
1 parent 75c9529 commit a225ef9
Show file tree
Hide file tree
Showing 33 changed files with 193 additions and 180 deletions.
File renamed without changes.
34 changes: 34 additions & 0 deletions pyxai/examples/DT/builder-rectify3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

from pyxai import Builder, Explainer

# Builder part

node_L_1 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)
node_L_2 = Builder.DecisionNode(1, operator=Builder.GT, threshold=20, left=0, right=node_L_1)

node_R_1 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)
node_R_2 = Builder.DecisionNode(2, operator=Builder.EQ, threshold=1, left=node_R_1, right=1)

root = Builder.DecisionNode(1, operator=Builder.GT, threshold=30, left=node_L_2, right=node_R_2)
tree = Builder.DecisionTree(3, root, feature_names=["I", "PP", "R"])

print("base:", tree.raw_data_for_CPP())
loan_types = {
"numerical": ["I"],
"binary": ["PP", "R"],
}

print("bob = (25, 1, 1):")
bob = (25, 1, 1)
explainer = Explainer.initialize(tree, instance=bob, features_type=loan_types)

print("binary representation: ", explainer.binary_representation)
print("target_prediction:", explainer.target_prediction)
print("to_features:", explainer.to_features(explainer.binary_representation, eliminate_redundant_features=False))

#For him/her, the following classification rule must be obeyed:
#whenever the annual income of the client is lower than 30,
#the demand should be rejected
rectified_model = explainer.rectify(decision_rule=(-1, ), label=0)

assert (0, (1, 0, (4, (3, 0, 1), 1))) == rectified_model.raw_data_for_CPP(), "The rectified model is not good."
33 changes: 17 additions & 16 deletions pyxai/sources/core/explainer/explainerDT.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from pyxai.sources.core.tools.utils import compute_weight
from pyxai.sources.solvers.COMPILER.D4Solver import D4Solver
from pyxai.sources.solvers.MAXSAT.OPENWBOSolver import OPENWBOSolver

from pyxai.sources.solvers.SAT.glucoseSolver import GlucoseSolver

from pyxai import Tools

class ExplainerDT(Explainer):

Expand Down Expand Up @@ -318,32 +317,34 @@ def rectify(self, *, decision_rule, label):
Returns:
DecisionTree: The rectified tree.
"""
print("")
print("-------------- Rectify information:")
Tools.verbose("")
Tools.verbose("-------------- Rectify information:")
tree_decision_rule = self._tree.decision_rule_to_tree(decision_rule)
print("Desision Rule Number of nodes:", tree_decision_rule.n_nodes())
print("Model Number of nodes:", self._tree.n_nodes())
Tools.verbose("Desision Rule Number of nodes:", tree_decision_rule.n_nodes())
Tools.verbose("Model Number of nodes:", self._tree.n_nodes())
if label == 1:
#
# When label is 1, we have to inverse the decision rule and disjoint the two trees.
tree_decision_rule = tree_decision_rule.negating_tree()
tree_rectified = self._tree.disjoint_tree(tree_decision_rule)
print("Model Number of nodes (after rectify):", tree_rectified.n_nodes())
Tools.verbose("Model Number of nodes (after rectify):", tree_rectified.n_nodes())
tree_rectified = self.simplify_theory(tree_rectified)
print("Model Number of nodes (symplify theory):", tree_rectified.n_nodes())
Tools.verbose("Model Number of nodes (symplify theory):", tree_rectified.n_nodes())
tree_rectified.simplify()
print("Model Number of nodes (symplify redundancy):", tree_rectified.n_nodes())
else:
#
Tools.verbose("Model Number of nodes (symplify redundancy):", tree_rectified.n_nodes())
elif label == 0:
# When label is 0, we have to concatenate the two trees.
tree_rectified = self._tree.concatenate_tree(tree_decision_rule)
print("Model Number of nodes (after rectify):", tree_rectified.n_nodes())
Tools.verbose("Model Number of nodes (after rectify):", tree_rectified.n_nodes())
self.simplify_theory(tree_rectified)
print("Model Number of nodes (symplify theory):", tree_rectified.n_nodes())
Tools.verbose("Model Number of nodes (symplify theory):", tree_rectified.n_nodes())
tree_rectified.simplify()
print("Model Number of nodes (symplify redundancy):", tree_rectified.n_nodes())
Tools.verbose("Model Number of nodes (symplify redundancy):", tree_rectified.n_nodes())
else:
raise NotImplementedError("Multiclasses is in progress.")
self._tree = tree_rectified
if self._instance is not None:
self.set_instance(self._instance)
print("--------------")
Tools.verbose("--------------")
return self._tree


Expand Down
4 changes: 2 additions & 2 deletions pyxai/tests/functionality/Metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

Tools.set_verbose(0)

class TestLearningScikitlearn(unittest.TestCase):
class TestMetrics(unittest.TestCase):
PRECISION = 1

def test_binary_classification(self):
Expand Down Expand Up @@ -188,5 +188,5 @@ def cross_validation(self, X, Y, learner_type, n_trees=100, n_forests=10):
forests.append((rf, index_training, index_test))
return forests
if __name__ == '__main__':
print("Tests: " + TestLearningScikitlearn.__name__ + ":")
print("Tests: " + TestMetrics.__name__ + ":")
unittest.main()
136 changes: 136 additions & 0 deletions pyxai/tests/functionality/Rectify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from pyxai import Builder, Learning, Explainer, Tools
import math


Tools.set_verbose(0)

import unittest
class TestRectify(unittest.TestCase):

def test_rectify_1(self):
nodeT1_3 = Builder.DecisionNode(3, left=0, right=1)
nodeT1_2 = Builder.DecisionNode(2, left=1, right=0)
nodeT1_1 = Builder.DecisionNode(1, left=nodeT1_2, right=nodeT1_3)
model = Builder.DecisionTree(3, nodeT1_1, force_features_equal_to_binaries=True)

loan_types = {
"binary": ["f1", "f2", "f3"],
}

explainer = Explainer.initialize(model, features_type=loan_types)

#Alice’s expertise can be represented by the formula T = ((x1 ∧ not x3) ⇒ y) ∧ (not x2 ⇒ not y) encoding her two decision rules
explainer.rectify(decision_rule=(1, -3), label=1) #(x1 ∧ not x3) ⇒ y
explainer.rectify(decision_rule=(-2, ), label=0) #not x2 ⇒ not y

rectified_model = explainer.get_model().raw_data_for_CPP()

self.assertEqual(rectified_model, (0, (1, 0, (2, 0, 1))))

def test_rectify_2(self):

node_v3_1 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)
node_v2_1 = Builder.DecisionNode(2, operator=Builder.EQ, threshold=1, left=0, right=node_v3_1)

node_v3_2 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)
node_v2_2 = Builder.DecisionNode(2, operator=Builder.EQ, threshold=1, left=0, right=node_v3_2)

node_v3_3 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)
node_v2_3 = Builder.DecisionNode(2, operator=Builder.EQ, threshold=1, left=0, right=node_v3_3)

node_v1_1 = Builder.DecisionNode(1, operator=Builder.GE, threshold=10, left=node_v2_1, right=node_v2_2)
node_v1_2 = Builder.DecisionNode(1, operator=Builder.GE, threshold=20, left=node_v1_1, right=node_v2_3)
node_v1_3 = Builder.DecisionNode(1, operator=Builder.GE, threshold=30, left=node_v1_2, right=1)
node_v1_4 = Builder.DecisionNode(1, operator=Builder.GE, threshold=40, left=node_v1_3, right=1)

tree = Builder.DecisionTree(3, node_v1_4)

loan_types = {
"numerical": ["f1"],
"binary": ["f2", "f3"],
}

bob = (20, 1, 0)
explainer = Explainer.initialize(tree, instance=bob, features_type=loan_types)


minimal = explainer.minimal_sufficient_reason()


explainer.rectify(decision_rule=minimal, label=1)
rectified_model = explainer.get_model().raw_data_for_CPP()
self.assertEqual(rectified_model, (0, (1, (2, (5, (6, 1, 0), 1), 1), 1)))

def test_rectify_4(self):

node_L_1 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)
node_L_2 = Builder.DecisionNode(1, operator=Builder.GT, threshold=20, left=0, right=node_L_1)

node_R_1 = Builder.DecisionNode(3, operator=Builder.EQ, threshold=1, left=0, right=1)
node_R_2 = Builder.DecisionNode(2, operator=Builder.EQ, threshold=1, left=node_R_1, right=1)

root = Builder.DecisionNode(1, operator=Builder.GT, threshold=30, left=node_L_2, right=node_R_2)
tree = Builder.DecisionTree(3, root, feature_names=["I", "PP", "R"])

loan_types = {
"numerical": ["I"],
"binary": ["PP", "R"],
}

bob = (25, 1, 1)
explainer = Explainer.initialize(tree, instance=bob, features_type=loan_types)


#For him/her, the following classification rule must be obeyed:
#whenever the annual income of the client is lower than 30,
#the demand should be rejected
rectified_model = explainer.rectify(decision_rule=(-1, ), label=0)

self.assertEqual(rectified_model.raw_data_for_CPP(), (0, (1, 0, (4, (3, 0, 1), 1))))


def test_rectify_3(self):
learner = Learning.Scikitlearn("tests/compas.csv", learner_type=Learning.CLASSIFICATION)
model = learner.evaluate(method=Learning.HOLD_OUT, output=Learning.DT)

dict_information = learner.get_instances(model, n=1, indexes=Learning.TEST, correct=False, details=True)

all_dict_information = learner.get_instances(model, indexes=Learning.ALL, details=True)

instance = dict_information["instance"]
label = dict_information["label"]
prediction = dict_information["prediction"]


compas_types = {
"numerical": ["Number_of_Priors"],
"binary": ["Misdemeanor", "score_factor", "Age_Above_FourtyFive", "Age_Below_TwentyFive", "Female"],
"categorical": {"Origin*": ["African_American", "Asian", "Hispanic", "Native_American", "Other"]}
}


explainer = Explainer.initialize(model, instance=instance, features_type=compas_types)
minimal_reason = explainer.sufficient_reason(n=1)
#print("explanation:", minimal_reason)
#print("explanation:", explainer.to_features(minimal_reason))
model = explainer.rectify(decision_rule=minimal_reason, label=1)

self.assertEqual(model.predict_instance(instance), 1)

reason = set(minimal_reason)

for instance_dict in all_dict_information:
instance = instance_dict["instance"]
old_prediction = instance_dict["prediction"]
binary_representation = set(explainer._to_binary_representation(instance))
result = binary_representation.intersection(reason)
if len(result) == len(reason):
self.assertEqual(model.predict_instance(instance), 1)
else:
self.assertEqual(model.predict_instance(instance), old_prediction)



if __name__ == '__main__':
print("Tests: " + TestRectify.__name__ + ":")
unittest.main()
4 changes: 4 additions & 0 deletions pyxai/tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from pyxai.tests.functionality.GetInstances import *
from pyxai.tests.functionality.ToFeatures import *
from pyxai.tests.functionality.Metrics import *
from pyxai.tests.functionality.Rectify import *
from pyxai.tests.learning.ScikitLearn import *
from pyxai.tests.learning.LightGBM import *
from pyxai.tests.learning.XGBoost import *
Expand All @@ -26,6 +28,8 @@ def linux_tests():
suite = unittest.TestSuite()
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestToFeatures))
suite.addTest(unittest.makeSuite(TestGetInstances))
suite.addTest(unittest.makeSuite(TestMetrics))
suite.addTest(unittest.makeSuite(TestRectify))

suite.addTest(unittest.makeSuite(TestLearningScikitlearn))
suite.addTest(unittest.makeSuite(TestLearningXGBoost))
Expand Down
6 changes: 0 additions & 6 deletions tmp/wbo-10800331.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-13351277.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-14031925.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-15767817.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-15835756.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-16730192.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-17923783.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-18707833.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-19743337.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-19867073.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-20729516.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-21395586.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-23269759.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-23933993.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-24262394.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-24336094.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-24932345.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-25434562.wcnf

This file was deleted.

6 changes: 0 additions & 6 deletions tmp/wbo-26224984.wcnf

This file was deleted.

Loading

0 comments on commit a225ef9

Please sign in to comment.