-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-animal support to DLCLive #72
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import tensorflow as tf | ||
import typing | ||
from pathlib import Path | ||
from scipy.optimize import linear_sum_assignment | ||
from typing import Optional, Tuple, List | ||
|
||
try: | ||
|
@@ -24,6 +25,13 @@ | |
except Exception: | ||
pass | ||
|
||
from deeplabcut.pose_estimation_tensorflow.config import load_config | ||
from deeplabcut.pose_estimation_tensorflow.core import ( | ||
predict, predict_multianimal, | ||
) | ||
from deeplabcut.pose_estimation_tensorflow.lib import ( | ||
trackingutils, inferenceutils, | ||
) | ||
from dlclive.graph import ( | ||
read_graph, | ||
finalize_graph, | ||
|
@@ -86,7 +94,7 @@ class DLCLive(object): | |
display_lik : float, optional | ||
Likelihood threshold for display | ||
|
||
display_raidus : int, optional | ||
display_radius : int, optional | ||
radius for keypoint display in pixels, default=3 | ||
""" | ||
|
||
|
@@ -294,7 +302,7 @@ def init_inference(self, frame=None, **kwargs): | |
graph = finalize_graph(graph_def) | ||
output_nodes = get_output_nodes(graph) | ||
output_nodes = [on.replace("DLC/", "") for on in output_nodes] | ||
|
||
tf_version_2 = tf.__version__[0] == '2' | ||
|
||
if tf_version_2: | ||
|
@@ -311,7 +319,7 @@ def init_inference(self, frame=None, **kwargs): | |
output_nodes, | ||
input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]}, | ||
) | ||
|
||
try: | ||
tflite_model = converter.convert() | ||
except Exception: | ||
|
@@ -478,3 +486,154 @@ def close(self): | |
self.is_initialized = False | ||
if self.display is not None: | ||
self.display.destroy() | ||
|
||
|
||
class MultiAnimalDLCLive(DLCLive): | ||
def __init__( | ||
self, | ||
model_path, | ||
n_animals, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these not specified in the config file? |
||
n_multibodyparts, | ||
track_method: str = "box", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this could be a literal |
||
min_hits: int = 1, | ||
max_age: int = 1, | ||
sim_threshold: float = 0.6, | ||
resize: Optional[float] = None, | ||
convert2rgb: bool = True, | ||
processor: Optional['Processor'] = None, | ||
display: typing.Union[bool, Display] = False, | ||
pcutoff: float = 0.5, | ||
display_radius: int = 3, | ||
display_cmap: str = "bmy", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing arguments from parent signature:
Any reason? We should try and make uniform API if possible. Since these are all in the Two approaches: add the missing arguments, or else remove all of them and accept a |
||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. needs docstring! Since inherits from |
||
if track_method not in ("box", "ellipse"): | ||
raise ValueError("`track_method` should be either `box` or `ellipse`.") | ||
|
||
self.model_path = model_path | ||
super().__init__( | ||
Path(model_path).parent, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. parent class expects a |
||
resize=resize, | ||
convert2rgb=convert2rgb, | ||
processor=processor, | ||
display=display, | ||
pcutoff=pcutoff, | ||
display_radius=display_radius, | ||
display_cmap=display_cmap, | ||
) | ||
self.n_animals = n_animals | ||
self.n_multibodyparts = n_multibodyparts | ||
self.track_method = track_method | ||
self.min_hits = min_hits | ||
self.max_age = max_age | ||
self.sim_threshold = sim_threshold | ||
|
||
def read_config(self): | ||
cfg_path = Path(self.path).resolve() / "pose_cfg.yaml" | ||
if not cfg_path.exists(): | ||
raise FileNotFoundError( | ||
f"The pose configuration file for the exported model at {str(cfg_path)} was not found. Please check the path to the exported model directory" | ||
) | ||
|
||
self.cfg = load_config(cfg_path) | ||
self.cfg["batch_size"] = 1 | ||
self.cfg["init_weights"] = self.model_path.split(".")[0] | ||
self.identity_only = self.cfg["num_idchannel"] > 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not declared in init |
||
|
||
def init_inference( | ||
self, | ||
frame=None, | ||
allow_growth=False, | ||
**kwargs, | ||
): | ||
self.sess, self.inputs, self.outputs = predict.setup_pose_prediction( | ||
self.cfg, allow_growth=allow_growth, | ||
) | ||
|
||
if self.track_method == "box": | ||
self.mot_tracker = trackingutils.SORTBox( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. attribute not declared in |
||
self.max_age, self.min_hits, self.sim_threshold, | ||
) | ||
else: | ||
self.mot_tracker = trackingutils.SORTEllipse( | ||
self.max_age, self.min_hits, self.sim_threshold, | ||
) | ||
|
||
|
||
data = { | ||
"metadata": { | ||
"all_joints_names": self.cfg["all_joints_names"], | ||
"PAFgraph": self.cfg["partaffinityfield_graph"], | ||
"PAFinds": self.cfg.get("paf_best", np.arange(self.n_multibodyparts)) | ||
} | ||
} | ||
# Hack to avoid IndexError when determining _has_identity | ||
temp = {"identity": []} if self.identity_only else {} | ||
data["frame0"] = temp | ||
self.ass = inferenceutils.Assembler( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
data, | ||
max_n_individuals=self.n_animals, | ||
n_multibodyparts=self.n_multibodyparts, | ||
greedy=True, # TODO Benchmark vs optimal matching | ||
identity_only=self.identity_only, | ||
max_overlap=1, | ||
) | ||
|
||
if frame is not None: | ||
pose = self.get_pose(frame, **kwargs) | ||
else: | ||
pose = None | ||
|
||
self.is_initialized = True | ||
|
||
return pose | ||
|
||
def get_pose(self, frame=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now might also be a good time to add docstrings to these methods too -- their function is straightforward enough, but documenting what goes on in them, etc. either here or documenting the arguments/attrs in the |
||
if frame is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. iirc |
||
raise DLCLiveError("No frame provided for live pose estimation") | ||
|
||
frame = self.process_frame(frame) | ||
data_dict = predict_multianimal.predict_batched_peaks_and_costs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure what happens from here to L627, some comments would be nice! |
||
self.cfg, | ||
np.expand_dims(frame, axis=0), | ||
self.sess, | ||
self.inputs, | ||
self.outputs, | ||
)[0] | ||
|
||
assemblies, unique = self.ass._assemble(data_dict, ind_frame=0) | ||
pose = np.full((self.n_animals, self.n_multibodyparts, 4), np.nan) | ||
if self.n_animals == 1: | ||
pose[0] = assemblies[0].data | ||
else: | ||
animals = np.stack([a.data for a in assemblies]) | ||
if not self.ass.identity_only: | ||
if self.track_method == "box": | ||
xy = trackingutils.calc_bboxes_from_keypoints(animals) | ||
else: | ||
xy = animals[..., :2] | ||
trackers = self.mot_tracker.track(xy)[:, -2:].astype(np.int) | ||
else: | ||
# Optimal identity assignment based on soft voting | ||
mat = np.zeros( | ||
(len(assemblies), self.n_animals) | ||
) | ||
for nrow, assembly in enumerate(assemblies): | ||
for k, v in assembly.soft_identity.items(): | ||
mat[nrow, k] = v | ||
inds = linear_sum_assignment(mat, maximize=True) | ||
trackers = np.c_[inds][:, ::-1] | ||
for i, j in trackers: | ||
pose[i] = animals[j] | ||
self.pose = (pose, unique) | ||
|
||
if self.display is not None: | ||
self.display.display_frame(frame, self.pose) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did this get tested? not sure if the maDLC pose is any different than regular |
||
|
||
if self.resize is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing corrections for |
||
self.pose[0][..., :2] *= 1 / self.resize | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like this should happen for |
||
self.pose[1][:, :2] *= 1 / self.resize | ||
|
||
if self.processor: | ||
self.pose = self.processor.process(self.pose, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does the processor need to be different for maDLC pose? |
||
|
||
return self.pose | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type hint!
model_path: Union[Path, str]
n_animals: int
n_multibodyparts: int