From f8205d8837f2dbbaa8d5a04b589ea7cf0b2c1f09 Mon Sep 17 00:00:00 2001 From: Salatiel Bairros Date: Wed, 3 Jan 2024 14:40:22 -0300 Subject: [PATCH] added tree entropy calculation --- requirements.txt | 3 ++- src/tree/decision_tree_classifier.py | 21 +++++++++++++++++++++ tests/tree/test_decision_tree_classifier.py | 19 +++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 src/tree/decision_tree_classifier.py create mode 100644 tests/tree/test_decision_tree_classifier.py diff --git a/requirements.txt b/requirements.txt index 5237ed1..c58044f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ numpy scikit-learn pandas -polars \ No newline at end of file +polars +SciPy \ No newline at end of file diff --git a/src/tree/decision_tree_classifier.py b/src/tree/decision_tree_classifier.py new file mode 100644 index 0000000..91493bd --- /dev/null +++ b/src/tree/decision_tree_classifier.py @@ -0,0 +1,21 @@ +import numpy as np + +class DecisionTreeClassifier: + def __init__(self): + pass + + def get_entropy(self, y: np.array): + y_size = len(y) + + if y_size <= 1: + return 0 + + count_per_index = np.bincount(y) + total_count = count_per_index[np.nonzero(count_per_index)] + probabilities = total_count / y_size + + if len(probabilities) <= 1: + return 0 + + return - np.sum(probabilities * np.log2(probabilities)) + \ No newline at end of file diff --git a/tests/tree/test_decision_tree_classifier.py b/tests/tree/test_decision_tree_classifier.py new file mode 100644 index 0000000..79a25ce --- /dev/null +++ b/tests/tree/test_decision_tree_classifier.py @@ -0,0 +1,19 @@ +from unittest import TestCase +import numpy as np +from scipy.stats import entropy +from src.tree.decision_tree_classifier import DecisionTreeClassifier + +class TestDecisionTreeClassifier(TestCase): + def test_should_calculate_entropy(self): + y_labels = np.asarray([ + 1, 2, 1, 3, 4, 5, 1, 1, 4, 4, 5, 0 + ]) + total_count = np.bincount(y_labels) + + expected = entropy(total_count, base=2) + + classifier = DecisionTreeClassifier() + actual = classifier.get_entropy(y_labels) + + assert round(expected, 4) == round(actual, 4) + \ No newline at end of file