Skip to content

Commit

Permalink
promote dual in master equations
Browse files Browse the repository at this point in the history
  • Loading branch information
apkille committed Aug 9, 2024
1 parent be847dd commit 247da73
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 28 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Arpack = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Expand Down
4 changes: 4 additions & 0 deletions src/master.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ function master_h(tspan, rho0::Operator, H::AbstractOperator, J;
_check_const.(J)
_check_const.(Jdagger)
check_master(rho0, H, J, Jdagger, rates)
tspan, rho0 = _promote_time_and_state(rho0, H, J, tspan)
tmp = copy(rho0)
dmaster_(t, rho, drho) = dmaster_h!(drho, H, J, Jdagger, rates, rho, tmp)
integrate_master(tspan, dmaster_, rho0, fout; kwargs...)
Expand Down Expand Up @@ -41,6 +42,7 @@ function master_nh(tspan, rho0::Operator, Hnh::AbstractOperator, J;
_check_const.(J)
_check_const.(Jdagger)
check_master(rho0, Hnh, J, Jdagger, rates)
tspan, rho0 = _promote_time_and_state(rho0, Hnh, J, tspan)
tmp = copy(rho0)
dmaster_(t, rho, drho) = dmaster_nh!(drho, Hnh, Hnhdagger, J, Jdagger, rates, rho, tmp)
integrate_master(tspan, dmaster_, rho0, fout; kwargs...)
Expand Down Expand Up @@ -86,6 +88,7 @@ function master(tspan, rho0::Operator, H::AbstractOperator, J;
_check_const(H)
_check_const.(J)
_check_const.(Jdagger)
tspan, rho0 = _promote_time_and_state(rho0, H, J, tspan)
isreducible = check_master(rho0, H, J, Jdagger, rates)
if !isreducible
tmp = copy(rho0)
Expand Down Expand Up @@ -124,6 +127,7 @@ function master(tspan, rho0::Operator, L::SuperOperator; fout=nothing, kwargs...
b = GenericBasis(dim)
rho_ = Ket(b,reshape(rho0.data, dim))
L_ = Operator(b,b,L.data)
tspan, rho_ = _promote_time_and_state(rho_, L_, tspan)
dmaster_(t,rho,drho) = dmaster_liouville!(drho,L_,rho)

# Rewrite into density matrix when saving
Expand Down
28 changes: 0 additions & 28 deletions src/schroedinger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,31 +122,3 @@ function check_schroedinger(psi::Bra, H)
check_multiplicable(psi, H)
check_samebases(H)
end


function _promote_time_and_state(u0, H::AbstractOperator, tspan)
Ts = eltype(H)
Tt = real(Ts)
p = Vector{Tt}(undef,0)
u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1])
tspan_promote = DiffEqBase.promote_tspan(u0_promote, p, tspan, nothing, Dict{Symbol, Any}())
return tspan_promote, u0_promote
end
_promote_time_and_state(u0, f, tspan) = _promote_time_and_state(u0, f(first(tspan), u0), tspan)

@inline function DiffEqBase.promote_u0(u0::T, p, t0) where {T<:Union{Bra,Ket}}
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = T(u0.basis, u0data_promote)
return u0_promote
end
return u0
end
@inline function DiffEqBase.promote_u0(u0::Operator, p, t0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = Operator(u0.basis_l, u0.basis_r, u0data_promote)
return u0_promote
end
return u0
end
44 changes: 44 additions & 0 deletions src/timeevolution_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,47 @@ macro skiptimechecks(ex)
end

Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T)

function _promote_time_and_state(u0, H::AbstractOperator, tspan)
Ts = eltype(H)
Tt = real(Ts)
p = Vector{Tt}(undef,0)
u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1])
tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}())
return tspan_promote, u0_promote
end
function _promote_time_and_state(u0, H::AbstractOperator, J, tspan)
Ts = DiffEqBase.promote_dual(eltype(H), DiffEqBase.anyeltypedual(J))
Tt = real(Ts)
p = Vector{Tt}(undef,0)
u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1])
tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}())
return tspan_promote, u0_promote
end

_promote_time_and_state(u0, f, tspan) = _promote_time_and_state(u0, f(first(tspan)..., u0), tspan)

@inline function DiffEqBase.promote_u0(u0::Ket, p, t0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = Ket(u0.basis, u0data_promote)
return u0_promote
end
return u0
end
@inline function DiffEqBase.promote_u0(u0::Bra, p, t0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = Bra(u0.basis, u0data_promote)
return u0_promote
end
return u0
end
@inline function DiffEqBase.promote_u0(u0::Operator, p, t0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = Operator(u0.basis_l, u0.basis_r, u0data_promote)
return u0_promote
end
return u0
end
33 changes: 33 additions & 0 deletions test/test_ForwardDiff.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test
using OrdinaryDiffEq, QuantumOptics
import ForwardDiff
import FiniteDiff

# for some caese ForwardDiff.jl returns NaN due to issue with DiffEq.jl. see https://github.com/SciML/DiffEqBase.jl/issues/861
# Here we test;
Expand Down Expand Up @@ -73,3 +74,35 @@ Ftdop(1.0)
@test ForwardDiff.derivative(Ftdop, 1.0) isa Any

end # testset


@testset "ForwardDiff with `master`" begin

b = SpinBasis(1//2)
psi0 = spindown(b)
rho0 = dm(psi0)
params = [10.0, -3.0]

# test to see if parameter propagates through Hamiltonian
H(p) = p[1]*sigmax(b) + p[2]*sigmam(b) # Hamiltonian
function cost1(p) #
tf, psif = timeevolution.master((0.0, pi), rho0, H(p), [sigmax(b)])
return 1 - norm(psif)
end

forwarddiff1 = ForwardDiff.gradient(cost1, params)
finitediff1 = FiniteDiff.finite_difference_gradient(cost1, params)
@test isapprox(forwarddiff1, finitediff1; atol=1e-5)

# test to see if parameter propagates through Jump operator
J(p) = p[1]*sigmax(b) + p[2]*sigmam(b) # jump operator
function cost2(p)
tf, psif = timeevolution.master((0.0, pi), rho0, sigmax(b), [J(p)])
return 1 - norm(psif)
end

forwarddiff2 = ForwardDiff.gradient(cost2, params)
finitediff2 = FiniteDiff.finite_difference_gradient(cost2, params)
@test isapprox(forwarddiff2, finitediff2; atol=1e-5)

end

0 comments on commit 247da73

Please sign in to comment.