Skip to content

Commit 417c361

Browse files
feat: more linear algebra operations (non factorization changes) (#1883)
* feat: add cheaper det/logabsdet if triangular * feat: lowering of inv * fix: accidental promotion in norm * feat: support cross * feat: symmetric/hermitian/banded check * feat: normalize/normalize\! * Apply suggestion from @avik-pal * feat: cholesky decomposition (#1884) * refactor: move LU into a separate file * feat: lower cholesky * feat: lowering cholesky ldiv * test: cholesky * fix: revert change to is symm/hermitian * fix: overload normalize * test: use low cond number matrix * Update test/integration/linear_algebra.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 5ef447d commit 417c361

File tree

9 files changed

+653
-228
lines changed

9 files changed

+653
-228
lines changed

src/Overlay.jl

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -230,36 +230,42 @@ end
230230
end
231231

232232
# LinearAlgebra
233-
@reactant_overlay @noinline function LinearAlgebra.lu(x::AbstractArray; kwargs...)
234-
if use_overlayed_version(x)
235-
return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...)
236-
else
237-
return Base.inferencebarrier(LinearAlgebra.lu)(x; kwargs...)
238-
end
239-
end
240-
@reactant_overlay @noinline function LinearAlgebra.lu(
241-
x::AbstractArray, pivot::RowMaximum; kwargs...
242-
)
243-
if use_overlayed_version(x)
244-
return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...)
245-
else
246-
return Base.inferencebarrier(LinearAlgebra.lu)(x, pivot; kwargs...)
247-
end
248-
end
249-
@reactant_overlay @noinline function LinearAlgebra.lu!(x::AbstractArray; kwargs...)
250-
if use_overlayed_version(x)
251-
return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...)
252-
else
253-
return Base.inferencebarrier(LinearAlgebra.lu!)(x; kwargs...)
254-
end
255-
end
256-
@reactant_overlay @noinline function LinearAlgebra.lu!(
257-
x::AbstractArray, pivot::RowMaximum; kwargs...
233+
## Various factorizations
234+
## TODO: specialize for `cholesky!` --> cholcopy
235+
factorization_copy(f::F, x, pivot) where {F} = x
236+
factorization_copy(f::F, x) where {F} = x
237+
238+
for (jlop, rop, default_pivot) in (
239+
(:lu, :overloaded_lu, RowMaximum),
240+
(:lu!, :overloaded_lu, RowMaximum),
241+
(:cholesky, :overloaded_cholesky, NoPivot),
242+
(:cholesky!, :overloaded_cholesky, NoPivot),
258243
)
259-
if use_overlayed_version(x)
260-
return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...)
261-
else
262-
return Base.inferencebarrier(LinearAlgebra.lu!)(x, pivot; kwargs...)
244+
@eval begin
245+
@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
246+
x::AbstractArray; kwargs...
247+
)
248+
if use_overlayed_version(x)
249+
pivot = $(default_pivot)()
250+
return TracedLinearAlgebra.$(rop)(
251+
factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs...
252+
)
253+
else
254+
return Base.inferencebarrier(LinearAlgebra.$(jlop))(x; kwargs...)
255+
end
256+
end
257+
258+
@reactant_overlay @noinline function LinearAlgebra.$(jlop)(
259+
x::AbstractArray, pivot::$(default_pivot); kwargs...
260+
)
261+
if use_overlayed_version(x)
262+
return TracedLinearAlgebra.$(rop)(
263+
factorization_copy(LinearAlgebra.$(jlop), x, pivot), pivot; kwargs...
264+
)
265+
else
266+
return Base.inferencebarrier(LinearAlgebra.$(jlop))(x, pivot; kwargs...)
267+
end
268+
end
263269
end
264270
end
265271

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module Reactant
33
using ReactantCore:
44
ReactantCore, @trace, within_compile, MissingTracedValue, materialize_traced_array
55

6-
using LinearAlgebra: LinearAlgebra, RowMaximum
6+
using LinearAlgebra: LinearAlgebra, RowMaximum, NoPivot
77
using Random: Random, AbstractRNG
88
using EnumX: @enumx
99
using Functors: Functors, @leaf

src/TracedRArray.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -730,11 +730,20 @@ end
730730

731731
# stack
732732
function overloaded_stack(dims::Union{Integer,Colon}, xs)
733-
@assert allequal([ndims(x) for x in xs]) "All arrays must have the same number of \
734-
dimensions..."
735-
dims = dims isa Colon ? ndims(first(xs)) + 1 : dims
733+
dims = dims isa Colon ? nothing : dims
736734
res = []
737-
for x in xs
735+
prev_dims = nothing
736+
for x in unwrapped_broadcast(identity, xs)
737+
cur_dims = ndims(x)
738+
if prev_dims === nothing
739+
prev_dims = cur_dims
740+
else
741+
@assert prev_dims == cur_dims "All arrays must have the same number of \
742+
dimensions..."
743+
end
744+
745+
dims === nothing && (dims = cur_dims + 1)
746+
738747
new_shape = ntuple(
739748
i -> i == dims ? 1 : (i < dims ? size(x, i) : size(x, i - 1)), ndims(x) + 1
740749
)

src/TracedRNumber.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,25 @@ Base.copy(x::TracedRNumber{T}) where {T} = TracedRNumber{T}((), x.mlir_data)
2626
function Base.eps(::Type{TracedRNumber{T}}) where {T}
2727
return Reactant.promote_to(TracedRNumber{T}, eps(T))
2828
end
29+
Base.eps(x::TracedRNumber{T}) where {T} = eps(typeof(x))
2930

3031
function Base.typemin(::Type{TracedRNumber{T}}) where {T}
3132
return Reactant.promote_to(TracedRNumber{T}, typemin(T))
3233
end
34+
Base.typemin(x::TracedRNumber{T}) where {T} = typemin(typeof(x))
35+
3336
function Base.typemax(::Type{TracedRNumber{T}}) where {T}
3437
return Reactant.promote_to(TracedRNumber{T}, typemax(T))
3538
end
39+
Base.typemax(x::TracedRNumber{T}) where {T} = typemax(typeof(x))
40+
41+
function Base.nextfloat(x::TracedRNumber{T}) where {T<:AbstractFloat}
42+
return @opcall next_after(x, typemax(x))
43+
end
44+
45+
function Base.prevfloat(x::TracedRNumber{T}) where {T<:AbstractFloat}
46+
return @opcall next_after(x, typemin(x))
47+
end
3648

3749
function Base.rtoldefault(T::Type{<:TracedRNumber})
3850
return T(Base.rtoldefault(unwrapped_eltype(T)))

0 commit comments

Comments
 (0)