diff --git a/Project.toml b/Project.toml index e7d630e..6d77bea 100644 --- a/Project.toml +++ b/Project.toml @@ -17,3 +17,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" UnitfulAstro = "6112ee07-acf9-5e0f-b108-d242c714bf9f" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/autodiff_example.jl b/examples/autodiff_example.jl new file mode 100644 index 0000000..65527b5 --- /dev/null +++ b/examples/autodiff_example.jl @@ -0,0 +1,103 @@ +using GWecc +using Zygote +using PyPlot + +year = 365.25 * 24 * 3600 +toas = LinRange(0, 10 * year, 5) + +theta = π / 3 +phi = π / 4 +psrdist = 1.0 +cos_gwtheta = 0.3 +gwphi = π / 5 +psi = 0.0 +cos_inc = 0.5 +log10_M = 8.5 +eta = 0.2 +log10_F = -8.0 +e0 = 0.3 +gamma0 = gammap = 0.0 +l0 = lp = 0.0 +tref = maximum(toas) +log10_A = -8.0 + +psrTerm = false +spline = false + +eccentric_pta_signal_for_ad(toas, tref, psrTerm, spline) = + ( + theta, + phi, + psrdist, + cos_gwtheta, + gwphi, + psi, + cos_inc, + log10_M, + eta, + log10_F, + e0, + gamma0, + l0, + log10_A, + ) -> eccentric_pta_signal( + toas, + theta, + phi, + psrdist, + cos_gwtheta, + gwphi, + psi, + cos_inc, + log10_M, + eta, + log10_F, + e0, + gamma0, + gamma0, + l0, + l0, + tref, + log10_A, + psrTerm, + spline, + ) + +gwe = eccentric_pta_signal_for_ad(toas, tref, psrTerm, spline) +Rs = gwe( + theta, + phi, + psrdist, + cos_gwtheta, + gwphi, + psi, + cos_inc, + log10_M, + eta, + log10_F, + e0, + gamma0, + l0, + log10_A, +) + +PyPlot.plot(toas, Rs) +PyPlot.show() + +J = jacobian( + gwe, + theta, + phi, + psrdist, + cos_gwtheta, + gwphi, + psi, + cos_inc, + log10_M, + eta, + log10_F, + e0, + gamma0, + l0, + log10_A, +) diff --git a/src/paramutils.jl b/src/paramutils.jl index fdca03f..3c77de4 100644 --- a/src/paramutils.jl +++ b/src/paramutils.jl @@ -12,8 +12,13 @@ using PhysicalConstants.CODATA2018 const GMsun = UnitfulAstro.GMsun const c_0 = CODATA2018.c_0 +const Msun_to_s = uconvert(u"s", GMsun / c_0^3).val +const kpc_to_s = uconvert(u"s", u"kpc" / c_0).val +const Mpc_to_s = uconvert(u"s", u"Mpc" / c_0).val +const year_to_s = 365.25 * 24 * 3600 + function mass_from_log10_mass(log10_M::Float64, eta::Float64)::Mass - M::Float64 = uconvert(u"s", (10.0^log10_M * GMsun) / c_0^3).val + M::Float64 = 10.0^log10_M * Msun_to_s return Mass(M, eta) end @@ -22,18 +27,17 @@ function mean_motion_from_log10_freq(log10_F::Float64)::MeanMotion end function psrdist_from_kpc(pdist::Float64)::Distance - dp = uconvert(u"s", pdist * u"kpc" / c_0).val + dp = pdist * kpc_to_s return Distance(dp) end function dl_from_gwdist(gwdist::Float64)::Distance - dp = uconvert(u"s", gwdist * u"Mpc" / c_0).val - return Distance(dp) + dgw = gwdist * Mpc_to_s + return Distance(dgw) end function Δp_from_deltap(deltap::Float64)::Time - year = 365.25 * 24 * 3600 - return Time(-deltap * year) + return Time(-deltap * year_to_s) end # function mean_motion_from_log10_sidereal_freq( diff --git a/test/runtests.jl b/test/runtests.jl index 8838603..26be081 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using LinearAlgebra using UnPack using NumericalIntegration using Statistics +using Zygote e_from_τ_from_e(ecc::Float64)::Float64 = e_from_τ(τ_from_e(Eccentricity(ecc))).e @@ -911,4 +912,15 @@ e_from_τ_from_e(ecc::Float64)::Float64 = e_from_τ(τ_from_e(Eccentricity(ecc)) @test m1.m ≈ m2.m atol = 1e-6 end + @testset "autodiff" begin + m = 1000.0 + η = 0.2 + + @testset "parameters and conversions" begin + chirp_mass = (m, η) -> Mass(m, η).Mch + Mch = chirp_mass(m, η) + ∂Mch_∂m, ∂Mch_∂η = gradient(chirp_mass, m, η) + @test ∂Mch_∂m ≈ Mch / m && ∂Mch_∂η ≈ (3 / 5) * Mch / η + end + end end