Skip to content

Commit

Permalink
Merge pull request #49 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
Merge dev branch for a 0.5.1 release
  • Loading branch information
ablaom authored Sep 24, 2019
2 parents 3634e1b + 00f4f31 commit 443a58e
Show file tree
Hide file tree
Showing 11 changed files with 446 additions and 24 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.5.0"
version = "0.5.1"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand All @@ -27,9 +27,10 @@ julia = "1"
[extras]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["CSV", "DataFrames", "LossFunctions", "Test", "TypedTables"]
test = ["CSV", "DataFrames", "Distances", "LossFunctions", "Test", "TypedTables"]
17 changes: 13 additions & 4 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ export selectrows, selectcols, select, nrows # data.jl
export table, levels_seen, matrix, container_type # data.jl
export partition, unpack # data.jl
export @set_defaults # utilities.jl
export @mlj_model # mlj_model_macro.jl
export metadata_model, metadata_pkg # metadata_utilities
export HANDLE_GIVEN_ID, @more, @constant # show.jl
export color_on, color_off # show.jl
export UnivariateFinite, average # distributions.jl
Expand Down Expand Up @@ -50,6 +52,7 @@ export pdf, mode, median, mean, shuffle!, categorical, shuffle, levels, levels!
export std

import Base.==
import Base: @__doc__

using Tables
using OrderedCollections # already a dependency of StatsBase
Expand Down Expand Up @@ -83,7 +86,7 @@ const DEFAULT_SHOW_DEPTH = 0
include("utilities.jl")


## BASE TYPES
## BASE TYPES

abstract type MLJType end
include("equality.jl") # equality for MLJType objects
Expand Down Expand Up @@ -116,7 +119,7 @@ abstract type UnsupervisedNetwork <: Unsupervised end
## THE MODEL INTERFACE

# every model interface must implement a `fit` method of the form
# `fit(model, verbosity::Integer, training_args...) -> fitresult, cache, report`
# `fit(model, verbosity::Integer, training_args...) -> fitresult, cache, report`
# or, one the simplified versions
# `fit(model, training_args...) -> fitresult`
# `fit(model, X, ys...) -> fitresult`
Expand Down Expand Up @@ -169,14 +172,14 @@ clean!(model::Model) = ""

## TRAITS

"""
"""
info(object)
List the traits of an object, such as a model or a performance measure.
"""
info(object) = info(object, Val(ScientificTypes.trait(object)))
info(object) = info(object, Val(ScientificTypes.trait(object)))


include("model_traits.jl")
Expand All @@ -199,6 +202,12 @@ include("datasets.jl") # importing CSV will also load datasets_requires.jl
include("tasks.jl")
include("measures.jl")

# mlj model macro to help define models
include("mlj_model_macro.jl")

# metadata utils
include("metadata_utilities.jl")

# __init__() function:
include("init.jl")

Expand Down
12 changes: 6 additions & 6 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ function partition(rows::AbstractVector{Int}, fractions...; shuffle::Bool=false,
end

"""
t1, t2, ...., tk = unnpack(table, c1, c2, ... ck; wrap_singles=false)
t1, t2, ...., tk = unnpack(table, t1, t2, ... tk; wrap_singles=false)
Split any Tables.jl compatible `table` into smaller tables (or
vectors) `t1, t2, ..., tk` by making selections *without replacement*
from the column names defined by the conditionals `c1`, `c2`, ...,
`ck`. A *conditional* is any object `c` such that `c(name)` is `true`
from the column names defined by the tests `t1`, `t2`, ...,
`tk`. A *test* is any object `t` such that `t(name)` is `true`
or `false` for each column `name::Symbol` of `table`.
Whenever a returned table contains a single column, it is converted to
Expand All @@ -59,7 +59,7 @@ a vector unless `wrap_singles=true`.
Scientific type conversions can be optionally specified (note
semicolon):
unpack(table, c...; wrap_singles=false, col1=>scitype1, col2=>scitype2, ... )
unpack(table, t...; wrap_singles=false, col1=>scitype1, col2=>scitype2, ... )
### Example
Expand All @@ -82,7 +82,7 @@ julia> Z
```
"""
function unpack(X, conditionals...; wrap_singles=false, pairs...)
function unpack(X, tests...; wrap_singles=false, pairs...)

if isempty(pairs)
Xfixed = X
Expand All @@ -94,7 +94,7 @@ function unpack(X, conditionals...; wrap_singles=false, pairs...)
names_left = schema(Xfixed).names |> collect
history = ""
counter = 1
for c in conditionals
for c in tests
names = filter(c, names_left)
filter!(!in(names), names_left)
history *= "selection $counter: $names\n remaining: $names_left\n"
Expand Down
32 changes: 29 additions & 3 deletions src/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

const Dist = Distributions


## EQUALITY OF DISTRIBUTIONS (minor type piracy)

# TODO: We should get rid of this. I think it is used only in
# MLJModels/test.

function ==(d1::D, d2::D) where D<:Dist.Sampleable
ret = true
for fld in fieldnames(D)
Expand Down Expand Up @@ -108,8 +110,10 @@ end

function UnivariateFinite(classes::AbstractVector{L},
p::AbstractVector{<:Real}) where L
L <: CategoricalElement || error("classes must have CategoricalValue or "*
"CategoricalString type.")
L <: CategoricalElement ||
error("`classes` must have type `AbstractVector{T}` where "*
"`T <: Union{CategoricalValue,CategoricalString}. "*
"Perhaps you have `T=Any`? ")
Dist.@check_args(UnivariateFinite, length(classes)==length(p))
prob_given_class = LittleDict([classes[i]=>p[i] for i in eachindex(p)])
return UnivariateFinite(prob_given_class)
Expand Down Expand Up @@ -138,6 +142,28 @@ function Base.show(stream::IO, d::UnivariateFinite)
print(stream, str)
end

