diff --git a/bluepyemodel/emodel_pipeline/emodel_settings.py b/bluepyemodel/emodel_pipeline/emodel_settings.py index 140097d1..20d87d0f 100644 --- a/bluepyemodel/emodel_pipeline/emodel_settings.py +++ b/bluepyemodel/emodel_pipeline/emodel_settings.py @@ -446,6 +446,14 @@ def __init__( self.plot_IV_curves = plot_IV_curves self.plot_FI_curve_comparison = plot_FI_curve_comparison self.plot_traces_comparison = plot_traces_comparison + if extract_absolute_amplitudes is True: + if any((plot_IV_curves, plot_FI_curve_comparison)): + logger.warning( + "The 'plot_IV_curves' and 'plot_FI_curve_comparison' features do not " + "support absolute current amplitude. These plots have been disabled." + ) + self.plot_IV_curves = False + self.plot_FI_curve_comparison = False if pickle_cells_extraction is False: if any((plot_IV_curves, plot_FI_curve_comparison, plot_traces_comparison)): logger.warning( diff --git a/bluepyemodel/emodel_pipeline/plotting.py b/bluepyemodel/emodel_pipeline/plotting.py index ad8b2c36..314961f9 100644 --- a/bluepyemodel/emodel_pipeline/plotting.py +++ b/bluepyemodel/emodel_pipeline/plotting.py @@ -43,6 +43,7 @@ from bluepyemodel.emodel_pipeline.plotting_utils import get_experimental_FI_curve_for_plotting from bluepyemodel.emodel_pipeline.plotting_utils import get_impedance from bluepyemodel.emodel_pipeline.plotting_utils import get_ordered_currentscape_keys +from bluepyemodel.emodel_pipeline.plotting_utils import get_original_protocol_name from bluepyemodel.emodel_pipeline.plotting_utils import get_recording_names from bluepyemodel.emodel_pipeline.plotting_utils import get_simulated_FI_curve_for_plotting from bluepyemodel.emodel_pipeline.plotting_utils import get_sinespec_evaluator @@ -50,7 +51,10 @@ from bluepyemodel.emodel_pipeline.plotting_utils import get_traces_names_and_float_responses from bluepyemodel.emodel_pipeline.plotting_utils import get_traces_ylabel from bluepyemodel.emodel_pipeline.plotting_utils import get_voltage_currents_from_files +from bluepyemodel.emodel_pipeline.plotting_utils import plot_fi_curves from bluepyemodel.emodel_pipeline.plotting_utils import rel_to_abs_amplitude +from bluepyemodel.emodel_pipeline.plotting_utils import save_fig +from bluepyemodel.emodel_pipeline.plotting_utils import update_evaluator from bluepyemodel.evaluation.evaluation import compute_responses from bluepyemodel.evaluation.evaluation import get_evaluator_from_access_point from bluepyemodel.evaluation.evaluator import PRE_PROTOCOLS @@ -85,14 +89,6 @@ } -def save_fig(figures_dir, figure_name, dpi=100): - """Save a matplotlib figure""" - p = Path(figures_dir) / figure_name - plt.savefig(str(p), dpi=dpi, bbox_inches="tight") - plt.close("all") - plt.clf() - - def optimisation( optimiser, emodel, @@ -1205,8 +1201,11 @@ def run_and_plot_EPSP( def plot_IV_curves( evaluator, emodels, + access_point, figures_dir, efel_settings, + mapper, + seeds, prot_name="iv", custom_bluepyefe_cells_pklpath=None, write_fig=True, @@ -1218,8 +1217,11 @@ def plot_IV_curves( Args: evaluator (CellEvaluator): cell evaluator emodels (list): list of EModels + access_point (DataAccessPoint): data access point figures_dir (str or Path): output directory for the figure to be saved on efel_settings (dict): eFEL settings in the form {setting_name: setting_value}. + mapper (map): used to parallelize the evaluation of the individual in the population. + seeds (list): if not None, filter emodels to keep only the ones with these seeds. prot_name (str): Only recordings from this protocol will be used. custom_bluepyefe_cells_pklpath (str): file path to the cells.pkl output of BluePyEfe. If None, will use usual file path used in BluePyEfe, @@ -1227,14 +1229,28 @@ def plot_IV_curves( write_fig (bool): whether to save the figure n_bin (int): number of bins to use """ - # pylint: disable=too-many-nested-blocks, possibly-used-before-assignment + # pylint: disable=too-many-nested-blocks, possibly-used-before-assignment, disable=too-many-locals, disable=too-many-statements # note: should maybe also check location and recorded variable make_dir(figures_dir) if efel_settings is None: efel_settings = bluepyefe.tools.DEFAULT_EFEL_SETTINGS.copy() + lower_bound = -100 + upper_bound = 100 + + # Generate amplitude points + sim_amp_points = list(map(int, numpy.linspace(lower_bound, upper_bound, n_bin + 1))) # add missing features (if any) to evaluator - updated_evaluator = fill_in_IV_curve_evaluator(evaluator, efel_settings, prot_name) + updated_evaluator = fill_in_IV_curve_evaluator( + evaluator, efel_settings, prot_name, sim_amp_points + ) + + emodels = compute_responses( + access_point, + updated_evaluator, + map_function=mapper, + seeds=seeds, + ) emodel_name = None cells = None @@ -1349,6 +1365,9 @@ def plot_IV_curves( def plot_FI_curves_comparison( evaluator, emodels, + access_point, + seeds, + mapper, figures_dir, prot_name, custom_bluepyefe_cells_pklpath=None, @@ -1362,6 +1381,9 @@ def plot_FI_curves_comparison( Args: evaluator (CellEvaluator): cell evaluator emodels (list): list of EModels + access_point (DataAccessPoint): data access point + seeds (list): if not None, filter emodels to keep only the ones with these seeds. + mapper (map): used to parallelize the evaluation of the individual in the population. figures_dir (str or Path): output directory for the figure to be saved on prot_name (str): name of the protocol to use for the FI curve custom_bluepyefe_cells_pklpath (str): file path to the cells.pkl output of BluePyEfe. @@ -1373,8 +1395,8 @@ def plot_FI_curves_comparison( # pylint: disable=too-many-nested-blocks, possibly-used-before-assignment make_dir(figures_dir) - emodel_name = None - cells = None + emodel_name, cells = None, None + updated_evaluator = copy.deepcopy(evaluator) for emodel in emodels: # do not re-extract data if the emodel is the same as previously if custom_bluepyefe_cells_pklpath is not None: @@ -1382,15 +1404,9 @@ def plot_FI_curves_comparison( cells = read_extraction_output(custom_bluepyefe_cells_pklpath) if cells is None: continue + # experimental FI curve - ( - expt_amp_rel, - expt_freq_rel, - expt_freq_rel_err, - expt_amp, - expt_freq_abs, - expt_freq_abs_err, - ) = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin) + expt_data = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin) elif emodel_name != emodel.emodel_metadata.emodel or cells is None: # take extraction data from pickle file and rearange it for plotting cells = read_extraction_output_cells(emodel.emodel_metadata.emodel) @@ -1399,57 +1415,47 @@ def plot_FI_curves_comparison( continue # experimental FI curve - ( - expt_amp_rel, - expt_freq_rel, - expt_freq_rel_err, - expt_amp, - expt_freq_abs, - expt_freq_abs_err, - ) = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin) + expt_data = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin) emodel_name = emodel.emodel_metadata.emodel - # simulated FI curve - simulated_amp_rel, simulated_amp, simulated_freq = get_simulated_FI_curve_for_plotting( - evaluator, emodel.responses, prot_name + expt_data_amp_rel = expt_data[0] + prot_name_original = get_original_protocol_name(prot_name, evaluator) + updated_evaluator = update_evaluator( + expt_data_amp_rel, prot_name_original, updated_evaluator ) - # plotting - _, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 3)) - ax[0].errorbar( - expt_amp_rel, - expt_freq_rel, - yerr=expt_freq_rel_err, - marker="o", - color="grey", - label="experiment", - ) - ax[0].set_xlabel("Amplitude (% of rheobase)") - ax[0].set_ylabel("Mean Frequency (Hz)") - ax[0].set_title("FI curve (relative amplitude)") - ax[0].plot(simulated_amp_rel, simulated_freq, "o", color="blue", label="model") + updated_evaluator.fitness_protocols["main_protocol"].execution_order = ( + updated_evaluator.fitness_protocols["main_protocol"].compute_execution_order() + ) - ax[1].errorbar( - expt_amp, - expt_freq_abs, - yerr=expt_freq_abs_err, - marker="o", - color="grey", - label="experiment", - ) - ax[1].set_xlabel("Amplitude (nA)") - ax[1].set_ylabel("Voltage (mV)") - ax[1].set_title("IV curve (absolute amplitude)") - ax[1].plot(simulated_amp, simulated_freq, "o", color="blue", label="model") + emodels = compute_responses(access_point, updated_evaluator, mapper, seeds) + for emodel in emodels: + # do not re-extract data if the emodel is the same as previously + if custom_bluepyefe_cells_pklpath is not None: + if cells is None: + cells = read_extraction_output(custom_bluepyefe_cells_pklpath) + if cells is None: + continue - ax[0].legend() - ax[1].legend() - if write_fig: - save_fig( - figures_dir, - emodel.emodel_metadata.as_string(emodel.seed) + "__FI_curve_comparison.pdf", - ) + # experimental FI curve + expt_data = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin) + elif emodel_name != emodel.emodel_metadata.emodel or cells is None: + # take extraction data from pickle file and rearange it for plotting + cells = read_extraction_output_cells(emodel.emodel_metadata.emodel) + emodel_name = emodel.emodel_metadata.emodel + if cells is None: + continue + + # experimental FI curve + expt_data = get_experimental_FI_curve_for_plotting(cells, prot_name, n_bin=n_bin) + + emodel_name = emodel.emodel_metadata.emodel + + sim_data = get_simulated_FI_curve_for_plotting( + updated_evaluator, emodel.responses, prot_name + ) + plot_fi_curves(expt_data, sim_data, figures_dir, emodel, write_fig) def phase_plot( @@ -2046,10 +2052,13 @@ def plot_models( plot_IV_curves( cell_evaluator, emodels, + access_point, figures_dir_IV_curves, # maybe we should give efel_settings as an argument of plot_models, # like the other pipeline_settings access_point.pipeline_settings.efel_settings, + mapper, + seeds, IV_curve_prot_name, custom_bluepyefe_cells_pklpath=custom_bluepyefe_cells_pklpath, ) @@ -2059,6 +2068,9 @@ def plot_models( plot_FI_curves_comparison( cell_evaluator, emodels, + access_point, + seeds, + mapper, figures_dir_FI_curves, FI_curve_prot_name, custom_bluepyefe_cells_pklpath=custom_bluepyefe_cells_pklpath, diff --git a/bluepyemodel/emodel_pipeline/plotting_utils.py b/bluepyemodel/emodel_pipeline/plotting_utils.py index 527b0402..52ce9fc1 100644 --- a/bluepyemodel/emodel_pipeline/plotting_utils.py +++ b/bluepyemodel/emodel_pipeline/plotting_utils.py @@ -21,13 +21,16 @@ from pathlib import Path import efel +import matplotlib.pyplot as plt import numpy from bluepyopt.ephys.objectives import SingletonWeightObjective from bluepyemodel.ecode.sinespec import SineSpec from bluepyemodel.evaluation.efel_feature_bpem import eFELFeatureBPEM from bluepyemodel.evaluation.evaluator import PRE_PROTOCOLS +from bluepyemodel.evaluation.evaluator import define_protocol from bluepyemodel.evaluation.evaluator import soma_loc +from bluepyemodel.evaluation.protocol_configuration import ProtocolConfiguration from bluepyemodel.evaluation.protocols import BPEMProtocol from bluepyemodel.evaluation.protocols import ThresholdBasedProtocol from bluepyemodel.evaluation.recordings import FixedDtRecordingCustom @@ -81,6 +84,7 @@ def get_recording_names(protocol_config, stimuli): def get_traces_names_and_float_responses(responses, recording_names): """Extract the names of the traces to be plotted, as well as the float responses values.""" + # pylint: disable=too-many-nested-blocks traces_names = [] threshold = None @@ -236,14 +240,83 @@ def extract_experimental_data_for_IV_curve(cells, efel_settings, prot_name="iv", return exp_peak, exp_vd -def fill_in_IV_curve_evaluator(evaluator, efel_settings, prot_name="iv"): +def find_matching_feature(evaluator, protocol_name): + conditions = [ + lambda feat: protocol_name in feat.recording_names[""], + lambda feat: protocol_name.split(".")[0] in feat.recording_names[""] + and feat.stimulus_current() is not None, + lambda feat: protocol_name.split("_")[0] in feat.recording_names[""] + and feat.stimulus_current() is not None, + ] + + for condition in conditions: + for objective in evaluator.fitness_calculator.objectives: + feat = objective.features[0] + if condition(feat): + return feat, condition == conditions[0] + + return None, False + + +def create_protocol(amp_rel, amp, feature, protocol, protocol_name): + """ + Create a new protocol with adjusted stimulus amplitude based on a threshold. + + Arguments: + amp_rel (float): Relative amplitude as a percentage of the threshold current. + amp (float): Absolute amplitude to use for recalculating the stimulus amplitude. + feature (eFELFeatureBPEM, optional): Feature object used to retrieve the threshold + current for scaling. + protocol (BPEMProtocol): The original protocol to modify. + protocol_name (str): Name for the new protocol. + + Returns: + BPEMProtocol: A new protocol object with the adjusted threshold-based stimulus amplitude. + """ + if amp is None: + s_amp = amp_rel + elif amp_rel is not None and feature is not None: + s_amp = feature.stimulus_current() * amp_rel / amp + else: + s_amp = None + + stimuli = [ + { + "holding_current": protocol.stimuli[0].holding_current, + "threshold_current": protocol.stimuli[0].threshold_current, + "amp": s_amp, + "thresh_perc": amp_rel, + "delay": protocol.stimuli[0].delay, + "duration": protocol.stimuli[0].duration, + "totduration": protocol.stimuli[0].total_duration, + } + ] + recordings = [ + { + "type": "CompRecording", + "name": f"{protocol_name}.soma.v", + "location": "soma", + "variable": "v", + } + ] + my_protocol_configuration = ProtocolConfiguration( + name=protocol_name, stimuli=stimuli, recordings=recordings, validation=False + ) + p = define_protocol(my_protocol_configuration) + + return p + + +def fill_in_IV_curve_evaluator(evaluator, efel_settings, prot_name="iv", new_amps=None): """Returns a copy of the evaluator, with missing features added for IV_curve computation. Args: evaluator (CellEvaluator): cell evaluator efel_settings (dict): eFEL settings in the form {setting_name: setting_value}. prot_name (str): Only recordings from this protocol will be used. + new_amps (list): List of amplitudes to extend the protocols with. """ + # pylint: disable=too-many-branches updated_evaluator = copy.deepcopy(evaluator) # find protocols we expect to have the features we want to plot prot_max_v = [] @@ -258,24 +331,52 @@ def fill_in_IV_curve_evaluator(evaluator, efel_settings, prot_name="iv"): else: prot_v_deflection.append(prot.name) + if new_amps is not None: + prot_name_original = get_original_protocol_name(prot_name, evaluator) + for amp in new_amps: + protocol_name_amp = f"{prot_name_original.split('_')[0]}_{amp}" + if 0 <= amp < 100: + if protocol_name_amp not in prot_max_v: + prot_max_v.append(protocol_name_amp) + elif amp < 0: + if protocol_name_amp not in prot_v_deflection: + prot_v_deflection.append(protocol_name_amp) + # maps protocols of interest with all its associated features # also get protocol data we need for feature registration prots_to_feats = {} prots_data = {} - for objective in evaluator.fitness_calculator.objectives: - feat = objective.features[0] - for protocol_name in prot_v_deflection + prot_max_v: - if protocol_name in feat.recording_names[""]: + + for protocol_name in prot_v_deflection + prot_max_v: + matched_feat, feat_already_present = find_matching_feature(evaluator, protocol_name) + if matched_feat is not None: + if feat_already_present: if protocol_name not in prots_to_feats: prots_to_feats[protocol_name] = [] if protocol_name not in prots_data: prots_data[protocol_name] = { - "stimulus_current": feat.stimulus_current, - "stim_start": feat.stim_start, - "stim_end": feat.stim_end, + "stimulus_current": matched_feat.stimulus_current, + "stim_start": matched_feat.stim_start, + "stim_end": matched_feat.stim_end, } - prots_to_feats[protocol_name].append(feat.efel_feature_name) - continue + prots_to_feats[protocol_name].append(matched_feat.efel_feature_name) + else: + p_rel_name = matched_feat.recording_names[""].split(".")[0] + amp_rel = float(protocol_name.split("_")[1]) + amp = float(matched_feat.recording_names[""].split(".")[0].split("_")[-1]) + p_rel = updated_evaluator.fitness_protocols["main_protocol"].protocols[p_rel_name] + p = create_protocol(amp_rel, amp, matched_feat, p_rel, protocol_name) + updated_evaluator.fitness_protocols["main_protocol"].protocols[protocol_name] = p + + if protocol_name not in prots_to_feats: + prots_to_feats[protocol_name] = [] + if protocol_name not in prots_data: + prots_data[protocol_name] = { + "stimulus_current": matched_feat.stimulus_current() * amp_rel / amp, + "stim_start": matched_feat.stim_start, + "stim_end": matched_feat.stim_end, + } + prots_to_feats[protocol_name].append(matched_feat.efel_feature_name) # add missing features for protocol_name, feat_list in prots_to_feats.items(): @@ -326,6 +427,10 @@ def fill_in_IV_curve_evaluator(evaluator, efel_settings, prot_name="iv"): ) ) + updated_evaluator.fitness_protocols["main_protocol"].execution_order = ( + updated_evaluator.fitness_protocols["main_protocol"].compute_execution_order() + ) + return updated_evaluator @@ -381,7 +486,7 @@ def get_simulated_FI_curve_for_plotting(evaluator, responses, prot_name): simulated_amp_rel = [] simulated_amp = [] for val in values: - if prot_name.lower() in val.lower(): + if prot_name.lower().split("_")[0] in val.lower(): protocol_name = get_protocol_name(val) amp_temp = float(protocol_name.split("_")[-1]) if "mean_frequency" in val: @@ -650,3 +755,105 @@ def get_voltage_currents_from_files(key_dict, output_dir): ionic_concentrations = [numpy.loadtxt(ion_conc_path)[:, 1] for ion_conc_path in ion_conc_paths] return time, voltage, currents, ionic_concentrations + + +def get_original_protocol_name(prot_name, evaluator): + """Retrieve the protocol name as defined by the user, preserving the original case""" + for protocol_name in evaluator.fitness_protocols["main_protocol"].protocols: + if prot_name.lower() in protocol_name.lower(): + return protocol_name + return prot_name + + +def update_evaluator(expt_amp_rel, prot_name, evaluator): + """update evaluator with new simulation protocols.""" + for amp_rel in expt_amp_rel: + protocol_name = f"{prot_name.split('_')[0]}_{int(amp_rel)}" + protocol = evaluator.fitness_protocols["main_protocol"].protocols[prot_name] + if protocol_name not in evaluator.fitness_protocols["main_protocol"].protocols: + p = create_protocol(int(amp_rel), None, None, protocol, protocol_name) + evaluator.fitness_protocols["main_protocol"].protocols[protocol_name] = p + + for objective in evaluator.fitness_calculator.objectives: + feat = objective.features[0] + if ( + protocol_name.split("_", maxsplit=1)[0] in feat.recording_names[""] + and "mean_frequency" in feat.efel_feature_name + ): + feat_name = f"{protocol_name}.soma.v.mean_frequency" + amp_rel = float(protocol_name.split("_")[1]) + amp = float(feat.recording_names[""].split(".")[0].split("_")[-1]) + evaluator.fitness_calculator.objectives.append( + SingletonWeightObjective( + feat_name, + eFELFeatureBPEM( + feat_name, + efel_feature_name="mean_frequency", + recording_names={"": f"{protocol_name}.soma.v"}, + stim_start=feat.stim_start, + stim_end=feat.stim_end, + exp_mean=1.0, # fodder: not used + exp_std=1.0, # fodder: not used + threshold=feat.threshold, + stimulus_current=feat.stimulus_current() * amp_rel / amp, + weight=1.0, + ), + 1.0, + ) + ) + break + return evaluator + + +def plot_fi_curves(expt_data, sim_data, figures_dir, emodel, write_fig): + """Plot and save the FI curves.""" + ( + expt_amp_rel, + expt_freq_rel, + expt_freq_rel_err, + expt_amp, + expt_freq_abs, + expt_freq_abs_err, + ) = expt_data + simulated_amp_rel, simulated_amp, simulated_freq = sim_data + + _, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 3)) + ax[0].errorbar( + expt_amp_rel, + expt_freq_rel, + yerr=expt_freq_rel_err, + marker="o", + color="grey", + label="experiment", + ) + ax[0].plot(simulated_amp_rel, simulated_freq, "o", color="blue", label="model") + ax[0].set_xlabel("Amplitude (% of rheobase)") + ax[0].set_ylabel("Mean Frequency (Hz)") + ax[0].set_title("FI curve (relative amplitude)") + ax[0].legend() + + ax[1].errorbar( + expt_amp, + expt_freq_abs, + yerr=expt_freq_abs_err, + marker="o", + color="grey", + label="experiment", + ) + ax[1].plot(simulated_amp, simulated_freq, "o", color="blue", label="model") + ax[1].set_xlabel("Amplitude (nA)") + ax[1].set_ylabel("Voltage (mV)") + ax[1].set_title("FI curve (absolute amplitude)") + ax[1].legend() + + if write_fig: + filename = f"{emodel.emodel_metadata.as_string(emodel.seed)}__FI_curve_comparison.pdf" + save_fig(figures_dir, filename) + + +def save_fig(figures_dir, figure_name, dpi=100): + """Save a matplotlib figure""" + p = Path(figures_dir) / figure_name + plt.savefig(str(p), dpi=dpi, bbox_inches="tight") + plt.close("all") + plt.clf() diff --git a/examples/nexus/exploit_model.ipynb b/examples/nexus/exploit_model.ipynb index c8b4a70b..cdd80fe9 100644 --- a/examples/nexus/exploit_model.ipynb +++ b/examples/nexus/exploit_model.ipynb @@ -221,7 +221,10 @@ "metadata": {}, "outputs": [], "source": [ - "evaluator.fitness_protocols[\"main_protocol\"].protocols[protocol_name] = my_protocol" + "evaluator.fitness_protocols[\"main_protocol\"].protocols[protocol_name] = my_protocol\n", + "evaluator.fitness_protocols[\"main_protocol\"].execution_order = (\n", + " evaluator.fitness_protocols[\"main_protocol\"].compute_execution_order()\n", + ")" ] }, {