Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial commit for implementation of Affinity Propagation #28

Merged
merged 13 commits into from
Dec 10, 2024
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ version = "0.1.11"
[deps]
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Clustering = "0.15"
Distances = "0.9, 0.10"
LinearAlgebra = "1.11.0"
MLJModelInterface = "1.4"
StatsBase = "0.34.3"
Yuan-Ru-Lin marked this conversation as resolved.
Show resolved Hide resolved
julia = "1.6"
139 changes: 136 additions & 3 deletions src/MLJClusteringInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
@mlj_model, metadata_model, metadata_pkg

using Distances
using LinearAlgebra
using StatsBase

# ===================================================================
## EXPORTS
export KMeans, KMedoids, DBSCAN, HierarchicalClustering
export KMeans, KMedoids, AffinityPropagation, DBSCAN, HierarchicalClustering

# ===================================================================
## CONSTANTS
Expand Down Expand Up @@ -95,7 +97,6 @@
return MMI.table(X̃, prototype=X)
end


# # PREDICT FOR K_MEANS AND K_MEDOIDS

function MMI.predict(model::Union{KMeans,KMedoids}, fitresult, Xnew)
Expand Down Expand Up @@ -208,10 +209,66 @@

MMI.reporting_operations(::Type{<:HierarchicalClustering}) = (:predict,)

# # AFFINITY_PROPAGATION

@mlj_model mutable struct AffinityPropagation <: MMI.Static
damp::Float64 = 0.5::(0.0 ≤ _ < 1.0)
maxiter::Int = 200::(_ > 0)
tol::Float64 = 1e-6::(_ > 0)
preference::Union{Nothing,Float64} = nothing
metric::SemiMetric = SqEuclidean()
end

function MMI.predict(model::AffinityPropagation, ::Nothing, X)
Xarray = MMI.matrix(X)'

Check warning on line 223 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L222-L223

Added lines #L222 - L223 were not covered by tests

# Compute similarity matrix using negative pairwise distances
S = -pairwise(model.metric, Xarray, dims=2)

Check warning on line 226 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L226

Added line #L226 was not covered by tests

diagonal_element = if !isnothing(model.preference)
model.preference

Check warning on line 229 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L228-L229

Added lines #L228 - L229 were not covered by tests
else
# Get the median out of all pairs of similarity, that is, values above
# the diagonal line.
# Such default choice is mentioned in the algorithm's wiki article
iuppertri = triu!(trues(size(S)),1)
median(S[iuppertri])

Check warning on line 235 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L234-L235

Added lines #L234 - L235 were not covered by tests
end

fill!(view(S, diagind(S)), diagonal_element)

Check warning on line 238 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L238

Added line #L238 was not covered by tests

result = Cl.affinityprop(

Check warning on line 240 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L240

Added line #L240 was not covered by tests
S,
maxiter=model.maxiter,
tol=model.tol,
damp=model.damp
)

# Get number of clusters and labels
exemplars = result.exemplars
k = length(exemplars)
cluster_labels = MMI.categorical(1:k)

Check warning on line 250 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L248-L250

Added lines #L248 - L250 were not covered by tests

# Store exemplar points as centers (similar to KMeans/KMedoids)
centers = view(Xarray, :, exemplars)

Check warning on line 253 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L253

Added line #L253 was not covered by tests

report = (

Check warning on line 255 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L255

Added line #L255 was not covered by tests
exemplars=exemplars,
centers=centers,
cluster_labels=cluster_labels,
iterations=result.iterations,
converged=result.converged
)

return MMI.categorical(result.assignments), report

Check warning on line 263 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L263

Added line #L263 was not covered by tests
end

MMI.reporting_operations(::Type{<:AffinityPropagation}) = (:predict,)

Check warning on line 266 in src/MLJClusteringInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/MLJClusteringInterface.jl#L266

Added line #L266 was not covered by tests

# # METADATA

metadata_pkg.(
(KMeans, KMedoids, DBSCAN, HierarchicalClustering),
(KMeans, KMedoids, DBSCAN, HierarchicalClustering, AffinityPropagation),
name="Clustering",
uuid="aaaa29a8-35af-508c-8bc3-b662a17a0fe5",
url="https://github.com/JuliaStats/Clustering.jl",
Expand Down Expand Up @@ -251,6 +308,13 @@
path = "$(PKG).HierarchicalClustering"
)

metadata_model(
AffinityPropagation,
human_name = "Affinity Propagation clusterer",
input_scitype = MMI.Table(Continuous),
path = "$(PKG).AffinityPropagation"
)

"""
$(MMI.doc_header(KMeans))

Expand Down Expand Up @@ -614,4 +678,73 @@
"""
HierarchicalClustering

"""
$(MMI.doc_header(AffinityPropagation))

[Affinity Propagation](https://en.wikipedia.org/wiki/Affinity_propagation) is a clustering algorithm based on the concept of "message passing" between data points. More information is available at the [Clustering.jl documentation](https://juliastats.org/Clustering.jl/stable/index.html). Use `predict` to get cluster assignments. Indices of the exemplars, their values, etc, are accessed from the machine report (see below).

This is a static implementation, i.e., it does not generalize to new data instances, and
there is no training data. For clusterers that do generalize, see [`KMeans`](@ref) or
[`KMedoids`](@ref).

In MLJ or MLJBase, create a machine with

mach = machine(model)

# Hyper-parameters

- `damp = 0.5`: damping factor

- `maxiter = 200`: maximum number of iteration

- `tol = 1e-6`: tolerance for converenge

- `preference = nothing`: the value of the diagonal elements of the similarity matrix
Yuan-Ru-Lin marked this conversation as resolved.
Show resolved Hide resolved

- `metric = SqEuclidean`: metric (see `Distances.jl` for available metrics)
Yuan-Ru-Lin marked this conversation as resolved.
Show resolved Hide resolved

# Operations

- `predict(mach, X)`: return cluster label assignments, as an unordered
`CategoricalVector`. Here `X` is any table of input features (eg, a `DataFrame`) whose
columns are of scitype `Continuous`; check column scitypes with `schema(X)`.

# Report

After calling `predict(mach)`, the fields of `report(mach)` are:

- exemplars: indices of the data picked as exemplars in `X`

- centers: positions of the exemplars in the feature space

- cluster_labels: labels of clusters given to each datum in `X`

- iterations: the number of iteration run by the algorithm

- converged: whether or not the algorithm converges by the maximum iteration

# Examples

```
using MLJ, MLJClusteringInterface
Yuan-Ru-Lin marked this conversation as resolved.
Show resolved Hide resolved

X, labels = make_moons(400, noise=0.9, rng=1)

AffinityPropagation = @load AffinityPropagation pkg=Clustering
model = AffinityPropagation(preference=-10.0)
mach = machine(model)

# compute and output cluster assignments for observations in `X`:
yhat = predict(mach, X)

# Get the positions of the exemplars
report(mach).centers

# Plot clustering result
using GLMakie
scatter(MLJ.matrix(X)', color=yhat.refs)
```
"""
AffinityPropagation

end # module
Loading