Skip to content

Commit

Permalink
Remove type piracy and enable function wrapping (#505)
Browse files Browse the repository at this point in the history
Ignore MPS functions for now since MPS does not seem to have a dylib.
  • Loading branch information
christiangnrd authored Dec 24, 2024
1 parent b949b14 commit 6a760a6
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 20 deletions.
95 changes: 95 additions & 0 deletions lib/mtl/libmtl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,83 @@ struct MTLTextureSwizzleChannels
alpha::MTLTextureSwizzle
end

function MTLTextureSwizzleChannelsMake(r, g, b, a)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLTextureSwizzleChannelsMake(r::MTLTextureSwizzle,
g::MTLTextureSwizzle,
b::MTLTextureSwizzle,
a::MTLTextureSwizzle)::MTLTextureSwizzleChannels
end

struct MTLOrigin
x::NSUInteger
y::NSUInteger
z::NSUInteger
MTLOrigin(x=0, y=0, z=0) = new(x, y, z)
end

function MTLOriginMake(x, y, z)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLOriginMake(x::NSUInteger,
y::NSUInteger,
z::NSUInteger)::MTLOrigin
end

struct MTLSize
width::NSUInteger
height::NSUInteger
depth::NSUInteger
MTLSize(w=1, h=1, d=1) = new(w, h, d)
end

function MTLSizeMake(width, height, depth)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLSizeMake(width::NSUInteger,
height::NSUInteger,
depth::NSUInteger)::MTLSize
end

struct MTLRegion
origin::MTLOrigin
size::MTLSize
MTLRegion(origin=MTLOrigin(), size=MTLSize()) = new(origin, size)
end

function MTLRegionMake1D(x, width)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLRegionMake1D(x::NSUInteger,
width::NSUInteger)::MTLRegion
end

function MTLRegionMake2D(x, y, width, height)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLRegionMake2D(x::NSUInteger,
y::NSUInteger,
width::NSUInteger,
height::NSUInteger)::MTLRegion
end

function MTLRegionMake3D(x, y, z, width, height, depth)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLRegionMake3D(x::NSUInteger,
y::NSUInteger,
z::NSUInteger,
width::NSUInteger,
height::NSUInteger,
depth::NSUInteger)::MTLRegion
end

struct MTLSamplePosition
x::Cfloat
y::Cfloat
end

function MTLSamplePositionMake(x, y)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLSamplePositionMake(x::Cfloat,
y::Cfloat)::MTLSamplePosition
end

const MTLCoordinate2D = MTLSamplePosition

function MTLCoordinate2DMake(x, y)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLCoordinate2DMake(x::Cfloat,
y::Cfloat)::MTLCoordinate2D
end

struct MTLResourceID
_impl::UInt64
end
Expand Down Expand Up @@ -654,6 +704,13 @@ struct MTLClearColor
alpha::Cdouble
end

function MTLClearColorMake(red, green, blue, alpha)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLClearColorMake(red::Cdouble,
green::Cdouble,
blue::Cdouble,
alpha::Cdouble)::MTLClearColor
end

@cenum MTLLoadAction::UInt64 begin
MTLLoadActionDontCare = 0x0000000000000000
MTLLoadActionLoad = 0x0000000000000001
Expand Down Expand Up @@ -1137,13 +1194,26 @@ end

const MTLPackedFloat3 = _MTLPackedFloat3

function MTLPackedFloat3Make(x, y, z)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLPackedFloat3Make(x::Cfloat,
y::Cfloat,
z::Cfloat)::MTLPackedFloat3
end

struct MTLPackedFloatQuaternion
x::Cfloat
y::Cfloat
z::Cfloat
w::Cfloat
end

function MTLPackedFloatQuaternionMake(x, y, z, w)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLPackedFloatQuaternionMake(x::Cfloat,
y::Cfloat,
z::Cfloat,
w::Cfloat)::MTLPackedFloatQuaternion
end

struct _MTLPackedFloat4x3
columns::NTuple{4,MTLPackedFloat3}
end
Expand Down Expand Up @@ -1308,6 +1378,11 @@ struct MTLIndirectCommandBufferExecutionRange
length::UInt32
end

function MTLIndirectCommandBufferExecutionRangeMake(location, length)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLIndirectCommandBufferExecutionRangeMake(location::UInt32,
length::UInt32)::MTLIndirectCommandBufferExecutionRange
end

@cenum MTLFunctionLogType::UInt64 begin
MTLFunctionLogTypeValidation = 0x0000000000000000
end
Expand Down Expand Up @@ -1395,3 +1470,23 @@ end
end

const MTLIOCompressionContext = Ptr{Cvoid}

function MTLIOCompressionContextDefaultChunkSize()
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLIOCompressionContextDefaultChunkSize()::Csize_t
end

function MTLIOCreateCompressionContext(path, type, chunkSize)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLIOCreateCompressionContext(path::Cstring,
type::MTLIOCompressionMethod,
chunkSize::Csize_t)::MTLIOCompressionContext
end

function MTLIOCompressionContextAppendData(context, data, size)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLIOCompressionContextAppendData(context::MTLIOCompressionContext,
data::Ptr{Cvoid},
size::Csize_t)::Cvoid
end

function MTLIOFlushAndDestroyCompressionContext(context)
@ccall (Symbol("/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib")).MTLIOFlushAndDestroyCompressionContext(context::MTLIOCompressionContext)::MTLIOCompressionStatus
end
11 changes: 10 additions & 1 deletion res/wrap/libmps.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@ printer_blacklist = [
"CF.*",
"MTL.*",
"NS.*",
"BOOL"
"BOOL",
# Not sure how to access the MPS functions so don't wrap for now
"MPSDataTypeBitsCount",
"MPSSizeofMPSDataType",
"MPSSizeofMPSDataType",
"MPSFindIntegerDivisionParams",
"MPSGetCustomKernelMaxBatchSize",
"MPSGetCustomKernelBatchedDestinationIndex",
"MPSGetCustomKernelBatchedSourceIndex",
"MPSGetCustomKernelBroadcastSourceIndex",
]

[codegen]
Expand Down
2 changes: 1 addition & 1 deletion res/wrap/libmtl.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[general]
library_name = "libmtl"
library_name = "Symbol(\"/System/Library/Frameworks/Metal.framework/Resources/BridgeSupport/Metal.dylib\")"
output_file_path = "../../lib/mtl/libmtl.jl"

generate_isystem_symbols = false
Expand Down
31 changes: 13 additions & 18 deletions res/wrap/wrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ using Clang_jll
Clang_jll.libclang = "/Applications/Xcode.app/Contents/Frameworks/libclang.dylib"

using Clang.Generators
using Clang.Generators: LinkEnumAlias
using Clang
using Glob
using JLD2
Expand All @@ -13,11 +12,8 @@ using Logging
# Use system SDK
SDK_PATH = `xcrun --show-sdk-path` |> open |> readchomp |> String

# Hack to prevent printing of functions for now
Generators.skip_check(dag::Generators.ExprDAG, node::Generators.ExprNode{Generators.FunctionProto}) = true

main(name::AbstractString; kwargs...) = main([name]; kwargs...)
function main(names=["all"]; sdk_path=SDK_PATH)
function main(names::AbstractVector=["all"]; sdk_path=SDK_PATH)
path_to_framework(framework) = joinpath(sdk_path, "System/Library/Frameworks/",framework*".framework","Headers")
path_to_mps_framework(framework) = joinpath(sdk_path, "System/Library/Frameworks/","MetalPerformanceShaders.framework","Frameworks",framework*".framework","Headers")

Expand All @@ -43,16 +39,6 @@ function main(names=["all"]; sdk_path=SDK_PATH)
push!(ctxs, tctx)
end

# if "all" in names || "libfoundation" in names || "foundation" in names
# fwpath = path_to_framework("Foundation")
# tctx = wrap("libfoundation", joinpath(foundation, "Foundation.h");, defines=["__builtin_va_list"])
# push!(ctxs, tctx)
# end
# if "all" in names || "libcf" in names || "cf" in names
# fwpath = path_to_framework("CoreFoundation")
# tctx = wrap("libfoundation", joinpath(fwpath, "CoreFoundation.h");, defines=["__builtin_va_list"])
# push!(ctxs, tctx)
# end
return ctxs
end

Expand Down Expand Up @@ -119,16 +105,22 @@ function create_objc_context(headers::Vector, args::Vector=String[], options::Di
"/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain"
]

regen = if haskey(options, "general") && haskey(options["general"], "regenerate_dependent_headers")
options["general"]["regenerate_dependent_headers"]
else
false
end

# Since the framework we're wrapping is a system header,
# find all dependent headers, then remove all but the relevant ones
# also temporarily disable logging
dep_headers_fname = if haskey(options, "general") && haskey(options["general"], "library_name")
options["general"]["library_name"]*".JLD2"
splitext(splitpath(options["general"]["output_file_path"])[end])[1]*".JLD2"
else
nothing
end
Base.CoreLogging._min_enabled_level[] = Logging.Info+1
dependent_headers = if !isnothing(dep_headers_fname) && isfile(dep_headers_fname)
dependent_headers = if !regen && !isnothing(dep_headers_fname) && isfile(dep_headers_fname)
JLD2.load(dep_headers_fname, "dep_headers")
else
all_headers = find_dependent_headers(headers,args,[])
Expand All @@ -137,7 +129,10 @@ function create_objc_context(headers::Vector, args::Vector=String[], options::Di
target_framework = "/"*joinpath(Sys.splitpath(header)[end-2:end-1])
dep_headers = append!(dep_headers, filter(s -> occursin(target_framework, s), all_headers))
end
JLD2.@save dep_headers_fname dep_headers
if haskey(options, "general") && haskey(options["general"], "extra_target_headers")
append!(dep_headers, options["general"]["extra_target_headers"])
end
regen || JLD2.@save dep_headers_fname dep_headers
dep_headers
end
Base.CoreLogging._min_enabled_level[] = Logging.Debug
Expand Down

0 comments on commit 6a760a6

Please sign in to comment.