Skip to content

Commit d8798d5

Browse files
committed
Add code for a multivariate discrete rv
1 parent 9334727 commit d8798d5

File tree

3 files changed

+128
-26
lines changed

3 files changed

+128
-26
lines changed

src/QuantEcon.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ export
4242

4343
# discrete_rv
4444
DiscreteRV,
45+
MVDiscreteRV,
4546
draw,
4647

4748
# mc_tools

src/discrete_rv.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,69 @@ function Base.rand!(out::AbstractArray{T}, d::DiscreteRV) where T<:Integer
7878
end
7979

8080
@deprecate draw Base.rand
81+
82+
83+
struct MVDiscreteRV{TV1<:AbstractArray,TV2<:AbstractVector,K,TI<:Integer}
84+
q::TV1
85+
Q::TV2
86+
dims::NTuple{K,TI}
87+
88+
function MVDiscreteRV{TV1,TV2,K,TI}(q::TV1, Q::TV2, dims::NTuple{K,TI}) where {TV1,TV2,K,TI}
89+
abs(sum(q) - 1.0) > 1e-10 && error("q should sum to 1")
90+
abs(Q[end] - 1.0) > 1e-10 && error("Q[end] should be 1")
91+
length(Q) != prod(dims) && error("Number of elements is inconsistent")
92+
93+
new{TV1,TV2,K,TI}(q, Q, dims)
94+
end
95+
end
96+
97+
98+
function MVDiscreteRV(q::TV1) where TV1<:AbstractArray
99+
Q = cumsum(vec(q))
100+
dims = size(q)
101+
102+
return MVDiscreteRV{typeof(q),typeof(Q),length(dims),eltype(dims)}(q, Q, dims)
103+
end
104+
105+
106+
"""
107+
Make a single draw from the multivariate discrete distribution.
108+
109+
##### Arguments
110+
111+
- `d::MVDiscreteRV`: The `MVDiscreteRV` type represetning the distribution
112+
113+
##### Returns
114+
115+
- `out::NTuple{Int}`: One draw from the discrete distribution
116+
"""
117+
function Base.rand(d::MVDiscreteRV)
118+
x = rand()
119+
i = searchsortedfirst(d.Q, x)
120+
121+
return ind2sub(d.dims, i)
122+
end
123+
124+
"""
125+
Make multiple draws from the discrete distribution represented by a
126+
`MVDiscreteRV` instance
127+
128+
##### Arguments
129+
130+
- `d::MVDiscreteRV`: The `DiscreteRV` type representing the distribution
131+
- `k::Int`
132+
133+
##### Returns
134+
135+
- `out::Vector{NTuple{Int}}`: `k` draws from `d`
136+
"""
137+
Base.rand(d::MVDiscreteRV{T1,T2,K,TI}, k::V) where {T1,T2,K,TI,V} =
138+
NTuple{K,TI}[rand(d) for i in 1:k]
139+
140+
function Base.rand!(out::AbstractArray{NTuple{K,TI}}, d::MVDiscreteRV) where {K,TI}
141+
@inbounds for I in eachindex(out)
142+
out[I] = rand(d)
143+
end
144+
145+
return out
146+
end

test/test_discrete_rv.jl

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,69 @@
11
@testset "Testing discrete_rv.jl" begin
22

3-
# set up
4-
n = 10
5-
x = rand(n)
6-
x ./= sum(x)
7-
drv = DiscreteRV(x)
8-
9-
# test Q sums to 1
10-
@test drv.Q[end] 1.0
11-
12-
# test lln
13-
draws = rand(drv, 100_000)
14-
c = counter(draws)
15-
counts = Array{Float64}(n)
16-
for i=1:n
17-
counts[i] = c[i]
18-
end
19-
counts ./= sum(counts)
3+
@testset "Testing univariate discrete rv" begin
4+
# set up
5+
n = 10
6+
x = rand(n)
7+
x ./= sum(x)
8+
drv = DiscreteRV(x)
9+
10+
# test Q sums to 1
11+
@test drv.Q[end] 1.0
12+
13+
# test lln
14+
draws = rand(drv, 100_000)
15+
c = counter(draws)
16+
counts = Array{Float64}(n)
17+
for i=1:n
18+
counts[i] = c[i]
19+
end
20+
counts ./= sum(counts)
2021

21-
@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)
22+
@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)
2223

23-
draws = Array{Int}(100_000)
24-
rand!(draws, drv)
25-
c = counter(draws)
26-
counts = Array{Float64}(n)
27-
for i=1:n
28-
counts[i] = c[i]
24+
draws = Array{Int}(100_000)
25+
rand!(draws, drv)
26+
c = counter(draws)
27+
counts = Array{Float64}(n)
28+
for i=1:n
29+
counts[i] = c[i]
30+
end
31+
counts ./= sum(counts)
32+
33+
@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)
2934
end
30-
counts ./= sum(counts)
3135

32-
@test isapprox(Base.maximum(abs, counts - drv.q), 0.0; atol=1e-2)
36+
@testset "Testing multivariate discrete rv" begin
37+
# Do tests for various sizes
38+
for dims in [(5, 3), (5, 10, 3), (5, 7, 5, 10)]
39+
# How many dimensions
40+
n = length(dims)
41+
42+
# Make some distributional matrix
43+
q = rand(dims...)
44+
q ./= sum(q) # Normalize to sum to 1
45+
46+
# Create mv rv
47+
rv = MVDiscreteRV(q)
48+
49+
# Make sure it doesn't draw numbers that don't make sense... Must
50+
# be between 1 and n
51+
for i in 1:n
52+
@test rand(rv)[i] >= 1
53+
@test rand(rv)[i] <= dims[i]
54+
end
55+
56+
ndraws = 1_000_000
57+
draws = rand(rv, ndraws)
58+
counter = zeros(dims...)
59+
for i in 1:ndraws
60+
draw = draws[i]
61+
counter[draw...] += 1.0
62+
end
63+
counter ./= ndraws
64+
@test isapprox(Base.maximum(abs, counter - rv.q), 0.0; atol=1e-2)
65+
end
66+
67+
end
3368

3469
end # testset

0 commit comments

Comments
 (0)