Skip to content

Commit

Permalink
feature_rsa_update
Browse files Browse the repository at this point in the history
  • Loading branch information
bbuchsbaum committed Dec 20, 2024
1 parent d28df8a commit 6b697f2
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 98 deletions.
86 changes: 46 additions & 40 deletions R/feature_rsa_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,68 +147,38 @@ feature_rsa_model <- function(dataset,
}


#' Train an RSA Model
#'
#' Trains a Feature RSA model using scca, pls, or pca method.
#' @param obj An object of class \code{feature_rsa_model}
#' @param train_dat The training data
#' @param indices The indices of the training data
#' @param ... Additional args
#' @return The trained model
#' @export
train_model.feature_rsa_model <- function(obj, train_dat, indices, ...) {
train_model.feature_rsa_model <- function(obj, train_dat, ytrain, ...) {
X <- as.matrix(train_dat)
F <- obj$design$F[indices,,drop=FALSE]

method <- obj$method
Fsub <- as.matrix(ytrain) # ytrain is the already subsetted portion of F

if (method == "pls") {
# Fit PLS: F ~ X, scale=TRUE handled by plsr internally
obj$trained_model <- pls::plsr(F ~ X, scale=TRUE)
obj$training_indices <- indices

predicted <- predict_model(obj, X)
obj$performance <- evaluate_model(obj, predicted, F)
return(obj)

} else if (method == "scca") {
# scca requires scaling
if (obj$method == "pls") {
obj$trained_model <- pls::plsr(Fsub ~ X, scale=TRUE)
} else if (obj$method == "scca") {
sx <- .standardize(X)
sf <- .standardize(F)
sf <- .standardize(Fsub)
scca_res <- whitening::scca(sx$X_sc, sf$X_sc, scale=FALSE)
obj$trained_model <- scca_res
obj$training_indices <- indices
obj$scca_x_mean <- sx$mean
obj$scca_x_sd <- sx$sd
obj$scca_f_mean <- sf$mean
obj$scca_f_sd <- sf$sd

predicted <- predict_model(obj, X)
obj$performance <- evaluate_model(obj, predicted, F)
return(obj)

} else if (method == "pca") {
# PCA on X (scaled)
} else if (obj$method == "pca") {
sx <- .standardize(X)
sf <- .standardize(F)
sf <- .standardize(Fsub)
pca_res <- prcomp(sx$X_sc, scale.=FALSE)
PC_train <- pca_res$x
PC_train_i <- cbind(1, PC_train)
coefs <- solve(t(PC_train_i)%*%PC_train_i, t(PC_train_i)%*%sf$X_sc)

coefs <- solve(t(PC_train_i) %*% PC_train_i, t(PC_train_i) %*% sf$X_sc)
obj$trained_model <- pca_res
obj$training_indices <- indices
obj$pcarot <- pca_res$rotation
obj$pca_x_mean <- sx$mean
obj$pca_x_sd <- sx$sd
obj$pca_f_mean <- sf$mean
obj$pca_f_sd <- sf$sd
obj$pca_coefs <- coefs

predicted <- predict_model(obj, X)
obj$performance <- evaluate_model(obj, predicted, F)
return(obj)
}
obj
}


Expand Down Expand Up @@ -255,6 +225,42 @@ evaluate_model.feature_rsa_model <- function(object, predicted, observed, ...) {
}


#' @export
y_train.feature_rsa_model <- function(object) {
object$F
}

format_result.feature_rsa_model <- function(obj, result, error_message=NULL, context, ...) {
if (!is.null(error_message)) {
return(tibble::tibble(
observed=list(NULL),
predicted=list(NULL),
error=TRUE,
error_message=error_message
))
} else {
# Predict on test data
testX <- tibble::as_tibble(context$test, .name_repair=.name_repair)
pred <- predict_model(obj, testX)

# observed is ytest
observed <- as.matrix(context$ytest)

# Evaluate
perf <- evaluate_model(obj, pred, observed)

# Return a tibble
# Store predicted and observed for optional inspection
tibble::tibble(
observed=list(observed),
predicted=list(pred),
performance=list(perf),
error=FALSE,
error_message="~"
)
}
}

#' Print Method for Feature RSA Model
#'
#' @param x The feature RSA model
Expand Down
2 changes: 1 addition & 1 deletion R/mvpa_iterate.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ internal_crossval <- function(mspec, roi, id) {
if (ncol(train) < 2) {
# Return an error message
return(
format_result(mspec, error_message="error: less than 2 features", context=list(roi=roi, ytrain=ytrain, ytest=ytest, train=train, test=test, .id=.id))
format_result(mspec, NULL, error_message="error: less than 2 features", context=list(roi=roi, ytrain=ytrain, ytest=ytest, train=train, test=test, .id=.id))
)
}

Expand Down
59 changes: 2 additions & 57 deletions R/vector_rsa_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,63 +130,8 @@ train_model.vector_rsa_model <- function(obj, train_dat, indices, ...) {
return(scores)
}

# Perform vector RSA for a subset of data
#
# @param roi A subset of data, usually representing one ROI or one trial block.
# @param mod_spec The RSA model specification.
# @param rnum roi ids
# @return A tibble with RSA results and potentially error information.
# @noRd
# do_vector_rsa <- function(roi, mod_spec, rnum) {
# xtrain <- tibble::as_tibble(neuroim2::values(roi$train_roi), .name_repair=.name_repair)
# ind <- indices(roi$train_roi)
# tryCatch({
# scores <- train_model(mod_spec, xtrain)
# tibble::tibble(result = list(NULL), indices=list(ind), performance=list(scores), id=rnum, error = FALSE, error_message = "~")
# }, error = function(e) {
# tibble::tibble(result = list(NULL), indices=list(ind), performance=list(NULL), id=rnum, error = TRUE, error_message = e$message)
# })
# }



#' Iterate over data sets applying the vector RSA model
#
# @param mod_spec The model specification.
# @param vox_list A list of voxel sets to analyze.
# @param ids Identifiers for each data set.
# @noRd
# vector_rsa_iterate <- function(mod_spec, vox_list, ids = seq_along(vox_list)) {
# # Ensure IDs match the number of data sets
# if (length(ids) != length(vox_list)) {
# stop("Length of ids must match the number of data sets.")
# }
#
# assert_that(length(ids) == length(vox_list), msg=paste("length(ids) = ", length(ids), "::", "length(vox_list) =", length(vox_list)))
#
# sframe <- get_samples(mod_spec$dataset, vox_list)
# ## iterate over searchlights using parallel futures
# sf <- sframe %>% dplyr::mutate(rnum=ids)
#
# fut_vector_rsa(mod_spec,sf)
#
# }


# Apply the RSA model in parallel using futures
#
# @param mod_spec The model specification.
# @param sf A tibble containing the data sets and their identifiers.
# @param method Method for computing similarities.
# @return A combined result of all RSA analyses.
# @noRd
# fut_vector_rsa <- function(mod_spec, sf, ...) {
# gc()
# sf %>% furrr::future_pmap(function(sample, rnum, .id) {
# do_vector_rsa(as_roi(sample, mod_spec$dataset), mod_spec, rnum, ...)
# }, .options = furrr::furrr_options(seed = T)) %>% dplyr::bind_rows()
#
# }



#' @noRd
#' @keywords internal
Expand Down

0 comments on commit 6b697f2

Please sign in to comment.