From b5391dfc2697ae44956d5197fecc1f691c7586cf Mon Sep 17 00:00:00 2001 From: Jessy <30733203+jeylau@users.noreply.github.com> Date: Thu, 9 Jun 2022 21:06:29 +0200 Subject: [PATCH 1/5] Fix typo --- dlclive/dlclive.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index 210671e..b251f1c 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -86,7 +86,7 @@ class DLCLive(object): display_lik : float, optional Likelihood threshold for display - display_raidus : int, optional + display_radius : int, optional radius for keypoint display in pixels, default=3 """ From 85008676547ff6e48f271d4b5bc751c0b59d70a0 Mon Sep 17 00:00:00 2001 From: Jessy <30733203+jeylau@users.noreply.github.com> Date: Thu, 9 Jun 2022 21:07:41 +0200 Subject: [PATCH 2/5] Draft MultiAnimalDLCLive --- dlclive/dlclive.py | 162 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 160 insertions(+), 2 deletions(-) diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index b251f1c..f6ff77d 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -13,6 +13,7 @@ import tensorflow as tf import typing from pathlib import Path +from scipy.optimize import linear_sum_assignment from typing import Optional, Tuple, List try: @@ -24,6 +25,13 @@ except Exception: pass +from deeplabcut.pose_estimation_tensorflow.config import load_config +from deeplabcut.pose_estimation_tensorflow.core import ( + predict, predict_multianimal, +) +from deeplabcut.pose_estimation_tensorflow.lib import ( + trackingutils, inferenceutils, +) from dlclive.graph import ( read_graph, finalize_graph, @@ -294,7 +302,7 @@ def init_inference(self, frame=None, **kwargs): graph = finalize_graph(graph_def) output_nodes = get_output_nodes(graph) output_nodes = [on.replace("DLC/", "") for on in output_nodes] - + tf_version_2 = tf.__version__[0] == '2' if tf_version_2: @@ -311,7 +319,7 @@ def init_inference(self, frame=None, **kwargs): output_nodes, input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]}, ) - + try: tflite_model = converter.convert() except Exception: @@ -478,3 +486,153 @@ def close(self): self.is_initialized = False if self.display is not None: self.display.destroy() + + +class MultiAnimalDLCLive(DLCLive): + def __init__( + self, + model_path, + n_animals, + n_multibodyparts, + track_method: str = "box", + min_hits: int = 1, + max_age: int = 1, + sim_threshold: float = 0.6, + resize: Optional[float] = None, + convert2rgb: bool = True, + processor: Optional['Processor'] = None, + display: typing.Union[bool, Display] = False, + pcutoff: float = 0.5, + display_radius: int = 3, + display_cmap: str = "bmy", + ): + if track_method not in ("box", "ellipse"): + raise ValueError("`track_method` should be either `box` or `ellipse`.") + + self.model_path = model_path + super().__init__( + Path(model_path).parent, + resize=resize, + convert2rgb=convert2rgb, + processor=processor, + display=display, + pcutoff=pcutoff, + display_radius=display_radius, + display_cmap=display_cmap, + ) + self.n_animals = n_animals + self.n_multibodyparts = n_multibodyparts + self.track_method = track_method + self.min_hits = min_hits + self.max_age = max_age + self.sim_threshold = sim_threshold + + def read_config(self): + cfg_path = Path(self.path).resolve() / "pose_cfg.yaml" + if not cfg_path.exists(): + raise FileNotFoundError( + f"The pose configuration file for the exported model at {str(cfg_path)} was not found. Please check the path to the exported model directory" + ) + + self.cfg = load_config(cfg_path) + self.cfg["batch_size"] = 1 + self.cfg["init_weights"] = self.model_path.split(".")[0] + self.identity_only = self.cfg["num_idchannel"] > 0 + + def init_inference( + self, + frame=None, + allow_growth=False, + **kwargs, + ): + self.sess, self.inputs, self.outputs = predict.setup_pose_prediction( + self.cfg, allow_growth=allow_growth, + ) + + if self.track_method == "box": + self.mot_tracker = trackingutils.SORTBox( + self.max_age, self.min_hits, self.sim_threshold, + ) + else: + self.mot_tracker = trackingutils.SORTEllipse( + self.max_age, self.min_hits, self.sim_threshold, + ) + + + data = { + "metadata": { + "all_joints_names": self.cfg["all_joints_names"], + "PAFgraph": self.cfg["partaffinityfield_graph"], + "PAFinds": self.cfg.get("paf_best", np.arange(self.n_multibodyparts)) + } + } + # Hack to avoid IndexError when determining _has_identity + temp = {"identity": []} if self.identity_only else {} + data["frame0"] = temp + self.ass = inferenceutils.Assembler( + data, + max_n_individuals=self.n_animals, + n_multibodyparts=self.n_multibodyparts, + greedy=True, # TODO Benchmark vs optimal matching + identity_only=self.identity_only, + max_overlap=1, + ) + + if frame is not None: + pose = self.get_pose(frame, **kwargs) + else: + pose = None + + self.is_initialized = True + + return pose + + def get_pose(self, frame=None, **kwargs): + if frame is None: + raise DLCLiveError("No frame provided for live pose estimation") + + frame = self.process_frame(frame) + data_dict = predict_multianimal.predict_batched_peaks_and_costs( + self.cfg, + np.expand_dims(frame, axis=0), + self.sess, + self.inputs, + self.outputs, + )[0] + + assemblies, unique = self.ass._assemble(data_dict, ind_frame=0) + pose = np.full((self.n_animals, self.n_multibodyparts, 4), np.nan) + if self.n_animals == 1: + pose[0] = assemblies[0].data + else: + animals = np.stack([a.data for a in assemblies]) + if not self.ass.identity_only: + if self.track_method == "box": + xy = trackingutils.calc_bboxes_from_keypoints(animals) + else: + xy = animals[..., :2] + trackers = self.mot_tracker.track(xy)[:, -2:].astype(np.int) + else: + # Optimal identity assignment based on soft voting + mat = np.zeros( + (len(assemblies), self.n_animals) + ) + for nrow, assembly in enumerate(assemblies): + for k, v in assembly.soft_identity.items(): + mat[nrow, k] = v + inds = linear_sum_assignment(mat, maximize=True) + trackers = np.c_[inds][:, ::-1] + for i, j in trackers: + pose[i] = animals[j] + self.pose = (pose, unique) + + if self.display is not None: + self.display.display_frame(frame, self.pose) + + if self.resize is not None: + self.pose[..., :2] *= 1 / self.resize + + if self.processor: + self.pose = self.processor.process(self.pose, **kwargs) + + return self.pose From 5b223d2bae0c71d2bcf4b55173811a48a4b3f11f Mon Sep 17 00:00:00 2001 From: Jessy <30733203+jeylau@users.noreply.github.com> Date: Thu, 9 Jun 2022 22:37:38 +0200 Subject: [PATCH 3/5] Fix resizing --- dlclive/dlclive.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index f6ff77d..98435fb 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -630,7 +630,8 @@ def get_pose(self, frame=None, **kwargs): self.display.display_frame(frame, self.pose) if self.resize is not None: - self.pose[..., :2] *= 1 / self.resize + self.pose[0][..., :2] *= 1 / self.resize + self.pose[1][:, :2] *= 1 / self.resize if self.processor: self.pose = self.processor.process(self.pose, **kwargs) From 79cb7844e064f4459fc863716045880f48d3c9f1 Mon Sep 17 00:00:00 2001 From: Jessy Lauer <30733203+jeylau@users.noreply.github.com> Date: Thu, 10 Nov 2022 11:51:24 +0100 Subject: [PATCH 4/5] Minor fixes --- dlclive/dlclive.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index 98435fb..11b93ca 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -598,7 +598,11 @@ def get_pose(self, frame=None, **kwargs): self.sess, self.inputs, self.outputs, - )[0] + ) + if not data_dict: + return + else: + data_dict = data_dict[0] assemblies, unique = self.ass._assemble(data_dict, ind_frame=0) pose = np.full((self.n_animals, self.n_multibodyparts, 4), np.nan) @@ -622,8 +626,10 @@ def get_pose(self, frame=None, **kwargs): mat[nrow, k] = v inds = linear_sum_assignment(mat, maximize=True) trackers = np.c_[inds][:, ::-1] - for i, j in trackers: - pose[i] = animals[j] + # Discard trackers of false positives + trackers = trackers[trackers[:, 0] < self.n_animals] + for pose_ind, animal_ind in trackers: + pose[pose_ind] = animals[animal_ind] self.pose = (pose, unique) if self.display is not None: From 61e3dc13f23cbe9311af481f43338c4c24935d46 Mon Sep 17 00:00:00 2001 From: Jessy Lauer <30733203+jeylau@users.noreply.github.com> Date: Thu, 10 Nov 2022 14:26:58 +0100 Subject: [PATCH 5/5] Handle absence of assemblies --- dlclive/dlclive.py | 49 +++++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/dlclive/dlclive.py b/dlclive/dlclive.py index 11b93ca..3b1f88a 100644 --- a/dlclive/dlclive.py +++ b/dlclive/dlclive.py @@ -604,32 +604,33 @@ def get_pose(self, frame=None, **kwargs): else: data_dict = data_dict[0] - assemblies, unique = self.ass._assemble(data_dict, ind_frame=0) pose = np.full((self.n_animals, self.n_multibodyparts, 4), np.nan) - if self.n_animals == 1: - pose[0] = assemblies[0].data - else: - animals = np.stack([a.data for a in assemblies]) - if not self.ass.identity_only: - if self.track_method == "box": - xy = trackingutils.calc_bboxes_from_keypoints(animals) - else: - xy = animals[..., :2] - trackers = self.mot_tracker.track(xy)[:, -2:].astype(np.int) + assemblies, unique = self.ass._assemble(data_dict, ind_frame=0) + if assemblies: + if self.n_animals == 1: + pose[0] = assemblies[0].data else: - # Optimal identity assignment based on soft voting - mat = np.zeros( - (len(assemblies), self.n_animals) - ) - for nrow, assembly in enumerate(assemblies): - for k, v in assembly.soft_identity.items(): - mat[nrow, k] = v - inds = linear_sum_assignment(mat, maximize=True) - trackers = np.c_[inds][:, ::-1] - # Discard trackers of false positives - trackers = trackers[trackers[:, 0] < self.n_animals] - for pose_ind, animal_ind in trackers: - pose[pose_ind] = animals[animal_ind] + animals = np.stack([a.data for a in assemblies]) + if not self.ass.identity_only: + if self.track_method == "box": + xy = trackingutils.calc_bboxes_from_keypoints(animals) + else: + xy = animals[..., :2] + trackers = self.mot_tracker.track(xy)[:, -2:].astype(np.int) + else: + # Optimal identity assignment based on soft voting + mat = np.zeros( + (len(assemblies), self.n_animals) + ) + for nrow, assembly in enumerate(assemblies): + for k, v in assembly.soft_identity.items(): + mat[nrow, k] = v + inds = linear_sum_assignment(mat, maximize=True) + trackers = np.c_[inds][:, ::-1] + # Discard trackers of false positives + trackers = trackers[trackers[:, 0] < self.n_animals] + for pose_ind, animal_ind in trackers: + pose[pose_ind] = animals[animal_ind] self.pose = (pose, unique) if self.display is not None: