diff --git a/pytorchvideo/data/clip_sampling.py b/pytorchvideo/data/clip_sampling.py index b91a102b..2e8d5938 100644 --- a/pytorchvideo/data/clip_sampling.py +++ b/pytorchvideo/data/clip_sampling.py @@ -58,7 +58,7 @@ def __init__(self, clip_duration: Union[float, Fraction]) -> None: @abstractmethod def __call__( self, - last_clip_time: Union[float, Fraction], + last_clip_end_time: Union[float, Fraction], video_duration: Union[float, Fraction], annotation: Dict[str, Any], ) -> ClipInfo: @@ -111,7 +111,7 @@ def __init__( Args: clip_duration (Union[float, Fraction]): The length of the clip to sample (in seconds). - stride (floUnion[float, Fraction]at, optional): + stride (Union[float, Fraction], optional): The amount of seconds to offset the next clip by default value of None is equivalent to no stride => stride == clip_duration. eps (float): @@ -124,35 +124,33 @@ def __init__( (1.0667s). The clips will be (in frame numbers): with backpad_last = False - - [0, 32] + - [0, 31] with backpad_last = True - - [0, 32] - - [8, 40], this is "back-padded" from [16, 48] to fit the last window + - [0, 31] + - [8, 39], this is "back-padded" from [16, 48] to fit the last window Note that you can use Fraction for clip_duration and stride if you want to avoid float precision issue and need accurate frames in each clip. """ super().__init__(clip_duration) - self._stride = stride if stride is not None else clip_duration + self._stride = stride if stride is not None else self._clip_duration self._eps = eps self._backpad_last = backpad_last - assert ( - self._stride > 0 and self._stride <= clip_duration - ), f"stride must be >0 and <= clip_duration ({clip_duration})" + assert self._stride > 0, "stride must be positive" def _clip_start_end( self, - last_clip_time: Union[float, Fraction], + last_clip_end_time: Union[float, Fraction], video_duration: Union[float, Fraction], backpad_last: bool, ) -> Tuple[Fraction, Fraction]: """ Helper to calculate the start/end clip with backpad logic """ - clip_start = Fraction( - max(last_clip_time - max(0, self._clip_duration - self._stride), 0) - ) + delta = self._stride - self._clip_duration + last_end_time = -delta if last_clip_end_time is None else last_clip_end_time + clip_start = Fraction(last_end_time + delta) clip_end = Fraction(clip_start + self._clip_duration) if backpad_last: buffer_amount = max(0, clip_end - video_duration) @@ -163,11 +161,14 @@ def _clip_start_end( return clip_start, clip_end def __call__( - self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] + self, + last_clip_end_time: Optional[float], + video_duration: float, + annotation: Dict[str, Any], ) -> ClipInfo: """ Args: - last_clip_time (float): the last clip end time sampled from this video. This + last_clip_end_time (float): the last clip end time sampled from this video. This should be 0.0 if the video hasn't had clips sampled yet. video_duration: (float): the duration of the video that's being sampled in seconds annotation (Dict): Not used by this sampler. @@ -178,7 +179,7 @@ def __call__( to be sampled. """ clip_start, clip_end = self._clip_start_end( - last_clip_time, video_duration, backpad_last=self._backpad_last + last_clip_end_time, video_duration, backpad_last=self._backpad_last ) # if they both end at the same time - it's the last clip @@ -188,7 +189,7 @@ def __call__( if self._backpad_last: is_last_clip = abs(next_clip_end - clip_end) < self._eps else: - is_last_clip = next_clip_end > video_duration + is_last_clip = (next_clip_end - video_duration) > self._eps clip_index = self._current_clip_index self._current_clip_index += 1 @@ -221,14 +222,19 @@ def __init__( self.truncation_duration = truncation_duration def __call__( - self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] + self, + last_clip_end_time: float, + video_duration: float, + annotation: Dict[str, Any], ) -> ClipInfo: truncated_video_duration = video_duration if self.truncation_duration is not None: truncated_video_duration = min(self.truncation_duration, video_duration) - return super().__call__(last_clip_time, truncated_video_duration, annotation) + return super().__call__( + last_clip_end_time, truncated_video_duration, annotation + ) class RandomClipSampler(ClipSampler): @@ -237,11 +243,14 @@ class RandomClipSampler(ClipSampler): """ def __call__( - self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] + self, + last_clip_end_time: float, + video_duration: float, + annotation: Dict[str, Any], ) -> ClipInfo: """ Args: - last_clip_time (float): Not used for RandomClipSampler. + last_clip_end_time (float): Not used for RandomClipSampler. video_duration: (float): the duration (in seconds) for the video that's being sampled annotation (Dict): Not used by this sampler. @@ -268,7 +277,10 @@ def __init__(self, clip_duration: float, num_clips: int) -> None: self._num_clips = num_clips def __call__( - self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] + self, + last_clip_end_time: Optional[float], + video_duration: float, + annotation: Dict[str, Any], ) -> ClipInfoList: ( @@ -291,7 +303,7 @@ def __call__( clip_index_list[i], aug_index_list[i], is_last_clip_list[i], - ) = super().__call__(last_clip_time, video_duration, annotation) + ) = super().__call__(last_clip_end_time, video_duration, annotation) return ClipInfoList( clip_start_list, @@ -316,14 +328,19 @@ def __init__( self.truncation_duration = truncation_duration def __call__( - self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] + self, + last_clip_end_time: Optional[float], + video_duration: float, + annotation: Dict[str, Any], ) -> ClipInfoList: truncated_video_duration = video_duration if self.truncation_duration is not None: truncated_video_duration = min(self.truncation_duration, video_duration) - return super().__call__(last_clip_time, truncated_video_duration, annotation) + return super().__call__( + last_clip_end_time, truncated_video_duration, annotation + ) class ConstantClipsPerVideoSampler(ClipSampler): @@ -340,11 +357,14 @@ def __init__( self._augs_per_clip = augs_per_clip def __call__( - self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] + self, + last_clip_end_time: Optional[float], + video_duration: float, + annotation: Dict[str, Any], ) -> ClipInfo: """ Args: - last_clip_time (float): Not used for ConstantClipsPerVideoSampler. + last_clip_end_time (float): Not used for ConstantClipsPerVideoSampler. video_duration: (float): the duration (in seconds) for the video that's being sampled. annotation (Dict): Not used by this sampler. diff --git a/pytorchvideo/data/labeled_video_dataset.py b/pytorchvideo/data/labeled_video_dataset.py index 32f29ea2..1dde136a 100644 --- a/pytorchvideo/data/labeled_video_dataset.py +++ b/pytorchvideo/data/labeled_video_dataset.py @@ -81,7 +81,7 @@ def __init__( # clip time in these variables. self._loaded_video_label = None self._loaded_clip = None - self._next_clip_start_time = 0.0 + self._last_clip_end_time = None self.video_path_handler = VideoPathHandler() @property @@ -153,9 +153,7 @@ def __next__(self) -> dict: clip_index, aug_index, is_last_clip, - ) = self._clip_sampler( - self._next_clip_start_time, video.duration, info_dict - ) + ) = self._clip_sampler(self._last_clip_end_time, video.duration, info_dict) if isinstance(clip_start, list): # multi-clip in each sample @@ -182,7 +180,7 @@ def __next__(self) -> dict: if aug_index == 0: self._loaded_clip = video.get_clip(clip_start, clip_end) - self._next_clip_start_time = clip_end + self._last_clip_end_time = clip_end video_is_null = ( self._loaded_clip is None or self._loaded_clip["video"] is None @@ -194,7 +192,7 @@ def __next__(self) -> dict: # to sample a new video on the next iteration. self._loaded_video_label[0].close() self._loaded_video_label = None - self._next_clip_start_time = 0.0 + self._last_clip_end_time = None self._clip_sampler.reset() if video_is_null: logger.debug( diff --git a/tests/test_uniform_clip_sampler.py b/tests/test_uniform_clip_sampler.py index 145b9e45..7b02aa73 100644 --- a/tests/test_uniform_clip_sampler.py +++ b/tests/test_uniform_clip_sampler.py @@ -1,8 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import copy import unittest from typing import Optional +import numpy as np from parameterized import parameterized from pytorchvideo.data.clip_sampling import UniformClipSampler @@ -21,13 +23,9 @@ def _num_clips( N = num_frames - window_size_frames if N < 0: return 1 - - result = N // stride_frames + 1 - - # handle padded frame - if backpad_last and N % stride_frames != 0: - result += 1 - return result + result = int(N / stride_frames + 1) + pad = backpad_last and N % stride_frames != 0 + return result + pad class TestUniformClipSampler(unittest.TestCase): @@ -69,6 +67,22 @@ class TestUniformClipSampler(unittest.TestCase): # > half stride (False, 30, 32, 24, 107 / 30, 4), (True, 30, 32, 24, 107 / 30, 5), + (False, 30, 5, 1, 11 / 30, 7), + (True, 30, 5, 1, 11 / 30, 7), + # stride > window size + (False, 30, 1, 5, 11 / 30, 3), + (True, 30, 1, 5, 11 / 30, 3), + (True, 30, 1, 5, 1759 / 30, 353), + (False, 30, 3, 10, 132 / 30, 13), + (True, 30, 3, 10, 132 / 30, 14), + (False, 30, 6, 10, 111 / 30, 11), + (True, 30, 6, 10, 111 / 30, 12), + # stride <= window size + (False, 30, 10, 3, 132 / 30, 41), + (True, 30, 10, 3, 132 / 30, 42), + (False, 30, 10, 6, 111 / 30, 17), + (True, 30, 10, 6, 111 / 30, 18), + (True, 30, 1, 1, 132 / 30, 132), ] ) def test_uniform_clip_sampler( @@ -88,30 +102,68 @@ def test_uniform_clip_sampler( stride_frames / fps if stride_frames is not None else None, backpad_last=backpad_last, ) + predicted_n_clips = _num_clips( + video_length, + fps, + stride_frames=stride_frames if stride_frames is not None else window_size, + window_size_frames=window_size, + backpad_last=backpad_last, + ) + self.assertEqual(predicted_n_clips, expected_number_of_clips) + + s_prime = stride_frames if stride_frames is not None else window_size + expected_start_end_times = [ + ((i * s_prime) / fps, ((i * s_prime + window_size) / fps)) + for i in range(expected_number_of_clips) + ] + if expected_start_end_times[-1][1] - video_length > 1e-6: + expected_start_end_times[-1] = ( + video_length - window_size / fps, + video_length, + ) + + self.assertTrue( + ( + expected_start_end_times[-1][0] + (s_prime / fps) > video_length + or expected_start_end_times[-1][-1] + (s_prime / fps) > video_length + ) + ) + if len(expected_start_end_times) >= 2: + self.assertNotAlmostEqual( + expected_start_end_times[-2][0], expected_start_end_times[-1][0] + ) + self.assertNotAlmostEqual( + expected_start_end_times[-2][1], expected_start_end_times[-1][1] + ) + + start_end_times = [] - last_clip_time = 0 + last_clip_time = None annotation = {} - n_clips = 0 while True: clip = sampler(last_clip_time, video_length, annotation) - last_clip_time = clip.clip_end_sec - n_clips += 1 + last_clip_time = copy.deepcopy(clip.clip_end_sec) + n_frames = (clip.clip_end_sec - clip.clip_start_sec) * fps + int_n_frames = int(np.round(float(n_frames))) + self.assertAlmostEqual(float(int_n_frames), float(n_frames)) + self.assertEqual(int_n_frames, window_size) + + start_end_times.append( + (float(clip.clip_start_sec), float(clip.clip_end_sec)) + ) if clip.is_last_clip: break # just in case we get an infinite loop - if n_clips > 2 * expected_number_of_clips: + if len(start_end_times) > 2 * expected_number_of_clips: break - predicted_n_clips = _num_clips( - video_length, - fps, - stride_frames=stride_frames if stride_frames is not None else window_size, - window_size_frames=window_size, - backpad_last=backpad_last, - ) - self.assertEqual(predicted_n_clips, expected_number_of_clips) - self.assertEqual(n_clips, expected_number_of_clips) + self.assertEqual(len(start_end_times), expected_number_of_clips) + for (start, end), (expected_start, expected_end) in zip( + start_end_times, expected_start_end_times + ): + self.assertAlmostEqual(float(start), expected_start) + self.assertAlmostEqual(float(end), expected_end) @parameterized.expand( [ @@ -142,6 +194,11 @@ def test_uniform_clip_sampler( (19 / 30, 30, 1, 32, True, 1), (33 / 30, 30, 1, 32, False, 2), (33 / 30, 30, 1, 32, True, 2), + (11 / 30, 30, 1, 5, False, 7), + (11 / 30, 30, 1, 5, True, 7), + (11 / 30, 30, 5, 1, False, 3), + (11 / 30, 30, 5, 1, True, 3), + (1759 / 30, 30, 5, 1, True, 353), ] ) def test_num_clips(