From 4cefb1a5332eed7ee8715a62d55a4ec6b013bb5e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 1 Mar 2025 21:42:01 +0000 Subject: [PATCH 01/12] wip: introduce special assignment operator --- src/DynamicExpressions.jl | 2 ++ src/Evaluate.jl | 32 ++++++++++++++++++---- src/SpecialOperators.jl | 57 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 5 deletions(-) create mode 100644 src/SpecialOperators.jl diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 6c0ba5f8..74e92c80 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -12,6 +12,7 @@ using DispatchDoctor: @stable, @unstable include("NodePreallocation.jl") include("Strings.jl") include("Evaluate.jl") + include("SpecialOperators.jl") include("EvaluateDerivative.jl") include("ChainRules.jl") include("EvaluationHelpers.jl") @@ -76,6 +77,7 @@ import .StringsModule: get_op_name @reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array, EvalOptions import .EvaluateModule: ArrayBuffer +@reexport import .SpecialOperatorsModule: AssignOperator @reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array @reexport import .ChainRulesModule: NodeTangent, extract_gradient @reexport import .SimplifyModule: combine_operators, simplify_tree! diff --git a/src/Evaluate.jl b/src/Evaluate.jl index a03d0140..9d9888d1 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -218,6 +218,10 @@ function eval_tree_array( "Bumper and LoopVectorization features are only compatible with numeric element types", ) end + if any_special_operators(typeof(operators)) + cX = copy(cX) + # TODO: This is dangerous if the element type is mutable + end if _eval_options.bumper isa Val{true} return bumper_eval_tree_array(tree, cX, operators, _eval_options) end @@ -329,6 +333,8 @@ end long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote + op = operators.binops[op_idx] + special_operator(op) && return deg2_eval_special(tree, cX, op, eval_options) result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l @return_on_nonfinite_array(eval_options, result_l.x) @@ -336,7 +342,7 @@ end !result_r.ok && return result_r @return_on_nonfinite_array(eval_options, result_r.x) # op(x, y), for any x or y - deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], eval_options) + deg2_eval(result_l.x, result_r.x, op, eval_options) end end return quote @@ -344,7 +350,9 @@ end $nbin, i -> i == op_idx, i -> let op = operators.binops[i] - if tree.l.degree == 0 && tree.r.degree == 0 + if special_operator(op) + deg2_eval_special(tree, cX, op, eval_options) + elseif tree.l.degree == 0 && tree.r.degree == 0 deg2_l0_r0_eval(tree, cX, op, eval_options) elseif tree.r.degree == 0 result_l = _eval_tree_array(tree.l, cX, operators, eval_options) @@ -380,13 +388,16 @@ end eval_options::EvalOptions, ) where {T} nuna = get_nuna(operators) + special_operators = any_special_operators(operators) long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote + op = operators.unaops[op_idx] + special_operator(op) && return deg1_eval_special(tree, cX, op, eval_options) result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result @return_on_nonfinite_array(eval_options, result.x) - deg1_eval(result.x, operators.unaops[op_idx], eval_options) + deg1_eval(result.x, op, eval_options) end end # This @nif lets us generate an if statement over choice of operator, @@ -396,13 +407,18 @@ end $nuna, i -> i == op_idx, i -> let op = operators.unaops[i] - if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0 + if special_operator(op) + deg1_eval_special(tree, cX, op, eval_options) + elseif !special_operators && + tree.l.degree == 2 && + tree.l.l.degree == 0 && + tree.l.r.degree == 0 # op(op2(x, y)), where x, y, z are constants or variables. l_op_idx = tree.l.op dispatch_deg1_l2_ll0_lr0_eval( tree, cX, op, l_op_idx, operators.binops, eval_options ) - elseif tree.l.degree == 1 && tree.l.l.degree == 0 + elseif !special_operators && tree.l.degree == 1 && tree.l.l.degree == 0 # op(op2(x)), where x is a constant or variable. l_op_idx = tree.l.op dispatch_deg1_l1_ll0_eval( @@ -925,4 +941,10 @@ end end end +# Overloaded by SpecialOperators.jl: +function any_special_operators end +function special_operator end +function deg2_eval_special end +function deg1_eval_special end + end diff --git a/src/SpecialOperators.jl b/src/SpecialOperators.jl new file mode 100644 index 00000000..a0086b8a --- /dev/null +++ b/src/SpecialOperators.jl @@ -0,0 +1,57 @@ +module SpecialOperatorsModule + +using ..OperatorEnumModule: OperatorEnum +using ..EvaluateModule: _eval_tree_array, @return_on_nonfinite_array, deg2_eval + +import ..EvaluateModule: + special_operator, deg2_eval_special, deg1_eval_special, any_special_operators + +function any_special_operators(::Type{OperatorEnum{B,U}}) where {B,U} + return any(special_operator, B.types) || any(special_operator, U.types) +end + +# Use this to customize evaluation behavior for operators: +@inline special_operator(::Type) = false +@inline special_operator(f) = special_operator(typeof(f)) + +# Base.@kwdef struct WhileOperator <: Function +# max_iters::Int = 100 +# end +Base.@kwdef struct AssignOperator <: Function + target_register::Int +end + +# @inline special_operator(::Type{WhileOperator}) = true +@inline special_operator(::Type{AssignOperator}) = true + +# function deg2_eval_special(tree, cX, op::WhileOperator, eval_options) +# cond = tree.l +# body = tree.r +# for _ in 1:(op.max_iters) +# let cond_result = _eval_tree_array(cond, cX, operators, eval_options) +# !cond_result.ok && return cond_result +# @return_on_nonfinite_array(eval_options, cond_result.x) +# end +# let body_result = _eval_tree_array(body, cX, operators, eval_options) +# !body_result.ok && return body_result +# @return_on_nonfinite_array(eval_options, body_result.x) +# # TODO: Need to somehow mask instances +# end +# end + +# return get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2)) +# end +# TODO: Need to void any instance of buffer when using while loop. + +function deg1_eval_special(tree, cX, op::AssignOperator, eval_options) + result = _eval_tree_array(tree.l, cX, operators, eval_options) + !result.ok && return result + @return_on_nonfinite_array(eval_options, result.x) + target_register = op.target_register + @inbounds @simd for i in eachindex(axes(cX, 2)) + cX[target_register, i] = result.x[i] + end + return result +end + +end \ No newline at end of file From 4557455148e4dfb3307e230c8e264bb6b2e07184 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 00:38:02 +0000 Subject: [PATCH 02/12] wip: working printing for assignment operator --- src/DynamicExpressions.jl | 2 +- src/Evaluate.jl | 32 ++++++++++++++----------- src/SpecialOperators.jl | 20 +++++++++------- src/Strings.jl | 16 ++++++++++++- test/test_special_operators.jl | 43 ++++++++++++++++++++++++++++++++++ test/unittest.jl | 1 + 6 files changed, 90 insertions(+), 24 deletions(-) create mode 100644 test/test_special_operators.jl diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 74e92c80..1a95c60d 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -12,7 +12,6 @@ using DispatchDoctor: @stable, @unstable include("NodePreallocation.jl") include("Strings.jl") include("Evaluate.jl") - include("SpecialOperators.jl") include("EvaluateDerivative.jl") include("ChainRules.jl") include("EvaluationHelpers.jl") @@ -20,6 +19,7 @@ using DispatchDoctor: @stable, @unstable include("OperatorEnumConstruction.jl") include("Expression.jl") include("ExpressionAlgebra.jl") + include("SpecialOperators.jl") include("Random.jl") include("Parse.jl") include("ParametricExpression.jl") diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 9d9888d1..ac9ef657 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -10,6 +10,14 @@ import ..NodeUtilsModule: is_constant import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded import ..ValueInterfaceModule: is_valid, is_valid_array +# Overloaded by SpecialOperators.jl: +function any_special_operators(_) + return false +end +function special_operator end +function deg2_eval_special end +function deg1_eval_special end + const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15 macro return_on_nonfinite_val(eval_options, val, X) @@ -268,7 +276,7 @@ function _eval_tree_array( # we can just return the constant result. if tree.degree == 0 return deg0_eval(tree, cX, eval_options) - elseif is_constant(tree) + elseif !any_special_operators(operators) && is_constant(tree) # Speed hack for constant trees. const_result = dispatch_constant_tree(tree, operators)::ResultOk{T} !const_result.ok && @@ -330,6 +338,7 @@ end eval_options::EvalOptions, ) where {T} nbin = get_nbin(operators) + special_operators = any_special_operators(operators) long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote @@ -352,15 +361,15 @@ end i -> let op = operators.binops[i] if special_operator(op) deg2_eval_special(tree, cX, op, eval_options) - elseif tree.l.degree == 0 && tree.r.degree == 0 + elseif !$(special_operators) && tree.l.degree == 0 && tree.r.degree == 0 deg2_l0_r0_eval(tree, cX, op, eval_options) - elseif tree.r.degree == 0 + elseif !$(special_operators) && tree.r.degree == 0 result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l @return_on_nonfinite_array(eval_options, result_l.x) # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, eval_options) - elseif tree.l.degree == 0 + elseif !$(special_operators) && tree.l.degree == 0 result_r = _eval_tree_array(tree.r, cX, operators, eval_options) !result_r.ok && return result_r @return_on_nonfinite_array(eval_options, result_r.x) @@ -393,7 +402,8 @@ end if long_compilation_time return quote op = operators.unaops[op_idx] - special_operator(op) && return deg1_eval_special(tree, cX, op, eval_options) + special_operator(op) && + return deg1_eval_special(tree, cX, op, eval_options, operators) result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result @return_on_nonfinite_array(eval_options, result.x) @@ -408,8 +418,8 @@ end i -> i == op_idx, i -> let op = operators.unaops[i] if special_operator(op) - deg1_eval_special(tree, cX, op, eval_options) - elseif !special_operators && + deg1_eval_special(tree, cX, op, eval_options, operators) + elseif !$(special_operators) && tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0 @@ -418,7 +428,7 @@ end dispatch_deg1_l2_ll0_lr0_eval( tree, cX, op, l_op_idx, operators.binops, eval_options ) - elseif !special_operators && tree.l.degree == 1 && tree.l.l.degree == 0 + elseif !$(special_operators) && tree.l.degree == 1 && tree.l.l.degree == 0 # op(op2(x)), where x is a constant or variable. l_op_idx = tree.l.op dispatch_deg1_l1_ll0_eval( @@ -941,10 +951,4 @@ end end end -# Overloaded by SpecialOperators.jl: -function any_special_operators end -function special_operator end -function deg2_eval_special end -function deg1_eval_special end - end diff --git a/src/SpecialOperators.jl b/src/SpecialOperators.jl index a0086b8a..341cc2c0 100644 --- a/src/SpecialOperators.jl +++ b/src/SpecialOperators.jl @@ -2,11 +2,14 @@ module SpecialOperatorsModule using ..OperatorEnumModule: OperatorEnum using ..EvaluateModule: _eval_tree_array, @return_on_nonfinite_array, deg2_eval +using ..ExpressionModule: AbstractExpression +using ..ExpressionAlgebraModule: @declare_expression_operator import ..EvaluateModule: special_operator, deg2_eval_special, deg1_eval_special, any_special_operators +import ..StringsModule: get_op_name -function any_special_operators(::Type{OperatorEnum{B,U}}) where {B,U} +function any_special_operators(::Union{O,Type{O}}) where {B,U,O<:OperatorEnum{B,U}} return any(special_operator, B.types) || any(special_operator, U.types) end @@ -14,16 +17,17 @@ end @inline special_operator(::Type) = false @inline special_operator(f) = special_operator(typeof(f)) -# Base.@kwdef struct WhileOperator <: Function -# max_iters::Int = 100 -# end Base.@kwdef struct AssignOperator <: Function target_register::Int end - -# @inline special_operator(::Type{WhileOperator}) = true +@declare_expression_operator((op::AssignOperator), 1) @inline special_operator(::Type{AssignOperator}) = true +get_op_name(o::AssignOperator) = "[{FEATURE_" * string(o.target_register) * "} =]" +# Base.@kwdef struct WhileOperator <: Function +# max_iters::Int = 100 +# end +# @inline special_operator(::Type{WhileOperator}) = true # function deg2_eval_special(tree, cX, op::WhileOperator, eval_options) # cond = tree.l # body = tree.r @@ -43,7 +47,7 @@ end # end # TODO: Need to void any instance of buffer when using while loop. -function deg1_eval_special(tree, cX, op::AssignOperator, eval_options) +function deg1_eval_special(tree, cX, op::AssignOperator, eval_options, operators) result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result @return_on_nonfinite_array(eval_options, result.x) @@ -54,4 +58,4 @@ function deg1_eval_special(tree, cX, op::AssignOperator, eval_options) return result end -end \ No newline at end of file +end diff --git a/src/Strings.jl b/src/Strings.jl index a13eae31..dc035dc4 100644 --- a/src/Strings.jl +++ b/src/Strings.jl @@ -56,6 +56,18 @@ end end end +const FEATURE_PLACEHOLDER_FIRST_HALF_LENGTH = length("{FEATURE_") +function replace_feature_placeholders(s::String, f_variable::Function, variable_names) + return replace( + s, + r"\{FEATURE_(\d+)\}" => + m -> f_variable( + parse(Int, m[(begin + FEATURE_PLACEHOLDER_FIRST_HALF_LENGTH):(end - 1)]), + variable_names, + ), + ) +end + # Can overload these for custom behavior: needs_brackets(val::Real) = false needs_brackets(val::AbstractArray) = false @@ -179,7 +191,9 @@ function string_tree( c end, ) - return String(strip_brackets(raw_output)) + string_output = String(strip_brackets(raw_output)) + string_output = replace_feature_placeholders(string_output, f_variable, variable_names) + return string_output end # Print an equation diff --git a/test/test_special_operators.jl b/test/test_special_operators.jl new file mode 100644 index 00000000..e95bc616 --- /dev/null +++ b/test/test_special_operators.jl @@ -0,0 +1,43 @@ +using TestItems: @testitem + +@testitem "AssignOperator basic functionality" begin + using DynamicExpressions + using DynamicExpressions.SpecialOperatorsModule: AssignOperator + using DynamicExpressions.EvaluateModule: eval_tree_array + using Test + using Random + + # Define operators and variable names + assign_op2 = AssignOperator(; target_register=2) + operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[sin, cos, assign_op2] + ) + variable_names = ["x1", "x2", "x3", "x4", "x5"] + + # Test data + X = zeros(Float64, 2, 3) + X[1, :] .= [1.0, 2.0, 3.0] + X[2, :] .= [0.5, 1.5, 2.5] + + # 1. Basic register assignment - assign constant to register 2, + # and then add the return to `x2` (which should now be 3.0!) + x1 = Expression(Node(; feature=1); operators, variable_names) + x2 = Expression(Node(; feature=2); operators, variable_names) + assign_expr = assign_op2(0.0 * x1 + 3.0) + x2 + + @test string_tree(assign_expr) == "[x2 =]((0.0 * x1) + 3.0) + x2" + + # We should see that x2 will become 3.0 _before_ adding + result, completed = eval_tree_array(assign_expr, X) + @test completed == true + @test all(==(6.0), result) + + # We should also see that X is not changed by this + @test X[2, :] == [0.5, 1.5, 2.5] + + # But, with the reverse order, we get the x2 _before_ it was reassigned + assign_expr_reverse = x2 + assign_op2(0.0 * x1 + 3.0) + result, completed = eval_tree_array(assign_expr_reverse, X) + @test completed == true + @test result == [3.5, 4.5, 5.5] +end diff --git a/test/unittest.jl b/test/unittest.jl index 42ae11bb..209f9d9e 100644 --- a/test/unittest.jl +++ b/test/unittest.jl @@ -133,3 +133,4 @@ include("test_expression_math.jl") include("test_structured_expression.jl") include("test_readonlynode.jl") include("test_zygote_gradient_wrapper.jl") +include("test_special_operators.jl") From ef9302cb5f4cf2b0d5e2a295e58e4bf8b9d63037 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 01:02:50 +0000 Subject: [PATCH 03/12] wip: restore some branches which ARE compatible with special operators --- src/Evaluate.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index ac9ef657..a65cf7d8 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -361,15 +361,17 @@ end i -> let op = operators.binops[i] if special_operator(op) deg2_eval_special(tree, cX, op, eval_options) - elseif !$(special_operators) && tree.l.degree == 0 && tree.r.degree == 0 + elseif tree.l.degree == 0 && tree.r.degree == 0 deg2_l0_r0_eval(tree, cX, op, eval_options) - elseif !$(special_operators) && tree.r.degree == 0 + elseif tree.r.degree == 0 result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l @return_on_nonfinite_array(eval_options, result_l.x) # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, eval_options) elseif !$(special_operators) && tree.l.degree == 0 + # This branch changes the execution order, so we cannot + # use this branch when special operators are present. result_r = _eval_tree_array(tree.r, cX, operators, eval_options) !result_r.ok && return result_r @return_on_nonfinite_array(eval_options, result_r.x) From 30f163e777c3b01e042031d80fe246f47098f804 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 01:05:13 +0000 Subject: [PATCH 04/12] test: clean up --- test/test_special_operators.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_special_operators.jl b/test/test_special_operators.jl index e95bc616..13d2a86c 100644 --- a/test/test_special_operators.jl +++ b/test/test_special_operators.jl @@ -8,9 +8,9 @@ using TestItems: @testitem using Random # Define operators and variable names - assign_op2 = AssignOperator(; target_register=2) + assign_x2 = AssignOperator(; target_register=2) operators = OperatorEnum(; - binary_operators=[+, -, *, /], unary_operators=[sin, cos, assign_op2] + binary_operators=[+, -, *, /], unary_operators=[sin, cos, assign_x2] ) variable_names = ["x1", "x2", "x3", "x4", "x5"] @@ -23,7 +23,7 @@ using TestItems: @testitem # and then add the return to `x2` (which should now be 3.0!) x1 = Expression(Node(; feature=1); operators, variable_names) x2 = Expression(Node(; feature=2); operators, variable_names) - assign_expr = assign_op2(0.0 * x1 + 3.0) + x2 + assign_expr = assign_x2(0.0 * x1 + 3.0) + x2 @test string_tree(assign_expr) == "[x2 =]((0.0 * x1) + 3.0) + x2" @@ -36,7 +36,7 @@ using TestItems: @testitem @test X[2, :] == [0.5, 1.5, 2.5] # But, with the reverse order, we get the x2 _before_ it was reassigned - assign_expr_reverse = x2 + assign_op2(0.0 * x1 + 3.0) + assign_expr_reverse = x2 + assign_x2(0.0 * x1 + 3.0) result, completed = eval_tree_array(assign_expr_reverse, X) @test completed == true @test result == [3.5, 4.5, 5.5] From 60152910a9daec3b7bd7784fb9b132a6062eac13 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 01:11:58 +0000 Subject: [PATCH 05/12] test: more complex assignment --- test/test_special_operators.jl | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/test/test_special_operators.jl b/test/test_special_operators.jl index 13d2a86c..fa667bdb 100644 --- a/test/test_special_operators.jl +++ b/test/test_special_operators.jl @@ -2,8 +2,6 @@ using TestItems: @testitem @testitem "AssignOperator basic functionality" begin using DynamicExpressions - using DynamicExpressions.SpecialOperatorsModule: AssignOperator - using DynamicExpressions.EvaluateModule: eval_tree_array using Test using Random @@ -37,7 +35,32 @@ using TestItems: @testitem # But, with the reverse order, we get the x2 _before_ it was reassigned assign_expr_reverse = x2 + assign_x2(0.0 * x1 + 3.0) + @test string_tree(assign_expr_reverse) == "x2 + [x2 =]((0.0 * x1) + 3.0)" result, completed = eval_tree_array(assign_expr_reverse, X) @test completed == true @test result == [3.5, 4.5, 5.5] end + +@testitem "AssignOperator with self-assignment" begin + using DynamicExpressions + using Test + using Random + + assign_x1 = AssignOperator(; target_register=1) + operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[sin, cos, assign_x1] + ) + variable_names = ["a", "b", "c"] + X = rand(Float64, 2, 10) + + x1 = Expression(Node(; feature=1); operators, variable_names) + x2 = Expression(Node(; feature=2); operators, variable_names) + x3 = Expression(Node(; feature=3); operators, variable_names) + + expr = assign_x1(assign_x1(x1 * 2) + x1) + @test string_tree(expr) == "[a =]([a =](a * 2.0) + a)" + + result, completed = eval_tree_array(expr, X) + @test completed == true + @test result == X[1, :] .* 4.0 +end From f3f78aef4318c635c4d18b4fd67ee58497e59ccc Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 01:26:21 +0000 Subject: [PATCH 06/12] fix: avoid simplification when given special operators --- src/Simplify.jl | 18 ++++++++++++++--- test/test_special_operators.jl | 36 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/Simplify.jl b/src/Simplify.jl index cf2592a2..9c1f783b 100644 --- a/src/Simplify.jl +++ b/src/Simplify.jl @@ -4,6 +4,7 @@ import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set import ..NodeUtilsModule: tree_mapreduce, is_node_constant import ..OperatorEnumModule: AbstractOperatorEnum import ..ValueInterfaceModule: is_valid +import ..EvaluateModule: any_special_operators _una_op_kernel(f::F, l::T) where {F,T} = f(l) _bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r) @@ -19,6 +20,12 @@ combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree # This is only defined for `Node` as it is not possible for, e.g., # `GraphNode`. function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T} + # Skip simplification if special operators are in use + any_special_operators(operators) && return tree + return _combine_operators(tree, operators) +end + +function _combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T} # NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before. # ((const + var) + const) => (const + var) # ((const * var) * const) => (const * var) @@ -28,10 +35,10 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where if tree.degree == 0 return tree elseif tree.degree == 1 - tree.l = combine_operators(tree.l, operators) + tree.l = _combine_operators(tree.l, operators) elseif tree.degree == 2 - tree.l = combine_operators(tree.l, operators) - tree.r = combine_operators(tree.r, operators) + tree.l = _combine_operators(tree.l, operators) + tree.r = _combine_operators(tree.r, operators) end top_level_constant = @@ -123,6 +130,11 @@ end # Simplify tree function simplify_tree!(tree::AbstractExpressionNode, operators::AbstractOperatorEnum) + # Skip simplification if special operators are in use + if any_special_operators(operators) + return tree + end + return tree_mapreduce( identity, (p, c...) -> combine_children!(operators, p, c...), tree, typeof(tree); ) diff --git a/test/test_special_operators.jl b/test/test_special_operators.jl index fa667bdb..24f10247 100644 --- a/test/test_special_operators.jl +++ b/test/test_special_operators.jl @@ -64,3 +64,39 @@ end @test completed == true @test result == X[1, :] .* 4.0 end + +@testitem "Simplification disabled with special operators" begin + using DynamicExpressions + using Test + + # Create operators with and without special operator + assign_op = AssignOperator(; target_register=1) + special_operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[sin, cos, assign_op] + ) + normal_operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[sin, cos] + ) + + @test DynamicExpressions.SpecialOperatorsModule.any_special_operators(special_operators) + @test !DynamicExpressions.SpecialOperatorsModule.any_special_operators(normal_operators) + + # Create expressions using the Expression constructor + const_val = 2.0 + + # Simple expression that should simplify: 2.0 + 2.0 + raw_node = Node(; op=1, l=Node(; val=const_val), r=Node(; val=const_val)) + simple_expr = Expression(copy(raw_node); operators=normal_operators) + simple_expr_special = Expression(copy(raw_node); operators=special_operators) + + @test string_tree(simple_expr) == "2.0 + 2.0" + @test string_tree(simple_expr_special) == "2.0 + 2.0" + + # Test normal simplification works + simplified = simplify_tree!(simple_expr) + @test string_tree(simplified) == "4.0" + + # Test simplification is disabled with special operators + not_simplified = simplify_tree!(simple_expr_special) + @test string_tree(not_simplified) == "2.0 + 2.0" +end From 6e0334b9bafa315426fdf82f5f03f1200e23da46 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 02:26:56 +0000 Subject: [PATCH 07/12] feat: create WhileOperator --- src/DynamicExpressions.jl | 2 +- src/Evaluate.jl | 5 ++- src/SpecialOperators.jl | 69 +++++++++++++++++++++------------ test/test_special_operators.jl | 71 ++++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 27 deletions(-) diff --git a/src/DynamicExpressions.jl b/src/DynamicExpressions.jl index 1a95c60d..94d41436 100644 --- a/src/DynamicExpressions.jl +++ b/src/DynamicExpressions.jl @@ -77,7 +77,7 @@ import .StringsModule: get_op_name @reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array, EvalOptions import .EvaluateModule: ArrayBuffer -@reexport import .SpecialOperatorsModule: AssignOperator +@reexport import .SpecialOperatorsModule: AssignOperator, WhileOperator @reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array @reexport import .ChainRulesModule: NodeTangent, extract_gradient @reexport import .SimplifyModule: combine_operators, simplify_tree! diff --git a/src/Evaluate.jl b/src/Evaluate.jl index a65cf7d8..5c69a76e 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -343,7 +343,8 @@ end if long_compilation_time return quote op = operators.binops[op_idx] - special_operator(op) && return deg2_eval_special(tree, cX, op, eval_options) + special_operator(op) && + return deg2_eval_special(tree, cX, op, eval_options, operators) result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l @return_on_nonfinite_array(eval_options, result_l.x) @@ -360,7 +361,7 @@ end i -> i == op_idx, i -> let op = operators.binops[i] if special_operator(op) - deg2_eval_special(tree, cX, op, eval_options) + deg2_eval_special(tree, cX, op, eval_options, operators) elseif tree.l.degree == 0 && tree.r.degree == 0 deg2_l0_r0_eval(tree, cX, op, eval_options) elseif tree.r.degree == 0 diff --git a/src/SpecialOperators.jl b/src/SpecialOperators.jl index 341cc2c0..39a227fa 100644 --- a/src/SpecialOperators.jl +++ b/src/SpecialOperators.jl @@ -1,7 +1,8 @@ module SpecialOperatorsModule using ..OperatorEnumModule: OperatorEnum -using ..EvaluateModule: _eval_tree_array, @return_on_nonfinite_array, deg2_eval +using ..EvaluateModule: + _eval_tree_array, @return_on_nonfinite_array, deg2_eval, ResultOk, get_filled_array using ..ExpressionModule: AbstractExpression using ..ExpressionAlgebraModule: @declare_expression_operator @@ -24,29 +25,6 @@ end @inline special_operator(::Type{AssignOperator}) = true get_op_name(o::AssignOperator) = "[{FEATURE_" * string(o.target_register) * "} =]" -# Base.@kwdef struct WhileOperator <: Function -# max_iters::Int = 100 -# end -# @inline special_operator(::Type{WhileOperator}) = true -# function deg2_eval_special(tree, cX, op::WhileOperator, eval_options) -# cond = tree.l -# body = tree.r -# for _ in 1:(op.max_iters) -# let cond_result = _eval_tree_array(cond, cX, operators, eval_options) -# !cond_result.ok && return cond_result -# @return_on_nonfinite_array(eval_options, cond_result.x) -# end -# let body_result = _eval_tree_array(body, cX, operators, eval_options) -# !body_result.ok && return body_result -# @return_on_nonfinite_array(eval_options, body_result.x) -# # TODO: Need to somehow mask instances -# end -# end - -# return get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2)) -# end -# TODO: Need to void any instance of buffer when using while loop. - function deg1_eval_special(tree, cX, op::AssignOperator, eval_options, operators) result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result @@ -58,4 +36,47 @@ function deg1_eval_special(tree, cX, op::AssignOperator, eval_options, operators return result end +Base.@kwdef struct WhileOperator <: Function + max_iters::Int = 100 +end + +@declare_expression_operator((op::WhileOperator), 2) +@inline special_operator(::Type{WhileOperator}) = true +get_op_name(o::WhileOperator) = "while" + +# TODO: Need to void any instance of buffer when using while loop. +function deg2_eval_special(tree, cX, op::WhileOperator, eval_options, operators) + cond = tree.l + body = tree.r + mask = trues(size(cX, 2)) + X = @view cX[:, mask] + # Initialize the result array for all columns + result_array = get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2)) + body_result = ResultOk(result_array, true) + + for _ in 1:(op.max_iters) + cond_result = _eval_tree_array(cond, X, operators, eval_options) + !cond_result.ok && return cond_result + @return_on_nonfinite_array(eval_options, cond_result.x) + + new_mask = cond_result.x .> 0.0 + any(new_mask) || return body_result + + # Track which columns are still active + mask[mask] .= new_mask + X = @view cX[:, mask] + + # Evaluate just for active columns + iter_result = _eval_tree_array(body, X, operators, eval_options) + !iter_result.ok && return iter_result + + # Update the corresponding elements in the result array + body_result.x[mask] .= iter_result.x + @return_on_nonfinite_array(eval_options, body_result.x) + end + + # We passed max_iters, so this result is invalid + return ResultOk(body_result.x, false) +end + end diff --git a/test/test_special_operators.jl b/test/test_special_operators.jl index 24f10247..c3b929a9 100644 --- a/test/test_special_operators.jl +++ b/test/test_special_operators.jl @@ -100,3 +100,74 @@ end not_simplified = simplify_tree!(simple_expr_special) @test string_tree(not_simplified) == "2.0 + 2.0" end + +@testitem "WhileOperator basic functionality" begin + using DynamicExpressions + using Test + + # Define operators + while_op = WhileOperator(; max_iters=100) + assign_x2 = AssignOperator(; target_register=2) + operators = OperatorEnum(; + binary_operators=[+, -, *, /, while_op], # While is binary operator + unary_operators=[assign_x2], + ) + variable_names = ["x1", "x2", "x3"] + + # Test data - x2 starts at 1.0 for all samples + X = zeros(Float64, 2, 3) + X[2, :] .= 1.0 # x2 initial value + + # Build expression: while (3.0 - x2 > 0) do x2 = x2 + 1.0 + x2 = Expression(Node(; feature=2); operators, variable_names) + expr = while_op(3.0 - x2, assign_x2(x2 + 1.0)) + + @test string_tree(expr) == "while(3.0 - x2, [x2 =](x2 + 1.0))" + + result, completed = eval_tree_array(expr, X) + @test completed == true + @test all(result .≈ 3.0) # After 2 iterations, x2 becomes 3.0 + @test X[2, :] == [1.0, 1.0, 1.0] # Original data unchanged +end + +@testitem "Fibonacci sequence with WhileOperator" begin + using DynamicExpressions + using Test + + # Define operators + while_op = WhileOperator(; max_iters=100) + assign_ops = [AssignOperator(; target_register=i) for i in 1:5] + operators = OperatorEnum(; + binary_operators=[+, -, *, /, while_op], unary_operators=assign_ops + ) + variable_names = ["x1", "x2", "x3", "x4", "x5"] + + # Test data - x2=5 (counter), x3=0 (F(0)), x4=1 (F(1)) + X = zeros(Float64, 5, 4) + # Set different Fibonacci sequence positions to calculate + X[2, :] = [3.0, 5.0, 7.0, 10.0] # Calculate F(3), F(5), F(7), F(10) + + # Initialize all rows with F(0)=0, F(1)=1 + X[3, :] .= 0.0 # x3 = 0.0 (F(0)) + X[4, :] .= 1.0 # x4 = 1.0 (F(1)) + + xs = [Expression(Node(; feature=i); operators, variable_names) for i in 1:5] + + # Build expression: + condition = xs[2] # WhileOperator implicitly checks if > 0 + body = + assign_ops[5](xs[3]) + + assign_ops[3](xs[4]) + + assign_ops[4](xs[5] + xs[4]) + + assign_ops[2](xs[2] - 1.0) + expr = (while_op(condition, body) * 0.0) + xs[3] + + @test string_tree(expr) == + "(while(x2, (([x5 =](x3) + [x3 =](x4)) + [x4 =](x5 + x4)) + [x2 =](x2 - 1.0)) * 0.0) + x3" + + result, completed = eval_tree_array(expr, X) + @test completed == true + + # Test each Fibonacci number is correctly calculated + @test result ≈ [2.0, 5.0, 13.0, 55.0] # F(3)=2, F(5)=5, F(7)=13, F(10)=55 +end From c90383f5a314673d5195b2bcef7ddcf0fb8d5f6e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 02:29:12 +0000 Subject: [PATCH 08/12] refactor: better signatures --- src/Evaluate.jl | 8 ++++---- src/SpecialOperators.jl | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 5c69a76e..5e9e4724 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -344,7 +344,7 @@ end return quote op = operators.binops[op_idx] special_operator(op) && - return deg2_eval_special(tree, cX, op, eval_options, operators) + return deg2_eval_special(tree, cX, operators, op, eval_options) result_l = _eval_tree_array(tree.l, cX, operators, eval_options) !result_l.ok && return result_l @return_on_nonfinite_array(eval_options, result_l.x) @@ -361,7 +361,7 @@ end i -> i == op_idx, i -> let op = operators.binops[i] if special_operator(op) - deg2_eval_special(tree, cX, op, eval_options, operators) + deg2_eval_special(tree, cX, operators, op, eval_options) elseif tree.l.degree == 0 && tree.r.degree == 0 deg2_l0_r0_eval(tree, cX, op, eval_options) elseif tree.r.degree == 0 @@ -406,7 +406,7 @@ end return quote op = operators.unaops[op_idx] special_operator(op) && - return deg1_eval_special(tree, cX, op, eval_options, operators) + return deg1_eval_special(tree, cX, operators, op, eval_options) result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result @return_on_nonfinite_array(eval_options, result.x) @@ -421,7 +421,7 @@ end i -> i == op_idx, i -> let op = operators.unaops[i] if special_operator(op) - deg1_eval_special(tree, cX, op, eval_options, operators) + deg1_eval_special(tree, cX, operators, op, eval_options) elseif !$(special_operators) && tree.l.degree == 2 && tree.l.l.degree == 0 && diff --git a/src/SpecialOperators.jl b/src/SpecialOperators.jl index 39a227fa..9fcb7aff 100644 --- a/src/SpecialOperators.jl +++ b/src/SpecialOperators.jl @@ -25,7 +25,7 @@ end @inline special_operator(::Type{AssignOperator}) = true get_op_name(o::AssignOperator) = "[{FEATURE_" * string(o.target_register) * "} =]" -function deg1_eval_special(tree, cX, op::AssignOperator, eval_options, operators) +function deg1_eval_special(tree, cX, operators, op::AssignOperator, eval_options) result = _eval_tree_array(tree.l, cX, operators, eval_options) !result.ok && return result @return_on_nonfinite_array(eval_options, result.x) @@ -45,7 +45,7 @@ end get_op_name(o::WhileOperator) = "while" # TODO: Need to void any instance of buffer when using while loop. -function deg2_eval_special(tree, cX, op::WhileOperator, eval_options, operators) +function deg2_eval_special(tree, cX, operators, op::WhileOperator, eval_options) cond = tree.l body = tree.r mask = trues(size(cX, 2)) From 1ff3bbaf00ec71efd0ef3a154c2d7c9ca133d3e0 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 02:48:07 +0000 Subject: [PATCH 09/12] fix: try to improve type inference --- src/SpecialOperators.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/SpecialOperators.jl b/src/SpecialOperators.jl index 9fcb7aff..da7ead1f 100644 --- a/src/SpecialOperators.jl +++ b/src/SpecialOperators.jl @@ -15,8 +15,8 @@ function any_special_operators(::Union{O,Type{O}}) where {B,U,O<:OperatorEnum{B, end # Use this to customize evaluation behavior for operators: -@inline special_operator(::Type) = false -@inline special_operator(f) = special_operator(typeof(f)) +@inline special_operator(::Type{F}) where {F} = false +@inline special_operator(::F) where {F} = special_operator(F) Base.@kwdef struct AssignOperator <: Function target_register::Int From 525370a4697e5fe3d23c4c53e73228eb9b1ed3ac Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 03:11:53 +0000 Subject: [PATCH 10/12] refactor: better printing for assignment operator --- src/SpecialOperators.jl | 2 +- src/Strings.jl | 33 +++++++++++++++++++++++++++------ test/test_special_operators.jl | 10 +++++----- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/SpecialOperators.jl b/src/SpecialOperators.jl index da7ead1f..69f45cf0 100644 --- a/src/SpecialOperators.jl +++ b/src/SpecialOperators.jl @@ -23,7 +23,7 @@ Base.@kwdef struct AssignOperator <: Function end @declare_expression_operator((op::AssignOperator), 1) @inline special_operator(::Type{AssignOperator}) = true -get_op_name(o::AssignOperator) = "[{FEATURE_" * string(o.target_register) * "} =]" +get_op_name(o::AssignOperator) = "ASSIGN_OP:{FEATURE_" * string(o.target_register) * "}" function deg1_eval_special(tree, cX, operators, op::AssignOperator, eval_options) result = _eval_tree_array(tree.l, cX, operators, eval_options) diff --git a/src/Strings.jl b/src/Strings.jl index dc035dc4..a6a9bbcf 100644 --- a/src/Strings.jl +++ b/src/Strings.jl @@ -116,12 +116,33 @@ function combine_op_with_inputs(op, l, r)::Vector{Char} end end function combine_op_with_inputs(op, l) - # "op(l)" - out = copy(op) - push!(out, '(') - append!(out, strip_brackets(l)) - push!(out, ')') - return out + # Check if this is an assignment operator with our special prefix + op_str = String(op) + if startswith(op_str, "ASSIGN_OP:") + # Extract the variable name from the operator name + var_name = op_str[11:end] + # Format: (var ← expr) + out = ['('] + append!(out, collect(var_name)) + append!(out, collect(" ← ")) + # Ensure the expression is always wrapped in parentheses for clarity + if l[1] == '(' && l[end] == ')' + append!(out, l) + else + push!(out, '(') + append!(out, strip_brackets(l)) + push!(out, ')') + end + push!(out, ')') + return out + else + # Regular unary operator: "op(l)" + out = copy(op) + push!(out, '(') + append!(out, strip_brackets(l)) + push!(out, ')') + return out + end end """ diff --git a/test/test_special_operators.jl b/test/test_special_operators.jl index c3b929a9..1eb36cac 100644 --- a/test/test_special_operators.jl +++ b/test/test_special_operators.jl @@ -23,7 +23,7 @@ using TestItems: @testitem x2 = Expression(Node(; feature=2); operators, variable_names) assign_expr = assign_x2(0.0 * x1 + 3.0) + x2 - @test string_tree(assign_expr) == "[x2 =]((0.0 * x1) + 3.0) + x2" + @test string_tree(assign_expr) == "(x2 ← ((0.0 * x1) + 3.0)) + x2" # We should see that x2 will become 3.0 _before_ adding result, completed = eval_tree_array(assign_expr, X) @@ -35,7 +35,7 @@ using TestItems: @testitem # But, with the reverse order, we get the x2 _before_ it was reassigned assign_expr_reverse = x2 + assign_x2(0.0 * x1 + 3.0) - @test string_tree(assign_expr_reverse) == "x2 + [x2 =]((0.0 * x1) + 3.0)" + @test string_tree(assign_expr_reverse) == "x2 + (x2 ← ((0.0 * x1) + 3.0))" result, completed = eval_tree_array(assign_expr_reverse, X) @test completed == true @test result == [3.5, 4.5, 5.5] @@ -58,7 +58,7 @@ end x3 = Expression(Node(; feature=3); operators, variable_names) expr = assign_x1(assign_x1(x1 * 2) + x1) - @test string_tree(expr) == "[a =]([a =](a * 2.0) + a)" + @test string_tree(expr) == "a ← ((a ← (a * 2.0)) + a)" result, completed = eval_tree_array(expr, X) @test completed == true @@ -122,7 +122,7 @@ end x2 = Expression(Node(; feature=2); operators, variable_names) expr = while_op(3.0 - x2, assign_x2(x2 + 1.0)) - @test string_tree(expr) == "while(3.0 - x2, [x2 =](x2 + 1.0))" + @test string_tree(expr) == "while(3.0 - x2, x2 ← (x2 + 1.0))" result, completed = eval_tree_array(expr, X) @test completed == true @@ -163,7 +163,7 @@ end expr = (while_op(condition, body) * 0.0) + xs[3] @test string_tree(expr) == - "(while(x2, (([x5 =](x3) + [x3 =](x4)) + [x4 =](x5 + x4)) + [x2 =](x2 - 1.0)) * 0.0) + x3" + "(while(x2, (((x5 ← (x3)) + (x3 ← (x4))) + (x4 ← (x5 + x4))) + (x2 ← (x2 - 1.0))) * 0.0) + x3" result, completed = eval_tree_array(expr, X) @test completed == true From b37c9443cd5f0825f716228c60e3161dd5d4e9a3 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 03:48:37 +0000 Subject: [PATCH 11/12] fix: try to improve type inference --- src/SpecialOperators.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/SpecialOperators.jl b/src/SpecialOperators.jl index 69f45cf0..3f2b828f 100644 --- a/src/SpecialOperators.jl +++ b/src/SpecialOperators.jl @@ -10,7 +10,9 @@ import ..EvaluateModule: special_operator, deg2_eval_special, deg1_eval_special, any_special_operators import ..StringsModule: get_op_name -function any_special_operators(::Union{O,Type{O}}) where {B,U,O<:OperatorEnum{B,U}} +@generated function any_special_operators( + ::Union{O,Type{O}} +) where {B,U,O<:OperatorEnum{B,U}} return any(special_operator, B.types) || any(special_operator, U.types) end From 6267f5e7907f4e3329443bc4bc90b71c585b9803 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 2 Mar 2025 03:57:26 +0000 Subject: [PATCH 12/12] fix: try to improve type inference --- src/Evaluate.jl | 16 +++++++--------- src/SpecialOperators.jl | 8 ++++---- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/Evaluate.jl b/src/Evaluate.jl index 5e9e4724..b426c3f6 100644 --- a/src/Evaluate.jl +++ b/src/Evaluate.jl @@ -11,9 +11,7 @@ import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization import ..ValueInterfaceModule: is_valid, is_valid_array # Overloaded by SpecialOperators.jl: -function any_special_operators(_) - return false -end +function any_special_operators end function special_operator end function deg2_eval_special end function deg1_eval_special end @@ -226,7 +224,7 @@ function eval_tree_array( "Bumper and LoopVectorization features are only compatible with numeric element types", ) end - if any_special_operators(typeof(operators)) + if any_special_operators(operators) cX = copy(cX) # TODO: This is dangerous if the element type is mutable end @@ -338,7 +336,6 @@ end eval_options::EvalOptions, ) where {T} nbin = get_nbin(operators) - special_operators = any_special_operators(operators) long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote @@ -370,7 +367,7 @@ end @return_on_nonfinite_array(eval_options, result_l.x) # op(x, y), where y is a constant or variable but x is not. deg2_r0_eval(tree, result_l.x, cX, op, eval_options) - elseif !$(special_operators) && tree.l.degree == 0 + elseif !any_special_operators(operators) && tree.l.degree == 0 # This branch changes the execution order, so we cannot # use this branch when special operators are present. result_r = _eval_tree_array(tree.r, cX, operators, eval_options) @@ -400,7 +397,6 @@ end eval_options::EvalOptions, ) where {T} nuna = get_nuna(operators) - special_operators = any_special_operators(operators) long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN if long_compilation_time return quote @@ -422,7 +418,7 @@ end i -> let op = operators.unaops[i] if special_operator(op) deg1_eval_special(tree, cX, operators, op, eval_options) - elseif !$(special_operators) && + elseif !any_special_operators(operators) && tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0 @@ -431,7 +427,9 @@ end dispatch_deg1_l2_ll0_lr0_eval( tree, cX, op, l_op_idx, operators.binops, eval_options ) - elseif !$(special_operators) && tree.l.degree == 1 && tree.l.l.degree == 0 + elseif !any_special_operators(operators) && + tree.l.degree == 1 && + tree.l.l.degree == 0 # op(op2(x)), where x is a constant or variable. l_op_idx = tree.l.op dispatch_deg1_l1_ll0_eval( diff --git a/src/SpecialOperators.jl b/src/SpecialOperators.jl index 3f2b828f..e5b88ee3 100644 --- a/src/SpecialOperators.jl +++ b/src/SpecialOperators.jl @@ -10,16 +10,16 @@ import ..EvaluateModule: special_operator, deg2_eval_special, deg1_eval_special, any_special_operators import ..StringsModule: get_op_name +# Use this to customize evaluation behavior for operators: +@inline special_operator(::Type{F}) where {F} = false +@inline special_operator(::F) where {F} = special_operator(F) + @generated function any_special_operators( ::Union{O,Type{O}} ) where {B,U,O<:OperatorEnum{B,U}} return any(special_operator, B.types) || any(special_operator, U.types) end -# Use this to customize evaluation behavior for operators: -@inline special_operator(::Type{F}) where {F} = false -@inline special_operator(::F) where {F} = special_operator(F) - Base.@kwdef struct AssignOperator <: Function target_register::Int end