Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace AbstractDifferentiation with DifferentiationInterface #20

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .github/workflows/tests_and_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
julia-version: ['1.7', '1.8', '1.9', '1.10']
julia-version: ['1.10', '1']
threads: ['1', '2']
fail-fast: false
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.julia-version }}
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # if authenticating with GitHub Actions token
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # if authenticating with SSH deploy key
BUILD_DOCS: ${{ matrix.julia-version == '1.7' && matrix.threads == '1'}} # only build/deploy docs from one version
BUILD_DOCS: ${{ matrix.julia-version == '1' && matrix.threads == '1'}} # only build/deploy docs from one version
JULIA_NUM_THREADS: ${{ matrix.threads }}
timeout-minutes: 30
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "MuseInference"
uuid = "43b88160-90c7-4f71-933b-9d65205cd921"
authors = ["Marius Millea <[email protected]>"]
version = "0.2.4"
version = "0.3.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Expand All @@ -29,12 +30,12 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
AbstractDifferentiation = "0.5"
ADTypes = "1"
ComponentArrays = "0.12.3, 0.13, 0.14, 0.15"
CovarianceEstimation = "0.2.7"
DifferentiationInterface = "0.6.41"
Distributions = "0.25.36"
FileIO = "1.11"
FiniteDifferences = "0.12.20"
Expand All @@ -46,5 +47,4 @@ Optim = "1.5"
ProgressMeter = "1.7"
Requires = "1.2"
Setfield = "0.8.1, 1"
UnPack = "1.0.2"
julia = "1.7"
julia = "1.10"
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ First, load up the packages we'll need:

```@example 1
using MuseInference, Turing
using AbstractDifferentiation, Dates, LinearAlgebra, Printf, Plots, Random, Zygote
using ADTypes, Dates, LinearAlgebra, Printf, Plots, Random, Zygote
Turing.setadbackend(:zygote)
using Logging # hide
Logging.disable_logging(Logging.Info) # hide
Expand Down Expand Up @@ -164,7 +164,7 @@ prob = SimpleMuseProblem(
function logPrior(θ)
-θ^2/(2*3^2)
end;
autodiff = AbstractDifferentiation.ZygoteBackend()
autodiff = ADTypes.AutoZygote()
)
nothing # hide
```
Expand Down
6 changes: 3 additions & 3 deletions src/MuseInference.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module MuseInference

import AbstractDifferentiation as AD
using ADTypes: ADTypes
import DifferentiationInterface as DI
using Base.Iterators: repeated
using ComponentArrays
using CovarianceEstimation
Expand All @@ -25,7 +26,6 @@ using Setfield
using Statistics
using Test
using TOML
using UnPack

export SimpleMuseProblem, MuseResult, muse, muse!, get_J!, get_H!

Expand Down Expand Up @@ -56,4 +56,4 @@ end
end
end

end
end
12 changes: 3 additions & 9 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@


# some convenient type-piracy for scalars
AD.gradient(ad::AD.AbstractBackend, f, x::Real) = AD.derivative(ad, f, x)
AD.hessian(ad::AD.AbstractBackend, f, x::Real) = first.(AD.hessian(ad, f∘first, [x]))

