diff --git a/docs/src/changelog.md b/docs/src/changelog.md index 57177dd6..d285c2d0 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -11,6 +11,8 @@ Changelog](https://keepachangelog.com). ([5a1cc29](https://github.com/JuliaInterop/Clang.jl/commit/5a1cc29c154ed925f01e59dfd705cbf8042158e4)). - Added bindings for Clang 17, which should allow compatibility with Julia 1.12 ([#494]). +- Experimental support for automatically dereferencing struct fields in + `Base.getproperty()` with the `auto_field_dereference` option ([#502]). ### Fixed diff --git a/gen/generator.toml b/gen/generator.toml index 08cbfe04..b3029e6c 100644 --- a/gen/generator.toml +++ b/gen/generator.toml @@ -181,6 +181,19 @@ wrap_variadic_function = false # generate getproperty/setproperty! methods for the types in the following list field_access_method_list = [] +# EXPERIMENTAL: +# By default the getproperty!(x::Ptr, ::Symbol) methods created for wrapped +# types will return pointers (Ptr{T}) to the struct fields. That behaviour is +# useful for accessing nested struct fields but it does require explicitly +# calling unsafe_load() every time. When enabled this option will automatically +# call unsafe_load() for you *except on nested struct fields and arrays*, which +# should make explicitly calling unsafe_load() unnecessary in most cases. A @ptr +# macro will be defined for cases where you really do want a pointer to a field +# (e.g. for writing), which supports syntax like `@ptr(foo.bar)`. +# +# This should be used with `field_access_method_list`. +auto_field_dereference = false + # the generator will prefix the function argument names in the following list with a "_" to # prevent the generated symbols from conflicting with the symbols defined and exported in Base. function_argument_conflict_symbols = [] diff --git a/src/generator/codegen.jl b/src/generator/codegen.jl index be91290f..56eb67f7 100644 --- a/src/generator/codegen.jl +++ b/src/generator/codegen.jl @@ -296,19 +296,20 @@ end ############################### Struct ############################### -function _emit_getproperty_ptr!(body, root_cursor, cursor, options) +function _emit_pointer_access!(body, root_cursor, cursor, options) field_cursors = fields(getCursorType(cursor)) field_cursors = isempty(field_cursors) ? children(cursor) : field_cursors for field_cursor in field_cursors n = name(field_cursor) if isempty(n) - _emit_getproperty_ptr!(body, root_cursor, field_cursor, options) + _emit_pointer_access!(body, root_cursor, field_cursor, options) continue end fsym = make_symbol_safe(n) fty = getCursorType(field_cursor) ty = translate(tojulia(fty), options) offset = getOffsetOf(getCursorType(root_cursor), n) + if isBitField(field_cursor) w = getFieldDeclBitWidth(field_cursor) @assert w <= 32 # Bit fields should not be larger than int(32 bits) @@ -322,12 +323,74 @@ function _emit_getproperty_ptr!(body, root_cursor, cursor, options) end end -# Base.getproperty(x::Ptr, f::Symbol) -> Ptr +# getptr(x::Ptr, f::Symbol) -> Ptr +function emit_getptr!(dag, node, options) + sym = make_symbol_safe(node.id) + signature = Expr(:call, :getptr, :(x::Ptr{$sym}), :(f::Symbol)) + body = Expr(:block) + _emit_pointer_access!(body, node.cursor, node.cursor, options) + + # The default field access exception changed to FieldError in 1.12 + throw_expr = :( + @static if VERSION >= v"1.12.0-DEV" + throw(FieldError($sym, f)) + else + error($("Unrecognized field of type `$sym`") * ": $f") + end + ) + Base.remove_linenums!(throw_expr) + throw_expr.args[2] = nothing # Remove the sticky LineNumberNode from the macro + + push!(body.args, throw_expr) + push!(node.exprs, Expr(:function, signature, body)) + return dag +end + +function emit_deref_getproperty!(body, root_cursor, cursor, options) + field_cursors = fields(getCursorType(cursor)) + field_cursors = isempty(field_cursors) ? children(cursor) : field_cursors + for field_cursor in field_cursors + n = name(field_cursor) + if isempty(n) + emit_deref_getproperty!(body, root_cursor, field_cursor, options) + continue + end + fsym = make_symbol_safe(n) + fty = getCursorType(field_cursor) + canonical_type = getCanonicalType(fty) + + return_expr = :(getptr(x, f)) + + # Automatically dereference all field types except for nested structs + # and arrays. + if !(canonical_type isa Union{CLRecord, CLConstantArray}) && !isBitField(field_cursor) + return_expr = :(unsafe_load($return_expr)) + elseif isBitField(field_cursor) + return_expr = :(getbitfieldproperty(x, $return_expr)) + end + + ex = :(f === $(QuoteNode(fsym)) && return $return_expr) + push!(body.args, ex) + end +end + +# Base.getproperty(x::Ptr, f::Symbol) function emit_getproperty_ptr!(dag, node, options) + auto_deref = get(options, "auto_field_dereference", false) sym = make_symbol_safe(node.id) + + # If automatically dereferencing, we first need to emit getptr!() + if auto_deref + emit_getptr!(dag, node, options) + end + signature = Expr(:call, :(Base.getproperty), :(x::Ptr{$sym}), :(f::Symbol)) body = Expr(:block) - _emit_getproperty_ptr!(body, node.cursor, node.cursor, options) + if auto_deref + emit_deref_getproperty!(body, node.cursor, node.cursor, options) + else + _emit_pointer_access!(body, node.cursor, node.cursor, options) + end push!(body.args, :(return getfield(x, f))) getproperty_expr = Expr(:function, signature, body) push!(node.exprs, getproperty_expr) @@ -345,45 +408,24 @@ end function emit_getproperty!(dag, node, options) sym = make_symbol_safe(node.id) - ref_expr = :(r = Ref{$sym}(x)) - conv_expr = :(ptr = Base.unsafe_convert(Ptr{$sym}, r)) - fptr_expr = :(fptr = getproperty(ptr, f)) - - load_expr = :(GC.@preserve r unsafe_load(fptr)) - load_expr.args[2] = nothing + # Build the macrocall manually so we can set the extra LineNumberNode to + # nothing to stop it from being printed. + return_expr = :(GC.@preserve r getproperty(ptr, f)) + return_expr.args[2] = nothing - load_base_expr = :(GC.@preserve r unsafe_load(baseptr32)) - load_base_expr.args[2] = nothing - load_next_expr = :(GC.@preserve r unsafe_load(baseptr32 + 4)) - load_next_expr.args[2] = nothing - - if is_bitfield_type(node.type) - ex = quote - if fptr isa Ptr - return $load_expr - else - baseptr, offset, width = fptr - ty = eltype(baseptr) - baseptr32 = convert(Ptr{UInt32}, baseptr) - u64 = $load_base_expr - if offset + width > 32 - u64 |= ($load_next_expr) << 32 - end - u64 = (u64 >> offset) & ((1 << width) - 1) - return u64 % ty - end + ex = quote + function Base.getproperty(x::$sym, f::Symbol) + r = Ref{$sym}(x) + ptr = Base.unsafe_convert(Ptr{$sym}, r) + return $return_expr end - else - ex = load_expr end + # Remove line number nodes and the enclosing :block node rm_line_num_node!(ex) + ex = ex.args[1] - signature = Expr(:call, :(Base.getproperty), :(x::$sym), :(f::Symbol)) - body = Expr(:block, ref_expr, conv_expr, fptr_expr, ex) - getproperty_expr = Expr(:function, signature, body) - - push!(node.exprs, getproperty_expr) + push!(node.exprs, ex) return dag end @@ -391,27 +433,19 @@ end function emit_setproperty!(dag, node, options) sym = make_symbol_safe(node.id) signature = Expr(:call, :(Base.setproperty!), :(x::Ptr{$sym}), :(f::Symbol), :v) - store_expr = :(unsafe_store!(getproperty(x, f), v)) + + auto_deref = get(options, "auto_field_dereference", false) + pointer_getter = auto_deref ? :getptr : :getproperty + store_expr = :(unsafe_store!($pointer_getter(x, f), v)) + if is_bitfield_type(node.type) body = quote - fptr = getproperty(x, f) + fptr = $pointer_getter(x, f) if fptr isa Ptr $store_expr else - baseptr, offset, width = fptr - baseptr32 = convert(Ptr{UInt32}, baseptr) - u64 = unsafe_load(baseptr32) - straddle = offset + width > 32 - if straddle - u64 |= unsafe_load(baseptr32 + 4) << 32 - end - mask = ((1 << width) - 1) - u64 &= ~(mask << offset) - u64 |= (unsigned(v) & mask) << offset - unsafe_store!(baseptr32, u64 & typemax(UInt32)) - if straddle - unsafe_store!(baseptr32 + 4, u64 >> 32) - end + # setbitfieldproperty!() is emitted by ProloguePrinter + setbitfieldproperty!(fptr, v) end end rm_line_num_node!(body) @@ -431,7 +465,7 @@ function get_names_types(root_cursor, cursor, options) for field_cursor in field_cursors n = name(field_cursor) if isempty(n) - _emit_getproperty_ptr!(root_cursor, field_cursor, options) + _emit_pointer_access!(root_cursor, field_cursor, options) continue end fsym = make_symbol_safe(n) diff --git a/src/generator/passes.jl b/src/generator/passes.jl index a4d49ebf..2abb4e9b 100644 --- a/src/generator/passes.jl +++ b/src/generator/passes.jl @@ -1094,6 +1094,7 @@ function (x::ProloguePrinter)(dag::ExprDAG, options::Dict) use_native_enum = get(general_options, "use_julia_native_enum_type", false) print_CEnum = get(general_options, "print_using_CEnum", true) wrap_variadic_function = get(codegen_options, "wrap_variadic_function", false) + auto_deref = get(codegen_options, "auto_field_dereference", false) show_info && @info "[ProloguePrinter]: print to $(x.file)" open(x.file, "w") do io @@ -1132,6 +1133,75 @@ function (x::ProloguePrinter)(dag::ExprDAG, options::Dict) """) end + # Print the bitfield helpers if there are any bitfield structs + if any(is_bitfield_type(node.type) for node in dag.nodes) + # These expressions include macrocalls, which are oddly clingy to + # their LineNumberNode's (rm_line_num_node!() will give wrong + # results). Instead we create the expressions explicitly and remove + # the LineNumberNode's by setting their 2nd argument to nothing. + u64_expr = :(GC.@preserve obj_handle unsafe_load(baseptr32)) + u64_expr.args[2] = nothing + ptrload_expr = :(GC.@preserve obj_handle unsafe_load(baseptr32 + 4)) + ptrload_expr.args[2] = nothing + + get_expr = quote + function getbitfieldproperty(obj_handle, bitfield_info) + baseptr, offset, width = bitfield_info + ty = eltype(baseptr) + baseptr32 = convert(Ptr{UInt32}, baseptr) + u64 = $u64_expr + if offset + width > 32 + u64 |= ($ptrload_expr) << 32 + end + u64 = (u64 >> offset) & ((1 << width) - 1) + return u64 % ty + end + end + + set_expr = quote + function setbitfieldproperty!(bitfield_info, value) + baseptr, offset, width = bitfield_info + baseptr32 = convert(Ptr{UInt32}, baseptr) + u64 = unsafe_load(baseptr32) + straddle = offset + width > 32 + if straddle + u64 |= unsafe_load(baseptr32 + 4) << 32 + end + mask = ((1 << width) - 1) + u64 &= ~(mask << offset) + u64 |= (unsigned(value) & mask) << offset + unsafe_store!(baseptr32, u64 & typemax(UInt32)) + if straddle + unsafe_store!(baseptr32 + 4, u64 >> 32) + end + end + end + + # Remove line number nodes and the extra :block node + rm_line_num_node!(get_expr) + rm_line_num_node!(set_expr) + get_expr = get_expr.args[1] + set_expr = set_expr.args[1] + + println(io, string(get_expr), "\n") + println(io, string(set_expr), "\n") + end + + if auto_deref + println(io, raw""" + macro ptr(expr) + if !Meta.isexpr(expr, :.) + error("Expression is not a property access, cannot use @ptr on it.") + end + + quote + local penultimate_obj = $(esc(expr.args[1])) + getptr(penultimate_obj, $(esc(expr.args[2]))) + end + end + """) + end + # print prelogue patches if !isempty(prologue_file_path) println(io, read(prologue_file_path, String)) diff --git a/test/generators.jl b/test/generators.jl index 9f72c7e0..3d094549 100644 --- a/test/generators.jl +++ b/test/generators.jl @@ -249,3 +249,109 @@ end @test docstring_has("callback") end end + +@testset "Struct getproperty()/setproperty!()" begin + options = Dict("general" => Dict{String, Any}("auto_mutability" => true, + "auto_mutability_with_new" => false, + "auto_mutability_includelist" => ["WithFields"]), + "codegen" => Dict{String, Any}("field_access_method_list" => ["WithFields", "Other"])) + + # Test the default getproperty()/setproperty!() behaviour + mktemp() do path, io + options["general"]["output_file_path"] = path + ctx = create_context([joinpath(@__DIR__, "include/struct-properties.h")], get_default_args(), options) + build!(ctx) + + println(read(path, String)) + + m = Module() + Base.include(m, path) + + # We now have to run in the latest world to use the new definitions + Base.invokelatest() do + obj = m.WithFields(1, C_NULL, m.Other(42), C_NULL, m.TypedefStruct(1), (1, 1)) + + GC.@preserve obj begin + obj_ptr = Ptr{m.WithFields}(pointer_from_objref(obj)) + + # The default getproperty() should basically always return a + # pointer to the field (except for bitfields, which are tested + # elsewhere). + @test obj_ptr.int_value isa Ptr{Cint} + @test obj_ptr.int_ptr isa Ptr{Ptr{Cint}} + @test obj_ptr.struct_value isa Ptr{m.Other} + @test obj_ptr.typedef_struct_value isa Ptr{m.TypedefStruct} + @test obj_ptr.array isa Ptr{NTuple{2, Cint}} + + # Sanity test + int_value = unsafe_load(obj_ptr.int_value) + @test int_value == obj.int_value + + # Test setproperty!() + obj_ptr.int_value = int_value + 1 + @test unsafe_load(obj_ptr.int_value) == int_value + 1 + end + end + end + + # Test the auto_field_dereference option + mktemp() do path, io + options["general"]["output_file_path"] = path + options["codegen"]["auto_field_dereference"] = true + ctx = create_context([joinpath(@__DIR__, "include/struct-properties.h")], get_default_args(), options) + build!(ctx) + + println(read(path, String)) + + m = Module() + Base.include(m, path) + + # We now have to run in the latest world to use the new definitions + Base.invokelatest() do + obj = m.WithFields(1, C_NULL, m.Other(42), C_NULL, m.TypedefStruct(1), (1, 1)) + + GC.@preserve obj begin + obj_ptr = Ptr{m.WithFields}(pointer_from_objref(obj)) + + # Test getproperty() + @test obj_ptr.int_value isa Cint + @test obj_ptr.int_value == obj.int_value + @test obj_ptr.int_ptr isa Ptr{Cint} + + @test obj_ptr.struct_value isa Ptr{m.Other} + @test obj_ptr.struct_value.i == obj.struct_value.i + @test obj_ptr.struct_ptr isa Ptr{m.Other} + @test obj_ptr.typedef_struct_value isa Ptr{m.TypedefStruct} + + @test obj_ptr.array isa Ptr{NTuple{2, Cint}} + + field_exception_t = @static if VERSION >= v"1.12.0-DEV" + FieldError + else + ErrorException + end + @test_throws field_exception_t obj_ptr.foo + + # Test @ptr + val_ptr = @eval m @ptr $obj_ptr.int_value + @test val_ptr isa Ptr{Cint} + int_ptr = @eval m @ptr $obj_ptr.int_ptr + @test int_ptr isa Ptr{Ptr{Cint}} + + @test_throws LoadError (@eval m @ptr $obj_ptr) + @test_throws field_exception_t (@eval m @ptr $obj_ptr.foo) + + # Test setproperty!() + new_value = obj.int_value * 2 + obj_ptr.int_value = new_value + @test obj.int_value == new_value + + new_value = obj.struct_value.i * 2 + obj_ptr.struct_value.i = new_value + @test obj.struct_value.i == new_value + + @test_throws field_exception_t obj_ptr.foo = 1 + end + end + end +end diff --git a/test/include/struct-properties.h b/test/include/struct-properties.h new file mode 100644 index 00000000..d836eba1 --- /dev/null +++ b/test/include/struct-properties.h @@ -0,0 +1,18 @@ +typedef struct { + int i; +} TypedefStruct; + +struct Other { + int i; +}; + +struct WithFields { + int int_value; + int* int_ptr; + + struct Other struct_value; + struct Other* struct_ptr; + TypedefStruct typedef_struct_value; + + int array[2]; +}; diff --git a/test/runtests.jl b/test/runtests.jl index acef7399..4fb038d9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,10 @@ using Clang using Test +# Temporary hack to make @doc work in 1.11 for the documentation tests. See: +# https://github.com/JuliaLang/julia/issues/54664 +using REPL + include("jllenvs.jl") include("file.jl") include("generators.jl") diff --git a/test/test_bitfield.jl b/test/test_bitfield.jl index 51bc6da1..fec17ec8 100644 --- a/test/test_bitfield.jl +++ b/test/test_bitfield.jl @@ -36,8 +36,10 @@ function build_libbitfield_binarybuilder() success = true try cd(@__DIR__) do + # Delete any old products and rebuild + rm("products"; force=true, recursive=true) run(`$(Base.julia_cmd()) --project bitfield/build_tarballs.jl`) - # from Pkg.download_verify_unpack + # Note that we filter out the extra log file that's generated tarball_path = only(filter(!contains("-logs.v"), readdir("products"))) dest = "build" @@ -59,21 +61,8 @@ function build_libbitfield() error("Could not build libbitfield binary") end - # Generate wrappers - @info "Building libbitfield wrapper" - args = get_default_args() - headers = joinpath(@__DIR__, "build", "include", "bitfield.h") - options = load_options(joinpath(@__DIR__, "bitfield", "generate.toml")) - lib_path = joinpath(@__DIR__, "build", "lib", Sys.iswindows() ? "bitfield.dll" : "libbitfield") - options["general"]["library_name"] = "\"$(escape_string(lib_path))\"" - options["general"]["output_file_path"] = joinpath(@__DIR__, "LibBitField.jl") - ctx = create_context(headers, args, options) - build!(ctx) - - # Call a function to ensure build is successful - include("LibBitField.jl") - m = Base.@invokelatest LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3) - Base.@invokelatest LibBitField.toBitfield(Ref(m)) + # Test the binary + generate_wrappers(false) catch e @warn "Building libbitfield failed: $e" success = false @@ -81,22 +70,51 @@ function build_libbitfield() return success end +function generate_wrappers(auto_deref::Bool) + @info "Building libbitfield wrapper" + args = get_default_args() + headers = joinpath(@__DIR__, "build", "include", "bitfield.h") + options = load_options(joinpath(@__DIR__, "bitfield", "generate.toml")) + options["codegen"]["auto_field_dereference"] = auto_deref + options["codegen"]["field_access_method_list"] = ["BitField"] + + lib_path = joinpath(@__DIR__, "build", "lib", Sys.iswindows() ? "bitfield.dll" : "libbitfield") + options["general"]["library_name"] = "\"$(escape_string(lib_path))\"" + options["general"]["output_file_path"] = joinpath(@__DIR__, "LibBitField.jl") + ctx = create_context(headers, args, options) + build!(ctx) + + # Call a function to ensure build is successful + anonmod = Module() + Base.include(anonmod, "LibBitField.jl") + m = Base.@invokelatest anonmod.LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3) + Base.@invokelatest anonmod.LibBitField.toBitfield(Ref(m)) + return anonmod +end @testset "Bitfield" begin if build_libbitfield() - bf = Ref(LibBitField.BitField(Int8(10), 1.5, Int32(1e6), Int32(-4), Int32(7), UInt32(3))) - m = Ref(LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3)) - GC.@preserve bf m begin - pbf = Ptr{LibBitField.BitField}(pointer_from_objref(bf)) - pm = Ptr{LibBitField.Mirror}(pointer_from_objref(m)) - @test LibBitField.toMirror(bf) == m[] - @test LibBitField.toBitfield(m).a == bf[].a - @test LibBitField.toBitfield(m).b == bf[].b - @test LibBitField.toBitfield(m).c == bf[].c - @test LibBitField.toBitfield(m).d == bf[].d - @test LibBitField.toBitfield(m).e == bf[].e - @test LibBitField.toBitfield(m).f == bf[].f + # Test the wrappers with and without auto-dereferencing. In the case of + # bitfields they should have identical behaviour. + for auto_deref in [false, true] + anonmod = generate_wrappers(auto_deref) + lib = anonmod.LibBitField + + bf = Ref(lib.BitField(Int8(10), 1.5, Int32(1e6), Int32(-4), Int32(7), UInt32(3))) + m = Ref(lib.Mirror(10, 1.5, 1e6, -4, 7, 3)) + + GC.@preserve bf m begin + pbf = Ptr{lib.BitField}(pointer_from_objref(bf)) + pm = Ptr{lib.Mirror}(pointer_from_objref(m)) + @test lib.toMirror(bf) == m[] + @test lib.toBitfield(m).a == bf[].a + @test lib.toBitfield(m).b == bf[].b + @test lib.toBitfield(m).c == bf[].c + @test lib.toBitfield(m).d == bf[].d + @test lib.toBitfield(m).e == bf[].e + @test lib.toBitfield(m).f == bf[].f + end end end end