@@ -5,6 +5,7 @@ function solve_adjoint_generic(X, F, states, reports_or_timesteps, G;
55 state0 = missing ,
66 forces = missing ,
77 info_level = 0 ,
8+ step_index = eachindex (states),
89 kwarg...
910 )
1011 Jutul. set_global_timer! (extra_timing)
@@ -41,7 +42,8 @@ function solve_adjoint_generic(X, F, states, reports_or_timesteps, G;
4142 t_solve = @elapsed solve_adjoint_generic! (∇G, X, F, storage, states, timesteps, G,
4243 info_level = info_level,
4344 state0 = state0,
44- forces = forces
45+ forces = forces,
46+ step_index = step_index
4547 )
4648 if info_level > 1
4749 jutul_message (" Adjoints" , " Adjoints solved in $(get_tstr (t_solve)) ." , color = :blue )
@@ -56,6 +58,7 @@ function solve_adjoint_generic(X, F, states, reports_or_timesteps, G;
5658
5759function solve_adjoint_generic! (∇G, X, F, storage, states, timesteps, G;
5860 info_level = 0 ,
61+ step_index = eachindex (states),
5962 state0 = missing ,
6063 forces = missing
6164 )
@@ -76,6 +79,7 @@ function solve_adjoint_generic!(∇G, X, F, storage, states, timesteps, G;
7679 end
7780 end
7881 if forces isa Vector
82+ forces = forces[step_index]
7983 @assert length (forces) == N " Expected $N forces (one per time-step), got $(length (forces)) ."
8084 end
8185 # Do sparsity detection if not already done.
0 commit comments