Skip to content

Commit

Permalink
Misc. cleanups (#191)
Browse files Browse the repository at this point in the history
* Use byref where applicable

* Properly capitalize BridgeStan
  • Loading branch information
WardBrian authored Dec 5, 2023
1 parent 4d2ee05 commit d37a860
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 43 deletions.
2 changes: 1 addition & 1 deletion R/R/compile.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ get_bridgestan_path <- function() {
tryCatch({
verify_bridgestan_path(path)
}, error = function(e) {
print(paste0("Bridgestan not found at location specified by $BRIDGESTAN ",
print(paste0("BridgeStan not found at location specified by $BRIDGESTAN ",
"environment variable, downloading version ", packageVersion("bridgestan"),
" to ", path))
get_bridgestan_src()
Expand Down
2 changes: 1 addition & 1 deletion R/R/download.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ get_bridgestan_src <- function() {

dir.create(HOME_BRIDGESTAN, showWarnings = FALSE, recursive = TRUE)
temp <- tempfile()
err_text <- paste("Failed to download Bridgestan", current_version, "from github.com.")
err_text <- paste("Failed to download BridgeStan", current_version, "from github.com.")
for (i in 1:RETRIES) {
tryCatch({
download.file(url, destfile = temp, mode = "wb", quiet = TRUE, method = "auto")
Expand Down
11 changes: 5 additions & 6 deletions python/bridgestan/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def set_bridgestan_path(path: Union[str, os.PathLike]) -> None:
os.environ["BRIDGESTAN"] = path


def get_bridgestan_path():
def get_bridgestan_path() -> str:
"""
Get the path to BridgeStan.
Expand All @@ -66,22 +66,22 @@ def get_bridgestan_path():
verify_bridgestan_path(path)
except ValueError:
print(
"Bridgestan not found at location specified by $BRIDGESTAN "
"BridgeStan not found at location specified by $BRIDGESTAN "
f"environment variable, downloading version {__version__} to {path}"
)
get_bridgestan_src()
num_files = len(list(HOME_BRIDGESTAN.iterdir()))
if num_files >= 5:
warnings.warn(
f"Found {num_files} different versions of Bridgestan in {HOME_BRIDGESTAN}. "
f"Found {num_files} different versions of BridgeStan in {HOME_BRIDGESTAN}. "
"Consider deleting old versions to save space."
)
print("Done!")

return path


def generate_so_name(model: Path):
def generate_so_name(model: Path) -> Path:
name = model.stem
return model.with_stem(f"{name}_model").with_suffix(".so")

Expand Down Expand Up @@ -138,7 +138,7 @@ def compile_model(
return output


def windows_dll_path_setup():
def windows_dll_path_setup() -> None:
"""Add tbb.dll to %PATH% on Windows."""
global WINDOWS_PATH_SET
if IS_WINDOWS and not WINDOWS_PATH_SET:
Expand Down Expand Up @@ -180,7 +180,6 @@ def windows_dll_path_setup():
os.path.dirname(out.stdout.decode().splitlines()[0])
)
os.add_dll_directory(mingw_dir)
WINDOWS_PATH_SET &= True
except:
# no default location
warnings.warn(
Expand Down
4 changes: 2 additions & 2 deletions python/bridgestan/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
RETRIES = 5


def get_bridgestan_src():
def get_bridgestan_src() -> None:
"""
Download and unzip the BridgeStan source distribution for this version
Expand All @@ -24,7 +24,7 @@ def get_bridgestan_src():
)
HOME_BRIDGESTAN.mkdir(exist_ok=True)

err_text = f"Failed to download Bridgestan {__version__} from github.com."
err_text = f"Failed to download BridgeStan {__version__} from github.com."
for i in range(1, 1 + RETRIES):
try:
file_tmp, _ = urllib.request.urlretrieve(url, filename=None)
Expand Down
86 changes: 53 additions & 33 deletions python/bridgestan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,13 @@ def __init__(
if capture_stan_prints:
self._set_print_callback(_print_callback, None)

err = ctypes.pointer(ctypes.c_char_p())
self.model = self._construct(str.encode(self.data), self.seed, err)
err = ctypes.c_char_p()
self.model = self._construct(
str.encode(self.data), self.seed, ctypes.byref(err)
)

if not self.model:
raise self._handle_error(err.contents, "bs_model_construct")
raise self._handle_error(err, "bs_model_construct")

if self.model_version() != __version_info__:
warnings.warn(
Expand Down Expand Up @@ -398,7 +400,7 @@ def param_constrain(
"Error: out must be same size as number of constrained parameters"
)

err = ctypes.pointer(ctypes.c_char_p())
err = ctypes.c_char_p()

rc = self._param_constrain(
self.model,
Expand All @@ -407,11 +409,11 @@ def param_constrain(
theta_unc,
out,
rng_ptr,
err,
ctypes.byref(err),
)

if rc:
raise self._handle_error(err.contents, "param_constrain")
raise self._handle_error(err, "param_constrain")
return out

def new_rng(self, seed) -> "StanRNG":
Expand Down Expand Up @@ -447,10 +449,10 @@ def param_unconstrain(
raise ValueError(
f"out size = {out.size} != unconstrained params size = {dims}"
)
err = ctypes.pointer(ctypes.c_char_p())
rc = self._param_unconstrain(self.model, theta, out, err)
err = ctypes.c_char_p()
rc = self._param_unconstrain(self.model, theta, out, ctypes.byref(err))
if rc:
raise self._handle_error(err.contents, "param_unconstrain")
raise self._handle_error(err, "param_unconstrain")
return out

def param_unconstrain_json(
Expand Down Expand Up @@ -479,10 +481,10 @@ def param_unconstrain_json(
f"out size = {out.size} != unconstrained params size = {dims}"
)
chars = theta_json.encode("UTF-8")
err = ctypes.pointer(ctypes.c_char_p())
rc = self._param_unconstrain_json(self.model, chars, out, err)
err = ctypes.c_char_p()
rc = self._param_unconstrain_json(self.model, chars, out, ctypes.byref(err))
if rc:
raise self._handle_error(err.contents, "param_unconstrain_json")
raise self._handle_error(err, "param_unconstrain_json")
return out

def log_density(
Expand All @@ -505,14 +507,19 @@ def log_density(
:return: The log density.
:raises RuntimeError: If the C++ Stan model throws an exception.
"""
lp = ctypes.pointer(ctypes.c_double())
err = ctypes.pointer(ctypes.c_char_p())
lp = ctypes.c_double()
err = ctypes.c_char_p()
rc = self._log_density(
self.model, int(propto), int(jacobian), theta_unc, lp, err
self.model,
int(propto),
int(jacobian),
theta_unc,
ctypes.byref(lp),
ctypes.byref(err),
)
if rc:
raise self._handle_error(err.contents, "log_density")
return lp.contents.value
raise self._handle_error(err, "log_density")
return lp.value

def log_density_gradient(
self,
Expand Down Expand Up @@ -547,14 +554,20 @@ def log_density_gradient(
out = np.zeros(shape=dims)
elif out.size != dims:
raise ValueError(f"out size = {out.size} != params size = {dims}")
lp = ctypes.pointer(ctypes.c_double())
err = ctypes.pointer(ctypes.c_char_p())
lp = ctypes.c_double()
err = ctypes.c_char_p()
rc = self._log_density_gradient(
self.model, int(propto), int(jacobian), theta_unc, lp, out, err
self.model,
int(propto),
int(jacobian),
theta_unc,
ctypes.byref(lp),
out,
ctypes.byref(err),
)
if rc:
raise self._handle_error(err.contents, "log_density_gradient")
return lp.contents.value, out
raise self._handle_error(err, "log_density_gradient")
return lp.value, out

def log_density_hessian(
self,
Expand Down Expand Up @@ -602,22 +615,22 @@ def log_density_hessian(
raise ValueError(
f"out_hess size = {out_hess.size} != params size^2 = {hess_size}"
)
lp = ctypes.pointer(ctypes.c_double())
err = ctypes.pointer(ctypes.c_char_p())
lp = ctypes.c_double()
err = ctypes.c_char_p()
rc = self._log_density_hessian(
self.model,
int(propto),
int(jacobian),
theta_unc,
lp,
ctypes.byref(lp),
out_grad,
out_hess,
err,
ctypes.byref(err),
)
if rc:
raise self._handle_error(err.contents, "log_density_hessian")
raise self._handle_error(err, "log_density_hessian")
out_hess = out_hess.reshape(dims, dims)
return lp.contents.value, out_grad, out_hess
return lp.value, out_grad, out_hess

def log_density_hessian_vector_product(
self,
Expand Down Expand Up @@ -650,15 +663,22 @@ def log_density_hessian_vector_product(
out = np.zeros(shape=dims)
elif out.size != dims:
raise ValueError(f"out size = {out.size} != params size = {dims}")
lp = ctypes.pointer(ctypes.c_double())
err = ctypes.pointer(ctypes.c_char_p())
lp = ctypes.c_double()
err = ctypes.c_char_p()
rc = self._log_density_hvp(
self.model, int(propto), int(jacobian), theta_unc, v, lp, out, err
self.model,
int(propto),
int(jacobian),
theta_unc,
v,
ctypes.byref(lp),
out,
ctypes.byref(err),
)
if rc:
raise self._handle_error(err.contents, "log_density_hessian_vector_product")
raise self._handle_error(err, "log_density_hessian_vector_product")

return lp.contents.value, out
return lp.value, out

def _handle_error(self, err: ctypes.c_char_p, method: str) -> Exception:
"""
Expand Down

0 comments on commit d37a860

Please sign in to comment.