Skip to content

Commit

Permalink
Merge pull request #518 from alan-turing-institute/dev
Browse files Browse the repository at this point in the history
For a 0.17.4 release
  • Loading branch information
ablaom authored Mar 5, 2021
2 parents 1cca7c5 + be28989 commit 4070381
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBase"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.17.3"
version = "0.17.4"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
1 change: 1 addition & 0 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import PrettyTables
using DelimitedFiles
using OrderedCollections
using CategoricalArrays
import CategoricalArrays.DataAPI.unwrap
import InvertedIndices: Not
import JLSO
import Dates
Expand Down
70 changes: 35 additions & 35 deletions src/univariate_finite/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ function Base.Broadcast.broadcasted(
f() = zeros(P, size(u)) #default caller function

return Base.Broadcast.Broadcasted(
identity,
(get(f, u.prob_given_ref, int(cv)),)
)
identity,
(get(f, u.prob_given_ref, int(cv)),)
)
end

# pdf.(u, v)
Expand Down Expand Up @@ -176,41 +176,41 @@ end
# logpdf.(u::UniFinArr{S,V,R,P,N}, raw::AbstractArray{V,N})
# logpdf.(u::UniFinArr{S,V,R,P,N}, raw::V)
for typ in (:CategoricalValue,
:(AbstractArray{<:CategoricalValue{V,R},N}),
:V,
:(AbstractArray{V,N}))
:(AbstractArray{<:CategoricalValue{V,R},N}),
:V,
:(AbstractArray{V,N}))
if typ == :CategoricalValue || typ == :V
eval(quote
function Base.Broadcast.broadcasted(
::typeof(logpdf),
u::UniFinArr{S,V,R,P,N},
c::$typ) where {S,V,R,P,N}

# Start with the pdf array
# take advantage of loop fusion
result = log.(pdf.(u, c))
return result
end
end)
eval(quote
function Base.Broadcast.broadcasted(
::typeof(logpdf),
u::UniFinArr{S,V,R,P,N},
c::$typ) where {S,V,R,P,N}

# Start with the pdf array
# take advantage of loop fusion
result = log.(pdf.(u, c))
return result
end
end)

else
eval(quote
function Base.Broadcast.broadcasted(
::typeof(logpdf),
u::UniFinArr{S,V,R,P,N},
c::$typ) where {S,V,R,P,N}

# Start with the pdf array
result = pdf.(u, c)

# Take the log of each entry in-place
@simd for j in eachindex(result)
@inbounds result[j] = log(result[j])
end

return result
end
end)
eval(quote
function Base.Broadcast.broadcasted(
::typeof(logpdf),
u::UniFinArr{S,V,R,P,N},
c::$typ) where {S,V,R,P,N}

# Start with the pdf array
result = pdf.(u, c)

# Take the log of each entry in-place
@simd for j in eachindex(result)
@inbounds result[j] = log(result[j])
end

return result
end
end)
end

end
Expand Down
13 changes: 7 additions & 6 deletions src/univariate_finite/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
const UnivariateFiniteUnion =
Union{UnivariateFinite, UnivariateFiniteArray}

# NOT EXPORTED!!!!
# ABOVE NOT EXPORTED!!!!

"""
classes(d::UnivariateFinite)
Expand All @@ -22,19 +22,20 @@ MMI.classes(d::UnivariateFiniteUnion) = d.decoder.classes
levels(d::UnivariateFinite)
A list of the raw levels in the common pool of classes used to
construct `d`, equal to `get.(classes(d))`.
construct `d`, equal to
`CategoricalArrays.DataAPI.unwrap.(classes(d))`.
v = categorical(["yes", "maybe", "no", "yes"])
d = UnivariateFinite(v[1:2], [0.3, 0.7])
levels(d) # Array{String, 1}["maybe", "no", "yes"]
"""
levels(d::UnivariateFinite) = get.(classes(d))
levels(d::UnivariateFinite) = CategoricalArrays.DataAPI.unwrap.(classes(d))

function Distributions.params(d::UnivariateFinite)
raw = raw_support(d) # reflects order of pool at instantiation of d
pairs = tuple([get(d.decoder(r))=>d.prob_given_ref[r] for r in raw]...)
levs = get.(classes(d))
pairs = tuple([unwrap.(d.decoder(r))=>d.prob_given_ref[r] for r in raw]...)
levs = unwrap.(classes(d))
return (levels=levs, probs=pairs)
end

