From 2575787ce07b5d6cface105d764c5cc33adcef42 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 2 Mar 2025 23:58:46 -0600 Subject: [PATCH 1/4] feat: overlay `error` with a custom_call --- src/Ops.jl | 39 ++++++++++++++++++++++++++++----------- src/Overlay.jl | 10 +++++++++- src/mlir/IR/Attribute.jl | 2 +- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index b49ca654c0..1f5f02adad 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -568,7 +568,7 @@ end Tuple(rsize) end else - error("Invalid FFT type: $type") + Base.error("Invalid FFT type: $type") end res = MLIR.IR.result( @@ -841,7 +841,7 @@ end sizea = Dict(c => d for (c, d) in zip(ia, size(lhs))) sizeb = Dict(c => d for (c, d) in zip(ib, size(rhs))) sizes = mergewith(sizea, sizeb) do da, db - da == db ? da : error("Invalid dimensions in einsum equation") + da == db ? da : Base.error("Invalid dimensions in einsum equation") end rsize = Tuple(sizes[i] for i in ic) @@ -921,7 +921,7 @@ end elseif typ <: TracedRNumber return typ((), res) else - error("Invalid type: $typ") + Base.error("Invalid type: $typ") end end, ) @@ -1074,9 +1074,7 @@ end @assert fn_name == "comparator" "$comparator: no function generated" ftype_attr = MLIR.IR.attr(func, "function_type") ftype = MLIR.IR.Type(ftype_attr) - @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) error( - "$comparator return type is not tensor" - ) + @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) "$comparator return type is not tensor" comparator = MLIR.IR.Region() MLIR.API.mlirRegionTakeBody(comparator, MLIR.IR.region(func, 1)) @@ -1495,7 +1493,7 @@ julia> Reactant.@jit( end if isnothing(fn) - error("hlo_call: could not find function $func_name in the provided module") + Base.error("hlo_call: could not find function $func_name in the provided module") end ftype_attr = MLIR.IR.attr(fn, "function_type") @@ -1774,7 +1772,7 @@ end end end if isnothing(path) - error("if_condition: could not find path for linear arg $i") + Base.error("if_condition: could not find path for linear arg $i") end Reactant.TracedUtils.set_mlir_data!( arg, @@ -1837,7 +1835,7 @@ end end end if isnothing(path) - error("if_condition: could not find path for linear arg $i") + Base.error("if_condition: could not find path for linear arg $i") end Reactant.TracedUtils.set_mlir_data!( arg, @@ -2063,7 +2061,7 @@ end corrected_traced_results = fmap(traced_false_results, traced_true_results) do fr, tr if fr isa MissingTracedValue && tr isa MissingTracedValue - error("Both false and true branches are missing") + Base.error("Both false and true branches are missing") elseif fr isa MissingTracedValue return tr else @@ -2088,7 +2086,7 @@ end end end if isnothing(argpath) - error("if_condition: could not find path for resarg $path") + Base.error("if_condition: could not find path for resarg $path") end Reactant.TracedUtils.set!(args, argpath, MLIR.IR.result(if_compiled, residx)) end @@ -2313,4 +2311,23 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi end end +@noinline function throw( + msg::String, + value::Union{TracedRNumber{Bool},Nothing}=nothing; + location=mlir_stacktrace("throw", @__FILE__, @__LINE__), +) + value === nothing && + (value = Reactant.TracedUtils.promote_to(TracedRNumber{Bool}, true)) + + return stablehlo.custom_call( + MLIR.IR.Value[value.mlir_data]; + result_0=MLIR.IR.Type[], + has_side_effect=true, + call_target_name="throw", + backend_config=MLIR.IR.Attribute(Dict("error_message" => MLIR.IR.Attribute(msg))), + api_version=MLIR.IR.Attribute(Int32(4)), + location, + ) +end + end # module Ops diff --git a/src/Overlay.jl b/src/Overlay.jl index 5d9b85c838..97a6daec3d 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -154,8 +154,16 @@ end ## fixes #493 @reactant_overlay @noinline function Base._unique_dims(A::AbstractArray, dims::Colon) if use_overlayed_version(A) - error("Reactant doesn't have a `Base._unique_dims` with the current interpreter.") + Base.inferencebarrier(error)( + "Reactant doesn't have a `Base._unique_dims` with the current interpreter." + ) else Base.inferencebarrier(Base._unique_dims)(A, dims) end end + +## `Base.error` --> `Ops.throw` +## XXX: `throw` is a Core method. Can it be overlayed? +@reactant_overlay @noinline function Base.error(msg::String) + return Ops.throw(msg) +end diff --git a/src/mlir/IR/Attribute.jl b/src/mlir/IR/Attribute.jl index d7aac00830..ab25c2eb47 100644 --- a/src/mlir/IR/Attribute.jl +++ b/src/mlir/IR/Attribute.jl @@ -96,7 +96,7 @@ isdict(attr::Attribute) = API.mlirAttributeIsADictionary(attr) Creates a dictionary attribute containing the given list of elements in the provided context. """ function Attribute(attrs::Dict; context::Context=context()) - attrs = map(splat(NamedAttribute), attrs) + attrs = [NamedAttribute(k, v) for (k, v) in attrs] return Attribute(API.mlirDictionaryAttrGet(context, length(attrs), attrs)) end From bd8e335e4cf819bf7f0baa75dfe1d1dbd39ad228 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 3 Mar 2025 10:26:14 -0600 Subject: [PATCH 2/4] fix: don't overlay for now --- src/Ops.jl | 16 ++++++++-------- src/Overlay.jl | 6 ------ 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 1f5f02adad..3cffdc3f10 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -568,7 +568,7 @@ end Tuple(rsize) end else - Base.error("Invalid FFT type: $type") + error("Invalid FFT type: $type") end res = MLIR.IR.result( @@ -841,7 +841,7 @@ end sizea = Dict(c => d for (c, d) in zip(ia, size(lhs))) sizeb = Dict(c => d for (c, d) in zip(ib, size(rhs))) sizes = mergewith(sizea, sizeb) do da, db - da == db ? da : Base.error("Invalid dimensions in einsum equation") + da == db ? da : error("Invalid dimensions in einsum equation") end rsize = Tuple(sizes[i] for i in ic) @@ -921,7 +921,7 @@ end elseif typ <: TracedRNumber return typ((), res) else - Base.error("Invalid type: $typ") + error("Invalid type: $typ") end end, ) @@ -1493,7 +1493,7 @@ julia> Reactant.@jit( end if isnothing(fn) - Base.error("hlo_call: could not find function $func_name in the provided module") + error("hlo_call: could not find function $func_name in the provided module") end ftype_attr = MLIR.IR.attr(fn, "function_type") @@ -1772,7 +1772,7 @@ end end end if isnothing(path) - Base.error("if_condition: could not find path for linear arg $i") + error("if_condition: could not find path for linear arg $i") end Reactant.TracedUtils.set_mlir_data!( arg, @@ -1835,7 +1835,7 @@ end end end if isnothing(path) - Base.error("if_condition: could not find path for linear arg $i") + error("if_condition: could not find path for linear arg $i") end Reactant.TracedUtils.set_mlir_data!( arg, @@ -2061,7 +2061,7 @@ end corrected_traced_results = fmap(traced_false_results, traced_true_results) do fr, tr if fr isa MissingTracedValue && tr isa MissingTracedValue - Base.error("Both false and true branches are missing") + error("Both false and true branches are missing") elseif fr isa MissingTracedValue return tr else @@ -2086,7 +2086,7 @@ end end end if isnothing(argpath) - Base.error("if_condition: could not find path for resarg $path") + error("if_condition: could not find path for resarg $path") end Reactant.TracedUtils.set!(args, argpath, MLIR.IR.result(if_compiled, residx)) end diff --git a/src/Overlay.jl b/src/Overlay.jl index 97a6daec3d..982d31f9cd 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -161,9 +161,3 @@ end Base.inferencebarrier(Base._unique_dims)(A, dims) end end - -## `Base.error` --> `Ops.throw` -## XXX: `throw` is a Core method. Can it be overlayed? -@reactant_overlay @noinline function Base.error(msg::String) - return Ops.throw(msg) -end From df5644ae9986f7cac4b2b15cef149e11f529c1d3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 3 Mar 2025 13:49:01 -0600 Subject: [PATCH 3/4] feat: register ffi --- deps/ReactantExtra/API.cpp | 23 +++++++++++++++++++++++ src/Ops.jl | 24 +++++++++++++++++------- src/Overlay.jl | 4 +--- 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 0e87f25856..3657b13058 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -137,6 +137,12 @@ #include "llvm/Support/ExtensibleRTTI.h" +// XLA FFI +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" +#include "xla/service/custom_call_target_registry.h" + using namespace mlir; using namespace llvm; using namespace xla; @@ -2281,3 +2287,20 @@ ifrt_loaded_executable_num_devices(ifrt::LoadedExecutable *exec) { } #pragma endregion + +// Register an XLA FFI for throwing runtime errors +xla::ffi::Error xla_throw_error(xla::ffi::BufferR0 cond, + std::string_view message) { + if (cond.typed_data()[0]) + return xla::ffi::Error(xla::ffi::ErrorCode::kInternal, + std::string(message)); + return xla::ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(xla_throw_error_handler, xla_throw_error, + xla::ffi::Ffi::Bind() + .Arg>() + .Attr("message")); + +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "xla_throw_error", "Host", + xla_throw_error_handler); diff --git a/src/Ops.jl b/src/Ops.jl index 3cffdc3f10..7f80a50744 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2311,20 +2311,30 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi end end +""" + throw( + msg::String, + condition::Union{TracedRNumber{Bool},Nothing}=nothing; + location=mlir_stacktrace("throw", @__FILE__, @__LINE__) + ) + +Throw a runtime error with the given `msg` if `condition` is `true`. If `condition` is not provided, it defaults to `true`. +""" @noinline function throw( msg::String, - value::Union{TracedRNumber{Bool},Nothing}=nothing; - location=mlir_stacktrace("throw", @__FILE__, @__LINE__), + condition::Union{TracedRNumber{Bool},Nothing}=nothing; + location=mlir_stacktrace("throw", @__FILE__, @__LINE__) ) - value === nothing && - (value = Reactant.TracedUtils.promote_to(TracedRNumber{Bool}, true)) + if condition === nothing + condition = Reactant.TracedUtils.promote_to(TracedRNumber{Bool}, true) + end return stablehlo.custom_call( - MLIR.IR.Value[value.mlir_data]; + MLIR.IR.Value[condition.mlir_data]; result_0=MLIR.IR.Type[], has_side_effect=true, - call_target_name="throw", - backend_config=MLIR.IR.Attribute(Dict("error_message" => MLIR.IR.Attribute(msg))), + call_target_name="xla_throw_error", + backend_config=MLIR.IR.Attribute(Dict("message" => MLIR.IR.Attribute(msg))), api_version=MLIR.IR.Attribute(Int32(4)), location, ) diff --git a/src/Overlay.jl b/src/Overlay.jl index 982d31f9cd..5d9b85c838 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -154,9 +154,7 @@ end ## fixes #493 @reactant_overlay @noinline function Base._unique_dims(A::AbstractArray, dims::Colon) if use_overlayed_version(A) - Base.inferencebarrier(error)( - "Reactant doesn't have a `Base._unique_dims` with the current interpreter." - ) + error("Reactant doesn't have a `Base._unique_dims` with the current interpreter.") else Base.inferencebarrier(Base._unique_dims)(A, dims) end From c85a7cbbfda893c2db1e3e8c1087678753e3c7c8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 3 Mar 2025 13:55:09 -0600 Subject: [PATCH 4/4] feat: add cuda handler --- deps/ReactantExtra/API.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 3657b13058..99db424c50 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -2304,3 +2304,22 @@ XLA_FFI_DEFINE_HANDLER(xla_throw_error_handler, xla_throw_error, XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "xla_throw_error", "Host", xla_throw_error_handler); + +#ifdef REACTANT_CUDA +#include "third_party/gpus/cuda/include/cuda.h" + +xla::ffi::Error xla_throw_error(cudaStream_t stream, + xla::ffi::BufferR0 cond, + std::string_view message) { + return xla_throw_error(cond, message); +} + +XLA_FFI_DEFINE_HANDLER(xla_throw_error_handler_cuda, xla_throw_error, + xla::ffi::Ffi::Bind() + .Ctx>() + .Arg>() + .Attr("message")); + +XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "xla_throw_error", "CUDA", + xla_throw_error_handler_cuda); +#endif