Skip to content

Commit

Permalink
Merge pull request #23 from JuliaAI/adjoint-not-transpose
Browse files Browse the repository at this point in the history
Use adjoints not transposes
  • Loading branch information
ablaom authored Dec 6, 2022
2 parents 1a2c1e7 + 8fd563a commit 7bae312
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJClusteringInterface"
uuid = "d354fa79-ed1c-40d4-88ef-b8c7bd1568af"
authors = ["Anthony D. Blaom <[email protected]>", "Thibaut Lienart <[email protected]>", "Okon Samuel <[email protected]>"]
version = "0.1.8"
version = "0.1.9"

[deps]
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Expand Down
12 changes: 4 additions & 8 deletions src/MLJClusteringInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ const PKG = "MLJClusteringInterface"
end

function MMI.fit(model::KMeans, verbosity::Int, X)
# NOTE: using transpose here to get a LinearAlgebra.Transpose object
# which Kmeans can handle.
Xarray = transpose(MMI.matrix(X))
Xarray = MMI.matrix(X)'
result = Cl.kmeans(Xarray, model.k; distance=model.metric, init=model.init)
cluster_labels = MMI.categorical(1:model.k)
fitresult = (result.centers, cluster_labels) # centers (p x k)
Expand All @@ -55,7 +53,7 @@ function MMI.transform(model::KMeans, fitresult, X)
# pairwise distance from samples to centers
= pairwise(
model.metric,
transpose(MMI.matrix(X)),
MMI.matrix(X)',
fitresult[1],
dims=2
)
Expand All @@ -71,9 +69,7 @@ end
end

function MMI.fit(model::KMedoids, verbosity::Int, X)
# NOTE: using transpose=true will materialize the transpose (~ permutedims), KMedoids
# does not yet accept LinearAlgebra.Transpose
Xarray = MMI.matrix(X, transpose=true)
Xarray = MMI.matrix(X)'
# cost matrix: all the pairwise distances
cost_array = pairwise(model.metric, Xarray, dims=2) # n x n
result = Cl.kmedoids(cost_array, model.k, init = model.init)
Expand All @@ -93,7 +89,7 @@ function MMI.transform(model::KMedoids, fitresult, X)
# pairwise distance from samples to medoids
= pairwise(
model.metric,
MMI.matrix(X, transpose=true),
MMI.matrix(X)',
fitresult[1], dims=2
)
return MMI.table(X̃, prototype=X)
Expand Down
5 changes: 1 addition & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestIntegration = "697918b4-fdc1-4f9e-8ff9-929724cee270"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
MLJTestIntegration = "0.2.2"
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Distances
import LinearAlgebra: norm

using MLJBase
using MLJTestIntegration
using MLJTestInterface
using MLJClusteringInterface
using Random: seed!
using Test
Expand Down Expand Up @@ -122,11 +122,11 @@ end

@testset "MLJ interface" begin
models = [KMeans, KMedoids, DBSCAN, HierarchicalClustering]
failures, summary = MLJTestIntegration.test(
failures, summary = MLJTestInterface.test(
models,
X;
mod=@__MODULE__,
verbosity=0,
verbosity=0, # bump to debug
throw=false, # set to true to debug
)
@test isempty(failures)
Expand Down

0 comments on commit 7bae312

Please sign in to comment.