Skip to content

Commit 2661fc4

Browse files
Merge pull request #511 from SciML/enzyme_default
Fix the enzyme default check
2 parents b46fab3 + abc75c3 commit 2661fc4

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- uses: actions/checkout@v2
2424
- uses: julia-actions/setup-julia@v1
2525
with:
26-
version: 1
26+
version: 1.6
2727
- uses: actions/cache@v1
2828
env:
2929
cache-name: cache-artifacts

src/concrete_solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function DiffEqBase._concrete_solve_adjoint(prob::Union{ODEProblem,SDEProblem},
2929
try
3030
Enzyme.autodiff(Enzyme.Duplicated(du, du),
3131
u0,p,prob.tspan[1]) do out,u,_p,t
32-
f(out, u, _p, t)
32+
prob.f(out, u, _p, t)
3333
nothing
3434
end
3535
true
@@ -611,7 +611,7 @@ function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::TrackerAdjoint,
611611
out,pullback = Tracker.forward(tracker_adjoint_forwardpass,u0,p)
612612
function tracker_adjoint_backpass(ybar)
613613
tmp = if eltype(ybar) <: Number
614-
ybar
614+
Array(ybar)
615615
elseif typeof(ybar[1]) <: Array
616616
return Array(ybar)
617617
else

test/complex_adjoints.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,17 @@ dp2 = Zygote.gradient((p)->inner_loop(prob, p, loss_fun; sensealg = QuadratureAd
5151
dp3 = Zygote.gradient((p)->inner_loop(prob, p, loss_fun; sensealg = BacksolveAdjoint()), p)[1]
5252
@test dp1 dp2 dp3
5353
@test eltype(dp1) <: Float64
54+
55+
function fiip(du,u,p,t)
56+
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
57+
du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
58+
end
59+
p = [1.5,1.0,3.0,1.0]; u0 = [1.0; 1.0]
60+
prob = ODEProblem(fiip,complex(u0),(0.0,10.0),complex(p))
61+
62+
function sum_of_solution(u0, p)
63+
_prob = remake(prob,u0=u0,p=p)
64+
real(sum(solve(_prob,Tsit5(),reltol=1e-6,abstol=1e-6,saveat=0.1)))
65+
end
66+
67+
dx = Zygote.gradient(sum_of_solution, complex(u0), complex(p))

0 commit comments

Comments
 (0)