Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace AbstractDifferentiation with DifferentiationInterface #20

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "MuseInference"
uuid = "43b88160-90c7-4f71-933b-9d65205cd921"
authors = ["Marius Millea <[email protected]>"]
version = "0.2.4"
version = "0.3.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Expand All @@ -32,9 +33,10 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
AbstractDifferentiation = "0.5"
ADTypes = "1"
ComponentArrays = "0.12.3, 0.13, 0.14, 0.15"
CovarianceEstimation = "0.2.7"
DifferentiationInterface = "0.5"
Distributions = "0.25.36"
FileIO = "1.11"
FiniteDifferences = "0.12.20"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ First, load up the packages we'll need:

```@example 1
using MuseInference, Turing
using AbstractDifferentiation, Dates, LinearAlgebra, Printf, Plots, Random, Zygote
using ADTypes, Dates, LinearAlgebra, Printf, Plots, Random, Zygote
Turing.setadbackend(:zygote)
using Logging # hide
Logging.disable_logging(Logging.Info) # hide
Expand Down Expand Up @@ -164,7 +164,7 @@ prob = SimpleMuseProblem(
function logPrior(θ)
-θ^2/(2*3^2)
end;
autodiff = AbstractDifferentiation.ZygoteBackend()
autodiff = ADTypes.AutoZygote()
)
nothing # hide
```
Expand Down
5 changes: 3 additions & 2 deletions src/MuseInference.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module MuseInference

import AbstractDifferentiation as AD
using ADTypes: ADTypes
import DifferentiationInterface as DI
using Base.Iterators: repeated
using ComponentArrays
using CovarianceEstimation
Expand Down Expand Up @@ -56,4 +57,4 @@ end
end
end

end
end
12 changes: 3 additions & 9 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@


# some convenient type-piracy for scalars
AD.gradient(ad::AD.AbstractBackend, f, x::Real) = AD.derivative(ad, f, x)
AD.hessian(ad::AD.AbstractBackend, f, x::Real) = first.(AD.hessian(ad, f∘first, [x]))

