Skip to content

Commit

Permalink
add option to modify penalization, fix warnings from coef
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Sep 13, 2024
1 parent 67520ea commit 6d0fc69
Show file tree
Hide file tree
Showing 17 changed files with 1,228 additions and 1,063 deletions.
16 changes: 14 additions & 2 deletions R/coef.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,29 @@ coef.nhmm <- function(object, probs = c(0.025, 0.975), ...) {
p_s <- length(beta_s_raw)
p_o <- length(beta_o_raw)
sds <- try(
sqrt(diag(solve(-object$estimation_results$hessian))),
diag(solve(-object$estimation_results$hessian)),
silent = TRUE
)
if (inherits(sds, "try-error")) {
warning_(
paste0(
"Standard errors could not be computed due to singular Hessian.",
"Standard errors could not be computed due to singular Hessian. ",
"Confidence intervals will not be provided."
)
)
sds <- rep(NA, p_i + p_s + p_o)
} else {
if (any(sds < 0)) {
warning_(
paste0(
"Standard errors could not be computed due to negative variances. ",
"Confidence intervals will not be provided."
)
)
sds <- rep(NA, p_i + p_s + p_o)
} else {
sds <- sqrt(sds)
}
}
for(i in seq_along(probs)) {
q <- qnorm(probs[i])
Expand Down
10 changes: 6 additions & 4 deletions R/estimate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,26 @@ estimate_mnhmm <- function(
data = NULL, time = NULL, id = NULL, state_names = NULL,
channel_names = NULL, cluster_names = NULL, inits = "random", init_sd = 2,
restarts = 0L, threads = 1L, store_data = TRUE, verbose = TRUE,
restart_method = "1", ...) {
restart_method = "1", penalize = TRUE, penalty = 5, ...) {

call <- match.call()
model <- build_mnhmm(
observations, n_states, n_clusters, initial_formula,
transition_formula, emission_formula, cluster_formula, data, time, id,
state_names, channel_names, cluster_names
)
)
stopifnot_(
checkmate::test_flag(x = store_data),
"Argument {.arg store_data} must be a single {.cls logical} value.")
if (store_data) {
model$data <- data
}
if (restart_method == "1") {
out <- fit_mnhmm(model, inits, init_sd, restarts, threads, verbose, ...)
out <- fit_mnhmm(model, inits, init_sd, restarts, threads, verbose,
penalize, penalty, ...)
} else {
out <- fit_mnhmm2(model, inits, init_sd, restarts, threads, verbose, ...)
out <- fit_mnhmm2(model, inits, init_sd, restarts, threads, verbose,
penalize, penalty, ...)
}
attr(out, "call") <- call
out
Expand Down
5 changes: 3 additions & 2 deletions R/estimate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ estimate_nhmm <- function(
transition_formula = ~1, emission_formula = ~1,
data = NULL, time = NULL, id = NULL, state_names = NULL, channel_names = NULL,
inits = "random", init_sd = 2, restarts = 0L, threads = 1L,
store_data = TRUE, verbose = TRUE, ...) {
store_data = TRUE, verbose = TRUE, penalize = TRUE, penalty = 5, ...) {

call <- match.call()

Expand All @@ -88,7 +88,8 @@ estimate_nhmm <- function(
if (store_data) {
model$data <- data
}
out <- fit_nhmm(model, inits, init_sd, restarts, threads, verbose, ...)
out <- fit_nhmm(model, inits, init_sd, restarts, threads, verbose,
penalize, penalty, ...)
attr(out, "call") <- call
out
}
6 changes: 5 additions & 1 deletion R/fit_mnhmm.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Estimate a Mixture Non-homogeneous Hidden Markov Model
#'
#' @noRd
fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, penalize, penalty, ...) {
stopifnot_(
checkmate::test_int(x = threads, lower = 1L),
"Argument {.arg threads} must be a single positive integer."
Expand Down Expand Up @@ -57,6 +57,8 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
c(list(
model_code, init = init,
data = list(
penalty = penalty,
penalize = penalize,
N = model$n_sequences,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
Expand Down Expand Up @@ -107,6 +109,8 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
c(list(
model_code,
data = list(
penalty = penalty,
penalize = penalize,
N = model$n_sequences,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
Expand Down
9 changes: 8 additions & 1 deletion R/fit_mnhmm2.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#' Estimate a Mixture Non-homogeneous Hidden Markov Model
#'
#' @noRd
fit_mnhmm2 <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
fit_mnhmm2 <- function(model, inits, init_sd, restarts, threads, verbose,
penalize, penalty, ...) {
stopifnot_(
checkmate::test_int(x = threads, lower = 1L),
"Argument {.arg threads} must be a single positive integer."
Expand Down Expand Up @@ -61,6 +62,8 @@ fit_mnhmm2 <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
c(list(
model_code, init = init,
data = list(
penalize = penalize,
penalty = penalty,
N = model$n_sequences,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
Expand Down Expand Up @@ -99,6 +102,8 @@ fit_mnhmm2 <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
c(list(
model_code, init = out0[[idx[i]]]$par,
data = list(
penalize = penalize,
penalty = penalty,
N = model$n_sequences,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
Expand Down Expand Up @@ -149,6 +154,8 @@ fit_mnhmm2 <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
c(list(
model_code,
data = list(
penalize = penalize,
penalty = penalty,
N = model$n_sequences,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
Expand Down
6 changes: 5 additions & 1 deletion R/fit_nhmm.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' Estimate a Non-homogeneous Hidden Markov Model
#'
#' @noRd
fit_nhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
fit_nhmm <- function(model, inits, init_sd, restarts, threads, verbose, penalize, penalty, ...) {
stopifnot_(
checkmate::test_int(x = threads, lower = 1L),
"Argument {.arg threads} must be a single positive integer."
Expand Down Expand Up @@ -53,6 +53,8 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
c(list(
model_code, init = init,
data = list(
penalty = penalty,
penalize = penalize,
N = model$n_sequences,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
Expand Down Expand Up @@ -100,6 +102,8 @@ fit_nhmm <- function(model, inits, init_sd, restarts, threads, verbose, ...) {
c(list(
model_code,
data = list(
penalty = penalty,
penalize = penalize,
N = model$n_sequences,
T = model$sequence_lengths,
max_T = model$length_of_sequences,
Expand Down
4 changes: 2 additions & 2 deletions inst/stan/include/data_covariates.stan
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
array[max_T, N] vector[K_s] X_s; // covariates for transitions (including the intercept)
int<lower=1> K_o; // number of covariates for emission probabilities (including the intercept)
array[max_T, N] vector[K_o] X_o; // covariates for emissions (including the intercept)


int<lower=0,upper=1> penalize; // penalize the likelihood of the model
real<lower=0> penalty; // standard deviation of the priors
12 changes: 7 additions & 5 deletions inst/stan/include/priors_multichannel.stan
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
real prior = 0;
// priors for (very weak) regularisation
prior += normal_lpdf(to_vector(beta_i_raw) | 0, 5);
for(s in 1:S) {
prior += normal_lpdf(to_vector(beta_s_raw[s]) | 0, 5);
if (penalize == 1) {
// priors for (very weak) regularisation
prior += normal_lpdf(to_vector(beta_i_raw) | 0, penalty);
for(s in 1:S) {
prior += normal_lpdf(to_vector(beta_s_raw[s]) | 0, penalty);
}
prior += normal_lpdf(beta_o_raw | 0, penalty);
}
prior += normal_lpdf(beta_o_raw | 0, 5);
14 changes: 8 additions & 6 deletions inst/stan/include/priors_multichannel_mixture.stan
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// priors for (very weak) regularisation
real prior = 0;
for (d in 1:D){
prior += normal_lpdf(to_vector(beta_i_raw[d]) | 0, 5);
for(s in 1:S) {
prior += normal_lpdf(to_vector(beta_s_raw[d, s]) | 0, 5);
if (penalize == 1) {
for (d in 1:D){
prior += normal_lpdf(to_vector(beta_i_raw[d]) | 0, penalty);
for(s in 1:S) {
prior += normal_lpdf(to_vector(beta_s_raw[d, s]) | 0, penalty);
}
prior += normal_lpdf(beta_o_raw[d] | 0, penalty);
}
prior += normal_lpdf(beta_o_raw[d] | 0, 5);
prior += normal_lpdf(to_vector(theta_raw) | 0, penalty);
}
prior += normal_lpdf(to_vector(theta_raw) | 0, 5);
11 changes: 7 additions & 4 deletions inst/stan/include/priors_singlechannel.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// priors for (very weak) regularisation
real prior = normal_lpdf(to_vector(beta_i_raw) | 0, 5);
for(s in 1:S) {
prior += normal_lpdf(to_vector(beta_s_raw[s]) | 0, 5);
prior += normal_lpdf(to_vector(beta_o_raw[s]) | 0, 5);
real prior = 0;
if (penalize == 1) {
prior += normal_lpdf(to_vector(beta_i_raw) | 0, penalty);
for(s in 1:S) {
prior += normal_lpdf(to_vector(beta_s_raw[s]) | 0, penalty);
prior += normal_lpdf(to_vector(beta_o_raw[s]) | 0, penalty);
}
}
14 changes: 8 additions & 6 deletions inst/stan/include/priors_singlechannel_mixture.stan
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// priors for (very weak) regularisation
real prior = 0;
for (d in 1:D){
prior += normal_lpdf(to_vector(beta_i_raw[d]) | 0, 5);
for(s in 1:S) {
prior += normal_lpdf(to_vector(beta_s_raw[d, s]) | 0, 5);
prior += normal_lpdf(to_vector(beta_o_raw[d, s]) | 0, 5);
if (penalize == 1) {
for (d in 1:D){
prior += normal_lpdf(to_vector(beta_i_raw[d]) | 0, penalty);
for(s in 1:S) {
prior += normal_lpdf(to_vector(beta_s_raw[d, s]) | 0, penalty);
prior += normal_lpdf(to_vector(beta_o_raw[d, s]) | 0, penalty);
}
}
prior += normal_lpdf(to_vector(theta_raw) | 0, penalty);
}
prior += normal_lpdf(to_vector(theta_raw) | 0, 5);
1 change: 0 additions & 1 deletion inst/stan/mnhmm.stan
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ model {
target += log_lik;
}
generated quantities {
print("not yet!");
real ploglik_N = prior + log_lik;
if (N > N_sample) {
ploglik_N = prior + loglik_sc_mix(beta_i_raw, beta_s_raw,
Expand Down
Loading

0 comments on commit 6d0fc69

Please sign in to comment.