Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable autograd using Zygote #65

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
UnitfulAstro = "6112ee07-acf9-5e0f-b108-d242c714bf9f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
103 changes: 103 additions & 0 deletions examples/autodiff_example.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
using GWecc
using Zygote
using PyPlot

year = 365.25 * 24 * 3600
toas = LinRange(0, 10 * year, 5)

theta = π / 3
phi = π / 4
psrdist = 1.0
cos_gwtheta = 0.3
gwphi = π / 5
psi = 0.0
cos_inc = 0.5
log10_M = 8.5
eta = 0.2
log10_F = -8.0
e0 = 0.3
gamma0 = gammap = 0.0
l0 = lp = 0.0
tref = maximum(toas)
log10_A = -8.0

psrTerm = false
spline = false

eccentric_pta_signal_for_ad(toas, tref, psrTerm, spline) =
(
theta,
phi,
psrdist,
cos_gwtheta,
gwphi,
psi,
cos_inc,
log10_M,
eta,
log10_F,
e0,
gamma0,
l0,
log10_A,
) -> eccentric_pta_signal(
toas,
theta,
phi,
psrdist,
cos_gwtheta,
gwphi,
psi,
cos_inc,
log10_M,
eta,
log10_F,
e0,
gamma0,
gamma0,
l0,
l0,
tref,
log10_A,
psrTerm,
spline,
)

gwe = eccentric_pta_signal_for_ad(toas, tref, psrTerm, spline)
Rs = gwe(
theta,
phi,
psrdist,
cos_gwtheta,
gwphi,
psi,
cos_inc,
log10_M,
eta,
log10_F,
e0,
gamma0,
l0,
log10_A,
)

PyPlot.plot(toas, Rs)
PyPlot.show()

J = jacobian(
gwe,
theta,
phi,
psrdist,
cos_gwtheta,
gwphi,
psi,
cos_inc,
log10_M,
eta,
log10_F,
e0,
gamma0,
l0,
log10_A,
)
16 changes: 10 additions & 6 deletions src/paramutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@ using PhysicalConstants.CODATA2018
const GMsun = UnitfulAstro.GMsun
const c_0 = CODATA2018.c_0

const Msun_to_s = uconvert(u"s", GMsun / c_0^3).val
const kpc_to_s = uconvert(u"s", u"kpc" / c_0).val
const Mpc_to_s = uconvert(u"s", u"Mpc" / c_0).val
const year_to_s = 365.25 * 24 * 3600

function mass_from_log10_mass(log10_M::Float64, eta::Float64)::Mass
M::Float64 = uconvert(u"s", (10.0^log10_M * GMsun) / c_0^3).val
M::Float64 = 10.0^log10_M * Msun_to_s
return Mass(M, eta)
end

Expand All @@ -22,18 +27,17 @@ function mean_motion_from_log10_freq(log10_F::Float64)::MeanMotion
end

function psrdist_from_kpc(pdist::Float64)::Distance
dp = uconvert(u"s", pdist * u"kpc" / c_0).val
dp = pdist * kpc_to_s
return Distance(dp)
end

function dl_from_gwdist(gwdist::Float64)::Distance
dp = uconvert(u"s", gwdist * u"Mpc" / c_0).val
return Distance(dp)
dgw = gwdist * Mpc_to_s
return Distance(dgw)
end

function Δp_from_deltap(deltap::Float64)::Time
year = 365.25 * 24 * 3600
return Time(-deltap * year)
return Time(-deltap * year_to_s)
end

# function mean_motion_from_log10_sidereal_freq(
Expand Down
12 changes: 12 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using LinearAlgebra
using UnPack
using NumericalIntegration
using Statistics
using Zygote


e_from_τ_from_e(ecc::Float64)::Float64 = e_from_τ(τ_from_e(Eccentricity(ecc))).e
Expand Down Expand Up @@ -911,4 +912,15 @@ e_from_τ_from_e(ecc::Float64)::Float64 = e_from_τ(τ_from_e(Eccentricity(ecc))
@test m1.m ≈ m2.m atol = 1e-6
end

@testset "autodiff" begin
m = 1000.0
η = 0.2

@testset "parameters and conversions" begin
chirp_mass = (m, η) -> Mass(m, η).Mch
Mch = chirp_mass(m, η)
∂Mch_∂m, ∂Mch_∂η = gradient(chirp_mass, m, η)
@test ∂Mch_∂m ≈ Mch / m && ∂Mch_∂η ≈ (3 / 5) * Mch / η
end
end
end