From 5b6d0a324572e71699cb3a0e5c31638261cdc64f Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Sun, 15 Dec 2024 12:32:07 -0700 Subject: [PATCH 01/14] Fix #41, add tests --- CondaPkg.toml | 1 - Project.toml | 8 ++++- src/MLJCatBoostInterface.jl | 6 ++-- test/mlj_interface.jl | 58 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 3 ++ 5 files changed, 71 insertions(+), 5 deletions(-) diff --git a/CondaPkg.toml b/CondaPkg.toml index fe2aabd..e9eca30 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -1,3 +1,2 @@ [deps] -numpy = ">=1,<2" catboost = ">=1.1" diff --git a/Project.toml b/Project.toml index cdd53f3..0392299 100644 --- a/Project.toml +++ b/Project.toml @@ -15,8 +15,11 @@ DataFrames = "1.6" MLJBase = "1" MLJModelInterface = "1.7" MLJTestInterface = "0.2.6" +MLJTuning = "0.8" OrderedCollections = "1.6" PythonCall = "0.9" +StatisticalMeasures = "0.1" +Statistics = "1.6" Tables = "1.10" Test = "1.6" julia = "1.6" @@ -26,7 +29,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd" +MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f" +StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "DataFrames", "MLJBase", "MLJTestInterface", "Test"] +test = ["Aqua", "DataFrames", "MLJBase", "MLJTestInterface", "MLJTuning", "Statistics", "StatisticalMeasures", "Test"] diff --git a/src/MLJCatBoostInterface.jl b/src/MLJCatBoostInterface.jl index 919a9d8..aea39b9 100644 --- a/src/MLJCatBoostInterface.jl +++ b/src/MLJCatBoostInterface.jl @@ -140,9 +140,9 @@ end ##### function MMI.update(mlj_model::CatBoostModels, verbosity::Integer, fitresult, cache, - data_pool) + data_pool) if mlj_model.iterations > cache.mlj_model.iterations && - MMI.is_same_except(mlj_model, cache.mlj_model, :iterations) + MMI.is_same_except(mlj_model, cache.mlj_model, :iterations) iterations = mlj_model.iterations - cache.mlj_model.iterations verbose = verbosity <= 1 ? false : true new_model = model_init(mlj_model; verbose, iterations) @@ -150,7 +150,7 @@ function MMI.update(mlj_model::CatBoostModels, verbosity::Integer, fitresult, ca report = (feature_importances=feature_importance(new_model),) cache = (; mlj_model=mlj_model) else - new_model, cache, report = fit(mlj_model, verbosity, data_pool) + new_model, cache, report = MMI.fit(mlj_model, verbosity, data_pool) end return new_model, cache, report diff --git a/test/mlj_interface.jl b/test/mlj_interface.jl index 0c374a1..622e58d 100644 --- a/test/mlj_interface.jl +++ b/test/mlj_interface.jl @@ -57,6 +57,64 @@ restored_fitresult = MLJBase.restore(mach, serializable_fitresult) end + @testset "evaluate" begin + using Statistics + + X = (a = [1, 2, 3, 4], b = [4, 5, 6, 7]) + y = [1.0, 2.0, 3.0, 4.0] + model = CatBoostRegressor(; iterations = 5) + mach = machine(model, X, y) + e = evaluate!(mach; resampling = Holdout(fraction_train = 0.7), + measure = l1, verbosity = 0) + @test e.measurement[1] >= 0.0 # l1 should be non-negative + + y_class = categorical(["cat", "dog", "cat", "dog"]) + model_class = CatBoostClassifier(; iterations = 5) + mach_class = machine(model_class, X, y_class) + e_class = evaluate!(mach_class; resampling = CV(nfolds = 2), + measure = accuracy, verbosity = 0) + @test 0.0 <= e_class.measurement[1] <= 1.0 # accuracy between 0 and 1 + end + + + @testset "MLJ GridSearch" begin + X = (a=[1, 2, 3, 4, 5, 6], b=[4, 5, 6, 7, 8, 9]) + y = [1.0, 2.0, 3.0, 4.0, 5, 6] + model = CatBoostRegressor() + r = range(model, :iterations, lower=2, upper=5) + tuning = Grid(resolution=3) + tuned_model = TunedModel(model=model, tuning=tuning, + resampling=Holdout(fraction_train=0.7), + range=r, measure=l1) + mach = machine(tuned_model, X, y) + e = evaluate!(mach, + measure=l1, + resampling=Holdout(fraction_train=0.7), + verbosity=0, + acceleration=CPU1(), + per_observation=false) + + @test e.measurement[1] >= 0.0 # l1 should be non-negative + + y_class = categorical(["cat", "dog", "cat", "dog", "cat", "dog"]) + model_class = CatBoostClassifier() + r1 = range(model_class, :iterations, lower=2, upper=5) + r2 = range(model_class, :depth, lower=1, upper=3) + tuning_class = Grid(resolution=2) + tuned_model_class = TunedModel(model=model_class, tuning=tuning_class, + resampling=CV(nfolds=2), + range=[r1, r2], measure=accuracy) + + mach_class = machine(tuned_model_class, X, y_class) + e_class = evaluate!(mach_class, + measure=accuracy, + resampling=CV(nfolds=2), + verbosity=0, + acceleration=CPU1(), + per_observation=false) + @test 0.0 <= e_class.measurement[1] <= 1.0 # accuracy between 0 and 1 + end + @testset "generic interface tests" begin @testset "CatBoostRegressor" begin data = MLJTestInterface.make_regression() diff --git a/test/runtests.jl b/test/runtests.jl index c1578fc..492566d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,10 @@ using Aqua using DataFrames using MLJBase using MLJTestInterface +using MLJTuning using PythonCall +using StatisticalMeasures +using Statistics using Test const WRAPPER_EXAMPLES_DIR = joinpath(@__DIR__, "..", "examples/wrapper") From c4a97cefe2ef9667877fe2f1cbaff593a3b89d66 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Sun, 15 Dec 2024 12:33:29 -0700 Subject: [PATCH 02/14] fix formatting --- src/MLJCatBoostInterface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MLJCatBoostInterface.jl b/src/MLJCatBoostInterface.jl index aea39b9..b6ba9e7 100644 --- a/src/MLJCatBoostInterface.jl +++ b/src/MLJCatBoostInterface.jl @@ -140,7 +140,7 @@ end ##### function MMI.update(mlj_model::CatBoostModels, verbosity::Integer, fitresult, cache, - data_pool) + data_pool) if mlj_model.iterations > cache.mlj_model.iterations && MMI.is_same_except(mlj_model, cache.mlj_model, :iterations) iterations = mlj_model.iterations - cache.mlj_model.iterations From 7ad0cabcb655f7b61828eab1b49b11489c82e93d Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Sun, 15 Dec 2024 12:35:30 -0700 Subject: [PATCH 03/14] fix bounds for Test, Statistics --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 0392299..0ac24e0 100644 --- a/Project.toml +++ b/Project.toml @@ -19,9 +19,9 @@ MLJTuning = "0.8" OrderedCollections = "1.6" PythonCall = "0.9" StatisticalMeasures = "0.1" -Statistics = "1.6" +Statistics = "<0.0.1, 1" Tables = "1.10" -Test = "1.6" +Test = "<0.0.1, 1" julia = "1.6" [extras] From 22ca0092a851d1f6450cfe469528c2f18d2e99e0 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Sun, 15 Dec 2024 12:36:39 -0700 Subject: [PATCH 04/14] run formatter --- format/Manifest.toml | 201 +++++++++++------------------------- format/README.md | 4 +- src/MLJCatBoostInterface.jl | 2 +- test/mlj_interface.jl | 53 +++++----- 4 files changed, 92 insertions(+), 168 deletions(-) diff --git a/format/Manifest.toml b/format/Manifest.toml index 0773aa6..d584a02 100644 --- a/format/Manifest.toml +++ b/format/Manifest.toml @@ -1,196 +1,121 @@ # This file is machine-generated - editing it directly is not advised -[[ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +julia_version = "1.11.2" +manifest_format = "2.0" +project_hash = "30b405be1c677184b7703a9bfb3d2100029ccad0" -[[Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[Base64]] +[[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" -[[CSTParser]] +[[deps.CSTParser]] deps = ["Tokenize"] -git-tree-sha1 = "3ddd48d200eb8ddf9cb3e0189fc059fd49b97c1f" +git-tree-sha1 = "0157e592151e39fa570645e2b2debcdfb8a0f112" uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" -version = "3.3.6" +version = "3.4.3" -[[CommonMark]] -deps = ["Crayons", "JSON", "PrecompileTools", "URIs"] -git-tree-sha1 = "532c4185d3c9037c0237546d817858b23cf9e071" +[[deps.CommonMark]] +deps = ["Crayons", "PrecompileTools"] +git-tree-sha1 = "3faae67b8899797592335832fccf4b3c80bb04fa" uuid = "a80b9123-70ca-4bc0-993e-6e3bcb318db6" -version = "0.8.12" +version = "0.8.15" -[[Compat]] -deps = ["Dates", "LinearAlgebra", "UUIDs"] -git-tree-sha1 = "886826d76ea9e72b35fcd000e535588f7b60f21d" +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.10.1" +version = "4.16.0" + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" -[[Crayons]] + [deps.Compat.weakdeps] + Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" version = "4.1.1" -[[DataStructures]] +[[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "ac67408d9ddf207de5cfa9a97e114352430f01ed" +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.16" +version = "0.18.20" -[[Dates]] +[[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" -[[Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" - -[[Glob]] +[[deps.Glob]] git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" version = "1.3.1" -[[InteractiveUtils]] +[[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" -[[JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" - -[[JuliaFormatter]] -deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "PrecompileTools", "Tokenize"] -git-tree-sha1 = "8f5295e46f594ad2d8652f1098488a77460080cd" +[[deps.JuliaFormatter]] +deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "PrecompileTools", "TOML", "Tokenize"] +git-tree-sha1 = "59cf7ad64f1b0708a4fa4369879d33bad3239b56" uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" -version = "1.0.45" - -[[LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" - -[[LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "1.0.62" -[[LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" - -[[Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[LinearAlgebra]] -deps = ["Libdl"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[Markdown]] +[[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" -[[MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" - -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" - -[[NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" - -[[OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +[[deps.OrderedCollections]] +git-tree-sha1 = "12f1439c4f986bb868acda6ea33ebc78e19b95ad" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" - -[[Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" +version = "1.7.0" -[[Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[PrecompileTools]] +[[deps.PrecompileTools]] deps = ["Preferences"] -git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.0" +version = "1.2.1" -[[Preferences]] +[[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "00805cd429dcb4870060ff49ef443486c262e38e" +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.1" +version = "1.4.3" -[[Printf]] +[[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[Random]] -deps = ["Serialization"] +[[deps.Random]] +deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" -[[SHA]] +[[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[TOML]] +[[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" -[[Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" - -[[Tokenize]] -git-tree-sha1 = "3ac1ac11b09e8033ec93a7993acdb9b68252be6d" +[[deps.Tokenize]] +git-tree-sha1 = "468b4685af4abe0e9fd4d7bf495a6554a6276e75" uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" -version = "0.5.27" - -[[URIs]] -git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.1" +version = "0.5.29" -[[UUIDs]] +[[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" -[[Unicode]] +[[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" - -[[nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" - -[[p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "1.11.0" diff --git a/format/README.md b/format/README.md index e74280f..e347cb3 100644 --- a/format/README.md +++ b/format/README.md @@ -1,6 +1,6 @@ # format -Run `julia --project=format format/run.jl` with Julia 1.6 to run JuliaFormatter. +Run `julia --project=format format/run.jl` with Julia 1.11 to run JuliaFormatter. If you update the version of Julia used to generate the `Manifest.toml` make sure to also -update the version in `.github/workflows/format_check.yml` to match. \ No newline at end of file +update the version in `.github/workflows/format_check.yml` to match. diff --git a/src/MLJCatBoostInterface.jl b/src/MLJCatBoostInterface.jl index b6ba9e7..cbf4659 100644 --- a/src/MLJCatBoostInterface.jl +++ b/src/MLJCatBoostInterface.jl @@ -142,7 +142,7 @@ end function MMI.update(mlj_model::CatBoostModels, verbosity::Integer, fitresult, cache, data_pool) if mlj_model.iterations > cache.mlj_model.iterations && - MMI.is_same_except(mlj_model, cache.mlj_model, :iterations) + MMI.is_same_except(mlj_model, cache.mlj_model, :iterations) iterations = mlj_model.iterations - cache.mlj_model.iterations verbose = verbosity <= 1 ? false : true new_model = model_init(mlj_model; verbose, iterations) diff --git a/test/mlj_interface.jl b/test/mlj_interface.jl index 622e58d..c687185 100644 --- a/test/mlj_interface.jl +++ b/test/mlj_interface.jl @@ -60,55 +60,54 @@ @testset "evaluate" begin using Statistics - X = (a = [1, 2, 3, 4], b = [4, 5, 6, 7]) + X = (a=[1, 2, 3, 4], b=[4, 5, 6, 7]) y = [1.0, 2.0, 3.0, 4.0] - model = CatBoostRegressor(; iterations = 5) + model = CatBoostRegressor(; iterations=5) mach = machine(model, X, y) - e = evaluate!(mach; resampling = Holdout(fraction_train = 0.7), - measure = l1, verbosity = 0) + e = evaluate!(mach; resampling=Holdout(; fraction_train=0.7), + measure=l1, verbosity=0) @test e.measurement[1] >= 0.0 # l1 should be non-negative y_class = categorical(["cat", "dog", "cat", "dog"]) - model_class = CatBoostClassifier(; iterations = 5) + model_class = CatBoostClassifier(; iterations=5) mach_class = machine(model_class, X, y_class) - e_class = evaluate!(mach_class; resampling = CV(nfolds = 2), - measure = accuracy, verbosity = 0) + e_class = evaluate!(mach_class; resampling=CV(; nfolds=2), + measure=accuracy, verbosity=0) @test 0.0 <= e_class.measurement[1] <= 1.0 # accuracy between 0 and 1 end - @testset "MLJ GridSearch" begin X = (a=[1, 2, 3, 4, 5, 6], b=[4, 5, 6, 7, 8, 9]) y = [1.0, 2.0, 3.0, 4.0, 5, 6] model = CatBoostRegressor() - r = range(model, :iterations, lower=2, upper=5) - tuning = Grid(resolution=3) - tuned_model = TunedModel(model=model, tuning=tuning, - resampling=Holdout(fraction_train=0.7), - range=r, measure=l1) + r = range(model, :iterations; lower=2, upper=5) + tuning = Grid(; resolution=3) + tuned_model = TunedModel(; model=model, tuning=tuning, + resampling=Holdout(; fraction_train=0.7), + range=r, measure=l1) mach = machine(tuned_model, X, y) - e = evaluate!(mach, - measure=l1, - resampling=Holdout(fraction_train=0.7), - verbosity=0, - acceleration=CPU1(), - per_observation=false) + e = evaluate!(mach; + measure=l1, + resampling=Holdout(; fraction_train=0.7), + verbosity=0, + acceleration=CPU1(), + per_observation=false) @test e.measurement[1] >= 0.0 # l1 should be non-negative y_class = categorical(["cat", "dog", "cat", "dog", "cat", "dog"]) model_class = CatBoostClassifier() - r1 = range(model_class, :iterations, lower=2, upper=5) - r2 = range(model_class, :depth, lower=1, upper=3) - tuning_class = Grid(resolution=2) - tuned_model_class = TunedModel(model=model_class, tuning=tuning_class, - resampling=CV(nfolds=2), - range=[r1, r2], measure=accuracy) + r1 = range(model_class, :iterations; lower=2, upper=5) + r2 = range(model_class, :depth; lower=1, upper=3) + tuning_class = Grid(; resolution=2) + tuned_model_class = TunedModel(; model=model_class, tuning=tuning_class, + resampling=CV(; nfolds=2), + range=[r1, r2], measure=accuracy) mach_class = machine(tuned_model_class, X, y_class) - e_class = evaluate!(mach_class, + e_class = evaluate!(mach_class; measure=accuracy, - resampling=CV(nfolds=2), + resampling=CV(; nfolds=2), verbosity=0, acceleration=CPU1(), per_observation=false) From 8b5a165a81c4b795aad628d4fcf5638d6411fd85 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Sun, 15 Dec 2024 12:39:47 -0700 Subject: [PATCH 05/14] Change formatting default to new lts, 1.10 --- .github/workflows/format_check.yml | 2 +- format/Manifest.toml | 10 +--------- format/README.md | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/.github/workflows/format_check.yml b/.github/workflows/format_check.yml index 0e016bd..92c9d1e 100644 --- a/.github/workflows/format_check.yml +++ b/.github/workflows/format_check.yml @@ -15,7 +15,7 @@ jobs: steps: - uses: julia-actions/setup-julia@latest with: - version: 1.6.0 + version: 1.10.7 - uses: actions/checkout@v4 - name: Instantiate `format` environment and format run: | diff --git a/format/Manifest.toml b/format/Manifest.toml index d584a02..9333133 100644 --- a/format/Manifest.toml +++ b/format/Manifest.toml @@ -1,12 +1,11 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.11.2" +julia_version = "1.10.7" manifest_format = "2.0" project_hash = "30b405be1c677184b7703a9bfb3d2100029ccad0" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -version = "1.11.0" [[deps.CSTParser]] deps = ["Tokenize"] @@ -47,7 +46,6 @@ version = "0.18.20" [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" -version = "1.11.0" [[deps.Glob]] git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" @@ -57,7 +55,6 @@ version = "1.3.1" [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -version = "1.11.0" [[deps.JuliaFormatter]] deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "PrecompileTools", "TOML", "Tokenize"] @@ -68,7 +65,6 @@ version = "1.0.62" [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -version = "1.11.0" [[deps.OrderedCollections]] git-tree-sha1 = "12f1439c4f986bb868acda6ea33ebc78e19b95ad" @@ -90,12 +86,10 @@ version = "1.4.3" [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -version = "1.11.0" [[deps.Random]] deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -version = "1.11.0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -114,8 +108,6 @@ version = "0.5.29" [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" -version = "1.11.0" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" -version = "1.11.0" diff --git a/format/README.md b/format/README.md index e347cb3..269fdbe 100644 --- a/format/README.md +++ b/format/README.md @@ -1,6 +1,6 @@ # format -Run `julia --project=format format/run.jl` with Julia 1.11 to run JuliaFormatter. +Run `julia --project=format format/run.jl` with Julia 1.10 to run JuliaFormatter. If you update the version of Julia used to generate the `Manifest.toml` make sure to also update the version in `.github/workflows/format_check.yml` to match. From db92c637053769f5d7dfadce3cca4bb749a0a160 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Sun, 15 Dec 2024 12:42:19 -0700 Subject: [PATCH 06/14] bump lts to 1.10, increase StatisticalMeasures lowerbound --- .github/workflows/CI.yml | 2 +- .github/workflows/docs.yml | 2 +- Project.toml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index fa96113..a4d1897 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,7 +26,7 @@ jobs: fail-fast: false matrix: version: - - '1.6' + - '1.10' - '1' - 'nightly' os: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 1f9cf14..d411f60 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -21,7 +21,7 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@latest with: - version: 1.6 # earliest supported version + version: 1.10 # earliest supported version - uses: julia-actions/cache@v1 # https://github.com/julia-actions/cache - uses: julia-actions/julia-docdeploy@releases/v1 env: diff --git a/Project.toml b/Project.toml index 0ac24e0..c825dd8 100644 --- a/Project.toml +++ b/Project.toml @@ -18,11 +18,11 @@ MLJTestInterface = "0.2.6" MLJTuning = "0.8" OrderedCollections = "1.6" PythonCall = "0.9" -StatisticalMeasures = "0.1" +StatisticalMeasures = "0.1.7" Statistics = "<0.0.1, 1" Tables = "1.10" Test = "<0.0.1, 1" -julia = "1.6" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" From 66105e13e91424d398ad547ead8b8cbe2a6779f0 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Sun, 15 Dec 2024 12:47:29 -0700 Subject: [PATCH 07/14] Revert docs version bump. Leaving it for a different PR --- .github/workflows/docs.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index d411f60..1810c9e 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -21,8 +21,8 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@latest with: - version: 1.10 # earliest supported version - - uses: julia-actions/cache@v1 # https://github.com/julia-actions/cache + version: 1.6 # earliest supported version + - uses: julia-actions/cache@v2 # https://github.com/julia-actions/cache - uses: julia-actions/julia-docdeploy@releases/v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token From db765c826202fcae0b35ea9455fd2668987713e0 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Sun, 15 Dec 2024 12:51:43 -0700 Subject: [PATCH 08/14] add arm mac to test matrix --- .github/workflows/CI.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a4d1897..08631d0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -40,6 +40,9 @@ jobs: - os: macos-latest version: '1' arch: x64 + - os: macos-latest + version: '1' + arch: x64 steps: - uses: actions/checkout@v4 with: From 88e29c80440885a2c6a909b07ba80865195759a3 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Sun, 15 Dec 2024 12:52:22 -0700 Subject: [PATCH 09/14] Fix arch for arm mac tests --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 08631d0..30e91c3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -42,7 +42,7 @@ jobs: arch: x64 - os: macos-latest version: '1' - arch: x64 + arch: arm64 steps: - uses: actions/checkout@v4 with: From a9e89ae2c4c003ba0c2775f22dcac5483a66438c Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Sun, 15 Dec 2024 13:03:51 -0700 Subject: [PATCH 10/14] bump Catboost.jl to v0.3.6 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c825dd8..1113998 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "CatBoost" uuid = "e2e10f9a-a85d-4fa9-b6b2-639a32100a12" authors = ["Beacon Biosignals, Inc."] -version = "0.3.5" +version = "0.3.6" [deps] MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" From cece1f0d1789323a8aa4999b1560f1e9f1a184a8 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Wed, 25 Dec 2024 11:32:29 -0800 Subject: [PATCH 11/14] Make sure `y_first` is a CategoricalValue --- Project.toml | 2 ++ src/MLJCatBoostInterface.jl | 1 + src/mlj_catboostclassifier.jl | 9 ++++++--- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 1113998..b00f6a5 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Beacon Biosignals, Inc."] version = "0.3.6" [deps] +CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" @@ -11,6 +12,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Aqua = "0.8.4" +CategoricalArrays = "0.10.8" DataFrames = "1.6" MLJBase = "1" MLJModelInterface = "1.7" diff --git a/src/MLJCatBoostInterface.jl b/src/MLJCatBoostInterface.jl index cbf4659..c3d7578 100644 --- a/src/MLJCatBoostInterface.jl +++ b/src/MLJCatBoostInterface.jl @@ -9,6 +9,7 @@ using Tables using MLJModelInterface: MLJModelInterface const MMI = MLJModelInterface using MLJModelInterface: Table, Continuous, Count, Finite, OrderedFactor, Multiclass +using CategoricalArrays: CategoricalArray, CategoricalValue const PKG = "CatBoost" """ diff --git a/src/mlj_catboostclassifier.jl b/src/mlj_catboostclassifier.jl index baad80b..0b21db7 100644 --- a/src/mlj_catboostclassifier.jl +++ b/src/mlj_catboostclassifier.jl @@ -72,8 +72,12 @@ function model_init(mlj_model::CatBoostClassifier; kw...) return catboost.CatBoostClassifier(; mlj_to_kwargs(mlj_model)..., kw...) end +prepare_single_class(y_first) = CategoricalArray([y_first])[1] +prepare_single_class(y_first::CategoricalValue) = y_first + function MMI.fit(mlj_model::CatBoostClassifier, verbosity::Int, data_pool, y_first) - # Check if y_first has only one unique value + # Check if the data pool has only one unique value + y_first = prepare_single_class(y_first) unique_classes = pyconvert(Vector, numpy.unique(data_pool.get_label())) if length(unique_classes) == 1 # Skip training and store the single class @@ -103,8 +107,7 @@ function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool) n = pyconvert(Int, X_pool.shape[0]) classes = [fitresult.single_class] probs = ones(n, 1) - pool = MMI.categorical([fitresult.y_first]) - return MMI.UnivariateFinite(classes, probs; pool=pool) + return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first) end model, y_first = fitresult From 7a52755f148ec5a01aebf6de24de01c0a3092d5b Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Tue, 31 Dec 2024 10:59:03 -0700 Subject: [PATCH 12/14] Pass the pool object to `UnivariateFinite` --- src/mlj_catboostclassifier.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mlj_catboostclassifier.jl b/src/mlj_catboostclassifier.jl index 0b21db7..7ad4d67 100644 --- a/src/mlj_catboostclassifier.jl +++ b/src/mlj_catboostclassifier.jl @@ -102,18 +102,18 @@ MMI.fitted_params(::CatBoostClassifier, model) = (model=model,) MMI.reports_feature_importances(::Type{<:CatBoostClassifier}) = true function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool) - if fitresult[1] === nothing + model, y_first = fitresult + if model === nothing # Always predict the single class n = pyconvert(Int, X_pool.shape[0]) classes = [fitresult.single_class] probs = ones(n, 1) - return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first) + return MMI.UnivariateFinite(classes, probs; pool=y_first.pool) end - model, y_first = fitresult classes = pyconvert(Array, model.classes_.tolist()) py_preds = predict_proba(model, X_pool) - preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=y_first) + preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=y_first.pool) return preds end From e9a7907a5305bc33a0855c2a511c1f488e5d5542 Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Tue, 31 Dec 2024 11:22:58 -0700 Subject: [PATCH 13/14] Fix `MMI.predict`, ensure `fitresult` is a named tuple --- src/mlj_catboostclassifier.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/mlj_catboostclassifier.jl b/src/mlj_catboostclassifier.jl index 7ad4d67..ed31def 100644 --- a/src/mlj_catboostclassifier.jl +++ b/src/mlj_catboostclassifier.jl @@ -93,7 +93,7 @@ function MMI.fit(mlj_model::CatBoostClassifier, verbosity::Int, data_pool, y_fir cache = (; mlj_model=deepcopy(mlj_model)) report = (feature_importances=feature_importance(model),) - fitresult = (model, y_first) + fitresult = (model=model, y_first=y_first) return (fitresult, cache, report) end @@ -102,18 +102,17 @@ MMI.fitted_params(::CatBoostClassifier, model) = (model=model,) MMI.reports_feature_importances(::Type{<:CatBoostClassifier}) = true function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool) - model, y_first = fitresult - if model === nothing + if fitresult.model === nothing # Always predict the single class n = pyconvert(Int, X_pool.shape[0]) classes = [fitresult.single_class] probs = ones(n, 1) - return MMI.UnivariateFinite(classes, probs; pool=y_first.pool) + return MMI.UnivariateFinite(classes, probs; pool=fitresult.y_first.pool) end - classes = pyconvert(Array, model.classes_.tolist()) - py_preds = predict_proba(model, X_pool) - preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=y_first.pool) + classes = pyconvert(Array, fitresult.model.classes_.tolist()) + py_preds = predict_proba(fitresult.model, X_pool) + preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=fitresult.y_first.pool) return preds end From 52d957f5ee5e17e5ce04bd1fe9737220aa79473d Mon Sep 17 00:00:00 2001 From: Tyler Thomas Date: Tue, 31 Dec 2024 11:24:01 -0700 Subject: [PATCH 14/14] reformat --- src/mlj_catboostclassifier.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mlj_catboostclassifier.jl b/src/mlj_catboostclassifier.jl index ed31def..cf6c369 100644 --- a/src/mlj_catboostclassifier.jl +++ b/src/mlj_catboostclassifier.jl @@ -112,7 +112,8 @@ function MMI.predict(mlj_model::CatBoostClassifier, fitresult, X_pool) classes = pyconvert(Array, fitresult.model.classes_.tolist()) py_preds = predict_proba(fitresult.model, X_pool) - preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); pool=fitresult.y_first.pool) + preds = MMI.UnivariateFinite(classes, pyconvert(Array, py_preds); + pool=fitresult.y_first.pool) return preds end