function optim_only_fg!(func, autodiff)
function optim_only_fg!(func, backend::ADTypes.AbstractADType)
Optim.only_fg!() do F, G, z
if G != nothing
f, g = AD.value_and_gradient(autodiff, func, z)
f, g = DI.value_and_gradient(func, backend, z)
G .= first(g)
return f
end
if F != nothing
return func(z)
end
end
end
end
1 change: 0 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,3 @@ function check_self_consistency(
@test ∇θ_logLike(prob, x, z, θ, UnTransformedθ()) ≈ J(θ)' * ∇θ_logLike(prob, x, z, transform_θ(prob, θ), Transformedθ()) .+ ∇θ_V(θ) atol=atol
end
end

57 changes: 34 additions & 23 deletions src/muse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ function muse!(
ẑs = getindex.(gẑs, :ẑ)

g_like′ = g_like_dat′ .- mean(g_like_sims′)
g_prior′ = AD.gradient(AD.ForwardDiffBackend(), θ′ -> logPriorθ(prob, θ′, Transformedθ()), θ′)[1]
g_prior′ = DI.gradient(ADTypes.AutoForwardDiff(), θ′) do θ′
logPriorθ(prob, θ′, Transformedθ())
end
g_post′ = g_like′ .+ g_prior′

# Jacobian
Expand All @@ -204,7 +206,9 @@ function muse!(
end
end

H_prior′ = AD.hessian(AD.ForwardDiffBackend(), θ′ -> logPriorθ(prob, θ′, Transformedθ()), θ′)[1]
H_prior′ = DI.hessian(ADTypes.AutoForwardDiff(), θ′) do θ′
logPriorθ(prob, θ′, Transformedθ())
end
H⁻¹_post′ = inv(inv(H⁻¹_like′) + H_prior′)

t = now() - t₀
Expand Down Expand Up @@ -288,8 +292,8 @@ Keyword arguments:
differentiation, rather than finite differences. Will require 2nd
order AD through your `logLike` so pay close attention to your
`prob.autodiff`. Either
`AD.HigherOrderBackend((AD.ForwardDiffBackend(),
AD.ZygoteBackend()))` or `AD.ForwardDiffBackend()` are recommended
`DifferentiationInterface.SecondOrder(ADTypes.AutoForwardDiff(),
ADTypes.AutoZygote())` or `ADTypes.AutoForwardDiff()` are recommended
(default: `false`)

"""
Expand Down Expand Up @@ -347,35 +351,40 @@ function get_H!(
end
T = eltype(z_start)

ad_fwd, ad_rev = AD.second_lowest(prob.autodiff), AD.lowest(prob.autodiff)
ad_fwd, ad_rev = if prob.autodiff isa DI.SecondOrder
# assume forward-over-reverse is provided
DI.outer(prob.autodiff), DI.inner(prob.autodiff)
else
prob.autodiff, prob.autodiff
end

## non-implicit-diff term
H1 = implicit_diff_H1_is_zero ? 𝟘 : copyto!(similar(𝟘), first(AD.jacobian(θ₀, backend=ad_fwd) do θ
H1 = implicit_diff_H1_is_zero ? 𝟘 : copyto!(similar(𝟘), DI.jacobian(ad_fwd, θ₀) do θ
local x, = sample_x_z(prob, copy(rng), θ)
first(AD.gradient(θ₀, backend=ad_rev) do θ′
DI.gradient(ad_rev, θ₀) do θ′
logLike(prob, x, ẑ, θ′, UnTransformedθ())
end)
end))
end
end)

## term involving dzMAP/dθ via implicit-diff (w/ conjugate-gradient linear solve)
dFdθ = first(AD.jacobian(θ₀, backend=ad_fwd) do θ
first(AD.gradient(ẑ, backend=ad_rev) do z
dFdθ = DI.jacobian(ad_fwd, θ₀) do θ
DI.gradient(ad_rev, ẑ) do z
logLike(prob, x, z, θ, UnTransformedθ())
end)
end)
dFdθ1 = first(AD.jacobian(θ₀, backend=ad_fwd) do θ
end
end
dFdθ1 = DI.jacobian(ad_fwd, θ₀) do θ
local x, = sample_x_z(prob, copy(rng), θ)
first(AD.gradient(ẑ, backend=ad_rev) do z
DI.gradient(ad_rev, ẑ) do z
logLike(prob, x, z, θ₀, UnTransformedθ())
end)
end)
end
end
# A is the operation of the Hessian of logLike w.r.t. z
A = LinearMap{T}(length(z_start), isposdef=true, issymmetric=true, ishermitian=true) do w
first(AD.jacobian(0, backend=ad_fwd) do α
first(AD.gradient(ẑ + α * w, backend=ad_rev) do z
DI.jacobian(ad_fwd, 0) do α
DI.gradient(ad_rev, ẑ + α * w) do z
logLike(prob, x, z, θ₀, UnTransformedθ())
end)
end)
end
end
end
A⁻¹_dFdθ1 = pmap(pool_jac, eachcol(dFdθ1)) do w
A⁻¹_w = cg(A, w; implicit_diff_cg_kwargs..., log=true)
Expand Down Expand Up @@ -533,10 +542,12 @@ end


function finalize_result!(result::MuseResult, prob::AbstractMuseProblem)
@unpack H, J, θ = result
(; H, J, θ) = result
if H != nothing && J != nothing && θ != nothing
𝟘 = zero(J) # if θ::ComponentArray, helps keep component labels
H_prior = -AD.hessian(AD.ForwardDiffBackend(), θ -> logPriorθ(prob, θ, UnTransformedθ()), result.θ)[1]
H_prior = -DI.hessian(ADTypes.AutoForwardDiff(), result.θ) do θ
logPriorθ(prob, θ, UnTransformedθ())
end
result.Σ⁻¹ = H' * inv(J) * H + H_prior + 𝟘
result.Σ = inv(result.Σ⁻¹) + 𝟘
if length(result.θ) == 1
Expand Down
14 changes: 7 additions & 7 deletions src/simple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct SimpleMuseProblem{X,S,L,Gθ,Pθ,GZ,A} <: AbstractMuseProblem
end

@doc doc"""
SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=(θ->0); ad=AD.ForwardDiffBackend())
SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=(θ->0); ad=ADTypes.AutoForwardDiff())

Specify a MUSE problem by providing the simulation and posterior
evaluation code by-hand. The argument `x` should be the observed data.
Expand Down Expand Up @@ -44,8 +44,8 @@ end

and should return the prior $\log\mathcal{P}(\theta)$ for your
problem. The `autodiff` parameter should be either
`MuseInference.ForwardDiffBackend()` or
`MuseInference.ZygoteBackend()`, specifying which library to use for
`ADTypes.AutoForwardDiff()` or
`ADTypes.AutoZygote()`, specifying which library to use for
automatic differenation through `logLike`.


Expand All @@ -69,20 +69,20 @@ prob = SimpleMuseProblem(
function logPrior(θ)
-θ^2/(2*3^2)
end;
autodiff = MuseInference.ZygoteBackend()
autodiff = ADTypes.AutoZygote()
)

# get solution
muse(prob, (θ=1,))
```
"""
function SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=(θ->0); autodiff::AD.AbstractBackend=AD.ForwardDiffBackend())
function SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=(θ->0); autodiff::ADTypes.AbstractADType=ADTypes.AutoForwardDiff())
SimpleMuseProblem(
x,
sample_x_z,
logLike,
(x,z,θ) -> first(AD.gradient(autodiff, θ -> logLike(x,z,θ), θ)),
(x,z,θ) -> first.(AD.value_and_gradient(autodiff, z -> logLike(x,z,θ), z)),
(x,z,θ) -> DI.gradient(θ -> logLike(x,z,θ), autodiff, θ),
(x,z,θ) -> DI.value_and_gradient(z -> logLike(x,z,θ), autodiff, z),
logPriorθ,
autodiff
)
Expand Down
22 changes: 13 additions & 9 deletions src/soss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import .Soss.SimpleGraphs as SG

export SossMuseProblem

struct SossMuseProblem{A<:AD.AbstractBackend, M<:Soss.AbstractModel, MP<:Soss.AbstractModel} <: AbstractMuseProblem
struct SossMuseProblem{A<:ADTypes.AbstractADType, M<:Soss.AbstractModel, MP<:Soss.AbstractModel} <: AbstractMuseProblem
autodiff :: A
model :: M
model_for_prior :: MP
Expand All @@ -18,7 +18,7 @@ struct SossMuseProblem{A<:AD.AbstractBackend, M<:Soss.AbstractModel, MP<:Soss.Ab
end

@doc doc"""
SossMuseProblem(model; params, autodiff = ForwardDiffBackend())
SossMuseProblem(model; params, autodiff = AutoForwardDiff())

Specify a MUSE problem with a
[Soss](https://github.com/cscherrer/Soss.jl) model.
Expand All @@ -31,8 +31,8 @@ as a list of symbols. All other non-conditioned and non-`params`
variables will be considered the latent space.

The `autodiff` parameter should be either
`MuseInference.ForwardDiffBackend()` or
`MuseInference.ZygoteBackend()`, specifying which library to use for
`ADTypes.AutoForwardDiff()` or
`ADTypes.AutoZygote()`, specifying which library to use for
automatic differenation.

## Example
Expand Down Expand Up @@ -64,7 +64,7 @@ result = muse(prob, (θ=0,))
function SossMuseProblem(
model::Soss.ConditionalModel;
params = leaf_params(model),
autodiff = ForwardDiffBackend()
autodiff = ADTypes.AutoForwardDiff()
)
x = model.obs
!isempty(x) || error("Model must be conditioned on observed data.")
Expand Down Expand Up @@ -110,16 +110,20 @@ end

function ∇θ_logLike(prob::SossMuseProblem, x, z::AbstractVector, θ::ComponentVector, ::UnTransformedθ)
like = prob.model | (;x..., TV.transform(prob.xform_z, z)...)
first(AD.gradient(prob.autodiff, θ -> Soss.logdensityof(like, _namedtuple(θ)), θ))
DI.gradient(θ -> Soss.logdensityof(like, _namedtuple(θ)), prob.autodiff, θ)
end
function ∇θ_logLike(prob::SossMuseProblem, x, z::AbstractVector, θ::AbstractVector, ::Transformedθ)
like = prob.model | (;x..., TV.transform(prob.xform_z, z)...)
first(AD.gradient(prob.autodiff, θ -> Soss.logdensityof(like, _namedtuple(inv_transform_θ(prob, θ))), θ))
DI.gradient(prob.autodiff, θ) do θ
Soss.logdensityof(like, _namedtuple(inv_transform_θ(prob, θ)))
end
end


function logLike_and_∇z_logLike(prob::SossMuseProblem, x, z, θ)
first.(AD.value_and_gradient(prob.autodiff, z -> Soss.logdensityof(prob.model | (;x..., _namedtuple(θ)...), TV.transform(prob.xform_z, z)), z))
DI.value_and_gradient(prob.autodiff, z) do z
Soss.logdensityof(prob.model | (;x..., _namedtuple(θ)...), TV.transform(prob.xform_z, z))
end
end

function sample_x_z(prob::SossMuseProblem, rng::AbstractRNG, θ)
Expand Down Expand Up @@ -150,4 +154,4 @@ function get_J!(result::MuseResult, model::Soss.ConditionalModel, θ₀ = result
end
function get_H!(result::MuseResult, model::Soss.ConditionalModel, θ₀ = result.θ; kwargs...)
get_H!(result, SossMuseProblem(model, params=_params_from_θ₀(θ₀)), θ₀; kwargs...)
end
end
Loading