Skip to content

Commit

Permalink
further restrict directories returned by labeled frame app
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt committed Jul 15, 2024
1 parent c45018d commit 4967266
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 8 deletions.
33 changes: 29 additions & 4 deletions lightning_pose/apps/labeled_frame_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def run():

# add a multiselect that shows existing model folders, and allows the user to de-select models
# assume we have args.model_dir and we search two levels down for model folders
model_folders = get_model_folders(args.model_dir, require_predictions=True)
model_folders = get_model_folders(
args.model_dir,
require_predictions=True,
require_tb_logs=args.require_tb_logs,
)

# get the last two levels of each path to be presented to user
model_folders_vis = get_model_folders_vis(model_folders)
Expand Down Expand Up @@ -268,8 +272,29 @@ def run():
if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default=[])
parser.add_argument('--make_dir', action='store_true', default=False)
parser.add_argument('--use_ood', action='store_true', default=False)
parser.add_argument(
"--model_dir",
type=str,
default=[],
help="absolute path to directory that contains model subdirectories",
)
parser.add_argument(
"--make_dir",
action="store_true",
default=False,
help="create model_dir if it does not exist (i.e. no models have been trained yet"
)
parser.add_argument(
"--use_ood",
action="store_true",
default=False,
help="use predictions_new.csv instead of predictions.csv",
)
parser.add_argument(
"--require_tb_logs",
action="store_true",
default=False,
help="require model directory to contain the subdirectoy tb_logs, skip otherwise",
)

run()
17 changes: 13 additions & 4 deletions lightning_pose/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,12 @@ def compute_confidence(


# ------------ utils related to model finding in dir ---------
# write a function that finds all model folders in the model_dir
def get_model_folders(model_dir: str, require_predictions: bool = True) -> List[str]:
def get_model_folders(
model_dir: str,
require_predictions: bool = True,
require_tb_logs: bool = False,
) -> List[str]:
"""Find all model folders in a higher-level directory, conditional on directory contents."""
# strip trailing slash if present
if model_dir[-1] == os.sep:
model_dir = model_dir[:-1]
Expand All @@ -243,8 +247,13 @@ def get_model_folders(model_dir: str, require_predictions: bool = True) -> List[
for root, dirs, files in os.walk(model_dir):
if root.count(os.sep) - model_dir.count(os.sep) == 2:
# only include directory if it has predictions.csv file (model training finished)
if require_predictions:
if "predictions.csv" in os.listdir(root):
if require_predictions or require_tb_logs:
append = True
if require_predictions and ("predictions.csv" not in os.listdir(root)):
append &= False
if require_tb_logs and ("tb_logs" not in os.listdir(root)):
append &= False
if append:
model_folders.append(root)
else:
model_folders.append(root)
Expand Down
67 changes: 67 additions & 0 deletions tests/apps/test_app_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os


def test_get_model_folders(tmpdir):

from lightning_pose.apps.utils import get_model_folders

# create model directories
models = {
"00-11-22/33-44-55": {"predictions": True, "tb_logs": True},
"11-22-33/44-55-66": {"predictions": True, "tb_logs": False},
"22-33-44/55-66-77": {"predictions": False, "tb_logs": True},
"33-44-55/66-77-88": {"predictions": False, "tb_logs": False},
}
model_parent = os.path.join(str(tmpdir), "models")
for model, files in models.items():
model_dir = os.path.join(model_parent, model)
os.makedirs(model_dir)
if files["predictions"]:
os.mknod(os.path.join(model_dir, "predictions.csv"))
if files["tb_logs"]:
os.makedirs(os.path.join(model_dir, "tb_logs"))

# test 1: find all model directories
trained_models = get_model_folders(
model_parent,
require_predictions=False,
require_tb_logs=False,
)
for model in models.keys():
assert os.path.join(model_parent, model) in trained_models

# test 2: find trained model directories with predictions.csv
trained_models = get_model_folders(
model_parent,
require_predictions=True,
require_tb_logs=False,
)
for model, files in models.items():
if files["predictions"]:
assert os.path.join(model_parent, model) in trained_models
else:
assert os.path.join(model_parent, model) not in trained_models

# test 3: find trained model directories with config.yaml
trained_models = get_model_folders(
model_parent,
require_predictions=False,
require_tb_logs=True,
)
for model, files in models.items():
if files["tb_logs"]:
assert os.path.join(model_parent, model) in trained_models
else:
assert os.path.join(model_parent, model) not in trained_models

# test 4: find trained model directories with both predictions.csv and config.yaml
trained_models = get_model_folders(
model_parent,
require_predictions=True,
require_tb_logs=True,
)
for model, files in models.items():
if files["predictions"] and files["tb_logs"]:
assert os.path.join(model_parent, model) in trained_models
else:
assert os.path.join(model_parent, model) not in trained_models

0 comments on commit 4967266

Please sign in to comment.