diff --git a/mne_rsa/rsa.py b/mne_rsa/rsa.py index 6b0dbfe..98838ee 100644 --- a/mne_rsa/rsa.py +++ b/mne_rsa/rsa.py @@ -161,26 +161,38 @@ def rsa_gen(rdm_data_gen, rdm_model, metric="spearman", ignore_nan=False): masks = [~np.isnan(rdm) for rdm in rdm_model] else: masks = [slice(None)] * len(rdm_model) + # Precompute ranks for Spearman + if metric == "spearman": + rdm_model = [stats.rankdata(rdm) for rdm in rdm_model] for rdm_data in rdm_data_gen: rdm_data = _ensure_condensed(rdm_data, "rdm_data") if ignore_nan: data_mask = ~np.isnan(rdm_data) masks = [m & data_mask for m in masks] - rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks) + rsa_vals = _rsa_single_rdm(rdm_data, rdm_model, metric, masks, ignore_nan) if return_array: yield np.asarray(rsa_vals) else: yield rsa_vals[0] -def _rsa_single_rdm(rdm_data, rdm_model, metric, masks): +def _rsa_single_rdm(rdm_data, rdm_model, metric, masks, ignore_nan): """Compute RSA between a single data RDM and one or more model RDMs.""" if metric == "spearman": - rsa_vals = [ - stats.spearmanr(rdm_data[mask], rdm_model_[mask])[0] - for rdm_model_, mask in zip(rdm_model, masks) - ] + if not ignore_nan: + rdm_data = stats.rankdata(rdm_data) + rsa_vals = [ + np.corrcoef(rdm_data, rdm_model_[mask])[0, 1] + for rdm_model_, mask in zip(rdm_model, masks) + ] + else: + rsa_vals = [ + np.corrcoef( + stats.rankdata(rdm_data[mask]), stats.rankdata(rdm_model_[mask]) + )[0, 1] + for rdm_model_, mask in zip(rdm_model, masks) + ] elif metric == "pearson": rsa_vals = [ stats.pearsonr(rdm_data[mask], rdm_model_[mask])[0] @@ -423,6 +435,9 @@ def rsa_array( masks = [~np.isnan(rdm) for rdm in rdm_model] else: masks = [slice(None)] * len(rdm_model) + # Precompute ranks for Spearman + if rsa_metric == "spearman": + rdm_model = [stats.rankdata(rdm) for rdm in rdm_model] if verbose: from tqdm import tqdm @@ -449,7 +464,7 @@ def rsa_single_patch(patch): patch_masks = [m & data_mask for m in masks] else: patch_masks = masks - return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks) + return _rsa_single_rdm(rdm_data, rdm_model, rsa_metric, patch_masks, ignore_nan) # Call RSA multiple times in parallel for each searchlight patch data = Parallel(n_jobs=n_jobs)(