Skip to content

Commit

Permalink
Add checks on version numbers (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed May 1, 2024
1 parent b51c415 commit 70a1f64
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 8 deletions.
2 changes: 1 addition & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## Features
- [x] Add nicer ability to build models from source in the languages
- [x] download source if needed, similar to bridgestan
- [ ] Version checking
- [x] Version checking
- [ ] Fixed param sampler for 0 dimension parameters?
- [ ] Add wrapper around generate quantities method?
- [x] Add wraper around laplace sampling?
Expand Down
2 changes: 2 additions & 0 deletions clients/R/R/download.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
current_version <- packageVersion("tinystan")
current_version_list <- list(major = current_version$major, minor = current_version$minor,
patch = current_version$patch)
HOME_TINYSTAN <- path.expand(file.path("~", ".tinystan"))
CURRENT_TINYSTAN <- file.path(HOME_TINYSTAN, paste0("tinystan-", current_version))

Expand Down
20 changes: 20 additions & 0 deletions clients/R/R/tinystan.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ tinystan_model = function(lib, stanc_args = NULL, make_args = NULL, warn = TRUE)
dyn.load(lib, PACKAGE = lib_name)
sep <- .C("tinystan_separator_char_R", sep = raw(1), PACKAGE = lib_name)$sep
sep <- rawToChar(sep)

api_ver = .C("tinystan_api_version", major = integer(1), minor = integer(1),
patch = integer(1), PACKAGE = lib_name)

if (api_ver$major != current_version_list$major) {
msg = paste0("Incompatible TinyStan API version. Expected ", paste(current_version_list,
collapse = "."), " but found ", paste(api_ver, collapse = "."), ".\nYou need to re-compile your model.")
stop(msg)
} else if (api_ver$minor != current_version_list$minor || api_ver$patch != current_version_list$patch) {
msg = paste0("TinyStan API version mismatch. Expected ", paste(current_version_list,
collapse = "."), " but found ", paste(api_ver, collapse = "."), ".\nYou may need to re-compile your model.")
warning(msg)
}

ret <- list(lib = lib, lib_name = lib_name, sep = sep, code = code, built_with_so = built_with_so)
class(ret) <- c("tinystan_model", class(ret))
return(ret)
Expand All @@ -52,6 +66,12 @@ api_version = function(stan_model) {
PACKAGE = stan_model$lib_name)
}

#'@export
stan_version = function(stan_model) {
.C("tinystan_stan_version", major = integer(1), minor = integer(1), patch = integer(1),
PACKAGE = stan_model$lib_name)
}

