Skip to content

Commit

Permalink
added tree entropy calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
SalatielBairros committed Jan 3, 2024
1 parent c674dd8 commit f8205d8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
numpy
scikit-learn
pandas
polars
polars
SciPy
21 changes: 21 additions & 0 deletions src/tree/decision_tree_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np

class DecisionTreeClassifier:
def __init__(self):
pass

def get_entropy(self, y: np.array):
y_size = len(y)

if y_size <= 1:
return 0

count_per_index = np.bincount(y)
total_count = count_per_index[np.nonzero(count_per_index)]
probabilities = total_count / y_size

if len(probabilities) <= 1:
return 0

return - np.sum(probabilities * np.log2(probabilities))

19 changes: 19 additions & 0 deletions tests/tree/test_decision_tree_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from unittest import TestCase
import numpy as np
from scipy.stats import entropy
from src.tree.decision_tree_classifier import DecisionTreeClassifier

class TestDecisionTreeClassifier(TestCase):
def test_should_calculate_entropy(self):
y_labels = np.asarray([
1, 2, 1, 3, 4, 5, 1, 1, 4, 4, 5, 0
])
total_count = np.bincount(y_labels)

expected = entropy(total_count, base=2)

classifier = DecisionTreeClassifier()
actual = classifier.get_entropy(y_labels)

assert round(expected, 4) == round(actual, 4)

0 comments on commit f8205d8

Please sign in to comment.