diff --git a/naplib/localization/freesurfer.py b/naplib/localization/freesurfer.py index f29b552..81ea134 100644 --- a/naplib/localization/freesurfer.py +++ b/naplib/localization/freesurfer.py @@ -10,8 +10,10 @@ from nibabel.freesurfer.io import read_geometry, read_label, read_morph_data from scipy.spatial.distance import cdist from skspatial.objects import Line, Plane +from hdf5storage import loadmat from naplib.utils import dist_calc, load_freesurfer_label +from naplib import logger warnings.filterwarnings("ignore", message="nopython", append=True) @@ -109,6 +111,54 @@ } region2num = {v: k for k, v in num2region.items()} +num2region_mni = { + 0: 'unknown', + 1: 'bankssts', + 2: 'caudalanteriorcingulate', + 3: 'caudalmiddlefrontal', + 4: 'corpuscallosum', + 5: 'cuneus', + 6: 'entorhinal', + 7: 'fusiform', + 8: 'inferiorparietal', + 9: 'inferiortemporal', + 10: 'isthmuscingulate', + 11: 'lateraloccipital', + 12: 'lateralorbitofrontal', + 13: 'lingual', + 14: 'medialorbitofrontal', + 15: 'middletemporal', + 16: 'parahippocampal', + 17: 'paracentral', + 18: 'parsopercularis', + 19: 'parsorbitalis', + 20: 'parstriangularis', + 21: 'pericalcarine', + 22: 'postcentral', + 23: 'posteriorcingulate', + 24: 'precentral', + 25: 'precuneus', + 26: 'rostralanteriorcingulate', + 27: 'rostralmiddlefrontal', + 28: 'superiorfrontal', + 29: 'superiorparietal', + 30: 'superiortemporal', + 31: 'supramarginal', + 32: 'frontalpole', + 33: 'temporalpole', + 34: 'transversetemporal', + 35: 'insula', + 36: 'cMTG', + 37: 'mMTG', + 38: 'rMTG', + 39: 'cSTG', + 40: 'mSTG', + 41: 'rSTG', + # My custom labels + 42: "O_IFG", +} +region2num_mni = {v: k for k, v in num2region_mni.items()} + class Hemisphere: def __init__( @@ -147,12 +197,40 @@ def __init__( subject_dir = os.environ.get("SUBJECTS_DIR", "./") self.subject_dir = subject_dir + + self.atlas = 'FSAverage' - self.surf = read_geometry(self.surf_file(f"{hemi}.{surf_type}")) - self.cort = np.sort(read_label(self.label_file(f"{hemi}.cortex.label"))) - self.sulc = read_morph_data(self.surf_file(f"{hemi}.sulc")) + if os.path.exists(self.surf_file(f"{hemi}.{surf_type}")): + self.surf = read_geometry(self.surf_file(f"{hemi}.{surf_type}")) + else: + # try to find .mat file + surf_ = loadmat(self.surf_file(f"{hemi}_pial.mat")) + coords, faces = surf_['coords'], surf_['faces'] + faces -= 1 # make faces zero-indexed + self.surf = (coords, faces) + self.atlas = 'MNI152' + + if self.atlas == 'FSAverage': + self.surf_pial = read_geometry(self.surf_file(f"{hemi}.pial")) + else: + # try to find .mat file + surf_ = loadmat(self.surf_file(f"{hemi}_pial.mat")) + coords, faces = surf_['coords'], surf_['faces'] + faces -= 1 # make faces zero-indexed + self.surf_pial = (coords, faces) + + try: + self.cort = np.sort(read_label(self.label_file(f"{hemi}.cortex.label"))) + except Exception as e: + logger.warning(f'No {hemi}.cortext.label file found. Assuming the entire surface is cortex.') + self.cort = np.arange(self.surf[0].shape[0]) + + try: + self.sulc = read_morph_data(self.surf_file(f"{hemi}.sulc")) + except Exception as e: + logger.warning(f'No {hemi}.sulc file found. No sulcus information will be used.') + self.sulc = None - self.surf_pial = read_geometry(self.surf_file(f"{hemi}.pial")) self.load_labels() self.reset_overlay() @@ -196,18 +274,35 @@ def load_labels(self): """ self.ignore = np.zeros(self.n_verts, dtype=bool) annot_file = self.label_file(f"{self.hemi}.aparc.a2005s.annot") - for reg in ("Unknown", "Medial_wall"): - self.ignore[load_freesurfer_label(annot_file, reg)] = True + if os.path.exists(annot_file): + for reg in ("Unknown", "Medial_wall"): + self.ignore[load_freesurfer_label(annot_file, reg)] = True self.labels = np.zeros(self.n_verts, dtype=int) annot_file = self.label_file(f"{self.hemi}.aparc.a2009s.annot") - for ind, reg in num2region.items(): - if reg.startswith("O"): - continue - self.labels[load_freesurfer_label(annot_file, reg)] = ind + annot_file_mni = self.label_file(f"FSL_MNI152.{self.hemi}.aparc.split_STG_MTG.annot") + if self.atlas == 'FSAverage': + for ind, reg in num2region.items(): + if reg.startswith("O"): + continue + self.labels[load_freesurfer_label(annot_file, reg)] = ind + elif self.atlas == 'MNI152': + for ind, reg in num2region_mni.items(): + if reg.startswith("O"): + continue + self.labels[load_freesurfer_label(annot_file_mni, reg)] = ind + else: + raise ValueError('Bad atlas') self.labels[self.ignore] = 0 - self.num2label = num2region - self.label2num = {v: k for k, v in self.num2label.items()} + if self.atlas == 'FSAverage': + self.num2label = num2region + self.label2num = {v: k for k, v in self.num2label.items()} + elif self.atlas == 'MNI152': + self.num2label = num2region_mni + self.label2num = {v: k for k, v in self.num2label.items()} + else: + raise ValueError('Bad atlas') + self.simplified = False self.is_mangled_hg = False @@ -225,31 +320,47 @@ def simplify_labels(self): ------- self : instance of self """ - conversions = { - "Other": [], # Autofill all uncovered vertecies - "HG": ["G_temp_sup-G_T_transv"], - "pmHG": ["O_pmHG"], - "alHG": ["O_alHG"], - "Te1.0": ["O_Te10"], - "Te1.1": ["O_Te11"], - "Te1.2": ["O_Te12"], - "TTS": ["S_temporal_transverse"], - "PT": ["G_temp_sup-Plan_tempo"], - "PP": ["G_temp_sup-Plan_polar"], - "MTG": ["G_temporal_middle"], - "ITG": ["G_temporal_inf"], - "STG": ["G_temp_sup-Lateral"], - "mSTG": ["O_mSTG"], - "pSTG": ["O_pSTG"], - "STS": ["S_temporal_sup"], - "IFG": ["O_IFG"], - "IFG.opr": ["G_front_inf-Opercular"], - "IFG.tri": ["G_front_inf-Triangul"], - "IFG.orb": ["G_front_inf-Orbital"], - "Subcnt": ["G_and_S_subcentral"], - "Insula": ["G_Ins_lg_and_S_cent_ins", "G_insular_short"], - "T.Pole": ["Pole_temporal"], - } + if self.atlas == 'FSAverage': + conversions = { + "Other": [], # Autofill all uncovered vertecies + "HG": ["G_temp_sup-G_T_transv"], + "pmHG": ["O_pmHG"], + "alHG": ["O_alHG"], + "Te1.0": ["O_Te10"], + "Te1.1": ["O_Te11"], + "Te1.2": ["O_Te12"], + "TTS": ["S_temporal_transverse"], + "PT": ["G_temp_sup-Plan_tempo"], + "PP": ["G_temp_sup-Plan_polar"], + "MTG": ["G_temporal_middle"], + "ITG": ["G_temporal_inf"], + "STG": ["G_temp_sup-Lateral"], + "mSTG": ["O_mSTG"], + "pSTG": ["O_pSTG"], + "STS": ["S_temporal_sup"], + "IFG": ["O_IFG"], + "IFG.opr": ["G_front_inf-Opercular"], + "IFG.tri": ["G_front_inf-Triangul"], + "IFG.orb": ["G_front_inf-Orbital"], + "Subcnt": ["G_and_S_subcentral"], + "Insula": ["G_Ins_lg_and_S_cent_ins", "G_insular_short"], + "T.Pole": ["Pole_temporal"], + } + + + elif self.atlas == 'MNI152': + d1 = {k: [k] for k in region2num_mni.keys() if k not in ['O_IFG','parsopercularis','parstriangularis','parsorbitalis']} + d2_override = { + "Other": [], + "IFG": ["O_IFG"], + "IFG.opr": ["parsopercularis"], + "IFG.tri": ["parstriangularis"], + "IFG.orb": ["parsorbitalis"], + } + conversions = {**d1, **d2_override} + else: + raise ValueError('Bad atlas') + conversions = { key: [self.label2num[g] for g in groups] for key, groups in conversions.items() @@ -386,6 +497,9 @@ def split_hg(self, method="midpoint"): ) self.is_mangled_hg = True + if self.atlas == 'MNI152': + raise ValueError(f'split_hg() is not supported for MNI atlas.') + hg = self.filter_labels(["G_temp_sup-G_T_transv", "HG"]) if method == "midpoint": @@ -502,6 +616,9 @@ def remove_tts(self, method="split"): ------- self : instance of self """ + if self.atlas == 'MNI152': + raise ValueError(f'remove_tts() is not supported for MNI atlas.') + if self.is_mangled_tts: raise RuntimeError( "TTS cannot be removed as it is already mangled. Try changing order of operations?" @@ -557,6 +674,9 @@ def split_stg(self, method="tts_plane"): ------- self : instance of self """ + if self.atlas == 'MNI152': + raise ValueError(f'split_stg() is not supported for MNI atlas.') + if self.is_mangled_stg: raise RuntimeError( "STG cannot be split as it is already mangled. Try changing order of operations?" @@ -598,16 +718,29 @@ def join_ifg(self): ) self.is_mangled_ifg = True - ifg = self.filter_labels( - [ - "G_front_inf-Opercular", - "G_front_inf-Triangul", - "G_front_inf-Orbital", - "IFG.opr", - "IFG.tri", - "IFG.orb", - ] - ) + if self.atlas == 'FSAverage': + ifg = self.filter_labels( + [ + "G_front_inf-Opercular", + "G_front_inf-Triangul", + "G_front_inf-Orbital", + "IFG.opr", + "IFG.tri", + "IFG.orb", + ] + ) + else: # MNI152 + ifg = self.filter_labels( + [ + "parsopercularis", + "parstriangularis", + "parsorbitalis", + "IFG.opr", + "IFG.tri", + "IFG.orb", + ] + ) + self.labels[ifg] = self.label2num["IFG" if self.simplified else "O_IFG"] return self diff --git a/naplib/visualization/brain_plots.py b/naplib/visualization/brain_plots.py index 161b6b0..c62ef99 100644 --- a/naplib/visualization/brain_plots.py +++ b/naplib/visualization/brain_plots.py @@ -365,7 +365,7 @@ def plot_brain_elecs( def _plot_brain_elecs_standalone( brain, surfs, - sulci, + sulci=None, elecs=None, elec_isleft=None, elec_values=None, @@ -396,7 +396,6 @@ def _plot_brain_elecs_standalone( ) assert isinstance(surfs, dict) - assert isinstance(sulci, dict) if cortex not in colormap_map: raise ValueError( @@ -445,7 +444,8 @@ def _plot_brain_elecs_standalone( for i, hemi in enumerate(hemi_keys): verts = surfs[hemi][0] triangles = surfs[hemi][1] - sulc = sulci[hemi] + if sulci[hemi] is not None: + sulc = sulci[hemi] if isinstance(view, str): elev, azim = _view(hemi, mode=view, backend=backend) @@ -455,14 +455,21 @@ def _plot_brain_elecs_standalone( raise ValueError("Argument `view` should be a string or tuple.") # color by sulci - triangle_values_sulci = np.array( - [[sulc[nn] for nn in triangles[i]] for i in range(len(triangles))] - ).mean(1) - colors_sulci = cmap_sulci_func(triangle_values_sulci) + if sulci[hemi] is not None: + triangle_values_sulci = np.array( + [[sulc[nn] for nn in triangles[i]] for i in range(len(triangles))] + ).mean(1) + colors_sulci = cmap_sulci_func(triangle_values_sulci) + else: + colors_sulci = np.ones((len(triangles),4)) + colors_sulci[:,:3] = 0.5 + if backend == "plotly": + colors_sulci *= 255 colors_sulci = colors_sulci.astype("int") + if len(hemi_keys) == 2: # add some offset between hemispheres # if plotting both hemispheres on brain, need to offset since @@ -495,9 +502,13 @@ def _plot_brain_elecs_standalone( p3dc = ax.plot_trisurf( verts[:, 0], verts[:, 1], verts[:, 2], triangles=triangles ) - # set the face colors of the Poly3DCollection - colors_sulci[:, -1] = brain_alpha - p3dc.set_fc(colors_sulci) + if sulci[hemi] is not None: + # set the face colors of the Poly3DCollection + colors_sulci[:, -1] = brain_alpha + p3dc.set_fc(colors_sulci) + else: + p3dc.set_alpha(brain_alpha) + if elecs is None: continue