Skip to content

Commit

Permalink
Merge pull request #270 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
For a 0.13.2 release
  • Loading branch information
ablaom authored Apr 28, 2020
2 parents cb8e646 + 9e98055 commit aa0086a
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.13.1"
version = "0.13.2"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
2 changes: 1 addition & 1 deletion src/composition/networks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ function models(W::AbstractNode)
end

"""
sources(W::AbstractNode; kind=:any)
sources(N::AbstractNode; kind=:any)
A vector of all sources referenced by calls `N()` and `fit!(N)`. These
are the sources of the directed acyclic graph associated with the
Expand Down
22 changes: 17 additions & 5 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,25 @@ for operation in (:predict, :predict_mean, :predict_mode, :predict_median,
if isdefined(machine, :fitresult) || M <: Static
return $(operation)(machine.model, machine.fitresult, args...)
else
throw(error("$machine has not been trained."))
error("$machine has not been trained.")
end
end
$(operation)(machine::Machine; rows=:) =
$(operation)(machine, selectrows(machine.args[1], rows))
$(operation)(machine::NodalMachine, args::AbstractNode...) =
node($(operation), machine, args...)
function $(operation)(machine::Machine; rows=:)
isempty(machine.args) &&
throw(ArgumentError("Attempt to accesss non-existent data "*
"bound to a machine, "*
"probably because machine was "*
"deserialized. Specify data `X` "*
"with `$($operation)(mach, X)`. "))
return $(operation)(machine, selectrows(machine.args[1], rows))
end
function $(operation)(machine::NodalMachine, args::AbstractNode...)
length(args) > 0 ||
throw(ArgumentError("`args` in `$($operation)(mach, args...)`"*
" cannot be empty if `mach` is a "*
"`NodalMachine`. "))
return node($(operation), machine, args...)
end
end
eval(ex)
end
Expand Down
4 changes: 2 additions & 2 deletions test/composition/arrows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ using Random

fit!(ŷ, rows=train)

@test isapprox(rms((rows=test), ys(rows=test)), 0.627123, rtol=1e-4)
@test isapprox(rms((rows=test), ys(rows=test)), 0.627123, atol=0.07)

# shortcut to get and set hyperparameters of a node
ẑ[:lambda] = 5.0
fit!(ŷ, rows=train)
@test isapprox(rms((rows=test), ys(rows=test)), 0.62699, rtol=1e-4)
@test isapprox(rms((rows=test), ys(rows=test)), 0.62699, atol=0.07)
end

@testset "Auto-source" begin
Expand Down
21 changes: 21 additions & 0 deletions test/operations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module TestOperations

using Test
using MLJBase
using ..Models

@testset "errors for deserialized machines" begin
filename = joinpath(@__DIR__, "machine.jlso")
m = machine(filename)
@test_throws ArgumentError predict(m)
end

@testset "error for operations on nodes" begin
X = source()
m = machine(OneHotEncoder(), X)
@test_throws ArgumentError transform(m)
end

end

true
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ end
VERSION v"1.3.0-" && @test include("composition/arrows.jl")
end

@testset "operations.jl" begin
@test include("operations.jl")
end

@testset "hyperparam" begin
@test include("hyperparam/one_dimensional_ranges.jl")
@test include("hyperparam/one_dimensional_range_methods.jl")
Expand Down

0 comments on commit aa0086a

Please sign in to comment.