From 9a747c5784465c7a03c6c39e9a6b92ff346edb87 Mon Sep 17 00:00:00 2001 From: Jan Blunk Date: Mon, 28 Aug 2023 10:25:00 +0200 Subject: [PATCH] python port for the CMIh by Zan et al. --- README.md | 13 +- mixed_cmiI_estimator.py | 405 ++++++++++++++++++++++++++++++++++++++++ tensor_utils.py | 107 +++++++++++ 3 files changed, 523 insertions(+), 2 deletions(-) create mode 100644 mixed_cmiI_estimator.py create mode 100644 tensor_utils.py diff --git a/README.md b/README.md index 02d5b83..f5be8ba 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,12 @@ -# beyond-debiasing +# Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization -This repository provides code to use the method presented in our GCPR 2023 paper "Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization". \ No newline at end of file +This repository provides code to use the method presented in our GCPR 2023 paper "Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization". If you use this method, please cite: + + @inproceedings{Blunk23:FS, + author = {Jan Blunk and Niklas Penzel and Paul Bodesheim and Joachim Denzler}, + booktitle = {DAGM German Conference on Pattern Recognition (DAGM-GCPR)}, + title = {Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization}, + year = {2023}, + } + +This repository includes a Python implementation of the hybrid CMI estimator CMIh presented by [Zan et al.](https://doi.org/10.3390/e24091234) The authors' original R implementation can be found [here](https://github.com/leizan/CMIh2022). CMIh was published under the MIT license. \ No newline at end of file diff --git a/mixed_cmiI_estimator.py b/mixed_cmiI_estimator.py new file mode 100644 index 0000000..85bb5e4 --- /dev/null +++ b/mixed_cmiI_estimator.py @@ -0,0 +1,405 @@ +""" +This is document contains a python port of the hybrid CMI estimator CMIh proposed +by Zan et al.: + + Zan, Lei; Meynaoui, Anouar; Assaad, Charles K.; Devijver, Emilie; Gaussier, + Eric (2022): A Conditional Mutual Information Estimator for Mixed Data and + an Associated Conditional Independence Test. In: Entropy (Basel, Switzerland) + 24 (9). DOI: 10.3390/e24091234. + +The original R implementation can be found here: + https://github.com/leizan/CMIh2022 +It was published under the following license: + + +MIT License + +Copyright (c) 2022 leizan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import torch +import math +from torch.special import digamma +from torchmetrics.functional import pairwise_manhattan_distance + +from tensor_utils import ( + intersection, + setdiff, + unsqueeze_to_1d, + unsqueeze_to_2d, + concatenate_1d, + cbind, +) + + +def get_dist_array(data): + """Returns pairwise distances for all columns of data measured with the + manhattan distance. + + Variables are represented by COLUMNS, observations by ROWS. + """ + # Works for 1D and 2d data. + + # Rotate if only a list so that it becomes a two-dimensional vector. + if data.size()[0] == 1: + data = data.transpose(dim0=0, dim1=1) + N = data.size()[0] + nDim = data.size()[1] + disArray = torch.zeros(size=(N, N, nDim)) + + # Attention: Our m is one smaller than the m of the original implementation + for m in range(nDim): # 1:nDim + # Get m-th column of data. + dataDim = unsqueeze_to_2d(data[:, m]) + + # Calculate pairwise manhattan distance of columns. + disArray[:, :, m] = pairwise_manhattan_distance(dataDim) + return disArray + + +def get_epsilon_distance(k, disArray): + """Based on a tensor of pairwise distances per observation (tensor is + three-dimensional, quadratic, symmetric). + """ + # data = torch.tensor([[0, 8, 2], [0, 8, 2], [0, 8, 2]]) + if disArray.size()[0] == 1: + disArray = disArray.transpose(dim0=0, dim1=1) + N = disArray.size()[0] + epsilonDisArray = torch.zeros(size=[N]) + + for point_i in range(N): + coord_dists = disArray[:, point_i, :] + # Compute maximum element per row. + # torch.max returns (values, indices). + dists, _ = torch.max(coord_dists, dim=1) + # torch.sort returns (sorted, indices). + ordered_dists, _ = torch.sort(dists) + # The original R implementation uses k+1 to access ordered_dists. However, + # python's indexing starts with 0 instead of 1. Therefore, the index k+1 in R + # corresponds to index k in python. + epsilonDisArray[point_i] = 2 * ordered_dists[k] + + return epsilonDisArray + + +def find_inter_cluster(eleInEachClass): + """Takes a list of scalars / lists of scalars and returns the intersection.""" + + interCluster = eleInEachClass[0] + for m in range(1, len(eleInEachClass)): + interCluster = intersection(interCluster, eleInEachClass[m]) + + # Output if zero if there is no intersection (= empty list). + if interCluster.nelement() == 0: + interCluster = torch.tensor(0.0) + + return interCluster + + +def con_entro_estimator(data, k, dN): + """Estimates entropy of quantitative data. + + Variables are represented by COLUMNS, observations by ROWS. + """ + + # If only one row, count the number of columns (= length of the vector). + if data.size()[0] == 1: + N = data.size()[1] + else: + N = data.size()[0] + if N == 1: + return 0 + + # Get distances. + distArray = get_dist_array(data) + epsilonDis = get_epsilon_distance(k, distArray) + + if 0 in epsilonDis: + epsilonDis = epsilonDis[epsilonDis != 0] + N = epsilonDis.nelement() + if N == 0: + return 0 + entropy = ( + -digamma(torch.tensor(k)) + + digamma(torch.tensor(N)) + + (dN * torch.sum(torch.log(epsilonDis))) / N + ) + return entropy + + # The maximum norm is used, so Cd=1, log(Cd)=0 + entropy = ( + -digamma(torch.tensor(k)) + + digamma(torch.tensor(N)) + + (dN * torch.sum(torch.log(epsilonDis))) / N + ) + return entropy + + +def mixed_entro_estimator(data, dimCon, dimDis, k=0.1): + """Estimates the entropie of the mixed variables with the given dimensions. + + Variables are represented by COLUMNS, observations by ROWS. + + Args: + k (float, optional): Neighborhood size is calculated as + max(1, round(k*#all_neighbors)). k should be < 1 and defaults to 0.1. + """ + + # Get number of quantitative variables and number of observations. + dN = len(dimCon) + N = data.size()[0] + + # Split data into quantitative and qualitative data. + entroCon = 0 + entroDis = 0 + if len(dimCon) != 0: + dataCon = torch.index_select(data, 1, dimCon) + if len(dimDis) != 0: + dataDis = torch.index_select(data, 1, dimDis) + if len(dimCon) == 0 and len(dimDis) == 0: + # print('Input data is NULL!!!') + pass + + # Calculate the entropie for the extracted data. + # If the data is purely continuous! + if len(dimDis) == 0 and len(dimCon) != 0: + entroCon = con_entro_estimator(dataCon, max(1, round(k * N)), dN) + + if len(dimDis) != 0: + # Histogram creation. + # We create a histogram that shows how often each combination of all + # possible combinations of the input variables X, Y, Z occurs. That is, + # we gather the unique values of X, Y and Z separately and then calculate + # all possible combinations of these values. + # Afterwards, we count the occurence of each of these combinations + # and divide all bins of the histogram by the total number of bins + # (= combinations of values from X, Y and Z). + + # Unique elements -> Histogram bins. + classByDimList = [ + torch.unique(column) for column in dataDis.transpose(dim0=0, dim1=1) + ] + + # Create all possible combinations of the unique values of the columns + # of dataDis (we have already removed duplicate rows). + # Make sure that classList is a two-dimensional tensor, even for dataDis + # having just one column. + classList = torch.cartesian_prod(*classByDimList) + classList = unsqueeze_to_2d(classList) + + # Create histogram entries. That is, we count the occurence of each + # of the histogram bins / possible row combinations from dataDis. + indexInClass = [0] * classList.size()[0] + for i in range(len(indexInClass)): + classElement = classList[i, :] + + if classElement.size()[0] == 1: + # Get indices where classElement is a row in dataDis. + # See: https://stackoverflow.com/questions/59705001/torch-find-indices-of-matching-rows-in-2-2d-tensors. + classElement = unsqueeze_to_2d(classElement) + indexInClass[i] = torch.where((classElement == dataDis).all(dim=1))[0] + else: + eleInEachClass = [0] * classElement.size()[0] + for m in range(len(eleInEachClass)): + eleInEachClass[m] = torch.where( + ( + unsqueeze_to_2d(classElement[m]) + == unsqueeze_to_2d(dataDis[:, m]) + ).all(dim=1) + )[0] + indexInClass[i] = find_inter_cluster(eleInEachClass) + + # Remove the empty bins and reverse order (for some reason the authors + # of the original paper reverse the order here). + indexInClass = [bin for bin in indexInClass if not len(bin.size()) == 0] + indexInClass.reverse() + + # Entropy calculation. + # We calculate the entropy by treating each entry of the histogram + # as probabilities and use the following formular: + # H=-\sum(p*log p) + # To treat the entries of the histogram as probabilities, we normalize + # with the total number of observations. + probBins = [hist_bin.nelement() / N for hist_bin in indexInClass] + for i in probBins: + entroDis = entroDis - i * math.log(i) + entroDis = torch.tensor(entroDis) + + if len(dimDis) != 0 and len(dimCon) != 0: + # Unlike in the case of no qualitative variables, in the case of both + # qualitative and quantitative variables the calculation of the + # quantitative entropy depends on the previous calculation of the + # qualitative entropy. + # Since we want the result to be differentiable w.r.t. the quantitative + # variables, we have to make sure that this calculation is differentiable. + # Everything that interacts with dataCon has to be differentiable. + for i in range(len(probBins)): + data = unsqueeze_to_2d(dataCon)[indexInClass[i], :] + # Neighborhood size is not named k, because that parameter already exists. + neighborhood_size = max(1, round(k * indexInClass[i].nelement())) + entroCon = ( + entroCon + + con_entro_estimator(data=data, k=neighborhood_size, dN=dN) + * probBins[i] + ) + + # Put quantitative and qualitative entropy together. + res = entroCon + entroDis + + return res + + +def _mixed_cmi_model(data, xind, yind, zind, is_categorical, k=0.1): + """Estimates the CMI from qualitative and quantitative data. + + The conditional mutual independence I(X;Y|Z) is -> 0 if X and Y are + dissimilar and -> inf if they are similar. + + Method: + Zan, Lei; Meynaoui, Anouar; Assaad, Charles K.; Devijver, Emilie; Gaussier, + Eric (2022): A Conditional Mutual Information Estimator for Mixed Data and + an Associated Conditional Independence Test. In: Entropy (Basel, Switzerland) + 24 (9). DOI: 10.3390/e24091234. + The implementation follows their R implementation licensed unter MIT license: + https://github.com/leizan/CMIh2022/blob/main/method.R + + :param data: Observations of the variables x, y, and z. Variables are + represented by COLUMNS, observations by ROWS. + :param xind: One-dimensional tensor that contains a list of the indices + corresponding to the columns of data containing the observations of the + variable x. + :param yind: One-dimensional tensor that contains a list of the indices + corresponding to the columns of data containing the observations of the + variable y. + :param zind: One-dimensional tensor that contains a list of the indices + corresponding to the columns of data containing the observations of the + #variable z. + :param is_categorical: One-dimensional tensor that contains a list of + the indices corresponding to the columns of data that contain qualitative + (=categorical) data. All other columns are expected to contain quantitative + data. + :param k: Neighborhood size for KNN. Calculated as + Neighborhood size = max(1, round(k * #All neighbors)). + :return: Estimate of the CMI I(X;Y|Z). + :rtype: float. + + """ + + # data: Variables are represented by COLUMNS, obervations by ROWS. + # + # xind: Tensor with indices of the columns that contain the observations of the variable x. + # yind: Tensor with indices of the variable x. + # zind: Tensor with indices of the variable x. + # + # isCat: List that contains the indices of all columns of data that contain + # qualitative (=categorical) data. All other columns contain quantitative values. + + # All input variables describing indices should be 1D-Tensors of type int. + xind, yind, zind, is_categorical = [ + torch.tensor(var) if isinstance(var, int) or isinstance(var, list) else var + for var in (xind, yind, zind, is_categorical) + ] + xind, yind, zind, is_categorical = [ + var.unsqueeze(dim=0) if len(var.size()) == 0 else var + for var in (xind, yind, zind, is_categorical) + ] + xind, yind, zind, is_categorical = [ + var.to(torch.int) for var in (xind, yind, zind, is_categorical) + ] + + # Move tensors to correct device. + xind, yind, zind, is_categorical = [ + var.to(data.get_device()) for var in (xind, yind, zind, is_categorical) + ] + + # setdiff: setdiff(A, B) returns all elements of A that are not in B + # (without repetitions). + # ...Con = Indices of data for the quantitative columns of X + # ...Dis = Indices of data for the qualitative columns of X + xDimCon = setdiff(xind, is_categorical) + xDimDis = setdiff(xind, xDimCon) + yDimCon = setdiff(yind, is_categorical) + yDimDis = setdiff(yind, yDimCon) + zDimCon = setdiff(zind, is_categorical) + zDimDis = setdiff(zind, zDimCon) + xDimCon, xDimDis, yDimCon, yDimDis, zDimCon, zDimDis = unsqueeze_to_1d( + xDimCon, xDimDis, yDimCon, yDimDis, zDimCon, zDimDis + ) + + conXYZ = concatenate_1d(xDimCon, yDimCon, zDimCon) + disXYZ = concatenate_1d(xDimDis, yDimDis, zDimDis) + hXYZ = mixed_entro_estimator(data, conXYZ, disXYZ, k=k) + + conXZ = concatenate_1d(xDimCon, zDimCon) + disXZ = concatenate_1d(xDimDis, zDimDis) + hXZ = mixed_entro_estimator(data, conXZ, disXZ, k=k) + + conYZ = concatenate_1d(yDimCon, zDimCon) + disYZ = concatenate_1d(yDimDis, zDimDis) + hYZ = mixed_entro_estimator(data, conYZ, disYZ, k=k) + + conZ = zDimCon + disZ = zDimDis + hZ = mixed_entro_estimator(data, conZ, disZ, k=k) + + cmi = hXZ + hYZ - hXYZ - hZ + + return cmi + + +def mixed_cmi_model( + feature, output, target, feature_is_categorical, target_is_categorical +): + """Estimates the CMI from qualitative and quantitative data. + + The conditional mutual information I(X;Y|Z) is -> 0 if X and Y are + dissimilar and -> inf if they are similar. + Here, the resulting CMI is only differentiable w.r.t. to non-categorical + inputs (creating a histogram in a differentiable manner is not really + reasonable). + All input tensors are only allowed to be one-dimensional tensors. + + Method: + Zan, Lei; Meynaoui, Anouar; Assaad, Charles K.; Devijver, Emilie; Gaussier, + Eric (2022): A Conditional Mutual Information Estimator for Mixed Data and + an Associated Conditional Independence Test. In: Entropy (Basel, Switzerland) + 24 (9). DOI: 10.3390/e24091234. + The implementation follows their R implementation and was adapted for + differentiability w.r.t. quantitative variables: + https://github.com/leizan/CMIh2022/blob/main/method.R + """ + + if any([z.dim() != 1 for z in (feature, output, target)]): + raise ValueError("All input tensors have to be one-dimensional!") + + # Merge into one 2D array. + # Variables are represented by columns. + data = cbind(feature, output, target) + + # Prepare indices and which variables are categorical. + is_categorical = [] + if feature_is_categorical: + is_categorical.append(0) + if target_is_categorical: + is_categorical.append(2) + xind, yind, zind = [0], [1], [2] + + return _mixed_cmi_model(data, xind, yind, zind, is_categorical) diff --git a/tensor_utils.py b/tensor_utils.py new file mode 100644 index 0000000..cf95b1f --- /dev/null +++ b/tensor_utils.py @@ -0,0 +1,107 @@ +import torch + +def difference(t1, t2): + """ Returns all elements that are only in t1 or t2. + """ + + t1, t2 = t1.unique(), t2.unique() + combined = torch.cat((t1, t2)) + uniques, counts = combined.unique(return_counts=True) + diff = uniques[counts == 1] + + return diff + +def intersection(t1, t2): + """ Returns all elements that are in both t1 and t2. + """ + + t1, t2 = t1.unique(), t2.unique() + combined = torch.cat((t1, t2)) + uniques, counts = combined.unique(return_counts=True) + intersec = uniques[counts > 1] + + return intersec + +def setdiff(t1, t2): + """ Returns all elements of tensor t1 that are not in tensor t2. + """ + + diff = difference(t1, t2) + diff_from_t1 = intersection(diff, t1) + + return diff_from_t1 + +def concatenate_1d(*tensors): + """Concatenates the given 1d tensors. + """ + + for tensor in tensors: + if len(tensor.size()) != 1: + raise ValueError("Can only concatenate 1d tensors. Otherwise, use rbind / cbind.") + + return torch.cat(tensors, 0) + +def cbind(*tensors): + """Combines the given 2d tensors as columns. + """ + + # If a vector is one-dimensional, convert it to a two-dimensional column + # vector. + tensors = [unsqueeze_to_2d(var) if len(var.size()) == 1 else var for var in tensors] + + return torch.cat(tensors, 1) + +def rbind(*tensors): + """Combines the given 2d tensors as rows. + """ + + for tensor in tensors: + if len(tensor.size()) < 2: + raise ValueError("rbind only takes two-dimensional tensors as input") + + return torch.cat(tensors, 0) + +def unsqueeze_to_1d(*tensors): + """ Unsqueezes zero-dimensional tensors to one-dimensional tensors. + """ + + if len(tensors) > 1: + return [var.unsqueeze(dim=0) if len(var.size()) == 0 else var for var in tensors] + else: + return tensors[0].unsqueeze(dim=0) if len(tensors[0].size()) == 0 else tensors[0] + +def unsqueeze_to_2d(*tensors): + """ Unsqueezes one-dimensional tensors two-dimensional tensors. + """ + + if len(tensors) > 1: + return [var.unsqueeze(dim=1) if len(var.size()) == 1 else var for var in tensors] + else: + return tensors[0].unsqueeze(dim=1) if len(tensors[0].size()) == 1 else tensors[0] + +def convert_type(torch_type, *tensors): + """ Converts all given tensors to tensors of the given type. + """ + + if len(tensors) > 1: + return [var.to(torch_type) for var in tensors] + else: + return tensors[0].to(torch_type) + +def shuffle(t, mode=None): + """ Shuffles the rows of the given tensor. + """ + + if mode == "within_columns": + rand_indices = torch.randperm(len(t)) + t = t[rand_indices] + elif mode == "within_rows": + t = t.transpose(dim0=0, dim1=1) + rand_indices = torch.randperm(len(t)) + t = t[rand_indices] + t = t.transpose(dim0=0, dim1=1) + else: + rand_indices = torch.randperm(t.nelement()) + t = t.view(-1)[rand_indices].view(t.size()) + + return t