Skip to content

Commit

Permalink
Merge pull request #15 from audemard/main
Browse files Browse the repository at this point in the history
fix bug for DT and excluded features
  • Loading branch information
szczepanskiNicolas committed Feb 22, 2024
2 parents d8d61a2 + 68cb3e1 commit 8f30eb8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pyxai/sources/core/explainer/explainerDT.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def sufficient_reason(self, *, n=1, time_limit=None):
SATsolver.add_clauses(prime_implicant_cnf.cnf)

# Remove excluded features
SATsolver.add_clauses([[-prime_implicant_cnf.from_original_to_new(lit)] for lit in self._excluded_literals])
SATsolver.add_clauses([[-prime_implicant_cnf.from_original_to_new(lit)]
for lit in self._excluded_literals
if prime_implicant_cnf.from_original_to_new(lit) is not None])

sufficient_reasons = []
while True:
Expand Down Expand Up @@ -211,7 +213,8 @@ def preferred_sufficient_reason(self, *, method, n=1, time_limit=None, weights=N

# Remove excluded features
for lit in self._excluded_literals:
solver.add_hard_clause([-prime_implicant_cnf.from_original_to_new(lit)])
if prime_implicant_cnf.from_original_to_new(lit) is not None:
solver.add_hard_clause([-prime_implicant_cnf.from_original_to_new(lit)])

# Solving
time_used = 0
Expand Down Expand Up @@ -275,7 +278,8 @@ def n_sufficient_reasons_per_attribute(self, *, time_limit=None):
# Remove excluded features
cnf = list(prime_implicant_cnf.cnf)
for lit in self._excluded_literals:
cnf.append([-prime_implicant_cnf.from_original_to_new(lit)])
if prime_implicant_cnf.from_original_to_new(lit) is not None:
cnf.append([-prime_implicant_cnf.from_original_to_new(lit)])

compiler.add_cnf(cnf, prime_implicant_cnf.n_literals - 1)
compiler.add_count_model_query(cnf, prime_implicant_cnf.n_literals - 1, prime_implicant_cnf.n_literals_mapping)
Expand Down
2 changes: 2 additions & 0 deletions pyxai/sources/core/tools/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def new_mapping_id_variables(self):


def from_original_to_new(self, lit):
if abs(lit) >= len(self.mapping_original_to_new) or self.mapping_original_to_new[abs(lit)] == 0:
return None
return self.mapping_original_to_new[abs(lit)]


Expand Down

0 comments on commit 8f30eb8

Please sign in to comment.