function optim_only_fg!(func, autodiff)
function optim_only_fg!(func, backend::ADTypes.AbstractADType)
Optim.only_fg!() do F, G, z
if G != nothing
f, g = AD.value_and_gradient(autodiff, func, z)
f, g = DI.value_and_gradient(func, backend, z)
G .= first(g)
return f
end
if F != nothing
return func(z)
end
end
end
end
1 change: 0 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,3 @@ function check_self_consistency(
@test ∇θ_logLike(prob, x, z, θ, UnTransformedθ()) ≈ J(θ)' * ∇θ_logLike(prob, x, z, transform_θ(prob, θ), Transformedθ()) .+ ∇θ_V(θ) atol=atol
end
end

55 changes: 33 additions & 22 deletions src/muse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ function muse!(
ẑs = getindex.(gẑs, :ẑ)

g_like′ = g_like_dat′ .- mean(g_like_sims′)
g_prior′ = AD.gradient(AD.ForwardDiffBackend(), θ′ -> logPriorθ(prob, θ′, Transformedθ()), θ′)[1]
g_prior′ = DI.gradient(ADTypes.AutoForwardDiff(), θ′) do θ′
logPriorθ(prob, θ′, Transformedθ())
end
g_post′ = g_like′ .+ g_prior′

# Jacobian
Expand All @@ -204,7 +206,9 @@ function muse!(
end
end

H_prior′ = AD.hessian(AD.ForwardDiffBackend(), θ′ -> logPriorθ(prob, θ′, Transformedθ()), θ′)[1]
H_prior′ = DI.hessian(ADTypes.AutoForwardDiff(), θ′) do θ′
logPriorθ(prob, θ′, Transformedθ())
end
H⁻¹_post′ = inv(inv(H⁻¹_like′) + H_prior′)

t = now() - t₀
Expand Down Expand Up @@ -288,8 +292,8 @@ Keyword arguments:
differentiation, rather than finite differences. Will require 2nd
order AD through your `logLike` so pay close attention to your
`prob.autodiff`. Either
`AD.HigherOrderBackend((AD.ForwardDiffBackend(),
AD.ZygoteBackend()))` or `AD.ForwardDiffBackend()` are recommended
`DifferentiationInterface.SecondOrder(ADTypes.AutoForwardDiff(),
ADTypes.AutoZygote())` or `ADTypes.AutoForwardDiff()` are recommended
(default: `false`)

"""
Expand Down Expand Up @@ -347,35 +351,40 @@ function get_H!(
end
T = eltype(z_start)

ad_fwd, ad_rev = AD.second_lowest(prob.autodiff), AD.lowest(prob.autodiff)
ad_fwd, ad_rev = if prob.autodiff isa DI.SecondOrder
# assume forward-over-reverse is provided
DI.outer(prob.autodiff), DI.inner(prob.autodiff)
else
prob.autodiff, prob.autodiff
end

## non-implicit-diff term
H1 = implicit_diff_H1_is_zero ? 𝟘 : copyto!(similar(𝟘), first(AD.jacobian(θ₀, backend=ad_fwd) do θ
H1 = implicit_diff_H1_is_zero ? 𝟘 : copyto!(similar(𝟘), DI.jacobian(ad_fwd, θ₀) do θ
local x, = sample_x_z(prob, copy(rng), θ)
first(AD.gradient(θ₀, backend=ad_rev) do θ′
DI.gradient(ad_rev, θ₀) do θ′
logLike(prob, x, ẑ, θ′, UnTransformedθ())
end)
end))
end
end)

## term involving dzMAP/dθ via implicit-diff (w/ conjugate-gradient linear solve)
dFdθ = first(AD.jacobian(θ₀, backend=ad_fwd) do θ
first(AD.gradient(ẑ, backend=ad_rev) do z
dFdθ = DI.jacobian(ad_fwd, θ₀) do θ
DI.gradient(ad_rev, ẑ) do z
logLike(prob, x, z, θ, UnTransformedθ())
end)
end)
dFdθ1 = first(AD.jacobian(θ₀, backend=ad_fwd) do θ
end
end
dFdθ1 = DI.jacobian(ad_fwd, θ₀) do θ
local x, = sample_x_z(prob, copy(rng), θ)
first(AD.gradient(ẑ, backend=ad_rev) do z
DI.gradient(ad_rev, ẑ) do z
logLike(prob, x, z, θ₀, UnTransformedθ())
end)
end)
end
end
# A is the operation of the Hessian of logLike w.r.t. z
A = LinearMap{T}(length(z_start), isposdef=true, issymmetric=true, ishermitian=true) do w
first(AD.jacobian(0, backend=ad_fwd) do α
first(AD.gradient(ẑ + α * w, backend=ad_rev) do z
DI.jacobian(ad_fwd, 0) do α
DI.gradient(ad_rev, ẑ + α * w) do z
logLike(prob, x, z, θ₀, UnTransformedθ())
end)
end)
end
end
end
A⁻¹_dFdθ1 = pmap(pool_jac, eachcol(dFdθ1)) do w
A⁻¹_w = cg(A, w; implicit_diff_cg_kwargs..., log=true)
Expand Down Expand Up @@ -536,7 +545,9 @@ function finalize_result!(result::MuseResult, prob::AbstractMuseProblem)
@unpack H, J, θ = result
if H != nothing && J != nothing && θ != nothing
𝟘 = zero(J) # if θ::ComponentArray, helps keep component labels
H_prior = -AD.hessian(AD.ForwardDiffBackend(), θ -> logPriorθ(prob, θ, UnTransformedθ()), result.θ)[1]
H_prior = -DI.hessian(ADTypes.AutoForwardDiff(), result.θ) do θ
logPriorθ(prob, θ, UnTransformedθ())
end
result.Σ⁻¹ = H' * inv(J) * H + H_prior + 𝟘
result.Σ = inv(result.Σ⁻¹) + 𝟘
if length(result.θ) == 1
Expand Down
14 changes: 7 additions & 7 deletions src/simple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct SimpleMuseProblem{X,S,L,Gθ,Pθ,GZ,A} <: AbstractMuseProblem
end

@doc doc"""
SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=(θ->0); ad=AD.ForwardDiffBackend())
SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=(θ->0); ad=ADTypes.AutoForwardDiff())

Specify a MUSE problem by providing the simulation and posterior
evaluation code by-hand. The argument `x` should be the observed data.
Expand Down Expand Up @@ -44,8 +44,8 @@ end

and should return the prior $\log\mathcal{P}(\theta)$ for your
problem. The `autodiff` parameter should be either
`MuseInference.ForwardDiffBackend()` or
`MuseInference.ZygoteBackend()`, specifying which library to use for
`ADTypes.AutoForwardDiff()` or
`ADTypes.AutoZygote()`, specifying which library to use for
automatic differenation through `logLike`.


Expand All @@ -69,20 +69,20 @@ prob = SimpleMuseProblem(
function logPrior(θ)
-θ^2/(2*3^2)
end;
autodiff = MuseInference.ZygoteBackend()
autodiff = ADTypes.AutoZygote()
)

# get solution
muse(prob, (θ=1,))
```
"""
function SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=(θ->0); autodiff::AD.AbstractBackend=AD.ForwardDiffBackend())
function SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=(θ->0); autodiff::ADTypes.AbstractADType=ADTypes.AutoForwardDiff())
SimpleMuseProblem(
x,
sample_x_z,
logLike,
(x,z,θ) -> first(AD.gradient(autodiff, θ -> logLike(x,z,θ), θ)),
(x,z,θ) -> first.(AD.value_and_gradient(autodiff, z -> logLike(x,z,θ), z)),
(x,z,θ) -> DI.gradient(θ -> logLike(x,z,θ), autodiff, θ),
(x,z,θ) -> DI.value_and_gradient(z -> logLike(x,z,θ), autodiff, z),
logPriorθ,
autodiff
)
Expand Down
22 changes: 13 additions & 9 deletions src/soss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import .Soss.SimpleGraphs as SG

export SossMuseProblem

struct SossMuseProblem{A<:AD.AbstractBackend, M<:Soss.AbstractModel, MP<:Soss.AbstractModel} <: AbstractMuseProblem
struct SossMuseProblem{A<:ADTypes.AbstractADType, M<:Soss.AbstractModel, MP<:Soss.AbstractModel} <: AbstractMuseProblem
autodiff :: A
model :: M
model_for_prior :: MP
Expand All @@ -18,7 +18,7 @@ struct SossMuseProblem{A<:AD.AbstractBackend, M<:Soss.AbstractModel, MP<:Soss.Ab
end

@doc doc"""
SossMuseProblem(model; params, autodiff = ForwardDiffBackend())
SossMuseProblem(model; params, autodiff = AutoForwardDiff())

Specify a MUSE problem with a
[Soss](https://github.com/cscherrer/Soss.jl) model.
Expand All @@ -31,8 +31,8 @@ as a list of symbols. All other non-conditioned and non-`params`
variables will be considered the latent space.

The `autodiff` parameter should be either
`MuseInference.ForwardDiffBackend()` or
`MuseInference.ZygoteBackend()`, specifying which library to use for
`ADTypes.AutoForwardDiff()` or
`ADTypes.AutoZygote()`, specifying which library to use for
automatic differenation.

## Example
Expand Down Expand Up @@ -64,7 +64,7 @@ result = muse(prob, (θ=0,))
function SossMuseProblem(
model::Soss.ConditionalModel;
params = leaf_params(model),
autodiff = ForwardDiffBackend()
autodiff = ADTypes.AutoForwardDiff()
)
x = model.obs
!isempty(x) || error("Model must be conditioned on observed data.")
Expand Down Expand Up @@ -110,16 +110,20 @@ end

function ∇θ_logLike(prob::SossMuseProblem, x, z::AbstractVector, θ::ComponentVector, ::UnTransformedθ)
like = prob.model | (;x..., TV.transform(prob.xform_z, z)...)
first(AD.gradient(prob.autodiff, θ -> Soss.logdensityof(like, _namedtuple(θ)), θ))
DI.gradient(θ -> Soss.logdensityof(like, _namedtuple(θ)), prob.autodiff, θ)
end
function ∇θ_logLike(prob::SossMuseProblem, x, z::AbstractVector, θ::AbstractVector, ::Transformedθ)
like = prob.model | (;x..., TV.transform(prob.xform_z, z)...)
first(AD.gradient(prob.autodiff, θ -> Soss.logdensityof(like, _namedtuple(inv_transform_θ(prob, θ))), θ))
DI.gradient(prob.autodiff, θ) do θ
Soss.logdensityof(like, _namedtuple(inv_transform_θ(prob, θ)))
end
end


function logLike_and_∇z_logLike(prob::SossMuseProblem, x, z, θ)
first.(AD.value_and_gradient(prob.autodiff, z -> Soss.logdensityof(prob.model | (;x..., _namedtuple(θ)...), TV.transform(prob.xform_z, z)), z))
DI.value_and_gradient(prob.autodiff, z) do z
Soss.logdensityof(prob.model | (;x..., _namedtuple(θ)...), TV.transform(prob.xform_z, z))
end
end

function sample_x_z(prob::SossMuseProblem, rng::AbstractRNG, θ)
Expand Down Expand Up @@ -150,4 +154,4 @@ function get_J!(result::MuseResult, model::Soss.ConditionalModel, θ₀ = result
end
function get_H!(result::MuseResult, model::Soss.ConditionalModel, θ₀ = result.θ; kwargs...)
get_H!(result, SossMuseProblem(model, params=_params_from_θ₀(θ₀)), θ₀; kwargs...)
end
end
16 changes: 9 additions & 7 deletions src/turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function DynPPL.maybe_invlink_before_eval!!(vi::DynPPL.SimpleVarInfo{NT,T,<:Part
end


struct TuringMuseProblem{A<:AD.AbstractBackend, M<:Turing.Model} <: AbstractMuseProblem
struct TuringMuseProblem{A<:ADTypes.AbstractADType, M<:Turing.Model} <: AbstractMuseProblem

autodiff :: A
model :: M
Expand Down Expand Up @@ -51,8 +51,8 @@ as a list of symbols. All other non-conditioned and non-`params`
variables will be considered the latent space.

The `autodiff` parameter should be either
`MuseInference.ForwardDiffBackend()` or
`MuseInference.ZygoteBackend()`, specifying which library to use for
`ADTypes.AutoForwardDiff()` or
`ADTypes.AutoZygote()`, specifying which library to use for
automatic differenation. The default uses whatever the global
`Turing.ADBACKEND` is currently set to.

Expand Down Expand Up @@ -118,14 +118,14 @@ function TuringMuseProblem(
# set backend based on Turing's by default
if autodiff == nothing
if Turing.ADBACKEND[] == :zygote
autodiff = AD.ZygoteBackend()
autodiff = ADTypes.AutoZygote()
elseif Turing.ADBACKEND[] == :forwarddiff
autodiff = AD.ForwardDiffBackend()
autodiff = ADTypes.AutoForwardDiff()
else
error("Unsupposed backend from Turing: $(Turing.ADBACKEND)")
end
end
if (Threads.nthreads() > 1) && hasmethod(AD.ZygoteBackend,Tuple{}) && (autodiff isa typeof(AD.ZygoteBackend()))
if (Threads.nthreads() > 1) && (autodiff isa typeof(ADTypes.AutoZygote()))
error("Turing doesn't support using the Zygote backend when Threads.nthreads()>1. Use a different backend or a single-thread.")
end

Expand Down Expand Up @@ -202,7 +202,9 @@ function logPriorθ(prob::TuringMuseProblem, θ, θ_space)
end

function ∇θ_logLike(prob::TuringMuseProblem, x, z, θ, θ_space)
first(AD.gradient(prob.autodiff, θ -> logLike(prob, x, z, θ, θ_space), θ))
DI.gradient(prob.autodiff, θ) do θ
logLike(prob, x, z, θ, θ_space)
end
end

function ẑ_at_θ(prob::TuringMuseProblem, x, z₀, θ; ∇z_logLike_atol)
Expand Down
7 changes: 1 addition & 6 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ function Random.randn!(rng::AbstractRNG, A::Array{<:ForwardDiff.Dual})
A .= randn!(rng, ForwardDiff.value.(A))
end

# type-piracy bc these make code much clearer to read. could be removed if
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/pull/62 is merged
AD.gradient(f, args...; backend::AD.AbstractBackend) = AD.gradient(backend, f, args...)
AD.jacobian(f, args...; backend::AD.AbstractBackend) = AD.jacobian(backend, f, args...)

# worker pool which just falls back to map
struct LocalWorkerPool <: AbstractWorkerPool end
Distributed.pmap(f, ::LocalWorkerPool, args...) = map(f, args...)
Expand All @@ -95,4 +90,4 @@ versionof(pkg::Module) = Pkg.dependencies()[Base.PkgId(pkg).uuid].version

# allow using InverseMap as an IterativeSolvers preconditioner
LinearAlgebra.ldiv!(dst::AbstractVector, A::InverseMap, src::AbstractVector) = mul!(dst, A.A, src)
LinearAlgebra.ldiv!(A::InverseMap, vec::AbstractVector) = copyto!(vec, mul!(A.A, vec))
LinearAlgebra.ldiv!(A::InverseMap, vec::AbstractVector) = copyto!(vec, mul!(A.A, vec))
6 changes: 3 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -11,8 +11,8 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Soss = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -23,4 +23,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Soss = "0.21.2"
Turing = "0.28"
Turing = "0.28"
Loading
Loading