diff --git a/diplomat/core_ops.py b/diplomat/core_ops.py index 914b753..fea6ee7 100644 --- a/diplomat/core_ops.py +++ b/diplomat/core_ops.py @@ -12,6 +12,7 @@ NoneType, get_typecaster_required_arguments ) +from diplomat.predictor_ops import _get_predictor_settings from diplomat.utils.pretty_printer import printer as print from diplomat.utils.cli_tools import func_to_command, allow_arbitrary_flags, Flag, positional_argument_count, CLIError from argparse import ArgumentParser @@ -58,6 +59,29 @@ def _get_casted_args(tc_func, extra_args, error_on_miss=True): return new_args +def _reconcile_arguments_with_predictor_settings(predictor_name, extra_args, passed_predictor_settings, precasted_args): + """ + PRIVATE: Compare items in extra_args to the predictor's ConfigSpec. + If a key/value pair in the argument matches a pair of setting name/type in the predictor ConfigSpec, + add it to passed_predictor_settings and remove it from extra_args. If the setting has already been set, + either in passed_predictor_settings or in precasted_args, then it will be ignored. + """ + if passed_predictor_settings == None: + passed_predictor_settings = {} + new_extra_args = {} + for plugin_name, config_spec in _get_predictor_settings(predictor_name): + for k, v in extra_args.items(): + arg_type = type(v) + if((k in precasted_args) or (k in passed_predictor_settings)): + print(f"Warning: {k} is already set; skipping") + elif(k in config_spec): + print(f"Info: converted command line argument {(k,v)} to a {plugin_name} setting.") + passed_predictor_settings[k] = v + else: + extra_args[k] = v + + return new_extra_args, passed_predictor_settings + def _find_frontend( contracts: Union[DIPLOMATContract, List[DIPLOMATContract]], @@ -187,7 +211,6 @@ def track_with( predictor: Optional[str] = None, predictor_settings: Optional[Dict[str, Any]] = None, help_extra: Flag = False, - dipui_file: Optional[PathLike] = None, **extra_args ): """ @@ -210,10 +233,7 @@ def track_with( To see valid values, run track with extra_help flag set to true. """ from diplomat import CLI_RUN - - print(f"framestores: {frame_stores}") - print(f"dipui_file: {dipui_file}") - + selected_frontend_name, selected_frontend = _find_frontend( contracts=[DIPLOMATCommands.analyze_videos, DIPLOMATCommands.analyze_videos], config=config, @@ -234,9 +254,15 @@ def track_with( print("No frame stores or videos passed, terminating.") return + ## TODO: compare extra_args to predictor.get_settings + # If some videos are supplied, run the frontends video analysis function. if(videos is not None): - print("Running on videos... TEST") + print("Running on videos...") + + precasted_args = _get_casted_args(selected_frontend.analyze_videos, extra_args, error_on_miss = False) + extra_args, predictor_settings = _reconcile_arguments_with_predictor_settings(predictor, extra_args, predictor_settings, precasted_args) + selected_frontend.analyze_videos( config=config, videos=videos, @@ -249,13 +275,16 @@ def track_with( # If some frame stores are supplied, run the frontends frame analysis function. if(frame_stores is not None): print("Running on frame stores...") + + precasted_args = _get_casted_args(selected_frontend.analyze_frames, extra_args, error_on_miss = False) + extra_args, predictor_settings = _reconcile_arguments_with_predictor_settings(predictor, extra_args, predictor_settings, precasted_args) + selected_frontend.analyze_frames( config=config, frame_stores=frame_stores, num_outputs=num_outputs, predictor=predictor, predictor_settings=predictor_settings, - dipui_file=dipui_file, **_get_casted_args(selected_frontend.analyze_frames, extra_args) ) diff --git a/diplomat/frontends/deeplabcut/predict_videos_dlc.py b/diplomat/frontends/deeplabcut/predict_videos_dlc.py index 008e3ce..6f528f2 100644 --- a/diplomat/frontends/deeplabcut/predict_videos_dlc.py +++ b/diplomat/frontends/deeplabcut/predict_videos_dlc.py @@ -196,8 +196,7 @@ def _analyze_video( save_as_csv, dest_folder=None, predictor_cls=None, - predictor_settings=None, - dipui_file=None + predictor_settings=None ) -> str: print(f"Analyzing video: {video}") diff --git a/diplomat/predictor_ops.py b/diplomat/predictor_ops.py index 3b06988..4955362 100644 --- a/diplomat/predictor_ops.py +++ b/diplomat/predictor_ops.py @@ -21,41 +21,60 @@ def list_predictor_plugins(): print("\t", predictor.get_description()) print() - -@typecaster_function -@positional_argument_count(1) -def get_predictor_settings(predictor: Optional[Union[List[str], str]] = None): +def _get_predictor_settings(predictor_name: Optional[Union[List[str], str]] = None): """ - Gets the available/modifiable settings for a specified predictor plugin. + Returns the available/modifiable settings for a specified predictor plugin. - :param predictor: The string or list of strings being the names of the predictor plugins to view customizable + :param predictor_name: The string or list of strings being the names of the predictor plugins to view customizable settings for. If None, will print settings for all currently available predictors. Defaults to None. - :return: Nothing, prints to console.... + :return: ConfigSpec, a dictionary relating each predictor setting to a 3-tuple of its default value, type, and description strnig. """ from typing import Iterable # Convert whatever the predictor_name argument is to a list of predictor plugins - if predictor is None: + if predictor_name is None: predictors = processing.get_predictor_plugins() - elif isinstance(predictor, str): - predictors = [processing.get_predictor(predictor)] - elif isinstance(predictor, Iterable): - predictors = [processing.get_predictor(name) for name in predictor] + elif isinstance(predictor_name, str): + predictors = [processing.get_predictor(predictor_name)] + elif isinstance(predictor_name, Iterable): + predictor_name = [processing.get_predictor(name) for name in predictor] else: raise ValueError( "Argument 'predictor_name' not of type Iterable[str], string, or None!!!" ) - # Print name, and settings for each plugin. + # yield name and settings for each plugin. for predictor in predictors: - print(f"Plugin Name: {predictor.get_name()}") + plugin_name = predictor.get_name() + config_spec = predictor.get_settings() + if config_spec is None: + yield (plugin_name, {}) + else: + yield (plugin_name, config_spec) + +@typecaster_function +@positional_argument_count(1) +def get_predictor_settings(predictor: Optional[Union[List[str], str]] = None): + """ + Prints the available/modifiable settings for a specified predictor plugin. + + :param predictor: The string or list of strings being the names of the predictor plugins to view customizable + settings for. If None, will print settings for all currently available predictors. + Defaults to None. + + :return: Nothing, prints to console.... + """ + + # Print name, and settings for each plugin. + for (plugin_name, config_spec) in _get_predictor_settings(predictor): + print(f"Plugin Name: {plugin_name}") print("Arguments: ") - if predictor.get_settings() is None: + if config_spec is {}: print("None") else: - for name, (def_val, val_type, desc) in predictor.get_settings().items(): + for name, (def_val, val_type, desc) in config_spec.items(): print(f"Name: '{name}'") print(f"Type: {get_type_name(val_type)}") print(f"Default Value: {def_val}") diff --git a/diplomat/predictors/fpe/frame_pass_engine.py b/diplomat/predictors/fpe/frame_pass_engine.py index 136872b..312d868 100644 --- a/diplomat/predictors/fpe/frame_pass_engine.py +++ b/diplomat/predictors/fpe/frame_pass_engine.py @@ -409,7 +409,7 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]: self.settings.export_all_info ) - progress_bar.message("Selecting Maximums") + progress_bar.message("Selecting Maximums - FPE") return self.get_maximums( self._frame_holder, progress_bar, relaxed_radius=self.settings.relaxed_maximum_radius diff --git a/diplomat/predictors/sfpe/segmented_frame_pass_engine.py b/diplomat/predictors/sfpe/segmented_frame_pass_engine.py index 9daff71..6cee5ec 100644 --- a/diplomat/predictors/sfpe/segmented_frame_pass_engine.py +++ b/diplomat/predictors/sfpe/segmented_frame_pass_engine.py @@ -471,7 +471,7 @@ def get_frame_holder(self): output_path = Path(self.settings.dipui_file).resolve() if os.path.exists(output_path): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output_path += timestamp + output_path = Path(self.settings.dipui_file + timestamp).resolve() else: output_path = output_path.parent / f"{output_path.stem}_{timestamp}{output_path.suffix}" @@ -535,7 +535,7 @@ def _open(self): self._segments = np.array(self._frame_holder.metadata["segments"], dtype=np.int64) self._segment_scores = np.array(self._frame_holder.metadata["segment_scores"], dtype=np.float32) - elif(self.settings.storage_mode in ["memory","hybrid"]): + elif(self.settings.storage_mode in ["memory", "hybrid"]): self._frame_holder = ForwardBackwardData(self.num_frames, self._num_total_bp) else: self._frame_holder = self.get_frame_holder() @@ -1458,6 +1458,17 @@ def _export_frames( if(p_bar is not None): p_bar.update() + def _copy_to_disk(self, progress_bar: ProgressBar, new_frame_holder: ForwardBackwardData): + progress_bar.message("Saving to Disk") + progress_bar.reset(self._frame_holder.num_frames * self._frame_holder.num_bodyparts) + + new_frame_holder.metadata = self._frame_holder.metadata + for frame_idx in range(len(self._frame_holder.frames)): + for bodypart_idx in range(len(self._frame_holder.frames[frame_idx])): + new_frame_holder.frames[frame_idx][bodypart_idx] = self._frame_holder.frames[frame_idx][ + bodypart_idx] + progress_bar.update() + def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]: if(self._restore_path is None): self._run_frame_passes(progress_bar) @@ -1483,7 +1494,14 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]: self._height = self._frame_holder.metadata.height self._resolve_frame_orderings(progress_bar) - progress_bar.message("Selecting Maximums") + progress_bar.message("Selecting Maximums - SFPE") + + if(self._restore_path is None and self.settings.storage_mode == "hybrid"): + new_frame_holder = self.get_frame_holder() + self._copy_to_disk(progress_bar, new_frame_holder) + self._frame_holder = new_frame_holder + self._frame_holder._frames.flush() + return self.get_maximums( self._frame_holder, self._segments, diff --git a/diplomat/predictors/supervised_sfpe/supervised_segmented_frame_pass_engine.py b/diplomat/predictors/supervised_sfpe/supervised_segmented_frame_pass_engine.py index 7fdf0fa..e24f877 100644 --- a/diplomat/predictors/supervised_sfpe/supervised_segmented_frame_pass_engine.py +++ b/diplomat/predictors/supervised_sfpe/supervised_segmented_frame_pass_engine.py @@ -650,7 +650,7 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]: self._height = self._frame_holder.metadata.height self._resolve_frame_orderings(progress_bar) - progress_bar.message("Selecting Maximums") + progress_bar.message("Selecting Maximums - Segmented SFPE") poses = self.get_maximums( self._frame_holder, self._segments,