Skip to content

Commit

Permalink
Merge pull request #25 from JuliaAI/clustering-0-15-compat
Browse files Browse the repository at this point in the history
Bump compat for Clustering to 0.15
  • Loading branch information
ablaom authored Apr 26, 2023
2 parents 7bae312 + 35e8d7e commit fb174d9
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 6 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
name = "MLJClusteringInterface"
uuid = "d354fa79-ed1c-40d4-88ef-b8c7bd1568af"
authors = ["Anthony D. Blaom <[email protected]>", "Thibaut Lienart <[email protected]>", "Okon Samuel <[email protected]>"]
version = "0.1.9"
version = "0.1.10"

[deps]
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"

[compat]
Clustering = "0.14"
Clustering = "0.15"
Distances = "0.9, 0.10"
MLJModelInterface = "1.4"
julia = "1.6"
2 changes: 1 addition & 1 deletion src/MLJClusteringInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function MMI.predict(model::DBSCAN, ::Nothing, X)
leafsize=model.leafsize,
min_neighbors=model.min_neighbors,
min_cluster_size=model.min_cluster_size,
)
).clusters
nclusters = length(clusters)

# assignments and point types
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
36 changes: 33 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import LinearAlgebra: norm
using MLJBase
using MLJTestInterface
using MLJClusteringInterface
using Random: seed!
using StableRNGs
using Random
using Test

seed!(132442)
srng(n) = StableRNGs.StableRNG(n)
X, y = @load_crabs


Expand Down Expand Up @@ -59,6 +60,35 @@ end

@testset "DBSCAN" begin

# test interface is faithful on some synthetic data:
X, _ = make_moons(500, rng=srng(123))
dbscan = DBSCAN(radius=0.15, min_neighbors=3)
yhat3, _report = predict(dbscan, nothing, X)
Xarray = MLJBase.matrix(X)'
clusters = Clustering.dbscan(
Xarray,
dbscan.radius;
leafsize=dbscan.leafsize,
min_neighbors=dbscan.min_neighbors,
min_cluster_size=dbscan.min_cluster_size,
).clusters
noisy_indices = Set(1:length(y)) # initialization
for (k, cluster) in enumerate(clusters)
for i in cluster.core_indices
delete!(noisy_indices, i)
@test _report.point_types[i] == 'C'
@test yhat3[i] == k
end
for i in cluster.boundary_indices
delete!(noisy_indices, i)
@test _report.point_types[i] == 'B'
@test yhat3[i] == k
end
end
@test all(noisy_indices) do i
_report.point_types[i] == 'N' && yhat3[i] == 0
end

# five spot pattern
X = [
0.0 0.0
Expand All @@ -72,7 +102,7 @@ end
dbscan = DBSCAN(radius=0.1)
yhat1, report1 = predict(dbscan, nothing, X)
@test report1.nclusters == 5
@test report1.point_types == fill('B', 5)
@test report1.point_types == fill('C', 5)
@test Set(yhat1) == Set(unique(yhat1))
@test Set(report1.cluster_labels) == Set(unique(yhat1))

Expand Down

0 comments on commit fb174d9

Please sign in to comment.