Skip to content

euxhenh/ConstrainedKMeans

Folders and files

NameName
Last commit message
Last commit date

Latest commit

author
ferrocactus
Aug 11, 2020
f9b2bfe · Aug 11, 2020

History

7 Commits
Aug 11, 2020
Aug 11, 2020
Aug 11, 2020
Aug 11, 2020
Aug 11, 2020
Aug 11, 2020

Repository files navigation

Constrained KMeans

Modified version of KMeans algorithm that takes into account partial information about the data.

Given a partial list of known labels init_labels, Constrained KMeans finds a cluster configuration that complies with init_labels. init_labels is the same length as x.shape[0], which is why a second array can_change masks out which labels should be marked as known and which labels can change. Formally, the output of the algorithm is an array labels such that np.all((labels[can_change == 0] == init_labels[can_change == 0])) is True.

Can be installed via (requires Python>=3.7)

pip install ConstrainedKMeans

Example basic usage:

import numpy as np
from matplotlib import pyplot as plt

from ConstrainedKMeans import ConstrainedKMeans as CKM

def run_test():
    # Generate random dataset consisting of 5 gaussian centers
    ppt = 100
    free = 1000
    x = []
    init_labels = []
    can_change = []
    var = np.eye(2)
    for i, mu in enumerate([[1, 1], [8, 8], [1, 8], [8, 1], [4, 4]]):
        x.append(np.random.multivariate_normal(mu, var, ppt))
        init_labels.append(np.full(ppt, i))
        can_change.append(np.zeros(ppt))

    # Add some free points (shown in yellow in the plot)
    x.append(np.random.multivariate_normal([4, 4], np.eye(2) * 8, free))
    init_labels.append(np.full(free, 5))
    can_change.append(np.zeros(free) + 1)

    x = np.vstack(x)
    init_labels = np.hstack(init_labels)
    can_change = np.hstack(can_change)

    fig, axes = plt.subplots(1, 2, sharey=True, squeeze=False)

    # Plot before
    axes[0, 0].scatter(x[:, 0], x[:, 1], c=init_labels, s=10)
    axes[0, 0].set_title("Before clustering (yellow points are free)")

    # Fit labels
    ckm = CKM(n_clusters=5)
    ckm.fit(x, can_change, init_labels)
    labels = ckm.get_labels()

    # Plot after
    axes[0, 1].scatter(x[:, 0], x[:, 1], c=labels, s=10)
    axes[0, 1].set_title("After clustering")
    plt.show()

if __name__ == '__main__':
    run_test()

Releases

No releases published

Packages

No packages published

Languages