Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
8e537c7
feat: add Rotations extension
DhairyaLGandhi Nov 27, 2025
8da87ba
chore: working tree
DhairyaLGandhi Dec 1, 2025
97dd026
feat: convert orthogonal matrices to transpose
DhairyaLGandhi Dec 4, 2025
46cd839
test: check different inv calls
DhairyaLGandhi Dec 8, 2025
221c1f7
chore: working state
DhairyaLGandhi Dec 15, 2025
57451f1
Merge branch 'main' into dg/rot
DhairyaLGandhi Dec 23, 2025
ea0dc84
Merge branch 'dg/rot' of github.com:JuliaComputing/SymbolicCompilerPa…
DhairyaLGandhi Dec 23, 2025
07e420b
test: check multiple optimizations apply in parallel
DhairyaLGandhi Jan 5, 2026
3476e5b
chore: rotations ext
DhairyaLGandhi Jan 5, 2026
56bf118
chore: unused code
DhairyaLGandhi Jan 5, 2026
698cd4f
test(multiple): add ortho -> inv example
DhairyaLGandhi Jan 8, 2026
6516a64
chore: remove dispatch on CSEState
DhairyaLGandhi Jan 12, 2026
c905f15
chore: convert small views to static arrays
DhairyaLGandhi Jan 12, 2026
59aaf24
Merge branch 'main' into dg/rot
DhairyaLGandhi Jan 12, 2026
c3ee542
Merge pull request #14 from JuliaComputing/dg/multiple
DhairyaLGandhi Jan 12, 2026
9c5581d
chore: merge upstream
DhairyaLGandhi Jan 12, 2026
76e8000
test(multiple): import packages
DhairyaLGandhi Jan 12, 2026
e63dfa0
chore: check for sizes of views
DhairyaLGandhi Jan 19, 2026
20d0a5b
chore: working state
DhairyaLGandhi Jan 29, 2026
598018e
chore: working state
DhairyaLGandhi Feb 2, 2026
a645ac8
chore: use Base.ifelse instead of IfElse
DhairyaLGandhi Feb 2, 2026
a20cfda
Merge branch 'main' into dg/rot
DhairyaLGandhi Feb 2, 2026
9df123f
test: add Symbolics to test deps
DhairyaLGandhi Feb 2, 2026
de0b22f
test: add ortho tests to runtests
DhairyaLGandhi Feb 3, 2026
c21bed4
Update test/multiple.jl
DhairyaLGandhi Feb 3, 2026
1d2d1fe
Update test/multiple.jl
DhairyaLGandhi Feb 3, 2026
69d57dc
Merge branch 'dg/rot' into dg/mb
DhairyaLGandhi Feb 4, 2026
bf06dbd
chore: manage static ldiv
DhairyaLGandhi Feb 13, 2026
5821852
merge main
DhairyaLGandhi Feb 13, 2026
75f161d
Apply suggestions from code review
DhairyaLGandhi Feb 13, 2026
91acc97
Update ext/SCPLinearSolveExt/SCPLinearSolveExt.jl
DhairyaLGandhi Feb 16, 2026
34c6699
Merge pull request #16 from JuliaComputing/dg/mb
DhairyaLGandhi Feb 16, 2026
ddd2efd
test: simplify shapes
DhairyaLGandhi Feb 16, 2026
90ce954
build: add rotations to extras
DhairyaLGandhi Feb 16, 2026
13001c1
test: simplify shapes for testing multiple strats togeteher
DhairyaLGandhi Feb 17, 2026
efbae92
test: check if optimization is eexpected
DhairyaLGandhi Feb 17, 2026
a7f2d2a
test: set cases with no expected optimization
DhairyaLGandhi Feb 17, 2026
60ddb96
test: default checking ortho to true
DhairyaLGandhi Feb 17, 2026
7e67eae
chore: rm debug code
DhairyaLGandhi Feb 17, 2026
6a85912
test: handle rotation matrix appropriate
DhairyaLGandhi Feb 17, 2026
0552c65
chore: rm IfElse red
DhairyaLGandhi Feb 17, 2026
34d5fad
chore: log only once
DhairyaLGandhi Feb 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,32 +1,40 @@
name = "SymbolicCompilerPasses"
uuid = "3384d301-0fbe-4b40-9ae0-b0e68bedb069"
version = "0.1.2"
authors = ["Dhairya Gandhi <dhairya@juliahub.com>"]
version = "0.1.2"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[weakdeps]
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"

[extensions]
SCPLinearSolveExt = ["LinearSolve"]
SymbolicCompilerPassesRotationsExt = ["Rotations"]

[compat]
LinearAlgebra = "1"
DataStructures = "0.19.3"
LinearAlgebra = "1.11.0, 1.10"
LinearSolve = "3.53.0"
PreallocationTools = "0.4.34, 1"
Rotations = "1.7.1"
StaticArrays = "1.9.15, 1"
SymbolicUtils = "4.1.0"
Symbolics = "7.2"
julia = "1"

[extras]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"

