From 3931a52747fe4f9bdffb0261cc432fc4819d6600 Mon Sep 17 00:00:00 2001 From: aknudson Date: Sat, 6 Apr 2024 16:30:19 -0700 Subject: [PATCH 1/4] Added initial interface for `fit!(o, y, n)` --- src/OnlineStatsBase.jl | 68 ++++++++++++++++++++++++++++++++---------- test/test_stats.jl | 55 ++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 16 deletions(-) diff --git a/src/OnlineStatsBase.jl b/src/OnlineStatsBase.jl index dc57b32..0bd63f8 100644 --- a/src/OnlineStatsBase.jl +++ b/src/OnlineStatsBase.jl @@ -104,13 +104,39 @@ the type of a single observation for the provided `stat`, `fit!` will attempt to through and `fit!` each item in `data`. Therefore, `fit!(Mean(), 1:10)` translates roughly to: - o = Mean() +```julia +o = Mean() - for x in 1:10 - fit!(o, x) +for x in 1:10 + fit!(o, x) +end +``` +""" +fit!(o::OnlineStat{T}, y::T) where {T} = (_fit!(o, y); return o) + +function fit!(o::OnlineStat{I}, y::T) where {I, T} + T == eltype(y) && error("The input for $(name(o,false,false)) is $I. Found $T.") + for yi in y + fit!(o, yi) end + o +end + """ -fit!(o::OnlineStat{T}, yi::T) where {T} = (_fit!(o, yi); return o) + fit!(stat::OnlineStat, y, n) + +Update the "sufficient statistics" of a `stat` with multiple observations of a single value. +Unless a specialized formula is used, `fit!(Mean(), 10, 5)` is equivalent to: + +```julia +o = Mean() + +for _ in 1:5 + fit!(o, 10) +end +``` +""" +fit!(o::OnlineStat{T}, y::T, n::Integer) where {T} = (_fit!(o, y, n); return o) """ fit!(stat1::OnlineStat, stat2::OnlineStat) @@ -121,23 +147,33 @@ Useful for reductions of OnlineStats using `fit!`. # Example - julia> v = [reduce(fit!, [1, 2, 3], init=Mean()) for _ in 1:3] - 3-element Vector{Mean{Float64, EqualWeight}}: - Mean: n=3 | value=2.0 - Mean: n=3 | value=2.0 - Mean: n=3 | value=2.0 +```julia-repl +julia> v = [reduce(fit!, [1, 2, 3], init=Mean()) for _ in 1:3] +3-element Vector{Mean{Float64, EqualWeight}}: +Mean: n=3 | value=2.0 +Mean: n=3 | value=2.0 +Mean: n=3 | value=2.0 - julia> reduce(fit!, v, init=Mean()) - Mean: n=9 | value=2.0 +julia> reduce(fit!, v, init=Mean()) +Mean: n=9 | value=2.0 +``` """ fit!(o::OnlineStat, o2::OnlineStat) = merge!(o, o2) -function fit!(o::OnlineStat{I}, y::T) where {I, T} - T == eltype(y) && error("The input for $(name(o,false,false)) is $I. Found $T.") - for yi in y - fit!(o, yi) +# general fallback for _fit!(o, y) that each stat must implement +function _fit!(o::OnlineStat{T}, ::S) where {T, S} + error("_fit!(o, y) is not implemented for $(name(o,false,false)). If you are writing " * + "a new statistic, then this must be implemented. If you are a user, then please " * + "submit a bug report.") +end + +# general fallback for _fit!(o, y, n) that is optional to implement +function _fit!(o::OnlineStat{T}, y::S, n::Integer) where {T, S} + for _ in 1:n + _fit!(o, y) end - o + + return o end #-----------------------------------------------------------------------# utils diff --git a/test/test_stats.jl b/test/test_stats.jl index efda3f6..40bcb3f 100644 --- a/test/test_stats.jl +++ b/test/test_stats.jl @@ -40,6 +40,13 @@ println(" > CircBuff") fit!(b, 3:11) @test b[end] == 7 @test b[1] == 11 + + # Multiple obs method + c = CircBuff(Int, 5) + fit!(c, 5, 5) + fit!(c, 10) + @test c[1] == 5 + @test c[end] == 10 end #-----------------------------------------------------------------------# Counter @@ -50,6 +57,10 @@ println(" > Counter") o2 = fit!(Counter(Int), 1) @test value(merge!(o, o2)) == 11 ==(mergevals(Counter(), y, y2)...) + + # Multiple obs method + o3 = fit!(Counter(Int), 1, 10) + @test (value(o3)) == 10 end #-----------------------------------------------------------------------# CountMap @@ -82,6 +93,13 @@ println(" > CountMap") # Pair method @test ==(mergevals(CountMap(Bool), Pair.(x,z), Pair.(x2,z2); nobs_equals_length=false)...) @test ==(mergevals(CountMap(Int), Pair.(z,z), Pair.(z2,z2); nobs_equals_length=false)...) + + # Multiple obs method + c = fit!(CountMap(Bool), true, 10) + fit!(c, false, 5) + @test nobs(c) == 15 + @test c[true] == 10 + @test c[false] == 5 end #-----------------------------------------------------------------------# CountMissing println(" > CountMissing") @@ -143,6 +161,15 @@ println(" > Extrema") o = fit!(Extrema(), x) @test o.nmin == length(x) - sum(x) @test o.nmax == sum(x) + + # Multiple obs method + o = fit!(Extrema(), y) + fit!(o, 20, 5) + fit!(o, -20, 7) + @test o.nmax == 5 + @test o.nmin == 7 + @test maximum(o) == 20 + @test minimum(o) == -20 end #-----------------------------------------------------------------------------# ExtremeValues println(" > ExtremeValues") @@ -227,6 +254,12 @@ println(" > Mean") @test value(o) ≈ mean(y) @test mean(o) ≈ mean(y) @test ≈(mergevals(Mean(), y, y2)...) + + # Multiple obs method + o = fit!(Mean(), y) + fit!(o, 1.0, 4) + v = vcat(copy(y), [1.0, 1.0, 1.0, 1.0]) + @test mean(o) ≈ mean(v) end #-----------------------------------------------------------------------# Moments println(" > Moments") @@ -241,6 +274,17 @@ println(" > Moments") for (v1,v2) in zip(mergevals(Moments(), y, y2)...) @test v1 ≈ v2 end + + # Multiple obs method + o = fit!(Moments(), y) + fit!(o, 1.0, 4) + v = vcat(copy(y), [1.0, 1.0, 1.0, 1.0]) + @test value(o) ≈ [mean(v), mean(v .^ 2), mean(v .^ 3), mean(v .^ 4)] + @test mean(o) ≈ mean(v) + @test var(o) ≈ var(v) + @test std(o) ≈ std(v) + @test skewness(o) ≈ skewness(v) + @test kurtosis(o) ≈ kurtosis(v) end #-----------------------------------------------------------------------# Series @@ -281,6 +325,9 @@ println(" > Sum") @test ==(mergevals(Sum(Int), x, x2)...) @test ≈(mergevals(Sum(), y, y2)...) @test ==(mergevals(Sum(Int), z, z2)...) + + # Multiple obs method + @test sum(fit!(Sum(Int), 10, 5)) == 50 end #-----------------------------------------------------------------------------# TryCatch @@ -312,6 +359,14 @@ println(" > Variance") @test std(fit!(Variance(), [1, 2])) == sqrt(.5) # https://github.com/joshday/OnlineStats.jl/issues/217 @test value(fit!(Variance(Float32), randn(Float32, 10))) isa Float32 + + # Multiple obs method + o = fit!(Variance(), y) + fit!(o, 1.0, 4) + v = vcat(copy(y), [1.0, 1.0, 1.0, 1.0]) + @test mean(o) ≈ mean(v) + @test var(o) ≈ var(v) + @test std(o) ≈ std(v) end end # end "Test Stats" From 5ee5b96695ce392ba9cc7e014873d8a0516db909 Mon Sep 17 00:00:00 2001 From: aknudson Date: Sat, 6 Apr 2024 16:45:26 -0700 Subject: [PATCH 2/4] relaxed type constraint on internal _fit!(o, y, n) --- src/OnlineStatsBase.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/OnlineStatsBase.jl b/src/OnlineStatsBase.jl index 0bd63f8..124b42a 100644 --- a/src/OnlineStatsBase.jl +++ b/src/OnlineStatsBase.jl @@ -161,14 +161,14 @@ Mean: n=9 | value=2.0 fit!(o::OnlineStat, o2::OnlineStat) = merge!(o, o2) # general fallback for _fit!(o, y) that each stat must implement -function _fit!(o::OnlineStat{T}, ::S) where {T, S} +function _fit!(o::OnlineStat{T}, y) where {T} error("_fit!(o, y) is not implemented for $(name(o,false,false)). If you are writing " * "a new statistic, then this must be implemented. If you are a user, then please " * "submit a bug report.") end # general fallback for _fit!(o, y, n) that is optional to implement -function _fit!(o::OnlineStat{T}, y::S, n::Integer) where {T, S} +function _fit!(o::OnlineStat{T}, y, n) where {T} for _ in 1:n _fit!(o, y) end From 490bff16b720e77091e50dfefec97bafb76b6cae Mon Sep 17 00:00:00 2001 From: aknudson Date: Sat, 6 Apr 2024 16:58:53 -0700 Subject: [PATCH 3/4] Added specialized methods to fit multiple obs of a single value --- src/stats.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/stats.jl b/src/stats.jl index 8080e69..794c749 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -78,6 +78,7 @@ mutable struct Counter{T} <: OnlineStat{T} end Counter(T = Number) = Counter{T}() _fit!(o::Counter{T}, y) where {T} = (o.n += 1) +_fit!(o::Counter{T}, y, n) where {T} = (o.n += n) _merge!(a::Counter, b::Counter) = (a.n += b.n) #-----------------------------------------------------------------------# CountMap @@ -122,6 +123,10 @@ function _fit!(o::CountMap{T}, xy::Pair{<:T, <:Integer}) where {T} o.n += y o.value[x] = get!(o.value, x, 0) + y end +function _fit!(o::CountMap{T}, x, n) where {T} + o.n += n + o.value[x] = get!(o.value, x, 0) + n +end _merge!(o::CountMap, o2::CountMap) = (merge!(+, o.value, o2.value); o.n += o2.n) function probs(o::CountMap, kys = keys(o.value)) @@ -251,6 +256,18 @@ function _fit!(o::Extrema, y) y == o.min && (o.nmin += 1) y == o.max && (o.nmax += 1) end +function _fit!(o::Extrema, y, n) + (o.n += n) == n && (o.min = o.max = y) + if y < o.min + o.min = y + o.nmin = 0 + elseif y > o.max + o.max = y + o.nmax = 0 + end + y == o.min && (o.nmin += n) + y == o.max && (o.nmax += n) +end function _merge!(a::Extrema, b::Extrema) if a.min == b.min a.nmin += b.nmin @@ -451,6 +468,10 @@ Mean(T::Type{<:Number} = Float64; weight = EqualWeight()) = Mean(zero(T), weight function _fit!(o::Mean{T}, x) where {T} o.μ = smooth(o.μ, x, o.weight(o.n += 1)) end +function _fit!(o::Mean{T}, y, n) where {T} + o.n += n + o.μ = smooth(o.μ, y, n / o.n) +end function _merge!(o::Mean, o2::Mean) o.n += o2.n o.μ = smooth(o.μ, o2.μ, o2.n / o.n) @@ -524,6 +545,8 @@ Sum(T::Type = Float64) = Sum(T(0), 0) Base.sum(o::Sum) = o.sum _fit!(o::Sum{T}, x::Real) where {T<:AbstractFloat} = (o.sum += convert(T, x); o.n += 1) _fit!(o::Sum{T}, x::Real) where {T<:Integer} = (o.sum += round(T, x); o.n += 1) +_fit!(o::Sum{T}, x::Real, n) where {T<:AbstractFloat} = (o.sum += convert(T, x * n); o.n += n) +_fit!(o::Sum{T}, x::Real, n) where {T<:Integer} = (o.sum += round(T, x * n); o.n += n) _merge!(o::T, o2::T) where {T <: Sum} = (o.sum += o2.sum; o.n += o2.n; o) #-----------------------------------------------------------------------# Variance From 96069a171d72a1e61ed5286255b85085e39a1387 Mon Sep 17 00:00:00 2001 From: aknudson Date: Wed, 10 Apr 2024 18:45:34 -0700 Subject: [PATCH 4/4] Added stricter type for specialization on Mean --- src/stats.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/stats.jl b/src/stats.jl index 794c749..3a0cce0 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -468,9 +468,9 @@ Mean(T::Type{<:Number} = Float64; weight = EqualWeight()) = Mean(zero(T), weight function _fit!(o::Mean{T}, x) where {T} o.μ = smooth(o.μ, x, o.weight(o.n += 1)) end -function _fit!(o::Mean{T}, y, n) where {T} +function _fit!(o::Mean{T, W}, y, n) where {T, W<:EqualWeight} o.n += n - o.μ = smooth(o.μ, y, n / o.n) + o.μ = smooth(o.μ, y, o.weight(o.n / n)) end function _merge!(o::Mean, o2::Mean) o.n += o2.n