Skip to content

Commit

Permalink
Merge pull request #40 from adknudson/master
Browse files Browse the repository at this point in the history
Add `fit!(o, y, n)` to fit multiple observations of the same value
  • Loading branch information
joshday authored Apr 19, 2024
2 parents aabf395 + 96069a1 commit a217f23
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 16 deletions.
68 changes: 52 additions & 16 deletions src/OnlineStatsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,39 @@ the type of a single observation for the provided `stat`, `fit!` will attempt to
through and `fit!` each item in `data`. Therefore, `fit!(Mean(), 1:10)` translates
roughly to:
o = Mean()
```julia
o = Mean()
for x in 1:10
fit!(o, x)
for x in 1:10
fit!(o, x)
end
```
"""
fit!(o::OnlineStat{T}, y::T) where {T} = (_fit!(o, y); return o)

function fit!(o::OnlineStat{I}, y::T) where {I, T}
T == eltype(y) && error("The input for $(name(o,false,false)) is $I. Found $T.")
for yi in y
fit!(o, yi)
end
o
end

"""
fit!(o::OnlineStat{T}, yi::T) where {T} = (_fit!(o, yi); return o)
fit!(stat::OnlineStat, y, n)
Update the "sufficient statistics" of a `stat` with multiple observations of a single value.
Unless a specialized formula is used, `fit!(Mean(), 10, 5)` is equivalent to:
```julia
o = Mean()
for _ in 1:5
fit!(o, 10)
end
```
"""
fit!(o::OnlineStat{T}, y::T, n::Integer) where {T} = (_fit!(o, y, n); return o)

"""
fit!(stat1::OnlineStat, stat2::OnlineStat)
Expand All @@ -121,23 +147,33 @@ Useful for reductions of OnlineStats using `fit!`.
# Example
julia> v = [reduce(fit!, [1, 2, 3], init=Mean()) for _ in 1:3]
3-element Vector{Mean{Float64, EqualWeight}}:
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
```julia-repl
julia> v = [reduce(fit!, [1, 2, 3], init=Mean()) for _ in 1:3]
3-element Vector{Mean{Float64, EqualWeight}}:
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
julia> reduce(fit!, v, init=Mean())
Mean: n=9 | value=2.0
julia> reduce(fit!, v, init=Mean())
Mean: n=9 | value=2.0
```
"""
fit!(o::OnlineStat, o2::OnlineStat) = merge!(o, o2)

function fit!(o::OnlineStat{I}, y::T) where {I, T}
T == eltype(y) && error("The input for $(name(o,false,false)) is $I. Found $T.")
for yi in y
fit!(o, yi)
# general fallback for _fit!(o, y) that each stat must implement
function _fit!(o::OnlineStat{T}, y) where {T}
error("_fit!(o, y) is not implemented for $(name(o,false,false)). If you are writing " *
"a new statistic, then this must be implemented. If you are a user, then please " *
"submit a bug report.")
end

# general fallback for _fit!(o, y, n) that is optional to implement
function _fit!(o::OnlineStat{T}, y, n) where {T}
for _ in 1:n
_fit!(o, y)
end
o

return o
end

#-----------------------------------------------------------------------# utils
Expand Down
23 changes: 23 additions & 0 deletions src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ mutable struct Counter{T} <: OnlineStat{T}
end
Counter(T = Number) = Counter{T}()
_fit!(o::Counter{T}, y) where {T} = (o.n += 1)
_fit!(o::Counter{T}, y, n) where {T} = (o.n += n)
_merge!(a::Counter, b::Counter) = (a.n += b.n)

#-----------------------------------------------------------------------# CountMap
Expand Down Expand Up @@ -122,6 +123,10 @@ function _fit!(o::CountMap{T}, xy::Pair{<:T, <:Integer}) where {T}
o.n += y
o.value[x] = get!(o.value, x, 0) + y
end
function _fit!(o::CountMap{T}, x, n) where {T}
o.n += n
o.value[x] = get!(o.value, x, 0) + n
end

_merge!(o::CountMap, o2::CountMap) = (merge!(+, o.value, o2.value); o.n += o2.n)
function probs(o::CountMap, kys = keys(o.value))
Expand Down Expand Up @@ -251,6 +256,18 @@ function _fit!(o::Extrema, y)
y == o.min && (o.nmin += 1)
y == o.max && (o.nmax += 1)
end
function _fit!(o::Extrema, y, n)
(o.n += n) == n && (o.min = o.max = y)
if y < o.min
o.min = y
o.nmin = 0
elseif y > o.max
o.max = y
o.nmax = 0
end
y == o.min && (o.nmin += n)
y == o.max && (o.nmax += n)
end
function _merge!(a::Extrema, b::Extrema)
if a.min == b.min
a.nmin += b.nmin
Expand Down Expand Up @@ -451,6 +468,10 @@ Mean(T::Type{<:Number} = Float64; weight = EqualWeight()) = Mean(zero(T), weight
function _fit!(o::Mean{T}, x) where {T}
o.μ = smooth(o.μ, x, o.weight(o.n += 1))
end
function _fit!(o::Mean{T, W}, y, n) where {T, W<:EqualWeight}
o.n += n
o.μ = smooth(o.μ, y, o.weight(o.n / n))
end
function _merge!(o::Mean, o2::Mean)
o.n += o2.n
o.μ = smooth(o.μ, o2.μ, o2.n / o.n)
Expand Down Expand Up @@ -524,6 +545,8 @@ Sum(T::Type = Float64) = Sum(T(0), 0)
Base.sum(o::Sum) = o.sum
_fit!(o::Sum{T}, x::Real) where {T<:AbstractFloat} = (o.sum += convert(T, x); o.n += 1)
_fit!(o::Sum{T}, x::Real) where {T<:Integer} = (o.sum += round(T, x); o.n += 1)
_fit!(o::Sum{T}, x::Real, n) where {T<:AbstractFloat} = (o.sum += convert(T, x * n); o.n += n)
_fit!(o::Sum{T}, x::Real, n) where {T<:Integer} = (o.sum += round(T, x * n); o.n += n)
_merge!(o::T, o2::T) where {T <: Sum} = (o.sum += o2.sum; o.n += o2.n; o)

#-----------------------------------------------------------------------# Variance
Expand Down
55 changes: 55 additions & 0 deletions test/test_stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ println(" > CircBuff")
fit!(b, 3:11)
@test b[end] == 7
@test b[1] == 11

# Multiple obs method
c = CircBuff(Int, 5)
fit!(c, 5, 5)
fit!(c, 10)
@test c[1] == 5
@test c[end] == 10
end

#-----------------------------------------------------------------------# Counter
Expand All @@ -50,6 +57,10 @@ println(" > Counter")
o2 = fit!(Counter(Int), 1)
@test value(merge!(o, o2)) == 11
==(mergevals(Counter(), y, y2)...)

# Multiple obs method
o3 = fit!(Counter(Int), 1, 10)
@test (value(o3)) == 10
end

#-----------------------------------------------------------------------# CountMap
Expand Down Expand Up @@ -82,6 +93,13 @@ println(" > CountMap")
# Pair method
@test ==(mergevals(CountMap(Bool), Pair.(x,z), Pair.(x2,z2); nobs_equals_length=false)...)
@test ==(mergevals(CountMap(Int), Pair.(z,z), Pair.(z2,z2); nobs_equals_length=false)...)

# Multiple obs method
c = fit!(CountMap(Bool), true, 10)
fit!(c, false, 5)
@test nobs(c) == 15
@test c[true] == 10
@test c[false] == 5
end
#-----------------------------------------------------------------------# CountMissing
println(" > CountMissing")
Expand Down Expand Up @@ -143,6 +161,15 @@ println(" > Extrema")
o = fit!(Extrema(), x)
@test o.nmin == length(x) - sum(x)
@test o.nmax == sum(x)

# Multiple obs method
o = fit!(Extrema(), y)
fit!(o, 20, 5)
fit!(o, -20, 7)
@test o.nmax == 5
@test o.nmin == 7
@test maximum(o) == 20
@test minimum(o) == -20
end
#-----------------------------------------------------------------------------# ExtremeValues
println(" > ExtremeValues")
Expand Down Expand Up @@ -227,6 +254,12 @@ println(" > Mean")
@test value(o) mean(y)
@test mean(o) mean(y)
@test (mergevals(Mean(), y, y2)...)

# Multiple obs method
o = fit!(Mean(), y)
fit!(o, 1.0, 4)
v = vcat(copy(y), [1.0, 1.0, 1.0, 1.0])
@test mean(o) mean(v)
end
#-----------------------------------------------------------------------# Moments
println(" > Moments")
Expand All @@ -241,6 +274,17 @@ println(" > Moments")
for (v1,v2) in zip(mergevals(Moments(), y, y2)...)
@test v1 v2
end

# Multiple obs method
o = fit!(Moments(), y)
fit!(o, 1.0, 4)
v = vcat(copy(y), [1.0, 1.0, 1.0, 1.0])
@test value(o) [mean(v), mean(v .^ 2), mean(v .^ 3), mean(v .^ 4)]
@test mean(o) mean(v)
@test var(o) var(v)
@test std(o) std(v)
@test skewness(o) skewness(v)
@test kurtosis(o) kurtosis(v)
end

#-----------------------------------------------------------------------# Series
Expand Down Expand Up @@ -281,6 +325,9 @@ println(" > Sum")
@test ==(mergevals(Sum(Int), x, x2)...)
@test (mergevals(Sum(), y, y2)...)
@test ==(mergevals(Sum(Int), z, z2)...)

# Multiple obs method
@test sum(fit!(Sum(Int), 10, 5)) == 50
end

#-----------------------------------------------------------------------------# TryCatch
Expand Down Expand Up @@ -312,6 +359,14 @@ println(" > Variance")
@test std(fit!(Variance(), [1, 2])) == sqrt(.5)
# https://github.com/joshday/OnlineStats.jl/issues/217
@test value(fit!(Variance(Float32), randn(Float32, 10))) isa Float32

# Multiple obs method
o = fit!(Variance(), y)
fit!(o, 1.0, 4)
v = vcat(copy(y), [1.0, 1.0, 1.0, 1.0])
@test mean(o) mean(v)
@test var(o) var(v)
@test std(o) std(v)
end

end # end "Test Stats"

0 comments on commit a217f23

Please sign in to comment.