Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create CLI command morphoclass predict #88

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
morphoclass.console.cmd\_extract\_features\_and\_predict module
===============================================================

.. automodule:: morphoclass.console.cmd_extract_features_and_predict
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/api/morphoclass.console.cmd_predict.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
morphoclass.console.cmd\_predict module
=======================================

.. automodule:: morphoclass.console.cmd_predict
:members:
:undoc-members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/source/api/morphoclass.console.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ Submodules

morphoclass.console.cmd_evaluate
morphoclass.console.cmd_extract_features
morphoclass.console.cmd_extract_features_and_predict
morphoclass.console.cmd_morphometrics
morphoclass.console.cmd_organise_dataset
morphoclass.console.cmd_performance_table
morphoclass.console.cmd_plot_dataset_stats
morphoclass.console.cmd_predict
morphoclass.console.cmd_preprocess_dataset
morphoclass.console.cmd_train
morphoclass.console.cmd_xai
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ scipy==1.7.3
seaborn==0.11.0
shap[plots]==0.39.0
tmd==2.1.0
torch==1.7.1
torch==1.9.0
tqdm==4.53.0
umap-learn==0.5.1
xgboost==1.4.2
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ install_requires =
imbalanced-learn
jinja2
matplotlib
morphio
morphology-workflows>=0.3.0
#morphio
#morphology-workflows>=0.3.0
networkx
neurom>=3
NeuroR>=1.6.1
Expand Down
23 changes: 23 additions & 0 deletions src/morphoclass/console/cmd_extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ def cli(
no_simplify_graph: bool,
keep_diagram: bool,
force: bool,
) -> None:
"""Extract morphology features."""
return extract_features(
csv_path,
neurite_type,
feature,
output_dir,
orient,
no_simplify_graph,
keep_diagram,
force,
)


def extract_features(
csv_path: StrPath,
neurite_type: str,
feature: str,
output_dir: StrPath,
orient: bool,
no_simplify_graph: bool,
keep_diagram: bool,
force: bool,
) -> None:
"""Extract morphology features."""
output_dir = pathlib.Path(output_dir)
Expand Down
147 changes: 147 additions & 0 deletions src/morphoclass/console/cmd_extract_features_and_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright © 2022-2022 Blue Brain Project/EPFL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the `morphoclass predict` CLI command."""
from __future__ import annotations

import logging
import textwrap

import click

logger = logging.getLogger(__name__)


@click.command(
name="predict",
help="Run inference.",
)
@click.help_option("-h", "--help")
@click.option(
"-i",
"--input_csv",
required=True,
type=click.Path(exists=True, dir_okay=True),
help=textwrap.dedent(
"""
The CSV path with the path to all the morphologies to classify.
"""
).strip(),
)
@click.option(
"-c",
"--checkpoint",
"checkpoint_file",
required=True,
type=click.Path(exists=True, file_okay=True, dir_okay=False),
help=textwrap.dedent(
"""
The path to the pre-trained model checkpoint.
"""
).strip(),
)
@click.option(
"-o",
"--output-dir",
required=True,
type=click.Path(exists=False, file_okay=False, writable=True),
help="Output directory for the results.",
)
@click.option(
"-n",
"--results-name",
required=False,
type=click.STRING,
help="The filename of the results file",
)
def cli(input_csv, checkpoint_file, output_dir, results_name):
"""Run the `morphoclass predict` CLI command.

Parameters
----------
input_csv
The CSV with all the morphologies path.
checkpoint_file
The path to the checkpoint file.
output_dir
The path to the output directory.
results_name
File prefix for results output files.
"""
import pathlib
from datetime import datetime

input_csv = pathlib.Path(input_csv).resolve()
output_dir = pathlib.Path(output_dir).resolve()
checkpoint_file = pathlib.Path(checkpoint_file).resolve()
if results_name is None:
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
results_name = f"results_{timestamp}"
results_path = output_dir / (results_name + ".json")
click.secho(f"Input CSV : {input_csv}", fg="yellow")
click.secho(f"Output file : {results_path}", fg="yellow")
click.secho(f"Checkpoint : {checkpoint_file}", fg="yellow")
if results_path.exists():
msg = f'Results file "{results_path}" exists, overwrite? (y/[n]) '
click.secho(msg, fg="red", bold=True, nl=False)
response = input()
if response.strip().lower() != "y":
click.secho("Stopping.", fg="red")
return
else:
click.secho("You chose to overwrite, proceeding...", fg="red")

click.secho("✔ Loading checkpoint...", fg="green", bold=True)
import torch

from morphoclass.console.cmd_extract_features import extract_features
from morphoclass.console.cmd_predict import predict

checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
neurites = ["apical", "axon", "basal", "all"]
neurite_type = [
neurite for neurite in neurites if neurite in str(checkpoint["features_dir"])
]
features_type = [
"graph-rd",
"graph-proj",
"diagram-tmd-rd",
"diagram-tmd-proj",
"diagram-deepwalk",
"image-tmd-rd",
"image-tmd-proj",
"image-deepwalk",
]
feature = [
feature
for feature in features_type
if feature in str(checkpoint["features_dir"])
]

extract_features(
input_csv,
neurite_type[0],
feature[0],
output_dir / "features",
False,
False,
False,
False,
)

predict(
features_dir=output_dir / "features",
checkpoint_file=checkpoint_file,
output_dir=output_dir,
results_name=results_name,
)
Loading