diff --git a/Project.toml b/Project.toml index a80d27c..5ee0195 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,6 @@ MacroTools = "0.5" NNlib = "0.9" SpecialFunctions = "2" SymbolicUtils = "3, 4" -Symbolics = "6" +Symbolics = "6, 7" Zygote = "0.6, 0.7" julia = "1.10" diff --git a/src/utils.jl b/src/utils.jl index d260c25..03b75a4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,6 +4,7 @@ using ChainRules using ChainRulesCore using Symbolics: Symbolics, @variables, @rule, unwrap, isdiv +import SymbolicUtils using SymbolicUtils.Code: toexpr using MacroTools using MacroTools: prewalk, postwalk @@ -13,19 +14,58 @@ Pick a strategy for raising the derivative of a function. If the derivative is like 1 over something, raise with the division rule; otherwise, raise with the multiplication rule. """ -function get_term_raiser(func) - @variables z - r1 = @rule -1 * (1 / ~x) => (-1) / ~x - der = frule((NoTangent(), true), func, z)[2] - term = unwrap(der) - maybe_rewrite = r1(term) - if maybe_rewrite !== nothing - term = maybe_rewrite +function get_term_raiser end + +@static if pkgversion(Symbolics) < v"7" + function get_term_raiser(func) + @variables z + r1 = @rule -1 * (1 / ~x) => (-1) / ~x + der = frule((NoTangent(), true), func, z)[2] + term = unwrap(der) + maybe_rewrite = r1(term) + if maybe_rewrite !== nothing + term = maybe_rewrite + end + if isdiv(term) && (term.num == 1 || term.num == -1) + term.den * term.num, raiseinv + else + term, raise + end + end +else + const COMMON_Z = only(@variables z) + const FALLBACK_RULE = (@rule -1 * (1 / ~x) => (-1) / ~x) + + function is_plusminus_one(@nospecialize(x)) + if x isa Int + return x == 1 || x == -1 + elseif x isa Int32 + return x == 1 || x == -1 + elseif x isa Float64 + return x == 1 || x == -1 + elseif x isa Float32 + return x == 1 || x == -1 + elseif x isa Number + return (x == 1)::Bool || (x == -1)::Bool + else + return false + end end - if isdiv(term) && (term.num == 1 || term.num == -1) - term.den * term.num, raiseinv - else - term, raise + + function get_term_raiser(func) + der = frule((NoTangent(), true), func, COMMON_Z)[2] + term = unwrap(der) + maybe_rewrite = FALLBACK_RULE(term) + if maybe_rewrite !== nothing + term = maybe_rewrite + end + if isdiv(term) + num, den = SymbolicUtils.arguments(term) + if SymbolicUtils.isconst(num) && is_plusminus_one(SymbolicUtils.unwrap_const(num)) + return num * den, raiseinv + end + end + return term, raise end end