From fc61b974cc61455134927154308c8e9a90103859 Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Wed, 29 Mar 2023 17:50:09 +0200 Subject: [PATCH 01/14] Start predict CLI command --- src/morphoclass/console/cmd_predict.py | 248 +++++++++++++++++++++++++ 1 file changed, 248 insertions(+) create mode 100644 src/morphoclass/console/cmd_predict.py diff --git a/src/morphoclass/console/cmd_predict.py b/src/morphoclass/console/cmd_predict.py new file mode 100644 index 0000000..518c210 --- /dev/null +++ b/src/morphoclass/console/cmd_predict.py @@ -0,0 +1,248 @@ +# Copyright © 2022-2022 Blue Brain Project/EPFL +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of the `morphoclass predict` CLI command.""" +import functools +import logging +import textwrap + +import click + +logger = logging.getLogger(__name__) + + +@click.command( + name="predict", + help="Run inference.", +) +@click.help_option("-h", "--help") +@click.option( + "-i", + "--input-csv", + required=True, + type=click.Path(exists=True, dir_okay=False), + help=textwrap.dedent( + """ + A CSV file with paths to the morphology files in the + first column + """ + ).strip(), +) +@click.option( + "-c", + "--checkpoint", + "checkpoint_file", + required=True, + type=click.Path(exists=True, file_okay=True, dir_okay=False), + help=textwrap.dedent( + """ + The path to the pre-trained model checkpoint. + """ + ).strip(), +) +@click.option( + "-o", + "--output-dir", + required=True, + type=click.Path(exists=False, file_okay=False, writable=True), + help="Output directory for the results.", +) +@click.option( + "-n", + "--results-name", + required=False, + type=click.STRING, + help="The filename of the results file", +) +def cli(input_csv, checkpoint_file, output_dir, results_name): + """Run the `deepm predict` CLI command. + + Parameters + ---------- + input_csv + The CSV file with the input data paths. + checkpoint_file + The path to the checkpoint file. + output_dir + The path to the output directory. + results_name + File prefix for results output files. + """ + import json + import pathlib + from datetime import datetime + + input_csv = pathlib.Path(input_csv).resolve() + output_dir = pathlib.Path(output_dir).resolve() + checkpoint_file = pathlib.Path(checkpoint_file).resolve() + if results_name is None: + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + results_name = f"results_{timestamp}" + results_path = output_dir / (results_name + ".json") + click.secho(f"Input CSV : {input_csv}", fg="yellow") + click.secho(f"Output file : {results_path}", fg="yellow") + click.secho(f"Checkpoint : {checkpoint_file}", fg="yellow") + if results_path.exists(): + msg = f'Results file "{results_path}" exists, overwrite? (y/[n]) ' + click.secho(msg, fg="red", bold=True, nl=False) + response = input() + if response.strip().lower() != "y": + click.secho("Stopping.", fg="red") + return + else: + click.secho("You chose to overwrite, proceeding...", fg="red") + + click.secho("✔ Loading checkpoint...", fg="green", bold=True) + import numpy as np + import torch + + from morphoclass.data import MorphologyDataset + + checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu')) + model_class = checkpoint["model_class"] + click.secho(f"Model : {model_class}", fg="yellow") + if "metadata" in checkpoint: + timestamp = checkpoint["metadata"]["timestamp"] + click.secho(f"Created on : {timestamp}", fg="yellow") + + click.secho("✔ Loading data...", fg="green", bold=True) + dataset = MorphologyDataset.from_csv(csv_file=input_csv) + click.echo(f"> Dataset length: {len(dataset)}") + + click.secho("✔ Computing predictions...", fg="green", bold=True) + if "ManNet" in model_class: + logits = predict_gnn(dataset, checkpoint) + predictions = logits.argmax(axis=1) + elif "CNN" in model_class: + logits = predict_cnn(dataset, checkpoint) + predictions = logits.argmax(axis=1) + elif "XGB" in model_class: + predictions = predict_xgb(dataset, checkpoint) + else: + click.secho( + f"Model not recognized: {model_name}. Stopping.", + fg="red", + bold=True, + nl=False, + ) + return + logger.info(f"Accuracy: {np.mean(predictions == dataset_pi[2]):.2f}") + + click.secho("✔ Exporting results...", fg="green", bold=True) + prediction_lables = {} + for sample, sample_pred in zip(dataset.data, predictions): + sample_path = str(sample.file) + pred_label = dataset.class_dict[sample_pred] + prediction_lables[str(sample_path)] = pred_label + + results = dict() + results["predictions"] = prediction_lables + results["checkpoint_path"] = str(checkpoint_file) + results["model"] = model_name + with open(results_path, "w") as fp: + json.dump(results, fp) + + click.secho("✔ Done.", fg="green", bold=True) + + +def predict_gnn(dataset, checkpoint): + """Compute predictions with a GNN (ManNet) classifier. + + Parameters + ---------- + dataset + The morphology dataset. + checkpoint + The model checkpoint. + + Returns + ------- + logits + The predictions logits. + """ + import torch + + import morphoclass.models + + model_cls = getattr(morphoclass.models, checkpoint["model_class"].rpartition(".")[2]) + model = model_cls(**checkpoint["model_params"]) + model.load_state_dict(checkpoint["all"]["model"]) + model.eval() + logits = [model_cnn(sample) for sample in dataset] + + return np.array(logits) + + +def predict_cnn(dataset, checkpoint): + """Compute predictions with a CNN classifier. + + Parameters + ---------- + dataset + The persistence image dataset. + checkpoint + The model checkpoint. + + Returns + ------- + logits + The predictions logits. + """ + import numpy as np + import torch + from torch.utils.data import DataLoader, TensorDataset + + import morphoclass.models + from morphoclass.data import MorphologyDataLoader + + # Model + model_cls = getattr(morphoclass.models, checkpoint["model_class"].rpartition(".")[2]) + model = model_cls(**checkpoint["model_params"]) + model.load_state_dict(checkpoint["all"]["model"]) + + # Data + loader = MorphologyDataLoader(dataset) + + # Evaluation + model.eval() + logits = [] + with torch.no_grad(): + for batch in iter(loader): + batch_logits = model(batch).numpy() + logits.append(batch_logits) + if len(logits) > 0: + logits = np.concatenate(logits) + else: + logits = np.array(logits) + + return logits + + +def predict_xgb(dataset, checkpoint): + """Compute predictions with XGBoost classifier. + + Parameters + ---------- + dataset + The morphology persistence image dataset. + checkpoint + The model checkpoint. + + Returns + ------- + predictions + The predictions. + """ + model = checkpoint["all"]["model"] + predictions = [model.predict(sample.image.numpy().reshape(1, 10000)) for sample in dataset] + return predictions From 29c18ec7330841f7b12b827e0c7ea21cb350bf4e Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 12:21:21 +0200 Subject: [PATCH 02/14] Fix issues --- requirements.txt | 2 +- setup.cfg | 4 +- src/morphoclass/console/cmd_predict.py | 57 +++++++++++--------------- src/morphoclass/console/main.py | 2 + 4 files changed, 30 insertions(+), 35 deletions(-) diff --git a/requirements.txt b/requirements.txt index e91ca17..af2be82 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ scipy==1.7.3 seaborn==0.11.0 shap[plots]==0.39.0 tmd==2.1.0 -torch==1.7.1 +torch==1.9.0 tqdm==4.53.0 umap-learn==0.5.1 xgboost==1.4.2 diff --git a/setup.cfg b/setup.cfg index 6a0083a..3f00e1c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,8 +36,8 @@ install_requires = imbalanced-learn jinja2 matplotlib - morphio - morphology-workflows>=0.3.0 + #morphio + #morphology-workflows>=0.3.0 networkx neurom>=3 NeuroR>=1.6.1 diff --git a/src/morphoclass/console/cmd_predict.py b/src/morphoclass/console/cmd_predict.py index 518c210..ed6f59b 100644 --- a/src/morphoclass/console/cmd_predict.py +++ b/src/morphoclass/console/cmd_predict.py @@ -27,14 +27,13 @@ ) @click.help_option("-h", "--help") @click.option( - "-i", - "--input-csv", + "-f", + "--features-dir", required=True, - type=click.Path(exists=True, dir_okay=False), + type=click.Path(exists=True, dir_okay=True), help=textwrap.dedent( """ - A CSV file with paths to the morphology files in the - first column + The path to the extracted features of the morphologies to classify """ ).strip(), ) @@ -64,13 +63,13 @@ type=click.STRING, help="The filename of the results file", ) -def cli(input_csv, checkpoint_file, output_dir, results_name): +def cli(features_dir, checkpoint_file, output_dir, results_name): """Run the `deepm predict` CLI command. Parameters ---------- - input_csv - The CSV file with the input data paths. + features_dir + The path to the features of the morphologies. checkpoint_file The path to the checkpoint file. output_dir @@ -82,16 +81,16 @@ def cli(input_csv, checkpoint_file, output_dir, results_name): import pathlib from datetime import datetime - input_csv = pathlib.Path(input_csv).resolve() + features_dir = pathlib.Path(features_dir).resolve() output_dir = pathlib.Path(output_dir).resolve() checkpoint_file = pathlib.Path(checkpoint_file).resolve() if results_name is None: timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") results_name = f"results_{timestamp}" results_path = output_dir / (results_name + ".json") - click.secho(f"Input CSV : {input_csv}", fg="yellow") - click.secho(f"Output file : {results_path}", fg="yellow") - click.secho(f"Checkpoint : {checkpoint_file}", fg="yellow") + click.secho(f"Features Dir : {features_dir}", fg="yellow") + click.secho(f"Output file : {results_path}", fg="yellow") + click.secho(f"Checkpoint : {checkpoint_file}", fg="yellow") if results_path.exists(): msg = f'Results file "{results_path}" exists, overwrite? (y/[n]) ' click.secho(msg, fg="red", bold=True, nl=False) @@ -107,6 +106,7 @@ def cli(input_csv, checkpoint_file, output_dir, results_name): import torch from morphoclass.data import MorphologyDataset + from morphoclass.data.morphology_data import MorphologyData checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu')) model_class = checkpoint["model_class"] @@ -116,7 +116,10 @@ def cli(input_csv, checkpoint_file, output_dir, results_name): click.secho(f"Created on : {timestamp}", fg="yellow") click.secho("✔ Loading data...", fg="green", bold=True) - dataset = MorphologyDataset.from_csv(csv_file=input_csv) + data = [] + for path in sorted(features_dir.glob("*.features")): + data.append(MorphologyData.load(path)) + dataset = MorphologyDataset(data) click.echo(f"> Dataset length: {len(dataset)}") click.secho("✔ Computing predictions...", fg="green", bold=True) @@ -136,19 +139,18 @@ def cli(input_csv, checkpoint_file, output_dir, results_name): nl=False, ) return - logger.info(f"Accuracy: {np.mean(predictions == dataset_pi[2]):.2f}") click.secho("✔ Exporting results...", fg="green", bold=True) - prediction_lables = {} + prediction_labels = {} for sample, sample_pred in zip(dataset.data, predictions): - sample_path = str(sample.file) - pred_label = dataset.class_dict[sample_pred] - prediction_lables[str(sample_path)] = pred_label + sample_path = str(sample.path) + pred_label = dataset.y_to_label[sample_pred] + prediction_labels[sample_path] = pred_label results = dict() - results["predictions"] = prediction_lables + results["predictions"] = prediction_labels results["checkpoint_path"] = str(checkpoint_file) - results["model"] = model_name + results["model"] = model_class with open(results_path, "w") as fp: json.dump(results, fp) @@ -178,7 +180,7 @@ def predict_gnn(dataset, checkpoint): model = model_cls(**checkpoint["model_params"]) model.load_state_dict(checkpoint["all"]["model"]) model.eval() - logits = [model_cnn(sample) for sample in dataset] + logits = [model(sample) for sample in dataset] return np.array(logits) @@ -203,23 +205,14 @@ def predict_cnn(dataset, checkpoint): from torch.utils.data import DataLoader, TensorDataset import morphoclass.models - from morphoclass.data import MorphologyDataLoader # Model model_cls = getattr(morphoclass.models, checkpoint["model_class"].rpartition(".")[2]) model = model_cls(**checkpoint["model_params"]) model.load_state_dict(checkpoint["all"]["model"]) - # Data - loader = MorphologyDataLoader(dataset) - # Evaluation - model.eval() - logits = [] - with torch.no_grad(): - for batch in iter(loader): - batch_logits = model(batch).numpy() - logits.append(batch_logits) + logits = [model(sample.image).detach().numpy() for sample in dataset] if len(logits) > 0: logits = np.concatenate(logits) else: @@ -244,5 +237,5 @@ def predict_xgb(dataset, checkpoint): The predictions. """ model = checkpoint["all"]["model"] - predictions = [model.predict(sample.image.numpy().reshape(1, 10000)) for sample in dataset] + predictions = [model.predict(sample.image.numpy().reshape(1, 10000))[0] for sample in dataset] return predictions diff --git a/src/morphoclass/console/main.py b/src/morphoclass/console/main.py index 886d2f8..0e510fd 100644 --- a/src/morphoclass/console/main.py +++ b/src/morphoclass/console/main.py @@ -26,6 +26,7 @@ from morphoclass.console import cmd_organise_dataset from morphoclass.console import cmd_performance_table from morphoclass.console import cmd_plot_dataset_stats +from morphoclass.console import cmd_predict from morphoclass.console import cmd_preprocess_dataset from morphoclass.console import cmd_train from morphoclass.console import cmd_xai @@ -137,6 +138,7 @@ def cli(verbose: int, log_file_path: pathlib.Path | None) -> None: cli.add_command(cmd_xai.cli) cli.add_command(cmd_organise_dataset.cli) cli.add_command(cmd_plot_dataset_stats.cli) +cli.add_command(cmd_predict.cli) cli.add_command(cmd_preprocess_dataset.cli) cli.add_command(cmd_train.cli) cli.add_command(cmd_evaluate.cli) From a69c073e4047edead7a1144886d7a26282aab752 Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 12:27:03 +0200 Subject: [PATCH 03/14] Change typo --- src/morphoclass/console/cmd_predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/morphoclass/console/cmd_predict.py b/src/morphoclass/console/cmd_predict.py index ed6f59b..2cd68e8 100644 --- a/src/morphoclass/console/cmd_predict.py +++ b/src/morphoclass/console/cmd_predict.py @@ -64,7 +64,7 @@ help="The filename of the results file", ) def cli(features_dir, checkpoint_file, output_dir, results_name): - """Run the `deepm predict` CLI command. + """Run the `morphoclass predict` CLI command. Parameters ---------- From 3d3e364d1465eba6e45fb106f343ca13a7f38a31 Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 12:54:55 +0200 Subject: [PATCH 04/14] Try to combine extract-features and predict --- .../console/cmd_extract_features.py | 31 ++++ .../cmd_extract_features_and_predict.py | 142 ++++++++++++++++++ src/morphoclass/console/cmd_predict.py | 9 +- src/morphoclass/console/main.py | 2 + 4 files changed, 181 insertions(+), 3 deletions(-) create mode 100644 src/morphoclass/console/cmd_extract_features_and_predict.py diff --git a/src/morphoclass/console/cmd_extract_features.py b/src/morphoclass/console/cmd_extract_features.py index c952be5..8b1c635 100644 --- a/src/morphoclass/console/cmd_extract_features.py +++ b/src/morphoclass/console/cmd_extract_features.py @@ -114,6 +114,37 @@ def cli( keep_diagram: bool, force: bool, ) -> None: + + return extract_features( + csv_path, + neurite_type, + feature, + output_dir, + orient, + no_simplify_graph, + keep_diagram, + force, + ) + +def extract_features( + csv_path: StrPath, + neurite_type: Literal["apical", "axon", "basal", "all"], + feature: Literal[ + "graph-rd", + "graph-proj", + "diagram-tmd-rd", + "diagram-tmd-proj", + "diagram-deepwalk", + "image-tmd-rd", + "image-tmd-proj", + "image-deepwalk", + ], + output_dir: StrPath, + orient: bool, + no_simplify_graph: bool, + keep_diagram: bool, + force: bool, +) """Extract morphology features.""" output_dir = pathlib.Path(output_dir) if output_dir.exists() and not force: diff --git a/src/morphoclass/console/cmd_extract_features_and_predict.py b/src/morphoclass/console/cmd_extract_features_and_predict.py new file mode 100644 index 0000000..c26c674 --- /dev/null +++ b/src/morphoclass/console/cmd_extract_features_and_predict.py @@ -0,0 +1,142 @@ +# Copyright © 2022-2022 Blue Brain Project/EPFL +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of the `morphoclass predict` CLI command.""" +import functools +import logging +import textwrap + +import click + +logger = logging.getLogger(__name__) + + +@click.command( + name="predict", + help="Run inference.", +) +@click.help_option("-h", "--help") +@click.option( + "-i", + "--input_csv", + required=True, + type=click.Path(exists=True, dir_okay=True), + help=textwrap.dedent( + """ + The CSV path with the path to all the morphologies to classify. + """ + ).strip(), +) +@click.option( + "-c", + "--checkpoint", + "checkpoint_file", + required=True, + type=click.Path(exists=True, file_okay=True, dir_okay=False), + help=textwrap.dedent( + """ + The path to the pre-trained model checkpoint. + """ + ).strip(), +) +@click.option( + "-o", + "--output-dir", + required=True, + type=click.Path(exists=False, file_okay=False, writable=True), + help="Output directory for the results.", +) +@click.option( + "-n", + "--results-name", + required=False, + type=click.STRING, + help="The filename of the results file", +) +def cli(input_csv, checkpoint_file, output_dir, results_name): + """Run the `morphoclass predict` CLI command. + + Parameters + ---------- + input_csv + The CSV with all the morphologies path. + checkpoint_file + The path to the checkpoint file. + output_dir + The path to the output directory. + results_name + File prefix for results output files. + """ + import json + import pathlib + from datetime import datetime + + input_csv = pathlib.Path(input_csv).resolve() + output_dir = pathlib.Path(output_dir).resolve() + checkpoint_file = pathlib.Path(checkpoint_file).resolve() + if results_name is None: + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + results_name = f"results_{timestamp}" + results_path = output_dir / (results_name + ".json") + click.secho(f"Input CSV : {input_csv}", fg="yellow") + click.secho(f"Output file : {results_path}", fg="yellow") + click.secho(f"Checkpoint : {checkpoint_file}", fg="yellow") + if results_path.exists(): + msg = f'Results file "{results_path}" exists, overwrite? (y/[n]) ' + click.secho(msg, fg="red", bold=True, nl=False) + response = input() + if response.strip().lower() != "y": + click.secho("Stopping.", fg="red") + return + else: + click.secho("You chose to overwrite, proceeding...", fg="red") + + click.secho("✔ Loading checkpoint...", fg="green", bold=True) + import numpy as np + import torch + + from morphoclass.console.cmd_extract_features import extract_features + from morphoclass.console.cmd_predict import predict + + checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu')) + neurites = ["apical", "axon", "basal", "all"] + neurite_type = [neurite for neurite in neurites if neurite in str(checkpoint["features_dir"])] + features_type = [ + "graph-rd", + "graph-proj", + "diagram-tmd-rd", + "diagram-tmd-proj", + "diagram-deepwalk", + "image-tmd-rd", + "image-tmd-proj", + "image-deepwalk", + ] + feature = [feature for feature in features_type if feature in str(checkpoint["features_dir"])] + + extract_features( + input_csv, + neurite_type, + feature, + output_dir / "features", + False, + False, + False, + False, + ) + + predict( + features_dir=output_dir / "features", + checkpoint_file=checkpoint_file, + output_dir=output_dir, + results_name=results_name, + ) diff --git a/src/morphoclass/console/cmd_predict.py b/src/morphoclass/console/cmd_predict.py index 2cd68e8..ec89ba2 100644 --- a/src/morphoclass/console/cmd_predict.py +++ b/src/morphoclass/console/cmd_predict.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of the `morphoclass predict` CLI command.""" +"""Implementation of the `morphoclass predict-after-extraction` CLI command.""" import functools import logging import textwrap @@ -22,8 +22,8 @@ @click.command( - name="predict", - help="Run inference.", + name="predict-after-extraction", + help="Run inference from features directory.", ) @click.help_option("-h", "--help") @click.option( @@ -77,6 +77,9 @@ def cli(features_dir, checkpoint_file, output_dir, results_name): results_name File prefix for results output files. """ + return predict(features_dir, checkpoint_file, output_dir, results_name) + +def predict(features_dir, checkpoint_file, output_dir, results_name) import json import pathlib from datetime import datetime diff --git a/src/morphoclass/console/main.py b/src/morphoclass/console/main.py index 0e510fd..cdb27f9 100644 --- a/src/morphoclass/console/main.py +++ b/src/morphoclass/console/main.py @@ -22,6 +22,7 @@ import morphoclass from morphoclass.console import cmd_evaluate from morphoclass.console import cmd_extract_features +from morphoclass.console import cmd_extract_features_and_predict from morphoclass.console import cmd_morphometrics from morphoclass.console import cmd_organise_dataset from morphoclass.console import cmd_performance_table @@ -144,4 +145,5 @@ def cli(verbose: int, log_file_path: pathlib.Path | None) -> None: cli.add_command(cmd_evaluate.cli) cli.add_command(cmd_performance_table.cli) cli.add_command(cmd_extract_features.cli) +cli.add_command(cmd_extract_features_and_predict.cli) cli.add_command(cmd_morphometrics.cli) From c10b17fe018dec078062d7959e238468c8485bd6 Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 12:59:35 +0200 Subject: [PATCH 05/14] Make flake8 and black happy --- .../console/cmd_extract_features.py | 6 ++--- .../cmd_extract_features_and_predict.py | 27 ++++++++++--------- src/morphoclass/console/cmd_predict.py | 21 ++++++++++----- src/morphoclass/models/concatenet.py | 1 - .../add_random_points_to_reduction_mask.py | 1 - 5 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/morphoclass/console/cmd_extract_features.py b/src/morphoclass/console/cmd_extract_features.py index 8b1c635..1ec4b25 100644 --- a/src/morphoclass/console/cmd_extract_features.py +++ b/src/morphoclass/console/cmd_extract_features.py @@ -114,7 +114,6 @@ def cli( keep_diagram: bool, force: bool, ) -> None: - return extract_features( csv_path, neurite_type, @@ -124,7 +123,8 @@ def cli( no_simplify_graph, keep_diagram, force, - ) + ) + def extract_features( csv_path: StrPath, @@ -144,7 +144,7 @@ def extract_features( no_simplify_graph: bool, keep_diagram: bool, force: bool, -) +): """Extract morphology features.""" output_dir = pathlib.Path(output_dir) if output_dir.exists() and not force: diff --git a/src/morphoclass/console/cmd_extract_features_and_predict.py b/src/morphoclass/console/cmd_extract_features_and_predict.py index c26c674..d3898f1 100644 --- a/src/morphoclass/console/cmd_extract_features_and_predict.py +++ b/src/morphoclass/console/cmd_extract_features_and_predict.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the `morphoclass predict` CLI command.""" -import functools import logging import textwrap @@ -32,9 +31,9 @@ required=True, type=click.Path(exists=True, dir_okay=True), help=textwrap.dedent( - """ - The CSV path with the path to all the morphologies to classify. - """ + """ + The CSV path with the path to all the morphologies to classify. + """ ).strip(), ) @click.option( @@ -44,9 +43,9 @@ required=True, type=click.Path(exists=True, file_okay=True, dir_okay=False), help=textwrap.dedent( - """ - The path to the pre-trained model checkpoint. - """ + """ + The path to the pre-trained model checkpoint. + """ ).strip(), ) @click.option( @@ -77,7 +76,6 @@ def cli(input_csv, checkpoint_file, output_dir, results_name): results_name File prefix for results output files. """ - import json import pathlib from datetime import datetime @@ -102,15 +100,16 @@ def cli(input_csv, checkpoint_file, output_dir, results_name): click.secho("You chose to overwrite, proceeding...", fg="red") click.secho("✔ Loading checkpoint...", fg="green", bold=True) - import numpy as np import torch from morphoclass.console.cmd_extract_features import extract_features from morphoclass.console.cmd_predict import predict - checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu')) + checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu")) neurites = ["apical", "axon", "basal", "all"] - neurite_type = [neurite for neurite in neurites if neurite in str(checkpoint["features_dir"])] + neurite_type = [ + neurite for neurite in neurites if neurite in str(checkpoint["features_dir"]) + ] features_type = [ "graph-rd", "graph-proj", @@ -121,7 +120,11 @@ def cli(input_csv, checkpoint_file, output_dir, results_name): "image-tmd-proj", "image-deepwalk", ] - feature = [feature for feature in features_type if feature in str(checkpoint["features_dir"])] + feature = [ + feature + for feature in features_type + if feature in str(checkpoint["features_dir"]) + ] extract_features( input_csv, diff --git a/src/morphoclass/console/cmd_predict.py b/src/morphoclass/console/cmd_predict.py index ec89ba2..3184307 100644 --- a/src/morphoclass/console/cmd_predict.py +++ b/src/morphoclass/console/cmd_predict.py @@ -32,7 +32,7 @@ required=True, type=click.Path(exists=True, dir_okay=True), help=textwrap.dedent( - """ + """ The path to the extracted features of the morphologies to classify """ ).strip(), @@ -44,7 +44,7 @@ required=True, type=click.Path(exists=True, file_okay=True, dir_okay=False), help=textwrap.dedent( - """ + """ The path to the pre-trained model checkpoint. """ ).strip(), @@ -79,7 +79,8 @@ def cli(features_dir, checkpoint_file, output_dir, results_name): """ return predict(features_dir, checkpoint_file, output_dir, results_name) -def predict(features_dir, checkpoint_file, output_dir, results_name) + +def predict(features_dir, checkpoint_file, output_dir, results_name): import json import pathlib from datetime import datetime @@ -111,7 +112,7 @@ def predict(features_dir, checkpoint_file, output_dir, results_name) from morphoclass.data import MorphologyDataset from morphoclass.data.morphology_data import MorphologyData - checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu')) + checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu")) model_class = checkpoint["model_class"] click.secho(f"Model : {model_class}", fg="yellow") if "metadata" in checkpoint: @@ -179,7 +180,9 @@ def predict_gnn(dataset, checkpoint): import morphoclass.models - model_cls = getattr(morphoclass.models, checkpoint["model_class"].rpartition(".")[2]) + model_cls = getattr( + morphoclass.models, checkpoint["model_class"].rpartition(".")[2] + ) model = model_cls(**checkpoint["model_params"]) model.load_state_dict(checkpoint["all"]["model"]) model.eval() @@ -210,7 +213,9 @@ def predict_cnn(dataset, checkpoint): import morphoclass.models # Model - model_cls = getattr(morphoclass.models, checkpoint["model_class"].rpartition(".")[2]) + model_cls = getattr( + morphoclass.models, checkpoint["model_class"].rpartition(".")[2] + ) model = model_cls(**checkpoint["model_params"]) model.load_state_dict(checkpoint["all"]["model"]) @@ -240,5 +245,7 @@ def predict_xgb(dataset, checkpoint): The predictions. """ model = checkpoint["all"]["model"] - predictions = [model.predict(sample.image.numpy().reshape(1, 10000))[0] for sample in dataset] + predictions = [ + model.predict(sample.image.numpy().reshape(1, 10000))[0] for sample in dataset + ] return predictions diff --git a/src/morphoclass/models/concatenet.py b/src/morphoclass/models/concatenet.py index 6ab610b..9953f16 100644 --- a/src/morphoclass/models/concatenet.py +++ b/src/morphoclass/models/concatenet.py @@ -45,7 +45,6 @@ class ConcateNet(nn.Module): """ def __init__(self, n_node_features, n_classes, n_features_perslay, bn=False): - super().__init__() self.n_node_features = n_node_features self.n_classes = n_classes diff --git a/src/morphoclass/transforms/augmentors/add_random_points_to_reduction_mask.py b/src/morphoclass/transforms/augmentors/add_random_points_to_reduction_mask.py index 54178cd..8279c1e 100644 --- a/src/morphoclass/transforms/augmentors/add_random_points_to_reduction_mask.py +++ b/src/morphoclass/transforms/augmentors/add_random_points_to_reduction_mask.py @@ -32,7 +32,6 @@ class AddRandomPointsToReductionMask: """ def __init__(self, n_points): - self.n_points = n_points @require_field("tmd_neurites_masks") From 099fd7e635bb8053481303b6d42877a9a8b86e0a Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 13:02:04 +0200 Subject: [PATCH 06/14] Make flake8 happy again --- src/morphoclass/console/cmd_predict.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/morphoclass/console/cmd_predict.py b/src/morphoclass/console/cmd_predict.py index 3184307..c2f3f2f 100644 --- a/src/morphoclass/console/cmd_predict.py +++ b/src/morphoclass/console/cmd_predict.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the `morphoclass predict-after-extraction` CLI command.""" -import functools import logging import textwrap @@ -106,7 +105,6 @@ def predict(features_dir, checkpoint_file, output_dir, results_name): click.secho("You chose to overwrite, proceeding...", fg="red") click.secho("✔ Loading checkpoint...", fg="green", bold=True) - import numpy as np import torch from morphoclass.data import MorphologyDataset @@ -137,7 +135,7 @@ def predict(features_dir, checkpoint_file, output_dir, results_name): predictions = predict_xgb(dataset, checkpoint) else: click.secho( - f"Model not recognized: {model_name}. Stopping.", + f"Model not recognized: {model_class}. Stopping.", fg="red", bold=True, nl=False, @@ -176,7 +174,7 @@ def predict_gnn(dataset, checkpoint): logits The predictions logits. """ - import torch + import numpy as np import morphoclass.models @@ -207,8 +205,6 @@ def predict_cnn(dataset, checkpoint): The predictions logits. """ import numpy as np - import torch - from torch.utils.data import DataLoader, TensorDataset import morphoclass.models From a4b5a834ee28319d9f6e52847b8f0461a7461141 Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 13:04:22 +0200 Subject: [PATCH 07/14] Fix apidoc --- ...orphoclass.console.cmd_extract_features_and_predict.rst | 7 +++++++ docs/source/api/morphoclass.console.cmd_predict.rst | 7 +++++++ docs/source/api/morphoclass.console.cmd_predict_bu.rst | 7 +++++++ docs/source/api/morphoclass.console.rst | 3 +++ 4 files changed, 24 insertions(+) create mode 100644 docs/source/api/morphoclass.console.cmd_extract_features_and_predict.rst create mode 100644 docs/source/api/morphoclass.console.cmd_predict.rst create mode 100644 docs/source/api/morphoclass.console.cmd_predict_bu.rst diff --git a/docs/source/api/morphoclass.console.cmd_extract_features_and_predict.rst b/docs/source/api/morphoclass.console.cmd_extract_features_and_predict.rst new file mode 100644 index 0000000..e144ac5 --- /dev/null +++ b/docs/source/api/morphoclass.console.cmd_extract_features_and_predict.rst @@ -0,0 +1,7 @@ +morphoclass.console.cmd\_extract\_features\_and\_predict module +=============================================================== + +.. automodule:: morphoclass.console.cmd_extract_features_and_predict + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/morphoclass.console.cmd_predict.rst b/docs/source/api/morphoclass.console.cmd_predict.rst new file mode 100644 index 0000000..9dcbf34 --- /dev/null +++ b/docs/source/api/morphoclass.console.cmd_predict.rst @@ -0,0 +1,7 @@ +morphoclass.console.cmd\_predict module +======================================= + +.. automodule:: morphoclass.console.cmd_predict + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/morphoclass.console.cmd_predict_bu.rst b/docs/source/api/morphoclass.console.cmd_predict_bu.rst new file mode 100644 index 0000000..dc67db6 --- /dev/null +++ b/docs/source/api/morphoclass.console.cmd_predict_bu.rst @@ -0,0 +1,7 @@ +morphoclass.console.cmd\_predict\_bu module +=========================================== + +.. automodule:: morphoclass.console.cmd_predict_bu + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api/morphoclass.console.rst b/docs/source/api/morphoclass.console.rst index 67415de..2411d9b 100644 --- a/docs/source/api/morphoclass.console.rst +++ b/docs/source/api/morphoclass.console.rst @@ -9,10 +9,13 @@ Submodules morphoclass.console.cmd_evaluate morphoclass.console.cmd_extract_features + morphoclass.console.cmd_extract_features_and_predict morphoclass.console.cmd_morphometrics morphoclass.console.cmd_organise_dataset morphoclass.console.cmd_performance_table morphoclass.console.cmd_plot_dataset_stats + morphoclass.console.cmd_predict + morphoclass.console.cmd_predict_bu morphoclass.console.cmd_preprocess_dataset morphoclass.console.cmd_train morphoclass.console.cmd_xai From 1a614cc625fab5c44ad86bb5ce81637ae925a4e9 Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 13:05:17 +0200 Subject: [PATCH 08/14] Remove useless doc --- docs/source/api/morphoclass.console.cmd_predict_bu.rst | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 docs/source/api/morphoclass.console.cmd_predict_bu.rst diff --git a/docs/source/api/morphoclass.console.cmd_predict_bu.rst b/docs/source/api/morphoclass.console.cmd_predict_bu.rst deleted file mode 100644 index dc67db6..0000000 --- a/docs/source/api/morphoclass.console.cmd_predict_bu.rst +++ /dev/null @@ -1,7 +0,0 @@ -morphoclass.console.cmd\_predict\_bu module -=========================================== - -.. automodule:: morphoclass.console.cmd_predict_bu - :members: - :undoc-members: - :show-inheritance: From de49de8aa964dbe28fdd70e7aaa40686455b5dea Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 13:12:45 +0200 Subject: [PATCH 09/14] Try to make flake8 happy --- .../console/cmd_extract_features.py | 1 + src/morphoclass/console/cmd_predict.py | 24 +++++++++++++++---- src/morphoclass/data/tns_dataset.py | 4 +--- src/morphoclass/vis.py | 2 +- tests/unit/test_metrics.py | 2 +- 5 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/morphoclass/console/cmd_extract_features.py b/src/morphoclass/console/cmd_extract_features.py index 1ec4b25..d84e94f 100644 --- a/src/morphoclass/console/cmd_extract_features.py +++ b/src/morphoclass/console/cmd_extract_features.py @@ -114,6 +114,7 @@ def cli( keep_diagram: bool, force: bool, ) -> None: + """Extract morphology features.""" return extract_features( csv_path, neurite_type, diff --git a/src/morphoclass/console/cmd_predict.py b/src/morphoclass/console/cmd_predict.py index c2f3f2f..930c949 100644 --- a/src/morphoclass/console/cmd_predict.py +++ b/src/morphoclass/console/cmd_predict.py @@ -63,7 +63,7 @@ help="The filename of the results file", ) def cli(features_dir, checkpoint_file, output_dir, results_name): - """Run the `morphoclass predict` CLI command. + """Run the `morphoclass predict-after-extraction` CLI command. Parameters ---------- @@ -80,6 +80,19 @@ def cli(features_dir, checkpoint_file, output_dir, results_name): def predict(features_dir, checkpoint_file, output_dir, results_name): + """Run the predict command. + + Parameters + ---------- + features_dir + The path to the features of the morphologies. + checkpoint_file + The path to the checkpoint file. + output_dir + The path to the output directory. + results_name + File prefix for results output files. + """ import json import pathlib from datetime import datetime @@ -149,10 +162,11 @@ def predict(features_dir, checkpoint_file, output_dir, results_name): pred_label = dataset.y_to_label[sample_pred] prediction_labels[sample_path] = pred_label - results = dict() - results["predictions"] = prediction_labels - results["checkpoint_path"] = str(checkpoint_file) - results["model"] = model_class + results = { + "predictions": prediction_labels, + "checkpoint_path": str(checkpoint_file), + "model": model_class + } with open(results_path, "w") as fp: json.dump(results, fp) diff --git a/src/morphoclass/data/tns_dataset.py b/src/morphoclass/data/tns_dataset.py index 6be3bb2..c5a190c 100644 --- a/src/morphoclass/data/tns_dataset.py +++ b/src/morphoclass/data/tns_dataset.py @@ -135,9 +135,7 @@ def __init__( f"No data corresponding to layer {layer} found in data_path" ) else: - self.class_dict = { - n: m_type for n, m_type in enumerate(sorted(self.m_types)) - } + self.class_dict = dict(enumerate(sorted(self.m_types))) self.class_dict_inv = {v: k for k, v in self.class_dict.items()} self.distributions = input_distributions diff --git a/src/morphoclass/vis.py b/src/morphoclass/vis.py index 008afef..2b7e45c 100644 --- a/src/morphoclass/vis.py +++ b/src/morphoclass/vis.py @@ -820,7 +820,7 @@ def plot_neurite( nx.draw( g, ax=ax, - pos={n: xy for n, xy in enumerate(zip(px, py))}, + pos=dict(enumerate(zip(px, py))), nodelist=[0], node_color="red", node_size=soma_size, diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index 06043b4..dfb987d 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -158,7 +158,7 @@ def test_inter_rater_score(targets, predictions, kind, score): def test_inter_rater_score_fail(targets, predictions): - with pytest.raises(Exception): + with pytest.raises(ValueError): morphoclass.metrics.inter_rater_score(targets, predictions, kind="invalid") From 7a5a18c0013fde379a905cb512e4bb6b3380e5ca Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 13:14:38 +0200 Subject: [PATCH 10/14] Make isort happy --- src/morphoclass/console/cmd_extract_features_and_predict.py | 2 ++ src/morphoclass/console/cmd_predict.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/morphoclass/console/cmd_extract_features_and_predict.py b/src/morphoclass/console/cmd_extract_features_and_predict.py index d3898f1..b3f21cb 100644 --- a/src/morphoclass/console/cmd_extract_features_and_predict.py +++ b/src/morphoclass/console/cmd_extract_features_and_predict.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the `morphoclass predict` CLI command.""" +from __future__ import annotations + import logging import textwrap diff --git a/src/morphoclass/console/cmd_predict.py b/src/morphoclass/console/cmd_predict.py index 930c949..415654a 100644 --- a/src/morphoclass/console/cmd_predict.py +++ b/src/morphoclass/console/cmd_predict.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Implementation of the `morphoclass predict-after-extraction` CLI command.""" +from __future__ import annotations + import logging import textwrap From 54a1ee2b44426b46ecee523533c4ced2e9f272f8 Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 13:58:28 +0200 Subject: [PATCH 11/14] Make mypy happy --- src/morphoclass/console/cmd_extract_features.py | 2 +- src/morphoclass/console/cmd_extract_features_and_predict.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/morphoclass/console/cmd_extract_features.py b/src/morphoclass/console/cmd_extract_features.py index d84e94f..954ab20 100644 --- a/src/morphoclass/console/cmd_extract_features.py +++ b/src/morphoclass/console/cmd_extract_features.py @@ -145,7 +145,7 @@ def extract_features( no_simplify_graph: bool, keep_diagram: bool, force: bool, -): +) -> None: """Extract morphology features.""" output_dir = pathlib.Path(output_dir) if output_dir.exists() and not force: diff --git a/src/morphoclass/console/cmd_extract_features_and_predict.py b/src/morphoclass/console/cmd_extract_features_and_predict.py index b3f21cb..30d19e6 100644 --- a/src/morphoclass/console/cmd_extract_features_and_predict.py +++ b/src/morphoclass/console/cmd_extract_features_and_predict.py @@ -130,8 +130,8 @@ def cli(input_csv, checkpoint_file, output_dir, results_name): extract_features( input_csv, - neurite_type, - feature, + neurite_type[0], + feature[0], output_dir / "features", False, False, From c00ae2e9d6a50da26cca84fd56c06525cf197096 Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 14:36:07 +0200 Subject: [PATCH 12/14] Make mypy happy (2) --- src/morphoclass/console/cmd_extract_features.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/morphoclass/console/cmd_extract_features.py b/src/morphoclass/console/cmd_extract_features.py index 954ab20..93eda5e 100644 --- a/src/morphoclass/console/cmd_extract_features.py +++ b/src/morphoclass/console/cmd_extract_features.py @@ -129,17 +129,8 @@ def cli( def extract_features( csv_path: StrPath, - neurite_type: Literal["apical", "axon", "basal", "all"], - feature: Literal[ - "graph-rd", - "graph-proj", - "diagram-tmd-rd", - "diagram-tmd-proj", - "diagram-deepwalk", - "image-tmd-rd", - "image-tmd-proj", - "image-deepwalk", - ], + neurite_type: str, + feature: str, output_dir: StrPath, orient: bool, no_simplify_graph: bool, From ceb1e09896010a4dc833f8e36904d826f18c0df6 Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 14:39:37 +0200 Subject: [PATCH 13/14] Make apidoc happy --- docs/source/api/morphoclass.console.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/api/morphoclass.console.rst b/docs/source/api/morphoclass.console.rst index 2411d9b..54ab684 100644 --- a/docs/source/api/morphoclass.console.rst +++ b/docs/source/api/morphoclass.console.rst @@ -15,7 +15,6 @@ Submodules morphoclass.console.cmd_performance_table morphoclass.console.cmd_plot_dataset_stats morphoclass.console.cmd_predict - morphoclass.console.cmd_predict_bu morphoclass.console.cmd_preprocess_dataset morphoclass.console.cmd_train morphoclass.console.cmd_xai From 23a55b0dba217cb5d779a66c9bfd27bcdb70b603 Mon Sep 17 00:00:00 2001 From: Delattre Emilie Date: Thu, 30 Mar 2023 14:49:30 +0200 Subject: [PATCH 14/14] Make black happy --- src/morphoclass/console/cmd_predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/morphoclass/console/cmd_predict.py b/src/morphoclass/console/cmd_predict.py index 415654a..dbf461b 100644 --- a/src/morphoclass/console/cmd_predict.py +++ b/src/morphoclass/console/cmd_predict.py @@ -167,7 +167,7 @@ def predict(features_dir, checkpoint_file, output_dir, results_name): results = { "predictions": prediction_labels, "checkpoint_path": str(checkpoint_file), - "model": model_class + "model": model_class, } with open(results_path, "w") as fp: json.dump(results, fp)