Skip to content

Commit

Permalink
[REF] Use new function to run LDA commands (#587)
Browse files Browse the repository at this point in the history
* Add new function for running shell commands.

* Add test for new utility function.
  • Loading branch information
tsalo authored Oct 26, 2021
1 parent fd46f9c commit b1d7b74
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 13 deletions.
26 changes: 13 additions & 13 deletions nimare/annotate/lda.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Topic modeling with latent Dirichlet allocation via MALLET."""
import logging
import os
import os.path as op
import shutil
import subprocess

import numpy as np
import pandas as pd
Expand All @@ -12,6 +10,7 @@
from ..base import NiMAREBase
from ..due import due
from ..extract import download_mallet, utils
from ..utils import run_shell_command

LGR = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,12 +72,12 @@ def __init__(
self, text_df, text_column="abstract", n_topics=50, n_iters=1000, alpha="auto", beta=0.001
):
mallet_dir = download_mallet()
mallet_bin = op.join(mallet_dir, "bin/mallet")
mallet_bin = os.path.join(mallet_dir, "bin/mallet")

model_dir = utils._get_dataset_dir("mallet_model")
text_dir = op.join(model_dir, "texts")
text_dir = os.path.join(model_dir, "texts")

if not op.isdir(model_dir):
if not os.path.isdir(model_dir):
os.mkdir(model_dir)

if alpha == "auto":
Expand All @@ -90,7 +89,7 @@ def __init__(
self.model_dir = model_dir

# Check for presence of text files and convert if necessary
if not op.isdir(text_dir):
if not os.path.isdir(text_dir):
LGR.info("Texts folder not found. Creating text files...")
os.mkdir(text_dir)

Expand All @@ -104,11 +103,11 @@ def __init__(

for id_ in text_df["id"].values:
text = text_df.loc[text_df["id"] == id_, text_column].values[0]
with open(op.join(text_dir, str(id_) + ".txt"), "w") as fo:
with open(os.path.join(text_dir, str(id_) + ".txt"), "w") as fo:
fo.write(text)

# Run MALLET topic modeling
LGR.info("Generating topics...")
LGR.info("Compiling MALLET commands...")
import_str = (
f"{mallet_bin} import-dir "
f"--input {text_dir} "
Expand Down Expand Up @@ -142,8 +141,9 @@ def fit(self):
p_word_g_topic_ : :obj:`numpy.ndarray`
Probability of each word given a topic
"""
subprocess.call(self.commands_[0], shell=True)
subprocess.call(self.commands_[1], shell=True)
LGR.info("Generating topics...")
run_shell_command(self.commands_[0])
run_shell_command(self.commands_[1])

# Read in and convert doc_topics and topic_keys.
topic_names = [f"topic_{i:03d}" for i in range(self.params["n_topics"])]
Expand All @@ -158,7 +158,7 @@ def fit(self):
# on an individual id basis by the weights.
n_cols = (2 * self.params["n_topics"]) + 1
dt_df = pd.read_csv(
op.join(self.model_dir, "doc_topics.txt"),
os.path.join(self.model_dir, "doc_topics.txt"),
delimiter="\t",
skiprows=1,
header=None,
Expand Down Expand Up @@ -194,7 +194,7 @@ def fit(self):

# Topic word weights
p_word_g_topic_df = pd.read_csv(
op.join(self.model_dir, "topic_word_weights.txt"),
os.path.join(self.model_dir, "topic_word_weights.txt"),
dtype=str,
keep_default_na=False,
na_values=[],
Expand All @@ -213,7 +213,7 @@ def fit(self):
shutil.rmtree(self.model_dir)

def _clean_str(self, string):
return op.basename(op.splitext(string)[0])
return os.path.basename(os.path.splitext(string)[0])

def _get_sort(self, lst):
return [i[0] for i in sorted(enumerate(lst), key=lambda x: x[1])]
25 changes: 25 additions & 0 deletions nimare/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import os.path as op
import time

import nibabel as nib
import numpy as np
Expand Down Expand Up @@ -164,3 +165,27 @@ def test_mm2vox():
img = utils.get_template(space="mni152_2mm", mask=None)
aff = img.affine
assert np.array_equal(utils.mm2vox(test, aff), true)


def test_run_shell_command(caplog):
"""Test run_shell_command."""
with caplog.at_level(logging.INFO):
utils.run_shell_command("echo 'output'")
assert "output" in caplog.text

# Check that the exception is registered as such
with pytest.raises(Exception) as execinfo:
utils.run_shell_command("echo 'Error!' 1>&2;exit 64")
assert "Error!" in str(execinfo.value)

# Check that the function actually waits until the command completes
dur = 3
start = time.time()
with caplog.at_level(logging.INFO):
utils.run_shell_command(f"echo 'hi';sleep {dur}s;echo 'bye'")
end = time.time()

assert "hi" in caplog.text
assert "bye" in caplog.text
duration = end - start
assert duration >= dur
26 changes: 26 additions & 0 deletions nimare/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import os.path as op
import re
import subprocess
from functools import wraps
from tempfile import mkstemp

Expand Down Expand Up @@ -935,3 +936,28 @@ def boolean_unmask(data_array, bool_array):
unmasked_data[bool_array] = data_array
unmasked_data = unmasked_data.T
return unmasked_data


def run_shell_command(command, env=None):
"""Run a given command with certain environment variables set."""
merged_env = os.environ
if env:
merged_env.update(env)

process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True,
env=merged_env,
)
while True:
line = process.stdout.readline()
line = str(line, "utf-8")[:-1]
LGR.info(line)
if line == "" and process.poll() is not None:
break

if process.returncode != 0:
stderr_line = str(process.stderr.read(), "utf-8")[:-1]
raise Exception(f"Non zero return code: {process.returncode}\n{command}\n\n{stderr_line}")

0 comments on commit b1d7b74

Please sign in to comment.