From b1d7b7464a3ced88ba9d924486831d2b1fc15c97 Mon Sep 17 00:00:00 2001 From: Taylor Salo Date: Tue, 26 Oct 2021 16:08:02 -0400 Subject: [PATCH] [REF] Use new function to run LDA commands (#587) * Add new function for running shell commands. * Add test for new utility function. --- nimare/annotate/lda.py | 26 +++++++++++++------------- nimare/tests/test_utils.py | 25 +++++++++++++++++++++++++ nimare/utils.py | 26 ++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 13 deletions(-) diff --git a/nimare/annotate/lda.py b/nimare/annotate/lda.py index 3dbede6f8..1cd1a30ec 100644 --- a/nimare/annotate/lda.py +++ b/nimare/annotate/lda.py @@ -1,9 +1,7 @@ """Topic modeling with latent Dirichlet allocation via MALLET.""" import logging import os -import os.path as op import shutil -import subprocess import numpy as np import pandas as pd @@ -12,6 +10,7 @@ from ..base import NiMAREBase from ..due import due from ..extract import download_mallet, utils +from ..utils import run_shell_command LGR = logging.getLogger(__name__) @@ -73,12 +72,12 @@ def __init__( self, text_df, text_column="abstract", n_topics=50, n_iters=1000, alpha="auto", beta=0.001 ): mallet_dir = download_mallet() - mallet_bin = op.join(mallet_dir, "bin/mallet") + mallet_bin = os.path.join(mallet_dir, "bin/mallet") model_dir = utils._get_dataset_dir("mallet_model") - text_dir = op.join(model_dir, "texts") + text_dir = os.path.join(model_dir, "texts") - if not op.isdir(model_dir): + if not os.path.isdir(model_dir): os.mkdir(model_dir) if alpha == "auto": @@ -90,7 +89,7 @@ def __init__( self.model_dir = model_dir # Check for presence of text files and convert if necessary - if not op.isdir(text_dir): + if not os.path.isdir(text_dir): LGR.info("Texts folder not found. Creating text files...") os.mkdir(text_dir) @@ -104,11 +103,11 @@ def __init__( for id_ in text_df["id"].values: text = text_df.loc[text_df["id"] == id_, text_column].values[0] - with open(op.join(text_dir, str(id_) + ".txt"), "w") as fo: + with open(os.path.join(text_dir, str(id_) + ".txt"), "w") as fo: fo.write(text) # Run MALLET topic modeling - LGR.info("Generating topics...") + LGR.info("Compiling MALLET commands...") import_str = ( f"{mallet_bin} import-dir " f"--input {text_dir} " @@ -142,8 +141,9 @@ def fit(self): p_word_g_topic_ : :obj:`numpy.ndarray` Probability of each word given a topic """ - subprocess.call(self.commands_[0], shell=True) - subprocess.call(self.commands_[1], shell=True) + LGR.info("Generating topics...") + run_shell_command(self.commands_[0]) + run_shell_command(self.commands_[1]) # Read in and convert doc_topics and topic_keys. topic_names = [f"topic_{i:03d}" for i in range(self.params["n_topics"])] @@ -158,7 +158,7 @@ def fit(self): # on an individual id basis by the weights. n_cols = (2 * self.params["n_topics"]) + 1 dt_df = pd.read_csv( - op.join(self.model_dir, "doc_topics.txt"), + os.path.join(self.model_dir, "doc_topics.txt"), delimiter="\t", skiprows=1, header=None, @@ -194,7 +194,7 @@ def fit(self): # Topic word weights p_word_g_topic_df = pd.read_csv( - op.join(self.model_dir, "topic_word_weights.txt"), + os.path.join(self.model_dir, "topic_word_weights.txt"), dtype=str, keep_default_na=False, na_values=[], @@ -213,7 +213,7 @@ def fit(self): shutil.rmtree(self.model_dir) def _clean_str(self, string): - return op.basename(op.splitext(string)[0]) + return os.path.basename(os.path.splitext(string)[0]) def _get_sort(self, lst): return [i[0] for i in sorted(enumerate(lst), key=lambda x: x[1])] diff --git a/nimare/tests/test_utils.py b/nimare/tests/test_utils.py index 946500866..244691470 100644 --- a/nimare/tests/test_utils.py +++ b/nimare/tests/test_utils.py @@ -2,6 +2,7 @@ import logging import os import os.path as op +import time import nibabel as nib import numpy as np @@ -164,3 +165,27 @@ def test_mm2vox(): img = utils.get_template(space="mni152_2mm", mask=None) aff = img.affine assert np.array_equal(utils.mm2vox(test, aff), true) + + +def test_run_shell_command(caplog): + """Test run_shell_command.""" + with caplog.at_level(logging.INFO): + utils.run_shell_command("echo 'output'") + assert "output" in caplog.text + + # Check that the exception is registered as such + with pytest.raises(Exception) as execinfo: + utils.run_shell_command("echo 'Error!' 1>&2;exit 64") + assert "Error!" in str(execinfo.value) + + # Check that the function actually waits until the command completes + dur = 3 + start = time.time() + with caplog.at_level(logging.INFO): + utils.run_shell_command(f"echo 'hi';sleep {dur}s;echo 'bye'") + end = time.time() + + assert "hi" in caplog.text + assert "bye" in caplog.text + duration = end - start + assert duration >= dur diff --git a/nimare/utils.py b/nimare/utils.py index 42f96915a..c3a6257a6 100755 --- a/nimare/utils.py +++ b/nimare/utils.py @@ -5,6 +5,7 @@ import os import os.path as op import re +import subprocess from functools import wraps from tempfile import mkstemp @@ -935,3 +936,28 @@ def boolean_unmask(data_array, bool_array): unmasked_data[bool_array] = data_array unmasked_data = unmasked_data.T return unmasked_data + + +def run_shell_command(command, env=None): + """Run a given command with certain environment variables set.""" + merged_env = os.environ + if env: + merged_env.update(env) + + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + env=merged_env, + ) + while True: + line = process.stdout.readline() + line = str(line, "utf-8")[:-1] + LGR.info(line) + if line == "" and process.poll() is not None: + break + + if process.returncode != 0: + stderr_line = str(process.stderr.read(), "utf-8")[:-1] + raise Exception(f"Non zero return code: {process.returncode}\n{command}\n\n{stderr_line}")