Skip to content

Commit

Permalink
feature_rsa
Browse files Browse the repository at this point in the history
  • Loading branch information
bbuchsbaum committed Dec 20, 2024
1 parent 6b697f2 commit d961893
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 130 deletions.
8 changes: 5 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ S3method(crossval_samples,mvpa_model)
S3method(crossval_samples,sequential_blocked_cross_validation)
S3method(crossval_samples,twofold_blocked_cross_validation)
S3method(data_sample,mvpa_dataset)
S3method(format_result,feature_rsa_model)
S3method(get_samples,mvpa_dataset)
S3method(get_samples,mvpa_surface_dataset)
S3method(get_searchlight,mvpa_image_dataset)
Expand All @@ -20,6 +21,7 @@ S3method(has_test_set,mvpa_model)
S3method(merge_predictions,classification_prediction)
S3method(merge_predictions,regression_prediction)
S3method(merge_results,binary_classification_result)
S3method(merge_results,feature_rsa_model)
S3method(merge_results,multiway_classification_result)
S3method(merge_results,regional_mvpa_result)
S3method(merge_results,regression_result)
Expand Down Expand Up @@ -58,6 +60,7 @@ S3method(print,searchlight_result)
S3method(print,twofold_blocked_cross_validation)
S3method(prob_observed,binary_classification_result)
S3method(prob_observed,multiway_classification_result)
S3method(run_regional,feature_rsa_model)
S3method(run_regional,mvpa_model)
S3method(run_regional,rsa_model)
S3method(run_searchlight,model_spec)
Expand All @@ -69,11 +72,13 @@ S3method(sub_result,binary_classification_result)
S3method(sub_result,multiway_classification_result)
S3method(summary,feature_rsa_model)
S3method(test_design,mvpa_design)
S3method(train_model,feature_rsa_model)
S3method(train_model,rsa_model)
S3method(train_model,vector_rsa_model)
S3method(tune_grid,mvpa_model)
S3method(y_test,mvpa_design)
S3method(y_test,mvpa_model)
S3method(y_train,feature_rsa_model)
S3method(y_train,mvpa_design)
S3method(y_train,mvpa_model)
export(MVPAModels)
Expand Down Expand Up @@ -150,9 +155,7 @@ importFrom(Rfit,rfit)
importFrom(assertthat,assert_that)
importFrom(corpcor,invcov.shrink)
importFrom(dplyr,bind_rows)
importFrom(dplyr,do)
importFrom(dplyr,filter)
importFrom(dplyr,rowwise)
importFrom(ffmanova,ffmanova)
importFrom(furrr,future_map)
importFrom(furrr,future_pmap)
Expand All @@ -167,7 +170,6 @@ importFrom(neuroim2,NeuroVol)
importFrom(neuroim2,ROIVec)
importFrom(neuroim2,SparseNeuroVec)
importFrom(neuroim2,coords)
importFrom(neuroim2,indices)
importFrom(neuroim2,map_values)
importFrom(neuroim2,read_vec)
importFrom(neuroim2,read_vol)
Expand Down
62 changes: 62 additions & 0 deletions R/feature_rsa_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ y_train.feature_rsa_model <- function(object) {
object$F
}

#' @export
format_result.feature_rsa_model <- function(obj, result, error_message=NULL, context, ...) {
if (!is.null(error_message)) {
return(tibble::tibble(
Expand Down Expand Up @@ -261,6 +262,67 @@ format_result.feature_rsa_model <- function(obj, result, error_message=NULL, con
}
}

#' Merge Multiple Results for Feature RSA Model
#'
#' @param obj A \code{feature_rsa_model} object
#' @param result_set A data frame of results from cross-validation folds
#' @param indices The voxel indices used (may not be relevant for feature_rsa_model)
#' @param id An identifier for the merged result (e.g., ROI id)
#' @param ... Additional arguments
#' @return A tibble with merged results
#' @export
merge_results.feature_rsa_model <- function(obj, result_set, indices, id, ...) {
# If any errors occurred, return an error tibble
if (any(result_set$error)) {
emessage <- result_set$error_message[which(result_set$error)[1]]
return(tibble::tibble(
result=list(NULL),
indices=list(indices),
performance=list(NULL),
id=id,
error=TRUE,
error_message=emessage,
warning=any(result_set$warning),
warning_message=if(any(result_set$warning)) result_set$warning_message[which(result_set$warning)[1]] else "~"
))
}

# If no errors, combine all fold predictions and observed values
# Each fold should have columns: observed, predicted, performance, etc.
# result_set might have multiple rows (one per fold)

# Extract observed and predicted from each fold
# They are stored as lists, each element a matrix
observed_list <- result_set$observed
predicted_list <- result_set$predicted

# Combine rows (observations) across folds
combined_observed <- do.call(rbind, observed_list)
combined_predicted <- do.call(rbind, predicted_list)

# Compute performance on the combined set
perf <- evaluate_model(obj, combined_predicted, combined_observed)

# Create a single combined result. We can store the combined predictions and observed as well.
# 'result' typically is some kind of result object; here we can store the combined predictions.
# For consistency, let's mimic mvpa_model: store combined predictions/observed in 'result' as well.
combined_result <- list(
observed = combined_observed,
predicted = combined_predicted
)

tibble::tibble(
result = list(combined_result),
indices = list(indices),
performance = list(perf),
id = id,
error = FALSE,
error_message = "~",
warning = any(result_set$warning),
warning_message = if(any(result_set$warning)) result_set$warning_message[which(result_set$warning)[1]] else "~"
)
}

#' Print Method for Feature RSA Model
#'
#' @param x The feature RSA model
Expand Down
1 change: 1 addition & 0 deletions R/mvpa_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ format_result.mvpa_model <- function(obj, result, error_message=NULL, context, .
}



#' @keywords internal
get_multiclass_perf <- function(split_list=NULL, class_metrics=TRUE) {
function(result) {
Expand Down
4 changes: 2 additions & 2 deletions man/evaluate_model.feature_rsa_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 16 additions & 34 deletions man/feature_rsa_design.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 0 additions & 18 deletions man/feature_rsa_iterate.Rd

This file was deleted.

57 changes: 12 additions & 45 deletions man/feature_rsa_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/merge_results.feature_rsa_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/predict_model.feature_rsa_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/print.feature_rsa_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d961893

Please sign in to comment.