Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible type instability with Mean, Moments, Sum, Variance #35

Closed
nic-barbara opened this issue Aug 1, 2023 · 3 comments
Closed

Possible type instability with Mean, Moments, Sum, Variance #35

nic-barbara opened this issue Aug 1, 2023 · 3 comments

Comments

@nic-barbara
Copy link

nic-barbara commented Aug 1, 2023

Why are some statistics types subtyped with OnlineStat{Number}? For example:

mutable struct Mean{T,W} <: OnlineStat{Number}
    μ::T
    weight::W
    n::Int
end
Mean(T::Type{<:Number} = Float64; weight = EqualWeight()) = Mean(zero(T), weight, 0)

Is there a reason we can't have mutable struct Mean{T,W} <: OnlineStat{T} instead? This means that when input() is called on statistics like Mean() it will always return Number instead of the actual input type (eg: Float32). The same is true for Mean, Moments, Sum, and Variance.


I noticed this while playing around with a Mean/Stdev filter. My original code is as follows (and feel free to offer any suggestions on better/more efficient ways to do this, I'm new to this package).

using BenchmarkTools
using OnlineStatsBase

mutable struct MeanStdFilter{T}
    nu::Int
    tracker::OnlineStat
end

function MeanStdFilter(nu::Int; T::DataType=Float32)
    s = [Series(Mean(T), Variance(T)) for _ in 1:nu]
    return MeanStdFilter{T}(nu, Group(s...))
end

function _get_mean_var(m::MeanStdFilter{T}) where T
    vals = value.(value(m.tracker))
    return reinterpret(reshape, T, collect(vals))
end

function (m::MeanStdFilter)(x::AbstractVector)
    fit!(m.tracker, x)
    μσ2 = _get_mean_var(m)
    return (x .- μσ2[1,:]) ./ sqrt.(μσ2[2,:])
end

# Test runtime
nu = 4
T = Float32
m = MeanStdFilter(nu; T)

# @btime m(randn(T,nu));
@btime _get_mean_var(m);

Running with T = Float32 I get:

1.014 μs (18 allocations: 608 bytes)

and with T = Float64 it increases to:

549.342 ns (6 allocations: 480 bytes)

I suspect this is to do with having to convert Float64 to Float32 at some point in the pipeline because of the issue raised above.

Thanks in advance for any help!

@nic-barbara
Copy link
Author

(Copy-pasted from joshday/OnlineStats.jl#265)

@joshday
Copy link
Owner

joshday commented Aug 1, 2023

Why are some statistics types subtyped with OnlineStat{Number}?

The T in Mean{T,W} indicates how the type will be stored, not what the allowed inputs are. By leaving the supertype OnlineStat{Number}, you can do things like this:

julia> o = Mean(Complex{Float64})
Mean: n=0 | value=0.0+0.0im

# fit with non-Complex{Float64} data
julia> fit!(o, 1:10)
Mean: n=10 | value=5.5+0.0im

If the supertype here was OnlineStat{Complex{Float64}}, fit!(o, 1:10) would be an error (due to how the OnlineStatsBase interface works) and users would need to change this to something like:

# this is kinda annoying to do
fit!(o, Complex{Float64}(i) for i in 1:10)

In other words, the S in OnlineStat{S} is intended to be "wide". It doesn't hurt inference/cause type instability to do so.


I started holding a baby at this exact moment, so the rest will be brief 😄

Note:

mutable struct MeanStdFilter{T}
    nu::Int
    tracker::OnlineStat  # avoid abstract types here
end

If I'm following your code correctly, here's a simpler implementation:

julia> o = Group([Variance() for i in 1:5]...)
Group
├─ Variance: n=0 | value=1.0
├─ Variance: n=0 | value=1.0
├─ Variance: n=0 | value=1.0
├─ Variance: n=0 | value=1.0
└─ Variance: n=0 | value=1.0


julia> fit!(o, randn(5) for _ in 1:10^6)
Group
├─ Variance: n=1_000_000 | value=0.999672
├─ Variance: n=1_000_000 | value=0.998411
├─ Variance: n=1_000_000 | value=0.999115
├─ Variance: n=1_000_000 | value=0.999463
└─ Variance: n=1_000_000 | value=1.0014


julia> x = 1:5
1:5

julia> (x .- mean.(value(o))) ./ std.(value(o))
5-element Vector{Float64}:
 0.9999604301846247
 2.0024759146329214
 3.0005499202160566
 3.999616355074072
 4.995462949480359

@nic-barbara
Copy link
Author

Thanks for the suggestions and the explanation, this actually solved my issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants