Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 160 additions & 111 deletions src/Nonlinear/ReverseAD/graph_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,76 @@ function _compute_gradient_sparsity!(
return
end

"""
_get_nonlinear_child_interactions(
nod::Nonlinear.Node,
num_children::Int,
)

Get the list of nonlinear child interaction pairs for a node.
Returns empty list of tuples `(i, j)` where `i` and `j` are child indices (1-indexed)
that have nonlinear interactions.

For example, for `*` with 2 children, the result is `[(1, 2)]` because children 1
and 2 interact nonlinearly, but children 1 and 1, or 2 and 2, do not.

For functions like `+` or `-`, the result is `[]` since there are no nonlinear
interactions between children.
"""
function _get_nonlinear_child_interactions(
nod::Nonlinear.Node,
num_children::Int,
)::Vector{Tuple{Int,Int}}
if nod.type == Nonlinear.NODE_CALL_UNIVARIATE
@assert num_children == 1
op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, nod.index, nothing)
# Univariate operators :+ and :- don't create interactions
if op in (:+, :-)
return Tuple{Int,Int}[]
else
return [(1, 1)]
end
elseif nod.type == Nonlinear.NODE_CALL_MULTIVARIATE
op = get(Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS, nod.index, nothing)

if op in (:+, :-, :ifelse, :min, :max)
# No nonlinear interactions between children
return Tuple{Int,Int}[]
elseif op == :*
# All pairs of distinct children interact nonlinearly
result = Tuple{Int,Int}[]
for i in 1:num_children
for j in 1:(i-1)
push!(result, (j, i))
end
end
return result
elseif op == :/
@assert num_children == 2
# The numerator doesn't have a nonlinear interaction with itself.
return [(1, 2), (2, 2)]
else
# Conservative: assume all pairs interact
result = Tuple{Int,Int}[]
for i in 1:num_children
for j in 1:i
push!(result, (j, i))
end
end
return result
end
else
# Logic and comparison nodes don't generate hessian terms.
# Subexpression nodes are special cased.
return Tuple{Int,Int}[]
end
end

"""
_compute_hessian_sparsity(
nodes::Vector{Nonlinear.Node},
adj,
input_linearity::Vector{Linearity},
indexedset::Coloring.IndexedSet,
subexpression_edgelist::Vector{Set{Tuple{Int,Int}}},
subexpression_variables::Vector{Vector{Int}},
)
Expand All @@ -193,142 +257,127 @@ Compute the sparsity pattern the Hessian of an expression.
* `subexpression_variables` is the list of all variables which appear in a
subexpression (including recursively).

Idea: consider the (non)linearity of a node *with respect to the output*. The
children of any node which is nonlinear with respect to the output should have
nonlinear interactions, hence nonzeros in the hessian. This is not true in
general, but holds for everything we consider.

A counter example is `f(x, y, z) = x + y * z`, but we don't have any functions
like that. By "nonlinear with respect to the output", we mean that the output
depends nonlinearly on the value of the node, regardless of how the node itself
depends on the input.
Returns a `Set{Tuple{Int,Int}}` containing the nonzero entries of the Hessian.
"""
function _compute_hessian_sparsity(
nodes::Vector{Nonlinear.Node},
adj,
input_linearity::Vector{Linearity},
indexedset::Coloring.IndexedSet,
subexpression_edgelist::Vector{Set{Tuple{Int,Int}}},
subexpression_variables::Vector{Vector{Int}},
)
# So start at the root of the tree and classify the linearity wrt the output.
# For each nonlinear node, do a mini DFS and collect the list of children.
# Add a nonlinear interaction between all children of a nonlinear node.
edge_list = Set{Tuple{Int,Int}}()
nonlinear_wrt_output = fill(false, length(nodes))
children_arr = SparseArrays.rowvals(adj)
stack = Int[]
stack_ignore = Bool[]
nonlinear_group = indexedset
if length(nodes) == 1 && nodes[1].type == Nonlinear.NODE_SUBEXPRESSION
# Subexpression comes in linearly, so append edge_list
for ij in subexpression_edgelist[nodes[1].index]
push!(edge_list, ij)
end
end
for k in 2:length(nodes)

