diff --git a/src/Nonlinear/ReverseAD/graph_tools.jl b/src/Nonlinear/ReverseAD/graph_tools.jl index 7e9c366c26..af9813a970 100644 --- a/src/Nonlinear/ReverseAD/graph_tools.jl +++ b/src/Nonlinear/ReverseAD/graph_tools.jl @@ -175,12 +175,76 @@ function _compute_gradient_sparsity!( return end +""" + _get_nonlinear_child_interactions( + nod::Nonlinear.Node, + num_children::Int, + ) + +Get the list of nonlinear child interaction pairs for a node. +Returns empty list of tuples `(i, j)` where `i` and `j` are child indices (1-indexed) +that have nonlinear interactions. + +For example, for `*` with 2 children, the result is `[(1, 2)]` because children 1 +and 2 interact nonlinearly, but children 1 and 1, or 2 and 2, do not. + +For functions like `+` or `-`, the result is `[]` since there are no nonlinear +interactions between children. +""" +function _get_nonlinear_child_interactions( + nod::Nonlinear.Node, + num_children::Int, +)::Vector{Tuple{Int,Int}} + if nod.type == Nonlinear.NODE_CALL_UNIVARIATE + @assert num_children == 1 + op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, nod.index, nothing) + # Univariate operators :+ and :- don't create interactions + if op in (:+, :-) + return Tuple{Int,Int}[] + else + return [(1, 1)] + end + elseif nod.type == Nonlinear.NODE_CALL_MULTIVARIATE + op = get(Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS, nod.index, nothing) + + if op in (:+, :-, :ifelse, :min, :max) + # No nonlinear interactions between children + return Tuple{Int,Int}[] + elseif op == :* + # All pairs of distinct children interact nonlinearly + result = Tuple{Int,Int}[] + for i in 1:num_children + for j in 1:(i-1) + push!(result, (j, i)) + end + end + return result + elseif op == :/ + @assert num_children == 2 + # The numerator doesn't have a nonlinear interaction with itself. + return [(1, 2), (2, 2)] + else + # Conservative: assume all pairs interact + result = Tuple{Int,Int}[] + for i in 1:num_children + for j in 1:i + push!(result, (j, i)) + end + end + return result + end + else + # Logic and comparison nodes don't generate hessian terms. + # Subexpression nodes are special cased. + return Tuple{Int,Int}[] + end +end + """ _compute_hessian_sparsity( nodes::Vector{Nonlinear.Node}, adj, input_linearity::Vector{Linearity}, - indexedset::Coloring.IndexedSet, subexpression_edgelist::Vector{Set{Tuple{Int,Int}}}, subexpression_variables::Vector{Vector{Int}}, ) @@ -193,142 +257,127 @@ Compute the sparsity pattern the Hessian of an expression. * `subexpression_variables` is the list of all variables which appear in a subexpression (including recursively). -Idea: consider the (non)linearity of a node *with respect to the output*. The -children of any node which is nonlinear with respect to the output should have -nonlinear interactions, hence nonzeros in the hessian. This is not true in -general, but holds for everything we consider. - -A counter example is `f(x, y, z) = x + y * z`, but we don't have any functions -like that. By "nonlinear with respect to the output", we mean that the output -depends nonlinearly on the value of the node, regardless of how the node itself -depends on the input. +Returns a `Set{Tuple{Int,Int}}` containing the nonzero entries of the Hessian. """ function _compute_hessian_sparsity( nodes::Vector{Nonlinear.Node}, adj, input_linearity::Vector{Linearity}, - indexedset::Coloring.IndexedSet, subexpression_edgelist::Vector{Set{Tuple{Int,Int}}}, subexpression_variables::Vector{Vector{Int}}, ) - # So start at the root of the tree and classify the linearity wrt the output. - # For each nonlinear node, do a mini DFS and collect the list of children. - # Add a nonlinear interaction between all children of a nonlinear node. edge_list = Set{Tuple{Int,Int}}() - nonlinear_wrt_output = fill(false, length(nodes)) children_arr = SparseArrays.rowvals(adj) - stack = Int[] - stack_ignore = Bool[] - nonlinear_group = indexedset - if length(nodes) == 1 && nodes[1].type == Nonlinear.NODE_SUBEXPRESSION - # Subexpression comes in linearly, so append edge_list - for ij in subexpression_edgelist[nodes[1].index] - push!(edge_list, ij) - end - end - for k in 2:length(nodes) + + # Stack entry: (node_index, child_group_index) + stack = Tuple{Int,Int}[] + # Map from child_group_index to variable indices + child_group_variables = Dict{Int,Set{Int}}() + + for k in 1:length(nodes) nod = nodes[k] @assert nod.type != Nonlinear.NODE_MOI_VARIABLE - if nonlinear_wrt_output[k] - continue # already seen this node one way or another - elseif input_linearity[k] == CONSTANT - continue # definitely not nonlinear + + if input_linearity[k] == CONSTANT + continue # No hessian contribution from constant nodes end - @assert !nonlinear_wrt_output[nod.parent] - # check if the parent depends nonlinearly on the value of this node - par = nodes[nod.parent] - if par.type == Nonlinear.NODE_CALL_UNIVARIATE - op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, par.index, nothing) - if op === nothing || (op != :+ && op != :-) - nonlinear_wrt_output[k] = true + + # Check if this node has nonlinear child interactions + children_idx = SparseArrays.nzrange(adj, k) + num_children = length(children_idx) + interactions = _get_nonlinear_child_interactions(nod, num_children) + + if !isempty(interactions) + # This node has nonlinear child interactions, so collect variables from its children + empty!(child_group_variables) + + # DFS from all children, tracking child index + for (child_position, cidx) in enumerate(children_idx) + child_node_idx = children_arr[cidx] + push!(stack, (child_node_idx, child_position)) end - elseif par.type == Nonlinear.NODE_CALL_MULTIVARIATE - op = get( - Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS, - par.index, - nothing, - ) - if op === nothing - nonlinear_wrt_output[k] = true - elseif op in (:+, :-, :ifelse) - # pass - elseif op == :* - # check if all siblings are constant - sibling_idx = SparseArrays.nzrange(adj, nod.parent) - if !all( - i -> - input_linearity[children_arr[i]] == CONSTANT || - children_arr[i] == k, - sibling_idx, - ) - # at least one sibling isn't constant - nonlinear_wrt_output[k] = true + + while length(stack) > 0 + r, child_group_idx = pop!(stack) + + # Don't traverse into logical conditions or comparisons + if nodes[r].type == Nonlinear.NODE_LOGIC || + nodes[r].type == Nonlinear.NODE_COMPARISON + continue end - elseif op == :/ - # check if denominator is nonconstant - sibling_idx = SparseArrays.nzrange(adj, nod.parent) - if input_linearity[children_arr[last(sibling_idx)]] != CONSTANT - nonlinear_wrt_output[k] = true + + r_children_idx = SparseArrays.nzrange(adj, r) + for cidx in r_children_idx + push!(stack, (children_arr[cidx], child_group_idx)) + end + + if nodes[r].type == Nonlinear.NODE_VARIABLE + if !haskey(child_group_variables, child_group_idx) + child_group_variables[child_group_idx] = Set{Int}() + end + push!( + child_group_variables[child_group_idx], + nodes[r].index, + ) + elseif nodes[r].type == Nonlinear.NODE_SUBEXPRESSION + sub_vars = subexpression_variables[nodes[r].index] + if !haskey(child_group_variables, child_group_idx) + child_group_variables[child_group_idx] = Set{Int}() + end + union!(child_group_variables[child_group_idx], sub_vars) end - else - nonlinear_wrt_output[k] = true end - end - if nod.type == Nonlinear.NODE_SUBEXPRESSION && !nonlinear_wrt_output[k] - # subexpression comes in linearly, so append edge_list + _add_hessian_edges!(edge_list, interactions, child_group_variables) + elseif nod.type == Nonlinear.NODE_SUBEXPRESSION for ij in subexpression_edgelist[nod.index] push!(edge_list, ij) end end - if !nonlinear_wrt_output[k] - continue - end - # do a DFS from here, including all children - @assert isempty(stack) - @assert isempty(stack_ignore) - sibling_idx = SparseArrays.nzrange(adj, nod.parent) - for sidx in sibling_idx - push!(stack, children_arr[sidx]) - push!(stack_ignore, false) - end - empty!(nonlinear_group) - while length(stack) > 0 - r = pop!(stack) - should_ignore = pop!(stack_ignore) - nonlinear_wrt_output[r] = true - if nodes[r].type == Nonlinear.NODE_LOGIC || - nodes[r].type == Nonlinear.NODE_COMPARISON - # don't count the nonlinear interactions inside - # logical conditions or comparisons - should_ignore = true - end - children_idx = SparseArrays.nzrange(adj, r) - for cidx in children_idx - push!(stack, children_arr[cidx]) - push!(stack_ignore, should_ignore) - end - if should_ignore - continue - end - if nodes[r].type == Nonlinear.NODE_VARIABLE - push!(nonlinear_group, nodes[r].index) - elseif nodes[r].type == Nonlinear.NODE_SUBEXPRESSION - # append all variables in subexpression - union!(nonlinear_group, subexpression_variables[nodes[r].index]) + end + return edge_list +end + +""" + _add_hessian_edges!( + edge_list::Set{Tuple{Int,Int}}, + interactions::Vector{Tuple{Int,Int}}, + child_variables::Dict{Int,Set{Int}}, + ) + +Add hessian edges based on the operator's nonlinear interaction pattern. +""" +function _add_hessian_edges!( + edge_list::Set{Tuple{Int,Int}}, + interactions::Vector{Tuple{Int,Int}}, + child_variables::Dict{Int,Set{Int}}, +) + for (child_i, child_j) in interactions + if child_i == child_j + # Within-child interactions: add all pairs from a single child + if haskey(child_variables, child_i) + vars = child_variables[child_i] + for vi in vars + for vj in vars + i, j = minmax(vi, vj) + push!(edge_list, (j, i)) + end + end end - end - for i_ in 1:nonlinear_group.nnz - i = nonlinear_group.nzidx[i_] - for j_ in 1:nonlinear_group.nnz - j = nonlinear_group.nzidx[j_] - if j > i - continue # Only lower triangle. + else + # Between-child interactions: add pairs from different children + if haskey(child_variables, child_i) && + haskey(child_variables, child_j) + vars_i = child_variables[child_i] + vars_j = child_variables[child_j] + for vi in vars_i + for vj in vars_j + i, j = minmax(vi, vj) + push!(edge_list, (j, i)) + end end - push!(edge_list, (i, j)) end end end - return edge_list end """ diff --git a/src/Nonlinear/ReverseAD/mathoptinterface_api.jl b/src/Nonlinear/ReverseAD/mathoptinterface_api.jl index 1766c25e73..f67d9caf8e 100644 --- a/src/Nonlinear/ReverseAD/mathoptinterface_api.jl +++ b/src/Nonlinear/ReverseAD/mathoptinterface_api.jl @@ -93,7 +93,6 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol}) subex.nodes, subex.adj, linearity, - coloring_storage, subexpression_edgelist, subexpression_variables, ) diff --git a/src/Nonlinear/ReverseAD/types.jl b/src/Nonlinear/ReverseAD/types.jl index fc599accb0..ce3e807b94 100644 --- a/src/Nonlinear/ReverseAD/types.jl +++ b/src/Nonlinear/ReverseAD/types.jl @@ -91,7 +91,6 @@ struct _FunctionStorage nodes, adj, linearity, - coloring_storage, subexpression_edgelist, subexpression_variables, ) diff --git a/test/Nonlinear/ReverseAD.jl b/test/Nonlinear/ReverseAD.jl index 608f28d4a0..32e1852e2b 100644 --- a/test/Nonlinear/ReverseAD.jl +++ b/test/Nonlinear/ReverseAD.jl @@ -561,7 +561,6 @@ function test_linearity() nodes, adj, ret, - indexed_set, Set{Tuple{Int,Int}}[], Vector{Int}[], ) @@ -585,12 +584,7 @@ function test_linearity() [1, 2], ) _test_linearity(:(3 * 4 * ($x + $y)), ReverseAD.LINEAR) - _test_linearity( - :($z * $y), - ReverseAD.NONLINEAR, - Set([(3, 2), (3, 3), (2, 2)]), - [2, 3], - ) + _test_linearity(:($z * $y), ReverseAD.NONLINEAR, Set([(3, 2)]), [2, 3]) _test_linearity(:(3 + 4), ReverseAD.CONSTANT) _test_linearity(:(sin(3) + $x), ReverseAD.LINEAR) _test_linearity( @@ -635,6 +629,12 @@ function test_linearity() Set([(1, 1)]), [1], ) + _test_linearity( + :(($x + $y)/$z), + ReverseAD.NONLINEAR, + Set([(3, 3), (3, 2), (3, 1)]), + [1, 2, 3], + ) return end @@ -1416,7 +1416,7 @@ function test_hessian_reinterpret_unsafe() x_v = ones(5) MOI.eval_hessian_lagrangian(evaluator, H, x_v, 0.0, [1.0, 1.0]) @test count(isapprox.(H, 1.0; atol = 1e-8)) == 3 - @test count(isapprox.(H, 0.0; atol = 1e-8)) == 6 + @test count(isapprox.(H, 0.0; atol = 1e-8)) == 5 @test sort(H_s[round.(Bool, H)]) == [(3, 1), (3, 2), (5, 4)] return end