#' @noRd
with_model = function(model, data, seed, block) {
ffi_ret <- .C("tinystan_create_model_R", model_ptr = raw(8), as.character(data),
Expand Down
9 changes: 8 additions & 1 deletion clients/R/tests/testthat/test_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,12 @@ test_that("model loads", {
})

test_that("api_version is correct", {
expect_equal(api_version(bernoulli_model), list(major = 0, minor = 1, patch = 0))
expect_equal(api_version(bernoulli_model), current_version_list)
})

test_that("stan version is valid", {
stan_version <- stan_version(bernoulli_model)
expect_equal(stan_version$major, 2)
expect_gte(stan_version$minor, 34)
expect_gte(stan_version$patch, 0)
})
1 change: 1 addition & 0 deletions clients/julia/src/TinyStan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export Model,
OptimizationAlgorithm,
laplace_sample,
api_version,
stan_version,
compile_model,
get_tinystan_path,
set_tinystan_path!
Expand Down
30 changes: 29 additions & 1 deletion clients/julia/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,27 @@ mutable struct Model

windows_dll_path_setup()
lib = dlopen(libname)
sep = Char(@ccall $(dlsym(lib, :tinystan_separator_char))()::Cchar)

major, minor, patch = Ref{Cint}(), Ref{Cint}(), Ref{Cint}()
@ccall $(dlsym(lib, :tinystan_api_version))(
major::Ref{Cint},
minor::Ref{Cint},
patch::Ref{Cint},
)::Cvoid
api_ver = VersionNumber(major[], minor[], patch[])
if api_ver.major != TinyStan.pkg_version.major
error(
"Incompatible TinyStan API version. " *
"Expected $(TinyStan.pkg_version) but got $api_ver.\n" *
"You need to re-compile your model.",
)
elseif api_ver != TinyStan.pkg_version
@warn "TinyStan API version does not match. " *
"Expected $(TinyStan.pkg_version) but got $api_ver.\n" *
"You may need to re-compile your model."
end

sep = Char(@ccall $(dlsym(lib, :tinystan_separator_char))()::Cchar)
new(lib, sep)
end

Expand Down Expand Up @@ -165,6 +184,15 @@ function api_version(model::Model)
(major[], minor[], patch[])
end

function stan_version(model::Model)
major, minor, patch = Ref{Cint}(), Ref{Cint}(), Ref{Cint}()
@ccall $(dlsym(model.lib, :tinystan_stan_version))(
major::Ref{Cint},
minor::Ref{Cint},
patch::Ref{Cint},
)::Cvoid
(major[], minor[], patch[])
end

"""
sample(model::Model, data::String=""; num_chains::Int=4, inits::Union{nothing,AbstractString,AbstractArray{AbstractString}}=nothing, seed::Union{Nothing,UInt32}=nothing, id::Int=1, init_radius=2.0, num_warmup::Int=1000, num_samples::Int=1000, metric::HMCMetric=DIAGONAL, init_inv_metric::Union{Nothing,Array{Float64}}=nothing, save_metric::Bool=false, adapt::Bool=true, delta::Float64=0.8, gamma::Float64=0.05, kappa::Float64=0.75, t0::Int=10, init_buffer::Int=75, term_buffer::Int=50, window::Int=25, save_warmup::Bool=false, stepsize::Float64=1.0, stepsize_jitter::Float64=0.0, max_depth::Int=10, refresh::Int=0, num_threads::Int=-1)
Expand Down
9 changes: 8 additions & 1 deletion clients/julia/test/test_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
end

@testset "API version" begin
@test api_version(bernoulli_model) == (0, 1, 0)
@test VersionNumber(api_version(bernoulli_model)) == TinyStan.pkg_version
end

@testset "Stan version" begin
ver = stan_version(bernoulli_model)
@test ver[1] == 2
@test ver[2] >= 34
@test ver[3] >= 0
end

end
10 changes: 9 additions & 1 deletion clients/python/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,12 @@ def test_model_loads():

def test_api_version():
model = tinystan.Model(STAN_FOLDER / "bernoulli" / "bernoulli_model.so")
assert model.api_version() == (0, 1, 0)
assert model.api_version() == tinystan.__version.__version_info__


def test_stan_version():
model = tinystan.Model(STAN_FOLDER / "bernoulli" / "bernoulli_model.so")
stan_version = model.stan_version()
assert stan_version[0] == 2
assert stan_version[1] >= 34
assert stan_version[2] >= 0
37 changes: 34 additions & 3 deletions clients/python/tinystan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from numpy.ctypeslib import ndpointer
from stanio import dump_stan_json

from .__version import __version_info__
from .compile import compile_model, windows_dll_path_setup
from .output import StanOutput
from .util import validate_readable
Expand Down Expand Up @@ -197,6 +198,28 @@ def __init__(

self._lib = ctypes.CDLL(self.lib_path)

self._version = self._lib.tinystan_api_version
self._version.restype = None
self._version.argtypes = [
ctypes.POINTER(ctypes.c_int),
ctypes.POINTER(ctypes.c_int),
ctypes.POINTER(ctypes.c_int),
]

api_ver = self.api_version()
if api_ver[0] != __version_info__[0]:
raise RuntimeError(
"Incompatible TinyStan API version. Expected "
f"{__version_info__} but got {api_ver}.\n"
"You need to re-compile your model."
)
if api_ver != __version_info__:
warnings.warn(
"TinyStan API version does not match. Expected "
f"{__version_info__} but got {api_ver}.\n"
"You may need to re-compile your model."
)

self._create_model = self._lib.tinystan_create_model
self._create_model.restype = ctypes.c_void_p
self._create_model.argtypes = [ctypes.c_char_p, ctypes.c_uint, err_ptr]
Expand All @@ -213,9 +236,9 @@ def __init__(
self._num_free_params.restype = ctypes.c_size_t
self._num_free_params.argtypes = [ctypes.c_void_p]

self._version = self._lib.tinystan_api_version
self._version.restype = None
self._version.argtypes = [
self._stan_version = self._lib.tinystan_stan_version
self._stan_version.restype = None
self._stan_version.argtypes = [
ctypes.POINTER(ctypes.c_int),
ctypes.POINTER(ctypes.c_int),
ctypes.POINTER(ctypes.c_int),
Expand Down Expand Up @@ -394,6 +417,14 @@ def api_version(self):
self._version(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
return (major.value, minor.value, patch.value)

def stan_version(self):
"""Return the version of Stan backing this model."""
major, minor, patch = ctypes.c_int(), ctypes.c_int(), ctypes.c_int()
self._stan_version(
ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch)
)
return (major.value, minor.value, patch.value)

def sample(
self,
data: StanData = "",
Expand Down

0 comments on commit 70a1f64

Please sign in to comment.