[targets]
test = ["Pkg", "SafeTestsets", "Test"]
test = ["Pkg", "SafeTestsets", "Test", "Symbolics", "Rotations"]
26 changes: 23 additions & 3 deletions ext/SCPLinearSolveExt/SCPLinearSolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,37 @@ using SymbolicUtils
using SymbolicUtils.Code
using LinearSolve
using LinearAlgebra
import SymbolicCompilerPasses: ldiv_transformation, SymbolicCompilerPasses, get_factorization, get_from_cache, FACTORIZATION_CACHE, LINEARSOLVE_LIB
import SymbolicCompilerPasses: ldiv_transformation, SymbolicCompilerPasses, get_factorization, get_from_cache, FACTORIZATION_CACHE
using StaticArrays

__init__() = SymbolicCompilerPasses.LINEARSOLVE_LIB[] = true

const LINSOLVEPROB_CACHE = Dict()

function get_linear_prob(A::StaticArray, B::StaticArray)
prob = LinearSolve.LinearProblem(A, B)
end

function get_linear_prob(A::TA, B::TB) where {TA, TB}
get!(LINSOLVEPROB_CACHE, A) do
prob = LinearSolve.LinearProblem(A, B)
init(prob)
end# ::Base.promote_op(init, Tuple{Base.promote_op(LinearSolve.LinearProblem, Tuple{TA, TB})})
end

function linear_solve(A, B)
linsolve = get_factorization(A, B)
linsolve = get_linear_prob(A, B)
linsolve.b = B
sol = solve!(linsolve)
return sol.u
end

function linear_solve(A::StaticArray, B::StaticArray)
linsolve = get_linear_prob(A, B)
sol = solve(linsolve)
return sol.u
end

function get_factorization(A, B)
get!(FACTORIZATION_CACHE, A) do
prob = LinearSolve.LinearProblem(A, B)
Expand All @@ -25,7 +45,7 @@ end

function ldiv_transformation(safe_matches, ::Val{true})
@info "Using LinearSolve.jl for in-place backsolve optimizations.
In order to opt-out of using LinearSolve, set SymbolicCompilerPasses.LINEARSOLVE_LIB[] = false." maxlog=Inf
In order to opt-out of using LinearSolve, set SymbolicCompilerPasses.LINEARSOLVE_LIB[] = false." maxlog=1
# Build transformation
transformations = Dict{Int, Code.Assignment}()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module SymbolicCompilerPassesRotationsExt

using SymbolicCompilerPasses
import SymbolicCompilerPasses: is_orthogonal_type

using Rotations
is_orthogonal_type(::Rotations.Rotation) = true

end
7 changes: 6 additions & 1 deletion src/SymbolicCompilerPasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ using LinearAlgebra
using PreallocationTools
using SymbolicUtils
import SymbolicUtils: symtype, vartype, Sym, BasicSymbolic, Term, iscall, operation, arguments, maketerm, Const, shape, isterm, unwrap,
is_function_symbolic, is_called_function_symbolic, getname, Unknown
is_function_symbolic, is_called_function_symbolic, getname, Unknown, search_variables!, search_variables
import SymbolicUtils.Code: Code, OptimizationRule, substitute_in_ir, apply_optimization_rules, AbstractMatched,
Assignment, CSEState, lhs, rhs, apply_substitution_map, issym, isterm, toexpr,
_is_array_of_symbolics, MakeArray, shape
import SymbolicUtils: search_variables, search_variables!
using StaticArrays

using DataStructures

function bank(dic, key, value)
if haskey(dic, key)
dic[key] = vcat(dic[key], value)
Expand All @@ -20,8 +22,11 @@ function bank(dic, key, value)
end

include("matmuladd.jl")
include("ortho_inv_opt.jl")
include("hvncat_static_opt.jl")
include("ldiv_opt.jl")
include("la_opt.jl")

include("mb_opt.jl")

