Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Person ReID and keypoint retrieval #150

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions pytorchvideo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@

from .ava import Ava # noqa
from .charades import Charades # noqa
from .clip_sampling import ( # noqa; noqa
ClipSampler,
RandomClipSampler,
UniformClipSampler,
make_clip_sampler,
)
from .domsev import DomsevFrameDataset, DomsevVideoDataset # noqa
from .epic_kitchen_forecasting import EpicKitchenForecasting # noqa
from .epic_kitchen_recognition import EpicKitchenRecognition # noqa
Expand All @@ -16,3 +10,11 @@
from .labeled_video_dataset import LabeledVideoDataset, labeled_video_dataset # noqa
from .ssv2 import SSv2
from .ucf101 import Ucf101 # noqa


from .clip_sampling import ( # noqa; noqa
ClipSampler,
RandomClipSampler,
UniformClipSampler,
make_clip_sampler,
)
2 changes: 1 addition & 1 deletion pytorchvideo/data/clip_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
from abc import ABC, abstractmethod
from fractions import Fraction
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union, List
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union


class ClipInfo(NamedTuple):
Expand Down
80 changes: 80 additions & 0 deletions pytorchvideo/neural_engine/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,30 @@ def full_decode(status: OrderedDict, **args):
decode_audio = args.get("decode_audio", True)
video = EncodedVideo.from_path(status["path"], decode_audio, decoder)
frames = video.get_clip(0, video.duration)

return frames


def center_keypoints_in_bbox(bboxes_per_frame, keypoints_per_frame):
# calculate bbox center (x1, y1, x2, y2)
bboxes_per_frame_center_x = (
bboxes_per_frame[:, 0] + bboxes_per_frame[:, 2]
) / 2 # (x1+x2)/2
bboxes_per_frame_center_y = (
bboxes_per_frame[:, 1] + bboxes_per_frame[:, 3]
) / 2 # (y1+y2)/2

# change origin of the keypoints to center of each bbox
keypoints_per_frame[:, :, 0] = keypoints_per_frame[
:, :, 0
] - bboxes_per_frame_center_x.unsqueeze(1)
keypoints_per_frame[:, :, 1] = keypoints_per_frame[
:, :, 1
] - bboxes_per_frame_center_y.unsqueeze(1)

return keypoints_per_frame


class DecodeHook(HookBase):
def __init__(
self,
Expand All @@ -87,20 +108,79 @@ def __init__(
# Decoding params
self.decode_audio = decode_audio
self.decoder = decoder

# Hook params
self.executor = executor
self.inputs = ["path"]
self.outputs = ["video", "audio"] if decode_audio else ["video"]
self.fail_strategy = fail_strategy
self.priority = priority

# frame and bounding-box tracker
self.frame_tracker = []

def _populate_frame_tracker(self, model, frames):
"""
Generates a data structure to track bounding boxes and
keypoint coordinates. Useful for extracting the frame-id given
the bounding number from a video for action-recognition.
"""

for frame_id, frame in enumerate(frames):
model_outputs = model.predict(frame)

# get bounding-box coordinates (x1, y1, x2, y2)
bboxes_per_frame = (
model_outputs["instances"][model_outputs["instances"].pred_classes == 0]
.pred_boxes.tensor.to("cpu")
.squeeze()
)

# get keypoints (slice to select only the x,y coordinates)
keypoints_per_frame = (
model_outputs["instances"][model_outputs["instances"].pred_classes == 0]
.pred_keypoints[:, :, :2]
.to("cpu")
)

# center keypoints wrt to the respective bounding box centers
keypoints_per_frame = center_keypoints_in_bbox(
bboxes_per_frame=bboxes_per_frame,
keypoints_per_frame=keypoints_per_frame,
)

# sanity check
if bboxes_per_frame.shape[0] != keypoints_per_frame.shape[0]:
raise ValueError(
"bboxes_per_frame and keypoints_per_frame should have same 0th dim."
)

# append bbox_info to frame_tracker
for i in range(bboxes_per_frame.shape[0]):
bbox_coord = bboxes_per_frame[i, :]
keypoint_per_bbox = keypoints_per_frame[i, :, :]

bbox_info = {
"frame_id": frame_id,
"bbox_id": i,
"person_id": None,
"bbox_coord": bbox_coord,
"keypoint_coord": keypoint_per_bbox,
}

self.frame_tracker.append(bbox_info)

def _run(
self,
status: OrderedDict,
):
frames = self.executor(
status, decode_audio=self.decode_audio, decoder=self.decoder
)

# populate the frame tracker while decoding videos
self._populate_frame_tracker(frames=frames)

return frames


Expand Down
58 changes: 58 additions & 0 deletions pytorchvideo/neural_engine/retrieval_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from collections import OrderedDict
from typing import Callable

import torch
from hook import HookBase


def create_keypoint_features_db(frame_tracker):
return torch.stack([bbox["keypoint_coord"].flatten() for bbox in frame_tracker])


def calculate_distance_scores(action_query, keypoint_feature_db):
scores = torch.nn.functional.cosine_similarity(action_query, keypoint_feature_db)
return scores


def get_closest_keypoint_feature_match(scores, method, n):
if method == "topk":
return torch.topk(scores, n).indices.tolist()
elif method == "softmax":
score_probs = torch.nn.functional.softmax(scores, dim=0)
return (score_probs > n).nonzero().squeeze().tolist()


def bbox_to_frame_executor(frame_tracker, best_bbox_matches):
return [frame_tracker[bbox_id]["frame_id"] for bbox_id in best_bbox_matches]


class PeopleKeypointRetrievalHook(HookBase):
def __init__(self, executor: Callable = bbox_to_frame_executor):
self.executor = executor
self.inputs = ["frame_tracker", "action_query"]
self.outputs = ["frame_id"]

def _run(self, status: OrderedDict):
# extract frame_tracker and action_query feature
frame_tracker = status["frame_tracker"]
action_query = status["action_query"]

# combine multiple keypoint features into a single tensor
keypoint_feature_db = create_keypoint_features_db(frame_tracker)

# find feature closest to action_query from the keypoint_feature_db
distance_scores = calculate_distance_scores(
action_query=action_query, keypoint_feature_db=keypoint_feature_db
)

# extract the index (bbox_id) of the best matches
best_bbox_match_list = get_closest_keypoint_feature_match(
scores=distance_scores, method="softmax", n=0.9
)

# get frame_id_list from the best_bbox_match_list
frame_id_list = self.executor(
frame_tracker=frame_tracker, best_bbox_matches=best_bbox_match_list
)

return {"frame_id_list": frame_id_list}