diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index 210671e..3b1f88a 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -13,6 +13,7 @@ import tensorflow as tf import typing from pathlib import Path +from scipy.optimize import linear_sum_assignment from typing import Optional, Tuple, List try: @@ -24,6 +25,13 @@ except Exception: pass +from deeplabcut.pose_estimation_tensorflow.config import load_config +from deeplabcut.pose_estimation_tensorflow.core import ( + predict, predict_multianimal, +) +from deeplabcut.pose_estimation_tensorflow.lib import ( + trackingutils, inferenceutils, +) from dlclive.graph import ( read_graph, finalize_graph, @@ -86,7 +94,7 @@ class DLCLive(object): display_lik : float, optional Likelihood threshold for display - display_raidus : int, optional + display_radius : int, optional radius for keypoint display in pixels, default=3 """ @@ -294,7 +302,7 @@ def init_inference(self, frame=None, **kwargs): graph = finalize_graph(graph_def) output_nodes = get_output_nodes(graph) output_nodes = [on.replace("DLC/", "") for on in output_nodes] - + tf_version_2 = tf.__version__[0] == '2' if tf_version_2: @@ -311,7 +319,7 @@ def init_inference(self, frame=None, **kwargs): output_nodes, input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]}, ) - + try: tflite_model = converter.convert() except Exception: @@ -478,3 +486,161 @@ def close(self): self.is_initialized = False if self.display is not None: self.display.destroy() + + +class MultiAnimalDLCLive(DLCLive): + def __init__( + self, + model_path, + n_animals, + n_multibodyparts, + track_method: str = "box", + min_hits: int = 1, + max_age: int = 1, + sim_threshold: float = 0.6, + resize: Optional[float] = None, + convert2rgb: bool = True, + processor: Optional['Processor'] = None, + display: typing.Union[bool, Display] = False, + pcutoff: float = 0.5, + display_radius: int = 3, + display_cmap: str = "bmy", + ): + if track_method not in ("box", "ellipse"): + raise ValueError("`track_method` should be either `box` or `ellipse`.") + + self.model_path = model_path + super().__init__( + Path(model_path).parent, + resize=resize, + convert2rgb=convert2rgb, + processor=processor, + display=display, + pcutoff=pcutoff, + display_radius=display_radius, + display_cmap=display_cmap, + ) + self.n_animals = n_animals + self.n_multibodyparts = n_multibodyparts + self.track_method = track_method + self.min_hits = min_hits + self.max_age = max_age + self.sim_threshold = sim_threshold + + def read_config(self): + cfg_path = Path(self.path).resolve() / "pose_cfg.yaml" + if not cfg_path.exists(): + raise FileNotFoundError( + f"The pose configuration file for the exported model at {str(cfg_path)} was not found. Please check the path to the exported model directory" + ) + + self.cfg = load_config(cfg_path) + self.cfg["batch_size"] = 1 + self.cfg["init_weights"] = self.model_path.split(".")[0] + self.identity_only = self.cfg["num_idchannel"] > 0 + + def init_inference( + self, + frame=None, + allow_growth=False, + **kwargs, + ): + self.sess, self.inputs, self.outputs = predict.setup_pose_prediction( + self.cfg, allow_growth=allow_growth, + ) + + if self.track_method == "box": + self.mot_tracker = trackingutils.SORTBox( + self.max_age, self.min_hits, self.sim_threshold, + ) + else: + self.mot_tracker = trackingutils.SORTEllipse( + self.max_age, self.min_hits, self.sim_threshold, + ) + + + data = { + "metadata": { + "all_joints_names": self.cfg["all_joints_names"], + "PAFgraph": self.cfg["partaffinityfield_graph"], + "PAFinds": self.cfg.get("paf_best", np.arange(self.n_multibodyparts)) + } + } + # Hack to avoid IndexError when determining _has_identity + temp = {"identity": []} if self.identity_only else {} + data["frame0"] = temp + self.ass = inferenceutils.Assembler( + data, + max_n_individuals=self.n_animals, + n_multibodyparts=self.n_multibodyparts, + greedy=True, # TODO Benchmark vs optimal matching + identity_only=self.identity_only, + max_overlap=1, + ) + + if frame is not None: + pose = self.get_pose(frame, **kwargs) + else: + pose = None + + self.is_initialized = True + + return pose + + def get_pose(self, frame=None, **kwargs): + if frame is None: + raise DLCLiveError("No frame provided for live pose estimation") + + frame = self.process_frame(frame) + data_dict = predict_multianimal.predict_batched_peaks_and_costs( + self.cfg, + np.expand_dims(frame, axis=0), + self.sess, + self.inputs, + self.outputs, + ) + if not data_dict: + return + else: + data_dict = data_dict[0] + + pose = np.full((self.n_animals, self.n_multibodyparts, 4), np.nan) + assemblies, unique = self.ass._assemble(data_dict, ind_frame=0) + if assemblies: + if self.n_animals == 1: + pose[0] = assemblies[0].data + else: + animals = np.stack([a.data for a in assemblies]) + if not self.ass.identity_only: + if self.track_method == "box": + xy = trackingutils.calc_bboxes_from_keypoints(animals) + else: + xy = animals[..., :2] + trackers = self.mot_tracker.track(xy)[:, -2:].astype(np.int) + else: + # Optimal identity assignment based on soft voting + mat = np.zeros( + (len(assemblies), self.n_animals) + ) + for nrow, assembly in enumerate(assemblies): + for k, v in assembly.soft_identity.items(): + mat[nrow, k] = v + inds = linear_sum_assignment(mat, maximize=True) + trackers = np.c_[inds][:, ::-1] + # Discard trackers of false positives + trackers = trackers[trackers[:, 0] < self.n_animals] + for pose_ind, animal_ind in trackers: + pose[pose_ind] = animals[animal_ind] + self.pose = (pose, unique) + + if self.display is not None: + self.display.display_frame(frame, self.pose) + + if self.resize is not None: + self.pose[0][..., :2] *= 1 / self.resize + self.pose[1][:, :2] *= 1 / self.resize + + if self.processor: + self.pose = self.processor.process(self.pose, **kwargs) + + return self.pose