Skip to content

Commit

Permalink
Merge pull request #22 from TravisWheelerLab/dipui-debugging
Browse files Browse the repository at this point in the history
Dipui debugging
  • Loading branch information
georgeglidden authored Dec 19, 2024
2 parents e8043b2 + 7b52e60 commit e624ec4
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 30 deletions.
43 changes: 36 additions & 7 deletions diplomat/core_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
NoneType,
get_typecaster_required_arguments
)
from diplomat.predictor_ops import _get_predictor_settings
from diplomat.utils.pretty_printer import printer as print
from diplomat.utils.cli_tools import func_to_command, allow_arbitrary_flags, Flag, positional_argument_count, CLIError
from argparse import ArgumentParser
Expand Down Expand Up @@ -58,6 +59,29 @@ def _get_casted_args(tc_func, extra_args, error_on_miss=True):

return new_args

def _reconcile_arguments_with_predictor_settings(predictor_name, extra_args, passed_predictor_settings, precasted_args):
"""
PRIVATE: Compare items in extra_args to the predictor's ConfigSpec.
If a key/value pair in the argument matches a pair of setting name/type in the predictor ConfigSpec,
add it to passed_predictor_settings and remove it from extra_args. If the setting has already been set,
either in passed_predictor_settings or in precasted_args, then it will be ignored.
"""
if passed_predictor_settings == None:
passed_predictor_settings = {}
new_extra_args = {}
for plugin_name, config_spec in _get_predictor_settings(predictor_name):
for k, v in extra_args.items():
arg_type = type(v)
if((k in precasted_args) or (k in passed_predictor_settings)):
print(f"Warning: {k} is already set; skipping")
elif(k in config_spec):
print(f"Info: converted command line argument {(k,v)} to a {plugin_name} setting.")
passed_predictor_settings[k] = v
else:
extra_args[k] = v

return new_extra_args, passed_predictor_settings


