Skip to content

Commit

Permalink
Merge pull request #379 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
For a 0.14.6 release
  • Loading branch information
ablaom authored Jul 26, 2020
2 parents 359635d + ec7b97c commit 11f1ab1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.14.5"
version = "0.14.6"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
21 changes: 15 additions & 6 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,27 @@ for operation in OPERATIONS

ex = quote
# 1. operations on machines, given *concrete* data:
function $operation(mach::Machine{M}, Xraw, Xraw_more...) where M
if mach.state > 0 || M <: Static
function $operation(mach::Machine, Xraw)
if mach.state > 0
return $(operation)(mach.model, mach.fitresult,
Xraw, Xraw_more...)
Xraw)
else
error("$mach has not been trained.")
end
end

function $operation(mach::Machine{<:Static}, Xraw, Xraw_more...)
isdefined(mach, :fitresult) || (mach.fitresult = nothing)
return $(operation)(mach.model, mach.fitresult,
Xraw, Xraw_more...)
end

# 2. operations on machines, given *dynamic* data (nodes):
$operation(mach::Machine, X::AbstractNode) =
node($(operation), mach, X)

$operation(mach::Machine{<:Static}, X::AbstractNode, Xmore::AbstractNode...) =
node($(operation), mach, X, Xmore...)
end
eval(ex)
end
Expand All @@ -101,7 +110,7 @@ function predict_mode(m::Union{ProbabilisticComposite,ProbabilisticSurrogate},
fitresult,
Xnew)
if haskey(fitresult, :predict_mode)
return fitresult.predict_mode(X)
return fitresult.predict_mode(Xnew)
end
return mode.(predict(m, fitresult, Xnew))
end
Expand All @@ -110,7 +119,7 @@ function predict_mean(m::Union{ProbabilisticComposite,ProbabilisticSurrogate},
fitresult,
Xnew)
if haskey(fitresult, :predict_mean)
return fitresult.predict_mean(X)
return fitresult.predict_mean(Xnew)
end
return mean.(predict(m, fitresult, Xnew))
end
Expand All @@ -119,7 +128,7 @@ function predict_median(m::Union{ProbabilisticComposite,ProbabilisticSurrogate},
fitresult,
Xnew)
if haskey(fitresult, :predict_median)
return fitresult.predict_median(X)
return fitresult.predict_median(Xnew)
end
return median.(predict(m, fitresult, Xnew))
end

0 comments on commit 11f1ab1

Please sign in to comment.