diff --git a/.gitignore b/.gitignore index 4328f217e..54f7f1c71 100644 --- a/.gitignore +++ b/.gitignore @@ -119,4 +119,5 @@ cythonize.dat doc/_build/ doc/auto_examples/ doc/generated/ -doc/bibtex/auto \ No newline at end of file +doc/bibtex/auto +/imblearn/tree/*.c diff --git a/.travis.yml b/.travis.yml index 00b73b723..144e42623 100644 --- a/.travis.yml +++ b/.travis.yml @@ -33,13 +33,13 @@ matrix: - env: DISTRIB="ubuntu" TEST_DOC="true" TEST_NUMPYDOC="false" # Latest release - env: DISTRIB="conda" PYTHON_VERSION="3.7" - NUMPY_VERSION="*" SCIPY_VERSION="*" SKLEARN_VERSION="master" + NUMPY_VERSION="*" SCIPY_VERSION="*" SKLEARN_VERSION="master" CYTHON_VERSION="*" OPTIONAL_DEPS="keras" TEST_DOC="true" TEST_NUMPYDOC="false" - env: DISTRIB="conda" PYTHON_VERSION="3.7" - NUMPY_VERSION="*" SCIPY_VERSION="*" SKLEARN_VERSION="master" + NUMPY_VERSION="*" SCIPY_VERSION="*" SKLEARN_VERSION="master" CYTHON_VERSION="*" OPTIONAL_DEPS="tensorflow" TEST_DOC="true" TEST_NUMPYDOC="false" - env: DISTRIB="conda" PYTHON_VERSION="3.7" - NUMPY_VERSION="*" SCIPY_VERSION="*" SKLEARN_VERSION="master" + NUMPY_VERSION="*" SCIPY_VERSION="*" SKLEARN_VERSION="master" CYTHON_VERSION="*" OPTIONAL_DEPS="false" TEST_DOC="false" TEST_NUMPYDOC="true" install: source build_tools/travis/install.sh diff --git a/MANIFEST.in b/MANIFEST.in index 192436787..0637c065e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ - recursive-include doc * recursive-include examples * +include imblearn/tree *.pyx +include imblearn/tree *.pyd include AUTHORS.rst include CONTRIBUTING.ms include LICENSE diff --git a/appveyor.yml b/appveyor.yml index b09063a6a..42a7e97b0 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -37,6 +37,7 @@ install: - activate testenv - conda install scipy numpy joblib -y -q - pip install --pre -f https://sklearn-nightly.scdn8.secure.raxcdn.com scikit-learn + - conda install -c anaconda cython -y -q - conda install %OPTIONAL_DEP% -y -q - conda install pytest pytest-cov -y -q - pip install codecov diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index a11a76ad2..e0c3c6961 100755 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -33,7 +33,7 @@ if [[ "$DISTRIB" == "conda" ]]; then # provided versions conda create -n testenv --yes python=$PYTHON_VERSION pip source activate testenv - conda install --yes numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION + conda install --yes numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION cython=$CYTHON_VERSION if [[ "$OPTIONAL_DEPS" == "keras" ]]; then conda install --yes pandas keras tensorflow=1 @@ -66,12 +66,14 @@ elif [[ "$DISTRIB" == "ubuntu" ]]; then pip install --pre -f https://sklearn-nightly.scdn8.secure.raxcdn.com scikit-learn pip3 install pandas pip3 install pytest pytest-cov codecov sphinx numpydoc + pip3 install cython fi python --version python -c "import numpy; print('numpy %s' % numpy.__version__)" python -c "import scipy; print('scipy %s' % scipy.__version__)" +python -c "import Cython; print('Cython %s' % Cython.__version__)" pip install -e . ccache --show-stats diff --git a/build_tools/travis/test_script.sh b/build_tools/travis/test_script.sh index 325ffc7e8..5564f39b2 100755 --- a/build_tools/travis/test_script.sh +++ b/build_tools/travis/test_script.sh @@ -19,6 +19,7 @@ run_tests(){ python --version python -c "import numpy; print('numpy %s' % numpy.__version__)" python -c "import scipy; print('scipy %s' % scipy.__version__)" + python -c "import Cython; print('Cython %s' % Cython.__version__)" python -c "import multiprocessing as mp; print('%d CPUs' % mp.cpu_count())" pytest --cov=$MODULE -r sx --pyargs $MODULE diff --git a/doc/api.rst b/doc/api.rst index b83396643..1105db652 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -196,6 +196,29 @@ Imbalance-learn provides some fast-prototyping tools. .. _metrics_ref: +:mod:`imblearn.tree`: Tree split criterion +================================== + +.. automodule:: imblearn.tree + :no-members: + :no-inherited-members: + +.. currentmodule:: imblearn + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + tree.criterion.HellingerDistanceCriterion + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + pipeline.make_pipeline + +.. _metrics_ref: + :mod:`imblearn.metrics`: Metrics ================================ diff --git a/doc/miscellaneous.rst b/doc/miscellaneous.rst index 768d02cf8..412f2b822 100644 --- a/doc/miscellaneous.rst +++ b/doc/miscellaneous.rst @@ -169,4 +169,4 @@ will be passed to ``fit_generator``:: .. topic:: References - * :ref:`sphx_glr_auto_examples_applications_porto_seguro_keras_under_sampling.py` \ No newline at end of file + * :ref:`sphx_glr_auto_examples_applications_porto_seguro_keras_under_sampling.py` diff --git a/doc/tree.rst b/doc/tree.rst new file mode 100644 index 000000000..ceb3dc305 --- /dev/null +++ b/doc/tree.rst @@ -0,0 +1,26 @@ +.. _tree-split: + +============== +Tree-split +============== + +.. currentmodule:: imblearn.tree + +.. _cluster_centroids: + + +Hellinger Distance split +==================== + +Hellinger Distance is used to quantify the similarity between two probability distributions. +When used as split criterion in Decision Tree Classifier it makes it skew insensitive and helps tackle the imbalance problem. + + >>> import numpy as np + >>> from sklearn.ensemble import RandomForestClassifier + >>> from imblearn.tree.criterion import HellingerDistanceCriterion + + >>> hdc = HellingerDistanceCriterion(1, np.array([2],dtype='int64')) + >>> clf = RandomForestClassifier(criterion=hdc) + +:class:`HellingerDistanceCriterion` offers a Cython implementation of Hellinger Distance +as a criterion for decision tree split compatible with sklearn tree based classification models. diff --git a/doc/user_guide.rst b/doc/user_guide.rst index 6dbd575de..00a1e8c7c 100644 --- a/doc/user_guide.rst +++ b/doc/user_guide.rst @@ -12,6 +12,7 @@ User Guide introduction.rst over_sampling.rst under_sampling.rst + tree.rst combine.rst ensemble.rst miscellaneous.rst diff --git a/doc/whats_new/v0.4.rst b/doc/whats_new/v0.4.rst index b42fffddd..c615a6538 100644 --- a/doc/whats_new/v0.4.rst +++ b/doc/whats_new/v0.4.rst @@ -129,6 +129,9 @@ Enhancement :class:`BorderlineSMOTE` and :class:`SVMSMOTE`. :issue:`440` by :user:`Guillaume Lemaitre `. +- Add support for Hellinger Distance as sklearn classification tree split criterion. + By :user: `Evgeni Dubov `. + - Allow :class:`imblearn.over_sampling.RandomOverSampler` can return indices using the attributes ``return_indices``. :issue:`439` by :user:`Hugo Gascon` and diff --git a/examples/tree/README.txt b/examples/tree/README.txt new file mode 100644 index 000000000..a39794dbf --- /dev/null +++ b/examples/tree/README.txt @@ -0,0 +1,9 @@ +.. _tree_examples: + +Example using Hellinger Distance as tree split criterion +======================================================== + +Hellinger Distance is used to quantify the similarity between two probability distributions. +When used as split criterion in Decision Tree Classifier it makes it skew insensitive and helps tackle the imbalance problem. +This is Cython implementation of Hellinger Distance as a criterion for decision tree split compatible with sklearn tree based classification models. + diff --git a/examples/tree/train_model_with_hellinger_distance_criterion.py b/examples/tree/train_model_with_hellinger_distance_criterion.py new file mode 100644 index 000000000..3bfebe33d --- /dev/null +++ b/examples/tree/train_model_with_hellinger_distance_criterion.py @@ -0,0 +1,17 @@ +import numpy as np + +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split +from sklearn.ensemble import RandomForestClassifier + +from imblearn.tree.criterion import HellingerDistanceCriterion + +X, y = make_classification( + n_samples=10000, n_features=40, n_informative=5, + n_classes=2, weights=[0.05, 0.95], random_state=1) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4) + +hdc = HellingerDistanceCriterion(1, np.array([2], dtype='int64')) +clf = RandomForestClassifier(criterion=hdc, max_depth=4, n_estimators=100) +clf.fit(X_train, y_train) +print(clf.score(X_test, y_test)) diff --git a/imblearn/tree/__init__.py b/imblearn/tree/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/imblearn/tree/criterion.pyx b/imblearn/tree/criterion.pyx new file mode 100644 index 000000000..1325ec2e9 --- /dev/null +++ b/imblearn/tree/criterion.pyx @@ -0,0 +1,92 @@ +# Author: Evgeni Dubov +# +# License: BSD 3 clause + +#cython: language_level=3, boundscheck=False + +from libc.math cimport sqrt, pow +from libc.math cimport abs + +import numpy as np + +from sklearn.tree._criterion cimport ClassificationCriterion +from sklearn.tree._criterion cimport SIZE_t + +cdef double INFINITY = np.inf + + +cdef class HellingerDistanceCriterion(ClassificationCriterion): + """Hellinger distance criterion. + + + """ + + cdef double proxy_impurity_improvement(self) nogil: + cdef: + double impurity_left + double impurity_right + + self.children_impurity(&impurity_left, &impurity_right) + + return impurity_right + impurity_left + + cdef double impurity_improvement(self, double impurity) nogil: + cdef: + double impurity_left + double impurity_right + + self.children_impurity(&impurity_left, &impurity_right) + + return impurity_right + impurity_left + + cdef double node_impurity(self) nogil: + cdef: + SIZE_t* n_classes = self.n_classes + double* sum_total = self.sum_total + double hellinger = 0.0 + double sq_count + double count_k + SIZE_t k, c + + for k in range(self.n_outputs): + for c in range(n_classes[k]): + hellinger += 1.0 + + return hellinger / self.n_outputs + + cdef void children_impurity(self, double* impurity_left, + double* impurity_right) nogil: + cdef: + SIZE_t* n_classes = self.n_classes + double* sum_left = self.sum_left + double* sum_right = self.sum_right + double hellinger_left = 0.0 + double hellinger_right = 0.0 + double count_k1 = 0.0 + double count_k2 = 0.0 + SIZE_t k, c + + # stop splitting in case reached pure node with 0 samples of second + # class + if sum_left[1] + sum_right[1] == 0: + impurity_left[0] = -INFINITY + impurity_right[0] = -INFINITY + return + + for k in range(self.n_outputs): + if(sum_left[0] + sum_right[0] > 0): + count_k1 = sqrt(sum_left[0] / (sum_left[0] + sum_right[0])) + if(sum_left[1] + sum_right[1] > 0): + count_k2 = sqrt(sum_left[1] / (sum_left[1] + sum_right[1])) + + hellinger_left += pow((count_k1 - count_k2), 2) + + if(sum_left[0] + sum_right[0] > 0): + count_k1 = sqrt(sum_right[0] / (sum_left[0] + sum_right[0])) + if(sum_left[1] + sum_right[1] > 0): + count_k2 = sqrt(sum_right[1] / (sum_left[1] + sum_right[1])) + + hellinger_right += pow((count_k1 - count_k2), 2) + + impurity_left[0] = hellinger_left / self.n_outputs + impurity_right[0] = hellinger_right / self.n_outputs diff --git a/imblearn/tree/setup.py b/imblearn/tree/setup.py new file mode 100644 index 000000000..311c3e0d0 --- /dev/null +++ b/imblearn/tree/setup.py @@ -0,0 +1,18 @@ +import numpy + + +def configuration(parent_package='', top_path=None): + from numpy.distutils.misc_util import Configuration + config = Configuration('tree', parent_package, top_path) + libraries = [] + config.add_extension('criterion', + sources=['criterion.c'], + include_dirs=[numpy.get_include()], + libraries=libraries) + + return config + + +if __name__ == "__main__": + from numpy.distutils.core import setup + setup(**configuration().todict()) diff --git a/tools/cythonize.py b/tools/cythonize.py new file mode 100644 index 000000000..38fcd94fd --- /dev/null +++ b/tools/cythonize.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python +""" cythonize + +Cythonize pyx files into C files as needed. + +Usage: cythonize [root_dir] + +Checks pyx files to see if they have been changed relative to their +corresponding C files. If they have, then runs cython on these files to +recreate the C files. + +The script detects changes in the pyx/pxd files using checksums +[or hashes] stored in a database file + +Simple script to invoke Cython on all .pyx +files; while waiting for a proper build system. Uses file hashes to +figure out if rebuild is needed. + +It is called by ./setup.py sdist so that sdist package can be installed without +cython + +Originally written by Dag Sverre Seljebotn, and adapted from statsmodel 0.6.1 +(Modified BSD 3-clause) + +We copied it for scikit-learn. + +Note: this script does not check any of the dependent C libraries; it only +operates on the Cython .pyx files or their corresponding Cython header (.pxd) +files. +""" +# Author: Arthur Mensch +# Author: Raghav R V +# +# License: BSD 3 clause +# see http://github.com/scikit-learn/scikit-learn + + +from __future__ import division, print_function, absolute_import + +import os +import re +import sys +import hashlib +import subprocess + +HASH_FILE = 'cythonize.dat' + + +# WindowsError is not defined on unix systems +try: + WindowsError +except NameError: + WindowsError = None + + +def cythonize(cython_file, gen_file): + try: + from Cython.Compiler.Version import version as cython_version + from distutils.version import LooseVersion + if LooseVersion(cython_version) < LooseVersion('0.21'): + raise Exception('Building scikit-learn requires Cython >= 0.21') + + except ImportError: + pass + + flags = ['--fast-fail'] + if gen_file.endswith('.cpp'): + flags += ['--cplus'] + + try: + try: + rc = subprocess.call(['cython'] + + flags + ["-o", gen_file, cython_file]) + if rc != 0: + raise Exception('Cythonizing %s failed' % cython_file) + except OSError: + # There are ways of installing Cython that don't result in a cython + # executable on the path, see scipy issue gh-2397. + rc = subprocess.call( + [sys.executable, '-c', + 'import sys; from Cython.Compiler.Main; \ + import setuptools_main as main; sys.exit(main())'] + + flags + + ["-o", gen_file, cython_file]) + if rc != 0: + raise Exception('Cythonizing %s failed' % cython_file) + except OSError: + raise OSError('Cython needs to be installed') + + +def load_hashes(filename): + """Load the hashes dict from the hashfile""" + # { filename : (sha1 of header if available or 'NA', + # sha1 of input, + # sha1 of output) } + + hashes = {} + try: + with open(filename, 'r') as cython_hash_file: + for hash_record in cython_hash_file: + (filename, header_hash, + cython_hash, gen_file_hash) = hash_record.split() + hashes[filename] = (header_hash, cython_hash, gen_file_hash) + except (KeyError, ValueError, AttributeError, IOError): + hashes = {} + return hashes + + +def save_hashes(hashes, filename): + """Save the hashes dict to the hashfile""" + with open(filename, 'w') as cython_hash_file: + for key, value in hashes.items(): + cython_hash_file.write("%s %s %s %s\n" + % (key, value[0], value[1], value[2])) + + +def sha1_of_file(filename): + h = hashlib.sha1() + with open(filename, "rb") as f: + h.update(f.read()) + return h.hexdigest() + + +def clean_path(path): + """Clean the path""" + path = path.replace(os.sep, '/') + if path.startswith('./'): + path = path[2:] + return path + + +def get_hash_tuple(header_path, cython_path, gen_file_path): + """Get the hashes from the given files""" + + header_hash = (sha1_of_file(header_path) + if os.path.exists(header_path) else 'NA') + from_hash = sha1_of_file(cython_path) + to_hash = (sha1_of_file(gen_file_path) + if os.path.exists(gen_file_path) else 'NA') + + return header_hash, from_hash, to_hash + + +def cythonize_if_unchanged(path, cython_file, gen_file, hashes): + full_cython_path = os.path.join(path, cython_file) + full_header_path = full_cython_path.replace('.pyx', '.pxd') + full_gen_file_path = os.path.join(path, gen_file) + + current_hash = get_hash_tuple(full_header_path, full_cython_path, + full_gen_file_path) + + if current_hash == hashes.get(clean_path(full_cython_path)): + print('%s has not changed' % full_cython_path) + return + + print('Processing %s' % full_cython_path) + cythonize(full_cython_path, full_gen_file_path) + + # changed target file, recompute hash + current_hash = get_hash_tuple(full_header_path, full_cython_path, + full_gen_file_path) + + # Update the hashes dict with the new hash + hashes[clean_path(full_cython_path)] = current_hash + + +def check_and_cythonize(root_dir): + print(root_dir) + hashes = load_hashes(HASH_FILE) + + for cur_dir, dirs, files in os.walk(root_dir): + for filename in files: + if filename.endswith('.pyx'): + gen_file_ext = '.c' + # Cython files with libcpp imports should be compiled to cpp + with open(os.path.join(cur_dir, filename), 'rb') as f: + data = f.read() + m = re.search(b"libcpp", data, re.I | re.M) + if m: + gen_file_ext = ".cpp" + cython_file = filename + gen_file = filename.replace('.pyx', gen_file_ext) + cythonize_if_unchanged(cur_dir, cython_file, gen_file, hashes) + + # Save hashes once per module. This prevents cythonizing prev. + # files again when debugging broken code in a single file + save_hashes(hashes, HASH_FILE) + + +def main(root_dir): + check_and_cythonize(root_dir) + + +if __name__ == '__main__': + try: + root_dir_arg = sys.argv[1] + except IndexError: + raise ValueError("Usage: python cythonize.py ") + main(root_dir_arg)