diff --git a/Project.toml b/Project.toml index 6889a7aa..60580da0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "0.14.5" +version = "0.14.6" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/operations.jl b/src/operations.jl index 890ddb3f..d49ae317 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -70,18 +70,27 @@ for operation in OPERATIONS ex = quote # 1. operations on machines, given *concrete* data: - function $operation(mach::Machine{M}, Xraw, Xraw_more...) where M - if mach.state > 0 || M <: Static + function $operation(mach::Machine, Xraw) + if mach.state > 0 return $(operation)(mach.model, mach.fitresult, - Xraw, Xraw_more...) + Xraw) else error("$mach has not been trained.") end end + + function $operation(mach::Machine{<:Static}, Xraw, Xraw_more...) + isdefined(mach, :fitresult) || (mach.fitresult = nothing) + return $(operation)(mach.model, mach.fitresult, + Xraw, Xraw_more...) + end # 2. operations on machines, given *dynamic* data (nodes): $operation(mach::Machine, X::AbstractNode) = node($(operation), mach, X) + + $operation(mach::Machine{<:Static}, X::AbstractNode, Xmore::AbstractNode...) = + node($(operation), mach, X, Xmore...) end eval(ex) end @@ -101,7 +110,7 @@ function predict_mode(m::Union{ProbabilisticComposite,ProbabilisticSurrogate}, fitresult, Xnew) if haskey(fitresult, :predict_mode) - return fitresult.predict_mode(X) + return fitresult.predict_mode(Xnew) end return mode.(predict(m, fitresult, Xnew)) end @@ -110,7 +119,7 @@ function predict_mean(m::Union{ProbabilisticComposite,ProbabilisticSurrogate}, fitresult, Xnew) if haskey(fitresult, :predict_mean) - return fitresult.predict_mean(X) + return fitresult.predict_mean(Xnew) end return mean.(predict(m, fitresult, Xnew)) end @@ -119,7 +128,7 @@ function predict_median(m::Union{ProbabilisticComposite,ProbabilisticSurrogate}, fitresult, Xnew) if haskey(fitresult, :predict_median) - return fitresult.predict_median(X) + return fitresult.predict_median(Xnew) end return median.(predict(m, fitresult, Xnew)) end