Skip to content

Commit

Permalink
Merge pull request #6 from JuliaAI/stricter-tests-around-resampling
Browse files Browse the repository at this point in the history
Strengthen tests to catch more `selectrows`/`reformat` issues
  • Loading branch information
ablaom authored Feb 8, 2023
2 parents e988c4a + f192102 commit 702c7b9
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJTestInterface"
uuid = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.1.1"
version = "0.2.0"

[deps]
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Expand Down
9 changes: 8 additions & 1 deletion src/attemptors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ function fitted_machine(model, data...; throw=false, verbosity=1)
mach = model isa Static ? machine(model) :
machine(model, data...)
fit!(mach, verbosity=-1)
train, _ = MLJBase.partition(1:MLJBase.nrows(first(data)), 0.5)
fit!(mach, rows=train, verbosity=-1)
fit!(mach, rows=:, verbosity=-1)
MLJBase.report(mach)
MLJBase.fitted_params(mach)
mach
Expand All @@ -89,12 +92,17 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
attempt(finalize(message, verbosity); throw) do
operations = String[]
methods = MLJBase.implemented_methods(fitted_machine.model)
_, test = MLJBase.partition(1:MLJBase.nrows(first(data)), 0.5)
if :predict in methods
predict(fitted_machine, first(data))
predict(fitted_machine, rows=test)
predict(fitted_machine, rows=:)
push!(operations, "predict")
end
if :transform in methods
W = transform(fitted_machine, first(data))
transform(fitted_machine, rows=test)
transform(fitted_machine, rows=:)
push!(operations, "transform")
if :inverse_transform in methods
inverse_transform(fitted_machine, W)
Expand All @@ -104,4 +112,3 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
join(operations, ", ")
end
end

2 changes: 1 addition & 1 deletion test/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ expected_report2 = (
y;
mod=@__MODULE__,
level=2,
verbosity=0
verbosity=0,
)
@test isempty(fails)
@test report[1] == expected_report1
Expand Down

0 comments on commit 702c7b9

Please sign in to comment.