def _find_frontend(
contracts: Union[DIPLOMATContract, List[DIPLOMATContract]],
Expand Down Expand Up @@ -187,7 +211,6 @@ def track_with(
predictor: Optional[str] = None,
predictor_settings: Optional[Dict[str, Any]] = None,
help_extra: Flag = False,
dipui_file: Optional[PathLike] = None,
**extra_args
):
"""
Expand All @@ -210,10 +233,7 @@ def track_with(
To see valid values, run track with extra_help flag set to true.
"""
from diplomat import CLI_RUN

print(f"framestores: {frame_stores}")
print(f"dipui_file: {dipui_file}")


selected_frontend_name, selected_frontend = _find_frontend(
contracts=[DIPLOMATCommands.analyze_videos, DIPLOMATCommands.analyze_videos],
config=config,
Expand All @@ -234,9 +254,15 @@ def track_with(
print("No frame stores or videos passed, terminating.")
return

## TODO: compare extra_args to predictor.get_settings

# If some videos are supplied, run the frontends video analysis function.
if(videos is not None):
print("Running on videos... TEST")
print("Running on videos...")

precasted_args = _get_casted_args(selected_frontend.analyze_videos, extra_args, error_on_miss = False)
extra_args, predictor_settings = _reconcile_arguments_with_predictor_settings(predictor, extra_args, predictor_settings, precasted_args)

selected_frontend.analyze_videos(
config=config,
videos=videos,
Expand All @@ -249,13 +275,16 @@ def track_with(
# If some frame stores are supplied, run the frontends frame analysis function.
if(frame_stores is not None):
print("Running on frame stores...")

precasted_args = _get_casted_args(selected_frontend.analyze_frames, extra_args, error_on_miss = False)
extra_args, predictor_settings = _reconcile_arguments_with_predictor_settings(predictor, extra_args, predictor_settings, precasted_args)

selected_frontend.analyze_frames(
config=config,
frame_stores=frame_stores,
num_outputs=num_outputs,
predictor=predictor,
predictor_settings=predictor_settings,
dipui_file=dipui_file,
**_get_casted_args(selected_frontend.analyze_frames, extra_args)
)

Expand Down
3 changes: 1 addition & 2 deletions diplomat/frontends/deeplabcut/predict_videos_dlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ def _analyze_video(
save_as_csv,
dest_folder=None,
predictor_cls=None,
predictor_settings=None,
dipui_file=None
predictor_settings=None
) -> str:
print(f"Analyzing video: {video}")

Expand Down
51 changes: 35 additions & 16 deletions diplomat/predictor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,41 +21,60 @@ def list_predictor_plugins():
print("\t", predictor.get_description())
print()


@typecaster_function
@positional_argument_count(1)
def get_predictor_settings(predictor: Optional[Union[List[str], str]] = None):
def _get_predictor_settings(predictor_name: Optional[Union[List[str], str]] = None):
"""
Gets the available/modifiable settings for a specified predictor plugin.
Returns the available/modifiable settings for a specified predictor plugin.
:param predictor: The string or list of strings being the names of the predictor plugins to view customizable
:param predictor_name: The string or list of strings being the names of the predictor plugins to view customizable
settings for. If None, will print settings for all currently available predictors.
Defaults to None.
:return: Nothing, prints to console....
:return: ConfigSpec, a dictionary relating each predictor setting to a 3-tuple of its default value, type, and description strnig.
"""
from typing import Iterable

# Convert whatever the predictor_name argument is to a list of predictor plugins
if predictor is None:
if predictor_name is None:
predictors = processing.get_predictor_plugins()
elif isinstance(predictor, str):
predictors = [processing.get_predictor(predictor)]
elif isinstance(predictor, Iterable):
predictors = [processing.get_predictor(name) for name in predictor]
elif isinstance(predictor_name, str):
predictors = [processing.get_predictor(predictor_name)]
elif isinstance(predictor_name, Iterable):
predictor_name = [processing.get_predictor(name) for name in predictor]
else:
raise ValueError(
"Argument 'predictor_name' not of type Iterable[str], string, or None!!!"
)

# Print name, and settings for each plugin.
# yield name and settings for each plugin.
for predictor in predictors:
print(f"Plugin Name: {predictor.get_name()}")
plugin_name = predictor.get_name()
config_spec = predictor.get_settings()
if config_spec is None:
yield (plugin_name, {})
else:
yield (plugin_name, config_spec)

@typecaster_function
@positional_argument_count(1)
def get_predictor_settings(predictor: Optional[Union[List[str], str]] = None):
"""
Prints the available/modifiable settings for a specified predictor plugin.
:param predictor: The string or list of strings being the names of the predictor plugins to view customizable
settings for. If None, will print settings for all currently available predictors.
Defaults to None.
:return: Nothing, prints to console....
"""

# Print name, and settings for each plugin.
for (plugin_name, config_spec) in _get_predictor_settings(predictor):
print(f"Plugin Name: {plugin_name}")
print("Arguments: ")
if predictor.get_settings() is None:
if config_spec is {}:
print("None")
else:
for name, (def_val, val_type, desc) in predictor.get_settings().items():
for name, (def_val, val_type, desc) in config_spec.items():
print(f"Name: '{name}'")
print(f"Type: {get_type_name(val_type)}")
print(f"Default Value: {def_val}")
Expand Down
2 changes: 1 addition & 1 deletion diplomat/predictors/fpe/frame_pass_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]:
self.settings.export_all_info
)

progress_bar.message("Selecting Maximums")
progress_bar.message("Selecting Maximums - FPE")
return self.get_maximums(
self._frame_holder, progress_bar,
relaxed_radius=self.settings.relaxed_maximum_radius
Expand Down
24 changes: 21 additions & 3 deletions diplomat/predictors/sfpe/segmented_frame_pass_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def get_frame_holder(self):
output_path = Path(self.settings.dipui_file).resolve()
if os.path.exists(output_path):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path += timestamp
output_path = Path(self.settings.dipui_file + timestamp).resolve()
else:
output_path = output_path.parent / f"{output_path.stem}_{timestamp}{output_path.suffix}"

Expand Down Expand Up @@ -535,7 +535,7 @@ def _open(self):

self._segments = np.array(self._frame_holder.metadata["segments"], dtype=np.int64)
self._segment_scores = np.array(self._frame_holder.metadata["segment_scores"], dtype=np.float32)
elif(self.settings.storage_mode in ["memory","hybrid"]):
elif(self.settings.storage_mode in ["memory", "hybrid"]):
self._frame_holder = ForwardBackwardData(self.num_frames, self._num_total_bp)
else:
self._frame_holder = self.get_frame_holder()
Expand Down Expand Up @@ -1458,6 +1458,17 @@ def _export_frames(
if(p_bar is not None):
p_bar.update()

def _copy_to_disk(self, progress_bar: ProgressBar, new_frame_holder: ForwardBackwardData):
progress_bar.message("Saving to Disk")
progress_bar.reset(self._frame_holder.num_frames * self._frame_holder.num_bodyparts)

new_frame_holder.metadata = self._frame_holder.metadata
for frame_idx in range(len(self._frame_holder.frames)):
for bodypart_idx in range(len(self._frame_holder.frames[frame_idx])):
new_frame_holder.frames[frame_idx][bodypart_idx] = self._frame_holder.frames[frame_idx][
bodypart_idx]
progress_bar.update()

def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]:
if(self._restore_path is None):
self._run_frame_passes(progress_bar)
Expand All @@ -1483,7 +1494,14 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]:
self._height = self._frame_holder.metadata.height
self._resolve_frame_orderings(progress_bar)

progress_bar.message("Selecting Maximums")
progress_bar.message("Selecting Maximums - SFPE")

if(self._restore_path is None and self.settings.storage_mode == "hybrid"):
new_frame_holder = self.get_frame_holder()
self._copy_to_disk(progress_bar, new_frame_holder)
self._frame_holder = new_frame_holder
self._frame_holder._frames.flush()

return self.get_maximums(
self._frame_holder,
self._segments,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]:
self._height = self._frame_holder.metadata.height
self._resolve_frame_orderings(progress_bar)

progress_bar.message("Selecting Maximums")
progress_bar.message("Selecting Maximums - Segmented SFPE")
poses = self.get_maximums(
self._frame_holder,
self._segments,
Expand Down

0 comments on commit e624ec4

Please sign in to comment.