Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add caching of ranks for computing RSA with Spearman #34

Merged
merged 6 commits into from
Oct 15, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions mne_rsa/rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _partial_correlation(rdm_data, rdm_model, masks=None, type="pearson"):
return -R_partial[0, 1:]


def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False):
def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False, cache_ranks=True):
gydis marked this conversation as resolved.
Show resolved Hide resolved
"""Generate RSA values between data and model RDMs.

Will yield RSA scores for each data RDM.
Expand Down Expand Up @@ -156,30 +156,44 @@ def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False):
return_array = False
rdm_model = [_ensure_condensed(rdm_model, "rdm_model")]


if ignore_nan:
if cache_ranks and (rsa_metric == "spearman"):
raise ValueError("ignore_nan and cache_ranks is not yet supported together")
masks = [~np.isnan(rdm) for rdm in rdm_model]
else:
masks = [slice(None)] * len(rdm_model)

if cache_ranks and (rsa_metric == "spearman"):
# Precompute ranks for Spearman
rdm_model = [stats.rankdata(rdm) for rdm in rdm_model]

for rdm_data in rdm_data_gen:
rdm_data = _ensure_condensed(rdm_data, "rdm_data")
if ignore_nan:
data_mask = ~np.isnan(rdm_data)
masks = [m & data_mask for m in masks]
rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks)
rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks, cache_ranks)
if return_array:
yield np.asarray(rsa_vals)
else:
yield rsa_vals[0]


def _rsa_single_rdm(rdm_data, rdm_model, metric, masks):
def _rsa_single_rdm(rdm_data, rdm_model, metric, masks, cache_ranks):
"""Compute RSA between a single data RDM and one or more model RDMs."""
if metric == "spearman":
rsa_vals = [
stats.spearmanr(rdm_data[mask], rdm_model_[mask])[0]
for rdm_model_, mask in zip(rdm_model, masks)
]
if cache_ranks:
rdm_data = stats.rankdata(rdm_data)
rsa_vals = [
np.corrcoef(rdm_data, rdm_model_[mask])[0,1]
gydis marked this conversation as resolved.
Show resolved Hide resolved
for rdm_model_, mask in zip(rdm_model, masks)
]
else:
rsa_vals = [
stats.spearmanr(rdm_data[mask], rdm_model_[mask])[0]
for rdm_model_, mask in zip(rdm_model, masks)
]
elif metric == "pearson":
rsa_vals = [
stats.pearsonr(rdm_data[mask], rdm_model_[mask])[0]
gydis marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -312,6 +326,7 @@ def rsa_array(
y=None,
n_folds=1,
n_jobs=1,
cache_ranks=True,
verbose=False,
):
"""Perform RSA on an array of data, possibly in a searchlight pattern.
Expand Down Expand Up @@ -419,6 +434,8 @@ def rsa_array(
rdm_model = [_ensure_condensed(rdm_model, "rdm_model")]

if ignore_nan:
if cache_ranks and (rsa_metric == "spearman"):
raise ValueError("ignore_nan and cache_ranks is not yet supported together")
masks = [~np.isnan(rdm) for rdm in rdm_model]
else:
masks = [slice(None)] * len(rdm_model)
Expand All @@ -433,6 +450,10 @@ def rsa_array(
except AttributeError:
pass

if cache_ranks and (rsa_metric == "spearman"):
# Precompute ranks for Spearman
rdm_model = [stats.rankdata(rdm) for rdm in rdm_model]

def rsa_single_patch(patch):
"""Compute RSA for a single searchlight patch."""
if len(X) == 1: # Check number of folds
Expand All @@ -448,7 +469,7 @@ def rsa_single_patch(patch):
patch_masks = [m & data_mask for m in masks]
else:
patch_masks = masks
return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks)
return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks, cache_ranks=cache_ranks)

# Call RSA multiple times in parallel for each searchlight patch
data = Parallel(n_jobs=n_jobs)(
Expand Down