Expand Down Expand Up @@ -106,7 +107,7 @@ show_prefix(u::UnivariateFiniteArray) = join(size(u),'x')

# function Base.show(io::IO, m::MIME"text/plain",
# u::UnivariateFiniteArray{S,V,R,P,1}) where {S,V,R,P}
# support = get.(Dist.support(u))
# support = unwrap.(Dist.support(u))
# print(io, show_prefix(u), " UnivariateFiniteArray{$S,$V,$R,$P,1}")
# if !isempty(u)
# println(io, ":")
Expand Down
9 changes: 5 additions & 4 deletions src/univariate_finite/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,17 @@ function MMI.UnivariateFinite(::FI, d::AbstractDict{V,<:Prob};
c in raw_support
end

prob_given_class = LittleDict([c=>d[get(c)] for c in support])
prob_given_class =
LittleDict([c=>d[CategoricalArrays.DataAPI.unwrap(c)] for c in support])

return UnivariateFinite(FI(), prob_given_class)
end


## CONSTRUCTORS - FROM ARRAYS

# example: _get(A, 4) = A[:, :, 4] if A has 3 dims:
_get(probs::AbstractArray{<:Any,N}, i) where N = probs[fill(:,N-1)..., i]
# example: _get_on_last(A, 4) = A[:, :, 4] if A has 3 dims:
_get_on_last(probs::AbstractArray{<:Any,N}, i) where N = probs[fill(:,N-1)..., i]

# 1. Univariate Finite from a vector of classes or raw labels and
# array of probs; first, a dispatcher:
Expand Down Expand Up @@ -237,7 +238,7 @@ function _UnivariateFinite(support::AbstractVector{CategoricalValue{V,R}},
LittleDict{CategoricalValue{V,R}, AbstractArray{P,N}}()
end
for i in eachindex(support)
prob_given_class[support[i]] = _get(_probs, i)
prob_given_class[support[i]] = _get_on_last(_probs, i)
end

# calls dictionary constructor above:
Expand Down
2 changes: 1 addition & 1 deletion test/_models/Transformers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ end
# inverse transforming a categorical value:
function MLJBase.inverse_transform(
transformer::UnivariateDiscretizer, result, e::CategoricalValue)
k = get(e)
k = MLJBase.unwrap(e)
return inverse_transform(transformer, result, k)
end

Expand Down
8 changes: 4 additions & 4 deletions test/univariate_finite/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ u = UnivariateFinite(all_classes, P, augment=true) #uni_fin_arr
@test isequal(logpdf(u, ["yes", "no"]), log.(hcat(P, 1 .- P)))
@test pdf(u, reverse(all_classes)) == hcat(P, 1 .- P)
@test isequal(logpdf(u, reverse(all_classes)), log.(hcat(P, 1 .- P)))

# test pdf(::Array{UnivariateFinite, 1}, labels) and
# logpdf(::Array{UnivariateFinite, labels)
@test pdf([u...], ["yes", "no"]) == hcat(P, 1 .- P)
Expand All @@ -129,11 +129,11 @@ end
@testset "broadcasting: pdf.(uni_fin_arr, array_same_shape) and logpdf.(uni_fin_arr, array_same_shape)" begin
v = rand(classes(u), n)
@test broadcast(pdf, u, v) == [pdf(u[i], v[i]) for i in 1:length(u)]
@test isequal(broadcast(logpdf, u, v),
@test isequal(broadcast(logpdf, u, v),
[logpdf(u[i], v[i]) for i in 1:length(u)])
@test broadcast(pdf, u, get.(v)) ==
@test broadcast(pdf, u, MLJBase.unwrap.(v)) ==
[pdf(u[i], v[i]) for i in 1:length(u)]
@test isequal(broadcast(logpdf, u, get.(v)),
@test isequal(broadcast(logpdf, u, MLJBase.unwrap.(v)),
[logpdf(u[i], v[i]) for i in 1:length(u)])
end

Expand Down

0 comments on commit 4070381

Please sign in to comment.