Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Clustering.DBSCAN to interface #11

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@ version = "0.1.4"
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"

[compat]
Clustering = "0.14"
Distances = "0.9, 0.10"
MLJModelInterface = "0.3.6,0.4, 1.0"
NearestNeighbors = "0.4"
julia = "1"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "MLJBase", "Random", "Test"]
101 changes: 99 additions & 2 deletions src/MLJClusteringInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import MLJModelInterface: Continuous, Count, Finite, Multiclass, Table, OrderedF
@mlj_model, metadata_model, metadata_pkg

using Distances
using NearestNeighbors

# ===================================================================
## EXPORTS
export KMeans, KMedoids
export KMeans, KMedoids, DBSCAN

# ===================================================================
## CONSTANTS
Expand All @@ -34,13 +35,26 @@ K-Medoids algorithm: find K centroids corresponding to K clusters in the data.
Unlike K-Means, the centroids are found among data points themselves.
"""

const DBSCANDescription ="""
DBSCAN algorithm: find clusters through density-based expansion of seed points.
"""

const KMFields ="""
## Keywords

* `k=3` : number of centroids
* `metric` : distance metric to use
"""

const DBFields ="""
## Keywords

* `radius=1.0` : query radius
* `leafsize=20` : number of points binned in each leaf node
* `min_neighbors=1` : minimum number of core point neighbors
* `min_cluster_size=1` : minimum number of points in a valid cluster
"""

const PKG = "MLJClusteringInterface"

####
Expand All @@ -56,7 +70,6 @@ $KMFields
See also the
[package documentation](http://juliastats.github.io/Clustering.jl/latest/kmeans.html).
"""

@mlj_model mutable struct KMeans <: MMI.Unsupervised
k::Int = 3::(_ ≥ 2)
metric::SemiMetric = SqEuclidean()
Expand Down Expand Up @@ -162,6 +175,81 @@ function MMI.predict(model::Union{KMeans,KMedoids}, fitresult, Xnew)
return cluster_labels[pred]
end

####
#### DBSCAN
####
"""
DBSCAN(; kwargs...)

$DBSCANDescription

$DBFields

See also the
[package documentation](https://juliastats.org/Clustering.jl/stable/dbscan.html).
"""
@mlj_model mutable struct DBSCAN <: MMI.Unsupervised
radius::Real = 1.0::(_ > 0)
leafsize::Int = 20::(_ > 0)
min_neighbors::Int = 1::(_ > 0)
min_cluster_size::Int = 1::(_ > 0)
end

function MMI.fit(model::DBSCAN, verbosity::Int, X)
Xarray = MMI.matrix(X, transpose=true)
clusters = Cl.dbscan(Xarray, model.radius;
leafsize=model.leafsize,
min_neighbors=model.min_neighbors,
min_cluster_size=model.min_cluster_size)

# assignments and point types
npoints = size(Xarray, 2)
assignments = zeros(Int, npoints)
pointtypes = zeros(Int, npoints)
for (k, cluster) in enumerate(clusters)
for i in cluster.core_indices
assignments[i] = k
pointtypes[i] = 1
end
for i in cluster.boundary_indices
assignments[i] = k
pointtypes[i] = 0
end
end

result = (Xarray, assignments, pointtypes)
cache = nothing
report = nothing
result, cache, report
end

MMI.fitted_params(::DBSCAN, fitresult) = (assignments=fitresult[1][2],
pointtypes=fitresult[1][3])

function MMI.transform(::DBSCAN, fitresult, X)
# table with assignments in first column and
# point types in second column (core=1 vs. boundary=0)
_, assignments, pointtypes = fitresult[1]
X̃ = [assignments pointtypes]
MMI.table(X̃, prototype=X)
end

function MMI.predict(::DBSCAN, fitresult, Xnew)
X1, assignments, _ = fitresult[1]
X2 = MMI.matrix(Xnew, transpose=true)

labels = MMI.categorical(assignments)

# construct KDtree with points in X1
tree = KDTree(X1, Euclidean())

# find nearest neighbor of X2 in X1
inds, _ = nn(tree, X2)

# return assignment of nearest neighbor
labels[inds]
end

####
#### METADATA
####
Expand Down Expand Up @@ -194,5 +282,14 @@ metadata_model(
path = "$(PKG).KMedoids"
)

metadata_model(
DBSCAN,
input = MMI.Table(Continuous),
output = MMI.Table(Continuous),
weights = false,
descr = DBSCANDescription,
path = "$(PKG).DBSCAN"
)

end # module

7 changes: 7 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
52 changes: 43 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@ import LinearAlgebra: norm

using MLJBase
using MLJClusteringInterface
using Random:seed!
using Random: seed!
using Test

const Dist = Distances

seed!(132442)
X, y = @load_crabs

####
#### KMEANS
####

@testset "Kmeans" begin
@testset "KMeans" begin
barekm = KMeans()
fitresult, cache, report = fit(barekm, 1, X)
R = matrix(transform(barekm, fitresult, X))
Expand All @@ -28,25 +26,61 @@ X, y = @load_crabs
p = predict(barekm, fitresult, X)
@test argmin(R[1, :]) == p[1]
@test argmin(R[10, :]) == p[10]


end

####
#### KMEDOIDS
####

@testset "Kmedoids" begin
@testset "KMedoids" begin
barekm = KMedoids()
fitresult, cache, report = fit(barekm, 1, X)
X_array = matrix(X)
R = matrix(transform(barekm, fitresult, X))
@test R[1, 2] ≈ Dist.evaluate(
@test R[1, 2] ≈ Distances.evaluate(
barekm.metric, view(X_array, 1, :), view(fitresult[1], :, 2)
)
@test R[10, 3] ≈ Dist.evaluate(
@test R[10, 3] ≈ Distances.evaluate(
barekm.metric, view(X_array, 10, :), view(fitresult[1], :, 3)
)
p = predict(barekm, fitresult, X)
@test all(report.assignments .== p)
end

@testset "DBSCAN" begin
# five spot pattern
X = [
0.0 0.0
1.0 0.0
1.0 1.0
0.0 1.0
0.5 0.5
]

# radius < √2 ==> 5 clusters
dbscan = DBSCAN(radius=0.1)
fitresult = fit(dbscan, 1, X)
A = transform(dbscan, fitresult, X)
p = predict(dbscan, fitresult, X)
@test size(matrix(A)) == (5, 2)
@test A.x2 == [0,0,0,0,0]
@test Set(p) == Set(unique(p))

# radius > √2 ==> 1 cluster
dbscan = DBSCAN(radius=√2+eps())
fitresult = fit(dbscan, 1, X)
A = transform(dbscan, fitresult, X)
p = predict(dbscan, fitresult, X)
@test size(matrix(A)) == (5, 2)
@test A.x2 == [1,1,1,1,1]
@test unique(p) == [1]

# radius < √2 && min_cluster_size = 2 ==> all points are noise
dbscan = DBSCAN(radius=0.1, min_cluster_size=2)
fitresult = fit(dbscan, 1, X)
A = transform(dbscan, fitresult, X)
p = predict(dbscan, fitresult, X)
@test size(matrix(A)) == (5, 2)
@test A.x2 == [0,0,0,0,0]
@test unique(p) == [0]
end