Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arrows #314

Merged
merged 2 commits into from
Nov 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ include("networks.jl") # for building learning networks
include("composites.jl") # composite models & exporting learning networks
include("pipelines.jl") # pipelines (exported linear learning networks)
include("operations.jl") # syntactic sugar for operations (predict, etc)

if VERSION ≥ v"1.3.0-"
include("arrows.jl")
end

include("resampling.jl") # resampling strategies and model evaluation
include("parameters.jl") # hyperparameter ranges and grid generation
include("tuning.jl")
Expand Down
24 changes: 24 additions & 0 deletions src/arrows.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Syntactic sugar for arrow syntax
# we need version ≥ 1.3 in order to make use of multiple dispatch
# over abstract types


# This allows implicit: data |> machine
(mach::AbstractMachine{<:Unsupervised})(data) = transform(mach, data)
(mach::AbstractMachine{<:Supervised})(data) = predict(mach, data)
(mach::AbstractMachine)(data::AbstractMatrix) = data |> table |> mach

# This allows implicit: data |> Unsupervised
(m::Unsupervised)(data::AbstractNode) = data |> machine(m, data)
(m::Unsupervised)(data) = source(data) |> m
(m::Unsupervised)(data::AbstractMatrix) = data |> table |> m

# This allows implicit: data |> Supervised
(m::Supervised)(data::NTuple{2,AbstractNode}) = data[1] |> machine(m, data...)
(m::Supervised)(data::Tuple{AbstractNode,Any}) = (data[1], source(data[2], kind=:target)) |> m
(m::Supervised)(data::Tuple) = (source(data[1]), data[2]) |> m
(m::Supervised)(data::Tuple{AbstractMatrix,Any}) = (data[1] |> table, data[2]) |> m

# This allows implicit: data |> inverse_transform(node)
inverse_transform(node::Node{<:NodalMachine{<:Unsupervised}}) =
data -> inverse_transform(node.machine, data)
20 changes: 0 additions & 20 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,5 @@ function fitted_params(machine::AbstractMachine)
end
end


# Syntactic sugar for pipe syntax
# we need version ≥ 1.3 in order to make use of multiple dispatch
# over abstract types
if VERSION ≥ v"1.3.0-"

(mach::AbstractMachine{<:Unsupervised})(data) = transform(mach, data)
(mach::AbstractMachine{<:Supervised})(data) = predict(mach, data)

(m::Unsupervised)(data::AbstractNode) = data |> machine(m, data)
(m::Unsupervised)(data) = source(data) |> m

(m::Supervised)(data::NTuple{2,AbstractNode}) = data[1] |> machine(m, data...)
(m::Supervised)(data::Tuple) = source.(data) |> m

inverse_transform(node::Node{<:NodalMachine{<:Unsupervised}}) =
data->inverse_transform(node.machine, data)
end # version ≥ 1.3

# Syntactic sugar to directly access hyperparameters
getindex(n::Node{<:NodalMachine{<:Model}}, s::Symbol) = getproperty(n.machine.model, s)
setindex!(n::Node{<:NodalMachine{<:Model}}, v, s::Symbol) = setproperty!(n.machine.model, s, v)
109 changes: 109 additions & 0 deletions test/arrows.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
module TestArrows

using MLJ
using MLJBase
using Test
using Random

@testset "|> syntax for pipelines" begin
Random.seed!(142)
@load RidgeRegressor pkg="MultivariateStats"
@load KNNRegressor pkg="NearestNeighbors"
X = MLJBase.table(randn(500, 5))
y = abs.(randn(500))
train, test = partition(eachindex(y), 0.7)

# Feeding data directly to a supervised model
knn = KNNRegressor(K=10)
ŷ = (X, y) |> knn
fit!(ŷ, rows=train)

# Describing a full pipeline using |> syntax.
Xs, ys = source.((X, y))

# "first layer"
W = Xs |> Standardizer()
z = ys |> UnivariateBoxCoxTransformer()
# "second layer"
ẑ = (W, z) |> RidgeRegressor(lambda=0.1)
# "output layer"
ŷ = ẑ |> inverse_transform(z)

fit!(ŷ, rows=train)

@test isapprox(rms(ŷ(rows=test), ys(rows=test)), 0.627123, rtol=1e-4)

# shortcut to get and set hyperparameters of a node
ẑ[:lambda] = 5.0
fit!(ŷ, rows=train)
@test isapprox(rms(ŷ(rows=test), ys(rows=test)), 0.62699, rtol=1e-4)
end

@testset "Auto-source" begin
@load PCA
@load RidgeRegressor pkg="MultivariateStats"
Random.seed!(5615151)

X = MLJBase.table(randn(500, 5))
y = abs.(randn(500))

pca = X |> Standardizer() |> PCA(maxoutdim=2)
fit!(pca)

W = pca()
sch = schema(W)
@test sch.names == (:x1, :x2)
@test sch.scitypes == (Continuous, Continuous)
@test sch.nrows == 500

pipe = (pca, y) |> RidgeRegressor()
fit!(pipe)

ŷ = pipe()
@test ŷ isa Vector{Float64}
@test length(ŷ) == 500
end

@testset "Auto-table" begin
@load PCA
@load RidgeRegressor pkg="MultivariateStats"
Random.seed!(5615151)

X = randn(500, 5)
y = abs.(randn(500))

pca = X |> Standardizer() |> PCA(maxoutdim=2)
pipe = (pca, y) |> RidgeRegressor()
fit!(pipe)

ŷ = pipe()
@test ŷ isa Vector{Float64}
@test length(ŷ) == 500
end

@testset "Stacking" begin
@load PCA
@load RidgeRegressor pkg=MultivariateStats
@load DecisionTreeRegressor pkg=DecisionTree
Random.seed!(5615151)

X = randn(500, 5)
y = abs.(randn(500))

W = X |> Standardizer() |> PCA(maxoutdim=3)
z = y |> UnivariateBoxCoxTransformer()
ẑ₁ = (W, z) |> RidgeRegressor()
ẑ₂ = (W, z) |> DecisionTreeRegressor()
R = hcat(ẑ₁, ẑ₂)
ẑ = (R, z) |> DecisionTreeRegressor()
ŷ = ẑ |> inverse_transform(z)

fit!(ŷ)

p̂ = ŷ()
@test p̂ isa Vector{Float64}
@test length(p̂) == 500
end

end
true
45 changes: 0 additions & 45 deletions test/operations.jl

This file was deleted.

6 changes: 4 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ end
@test include("networks.jl")
end

@testset "operations" begin
@test include("operations.jl")
if VERSION ≥ v"1.3.0-"
@testset "arrows" begin
@test include("arrows.jl")
end
end

@testset "composites" begin
Expand Down