Skip to content

Commit

Permalink
Support tracklet interpolation in ByteTrack (#385)
Browse files Browse the repository at this point in the history
* support tracklet interpolation

* update readme.md

* update readme.md

* update readme.md

* update readme.md

* update based on 1-st comments

* fix a typo

* update docstrings

* update docstrings
  • Loading branch information
GT9505 authored Dec 31, 2021
1 parent 522146e commit 07484a3
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 20 deletions.
11 changes: 8 additions & 3 deletions configs/mot/bytetrack/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ Multi-object tracking (MOT) aims at estimating bounding boxes and identities of

## Results and models on MOT17

| Method | Detector | Train Set | Test Set | Public | Inf time (fps) | MOTA | IDF1 | FP | FN | IDSw. | Config | Download |
| :-------------: | :-----------------: | :------------: | :------: | :----: | :------------: | :--: | :--: |:--:|:--:| :---: | :----: | :------: |
| ByteTrack | YOLOX-X | CrowdHuman + half-train | half-val | N | - | 78.3 | 77.2 | 10845 | 24588 | 1425 | [config](bytetrack_yolox_x_crowdhuman_mot17-private-half.py) | [model](https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500-1985c9f0.pth) | [log](https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500.log.json) |
Please note that the performance on `MOT17-half-val` is comparable with the performance reported in the manuscript, while the performance on `MOT17-test` is lower than the performance reported in the manuscript.

The reason is that ByteTrack tunes customized hyper-parameters (e.g., image resolution and the high threshold of detection score) for each video in `MOT17-test` set, while we use unified parameters.

| Method | Detector | Train Set | Test Set | Public | Inf time (fps) | HOTA | MOTA | IDF1 | FP | FN | IDSw. | Config | Download |
| :-------------: | :-----------------: | :------------: | :------: | :----: | :------------: | :--: | :--: | :--: |:--:|:--:| :---: | :----: | :------: |
| ByteTrack | YOLOX-X | CrowdHuman + half-train | half-val | N | - | - | 78.6 | 79.2 | 12909 | 21024 | 666 | [config](bytetrack_yolox_x_crowdhuman_mot17-private-half.py) | [model](https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500-1985c9f0.pth) | [log](https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500.log.json) |
| ByteTrack | YOLOX-X | CrowdHuman + half-train | test | N | - | 61.7 | 78.1 | 74.8 | 36705 | 85032 | 2049 | [config](bytetrack_yolox_x_crowdhuman_mot17-private.py) | [model](https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500-1985c9f0.pth) | [log](https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500.log.json) |
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,12 @@
],
filter_empty_gt=False),
pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
val=dict(
pipeline=test_pipeline,
interpolate_tracks_cfg=dict(min_num_frames=5, max_num_frames=20)),
test=dict(
pipeline=test_pipeline,
interpolate_tracks_cfg=dict(min_num_frames=5, max_num_frames=20)))

# optimizer
# default 8 gpu
Expand Down
17 changes: 15 additions & 2 deletions configs/mot/bytetrack/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ Models:
- Task: Multiple Object Tracking
Dataset: MOT17-half-val
Metrics:
MOTA: 77.2
IDF1: 78.3
MOTA: 78.6
IDF1: 79.2
Weights: https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500-1985c9f0.pth

