Skip to content

Commit

Permalink
Merge pull request #55 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
Towards 0.7.0
  • Loading branch information
ablaom authored Oct 14, 2019
2 parents 23506f6 + 29b050a commit 48e4346
Show file tree
Hide file tree
Showing 10 changed files with 379 additions and 238 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.6.0"
version = "0.7.0"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -17,8 +18,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
CategoricalArrays = "<0.5.3, 0.7"
CSV = "0.5"
CategoricalArrays = "<0.5.3"
Requires = "^0.5.2"
ScientificTypes = "0.2.0"
Tables = "<0.1.19, >= 0.2"
Expand Down
3 changes: 3 additions & 0 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ export is_feature_dependent # measures.jl
export default_measure, value # measures.jl
export mav, mae, rms, rmsl, rmslp1, rmsp, l1, l2 # measures.jl
export misclassification_rate, cross_entropy # measures.jl
export BrierScore # measures.jl

# methods from other packages to be rexported:
export pdf, mean, mode
Expand Down Expand Up @@ -66,6 +67,8 @@ import ScientificTypes: trait

# to be extended:
import StatsBase: fit, predict, fit!
import Missings.levels


# from Standard Library:
using Statistics
Expand Down
26 changes: 10 additions & 16 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ CategoricalElement{U} = Union{CategoricalValue{<:Any,U},CategoricalString{U}}
"""
classes(x)
All the categorical elements with the same pool as `x` (including `x`),
All the categorical elements with in the same pool as `x` (including `x`),
returned as a list, with an ordering consistent with the pool. Here
`x` has `CategoricalValue` or `CategoricalString` type, and
`classes(x)` is a vector of the same eltype.
Expand All @@ -139,10 +139,8 @@ true, `x in x.pool.levels` is not true.
:c
"""
function classes(x::CategoricalElement)
p = x.pool
return [p.valindex[p.invindex[v]] for v in p.levels]
end
classes(p::CategoricalPool) = [p.valindex[p.invindex[v]] for v in p.levels]
classes(x::CategoricalElement) = classes(x.pool)

"""
int(x)
Expand Down Expand Up @@ -176,11 +174,11 @@ Broadcasted versions of `int`.
See also: [`decoder`](@ref).
"""
int(x::CategoricalElement) = x.pool.order[x.pool.invindex[x]]
int(A::AbstractArray{<:CategoricalElement}) = broadcast(int, A)
# workaround for CategoricalArrays issue
# https://github.com/JuliaData/CategoricalArrays.jl/issues/199:
# function int(X::CategoricalArray)
int(A::AbstractArray) = broadcast(int, A)

# get the integer representation of a level given pool (private
# method):
int(pool::CategoricalPool, level) = pool.order[pool.invindex[level]]

struct CategoricalDecoder{T,R} # <: MLJType
pool::CategoricalPool{T,R}
Expand All @@ -207,17 +205,13 @@ integer arrays, in which case `d` is broadcast over all elements.
julia> d(int(v)) == v
true
*Warning:* It is *not* true that `int(d(u)) == u` always holds.
See also: [`int`](@ref), [`classes`](@ref).
"""
decoder(element::CategoricalElement) =
CategoricalDecoder(element.pool, sortperm(element.pool.order))
## in the next lot need to skip the missing one
# decoder(X::CategoricalArray) = CategoricalDecoder(X.pool)
# function decoder(V::Array{<:CategoricalElement})
# isempty(V) && error("Unable to extract decoder from empty array. ")
# return X[1]
# end

(decoder::CategoricalDecoder{T,R})(i::Integer) where {T,R} =
CategoricalValue{T,R}(decoder.invorder[i], decoder.pool)
Expand Down
Loading

0 comments on commit 48e4346

Please sign in to comment.