diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ee4ef71e..97bfc55a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,7 +25,7 @@ jobs: version: - '1.6' - '1.7' - - 'nightly' + - '1.8' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index 1c3f5c85..1c9139e4 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,9 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -21,6 +23,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -33,8 +36,10 @@ Compat = "3.35, 4" ConstructionBase = "1.3" DensityInterface = "0.4" FillArrays = "0.12, 0.13" +FunctionChains = "0.1" IfElse = "0.1" -InverseFunctions = "0.1.7" +IntervalSets = "0.7" +InverseFunctions = "0.1.8" IrrationalConstants = "0.1" LogExpFunctions = "0.3" LogarithmicNumbers = "1" @@ -42,6 +47,7 @@ MappedArrays = "0.4" NaNMath = "0.3, 1" PrettyPrinting = "0.3, 0.4" Reexport = "1" +SpecialFunctions = "2" Static = "0.5, 0.6" Tricks = "0.1" julia = "1.3" diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index fbaa84d6..dcfe7464 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -8,9 +8,16 @@ import Random: gentype using Statistics using LinearAlgebra +import IntervalSets +# This seems harder than it should be to get `IntervalSets.:(..)` +@eval (using IntervalSets: $(Symbol(IntervalSets.:(..)))) + +using IntervalSets: Interval, width + import DensityInterface: logdensityof import DensityInterface: densityof import DensityInterface: DensityKind +using DensityInterface: FuncDensity, LogFuncDensity using DensityInterface using InverseFunctions @@ -19,6 +26,7 @@ using ChangesOfVariables import Base.iterate import ConstructionBase using ConstructionBase: constructorof +using IntervalSets using PrettyPrinting const Pretty = PrettyPrinting @@ -26,6 +34,7 @@ const Pretty = PrettyPrinting using ChainRulesCore using FillArrays using Static +using FunctionChains export ≪ export gentype @@ -108,6 +117,7 @@ using Compat using IrrationalConstants +include("smf.jl") include("getdof.jl") include("transport.jl") include("schema.jl") @@ -115,10 +125,10 @@ include("splat.jl") include("proxies.jl") include("kernel.jl") include("parameterized.jl") -include("combinators/half.jl") include("domains.jl") include("primitive.jl") include("utils.jl") +include("mass-interface.jl") # include("absolutecontinuity.jl") include("primitives/counting.jl") @@ -144,9 +154,11 @@ include("standard/stdmeasure.jl") include("standard/stduniform.jl") include("standard/stdexponential.jl") include("standard/stdlogistic.jl") -include("latent-joint.jl") +include("standard/stdnormal.jl") +include("combinators/half.jl") include("rand.jl") +include("fixedrng.jl") include("density.jl") include("density-core.jl") diff --git a/src/combinators/half.jl b/src/combinators/half.jl index 32dec91d..c713193a 100644 --- a/src/combinators/half.jl +++ b/src/combinators/half.jl @@ -26,4 +26,18 @@ logdensity_def(μ::Half, x) = logdensity_def(unhalf(μ), x) insupport(unhalf(d), x) end -testvalue(::Half) = 1.0 +testvalue(::Type{T}, ::Half) where {T} = one(T) + +massof(μ::Half) = massof(unhalf(μ)) + +function smf(μ::Half, x) + 2 * smf(μ.parent, max(x, zero(x))) - 1 +end + +function invsmf(μ::Half, p) + @assert zero(p) ≤ p ≤ one(p) + invsmf(μ.parent, (p + 1) / 2) +end + +transport_def(μ::Half, ::StdUniform, p) = invsmf(μ, p) +transport_def(::StdUniform, μ::Half, x) = smf(μ, x) diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index d6da7ee1..6dfd164f 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -202,4 +202,6 @@ more efficient than logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) """ -likelihood_ratio(ℓ::Likelihood, p, q) = exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)) \ No newline at end of file +function likelihood_ratio(ℓ::Likelihood, p, q) + exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)) +end diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 94447844..570e8b18 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -128,3 +128,5 @@ end function checked_arg(μ::PowerMeasure, x::Any) throw(ArgumentError("Size of variate doesn't match size of power measure")) end + +massof(m::PowerMeasure) = massof(m.parent)^prod(m.axes) diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 975ebcd1..97db166b 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -16,6 +16,8 @@ function Pretty.tile(μ::AbstractProductMeasure) result *= Pretty.literal(")") end +massof(m::AbstractProductMeasure) = prod(massof, marginals(m)) + export marginals function Base.:(==)(a::AbstractProductMeasure, b::AbstractProductMeasure) @@ -159,9 +161,11 @@ marginals(μ::ProductMeasure) = μ.marginals # TODO: Better `map` support in MappedArrays _map(f, args...) = map(f, args...) -_map(f, x::MappedArrays.ReadonlyMappedArray) = mappedarray(f ∘ x.f, x.data) +_map(f, x::MappedArrays.ReadonlyMappedArray) = mappedarray(fchain((x.f, f)), x.data) -testvalue(d::AbstractProductMeasure) = _map(testvalue, marginals(d)) +function testvalue(::Type{T}, d::AbstractProductMeasure) where {T} + _map(m -> testvalue(T, m), marginals(d)) +end export ⊗ @@ -220,3 +224,16 @@ end end return true end + +getdof(d::AbstractProductMeasure) = mapreduce(getdof, +, marginals(d)) + +function checked_arg(μ::ProductMeasure{<:NTuple{N,Any}}, x::NTuple{N,Any}) where {N} + map(checked_arg, marginals(μ), x) +end + +function checked_arg( + μ::ProductMeasure{<:NamedTuple{names}}, + x::NamedTuple{names}, +) where {names} + NamedTuple{names}(map(checked_arg, values(marginals(μ)), values(x))) +end diff --git a/src/combinators/smart-constructors.jl b/src/combinators/smart-constructors.jl index 775b1d93..11d7bbbb 100644 --- a/src/combinators/smart-constructors.jl +++ b/src/combinators/smart-constructors.jl @@ -85,7 +85,7 @@ superpose(nt::NamedTuple) = SuperpositionMeasure(nt) function superpose(μ::T, ν::T) where {T<:AbstractMeasure} if μ == ν - return weightedmeasure(logtwo, μ) + return weightedmeasure(static(float(logtwo)), μ) else return superpose((μ, ν)) end diff --git a/src/combinators/spikemixture.jl b/src/combinators/spikemixture.jl index 6168907a..b1e3cf46 100644 --- a/src/combinators/spikemixture.jl +++ b/src/combinators/spikemixture.jl @@ -37,6 +37,6 @@ function Base.rand(rng::AbstractRNG, T::Type, μ::SpikeMixture) return (rand(rng, T) < μ.w) * rand(rng, T, μ.m) end -testvalue(μ::SpikeMixture) = testvalue(μ.m) +testvalue(::Type{T}, μ::SpikeMixture) where {T} = zero(T) insupport(μ::SpikeMixture, x) = dynamic(insupport(μ.m, x)) || iszero(x) diff --git a/src/combinators/superpose.jl b/src/combinators/superpose.jl index 53d49674..2d385636 100644 --- a/src/combinators/superpose.jl +++ b/src/combinators/superpose.jl @@ -4,6 +4,8 @@ using LogExpFunctions export SuperpositionMeasure +abstract type AbstractSuperpositionMeasure <: AbstractMeasure end + @doc raw""" struct SuperpositionMeasure{NT} <: AbstractMeasure components :: NT @@ -24,17 +26,19 @@ Superposition measures satisfy \end{aligned} ``` """ -struct SuperpositionMeasure{C} <: AbstractMeasure +struct SuperpositionMeasure{C} <: AbstractSuperpositionMeasure components::C end +massof(m::SuperpositionMeasure) = sum(massof, m.components) + function Pretty.tile(d::SuperpositionMeasure) result = Pretty.literal("SuperpositionMeasure(") result *= Pretty.list_layout([Pretty.tile.(d.components)...]) result *= Pretty.literal(")") end -testvalue(μ::SuperpositionMeasure) = testvalue(first(μ.components)) +testvalue(::Type{T}, μ::SuperpositionMeasure) where {T} = testvalue(T, first(μ.components)) # SuperpositionMeasure(ms :: AbstractMeasure...) = SuperpositionMeasure{X,length(ms)}(ms) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 35baf32d..b00065a1 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -1,5 +1,7 @@ # TODO: Compare with ChangesOfVariables.jl +using InverseFunctions: FunctionWithInverse + abstract type AbstractTransformedMeasure <: AbstractMeasure end abstract type AbstractPushforward <: AbstractTransformedMeasure end @@ -17,16 +19,19 @@ function parent(::AbstractTransformedMeasure) end export PushforwardMeasure """ - struct PushforwardMeasure{FF,IF,MU,VC<:TransformVolCorr} <: AbstractPushforward - f :: FF - inv_f :: IF - origin :: MU + struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr} <: AbstractPushforward + f :: F + finv :: I + origin :: M volcorr :: VC end + + Users should not call `PushforwardMeasure` directly. Instead call or add + methods to `pushfwd`. """ -struct PushforwardMeasure{FF,IF,M,VC<:TransformVolCorr} <: AbstractPushforward - f::FF - inv_f::IF +struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr} <: AbstractPushforward + f::F + finv::I origin::M volcorr::VC end @@ -35,14 +40,25 @@ gettransform(ν::PushforwardMeasure) = ν.f parent(ν::PushforwardMeasure) = ν.origin function Pretty.tile(ν::PushforwardMeasure) - Pretty.list_layout(Pretty.tile.([ν.f, ν.inv_f, ν.origin]); prefix = :PushforwardMeasure) + Pretty.list_layout(Pretty.tile.([ν.f, ν.origin]); prefix = :PushforwardMeasure) end -@inline function logdensity_def( - ν::PushforwardMeasure{FF,IF,M,<:WithVolCorr}, - y, -) where {FF,IF,M} - x_orig, inv_ladj = with_logabsdet_jacobian(ν.inv_f, y) +# TODO: THIS IS ALMOST CERTAINLY WRONG +# @inline function logdensity_rel( +# ν::PushforwardMeasure{FF1,IF1,M1,<:WithVolCorr}, +# β::PushforwardMeasure{FF2,IF2,M2,<:WithVolCorr}, +# y, +# ) where {FF1,IF1,M1,FF2,IF2,M2} +# x = β.inv_f(y) +# f = ν.inv_f ∘ β.f +# inv_f = β.inv_f ∘ ν.f +# logdensity_rel(pushfwd(f, inv_f, ν.origin, WithVolCorr()), β.origin, x) +# end + +@inline function logdensity_def(ν::PushforwardMeasure{F,I,M,<:WithVolCorr}, y) where {F,I,M} + f = ν.f + finv = ν.finv + x_orig, inv_ladj = with_logabsdet_jacobian(finv, y) logd_orig = logdensity_def(ν.origin, x_orig) logd = float(logd_orig + inv_ladj) neginf = oftype(logd, -Inf) @@ -57,49 +73,87 @@ end ) end -@inline function logdensity_def( - ν::PushforwardMeasure{FF,IF,M,<:NoVolCorr}, - y, -) where {FF,IF,M} - x_orig = to_origin(ν, y) - return logdensity_def(ν.origin, x_orig) +@inline function logdensity_def(ν::PushforwardMeasure{F,I,M,<:NoVolCorr}, y) where {F,I,M} + x = ν.finv(y) + return logdensity_def(ν.origin, x) end -insupport(ν::PushforwardMeasure, y) = insupport(transport_origin(ν), to_origin(ν, y)) +insupport(ν::PushforwardMeasure, y) = insupport(ν.origin, ν.finv(y)) -testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(transport_origin(ν))) +function testvalue(::Type{T}, ν::PushforwardMeasure) where {T} + ν.f(testvalue(T, parent(ν))) +end @inline function basemeasure(ν::PushforwardMeasure) - PushforwardMeasure(ν.f, ν.inv_f, basemeasure(transport_origin(ν)), NoVolCorr()) + pushfwd(ν.f, basemeasure(parent(ν)), NoVolCorr()) end _pushfwd_dof(::Type{MU}, ::Type, dof) where {MU} = NoDOF{MU}() _pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof -# Assume that DOF are preserved if with_logabsdet_jacobian is functional: -@inline function getdof(ν::MU) where {MU<:PushforwardMeasure} - T = Core.Compiler.return_type(testvalue, Tuple{typeof(ν.origin)}) - R = Core.Compiler.return_type(with_logabsdet_jacobian, Tuple{typeof(ν.f),T}) - _pushfwd_dof(MU, R, getdof(ν.origin)) -end +@inline getdof(ν::MU) where {MU<:PushforwardMeasure} = getdof(ν.origin) # Bypass `checked_arg`, would require potentially costly transformation: @inline checked_arg(::PushforwardMeasure, x) = x @inline transport_origin(ν::PushforwardMeasure) = ν.origin @inline from_origin(ν::PushforwardMeasure, x) = ν.f(x) -@inline to_origin(ν::PushforwardMeasure, y) = ν.inv_f(y) +@inline to_origin(ν::PushforwardMeasure, y) = ν.finv(y) function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where {T} - return from_origin(ν, rand(rng, T, transport_origin(ν))) + return ν.f(rand(rng, T, parent(ν))) end +############################################################################### +# pushfwd + export pushfwd """ pushfwd(f, μ, volcorr = WithVolCorr()) -Return the [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure) -from `μ` the [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`. +Return the [pushforward +measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the +[measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`. + +To manually specify an inverse, call +`pushfwd(InverseFunctions.setinverse(f, finv), μ, volcorr)`. """ -pushfwd(f, μ, volcorr = WithVolCorr()) = PushforwardMeasure(f, inverse(f), μ, volcorr) +function pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr()) + PushforwardMeasure(f, inverse(f), μ, volcorr) +end + +function pushfwd(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr()) + _pushfwd_of_pushfwd(f, μ, μ.volcorr, volcorr) +end + +# Either both WithVolCorr or both NoVolCorr, so we can merge them +function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, ::V, v::V) where {V} + pushfwd(fchain((μ.f, f)), μ.origin, v) +end + +function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, _, v) + PushforwardMeasure(f, inverse(f), μ, v) +end + +############################################################################### +# pullback + +""" + pullback(f, μ, volcorr = WithVolCorr()) + +A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a +map _from_ the support of a measure, a pullback requires a map _into_ the +support of a measure. The log-density is then computed through function +composition, together with a volume correction as needed. + +This can be useful, since the log-density of a `PushforwardMeasure` is computing +in terms of the inverse function; the "forward" function is not used at all. In +some cases, we may be focusing on log-density (and not, for example, sampling). + +To manually specify an inverse, call +`pullback(InverseFunctions.setinverse(f, finv), μ, volcorr)`. +""" +function pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) + pushfwd(setinverse(inverse(f), f), μ, volcorr) +end diff --git a/src/combinators/weighted.jl b/src/combinators/weighted.jl index aef9dbee..db239b50 100644 --- a/src/combinators/weighted.jl +++ b/src/combinators/weighted.jl @@ -20,6 +20,8 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, μ::AbstractWeightedMeasure) whe rand(rng, T, basemeasure(μ)) end +testvalue(::Type{T}, μ::AbstractWeightedMeasure) where {T} = testvalue(T, basemeasure(μ)) + ############################################################################### struct WeightedMeasure{R,M} <: AbstractWeightedMeasure @@ -27,6 +29,8 @@ struct WeightedMeasure{R,M} <: AbstractWeightedMeasure base::M end +massof(w::WeightedMeasure) = exp(w.logweight) * massof(w.base) + _logweight(μ::WeightedMeasure) = μ.logweight basemeasure(μ::AbstractWeightedMeasure) = μ.base @@ -49,6 +53,8 @@ gentype(μ::WeightedMeasure) = gentype(μ.base) insupport(μ::WeightedMeasure, x) = insupport(μ.base, x) +# TODO: Transports must preserve mass transport_origin(ν::WeightedMeasure) = ν.base -to_origin(::WeightedMeasure, y) = y -from_origin(::WeightedMeasure, x) = x + +to_origin(w::WeightedMeasure, y) = y +from_origin(w::WeightedMeasure, x) = x diff --git a/src/density-core.jl b/src/density-core.jl index fa7dfd2a..5b75f2f9 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -1,3 +1,13 @@ +export logdensityof +export logdensity_rel +export logdensity_def + +export unsafe_logdensityof +export unsafe_logdensity_rel + +export densityof +export density_rel +export density_def """ logdensityof(m::AbstractMeasure, x) @@ -20,7 +30,18 @@ To compute a log-density relative to a specific base-measure, see """ @inline function logdensityof(μ::AbstractMeasure, x) result = dynamic(unsafe_logdensityof(μ, x)) - ifelse(insupport(μ, x) == true, result, oftype(result, -Inf)) + _checksupport(insupport(μ, x), result) +end + +_checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf)) + +import ChainRulesCore +@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result) + y = _checksupport(cond, result) + function _checksupport_pullback(ȳ) + return NoTangent(), ZeroTangent(), one(ȳ) + end + y, _checksupport_pullback end export unsafe_logdensityof @@ -39,9 +60,11 @@ See also `logdensityof`. b_0 = μ Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number b_{i} = basemeasure(b_{i - 1}, x) - if b_{i} isa typeof(b_{i - 1}) - return ℓ_{i - 1} - end + + # The below makes the evaluated code shorter, but screws up Zygote + # if b_{i} isa typeof(b_{i - 1}) + # return ℓ_{i - 1} + # end ℓ_{i} = let Δℓ_{i} = logdensity_def(b_{i}, x) ℓ_{i - 1} + Δℓ_{i} end @@ -49,12 +72,6 @@ See also `logdensityof`. return ℓ_10 end -export density_rel - -@inline density_rel(μ, ν, x) = exp(logdensity_rel(μ, ν, x)) - -export logdensity_rel - """ logdensity_rel(m1, m2, x) @@ -117,7 +134,9 @@ function logdensity_def(μ::T, ν::T, x) where {T} if μ === ν return zero(logdensity_def(μ, x)) else - return logdensity_def(μ, x) - logdensity_def(ν, x) + α = basemeasure(μ) + β = basemeasure(ν) + return logdensity_def(μ, x) - logdensity_def(ν, x) + logdensity_rel(α, β, x) end end @@ -151,10 +170,8 @@ end return q end -export densityof -export logdensityof - -export density_def +@inline density_rel(μ, ν, x) = exp(logdensity_rel(μ, ν, x)) +# TODO: Do we need this method? density_def(μ, ν::AbstractMeasure, x) = exp(logdensity_def(μ, ν, x)) density_def(μ, x) = exp(logdensity_def(μ, x)) diff --git a/src/density.jl b/src/density.jl index 180906e0..4862dcb1 100644 --- a/src/density.jl +++ b/src/density.jl @@ -1,53 +1,122 @@ -abstract type AbstractDensity end +################################################################### +# Abstract types and methods + +abstract type AbstractDensity <: Function end @inline DensityKind(::AbstractDensity) = IsDensity() +import DensityInterface + +#################################################################################### +# Density + """ - struct Density{M,B} + struct Density{M,B} <: AbstractDensity μ::M base::B end -For measures μ and ν with μ≪ν, the density of μ with respect to ν (also called -the Radon-Nikodym derivative dμ/dν) is a function f defined on the support of ν -with the property that for any measurable a ⊂ supp(ν), μ(a) = ∫ₐ f dν. - -Because this function is often difficult to express in closed form, there are -many different ways of computing it. We therefore provide a formal -representation to allow comptuational flexibilty. +For measures `μ` and `ν`, `Density(μ,ν)` represents the _density function_ +`dμ/dν`, also called the _Radom-Nikodym derivative_: +https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem#Radon%E2%80%93Nikodym_derivative + +Instead of calling this directly, users should call `density_rel(μ, ν)` or +its abbreviated form, `𝒹(μ,ν)`. """ struct Density{M,B} <: AbstractDensity μ::M base::B end +Base.:∘(::typeof(log), d::Density) = logdensity_rel(d.μ, d.base) + +Base.log(d::Density) = log ∘ d + export 𝒹 """ - 𝒹(μ::AbstractMeasure, base::AbstractMeasure) + 𝒹(μ, base) + +Compute the density (Radom-Nikodym derivative) of μ with respect to `base`. This +is a shorthand form for `density_rel(μ, base)`. +""" +𝒹(μ, base) = density_rel(μ, base) + +density_rel(μ, base) = Density(μ, base) + +(f::Density)(x) = density_rel(f.μ, f.base, x) + +DensityInterface.logfuncdensity(d::Density) = throw(MethodError(logfuncdensity, (d,))) + +#################################################################################### +# LogDensity + +""" + struct LogDensity{M,B} <: AbstractDensity + μ::M + base::B + end -Compute the Radom-Nikodym derivative of μ with respect to `base`. +For measures `μ` and `ν`, `LogDensity(μ,ν)` represents the _log-density function_ +`log(dμ/dν)`, also called the _Radom-Nikodym derivative_: +https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem#Radon%E2%80%93Nikodym_derivative + +Instead of calling this directly, users should call `logdensity_rel(μ, ν)` or +its abbreviated form, `log𝒹(μ,ν)`. """ -function 𝒹(μ::AbstractMeasure, base::AbstractMeasure) - return Density(μ, base) +struct LogDensity{M,B} <: AbstractDensity + μ::M + base::B end -logdensityof(d::Density, x) = logdensity_rel(d.μ, d.base, x) +Base.:∘(::typeof(exp), d::LogDensity) = density_rel(d.μ, d.base) + +Base.exp(d::LogDensity) = exp ∘ d -logdensity_def(d::Density, x) = logdensityof(d, x) +export log𝒹 + +""" + log𝒹(μ, base) +Compute the log-density (Radom-Nikodym derivative) of μ with respect to `base`. +This is a shorthand form for `logdensity_rel(μ, base)` """ - struct DensityMeasure{F,B} <: AbstractMeasure +log𝒹(μ, base) = logdensity_rel(μ, base) + +logdensity_rel(μ, base) = LogDensity(μ, base) + +(f::LogDensity)(x) = logdensity_rel(f.μ, f.base, x) + +DensityInterface.funcdensity(d::LogDensity) = throw(MethodError(funcdensity, (d,))) + +####################################################################################### +# DensityMeasure + +""" + struct DensityMeasure{F,B} <: AbstractDensityMeasure density :: F base :: B end -A `DensityMeasure` is a measure defined by a density with respect to some other -"base" measure +A `DensityMeasure` is a measure defined by a density or log-density with respect +to some other "base" measure. + +Users should not call `DensityMeasure` directly, but should instead call `∫(f, +base)` (if `f` is a density function or `DensityInterface.IsDensity` object) or +`∫exp(f, base)` (if `f` is a log-density function). """ struct DensityMeasure{F,B} <: AbstractMeasure f::F base::B + + function DensityMeasure(f::F, base::B) where {F,B} + @assert DensityKind(f) isa IsDensity + new{F,B}(f, base) + end +end + +@inline function insupport(d::DensityMeasure, x) + insupport(d.base, x) == true && isfinite(logdensityof(getfield(d, :f), x)) end function Pretty.tile(μ::DensityMeasure{F,B}) where {F,B} @@ -56,28 +125,6 @@ function Pretty.tile(μ::DensityMeasure{F,B}) where {F,B} result *= Pretty.literal(")") end -densitymeasure(f, base) = _densitymeasure(f, base, DensityKind(f)) - -_densitymeasure(f, base, ::IsDensity) = DensityMeasure(f, base) - -function _densitymeasure(f, base, _) - @error """ - The first argument of `DensityMeasure`" must be `::IsDensity`. To pass a - function, first wrap it in `DensityInterface.funcdensity` or - `DensityInterface.logfuncdensity`. - """ -end - -@inline function insupport(d::DensityMeasure, x) - insupport(d.base, x) == true && isfinite(logdensityof(d.f, x)) -end - -basemeasure(μ::DensityMeasure) = μ.base - -logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) - -density_def(μ::DensityMeasure, x) = densityof(μ.f, x) - export ∫ """ @@ -85,11 +132,13 @@ export ∫ Define a new measure in terms of a density `f` over some measure `base`. """ -∫(f::Function, base::AbstractMeasure) = DensityMeasure(funcdensity(f), base) +∫(f, base) = _densitymeasure(f, base, DensityKind(f)) -∫(f, base::AbstractMeasure) = _densitymeasure(f, base, DensityKind(f)) - -# ∫(μ::AbstractMeasure, base::AbstractMeasure) = ∫(𝒹(μ, base), base) +_densitymeasure(f, base, ::IsDensity) = DensityMeasure(f, base) +function _densitymeasure(f, base, ::HasDensity) + @error "`∫(f, base)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`." +end +_densitymeasure(f, base, ::NoDensity) = DensityMeasure(funcdensity(f), base) export ∫exp @@ -98,8 +147,21 @@ export ∫exp Define a new measure in terms of a log-density `f` over some measure `base`. """ -∫exp(f::Function, μ) = ∫(logfuncdensity(f), μ) +∫exp(f, base) = _logdensitymeasure(f, base, DensityKind(f)) +function _logdensitymeasure(f, base, ::IsDensity) + @error "`∫exp(f, base)` is not valid when `DensityKind(f) == IsDensity()`. Use `∫(f, base)` instead." +end +function _logdensitymeasure(f, base, ::HasDensity) + @error "`∫exp(f, base)` is not valid when `DensityKind(f) == HasDensity()`." +end +_logdensitymeasure(f, base, ::NoDensity) = DensityMeasure(logfuncdensity(f), base) + +basemeasure(μ::DensityMeasure) = μ.base + +logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) + +density_def(μ::DensityMeasure, x) = densityof(μ.f, x) """ rebase(μ, ν) diff --git a/src/domains.jl b/src/domains.jl index ed4617b2..e4c7f5e7 100644 --- a/src/domains.jl +++ b/src/domains.jl @@ -32,9 +32,9 @@ Base.maximum(b::BoundedReals) = b.upper Base.show(io::IO, ::typeof(ℝ₊)) = print(io, "ℝ₊") Base.show(io::IO, ::typeof(𝕀)) = print(io, "𝕀") -testvalue(::typeof(ℝ)) = 0.0 -testvalue(::typeof(ℝ₊)) = 1.0 -testvalue(::typeof(𝕀)) = 0.5 +testvalue(::Type{T}, ::typeof(ℝ)) where {T} = zero(T) +testvalue(::Type{T}, ::typeof(ℝ₊)) where {T} = one(T) +testvalue(::Type{T}, ::typeof(𝕀)) where {T} = one(T) / 2 abstract type IntegerDomain <: AbstractDomain end diff --git a/src/fixedrng.jl b/src/fixedrng.jl new file mode 100644 index 00000000..232b0891 --- /dev/null +++ b/src/fixedrng.jl @@ -0,0 +1,19 @@ +export FixedRNG +struct FixedRNG <: AbstractRNG end + +Base.rand(::FixedRNG) = one(Float64) / 2 +Random.randn(::FixedRNG) = zero(Float64) +Random.randexp(::FixedRNG) = one(Float64) + +Base.rand(::FixedRNG, ::Type{T}) where {T<:Real} = one(T) / 2 +Random.randn(::FixedRNG, ::Type{T}) where {T<:Real} = zero(T) +Random.randexp(::FixedRNG, ::Type{T}) where {T<:Real} = one(T) + +# We need concrete type parameters to avoid amiguity for these cases +for T in [Float16, Float32, Float64] + @eval begin + Base.rand(::FixedRNG, ::Type{$T}) = one($T) / 2 + Random.randn(::FixedRNG, ::Type{$T}) = zero($T) + Random.randexp(::FixedRNG, ::Type{$T}) = one($T) + end +end diff --git a/src/interface.jl b/src/interface.jl index 08a4b369..06b836ba 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -11,9 +11,12 @@ using MeasureBase: transport_to, NoTransport using DensityInterface: logdensityof using InverseFunctions: inverse using ChangesOfVariables: with_logabsdet_jacobian +using Tricks: static_hasmethod +using IntervalSets: Interval export test_interface export test_transport +export test_smf export basemeasure_depth export proxy export insupport @@ -22,7 +25,10 @@ export commonbase using Test -function dynamic_basemeasure_depth(μ) +function dynamic_basemeasure_depth(μ::M) where {M} + if static_hasmethod(proxy, Tuple{M}) + return dynamic_basemeasure_depth(proxy(μ)) + end β = basemeasure(μ) depth = 0 while μ ≠ β @@ -54,7 +60,7 @@ function test_interface(μ::M) where {M} ########################################################################### # testvalue, logdensityof - x = @inferred testvalue(μ) + x = @inferred testvalue(Float64, μ) β = @inferred basemeasure(μ, x) ℓμ = @inferred logdensityof(μ, x) @@ -62,30 +68,65 @@ function test_interface(μ::M) where {M} @test ℓμ ≈ logdensity_def(μ, x) + ℓβ - @test logdensity_def(μ, testvalue(μ)) isa Real + @test logdensity_def(μ, testvalue(Float64, μ)) isa Real end end end function test_transport(ν, μ) + supertype(x) = Any supertype(x::Real) = Real supertype(x::AbstractArray{<:Real,N}) where {N} = AbstractArray{<:Real,N} + structisapprox(a, b) = isapprox(a, b) + function structisapprox(a::NTuple{N,Any}, b::NTuple{N,Any}) where {N} + all(map(structisapprox, a, b)) + end + function structisapprox(a::NamedTuple{names}, b::NamedTuple{names}) where {names} + all(map(structisapprox, values(a), values(b))) + end + @testset "transport_to $μ to $ν" begin x = rand(μ) @test !(@inferred(transport_to(ν, μ)(x)) isa NoTransport) f = transport_to(ν, μ) y = f(x) - @test @inferred(inverse(f)(y)) ≈ x + @test structisapprox(@inferred(inverse(f)(y)), x) @test @inferred(with_logabsdet_jacobian(f, x)) isa Tuple{supertype(y),Real} @test @inferred(with_logabsdet_jacobian(inverse(f), y)) isa Tuple{supertype(x),Real} y2, ladj_fwd = with_logabsdet_jacobian(f, x) x2, ladj_inv = with_logabsdet_jacobian(inverse(f), y) - @test x ≈ x2 - @test y ≈ y2 - @test ladj_fwd ≈ -ladj_inv + @test structisapprox(x, x2) + @test structisapprox(y, y2) + @test isapprox(ladj_fwd, -ladj_inv, atol = 1e-10) @test ladj_fwd ≈ logdensityof(μ, x) - logdensityof(ν, y) end end +function test_smf(μ, n = 100) + # Get `n` sorted uniforms in O(n) time + p = rand(n) + p .+= 0:n-1 + p .*= inv(n) + + F(x) = smf(μ, x) + Finv(p) = invsmf(μ, p) + + @assert issorted(p) + x = invsmf.(μ, p) + @test issorted(x) + @test all(insupport(μ), x) + + @test all((Finv ∘ F).(x) .≈ x) + + for j in 1:n + a = rand() + b = rand() + a, b = minmax(a, b) + x = Finv(a) + y = Finv(b) + @test μ(Interval{:open,:closed}(x, y)) ≈ (F(y) - F(x)) + end +end + end # module Interface diff --git a/src/latent-joint.jl b/src/latent-joint.jl deleted file mode 100644 index 988c27cc..00000000 --- a/src/latent-joint.jl +++ /dev/null @@ -1,62 +0,0 @@ -_LATENT_DOCSTRING = """ -Some Probabilistic Programming Languages (PPLs) like Tilde.jl make a distinction -between a _latent space_, often a namespace represented as a named tuple, and -the space containing the return value, which we refer to as the _maifest space_. -The distinction is that computations are done in terms of the latent space, -while the resulting value is in the manifest space. - -To simplify many manipulations involving these concepts, we introduce the -concept of a _joint space_. For example, suppose `m()` is a measure with latent -space `NamedTuple{(:a, :b)}` that returns `a - b`, so the latent value `(a = 3, -b = 4)` is mapped to the manifest value `0.75`. Then the corresponding value in -the joint space is the pair `(a = 3, b = 4) => 0.75`. - -One of the many goals of probabilistic programming is to blur the line between -"built in" measures like `Normal()` and those defined in terms of a model from a -PPL. To accommodate this, we extend these concepts to general measures. - -For many measures, it's convenient to work directly in the manifest space, and -there's no need for such separation. However, it's important to be able to -manipulate measures programmatically, with minimal special cases. Because of -this, we introduce fall-back methods - - latentof(m) = m - manifestof(m) = m - -The default implementation of `jointof` is then a push-forward through the -function `x -> (x => x)`. For example, - - julia> rand(MeasureBase.jointof(StdUniform())) - 0.346439=>0.346439 -""" - -""" - latentof(m) - -$_LATENT_DOCSTRING -""" -latentof(m) = m - -""" - manifestof(m) - -$_LATENT_DOCSTRING -""" -manifestof(m) = m - -""" - jointof(m) - -$_LATENT_DOCSTRING -""" -function jointof(m) - fwd(x) = x => x - - function back(p::Pair) - x,y = p - @assert x === y - return x - end - - PushforwardMeasure(fwd, back, m, NoVolCorr()) -end diff --git a/src/mass-interface.jl b/src/mass-interface.jl new file mode 100644 index 00000000..7b0518f9 --- /dev/null +++ b/src/mass-interface.jl @@ -0,0 +1,119 @@ +import LinearAlgebra: normalize + +import Base + +abstract type AbstractUnknownMass <: Number end + +""" + struct UnknownFiniteMass <: AbstractUnknownMass end + +See `massof` +""" +struct UnknownFiniteMass <: AbstractUnknownMass end + +""" + struct UnknownMass <: AbstractUnknownMass end + +See `massof` +""" +struct UnknownMass <: AbstractUnknownMass end + +for T in (:UnknownFiniteMass, :UnknownMass) + @eval begin + Base.:+(::$T, ::$T) = $T() + Base.:*(::$T, ::$T) = $T() + Base.:^(::$T, k::Number) = isfinite(k) ? $T() : UnknownMass() + end +end + +for op in (:+, :*) + let + U = :UnknownMass + UF = :UnknownFiniteMass + @eval begin + Base.$op(::$U, ::$UF) = $U() + Base.$op(::$UF, ::$U) = $U() + end + end +end + +export massof + +""" + massof(m) + +Get the _mass_ of a measure - that is, integrate the measure over its support. + +`massof` + +---------- + + massof(m, dom) + +Integrate the measure `m` over the "domain" `dom`. Note that domains are not +defined universally, but may be specific to a given measure. If `m` is +`<:AbstractMeasure`, users can also write `m(dom)`. For new measures, users +should *not* add new "call" methods, but instead extend `MeasureBase.massof`. + + +For example, for many univariate measures `m` with `rootmeasure(m) == +LebesgueBase()`, users can call `massof(m, a_b)` where +`a_b::IntervalSets.Interval`. + +`massof` often returns a `Real`. But in many cases we may only know the mass is +finite, or we may know nothing at all about it. For these cases, it will return +`UnknownFiniteMass` or `UnknownMass`, respectively. When no `massof` method +exists, it defaults to `UnknownMass`. +""" +massof(m::AbstractMeasure) = UnknownMass(m) + +struct NormalizedMeasure{P,M} <: AbstractMeasure + parent::P + parent_mass::M +end + +massof(m::NormalizedMeasure) = static(1.0) + +normalize(m::AbstractMeasure) = _normalize(m, massof(m)) + +_normalize(m::AbstractMeasure, mass::AbstractUnknownMass) = NormalizedMeasure(m, mass) + +function _normalize(m::AbstractMeasure, mass) + isinf(mass) && error("Measure cannot be normalized: $m") + inv(mass) * m +end + +export isnormalized + +""" + isnormalized(m::AbstractMeasure) + +Checks whether the measure m is normalized, that is, whether `massof(m) == 1`. + +For convenience, we also provide a method on non-measures that only depends on +`norm`. +""" +isnormalized(m::AbstractMeasure) = isone(massof(m)) + +""" + isnormalized(x, p::Real=2) + +Check whether `norm(x, p) == 1`. +""" +isnormalized(x, p::Real = 2) = isone(norm(x, p)) + +isone(::AbstractUnknownMass) = false + +function massof(m, s) + _massof(m, s, rootmeasure(m)) +end + +""" + (m::AbstractMeasure)(s) + +Convenience method for `massof(m, s)`. To make a user-defined measure callable +in this way, users should add the corresponding `massof` method. +""" +(m::AbstractMeasure)(s) = massof(m, s) + +massof(μ, a_b::AbstractInterval) = smf(μ, rightendpoint(a_b)) - smf(μ, leftendpoint(a_b)) diff --git a/src/primitives/counting.jl b/src/primitives/counting.jl index b9715f34..f101ca96 100644 --- a/src/primitives/counting.jl +++ b/src/primitives/counting.jl @@ -1,10 +1,10 @@ # Counting measure -export Counting, CountingMeasure +export Counting, CountingBase -struct CountingMeasure <: PrimitiveMeasure end +struct CountingBase <: PrimitiveMeasure end -insupport(::CountingMeasure, x) = true +insupport(::CountingBase, x) = true struct Counting{T} <: AbstractMeasure support::T @@ -16,18 +16,22 @@ function logdensity_def(μ::Counting, x) insupport(μ, x) ? 0.0 : -Inf end -basemeasure(::Counting) = CountingMeasure() +basemeasure(::Counting) = CountingBase() Counting() = Counting(ℤ) -testvalue(d::Counting) = testvalue(d.support) +testvalue(::Type{T}, d::Counting) where {T} = testvalue(T, d.support) -proxy(d::Counting) = restrict(in(d.support), CountingMeasure()) +proxy(d::Counting) = restrict(in(d.support), CountingBase()) -Base.:∘(::typeof(basemeasure), ::Type{Counting}) = CountingMeasure() +Base.:∘(::typeof(basemeasure), ::Type{Counting}) = CountingBase() Base.show(io::IO, d::Counting) = print(io, "Counting(", d.support, ")") insupport(μ::Counting, x) = x ∈ μ.support insupport(μ::Counting{T}, x) where {T<:Type} = x isa μ.support + +massof(c::Counting, s::Set) = massof(CountingBase(), filter(insupport(c), s)) + +massof(::CountingBase, s::Set) = length(s) diff --git a/src/primitives/dirac.jl b/src/primitives/dirac.jl index 0c10e723..77727605 100644 --- a/src/primitives/dirac.jl +++ b/src/primitives/dirac.jl @@ -16,7 +16,9 @@ function (μ::Dirac{X})(s) where {X} return 0 end -basemeasure(d::Dirac) = CountingMeasure() +basemeasure(d::Dirac) = CountingBase() + +massof(::Dirac) = static(1.0) logdensity_def(μ::Dirac, x) = 0.0 @@ -26,8 +28,6 @@ export dirac dirac(d::AbstractMeasure) = Dirac(rand(d)) -testvalue(d::Dirac) = d.x - insupport(d::Dirac, x) = x == d.x @inline getdof(::Dirac) = static(0) diff --git a/src/primitives/lebesgue.jl b/src/primitives/lebesgue.jl index 52844708..8c42766f 100644 --- a/src/primitives/lebesgue.jl +++ b/src/primitives/lebesgue.jl @@ -2,14 +2,40 @@ export Lebesgue -struct LebesgueMeasure <: PrimitiveMeasure end +struct LebesgueBase <: PrimitiveMeasure end -testvalue(::LebesgueMeasure) = 0.0 +massof(::LebesgueBase, s::Interval) = width(s) -insupport(::LebesgueMeasure, x) = true +testvalue(::LebesgueBase) = 0.0 -insupport(::LebesgueMeasure) = Returns(true) +insupport(::LebesgueBase, x) = true +insupport(::LebesgueBase) = Returns(true) + +logdensity_def(::LebesgueBase, ::CountingBase, x) = -Inf + +logdensity_def(::CountingBase, ::LebesgueBase, x) = Inf + +@inline getdof(::LebesgueBase) = static(1) + +@inline checked_arg(::LebesgueBase, x::Real) = x + +@propagate_inbounds function checked_arg(::LebesgueBase, x::Any) + @boundscheck throw(ArgumentError("Invalid variate type for measure")) +end + +massof(::LebesgueBase) = static(Inf) + +function _massof(m, s::Interval, ::LebesgueBase) + mass = massof(m) + nu = mass * StdUniform() + f = transport_to(nu, m) + a = f(minimum(s)) + b = f(maximum(s)) + return mass * abs(b - a) +end + +########################################################## struct Lebesgue{T} <: AbstractMeasure support::T end @@ -22,14 +48,14 @@ gentype(::Lebesgue) = Float64 Lebesgue() = Lebesgue(ℝ) -# basemeasure(::Lebesgue) = LebesgueMeasure() +testvalue(::Type{T}, d::Lebesgue) where {T} = testvalue(T, d.support)::T -testvalue(d::Lebesgue) = testvalue(d.support) +proxy(d::Lebesgue) = restrict(in(d.support), LebesgueBase()) +proxy(::Lebesgue{MeasureBase.RealNumbers}) = LebesgueBase() -proxy(d::Lebesgue) = restrict(in(d.support), LebesgueMeasure()) @useproxy Lebesgue -Base.:∘(::typeof(basemeasure), ::Type{Lebesgue}) = LebesgueMeasure() +Base.:∘(::typeof(basemeasure), ::Type{Lebesgue}) = LebesgueBase() Base.show(io::IO, d::Lebesgue) = print(io, "Lebesgue(", d.support, ")") @@ -37,14 +63,30 @@ insupport(μ::Lebesgue, x) = x ∈ μ.support insupport(::Lebesgue{RealNumbers}, ::Real) = true -logdensity_def(::LebesgueMeasure, ::CountingMeasure, x) = -Inf - -logdensity_def(::CountingMeasure, ::LebesgueMeasure, x) = Inf +massof(::Lebesgue{RealNumbers}, s::Interval) = width(s) + +# Example: +# julia> Lebesgue(𝕀)(0.2..5) +# 0.8 +function massof(μ::Lebesgue{<:BoundedReals}, s::Interval) + a = μ.support.lower + b = μ.support.upper + left = max(s.left, a) + right = min(s.right, b) + w = right - left + max(w, zero(w)) +end -@inline getdof(::Lebesgue) = static(1) +function smf(μ::Lebesgue{<:BoundedReals}, x) + clamp(x, μ.support.lower, μ.support.upper) +end -@inline checked_arg(::Lebesgue, x::Real) = x +smf(::Lebesgue{RealNumbers}, x) = x +smf(::Lebesgue{RealNumbers}) = identity +invsmf(::Lebesgue{RealNumbers}, x) = x +invsmf(::Lebesgue{RealNumbers}) = identity -@propagate_inbounds function checked_arg(::Lebesgue, x::Any) - @boundscheck throw(ArgumentError("Invalid variate type for measure")) -end +smf(::LebesgueBase, x) = x +smf(::LebesgueBase) = identity +invsmf(::LebesgueBase, x) = x +invsmf(::LebesgueBase) = identity diff --git a/src/primitives/trivial.jl b/src/primitives/trivial.jl index e82b4ffc..f224ccd0 100644 --- a/src/primitives/trivial.jl +++ b/src/primitives/trivial.jl @@ -5,3 +5,5 @@ struct TrivialMeasure <: PrimitiveMeasure end gentype(::TrivialMeasure) = Nothing insupport(::TrivialMeasure, x) = False + +massof(::TrivialMeasure) = static(0.0) diff --git a/src/proxies.jl b/src/proxies.jl index ae3bd6d2..8dc4d3a3 100644 --- a/src/proxies.jl +++ b/src/proxies.jl @@ -19,5 +19,13 @@ macro useproxy(M) @inline $MeasureBase.basemeasure(μ::$M) = basemeasure(proxy(μ)) @inline $MeasureBase.basemeasure_depth(μ::$M) = basemeasure_depth(proxy(μ)) + + @inline $MeasureBase.transport_origin(μ::$M) = transport_origin(proxy(μ)) + @inline $MeasureBase.to_origin(μ::$M, y) = to_origin(proxy(μ), y) + @inline $MeasureBase.from_origin(μ::$M, x) = from_origin(proxy(μ), x) + + @inline $MeasureBase.massof(μ::$M) = massof(proxy(μ)) + @inline $MeasureBase.massof(μ::$M, s) = massof(proxy(μ), s) + (μ::$M)(s) = proxy(μ)(s) end end diff --git a/src/smf.jl b/src/smf.jl new file mode 100644 index 00000000..6ffa6f41 --- /dev/null +++ b/src/smf.jl @@ -0,0 +1,40 @@ +@doc raw""" + smf(μ, x::Real) ::Real + +Compute the _Stieltjes measure function (SMF)_ of the measure `μ` at the point +`x`. + +The SMF is the measure-theoretic generalization of the _cumulative distribution +function (CDF)_ from probability theory. An SMF `F(x) = smf(μ, x)` must have the +following properties: + +1. F is _nondecreasing_ +2. F is _right-continuous_: `F(x)` should be the same as `lim_{δ→0} F(x + |δ|)`. +3. μ((a,b]) = F(b) - F(a) + +Note that unlike the CDF, an SMF is only determined up to addition by a +constant. For many applications, this leads to a need to evaluate an SMF at -∞. +It's therefore important that `smf(μ, -Inf)` be fast. In practice, this will +usually be called as `smf(μ, static(-Inf))`. It's then easy to ensure speed and +avoid complex control flow by adding a method `smf(μ::M, ::StaticFloat64{-Inf})`. + +Users who pronounce `sinh` as "sinch" are advised to pronounce `smf` as "smurf". +""" +function smf end + +export smf + +function invsmf end + +export invsmf + +struct NoSMF end + +struct NoSMFInverse end + +smf(μ, x) = NoSMF() + +invsmf(μ, p) = NoSMFInverse() + +smf(μ) = Base.Fix1(smf, μ) +invsmf(μ) = Base.Fix1(invsmf, μ) diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index b038c587..a02c5765 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -5,7 +5,7 @@ export StdExponential insupport(d::StdExponential, x) = x ≥ zero(x) @inline logdensity_def(::StdExponential, x) = -x -@inline basemeasure(::StdExponential) = Lebesgue() +@inline basemeasure(::StdExponential) = LebesgueBase() @inline transport_def(::StdUniform, μ::StdExponential, x) = -expm1(-x) @inline transport_def(::StdExponential, μ::StdUniform, x) = -log1p(-x) diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index 31aa70f7..0d502ec6 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -5,11 +5,17 @@ export StdLogistic @inline insupport(d::StdLogistic, x) = true @inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2 * log1pexp(u)) -@inline basemeasure(::StdLogistic) = Lebesgue() +@inline basemeasure(::StdLogistic) = LebesgueBase() @inline transport_def(::StdUniform, μ::StdLogistic, x) = logistic(x) -@inline transport_def(::StdLogistic, μ::StdUniform, x) = logit(x) +@inline transport_def(::StdLogistic, μ::StdUniform, p) = logit(p) @inline function Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T} logit(rand(rng, T)) end + +smf(::StdLogistic, x) = logistic(x) +smf(::StdLogistic) = logistic + +invsmf(::StdLogistic, p) = logit(p) +invsmf(::StdLogistic) = logit diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 5af72106..2b50c1e5 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -2,6 +2,7 @@ abstract type StdMeasure <: AbstractMeasure end StdMeasure(::typeof(rand)) = StdUniform() StdMeasure(::typeof(randexp)) = StdExponential() +StdMeasure(::typeof(randn)) = StdNormal() @inline check_dof(::StdMeasure, ::StdMeasure) = nothing @@ -52,3 +53,71 @@ end @inline function transport_def(ν::PowerMeasure{<:StdMeasure}, ::Dirac, ::Any) Zeros{Bool}(map(_ -> 0, ν.axes)) end + +# Helpers for product transforms and similar: + +struct _TransportToStd{NU<:StdMeasure} <: Function end +_TransportToStd{NU}(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x) + +struct _TransportFromStd{MU<:StdMeasure} <: Function end +_TransportFromStd{MU}(ν, x) where {MU} = transport_to(ν, MU()^getdof(ν))(x) + +function _tuple_transport_def( + ν::PowerMeasure{NU}, + μs::Tuple, + xs::Tuple, +) where {NU<:StdMeasure} + reshape(vcat(map(_TransportToStd{NU}, μs, xs)...), ν.axes) +end + +function transport_def( + ν::PowerMeasure{NU}, + μ::ProductMeasure{<:Tuple}, + x, +) where {NU<:StdMeasure} + _tuple_transport_def(ν, marginals(μ), x) +end + +function transport_def( + ν::PowerMeasure{NU}, + μ::ProductMeasure{<:NamedTuple{names}}, + x, +) where {NU<:StdMeasure,names} + _tuple_transport_def(ν, values(marginals(μ)), values(x)) +end + +@inline _offset_cumsum(s, x, y, rest...) = (s, _offset_cumsum(s + x, y, rest...)...) +@inline _offset_cumsum(s, x) = (s,) +@inline _offset_cumsum(s) = () + +function _stdvar_viewranges(μs::Tuple, startidx::Integer) + N = map(getdof, μs) + offs = _offset_cumsum(startidx, N...) + map((o, n) -> o:o+n-1, offs, N) +end + +function _tuple_transport_def( + νs::Tuple, + μ::PowerMeasure{MU}, + x::AbstractArray{<:Real}, +) where {MU<:StdMeasure} + vrs = _stdvar_viewranges(νs, firstindex(x)) + xs = map(r -> view(x, r), vrs) + map(_TransportFromStd{MU}, νs, xs) +end + +function transport_def( + ν::ProductMeasure{<:Tuple}, + μ::PowerMeasure{MU}, + x, +) where {MU<:StdMeasure} + _tuple_transport_def(marginals(ν), μ, x) +end + +function transport_def( + ν::ProductMeasure{<:NamedTuple{names}}, + μ::PowerMeasure{MU}, + x, +) where {MU<:StdMeasure,names} + NamedTuple{names}(_tuple_transport_def(values(marginals(ν)), μ, x)) +end diff --git a/src/standard/stdnormal.jl b/src/standard/stdnormal.jl new file mode 100644 index 00000000..dc9cac74 --- /dev/null +++ b/src/standard/stdnormal.jl @@ -0,0 +1,30 @@ +using SpecialFunctions: erfc, erfcinv +using IrrationalConstants: invsqrt2 + +struct StdNormal <: StdMeasure end + +export StdNormal + +@inline insupport(d::StdNormal, x) = true + +@inline logdensity_def(::StdNormal, x) = -x^2 / 2 +@inline basemeasure(::StdNormal) = WeightedMeasure(static(-0.5 * log2π), LebesgueBase()) + +@inline getdof(::StdNormal) = static(1) + +@inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdNormal) where {T} = randn(rng, T) + +Φ(z) = erfc(-z * invsqrt2) / 2 +Φinv(p) = -erfcinv(2 * p) * sqrt2 + +InverseFunctions.inverse(::typeof(Φ)) = Φinv +InverseFunctions.inverse(::typeof(Φinv)) = Φ + +smf(::StdNormal, x) = Φ(x) +invsmf(::StdNormal, p) = Φinv(p) + +smf(::StdNormal) = Φ +invsmf(::StdNormal) = Φinv + +transport_def(::StdNormal, ::StdUniform, p) = Φinv(p) +transport_def(::StdUniform, ::StdNormal, x) = Φ(x) diff --git a/src/standard/stduniform.jl b/src/standard/stduniform.jl index d29dce80..8817561e 100644 --- a/src/standard/stduniform.jl +++ b/src/standard/stduniform.jl @@ -5,6 +5,15 @@ export StdUniform insupport(d::StdUniform, x) = zero(x) ≤ x ≤ one(x) @inline logdensity_def(::StdUniform, x) = zero(x) -@inline basemeasure(::StdUniform) = Lebesgue() +@inline basemeasure(::StdUniform) = LebesgueBase() Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdUniform) where {T} = rand(rng, T) + +massof(::StdUniform, s::Interval) = massof(Lebesgue(𝕀), s::Interval) + +smf(::StdUniform, x) = clamp(x, zero(x), one(x)) + +function invsmf(::StdUniform, p) + @assert zero(p) ≤ p ≤ one(p) + p +end diff --git a/src/transport.jl b/src/transport.jl index 08a6721c..cc73c07d 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -1,12 +1,12 @@ """ - struct MeasureBase.NoTransformOrigin{NU} + struct MeasureBase.NoTransportOrigin{NU} Indicates that no (default) pullback measure is available for measures of type `NU`. See [`MeasureBase.transport_origin`](@ref). """ -struct NoTransformOrigin{NU} end +struct NoTransportOrigin{NU} end """ MeasureBase.transport_origin(ν) @@ -16,7 +16,7 @@ between `ν` and another measure. """ function transport_origin end -transport_origin(ν::NU) where {NU} = NoTransformOrigin{NU}() +transport_origin(ν::NU) where {NU} = NoTransportOrigin{NU}() """ MeasureBase.from_origin(ν, x) @@ -25,7 +25,7 @@ Push `x` from `MeasureBase.transport_origin(μ)` forward to `ν`. """ function from_origin end -from_origin(ν::NU, ::Any) where {NU} = NoTransformOrigin{NU}() +from_origin(ν::NU, ::Any) where {NU} = NoTransportOrigin{NU}() """ MeasureBase.to_origin(ν, y) @@ -34,7 +34,7 @@ Pull `y` from `ν` back to `MeasureBase.transport_origin(ν)`. """ function to_origin end -to_origin(ν::NU, ::Any) where {NU} = NoTransformOrigin{NU}(ν) +to_origin(ν::NU, ::Any) where {NU} = NoTransportOrigin{NU}() """ struct MeasureBase.NoTransport{NU,MU} end @@ -62,7 +62,7 @@ The resulting function `f` should support so that densities of `ν` can be derived from densities of `μ` via `f` (using appropriate base measures). -Returns NoTransformOrigin{typeof(ν),typeof(μ)} if no transformation from +Returns NoTransportOrigin{typeof(ν),typeof(μ)} if no transformation from `μ` to `ν` can be found. To add transformation rules for a measure type `MyMeasure`, specialize @@ -92,6 +92,13 @@ distribution itself or a power of it (e.g. `StdUniform()` or """ function transport_to end +""" + transport_to(ν, μ, x) + +Transport `x` from the measure `μ` to the measure `ν` +""" +transport_to(ν, μ, x) = transport_to(ν, μ)(x) + """ transport_def(ν, μ, x) @@ -113,61 +120,83 @@ See [`transport_to`](@ref). """ function transport_def end -transport_def(::Any, ::Any, x::NoTransformOrigin) = x -transport_def(::Any, ::Any, x::NoTransport) = x - function transport_def(ν, μ, x) - _transport_with_intermediate( - ν, - _checked_transport_origin(ν), - _checked_transport_origin(μ), - μ, - x, - ) + _transport_between_origins(ν, _origin_depth(ν), _origin_depth(μ), μ, x) end -@inline _origin_must_have_separate_type(::Type{MU}, μ_o) where {MU} = μ_o -function _origin_must_have_separate_type(::Type{MU}, μ_o::MU) where {MU} - throw(ArgumentError("Measure of type $MU and its origin must have separate types")) -end - -@inline function _checked_transport_origin(μ::MU) where {MU} - μ_o = transport_origin(μ) - _origin_must_have_separate_type(MU, μ_o) -end - -function _transport_with_intermediate(ν, ν_o, μ_o, μ, x) - x_o = to_origin(μ, x) - # If μ is a pushforward then checked_arg may have been bypassed, so check now: - y_o = transport_def(ν_o, μ_o, checked_arg(μ_o, x_o)) - y = from_origin(ν, y_o) - return y +@inline function _origin_depth(ν::NU) where {NU} + ν_0 = ν + Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number + ν_{i} = transport_origin(ν_{i - 1}) + if ν_{i} isa NoTransportOrigin + return static(i - 1) + end + end + return static(10) end -function _transport_with_intermediate(ν, ν_o, ::NoTransformOrigin, μ, x) - y_o = transport_def(ν_o, μ, x) - y = from_origin(ν, y_o) - return y -end +_origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent() +ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback -function _transport_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x) - x_o = to_origin(μ, x) - # If μ is a pushforward then checked_arg may have been bypassed, so check now: - y = transport_def(ν, μ_o, checked_arg(μ_o, x_o)) - return y +# If both both measures have no origin: +function _transport_between_origins(ν, ::StaticInt{0}, ::StaticInt{0}, μ, x) + _transport_with_intermediate(ν, _transport_intermediate(ν, μ), μ, x) end -function _transport_with_intermediate(ν, ::NoTransformOrigin, ::NoTransformOrigin, μ, x) - _transport_with_intermediate(ν, _transport_intermediate(ν, μ), μ, x) +@generated function _transport_between_origins( + ν, + ::StaticInt{n_ν}, + ::StaticInt{n_μ}, + μ, + x, +) where {n_ν,n_μ} + prog = quote + μ0 = μ + x0 = x + ν0 = ν + end + for i in 1:n_μ + μ_i = Symbol(:μ, i) + μ_last = Symbol(:μ, i - 1) + push!(prog.args, :($μ_i = transport_origin($μ_last))) + end + for i in 1:n_μ + x_i = Symbol(:x, i) + x_last = Symbol(:x, i - 1) + μ_last = Symbol(:μ, i - 1) + push!(prog.args, :($x_i = to_origin($μ_last, $x_last))) + end + for i in 1:(n_ν) + ν_i = Symbol(:ν, i) + ν_last = Symbol(:ν, i - 1) + push!(prog.args, :($ν_i = transport_origin($ν_last))) + end + μ_im = Symbol(:μ, n_μ) + x_im = Symbol(:x, n_μ) + ν_im = Symbol(:ν, n_ν) + y_im = Symbol(:y, n_ν) + push!(prog.args, :($y_im = transport_def($ν_im, $μ_im, $x_im))) + for i in (n_ν-1):-1:0 + y_i = Symbol(:y, i) + y_last = Symbol(:y, i + 1) + ν_last = Symbol(:ν, i) + push!(prog.args, :($y_i = from_origin($ν_last, $y_last))) + end + push!(prog.args, :(return y0)) + return prog end @inline _transport_intermediate(ν, μ) = _transport_intermediate(getdof(ν), getdof(μ)) @inline _transport_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ @inline _transport_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform() +_call_transport_def(ν, μ, x) = transport_def(ν, μ, x) +_call_transport_def(::Any, ::Any, x::NoTransportOrigin) = x +_call_transport_def(::Any, ::Any, x::NoTransport) = x + function _transport_with_intermediate(ν, m, μ, x) - z = transport_def(m, μ, x) - y = transport_def(ν, m, z) + z = _call_transport_def(m, μ, x) + y = _call_transport_def(ν, m, z) return y end @@ -208,7 +237,7 @@ function Base.:(==)(a::TransportFunction, b::TransportFunction) end Base.@propagate_inbounds function (f::TransportFunction)(x) - return transport_def(f.ν, f.μ, checked_arg(f.μ, x)) + return _call_transport_def(f.ν, f.μ, checked_arg(f.μ, x)) end @inline function InverseFunctions.inverse(f::TransportFunction{NU,MU}) where {NU,MU} diff --git a/src/utils.jl b/src/utils.jl index fedca36c..b5db90c9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,7 +10,11 @@ showparams(io::IO, ::EmptyNamedTuple) = print(io, "()") showparams(io::IO, nt::NamedTuple) = print(io, nt) export testvalue -testvalue(μ::AbstractMeasure) = testvalue(basemeasure(μ)) + +@inline testvalue(μ) = rand(FixedRNG(), μ) + +@inline testvalue(::Type{T}, μ) where {T} = rand(FixedRNG(), T, μ) + testvalue(::Type{T}) where {T} = zero(T) export rootmeasure @@ -149,6 +153,7 @@ function rmap(f, nt::NamedTuple{N,T}) where {N,T} NamedTuple{N}(map(x -> rmap(f, x), values(nt))) end +insupport(m::AbstractMeasure) = Base.Fix1(insupport, m) @inline return_type(f, args::Tuple) = Core.Compiler.return_type(f, Tuple{typeof.(args)...}) diff --git a/test/combinators/superpose.jl b/test/combinators/superpose.jl index a77ff847..ed4c6996 100644 --- a/test/combinators/superpose.jl +++ b/test/combinators/superpose.jl @@ -10,13 +10,13 @@ using MeasureBase: superpose @test μs isa SuperpositionMeasure{<:Tuple{Dirac,Dirac}} @test μs == SuperpositionMeasure((μ, ν)) == superpose(μ, ν) @test density_def(μs, 0) == 1.0 - @test basemeasure(μs) == CountingMeasure() + CountingMeasure() + @test basemeasure(μs) == CountingBase() + CountingBase() μs = SuperpositionMeasure([μ, ν]) @test μs isa SuperpositionMeasure{<:AbstractVector{<:AbstractMeasure}} @test_throws ErrorException density_def(μs, 0) @test basemeasure(μs).components == - SuperpositionMeasure([CountingMeasure(), CountingMeasure()]).components + SuperpositionMeasure([CountingBase(), CountingBase()]).components μ2 = μ + μ @test μ2 isa WeightedMeasure diff --git a/test/combinators/transformedmeasure.jl b/test/combinators/transformedmeasure.jl index 4cdf7f0f..342edbed 100644 --- a/test/combinators/transformedmeasure.jl +++ b/test/combinators/transformedmeasure.jl @@ -1,21 +1,80 @@ using Test +using MeasureBase using MeasureBase: pushfwd, StdUniform, StdExponential, StdLogistic using MeasureBase: pushfwd, PushforwardMeasure -using MeasureBase: transport_to -using Statistics: var +using MeasureBase: transport_to, unsafe_logdensityof +import Statistics: var using DensityInterface: logdensityof +using LogExpFunctions +using SpecialFunctions: erfc, erfcinv +import InverseFunctions: inverse, FunctionWithInverse +using IrrationalConstants: invsqrt2, sqrt2 +import ChangesOfVariables: with_logabsdet_jacobian +using MeasureBase.Interface: transport_to, test_transport -@testset "transformedmeasure.jl" begin - μ = StdUniform() - @test @inferred(pushfwd((-) ∘ log1p ∘ (-), μ)) isa PushforwardMeasure - ν = pushfwd((-) ∘ log1p ∘ (-), μ) - ν_ref = StdExponential() +Φ(z) = erfc(-z * invsqrt2) / 2 +Φinv(p) = -erfcinv(2 * p) * sqrt2 + +with_logabsdet_jacobian(f::FunctionWithInverse, x) = with_logabsdet_jacobian(f.f, x) + +with_logabsdet_jacobian(::typeof(Φ), z) = (Φ(z), logdensityof(StdNormal(), z)) + +function with_logabsdet_jacobian(::typeof(Φinv), p) + z = Φinv(p) + (z, -logdensityof(StdNormal(), z)) +end + +inverse(::typeof(Φ)) = Φinv +inverse(::typeof(Φinv)) = Φ + +var(::StdNormal) = 1.0 +var(::StdExponential) = 1.0 +var(::StdUniform) = 1 / 12 +var(::StdLogistic) = π^2 / 3 - y = rand(ν_ref) - @test @inferred(logdensityof(ν, y)) ≈ logdensityof(ν_ref, y) +function test_pushfwd(f, μ, ν_ref) + @testset "pushfwd($f, $μ)" begin + @inferred(pushfwd(f, μ)) + ν = pushfwd(f, μ) + + test_transport(ν, ν_ref) + test_transport(ν_ref, ν) + + y = rand(ν_ref) + @test isapprox(@inferred(logdensityof(ν, y)), logdensityof(ν_ref, y), atol = 1e-10) + @test isapprox( + @inferred(unsafe_logdensityof(ν, y)), + unsafe_logdensityof(ν_ref, y), + atol = 1e-10, + ) + + @test isapprox(var(rand(ν^(10^5))), var(ν_ref), rtol = 0.05) + + @test abs(transport_to(StdLogistic(), ν)(y)) ≈ + abs(transport_to(StdLogistic(), ν_ref)(y)) + end +end + +@testset "transformedmeasure.jl" begin + # (f, μ, ν_ref), so that pushfwd(f, μ) ≅ ν_ref + triples = [ + ((-) ∘ log, StdUniform(), StdExponential()) + (exp ∘ (-), StdExponential(), StdUniform()) + (logit, StdUniform(), StdLogistic()) + (logistic, StdLogistic(), StdUniform()) + (Φ, StdNormal(), StdUniform()) + (Φinv, StdUniform(), StdNormal()) + ] - @test isapprox(var(rand(ν^(10^5))), 1, rtol = 0.05) + for (f, μ, ν_ref) in triples + test_pushfwd(f, μ, ν_ref) + end - @test transport_to(StdLogistic(), ν)(y) ≈ transport_to(StdLogistic(), ν)(y) + @testset "Pushforward-of-pushforward" begin + for (f, μ, ν_ref) in triples + finv = inverse(f) + test_pushfwd(finv, pushfwd(f, μ), μ) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 3578043d..f10bde45 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using LinearAlgebra import LogarithmicNumbers using MeasureBase -using MeasureBase: test_interface +using MeasureBase: test_interface, test_smf using Aqua Aqua.test_all(MeasureBase; ambiguities = false) @@ -76,7 +76,7 @@ test_measures = [ testbroken_measures = [ # InverseGamma(2) # Not defined yet # MvNormal(I(3)) # Entirely broken for now - CountingMeasure() + CountingBase() Likelihood TrivialMeasure() ] @@ -138,7 +138,7 @@ end # end @testset "broadcasting" begin - @test logdensityof.(Dirac(2), [1,2,3]) isa Vector{Float64} + @test logdensityof.(Dirac(2), [1, 2, 3]) isa Vector{Float64} end @testset "powers" begin @@ -146,8 +146,9 @@ end @test logdensityof(Lebesgue()^3, 2) == logdensityof(Lebesgue()^(3, 1), (2, 0)) end +Normal() = ∫exp(x -> -0.5x^2, Lebesgue(ℝ)) + @testset "Half" begin - Normal() = ∫exp(x -> -0.5x^2, Lebesgue(ℝ)) HalfNormal() = Half(Normal()) @test logdensityof(HalfNormal(), -0.2) == -Inf @test logdensity_def(HalfNormal(), 0.2) == logdensity_def(Normal(), 0.2) @@ -233,23 +234,20 @@ end @testset "Density measures and Radon-Nikodym" begin x = randn() f(x) = x^2 - @test logdensityof(𝒹(∫exp(f, Lebesgue()), Lebesgue()), x) ≈ f(x) + @test log(𝒹(∫exp(f, Lebesgue()), Lebesgue())(x)) ≈ f(x) let f = 𝒹(∫exp(x -> x^2, Lebesgue()), Lebesgue()) - @test logdensityof(f, x) ≈ x^2 + @test log(f(x)) ≈ x^2 end - # let d = ∫exp(log𝒹(Cauchy(), Normal()), Normal()) - # @test logdensity_def(d, x) ≈ logdensity_def(Cauchy(), x) - # end - - # let f = log𝒹(∫exp(x -> x^2, Normal()), Normal()) - # @test f(x) ≈ x^2 - # end + let f = log𝒹(∫exp(x -> x^2, Normal()), Normal()) + @test f(x) ≈ x^2 + end end include("getdof.jl") include("transport.jl") +include("smf.jl") include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") diff --git a/test/smf.jl b/test/smf.jl new file mode 100644 index 00000000..e4b0905e --- /dev/null +++ b/test/smf.jl @@ -0,0 +1,14 @@ +smf_measures = [ + StdNormal() + StdLogistic() + StdUniform() + Half(StdNormal()) + Half(StdLogistic()) + Half(StdUniform()) +] + +@testset "smf" begin + for μ in smf_measures + test_smf(μ) + end +end diff --git a/test/transport.jl b/test/transport.jl index da6377e5..874cb88b 100644 --- a/test/transport.jl +++ b/test/transport.jl @@ -1,21 +1,35 @@ using Test using MeasureBase.Interface: transport_to, test_transport -using MeasureBase: StdUniform, StdExponential, StdLogistic +using MeasureBase: StdUniform, StdExponential, StdLogistic, StdNormal using MeasureBase: Dirac +using LogExpFunctions: logit + +using ChainRulesTestUtils @testset "transport_to" begin - for μ0 in [StdUniform(), StdExponential(), StdLogistic()], - ν0 in [StdUniform(), StdExponential(), StdLogistic()] + test_rrule(MeasureBase._origin_depth, pushfwd(exp, StdUniform())) + + for (f, μ) in [ + (logit, StdUniform()) + (log, StdExponential()) + (exp, StdNormal()) + ] + test_transport(μ, pushfwd(f, μ)) + test_transport(pushfwd(f, μ), μ) + end + + for μ0 in [StdUniform(), StdExponential(), StdLogistic(), StdNormal()], + ν0 in [StdUniform(), StdExponential(), StdLogistic(), StdNormal()] @testset "transport_to (variations of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin test_transport(ν0, μ0) - test_transport(2.2 * ν0, 3 * μ0) + test_transport(2.2 * ν0, 2.2 * μ0) test_transport(ν0, μ0^1) test_transport(ν0^1, μ0) test_transport(ν0^3, μ0^3) test_transport(ν0^(2, 3, 2), μ0^(3, 4)) - test_transport(2.2 * ν0^(2, 3, 2), 3 * μ0^(3, 4)) + test_transport(2.2 * ν0^(2, 3, 2), 2.2 * μ0^(3, 4)) @test_throws ArgumentError transport_to(ν0, μ0)(rand(μ0^12)) @test_throws ArgumentError transport_to(ν0^3, μ0^3)(rand(μ0^(3, 4))) end @@ -41,4 +55,24 @@ using MeasureBase: Dirac @test @inferred(transport_to(StdUniform()^(2, 3), StdExponential)) == transport_to(StdUniform()^(2, 3), StdExponential()^6) end + + @testset "transport for products" begin + test_transport( + StdUniform()^(2, 2), + productmeasure((StdExponential(), StdLogistic()^3)), + ) + test_transport( + productmeasure((StdExponential(), StdLogistic()^3)), + StdUniform()^(2, 2), + ) + + test_transport( + StdUniform()^(2, 2), + productmeasure((a = StdExponential(), b = StdLogistic()^3)), + ) + test_transport( + productmeasure((a = StdExponential(), b = StdLogistic()^3)), + StdUniform()^(2, 2), + ) + end end