From c20629b2f870256ac810dd00336725784d141b8b Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 26 Apr 2024 12:23:55 -0400 Subject: [PATCH] Julia: Use more compact @ccall (#28) * Julia: Use more compact @ccall * Prefer Ref over Ptr for non-nullable arguments --- clients/julia/src/model.jl | 364 ++++++++++++------------------------- 1 file changed, 120 insertions(+), 244 deletions(-) diff --git a/clients/julia/src/model.jl b/clients/julia/src/model.jl index b3a5f08..a292f90 100644 --- a/clients/julia/src/model.jl +++ b/clients/julia/src/model.jl @@ -1,3 +1,4 @@ +using Base.Libc.Libdl: dlsym, dlopen, dllist """ Choices for the structure of the mass matrix used in the HMC sampler. @@ -72,7 +73,7 @@ mutable struct Model libname = model end - if warn && in(abspath(libname), Libc.Libdl.dllist()) + if warn && in(abspath(libname), dllist()) @warn "Loading a shared object '" * libname * "' which is already loaded.\n" * @@ -80,8 +81,8 @@ mutable struct Model end windows_dll_path_setup() - lib = Libc.Libdl.dlopen(libname) - sep = Char(ccall(Libc.Libdl.dlsym(lib, :tinystan_separator_char), Cchar, ())) + lib = dlopen(libname) + sep = Char(@ccall $(dlsym(lib, :tinystan_separator_char))()::Cchar) new(lib, sep) end @@ -95,20 +96,12 @@ function raise_for_error(lib::Ptr{Nothing}, return_code::Cint, err::Ref{Ptr{Cvoi if err[] == C_NULL error("Unknown error, function returned code $return_code") end - cstr = ccall( - Libc.Libdl.dlsym(lib, :tinystan_get_error_message), - Cstring, - (Ptr{Cvoid},), - err[], - ) + + cstr = @ccall $(dlsym(lib, :tinystan_get_error_message))(err[]::Ptr{Cvoid})::Cstring + msg = unsafe_string(cstr) - type = ccall( - Libc.Libdl.dlsym(lib, :tinystan_get_error_type), - Cint, - (Ptr{Cvoid},), - err[], - ) - ccall(Libc.Libdl.dlsym(lib, :tinystan_destroy_error), Cvoid, (Ptr{Cvoid},), err[]) + type = @ccall $(dlsym(lib, :tinystan_get_error_type))(err[]::Ptr{Cvoid})::Cint + @ccall $(dlsym(lib, :tinystan_destroy_error))(err[]::Ptr{Cvoid})::Cvoid exn = exceptions[type+1] throw(exn(msg)) end @@ -128,47 +121,33 @@ end function with_model(f, model::Model, data::String, seed::UInt32) err = Ref{Ptr{Cvoid}}() - model_ptr = ccall( - Libc.Libdl.dlsym(model.lib, :tinystan_create_model), - Ptr{Cvoid}, - (Cstring, Cuint, Ref{Ptr{Cvoid}}), - data, - seed, - err, - ) + model_ptr = @ccall $(dlsym(model.lib, :tinystan_create_model))( + data::Cstring, + seed::Cuint, + err::Ref{Ptr{Cvoid}}, + )::Ptr{Cvoid} raise_for_error(model.lib, Int32(model_ptr == C_NULL), err) try return f(model_ptr) finally - ccall( - Libc.Libdl.dlsym(model.lib, :tinystan_destroy_model), - Cvoid, - (Ptr{Cvoid},), - model_ptr, - ) + @ccall $(dlsym(model.lib, :tinystan_destroy_model))(model_ptr::Ptr{Cvoid})::Cvoid end end function num_free_params(model::Model, model_ptr::Ptr{Cvoid}) Int( - ccall( - Libc.Libdl.dlsym(model.lib, :tinystan_model_num_free_params), - Csize_t, - (Ptr{Cvoid},), - model_ptr, - ), + @ccall $(dlsym(model.lib, :tinystan_model_num_free_params))( + model_ptr::Ptr{Cvoid}, + )::Csize_t ) end function get_names(model::Model, model_ptr::Ptr{Cvoid}) - cstr = ccall( - Libc.Libdl.dlsym(model.lib, :tinystan_model_param_names), - Cstring, - (Ptr{Cvoid},), - model_ptr, - ) + cstr = @ccall $(dlsym(model.lib, :tinystan_model_param_names))( + model_ptr::Ptr{Cvoid}, + )::Cstring str = unsafe_string(cstr) if isempty(str) return String[] @@ -178,14 +157,11 @@ end function api_version(model::Model) major, minor, patch = Ref{Cint}(), Ref{Cint}(), Ref{Cint}() - ccall( - Libc.Libdl.dlsym(model.lib, :tinystan_api_version), - Cvoid, - (Ptr{Cint}, Ptr{Cint}, Ptr{Cint}), - major, - minor, - patch, - ) + @ccall $(dlsym(model.lib, :tinystan_api_version))( + major::Ref{Cint}, + minor::Ref{Cint}, + patch::Ref{Cint}, + )::Cvoid (major[], minor[], patch[]) end @@ -288,68 +264,37 @@ function sample( end err = Ref{Ptr{Cvoid}}() - return_code = ccall( - Libc.Libdl.dlsym(model.lib, :tinystan_sample), - Cint, - ( - Ptr{Cvoid}, - Csize_t, - Cstring, - Cuint, - Cuint, - Cdouble, - Cint, - Cint, - HMCMetric, - Ptr{Cdouble}, - Bool, - Cdouble, - Cdouble, - Cdouble, - Cdouble, - Cuint, - Cuint, - Cuint, - Bool, - Cdouble, - Cdouble, - Cint, - Cint, - Cint, - Ref{Cdouble}, - Csize_t, - Ptr{Cdouble}, - Ref{Ptr{Cvoid}}, - ), - model_ptr, - num_chains, - encode_inits(model.sep, inits), - seed, - id, - init_radius, - num_warmup, - num_samples, - metric, - init_inv_metric, - adapt, - delta, - gamma, - kappa, - t0, - init_buffer, - term_buffer, - window, - save_warmup, - stepsize, - stepsize_jitter, - max_depth, - refresh, - num_threads, - out, - length(out), - metric_out, - err, - ) + return_code = @ccall $(dlsym(model.lib, :tinystan_sample))( + model_ptr::Ptr{Cvoid}, + num_chains::Csize_t, + encode_inits(model.sep, inits)::Cstring, + seed::Cuint, + id::Cuint, + init_radius::Cdouble, + num_warmup::Cint, + num_samples::Cint, + metric::HMCMetric, + init_inv_metric::Ptr{Cdouble}, + adapt::Bool, + delta::Cdouble, + gamma::Cdouble, + kappa::Cdouble, + t0::Cdouble, + init_buffer::Cint, + term_buffer::Cint, + window::Cint, + save_warmup::Bool, + stepsize::Cdouble, + stepsize_jitter::Cdouble, + max_depth::Cint, + refresh::Cint, + num_threads::Cint, + out::Ref{Cdouble}, + length(out)::Csize_t, + metric_out::Ptr{Cdouble}, + err::Ref{Ptr{Cvoid}}, + )::Cint + raise_for_error(model.lib, return_code, err) out = permutedims(out, (3, 2, 1)) if save_metric @@ -427,60 +372,32 @@ function pathfinder( out = zeros(Float64, num_params, num_output) err = Ref{Ptr{Cvoid}}() - return_code = ccall( - Libc.Libdl.dlsym(model.lib, :tinystan_pathfinder), - Cint, - ( - Ptr{Cvoid}, - Csize_t, - Cstring, - Cuint, - Cuint, - Cdouble, - Cint, - Cint, - Cdouble, - Cdouble, - Cdouble, - Cdouble, - Cdouble, - Cdouble, - Cint, - Cint, - Cint, - Bool, - Bool, - Cint, - Cint, - Ref{Cdouble}, - Csize_t, - Ref{Ptr{Cvoid}}, - ), - model_ptr, - num_paths, - encode_inits(model.sep, inits), - seed, - id, - init_radius, - num_draws, - max_history_size, - init_alpha, - tol_obj, - tol_rel_obj, - tol_grad, - tol_rel_grad, - tol_param, - num_iterations, - num_elbo_draws, - num_multi_draws, - calculate_lp, - psis_resample, - refresh, - num_threads, - out, - length(out), - err, - ) + return_code = @ccall $(dlsym(model.lib, :tinystan_pathfinder))( + model_ptr::Ptr{Cvoid}, + num_paths::Csize_t, + encode_inits(model.sep, inits)::Cstring, + seed::Cuint, + id::Cuint, + init_radius::Cdouble, + num_draws::Cint, + max_history_size::Cint, + init_alpha::Cdouble, + tol_obj::Cdouble, + tol_rel_obj::Cdouble, + tol_grad::Cdouble, + tol_rel_grad::Cdouble, + tol_param::Cdouble, + num_iterations::Cint, + num_elbo_draws::Cint, + num_multi_draws::Cint, + calculate_lp::Bool, + psis_resample::Bool, + refresh::Cint, + num_threads::Cint, + out::Ref{Cdouble}, + length(out)::Csize_t, + err::Ref{Ptr{Cvoid}}, + )::Cint raise_for_error(model.lib, return_code, err) return (param_names, transpose(out)) @@ -530,56 +447,32 @@ function optimize( out = zeros(Float64, num_params) err = Ref{Ptr{Cvoid}}() - return_code = ccall( - Libc.Libdl.dlsym(model.lib, :tinystan_optimize), - Cint, - ( - Ptr{Cvoid}, - Cstring, - Cuint, - Cuint, - Cdouble, - Cint, # really enum - Cint, - Bool, - Cint, - Cdouble, - Cdouble, - Cdouble, - Cdouble, - Cdouble, - Cdouble, - Cint, - Cint, - Ref{Cdouble}, - Csize_t, - Ref{Ptr{Cvoid}}, - ), - model_ptr, + return_code = @ccall $(dlsym(model.lib, :tinystan_optimize))( + model_ptr::Ptr{Cvoid}, if init === nothing C_NULL else init - end, - seed, - id, - init_radius, - algorithm, - num_iterations, - jacobian, - max_history_size, - init_alpha, - tol_obj, - tol_rel_obj, - tol_grad, - tol_rel_grad, - tol_param, - refresh, - num_threads, - out, - length(out), - err, - ) + end::Cstring, + seed::Cuint, + id::Cuint, + init_radius::Cdouble, + algorithm::OptimizationAlgorithm, + num_iterations::Cint, + jacobian::Bool, + max_history_size::Cint, + init_alpha::Cdouble, + tol_obj::Cdouble, + tol_rel_obj::Cdouble, + tol_grad::Cdouble, + tol_rel_grad::Cdouble, + tol_param::Cdouble, + refresh::Cint, + num_threads::Cint, + out::Ref{Cdouble}, + length(out)::Csize_t, + err::Ref{Ptr{Cvoid}}, + )::Cint raise_for_error(model.lib, return_code, err) return (param_names, out) end @@ -642,38 +535,21 @@ function laplace_sample( end err = Ref{Ptr{Cvoid}}() - return_code = ccall( - Libc.Libdl.dlsym(model.lib, :tinystan_laplace_sample), - Cint, - ( - Ptr{Cvoid}, - Ptr{Cdouble}, - Cstring, - Cuint, - Cint, - Bool, - Bool, - Cint, - Cint, - Ref{Cdouble}, - Csize_t, - Ptr{Cdouble}, - Ref{Ptr{Cvoid}}, - ), - model_ptr, - mode_array, - mode_json, - seed, - num_draws, - jacobian, - calculate_lp, - refresh, - num_threads, - out, - length(out), - hessian_out, - err, - ) + return_code = @ccall $(dlsym(model.lib, :tinystan_laplace_sample))( + model_ptr::Ptr{Cvoid}, + mode_array::Ptr{Cdouble}, + mode_json::Cstring, + seed::Cuint, + num_draws::Cint, + jacobian::Bool, + calculate_lp::Bool, + refresh::Cint, + num_threads::Cint, + out::Ref{Cdouble}, + length(out)::Csize_t, + hessian_out::Ptr{Cdouble}, + err::Ref{Ptr{Cvoid}}, + )::Cint raise_for_error(model.lib, return_code, err) if save_hessian