"""
isapprox(d1::UnivariateFinite, d2::UnivariateFinite; kwargs...)
Returns `true` if and only if `Set(classes(d1) == Set(classes(d2))`
and the corresponding probabilities are approximately equal. The
key-word arguments `kwargs` are passed through to each call of
`isapprox` on probabiliity pairs. Returns `false` otherwise.
"""
function Base.isapprox(d1::UnivariateFinite, d2::UnivariateFinite; kwargs...)

classes1 = classes(d1)
classes2 = classes(d2)

for c in classes1
c in classes2 || return false
isapprox(pdf(d1, c), pdf(d2, c); kwargs...) ||
return false # pdf defined below
end
return true
end

function average(dvec::AbstractVector{UnivariateFinite{L,U,T}};
weights=nothing) where {L,U,T}

Expand Down
57 changes: 57 additions & 0 deletions src/metadata_utilities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
docstring_ext
Helper function to generate the docstring for a package.
"""
function docstring_ext(T; descr::String="")
package_name = MLJBase.package_name(T)
package_url = MLJBase.package_url(T)
model_name = MLJBase.name(T)
# the message to return
message = "$descr"
message *= "\n→ based on [$package_name]($package_url)"
message *= "\n→ do `@load $model_name` to use the model"
message *= "\n→ do `?$model_name` for documentation."
end

"""
metadata_pkg
Helper function to write the metadata for a package.
"""
function metadata_pkg(T; name::String="unknown", uuid::String="unknown", url::String="unknown",
julia::Union{Missing,Bool}=missing, license::String="unknown",
is_wrapper::Bool=false)
ex = quote
package_name(::Type{<:$T}) = $name
package_uuid(::Type{<:$T}) = $uuid
package_url(::Type{<:$T}) = $url
is_pure_julia(::Type{<:$T}) = $julia
package_license(::Type{<:$T}) = $license
is_wrapper(::Type{<:$T}) = $is_wrapper
end
eval(ex)
end

"""
metadata_model
Helper function to write the metadata for a single model of a package (complements
[`metadata_ext`](@ref)).
"""
function metadata_model(T; input=Unknown, target=Unknown,
output=Unknown, weights::Bool=false,
descr::String="", path::String="")
if isempty(path)
path = "MLJModels.$(package_name(T))_.$(name(T))"
end
ex = quote
input_scitype(::Type{<:$T}) = $input
output_scitype(::Type{<:$T}) = $output
target_scitype(::Type{<:$T}) = $target
supports_weights(::Type{<:$T}) = $weights
docstring(::Type{<:$T}) = docstring_ext($T, descr=$descr)
load_path(::Type{<:$T}) = $path
end
eval(ex)
end
Loading

0 comments on commit 443a58e

Please sign in to comment.