# Stack entry: (node_index, child_group_index)
stack = Tuple{Int,Int}[]
# Map from child_group_index to variable indices
child_group_variables = Dict{Int,Set{Int}}()

for k in 1:length(nodes)
nod = nodes[k]
@assert nod.type != Nonlinear.NODE_MOI_VARIABLE
if nonlinear_wrt_output[k]
continue # already seen this node one way or another
elseif input_linearity[k] == CONSTANT
continue # definitely not nonlinear

if input_linearity[k] == CONSTANT
continue # No hessian contribution from constant nodes
end
@assert !nonlinear_wrt_output[nod.parent]
# check if the parent depends nonlinearly on the value of this node
par = nodes[nod.parent]
if par.type == Nonlinear.NODE_CALL_UNIVARIATE
op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, par.index, nothing)
if op === nothing || (op != :+ && op != :-)
nonlinear_wrt_output[k] = true

# Check if this node has nonlinear child interactions
children_idx = SparseArrays.nzrange(adj, k)
num_children = length(children_idx)
interactions = _get_nonlinear_child_interactions(nod, num_children)

if !isempty(interactions)
# This node has nonlinear child interactions, so collect variables from its children
empty!(child_group_variables)

# DFS from all children, tracking child index
for (child_position, cidx) in enumerate(children_idx)
child_node_idx = children_arr[cidx]
push!(stack, (child_node_idx, child_position))
end
elseif par.type == Nonlinear.NODE_CALL_MULTIVARIATE
op = get(
Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS,
par.index,
nothing,
)
if op === nothing
nonlinear_wrt_output[k] = true
elseif op in (:+, :-, :ifelse)
# pass
elseif op == :*
# check if all siblings are constant
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
if !all(
i ->
input_linearity[children_arr[i]] == CONSTANT ||
children_arr[i] == k,
sibling_idx,
)
# at least one sibling isn't constant
nonlinear_wrt_output[k] = true

while length(stack) > 0
r, child_group_idx = pop!(stack)

# Don't traverse into logical conditions or comparisons
if nodes[r].type == Nonlinear.NODE_LOGIC ||
nodes[r].type == Nonlinear.NODE_COMPARISON
continue
end
elseif op == :/
# check if denominator is nonconstant
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
if input_linearity[children_arr[last(sibling_idx)]] != CONSTANT
nonlinear_wrt_output[k] = true

r_children_idx = SparseArrays.nzrange(adj, r)
for cidx in r_children_idx
push!(stack, (children_arr[cidx], child_group_idx))
end

if nodes[r].type == Nonlinear.NODE_VARIABLE
if !haskey(child_group_variables, child_group_idx)
child_group_variables[child_group_idx] = Set{Int}()
end
push!(
child_group_variables[child_group_idx],
nodes[r].index,
)
elseif nodes[r].type == Nonlinear.NODE_SUBEXPRESSION
sub_vars = subexpression_variables[nodes[r].index]
if !haskey(child_group_variables, child_group_idx)
child_group_variables[child_group_idx] = Set{Int}()
end
union!(child_group_variables[child_group_idx], sub_vars)
end
else
nonlinear_wrt_output[k] = true
end
end
if nod.type == Nonlinear.NODE_SUBEXPRESSION && !nonlinear_wrt_output[k]
# subexpression comes in linearly, so append edge_list
_add_hessian_edges!(edge_list, interactions, child_group_variables)
elseif nod.type == Nonlinear.NODE_SUBEXPRESSION
for ij in subexpression_edgelist[nod.index]
push!(edge_list, ij)
end
end
if !nonlinear_wrt_output[k]
continue
end
# do a DFS from here, including all children
@assert isempty(stack)
@assert isempty(stack_ignore)
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
for sidx in sibling_idx
push!(stack, children_arr[sidx])
push!(stack_ignore, false)
end
empty!(nonlinear_group)
while length(stack) > 0
r = pop!(stack)
should_ignore = pop!(stack_ignore)
nonlinear_wrt_output[r] = true
if nodes[r].type == Nonlinear.NODE_LOGIC ||
nodes[r].type == Nonlinear.NODE_COMPARISON
# don't count the nonlinear interactions inside
# logical conditions or comparisons
should_ignore = true
end
children_idx = SparseArrays.nzrange(adj, r)
for cidx in children_idx
push!(stack, children_arr[cidx])
push!(stack_ignore, should_ignore)
end
if should_ignore
continue
end
if nodes[r].type == Nonlinear.NODE_VARIABLE
push!(nonlinear_group, nodes[r].index)
elseif nodes[r].type == Nonlinear.NODE_SUBEXPRESSION
# append all variables in subexpression
union!(nonlinear_group, subexpression_variables[nodes[r].index])
end
return edge_list
end

