Skip to content

Commit 932cfdd

Browse files
committed
Improve handling of substates in adjoint
1 parent 7decc01 commit 932cfdd

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

src/DictOptimization/optimization.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ function solve_and_differentiate_for_optimization(x, dopt::DictParameters, setup
4444
end
4545
Jutul.AdjointsDI.solve_adjoint_generic!(
4646
g, x, setup_from_vector, S, states, dt, objective,
47+
step_index = step_ix
4748
)
4849
# g = Jutul.AdjointsDI.solve_adjoint_generic(
4950
# x, setup_from_vector, states, dt, objective,

src/ad/AdjointsDI/adjoints.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ function solve_adjoint_generic(X, F, states, reports_or_timesteps, G;
55
state0 = missing,
66
forces = missing,
77
info_level = 0,
8+
step_index = eachindex(states),
89
kwarg...
910
)
1011
Jutul.set_global_timer!(extra_timing)
@@ -41,7 +42,8 @@ function solve_adjoint_generic(X, F, states, reports_or_timesteps, G;
4142
t_solve = @elapsed solve_adjoint_generic!(∇G, X, F, storage, states, timesteps, G,
4243
info_level = info_level,
4344
state0 = state0,
44-
forces = forces
45+
forces = forces,
46+
step_index = step_index
4547
)
4648
if info_level > 1
4749
jutul_message("Adjoints", "Adjoints solved in $(get_tstr(t_solve)).", color = :blue)
@@ -56,6 +58,7 @@ function solve_adjoint_generic(X, F, states, reports_or_timesteps, G;
5658

5759
function solve_adjoint_generic!(∇G, X, F, storage, states, timesteps, G;
5860
info_level = 0,
61+
step_index = eachindex(states),
5962
state0 = missing,
6063
forces = missing
6164
)
@@ -76,6 +79,7 @@ function solve_adjoint_generic!(∇G, X, F, storage, states, timesteps, G;
7679
end
7780
end
7881
if forces isa Vector
82+
forces = forces[step_index]
7983
@assert length(forces) == N "Expected $N forces (one per time-step), got $(length(forces))."
8084
end
8185
# Do sparsity detection if not already done.

0 commit comments

Comments
 (0)