Skip to content

Commit

Permalink
remove UniformClipSampler assertion (#49)
Browse files Browse the repository at this point in the history
Summary:
## Motivation and Context

Adds the ability to set the stride < window size for UniformClipSampler. Why?
- Feature extraction on imagenet based models for videos, e.g. every 5 frames with a single frame retrieved

## How Has This Been Tested

- Added some unit tests
- Run existing unit tests

## Other comments

This was a bit difficult to implement due to the way clip samplers are designed. Personally I think this can be re-designed such that the implementation is much simpler. As an example, the implementation of UniformClipSampler could be ~20 lines of code compared to the current implementation:

```
def _num_clips(...) -> int:
      num_frames = round(duration_sec * fps)
      N = num_frames - window_size_frames
      if N < 0:
          return 1
      result = int(N / stride_frames + 1)
      pad = backpad_last and N % stride_frames != 0
      return result + pad

start_end_times = [
    (
      (i * s_prime) / fps,
      ((i * s_prime + window_size) / fps)
    )
    for i in range(num_clips(...))
]
if expected_start_end_times[-1][1] - video_length > 1e-6:
     expected_start_end_times[-1] = (video_length - window_size / fps, video_length)
```

X-link: fairinternal/pytorchvideo#49

Reviewed By: lyttonhao

Differential Revision: D34811417

Pulled By: miguelmartin75

fbshipit-source-id: b9d93d5446d16008ede20e61a6314a89e6b1364e
miguelmartin75 authored and facebook-github-bot committed Apr 7, 2022
1 parent 5e58541 commit 104257a
Showing 3 changed files with 129 additions and 54 deletions.
74 changes: 47 additions & 27 deletions pytorchvideo/data/clip_sampling.py
Original file line number Diff line number Diff line change
@@ -58,7 +58,7 @@ def __init__(self, clip_duration: Union[float, Fraction]) -> None:
@abstractmethod
def __call__(
self,
last_clip_time: Union[float, Fraction],
last_clip_end_time: Union[float, Fraction],
video_duration: Union[float, Fraction],
annotation: Dict[str, Any],
) -> ClipInfo:
@@ -111,7 +111,7 @@ def __init__(
Args:
clip_duration (Union[float, Fraction]):
The length of the clip to sample (in seconds).
stride (floUnion[float, Fraction]at, optional):
stride (Union[float, Fraction], optional):
The amount of seconds to offset the next clip by
default value of None is equivalent to no stride => stride == clip_duration.
eps (float):
@@ -124,35 +124,33 @@ def __init__(
(1.0667s). The clips will be (in frame numbers):
with backpad_last = False
- [0, 32]
- [0, 31]
with backpad_last = True
- [0, 32]
- [8, 40], this is "back-padded" from [16, 48] to fit the last window
- [0, 31]
- [8, 39], this is "back-padded" from [16, 48] to fit the last window
Note that you can use Fraction for clip_duration and stride if you want to
avoid float precision issue and need accurate frames in each clip.
"""
super().__init__(clip_duration)
self._stride = stride if stride is not None else clip_duration
self._stride = stride if stride is not None else self._clip_duration
self._eps = eps
self._backpad_last = backpad_last

assert (
self._stride > 0 and self._stride <= clip_duration
), f"stride must be >0 and <= clip_duration ({clip_duration})"
assert self._stride > 0, "stride must be positive"

def _clip_start_end(
self,
last_clip_time: Union[float, Fraction],
last_clip_end_time: Union[float, Fraction],
video_duration: Union[float, Fraction],
backpad_last: bool,
) -> Tuple[Fraction, Fraction]:
"""
Helper to calculate the start/end clip with backpad logic
"""
clip_start = Fraction(
max(last_clip_time - max(0, self._clip_duration - self._stride), 0)
)
delta = self._stride - self._clip_duration
last_end_time = -delta if last_clip_end_time is None else last_clip_end_time
clip_start = Fraction(last_end_time + delta)
clip_end = Fraction(clip_start + self._clip_duration)
if backpad_last:
buffer_amount = max(0, clip_end - video_duration)
@@ -163,11 +161,14 @@ def _clip_start_end(
return clip_start, clip_end

def __call__(
self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any]
self,
last_clip_end_time: Optional[float],
video_duration: float,
annotation: Dict[str, Any],
) -> ClipInfo:
"""
Args:
last_clip_time (float): the last clip end time sampled from this video. This
last_clip_end_time (float): the last clip end time sampled from this video. This
should be 0.0 if the video hasn't had clips sampled yet.
video_duration: (float): the duration of the video that's being sampled in seconds
annotation (Dict): Not used by this sampler.
@@ -178,7 +179,7 @@ def __call__(
to be sampled.
"""
clip_start, clip_end = self._clip_start_end(
last_clip_time, video_duration, backpad_last=self._backpad_last
last_clip_end_time, video_duration, backpad_last=self._backpad_last
)

# if they both end at the same time - it's the last clip
@@ -188,7 +189,7 @@ def __call__(
if self._backpad_last:
is_last_clip = abs(next_clip_end - clip_end) < self._eps
else:
is_last_clip = next_clip_end > video_duration
is_last_clip = (next_clip_end - video_duration) > self._eps

clip_index = self._current_clip_index
self._current_clip_index += 1
@@ -221,14 +222,19 @@ def __init__(
self.truncation_duration = truncation_duration

def __call__(
self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any]
self,
last_clip_end_time: float,
video_duration: float,
annotation: Dict[str, Any],
) -> ClipInfo:

truncated_video_duration = video_duration
if self.truncation_duration is not None:
truncated_video_duration = min(self.truncation_duration, video_duration)

return super().__call__(last_clip_time, truncated_video_duration, annotation)
return super().__call__(
last_clip_end_time, truncated_video_duration, annotation
)


class RandomClipSampler(ClipSampler):
@@ -237,11 +243,14 @@ class RandomClipSampler(ClipSampler):
"""

def __call__(
self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any]
self,
last_clip_end_time: float,
video_duration: float,
annotation: Dict[str, Any],
) -> ClipInfo:
"""
Args:
last_clip_time (float): Not used for RandomClipSampler.
last_clip_end_time (float): Not used for RandomClipSampler.
video_duration: (float): the duration (in seconds) for the video that's
being sampled
annotation (Dict): Not used by this sampler.
@@ -268,7 +277,10 @@ def __init__(self, clip_duration: float, num_clips: int) -> None:
self._num_clips = num_clips

def __call__(
self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any]
self,
last_clip_end_time: Optional[float],
video_duration: float,
annotation: Dict[str, Any],
) -> ClipInfoList:

(
@@ -291,7 +303,7 @@ def __call__(
clip_index_list[i],
aug_index_list[i],
is_last_clip_list[i],
) = super().__call__(last_clip_time, video_duration, annotation)
) = super().__call__(last_clip_end_time, video_duration, annotation)

return ClipInfoList(
clip_start_list,
@@ -316,14 +328,19 @@ def __init__(
self.truncation_duration = truncation_duration

def __call__(
self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any]
self,
last_clip_end_time: Optional[float],
video_duration: float,
annotation: Dict[str, Any],
) -> ClipInfoList:

truncated_video_duration = video_duration
if self.truncation_duration is not None:
truncated_video_duration = min(self.truncation_duration, video_duration)

return super().__call__(last_clip_time, truncated_video_duration, annotation)
return super().__call__(
last_clip_end_time, truncated_video_duration, annotation
)


class ConstantClipsPerVideoSampler(ClipSampler):
@@ -340,11 +357,14 @@ def __init__(
self._augs_per_clip = augs_per_clip

def __call__(
self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any]
self,
last_clip_end_time: Optional[float],
video_duration: float,
annotation: Dict[str, Any],
) -> ClipInfo:
"""
Args:
last_clip_time (float): Not used for ConstantClipsPerVideoSampler.
last_clip_end_time (float): Not used for ConstantClipsPerVideoSampler.
video_duration: (float): the duration (in seconds) for the video that's
being sampled.
annotation (Dict): Not used by this sampler.
10 changes: 4 additions & 6 deletions pytorchvideo/data/labeled_video_dataset.py
Original file line number Diff line number Diff line change
@@ -81,7 +81,7 @@ def __init__(
# clip time in these variables.
self._loaded_video_label = None
self._loaded_clip = None
self._next_clip_start_time = 0.0
self._last_clip_end_time = None
self.video_path_handler = VideoPathHandler()

@property
@@ -153,9 +153,7 @@ def __next__(self) -> dict:
clip_index,
aug_index,
is_last_clip,
) = self._clip_sampler(
self._next_clip_start_time, video.duration, info_dict
)
) = self._clip_sampler(self._last_clip_end_time, video.duration, info_dict)

if isinstance(clip_start, list): # multi-clip in each sample

@@ -182,7 +180,7 @@ def __next__(self) -> dict:
if aug_index == 0:
self._loaded_clip = video.get_clip(clip_start, clip_end)

self._next_clip_start_time = clip_end
self._last_clip_end_time = clip_end

video_is_null = (
self._loaded_clip is None or self._loaded_clip["video"] is None
@@ -194,7 +192,7 @@ def __next__(self) -> dict:
# to sample a new video on the next iteration.
self._loaded_video_label[0].close()
self._loaded_video_label = None
self._next_clip_start_time = 0.0
self._last_clip_end_time = None
self._clip_sampler.reset()
if video_is_null:
logger.debug(
99 changes: 78 additions & 21 deletions tests/test_uniform_clip_sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import copy
import unittest
from typing import Optional

import numpy as np
from parameterized import parameterized
from pytorchvideo.data.clip_sampling import UniformClipSampler

@@ -21,13 +23,9 @@ def _num_clips(
N = num_frames - window_size_frames
if N < 0:
return 1

result = N // stride_frames + 1

# handle padded frame
if backpad_last and N % stride_frames != 0:
result += 1
return result
result = int(N / stride_frames + 1)
pad = backpad_last and N % stride_frames != 0
return result + pad


class TestUniformClipSampler(unittest.TestCase):
@@ -69,6 +67,22 @@ class TestUniformClipSampler(unittest.TestCase):
# > half stride
(False, 30, 32, 24, 107 / 30, 4),
(True, 30, 32, 24, 107 / 30, 5),
(False, 30, 5, 1, 11 / 30, 7),
(True, 30, 5, 1, 11 / 30, 7),
# stride > window size
(False, 30, 1, 5, 11 / 30, 3),
(True, 30, 1, 5, 11 / 30, 3),
(True, 30, 1, 5, 1759 / 30, 353),
(False, 30, 3, 10, 132 / 30, 13),
(True, 30, 3, 10, 132 / 30, 14),
(False, 30, 6, 10, 111 / 30, 11),
(True, 30, 6, 10, 111 / 30, 12),
# stride <= window size
(False, 30, 10, 3, 132 / 30, 41),
(True, 30, 10, 3, 132 / 30, 42),
(False, 30, 10, 6, 111 / 30, 17),
(True, 30, 10, 6, 111 / 30, 18),
(True, 30, 1, 1, 132 / 30, 132),
]
)
def test_uniform_clip_sampler(
@@ -88,30 +102,68 @@ def test_uniform_clip_sampler(
stride_frames / fps if stride_frames is not None else None,
backpad_last=backpad_last,
)
predicted_n_clips = _num_clips(
video_length,
fps,
stride_frames=stride_frames if stride_frames is not None else window_size,
window_size_frames=window_size,
backpad_last=backpad_last,
)
self.assertEqual(predicted_n_clips, expected_number_of_clips)

s_prime = stride_frames if stride_frames is not None else window_size
expected_start_end_times = [
((i * s_prime) / fps, ((i * s_prime + window_size) / fps))
for i in range(expected_number_of_clips)
]
if expected_start_end_times[-1][1] - video_length > 1e-6:
expected_start_end_times[-1] = (
video_length - window_size / fps,
video_length,
)

self.assertTrue(
(
expected_start_end_times[-1][0] + (s_prime / fps) > video_length
or expected_start_end_times[-1][-1] + (s_prime / fps) > video_length
)
)
if len(expected_start_end_times) >= 2:
self.assertNotAlmostEqual(
expected_start_end_times[-2][0], expected_start_end_times[-1][0]
)
self.assertNotAlmostEqual(
expected_start_end_times[-2][1], expected_start_end_times[-1][1]
)

start_end_times = []

last_clip_time = 0
last_clip_time = None
annotation = {}
n_clips = 0
while True:
clip = sampler(last_clip_time, video_length, annotation)
last_clip_time = clip.clip_end_sec
n_clips += 1
last_clip_time = copy.deepcopy(clip.clip_end_sec)
n_frames = (clip.clip_end_sec - clip.clip_start_sec) * fps
int_n_frames = int(np.round(float(n_frames)))
self.assertAlmostEqual(float(int_n_frames), float(n_frames))
self.assertEqual(int_n_frames, window_size)

start_end_times.append(
(float(clip.clip_start_sec), float(clip.clip_end_sec))
)
if clip.is_last_clip:
break

# just in case we get an infinite loop
if n_clips > 2 * expected_number_of_clips:
if len(start_end_times) > 2 * expected_number_of_clips:
break

predicted_n_clips = _num_clips(
video_length,
fps,
stride_frames=stride_frames if stride_frames is not None else window_size,
window_size_frames=window_size,
backpad_last=backpad_last,
)
self.assertEqual(predicted_n_clips, expected_number_of_clips)
self.assertEqual(n_clips, expected_number_of_clips)
self.assertEqual(len(start_end_times), expected_number_of_clips)
for (start, end), (expected_start, expected_end) in zip(
start_end_times, expected_start_end_times
):
self.assertAlmostEqual(float(start), expected_start)
self.assertAlmostEqual(float(end), expected_end)

@parameterized.expand(
[
@@ -142,6 +194,11 @@ def test_uniform_clip_sampler(
(19 / 30, 30, 1, 32, True, 1),
(33 / 30, 30, 1, 32, False, 2),
(33 / 30, 30, 1, 32, True, 2),
(11 / 30, 30, 1, 5, False, 7),
(11 / 30, 30, 1, 5, True, 7),
(11 / 30, 30, 5, 1, False, 3),
(11 / 30, 30, 5, 1, True, 3),
(1759 / 30, 30, 5, 1, True, 353),
]
)
def test_num_clips(

0 comments on commit 104257a

Please sign in to comment.