Skip to content

Commit

Permalink
Add input size checks (#241)
Browse files Browse the repository at this point in the history
* Add input checks in Python

* Add input checks in Julia

* Add input checks in R
  • Loading branch information
WardBrian authored Aug 29, 2024
1 parent 721ea5b commit 340cb04
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 48 deletions.
31 changes: 26 additions & 5 deletions R/R/bridgestan.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ StanModel <- R6::R6Class("StanModel",
}
private$model <- ret$ptr_out

# pre-compute to avoid repeated work in bounds checks
private$unc_dims <- self$param_unc_num()

model_version <- self$model_version()
if (packageVersion("bridgestan") != paste(model_version$major, model_version$minor, model_version$patch, sep = ".")) {
warning(paste0("The version of the compiled model does not match the version of the R library. ",
Expand Down Expand Up @@ -167,6 +170,9 @@ StanModel <- R6::R6Class("StanModel",
} else {
rng_ptr <- as.raw(rng$ptr)
}
if (length(theta_unc) != private$unc_dims) {
stop("Incorrect number of unconstrained parameters.")
}
vars <- .C("bs_param_constrain_R", as.raw(private$model),
as.logical(include_tp), as.logical(include_gq), as.double(theta_unc),
theta = double(self$param_num(include_tp = include_tp, include_gq = include_gq)),
Expand Down Expand Up @@ -202,7 +208,7 @@ StanModel <- R6::R6Class("StanModel",
param_unconstrain = function(theta) {
vars <- .C("bs_param_unconstrain_R", as.raw(private$model),
as.double(theta),
theta_unc = double(self$param_unc_num()),
theta_unc = double(private$unc_dims),
return_code = as.integer(0),
err_msg = as.character(""),
err_ptr = raw(8),
Expand All @@ -223,7 +229,7 @@ StanModel <- R6::R6Class("StanModel",
param_unconstrain_json = function(json) {
vars <- .C("bs_param_unconstrain_json_R", as.raw(private$model),
as.character(json),
theta_unc = double(self$param_unc_num()),
theta_unc = double(private$unc_dims),
return_code = as.integer(0),
err_msg = as.character(""),
err_ptr = raw(8),
Expand All @@ -241,6 +247,9 @@ StanModel <- R6::R6Class("StanModel",
#' @param jacobian If `TRUE`, include change of variables terms for constrained parameters.
#' @return The log density.
log_density = function(theta_unc, propto = TRUE, jacobian = TRUE) {
if (length(theta_unc) != private$unc_dims) {
stop("Incorrect number of unconstrained parameters.")
}
vars <- .C("bs_log_density_R", as.raw(private$model),
as.logical(propto), as.logical(jacobian), as.double(theta_unc),
val = double(1),
Expand All @@ -262,7 +271,10 @@ StanModel <- R6::R6Class("StanModel",
#' @param jacobian If `TRUE`, include change of variables terms for constrained parameters.
#' @return List containing entries `val` (the log density) and `gradient` (the gradient).
log_density_gradient = function(theta_unc, propto = TRUE, jacobian = TRUE) {
dims <- self$param_unc_num()
if (length(theta_unc) != private$unc_dims) {
stop("Incorrect number of unconstrained parameters.")
}
dims <- private$unc_dims
vars <- .C("bs_log_density_gradient_R", as.raw(private$model),
as.logical(propto), as.logical(jacobian), as.double(theta_unc),
val = double(1), gradient = double(dims),
Expand All @@ -284,7 +296,10 @@ StanModel <- R6::R6Class("StanModel",
#' @param jacobian If `TRUE`, include change of variables terms for constrained parameters.
#' @return List containing entries `val` (the log density), `gradient` (the gradient), and `hessian` (the Hessian).
log_density_hessian = function(theta_unc, propto = TRUE, jacobian = TRUE) {
dims <- self$param_unc_num()
if (length(theta_unc) != private$unc_dims) {
stop("Incorrect number of unconstrained parameters.")
}
dims <- private$unc_dims
vars <- .C("bs_log_density_hessian_R", as.raw(private$model),
as.logical(propto), as.logical(jacobian), as.double(theta_unc),
val = double(1), gradient = double(dims), hess = double(dims * dims),
Expand All @@ -308,7 +323,12 @@ StanModel <- R6::R6Class("StanModel",
#' @param jacobian If `TRUE`, include change of variables terms for constrained parameters.
#' @return List containing entries `val` (the log density) and `Hvp` (the hessian-vector product).
log_density_hessian_vector_product = function(theta_unc, v, propto = TRUE, jacobian = TRUE){
dims <- self$param_unc_num()
dims <- private$unc_dims
if (length(theta_unc) != dims) {
stop("Incorrect number of unconstrained parameters.")
} else if (length(v) != dims) {
stop("Incorrect number of vector elements.")
}
vars <- .C("bs_log_density_hessian_vector_product_R",
as.raw(private$model), as.logical(propto), as.logical(jacobian),
as.double(theta_unc),
Expand All @@ -331,6 +351,7 @@ StanModel <- R6::R6Class("StanModel",
lib_name = NA,
model = NA,
seed = NA,
unc_dims = NA,
finalize = function() {
.C("bs_model_destruct_R",
as.raw(private$model),
Expand Down
2 changes: 2 additions & 0 deletions R/tests/testthat/test_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ test_that("param_constrain handles rng arguments", {

# require at least one present
expect_error(full$param_constrain(c(1.2), include_gq = TRUE), "rng must be provided")

expect_error(full$param_constrain(c(1.2, 1.2)), "Incorrect number of unconstrained parameters")
})


Expand Down
57 changes: 40 additions & 17 deletions julia/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ mutable struct StanModel
stanmodel::Ptr{StanModelStruct}
@const data::String
@const seed::UInt32
@const param_unc_num::Int

function StanModel(
lib::String,
Expand Down Expand Up @@ -85,7 +86,11 @@ mutable struct StanModel
error(handle_error(lib, err, "bs_model_construct"))
end

sm = new(lib, stanmodel, data, seed)
# compute now to avoid re-computing in bounds checks later
param_unc_num =
@ccall $(dlsym(lib, :bs_param_unc_num))(stanmodel::Ptr{StanModelStruct})::Cint

sm = new(lib, stanmodel, data, seed, param_unc_num)

function f(sm)
@ccall $(dlsym(sm.lib, :bs_model_destruct))(
Expand Down Expand Up @@ -279,7 +284,13 @@ function param_constrain!(
rng::Union{StanRNG,Nothing} = nothing,
)
dims = param_num(sm; include_tp = include_tp, include_gq = include_gq)
if length(out) != dims
if length(theta_unc) != sm.param_unc_num
throw(
DimensionMismatch(
"theta_unc must be same size as number of unconstrained parameters",
),
)
elseif length(out) != dims
throw(
DimensionMismatch("out must be same size as number of constrained parameters"),
)
Expand Down Expand Up @@ -359,8 +370,7 @@ The result is stored in the vector `out`, and a reference is returned. See
This is the inverse of [`param_constrain!`](@ref).
"""
function param_unconstrain!(sm::StanModel, theta::Vector{Float64}, out::Vector{Float64})
dims = param_unc_num(sm)
if length(out) != dims
if length(out) != sm.param_unc_num
throw(
DimensionMismatch(
"out must be same size as number of unconstrained parameters",
Expand Down Expand Up @@ -396,7 +406,7 @@ re-using existing memory.
This is the inverse of [`param_constrain`](@ref).
"""
function param_unconstrain(sm::StanModel, theta::Vector{Float64})
out = zeros(param_unc_num(sm))
out = zeros(sm.param_unc_num)
param_unconstrain!(sm, theta, out)
end

Expand All @@ -411,8 +421,7 @@ The result is stored in the vector `out`, and a reference is returned. See
[`param_unconstrain_json`](@ref) for a version which allocates fresh memory.
"""
function param_unconstrain_json!(sm::StanModel, theta::String, out::Vector{Float64})
dims = param_unc_num(sm)
if length(out) != dims
if length(out) != sm.param_unc_num
throw(
DimensionMismatch(
"out must be same size as number of unconstrained parameters",
Expand Down Expand Up @@ -445,7 +454,7 @@ See [`param_unconstrain_json!`](@ref) for a version which allows
re-using existing memory.
"""
function param_unconstrain_json(sm::StanModel, theta::String)
out = zeros(param_unc_num(sm))
out = zeros(sm.param_unc_num)
param_unconstrain_json!(sm, theta, out)
end

Expand Down Expand Up @@ -498,8 +507,11 @@ function log_density_gradient!(
propto::Bool = true,
jacobian::Bool = true,
)
dims = param_unc_num(sm)
if length(out) != dims
if length(q) != sm.param_unc_num
throw(
DimensionMismatch("q must be same size as number of unconstrained parameters"),
)
elseif length(out) != sm.param_unc_num
throw(
DimensionMismatch(
"out must be same size as number of unconstrained parameters",
Expand Down Expand Up @@ -541,7 +553,7 @@ function log_density_gradient(
propto::Bool = true,
jacobian::Bool = true,
)
grad = zeros(param_unc_num(sm))
grad = zeros(sm.param_unc_num)
log_density_gradient!(sm, q, grad; propto = propto, jacobian = jacobian)
end

Expand All @@ -565,8 +577,12 @@ function log_density_hessian!(
propto::Bool = true,
jacobian::Bool = true,
)
dims = param_unc_num(sm)
if length(out_grad) != dims
dims = sm.param_unc_num
if length(q) != dims
throw(
DimensionMismatch("q must be same size as number of unconstrained parameters"),
)
elseif length(out_grad) != dims
throw(
DimensionMismatch(
"out_grad must be same size as number of unconstrained parameters",
Expand Down Expand Up @@ -615,7 +631,7 @@ function log_density_hessian(
propto::Bool = true,
jacobian::Bool = true,
)
dims = param_unc_num(sm)
dims = sm.param_unc_num
grad = zeros(dims)
hess = zeros(dims * dims)
log_density_hessian!(sm, q, grad, hess; propto = propto, jacobian = jacobian)
Expand All @@ -641,8 +657,15 @@ function log_density_hessian_vector_product!(
propto::Bool = true,
jacobian::Bool = true,
)
dims = param_unc_num(sm)
if length(out) != dims
if length(q) != sm.param_unc_num
throw(
DimensionMismatch("q must be same size as number of unconstrained parameters"),
)
elseif length(v) != sm.param_unc_num
throw(
DimensionMismatch("v must be same size as number of unconstrained parameters"),
)
elseif length(out) != sm.param_unc_num
throw(
DimensionMismatch(
"out must be same size as number of unconstrained parameters",
Expand Down Expand Up @@ -687,7 +710,7 @@ function log_density_hessian_vector_product(
propto::Bool = true,
jacobian::Bool = true,
)
out = zeros(param_unc_num(sm))
out = zeros(sm.param_unc_num)
log_density_hessian_vector_product!(sm, q, v, out; propto = propto, jacobian = jacobian)
end

Expand Down
14 changes: 14 additions & 0 deletions julia/test/model_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ end


model2 = load_test_model("full", false)
a = randn(BridgeStan.param_unc_num(model2))
rng = StanRNG(model2, 1234)
@test 1 == length(BridgeStan.param_constrain(model2, a))
@test 2 == length(BridgeStan.param_constrain(model2, a; include_tp = true))
Expand Down Expand Up @@ -392,6 +393,12 @@ end
jacobian = true,
)

y_unc_bad = zeros(length(y_unc) + 1)
@test_throws DimensionMismatch BridgeStan.log_density_gradient(model, y_unc_bad)

y_unc_bad = zeros(length(y_unc) - 1)
@test_throws DimensionMismatch BridgeStan.log_density_gradient(model, y_unc_bad)

end

@testset "log_density_hessian" begin
Expand Down Expand Up @@ -473,6 +480,13 @@ end
jacobian = true,
)


y_unc_bad = zeros(length(y_unc) + 1)
@test_throws DimensionMismatch BridgeStan.log_density_hessian(model, y_unc_bad)

y_unc_bad = zeros(length(y_unc) - 1)
@test_throws DimensionMismatch BridgeStan.log_density_hessian(model, y_unc_bad)

end

end
Expand Down
18 changes: 12 additions & 6 deletions python/bridgestan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def __init__(

num_params = self._param_unc_num(self.model)

param_sized_array = array_ptr(
dtype=ctypes.c_double,
flags=("C_CONTIGUOUS",),
shape=(num_params,),
)

param_sized_out_array = array_ptr(
dtype=ctypes.c_double,
flags=("C_CONTIGUOUS", "WRITEABLE"),
Expand Down Expand Up @@ -227,7 +233,7 @@ def __init__(
ctypes.c_void_p,
ctypes.c_bool,
ctypes.c_bool,
double_array,
param_sized_array,
writeable_double_array,
ctypes.c_void_p,
star_star_char,
Expand Down Expand Up @@ -257,7 +263,7 @@ def __init__(
ctypes.c_void_p,
ctypes.c_bool,
ctypes.c_bool,
double_array,
param_sized_array,
ctypes.POINTER(ctypes.c_double),
star_star_char,
]
Expand All @@ -268,7 +274,7 @@ def __init__(
ctypes.c_void_p,
ctypes.c_bool,
ctypes.c_bool,
double_array,
param_sized_array,
ctypes.POINTER(ctypes.c_double),
param_sized_out_array,
star_star_char,
Expand All @@ -280,7 +286,7 @@ def __init__(
ctypes.c_void_p,
ctypes.c_bool,
ctypes.c_bool,
double_array,
param_sized_array,
ctypes.POINTER(ctypes.c_double),
param_sized_out_array,
param_sqrd_sized_out_array,
Expand All @@ -293,8 +299,8 @@ def __init__(
ctypes.c_void_p,
ctypes.c_bool,
ctypes.c_bool,
double_array,
double_array,
param_sized_array,
param_sized_array,
ctypes.POINTER(ctypes.c_double),
param_sized_out_array,
star_star_char,
Expand Down
Loading

0 comments on commit 340cb04

Please sign in to comment.