diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 6c0ba5f8..94d41436 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -19,6 +19,7 @@ using DispatchDoctor: @stable, @unstable include("OperatorEnumConstruction.jl") include("Expression.jl") include("ExpressionAlgebra.jl") + include("SpecialOperators.jl") include("Random.jl") include("Parse.jl") include("ParametricExpression.jl") @@ -76,6 +77,7 @@ import .StringsModule: get_op_name @reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array, EvalOptions import .EvaluateModule: ArrayBuffer +@reexport import .SpecialOperatorsModule: AssignOperator, WhileOperator @reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array @reexport import .ChainRulesModule: NodeTangent, extract_gradient @reexport import .SimplifyModule: combine_operators, simplify_tree! diff --git a/src/Evaluate.jl b/src/Evaluate.jl index a03d0140..b426c3f6 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -10,6 +10,12 @@ import ..NodeUtilsModule: is_constant import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded import ..ValueInterfaceModule: is_valid, is_valid_array +# Overloaded by SpecialOperators.jl: +function any_special_operators end +function special_operator end +function deg2_eval_special end +function deg1_eval_special end + const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15 macro return_on_nonfinite_val(eval_options, val, X) @@ -218,6 +224,10 @@ function eval_tree_array( "Bumper and LoopVectorization features are only compatible with numeric element types", ) end + if any_special_operators(operators) + cX = copy(cX) + # TODO: This is dangerous if the element type is mutable + end if _eval_options.bumper isa Val{true} return bumper_eval_tree_array(tree, cX, operators, _eval_options) end @@ -264,7 +274,7 @@ function _eval_tree_array( # we can just return the constant result. if tree.degree == 0 return deg0_eval(tree, cX, eval_options) - elseif is_constant(tree) + elseif !any_special_operators(operators) && is_constant(tree) # Speed hack for constant trees. const_result = dispatch_constant_tree(tree, operators)::ResultOk{T} !const_result.ok && @@ -329,6 +339,9 @@ end long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote + op = operators.binops[op_idx] + special_operator(op) && + return deg2_eval_special(tree, cX, operators, op, eval_options) result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l @return_on_nonfinite_array(eval_options, result_l.x) @@ -336,7 +349,7 @@ end !result_r.ok && return result_r @return_on_nonfinite_array(eval_options, result_r.x) # op(x, y), for any x or y - deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], eval_options) + deg2_eval(result_l.x, result_r.x, op, eval_options) end end return quote @@ -344,7 +357,9 @@ end $nbin, i -> i == op_idx, i -> let op = operators.binops[i] - if tree.l.degree == 0 && tree.r.degree == 0 + if special_operator(op) + deg2_eval_special(tree, cX, operators, op, eval_options) + elseif tree.l.degree == 0 && tree.r.degree == 0 deg2_l0_r0_eval(tree, cX, op, eval_options) elseif tree.r.degree == 0 result_l = _eval_tree_array(tree.l, cX, operators, eval_options) @@ -352,7 +367,9 @@ end @return_on_nonfinite_array(eval_options, result_l.x) # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, eval_options) - elseif tree.l.degree == 0 + elseif !any_special_operators(operators) && tree.l.degree == 0 + # This branch changes the execution order, so we cannot + # use this branch when special operators are present. result_r = _eval_tree_array(tree.r, cX, operators, eval_options) !result_r.ok && return result_r @return_on_nonfinite_array(eval_options, result_r.x) @@ -383,10 +400,13 @@ end long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote + op = operators.unaops[op_idx] + special_operator(op) && + return deg1_eval_special(tree, cX, operators, op, eval_options) result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result @return_on_nonfinite_array(eval_options, result.x) - deg1_eval(result.x, operators.unaops[op_idx], eval_options) + deg1_eval(result.x, op, eval_options) end end # This @nif lets us generate an if statement over choice of operator, @@ -396,13 +416,20 @@ end $nuna, i -> i == op_idx, i -> let op = operators.unaops[i] - if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0 + if special_operator(op) + deg1_eval_special(tree, cX, operators, op, eval_options) + elseif !any_special_operators(operators) && + tree.l.degree == 2 && + tree.l.l.degree == 0 && + tree.l.r.degree == 0 # op(op2(x, y)), where x, y, z are constants or variables. l_op_idx = tree.l.op dispatch_deg1_l2_ll0_lr0_eval( tree, cX, op, l_op_idx, operators.binops, eval_options ) - elseif tree.l.degree == 1 && tree.l.l.degree == 0 + elseif !any_special_operators(operators) && + tree.l.degree == 1 && + tree.l.l.degree == 0 # op(op2(x)), where x is a constant or variable. l_op_idx = tree.l.op dispatch_deg1_l1_ll0_eval( diff --git a/src/Simplify.jl b/src/Simplify.jl index cf2592a2..9c1f783b 100644 --- a/src/Simplify.jl +++ b/src/Simplify.jl @@ -4,6 +4,7 @@ import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set import ..NodeUtilsModule: tree_mapreduce, is_node_constant import ..OperatorEnumModule: AbstractOperatorEnum import ..ValueInterfaceModule: is_valid +import ..EvaluateModule: any_special_operators _una_op_kernel(f::F, l::T) where {F,T} = f(l) _bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r) @@ -19,6 +20,12 @@ combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree # This is only defined for `Node` as it is not possible for, e.g., # `GraphNode`. function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T} + # Skip simplification if special operators are in use + any_special_operators(operators) && return tree + return _combine_operators(tree, operators) +end + +function _combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T} # NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before. # ((const + var) + const) => (const + var) # ((const * var) * const) => (const * var) @@ -28,10 +35,10 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where if tree.degree == 0 return tree elseif tree.degree == 1 - tree.l = combine_operators(tree.l, operators) + tree.l = _combine_operators(tree.l, operators) elseif tree.degree == 2 - tree.l = combine_operators(tree.l, operators) - tree.r = combine_operators(tree.r, operators) + tree.l = _combine_operators(tree.l, operators) + tree.r = _combine_operators(tree.r, operators) end top_level_constant = @@ -123,6 +130,11 @@ end # Simplify tree function simplify_tree!(tree::AbstractExpressionNode, operators::AbstractOperatorEnum) + # Skip simplification if special operators are in use + if any_special_operators(operators) + return tree + end + return tree_mapreduce( identity, (p, c...) -> combine_children!(operators, p, c...), tree, typeof(tree); ) diff --git a/src/SpecialOperators.jl b/src/SpecialOperators.jl new file mode 100644 index 00000000..e5b88ee3 --- /dev/null +++ b/src/SpecialOperators.jl @@ -0,0 +1,84 @@ +module SpecialOperatorsModule + +using ..OperatorEnumModule: OperatorEnum +using ..EvaluateModule: + _eval_tree_array, @return_on_nonfinite_array, deg2_eval, ResultOk, get_filled_array +using ..ExpressionModule: AbstractExpression +using ..ExpressionAlgebraModule: @declare_expression_operator + +import ..EvaluateModule: + special_operator, deg2_eval_special, deg1_eval_special, any_special_operators +import ..StringsModule: get_op_name + +# Use this to customize evaluation behavior for operators: +@inline special_operator(::Type{F}) where {F} = false +@inline special_operator(::F) where {F} = special_operator(F) + +@generated function any_special_operators( + ::Union{O,Type{O}} +) where {B,U,O<:OperatorEnum{B,U}} + return any(special_operator, B.types) || any(special_operator, U.types) +end + +Base.@kwdef struct AssignOperator <: Function + target_register::Int +end +@declare_expression_operator((op::AssignOperator), 1) +@inline special_operator(::Type{AssignOperator}) = true +get_op_name(o::AssignOperator) = "ASSIGN_OP:{FEATURE_" * string(o.target_register) * "}" + +function deg1_eval_special(tree, cX, operators, op::AssignOperator, eval_options) + result = _eval_tree_array(tree.l, cX, operators, eval_options) + !result.ok && return result + @return_on_nonfinite_array(eval_options, result.x) + target_register = op.target_register + @inbounds @simd for i in eachindex(axes(cX, 2)) + cX[target_register, i] = result.x[i] + end + return result +end + +Base.@kwdef struct WhileOperator <: Function + max_iters::Int = 100 +end + +@declare_expression_operator((op::WhileOperator), 2) +@inline special_operator(::Type{WhileOperator}) = true +get_op_name(o::WhileOperator) = "while" + +# TODO: Need to void any instance of buffer when using while loop. +function deg2_eval_special(tree, cX, operators, op::WhileOperator, eval_options) + cond = tree.l + body = tree.r + mask = trues(size(cX, 2)) + X = @view cX[:, mask] + # Initialize the result array for all columns + result_array = get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2)) + body_result = ResultOk(result_array, true) + + for _ in 1:(op.max_iters) + cond_result = _eval_tree_array(cond, X, operators, eval_options) + !cond_result.ok && return cond_result + @return_on_nonfinite_array(eval_options, cond_result.x) + + new_mask = cond_result.x .> 0.0 + any(new_mask) || return body_result + + # Track which columns are still active + mask[mask] .= new_mask + X = @view cX[:, mask] + + # Evaluate just for active columns + iter_result = _eval_tree_array(body, X, operators, eval_options) + !iter_result.ok && return iter_result + + # Update the corresponding elements in the result array + body_result.x[mask] .= iter_result.x + @return_on_nonfinite_array(eval_options, body_result.x) + end + + # We passed max_iters, so this result is invalid + return ResultOk(body_result.x, false) +end + +end diff --git a/src/Strings.jl b/src/Strings.jl index a13eae31..a6a9bbcf 100644 --- a/src/Strings.jl +++ b/src/Strings.jl @@ -56,6 +56,18 @@ end end end +const FEATURE_PLACEHOLDER_FIRST_HALF_LENGTH = length("{FEATURE_") +function replace_feature_placeholders(s::String, f_variable::Function, variable_names) + return replace( + s, + r"\{FEATURE_(\d+)\}" => + m -> f_variable( + parse(Int, m[(begin + FEATURE_PLACEHOLDER_FIRST_HALF_LENGTH):(end - 1)]), + variable_names, + ), + ) +end + # Can overload these for custom behavior: needs_brackets(val::Real) = false needs_brackets(val::AbstractArray) = false @@ -104,12 +116,33 @@ function combine_op_with_inputs(op, l, r)::Vector{Char} end end function combine_op_with_inputs(op, l) - # "op(l)" - out = copy(op) - push!(out, '(') - append!(out, strip_brackets(l)) - push!(out, ')') - return out + # Check if this is an assignment operator with our special prefix + op_str = String(op) + if startswith(op_str, "ASSIGN_OP:") + # Extract the variable name from the operator name + var_name = op_str[11:end] + # Format: (var ← expr) + out = ['('] + append!(out, collect(var_name)) + append!(out, collect(" ← ")) + # Ensure the expression is always wrapped in parentheses for clarity + if l[1] == '(' && l[end] == ')' + append!(out, l) + else + push!(out, '(') + append!(out, strip_brackets(l)) + push!(out, ')') + end + push!(out, ')') + return out + else + # Regular unary operator: "op(l)" + out = copy(op) + push!(out, '(') + append!(out, strip_brackets(l)) + push!(out, ')') + return out + end end """ @@ -179,7 +212,9 @@ function string_tree( c end, ) - return String(strip_brackets(raw_output)) + string_output = String(strip_brackets(raw_output)) + string_output = replace_feature_placeholders(string_output, f_variable, variable_names) + return string_output end # Print an equation diff --git a/test/test_special_operators.jl b/test/test_special_operators.jl new file mode 100644 index 00000000..1eb36cac --- /dev/null +++ b/test/test_special_operators.jl @@ -0,0 +1,173 @@ +using TestItems: @testitem + +@testitem "AssignOperator basic functionality" begin + using DynamicExpressions + using Test + using Random + + # Define operators and variable names + assign_x2 = AssignOperator(; target_register=2) + operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[sin, cos, assign_x2] + ) + variable_names = ["x1", "x2", "x3", "x4", "x5"] + + # Test data + X = zeros(Float64, 2, 3) + X[1, :] .= [1.0, 2.0, 3.0] + X[2, :] .= [0.5, 1.5, 2.5] + + # 1. Basic register assignment - assign constant to register 2, + # and then add the return to `x2` (which should now be 3.0!) + x1 = Expression(Node(; feature=1); operators, variable_names) + x2 = Expression(Node(; feature=2); operators, variable_names) + assign_expr = assign_x2(0.0 * x1 + 3.0) + x2 + + @test string_tree(assign_expr) == "(x2 ← ((0.0 * x1) + 3.0)) + x2" + + # We should see that x2 will become 3.0 _before_ adding + result, completed = eval_tree_array(assign_expr, X) + @test completed == true + @test all(==(6.0), result) + + # We should also see that X is not changed by this + @test X[2, :] == [0.5, 1.5, 2.5] + + # But, with the reverse order, we get the x2 _before_ it was reassigned + assign_expr_reverse = x2 + assign_x2(0.0 * x1 + 3.0) + @test string_tree(assign_expr_reverse) == "x2 + (x2 ← ((0.0 * x1) + 3.0))" + result, completed = eval_tree_array(assign_expr_reverse, X) + @test completed == true + @test result == [3.5, 4.5, 5.5] +end + +@testitem "AssignOperator with self-assignment" begin + using DynamicExpressions + using Test + using Random + + assign_x1 = AssignOperator(; target_register=1) + operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[sin, cos, assign_x1] + ) + variable_names = ["a", "b", "c"] + X = rand(Float64, 2, 10) + + x1 = Expression(Node(; feature=1); operators, variable_names) + x2 = Expression(Node(; feature=2); operators, variable_names) + x3 = Expression(Node(; feature=3); operators, variable_names) + + expr = assign_x1(assign_x1(x1 * 2) + x1) + @test string_tree(expr) == "a ← ((a ← (a * 2.0)) + a)" + + result, completed = eval_tree_array(expr, X) + @test completed == true + @test result == X[1, :] .* 4.0 +end + +@testitem "Simplification disabled with special operators" begin + using DynamicExpressions + using Test + + # Create operators with and without special operator + assign_op = AssignOperator(; target_register=1) + special_operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[sin, cos, assign_op] + ) + normal_operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[sin, cos] + ) + + @test DynamicExpressions.SpecialOperatorsModule.any_special_operators(special_operators) + @test !DynamicExpressions.SpecialOperatorsModule.any_special_operators(normal_operators) + + # Create expressions using the Expression constructor + const_val = 2.0 + + # Simple expression that should simplify: 2.0 + 2.0 + raw_node = Node(; op=1, l=Node(; val=const_val), r=Node(; val=const_val)) + simple_expr = Expression(copy(raw_node); operators=normal_operators) + simple_expr_special = Expression(copy(raw_node); operators=special_operators) + + @test string_tree(simple_expr) == "2.0 + 2.0" + @test string_tree(simple_expr_special) == "2.0 + 2.0" + + # Test normal simplification works + simplified = simplify_tree!(simple_expr) + @test string_tree(simplified) == "4.0" + + # Test simplification is disabled with special operators + not_simplified = simplify_tree!(simple_expr_special) + @test string_tree(not_simplified) == "2.0 + 2.0" +end + +@testitem "WhileOperator basic functionality" begin + using DynamicExpressions + using Test + + # Define operators + while_op = WhileOperator(; max_iters=100) + assign_x2 = AssignOperator(; target_register=2) + operators = OperatorEnum(; + binary_operators=[+, -, *, /, while_op], # While is binary operator + unary_operators=[assign_x2], + ) + variable_names = ["x1", "x2", "x3"] + + # Test data - x2 starts at 1.0 for all samples + X = zeros(Float64, 2, 3) + X[2, :] .= 1.0 # x2 initial value + + # Build expression: while (3.0 - x2 > 0) do x2 = x2 + 1.0 + x2 = Expression(Node(; feature=2); operators, variable_names) + expr = while_op(3.0 - x2, assign_x2(x2 + 1.0)) + + @test string_tree(expr) == "while(3.0 - x2, x2 ← (x2 + 1.0))" + + result, completed = eval_tree_array(expr, X) + @test completed == true + @test all(result .≈ 3.0) # After 2 iterations, x2 becomes 3.0 + @test X[2, :] == [1.0, 1.0, 1.0] # Original data unchanged +end + +@testitem "Fibonacci sequence with WhileOperator" begin + using DynamicExpressions + using Test + + # Define operators + while_op = WhileOperator(; max_iters=100) + assign_ops = [AssignOperator(; target_register=i) for i in 1:5] + operators = OperatorEnum(; + binary_operators=[+, -, *, /, while_op], unary_operators=assign_ops + ) + variable_names = ["x1", "x2", "x3", "x4", "x5"] + + # Test data - x2=5 (counter), x3=0 (F(0)), x4=1 (F(1)) + X = zeros(Float64, 5, 4) + # Set different Fibonacci sequence positions to calculate + X[2, :] = [3.0, 5.0, 7.0, 10.0] # Calculate F(3), F(5), F(7), F(10) + + # Initialize all rows with F(0)=0, F(1)=1 + X[3, :] .= 0.0 # x3 = 0.0 (F(0)) + X[4, :] .= 1.0 # x4 = 1.0 (F(1)) + + xs = [Expression(Node(; feature=i); operators, variable_names) for i in 1:5] + + # Build expression: + condition = xs[2] # WhileOperator implicitly checks if > 0 + body = + assign_ops[5](xs[3]) + + assign_ops[3](xs[4]) + + assign_ops[4](xs[5] + xs[4]) + + assign_ops[2](xs[2] - 1.0) + expr = (while_op(condition, body) * 0.0) + xs[3] + + @test string_tree(expr) == + "(while(x2, (((x5 ← (x3)) + (x3 ← (x4))) + (x4 ← (x5 + x4))) + (x2 ← (x2 - 1.0))) * 0.0) + x3" + + result, completed = eval_tree_array(expr, X) + @test completed == true + + # Test each Fibonacci number is correctly calculated + @test result ≈ [2.0, 5.0, 13.0, 55.0] # F(3)=2, F(5)=5, F(7)=13, F(10)=55 +end diff --git a/test/unittest.jl b/test/unittest.jl index 42ae11bb..209f9d9e 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -133,3 +133,4 @@ include("test_expression_math.jl") include("test_structured_expression.jl") include("test_readonlynode.jl") include("test_zygote_gradient_wrapper.jl") +include("test_special_operators.jl")