"""
_add_hessian_edges!(
edge_list::Set{Tuple{Int,Int}},
interactions::Vector{Tuple{Int,Int}},
child_variables::Dict{Int,Set{Int}},
)

Add hessian edges based on the operator's nonlinear interaction pattern.
"""
function _add_hessian_edges!(
edge_list::Set{Tuple{Int,Int}},
interactions::Vector{Tuple{Int,Int}},
child_variables::Dict{Int,Set{Int}},
)
for (child_i, child_j) in interactions
if child_i == child_j
# Within-child interactions: add all pairs from a single child
if haskey(child_variables, child_i)
vars = child_variables[child_i]
for vi in vars
for vj in vars
i, j = minmax(vi, vj)
push!(edge_list, (j, i))
end
end
end
end
for i_ in 1:nonlinear_group.nnz
i = nonlinear_group.nzidx[i_]
for j_ in 1:nonlinear_group.nnz
j = nonlinear_group.nzidx[j_]
if j > i
continue # Only lower triangle.
else
# Between-child interactions: add pairs from different children
if haskey(child_variables, child_i) &&
haskey(child_variables, child_j)
vars_i = child_variables[child_i]
vars_j = child_variables[child_j]
for vi in vars_i
for vj in vars_j
i, j = minmax(vi, vj)
push!(edge_list, (j, i))
end
end
push!(edge_list, (i, j))
end
end
end
return edge_list
end

"""
Expand Down
1 change: 0 additions & 1 deletion src/Nonlinear/ReverseAD/mathoptinterface_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
subex.nodes,
subex.adj,
linearity,
coloring_storage,
subexpression_edgelist,
subexpression_variables,
)
Expand Down
1 change: 0 additions & 1 deletion src/Nonlinear/ReverseAD/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ struct _FunctionStorage
nodes,
adj,
linearity,
coloring_storage,
subexpression_edgelist,
subexpression_variables,
)
Expand Down
16 changes: 8 additions & 8 deletions test/Nonlinear/ReverseAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,6 @@ function test_linearity()
nodes,
adj,
ret,
indexed_set,
Set{Tuple{Int,Int}}[],
Vector{Int}[],
)
Expand All @@ -585,12 +584,7 @@ function test_linearity()
[1, 2],
)
_test_linearity(:(3 * 4 * ($x + $y)), ReverseAD.LINEAR)
_test_linearity(
:($z * $y),
ReverseAD.NONLINEAR,
Set([(3, 2), (3, 3), (2, 2)]),
[2, 3],
)
_test_linearity(:($z * $y), ReverseAD.NONLINEAR, Set([(3, 2)]), [2, 3])
_test_linearity(:(3 + 4), ReverseAD.CONSTANT)
_test_linearity(:(sin(3) + $x), ReverseAD.LINEAR)
_test_linearity(
Expand Down Expand Up @@ -635,6 +629,12 @@ function test_linearity()
Set([(1, 1)]),
[1],
)
_test_linearity(
:(($x + $y)/$z),
ReverseAD.NONLINEAR,
Set([(3, 3), (3, 2), (3, 1)]),
[1, 2, 3],
)
return
end

Expand Down Expand Up @@ -1416,7 +1416,7 @@ function test_hessian_reinterpret_unsafe()
x_v = ones(5)
MOI.eval_hessian_lagrangian(evaluator, H, x_v, 0.0, [1.0, 1.0])
@test count(isapprox.(H, 1.0; atol = 1e-8)) == 3
@test count(isapprox.(H, 0.0; atol = 1e-8)) == 6
@test count(isapprox.(H, 0.0; atol = 1e-8)) == 5
@test sort(H_s[round.(Bool, H)]) == [(3, 1), (3, 2), (5, 4)]
return
end
Expand Down
Loading