- Name: bytetrack_yolox_x_crowdhuman_mot17-private
In Collection: ByteTrack
Config: configs/mot/bytetrack/bytetrack_yolox_x_crowdhuman_mot17-private.py
Metadata:
Training Data: CrowdHuman + MOT17-half-train
Results:
- Task: Multiple Object Tracking
Dataset: MOT17-test
Metrics:
MOTA: 78.1
IDF1: 74.8
Weights: https://download.openmmlab.com/mmtracking/mot/bytetrack/bytetrack_yolox_x/bytetrack_yolox_x_crowdhuman_mot17-private-half_20211218_205500-1985c9f0.pth
2 changes: 1 addition & 1 deletion configs/mot/tracktor/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Our implementation outperform it by 4.9 points on MOTA and 3.3 points on IDF1.
| Tracktor | R50-FasterRCNN-FPN | R50 | half-train | half-val | Y | 3.2 | 57.3 | 63.4 | 1254 | 67091 | 614 | [config](tracktor_faster-rcnn_r50_fpn_4e_mot17-public-half.py) | [detector](https://download.openmmlab.com/mmtracking/mot/faster_rcnn/faster-rcnn_r50_fpn_4e_mot17-half-64ee2ed4.pth) [reid](https://download.openmmlab.com/mmtracking/mot/reid/reid_r50_6e_mot17-4bf6b63d.pth) |
| Tracktor | R50-FasterRCNN-FPN | R50 | half-train | half-val | N | 3.1 | 64.1 | 66.9 | 11088 | 45762 | 1233 | [config](tracktor_faster-rcnn_r50_fpn_4e_mot17-private-half.py) | [detector](https://download.openmmlab.com/mmtracking/mot/faster_rcnn/faster-rcnn_r50_fpn_4e_mot17-half-64ee2ed4.pth) [reid](https://download.openmmlab.com/mmtracking/mot/reid/reid_r50_6e_mot17-4bf6b63d.pth) |
| Tracktor | R50-FasterRCNN-FPN | R50 | train | test | Y | 3.2 | 61.2 | 58.4 | 8609 | 207627 | 2634 | [config](tracktor_faster-rcnn_r50_fpn_4e_mot17-public.py) | [detector](https://download.openmmlab.com/mmtracking/mot/faster_rcnn/faster-rcnn_r50_fpn_4e_mot17-ffa52ae7.pth) [reid](https://download.openmmlab.com/mmtracking/mot/reid/reid_r50_6e_mot17-4bf6b63d.pth) |
| Tracktor | R50-FasterRCNN-FPN* | R50 | train | test | Y | - | 56.3 | 55.1 | 8866 | 235449 | 1987 | - | - |
| Tracktor* | R50-FasterRCNN-FPN | R50 | train | test | Y | - | 56.3 | 55.1 | 8866 | 235449 | 1987 | - | - |
| Tracktor <br> (FP16) | R50-FasterRCNN-FPN | R50 | half-train | half-val | N | - | 64.7 | 66.6 | 10710 | 45270 | 1152 | [config](tracktor_faster-rcnn_r50_fpn_fp16_4e_mot17-private-half.py) | [detector](https://download.openmmlab.com/mmtracking/fp16/faster-rcnn_r50_fpn_fp16_4e_mot17-half_20210730_002436-f4ba7d61.pth) &#124; [detector_log](https://download.openmmlab.com/mmtracking/fp16/faster-rcnn_r50_fpn_fp16_4e_mot17-half_20210730_002436.log.json) &#124; [reid](https://download.openmmlab.com/mmtracking/fp16/reid_r50_fp16_8x32_6e_mot17_20210731_033055-4747ee95.pth) &#124; [reid_log](https://download.openmmlab.com/mmtracking/fp16/reid_r50_fp16_8x32_6e_mot17_20210731_033055.log.json) |

Note:
Expand Down
3 changes: 2 additions & 1 deletion mmtrack/core/track/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .correlation import depthwise_correlation
from .interpolation import interpolate_tracks
from .similarity import embed_similarity
from .transforms import imrenormalize, outs2results, results2outs

__all__ = [
'depthwise_correlation', 'outs2results', 'results2outs',
'embed_similarity', 'imrenormalize'
'embed_similarity', 'imrenormalize', 'interpolate_tracks'
]
87 changes: 87 additions & 0 deletions mmtrack/core/track/interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np


def _interpolate_track(track, track_id, max_num_frames=20):
"""Interpolate a track linearly to make the track more complete.
Args:
track (ndarray): With shape (N, 7). Each row denotes
(frame_id, track_id, x1, y1, x2, y2, score).
max_num_frames (int, optional): The maximum disconnected length in the
track. Defaults to 20.
Returns:
ndarray: The interpolated track with shape (N, 7). Each row denotes
(frame_id, track_id, x1, y1, x2, y2, score)
"""
assert (track[:, 1] == track_id).all(), \
'The track id should not changed when interpolate a track.'

frame_ids = track[:, 0]
interpolated_track = np.zeros((0, 7))
# perform interpolation for the disconnected frames in the track.
for i in np.where(np.diff(frame_ids) > 1)[0]:
left_frame_id = frame_ids[i]
right_frame_id = frame_ids[i + 1]
num_disconnected_frames = int(right_frame_id - left_frame_id)

if 1 < num_disconnected_frames < max_num_frames:
left_bbox = track[i, 2:6]
right_bbox = track[i + 1, 2:6]

# perform interpolation for two adjacent tracklets.
for j in range(1, num_disconnected_frames):
cur_bbox = j / (num_disconnected_frames) * (
right_bbox - left_bbox) + left_bbox
cur_result = np.ones((7, ))
cur_result[0] = j + left_frame_id
cur_result[1] = track_id
cur_result[2:6] = cur_bbox

interpolated_track = np.concatenate(
(interpolated_track, cur_result[None]), axis=0)

interpolated_track = np.concatenate((track, interpolated_track), axis=0)
return interpolated_track


def interpolate_tracks(tracks, min_num_frames=5, max_num_frames=20):
"""Interpolate tracks linearly to make tracks more complete.
This function is proposed in
"ByteTrack: Multi-Object Tracking by Associating Every Detection Box."
`ByteTrack<https://arxiv.org/abs/2110.06864>`_.
Args:
tracks (ndarray): With shape (N, 7). Each row denotes
(frame_id, track_id, x1, y1, x2, y2, score).
min_num_frames (int, optional): The minimum length of a track that will
be interpolated. Defaults to 5.
max_num_frames (int, optional): The maximum disconnected length in
a track. Defaults to 20.
Returns:
ndarray: The interpolated tracks with shape (N, 7). Each row denotes
(frame_id, track_id, x1, y1, x2, y2, score)
"""
max_track_id = int(np.max(tracks[:, 1]))
min_track_id = int(np.min(tracks[:, 1]))

# perform interpolation for each track
interpolated_tracks = []
for track_id in range(min_track_id, max_track_id + 1):
inds = tracks[:, 1] == track_id
track = tracks[inds]
num_frames = len(track)
if num_frames <= 2:
continue

if num_frames > min_num_frames:
interpolated_track = _interpolate_track(track, track_id,
max_num_frames)
else:
interpolated_track = track
interpolated_tracks.append(interpolated_track)

interpolated_tracks = np.concatenate(interpolated_tracks)
return interpolated_tracks[interpolated_tracks[:, 0].argsort()]
48 changes: 37 additions & 11 deletions mmtrack/datasets/mot_challenge_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mmdet.core import eval_map
from mmdet.datasets import DATASETS

from mmtrack.core import results2outs
from mmtrack.core import interpolate_tracks, results2outs
from .coco_video_dataset import CocoVideoDataset


Expand All @@ -21,6 +21,12 @@ class MOTChallengeDataset(CocoVideoDataset):
Args:
visibility_thr (float, optional): The minimum visibility
for the objects during training. Default to -1.
interpolate_tracks_cfg (dict, optional): If not None, Interpolate
tracks linearly to make tracks more complete. Defaults to None.
- min_num_frames (int, optional): The minimum length of a track
that will be interpolated. Defaults to 5.
- max_num_frames (int, optional): The maximum disconnected length
in a track. Defaults to 20.
detection_file (str, optional): The path of the public
detection file. Default to None.
"""
Expand All @@ -29,11 +35,13 @@ class MOTChallengeDataset(CocoVideoDataset):

def __init__(self,
visibility_thr=-1,
interpolate_tracks_cfg=None,
detection_file=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.visibility_thr = visibility_thr
self.interpolate_tracks_cfg = interpolate_tracks_cfg
self.detections = self.load_detections(detection_file)

def load_detections(self, detection_file=None):
Expand Down Expand Up @@ -181,21 +189,39 @@ def format_results(self, results, resfile_path=None, metrics=['track']):

def format_track_results(self, results, infos, resfile):
"""Format tracking results."""

results_per_video = []
for frame_id, result in enumerate(results):
outs_track = results2outs(bbox_results=result)
track_ids, bboxes = outs_track['ids'], outs_track['bboxes']
frame_ids = np.full_like(track_ids, frame_id)
results_per_frame = np.concatenate(
(frame_ids[:, None], track_ids[:, None], bboxes), axis=1)
results_per_video.append(results_per_frame)
# `results_per_video` is a ndarray with shape (N, 7). Each row denotes
# (frame_id, track_id, x1, y1, x2, y2, score)
results_per_video = np.concatenate(results_per_video)

if self.interpolate_tracks_cfg is not None:
results_per_video = interpolate_tracks(
results_per_video, **self.interpolate_tracks_cfg)

with open(resfile, 'wt') as f:
for res, info in zip(results, infos):
for frame_id, info in enumerate(infos):
# `mot_frame_id` is the actually frame id used for evaluation.
# It may not start from 0.
if 'mot_frame_id' in info:
frame = info['mot_frame_id']
mot_frame_id = info['mot_frame_id']
else:
frame = info['frame_id'] + 1
mot_frame_id = info['frame_id'] + 1

outs_track = results2outs(bbox_results=res)
for bbox, label, id in zip(outs_track['bboxes'],
outs_track['labels'],
outs_track['ids']):
x1, y1, x2, y2, conf = bbox
results_per_frame = \
results_per_video[results_per_video[:, 0] == frame_id]
for i in range(len(results_per_frame)):
_, track_id, x1, y1, x2, y2, conf = results_per_frame[i]
f.writelines(
f'{frame},{id},{x1:.3f},{y1:.3f},{(x2-x1):.3f},' +
f'{(y2-y1):.3f},{conf:.3f},-1,-1,-1\n')
f'{mot_frame_id},{track_id},{x1:.3f},{y1:.3f},' +
f'{(x2-x1):.3f},{(y2-y1):.3f},{conf:.3f},-1,-1,-1\n')

def format_bbox_results(self, results, infos, resfile):
"""Format detection results."""
Expand Down
21 changes: 21 additions & 0 deletions tests/test_core/test_track/test_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np


def test_interpolate_tracks():
from mmtrack.core import interpolate_tracks
frame_id = np.arange(100) // 10
tracklet_id = np.random.randint(low=1, high=5, size=(100))
bboxes = np.random.random((100, 4)) * 100
scores = np.random.random((100)) * 100
in_results = np.concatenate(
(frame_id[:, None], tracklet_id[:, None], bboxes, scores[:, None]),
axis=1)
out_results = interpolate_tracks(in_results)
assert out_results.shape[1] == in_results.shape[1]
# the range of frame ids should not change
assert min(out_results[:, 0]) == min(in_results[:, 0])
assert max(out_results[:, 0]) == max(in_results[:, 0])
# the range of track ids should not change
assert min(out_results[:, 1]) == min(in_results[:, 1])
assert max(out_results[:, 1]) == max(in_results[:, 1])

0 comments on commit 07484a3

Please sign in to comment.