end # module SymbolicCompilerPasses
15 changes: 5 additions & 10 deletions src/hvncat_static_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,21 +204,16 @@ function transform_hvncat_to_static(expr::Code.Let, match_data::Vector{HvncatMat
# Column vector: SVector{n}(elements...)
n = dims[1]
t = term(Core.apply_type, StaticArrays.SVector, n; type = Any)
static_ctor = Term{T}(
t,
elements;
type=symtype(lhs_var)
)
else
# Matrix: SMatrix{m,n}(elements...)
m, n = dims
t = term(Core.apply_type, StaticArrays.SMatrix, m, n; type = Any)
static_ctor = Term{T}(
t,
elements;
type=symtype(lhs_var)
)
end
static_ctor = Term{T}(
t,
elements;
type=symtype(lhs_var)
)

new_assignment = Assignment(lhs_var, static_ctor)
transformations[match.assignment_idx] = new_assignment
Expand Down
5 changes: 3 additions & 2 deletions src/ldiv_opt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
const FACTORIZATION_CACHE = WeakKeyDict()
const FACTORIZATION_CACHE = Dict()

struct LdivMatch{Ta, Tb, S <: Assignment, P <: AbstractString} <: AbstractMatched
A::Ta
Expand Down Expand Up @@ -149,6 +149,7 @@ function get_factorization(A)
qr_A = get!(FACTORIZATION_CACHE, A) do
qr(A)
end
# qr_A = qr(A)

qr_A
end
Expand All @@ -160,7 +161,7 @@ ldiv_transformation(x, ::Nothing) = ldiv_transformation(x, Val(false))
function ldiv_transformation(safe_matches, ::Val{false})
@warn "Backsolve may be sped up by adding LinearSolve.jl.
In order to enable this optimization, add LinearSolve.jl to your environment.
To opt-out of using LinearSolve, set SymbolicCompilerPasses.LINEARSOLVE_LIB[] = false." maxlog=Inf
To opt-out of using LinearSolve, set SymbolicCompilerPasses.LINEARSOLVE_LIB[] = false." maxlog=1

# Build transformation
transformations = Dict{Int, Code.Assignment}()
Expand Down
63 changes: 63 additions & 0 deletions src/mb_opt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
function detect_small_views(expr::Code.Let, state)
matches = []
for (i, p) in enumerate(expr.pairs)
r = rhs(p)
iscall(r) || continue
if operation(r) === view
arr, inds... = arguments(r)
myt = find_term(inds[1], expr)
is_small_hvncat(size(Code.rhs(myt))...) || continue
push!(matches, (idx = i, expr = r))
end
end
matches
end

function construct_type(dims)
# if length(dims) == 1
# return Core.apply_type(SVector, dims[1])
# else
# return Core.apply_type(SVector, Tuple(dims))
# end
Core.apply_type(SVector, length(dims))
end

function find_term(target, expr::Code.Let)
filter(expr.pairs) do p
Code.lhs(p) === target
end |> only
end

function transform_view(expr, match_data, state)
new_pairs = []
idxs = Set(getproperty.(match_data, :idx))
transformations = Dict()
for match in match_data
idx = match.idx
r = match.expr
T = symtype(r)
V = vartype(r)
arr, inds... = arguments(r)
t = term(construct_type, inds[1])
transformations[idx] = Term{V}(t, [r], type = T)
end

for (i, p) in enumerate(expr.pairs)
if i in idxs
new_rhs = transformations[i]
push!(new_pairs, Code.Assignment(lhs(p), new_rhs))
else
push!(new_pairs, p)
end
end

Code.Let(new_pairs, expr.body, expr.let_block)
end


const MB_VIEW_RULE = OptimizationRule(
"MB_VIEW_RULE",
detect_small_views,
transform_view,
10,
)
101 changes: 101 additions & 0 deletions src/ortho_inv_opt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
is_orthogonal_matrix(A) = A * A' ≈ I(size(A, 1))
is_orthogonal_type(x) = begin
if issym(x)
return getmetadata(x, IsOrthogonal, false)
else
return false
end
false
end

struct IsOrthogonal end

struct InvMatched{T, C} <: AbstractMatched
A::T
candidate::C
idx::Int
end

function detect_orthogonal_matrix(expr, state::Code.CSEState)
# var_to_ortho = Dict()
# ortho_to_var = Dict()
# for v in search_variables(expr)
# @show symtype(v)
# bank(var_to_ortho, v, getmetadata(v, IsOrthogonal, false))
# bank(ortho_to_var, getmetadata(v, IsOrthogonal, false), v)
# end

# if !haskey(ortho_to_var, true)
# @warn "not found any metadata arrays"
# return nothing
# end

idxs = findall(expr.pairs) do p
r = rhs(p)
iscall(r) || return false
op = operation(r)
if op === inv
args = arguments(r)
length(args) == 1 || return false
getmetadata(args[1], IsOrthogonal, false) == true || return false
return true
end
false
end

candidates = expr.pairs[idxs]

matches = map(idxs, candidates) do idx, candidate
A = arguments(rhs(candidate))[1]
InvMatched(A, candidate, idx)
end

f = filter(!isnothing, matches)
isempty(f) ? nothing : f
end

function transform_inv_optimization(expr, matches, state::Code.CSEState)

expr_copy = deepcopy(expr)
map(matches) do match
A = match.A
if getmetadata(A, IsOrthogonal, false) == true
expr_copy.pairs[match.idx] = Code.Assignment(
lhs(match.candidate),
transpose(A)
)
else
t = term(is_orthogonal_type, A)
# code = IfElse(
# t,
# transpose(A),
# inv(A)
# )
code = ifelse(t, transpose(A), inv(A))
expr_copy.pairs[match.idx] = Code.Assignment(
lhs(match.candidate),
code
)
end
end
expr_copy
end

const ORTHO_INV_RULE = OptimizationRule(
"Ortho_Inv",
detect_orthogonal_matrix,
transform_inv_optimization,
10
)

function ortho_inv_opt(expr, state::Code.CSEState)

# Try to apply optimization rules
optimized = apply_optimization_rules(expr, state, [ORTHO_INV_RULE])
if optimized !== nothing
return optimized
end

# If no optimization applied, return original expression
return expr
end
Loading
Loading