Skip to content

Commit

Permalink
Preventing custom clustering being overwritten with kmeans (#44)
Browse files Browse the repository at this point in the history
* preventing custom clustering being overwritten with kmeans

* bumping version number

* bug fix
  • Loading branch information
htjb authored Oct 16, 2023
1 parent 8e51f8c commit 91258ed
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Introduction

:margarine: Marginal Bayesian Statistics
:Authors: Harry T.J. Bevins
:Version: 1.1.2
:Version: 1.1.3
:Homepage: https://github.com/htjb/margarine
:Documentation: https://margarine.readthedocs.io/

Expand Down
45 changes: 28 additions & 17 deletions margarine/clustered.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def __init__(self, theta, **kwargs):

kmeans = KMeans(self.cluster_number, random_state=0, n_init='auto')
self.cluster_labels = kmeans.fit(self.theta).predict(self.theta)
self.custom_cluster = False
else:
self.custom_cluster = True

if self.cluster_number == 20:
warnings.warn("The number of clusters is 20. " +
Expand All @@ -184,24 +187,32 @@ def __init__(self, theta, **kwargs):
self.cluster_labels = self.cluster_labels.astype(int)
# count the number of times a cluster label appears in cluster_labels
self.cluster_count = np.bincount(self.cluster_labels)
# While loop to make sure clusters are not too small
while self.cluster_count.min() < 100:
warnings.warn("One or more clusters are too small " +

if self.custom_cluster:
if self.cluster_count.min() < 100:
warnings.warn("One or more clusters are too small " +
"(n_cluster < 100). " +
"Reducing the number of clusters by 1.")
minimum_index -= 1
self.cluster_number = ks[minimum_index]
kmeans = KMeans(self.cluster_number, random_state=0, n_init='auto')
self.cluster_labels = kmeans.fit(self.theta).predict(self.theta)
self.cluster_count = np.bincount(self.cluster_labels)
if self.cluster_number == 2:
# break if two clusters
warnings.warn("The number of clusters is 2. This is the " +
"minimum number of clusters that can be used. " +
"Some clusters may be too small and the " +
"train/test split may fail." +
"Try running without clusting. ")
break
"Since cluster_number was supplied margarine" +
"will continue but may crash.")
else:
# While loop to make sure clusters are not too small
while self.cluster_count.min() < 100:
warnings.warn("One or more clusters are too small " +
"(n_cluster < 100). " +
"Reducing the number of clusters by 1.")
minimum_index -= 1
self.cluster_number = ks[minimum_index]
kmeans = KMeans(self.cluster_number, random_state=0, n_init='auto')
self.cluster_labels = kmeans.fit(self.theta).predict(self.theta)
self.cluster_count = np.bincount(self.cluster_labels)
if self.cluster_number == 2:
# break if two clusters
warnings.warn("The number of clusters is 2. This is the " +
"minimum number of clusters that can be used. " +
"Some clusters may be too small and the " +
"train/test split may fail." +
"Try running without clusting. ")
break

split_theta = []
split_sample_weights = []
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def readme(short=False):

setup(
name='margarine',
version='1.1.2',
version='1.1.3',
description='margarine: Posterior Sampling and Marginal Bayesian Statistics',
long_description=readme(),
author='Harry T. J. Bevins',
Expand Down

0 comments on commit 91258ed

Please sign in to comment.