diff --git a/Project.toml b/Project.toml index c6080e37..c10469c0 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ Arpack = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" diff --git a/src/master.jl b/src/master.jl index 0cf63ff6..e402d611 100644 --- a/src/master.jl +++ b/src/master.jl @@ -14,6 +14,7 @@ function master_h(tspan, rho0::Operator, H::AbstractOperator, J; _check_const.(J) _check_const.(Jdagger) check_master(rho0, H, J, Jdagger, rates) + tspan, rho0 = _promote_time_and_state(rho0, H, J, tspan) tmp = copy(rho0) dmaster_(t, rho, drho) = dmaster_h!(drho, H, J, Jdagger, rates, rho, tmp) integrate_master(tspan, dmaster_, rho0, fout; kwargs...) @@ -41,6 +42,7 @@ function master_nh(tspan, rho0::Operator, Hnh::AbstractOperator, J; _check_const.(J) _check_const.(Jdagger) check_master(rho0, Hnh, J, Jdagger, rates) + tspan, rho0 = _promote_time_and_state(rho0, Hnh, J, tspan) tmp = copy(rho0) dmaster_(t, rho, drho) = dmaster_nh!(drho, Hnh, Hnhdagger, J, Jdagger, rates, rho, tmp) integrate_master(tspan, dmaster_, rho0, fout; kwargs...) @@ -86,6 +88,7 @@ function master(tspan, rho0::Operator, H::AbstractOperator, J; _check_const(H) _check_const.(J) _check_const.(Jdagger) + tspan, rho0 = _promote_time_and_state(rho0, H, J, tspan) isreducible = check_master(rho0, H, J, Jdagger, rates) if !isreducible tmp = copy(rho0) @@ -124,6 +127,7 @@ function master(tspan, rho0::Operator, L::SuperOperator; fout=nothing, kwargs... b = GenericBasis(dim) rho_ = Ket(b,reshape(rho0.data, dim)) L_ = Operator(b,b,L.data) + tspan, rho_ = _promote_time_and_state(rho_, L_, tspan) dmaster_(t,rho,drho) = dmaster_liouville!(drho,L_,rho) # Rewrite into density matrix when saving diff --git a/src/schroedinger.jl b/src/schroedinger.jl index 2b785ca4..eddecc6c 100644 --- a/src/schroedinger.jl +++ b/src/schroedinger.jl @@ -122,31 +122,3 @@ function check_schroedinger(psi::Bra, H) check_multiplicable(psi, H) check_samebases(H) end - - -function _promote_time_and_state(u0, H::AbstractOperator, tspan) - Ts = eltype(H) - Tt = real(Ts) - p = Vector{Tt}(undef,0) - u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1]) - tspan_promote = DiffEqBase.promote_tspan(u0_promote, p, tspan, nothing, Dict{Symbol, Any}()) - return tspan_promote, u0_promote -end -_promote_time_and_state(u0, f, tspan) = _promote_time_and_state(u0, f(first(tspan), u0), tspan) - -@inline function DiffEqBase.promote_u0(u0::T, p, t0) where {T<:Union{Bra,Ket}} - u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0) - if u0data_promote !== u0.data - u0_promote = T(u0.basis, u0data_promote) - return u0_promote - end - return u0 -end -@inline function DiffEqBase.promote_u0(u0::Operator, p, t0) - u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0) - if u0data_promote !== u0.data - u0_promote = Operator(u0.basis_l, u0.basis_r, u0data_promote) - return u0_promote - end - return u0 -end diff --git a/src/timeevolution_base.jl b/src/timeevolution_base.jl index 8949df0c..c34bb0d9 100644 --- a/src/timeevolution_base.jl +++ b/src/timeevolution_base.jl @@ -121,3 +121,47 @@ macro skiptimechecks(ex) end Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T) + +function _promote_time_and_state(u0, H::AbstractOperator, tspan) + Ts = eltype(H) + Tt = real(Ts) + p = Vector{Tt}(undef,0) + u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1]) + tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}()) + return tspan_promote, u0_promote +end +function _promote_time_and_state(u0, H::AbstractOperator, J, tspan) + Ts = DiffEqBase.promote_dual(eltype(H), DiffEqBase.anyeltypedual(J)) + Tt = real(Ts) + p = Vector{Tt}(undef,0) + u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1]) + tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}()) + return tspan_promote, u0_promote +end + +_promote_time_and_state(u0, f, tspan) = _promote_time_and_state(u0, f(first(tspan)..., u0), tspan) + +@inline function DiffEqBase.promote_u0(u0::Ket, p, t0) + u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0) + if u0data_promote !== u0.data + u0_promote = Ket(u0.basis, u0data_promote) + return u0_promote + end + return u0 +end +@inline function DiffEqBase.promote_u0(u0::Bra, p, t0) + u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0) + if u0data_promote !== u0.data + u0_promote = Bra(u0.basis, u0data_promote) + return u0_promote + end + return u0 +end +@inline function DiffEqBase.promote_u0(u0::Operator, p, t0) + u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0) + if u0data_promote !== u0.data + u0_promote = Operator(u0.basis_l, u0.basis_r, u0data_promote) + return u0_promote + end + return u0 +end \ No newline at end of file diff --git a/test/test_ForwardDiff.jl b/test/test_ForwardDiff.jl index 91b889fc..365282f2 100644 --- a/test/test_ForwardDiff.jl +++ b/test/test_ForwardDiff.jl @@ -1,6 +1,7 @@ using Test using OrdinaryDiffEq, QuantumOptics import ForwardDiff +import FiniteDiff # for some caese ForwardDiff.jl returns NaN due to issue with DiffEq.jl. see https://github.com/SciML/DiffEqBase.jl/issues/861 # Here we test; @@ -73,3 +74,35 @@ Ftdop(1.0) @test ForwardDiff.derivative(Ftdop, 1.0) isa Any end # testset + + +@testset "ForwardDiff with `master`" begin + +b = SpinBasis(1//2) +psi0 = spindown(b) +rho0 = dm(psi0) +params = [10.0, -3.0] + +# test to see if parameter propagates through Hamiltonian +H(p) = p[1]*sigmax(b) + p[2]*sigmam(b) # Hamiltonian +function cost1(p) # + tf, psif = timeevolution.master((0.0, pi), rho0, H(p), [sigmax(b)]) + return 1 - norm(psif) +end + +forwarddiff1 = ForwardDiff.gradient(cost1, params) +finitediff1 = FiniteDiff.finite_difference_gradient(cost1, params) +@test isapprox(forwarddiff1, finitediff1; atol=1e-5) + +# test to see if parameter propagates through Jump operator +J(p) = p[1]*sigmax(b) + p[2]*sigmam(b) # jump operator +function cost2(p) + tf, psif = timeevolution.master((0.0, pi), rho0, sigmax(b), [J(p)]) + return 1 - norm(psif) +end + +forwarddiff2 = ForwardDiff.gradient(cost2, params) +finitediff2 = FiniteDiff.finite_difference_gradient(cost2, params) +@test isapprox(forwarddiff2, finitediff2; atol=1e-5) + +end