diff --git a/pytorchvideo/data/stream.py b/pytorchvideo/data/stream.py new file mode 100644 index 00000000..b5d976c6 --- /dev/null +++ b/pytorchvideo/data/stream.py @@ -0,0 +1,55 @@ + +from typing import Any, Callable, Dict, Iterable +from .video import Video + + +class Stream(Iterable): + """Create an iterable streaming clips of video data.""" + + def __init__( + self, + video: Video, + clip_duration: float, + clip_transform: Callable = None, + **get_clip_kwargs: Dict[str, Any], + ) -> None: + """ + Parameters + ---------- + video : Video + PyTorchVideo video instance to stream. + clip_duration : float + Maximum duration (in seconds) of the returned clip at every iteration. + clip_transform : Transform, optional + Optional transform to apply to each clip, by default None + get_clip_kwargs : Dict[str, Any] + Arguments to pass to the underlying video `get_clip` method. + """ + super().__init__() + self._clip_duration = clip_duration + self._clip_transform = clip_transform + self._video = video + self._get_clip_kwargs = get_clip_kwargs + + def __iter__(self): + current_time = 0.0 + while current_time < self._video.duration: + next_time = min( + self._video.duration, + current_time + self._clip_duration, + ) + video_data = self._video.get_clip( + current_time, + next_time, + **self._get_clip_kwargs, + ) + current_time = next_time + + if self._clip_transform: + video_data = self._clip_transform(video_data) + + yield video_data + + @property + def video(self): + return self._video diff --git a/pytorchvideo/transforms/transforms.py b/pytorchvideo/transforms/transforms.py index 7b981124..68f5d06c 100644 --- a/pytorchvideo/transforms/transforms.py +++ b/pytorchvideo/transforms/transforms.py @@ -1,8 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import pytorchvideo.transforms.functional +import pytorchvideo.data.video +import pytorchvideo.data.stream import torch import torchvision.transforms @@ -429,3 +431,40 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torchvision.transforms.Lambda( pytorchvideo.transforms.functional.div_255 )(x) + +class Streamed(torch.nn.Module): + """Apply a video transform in a streamed fashion (useful for large videos).""" + + def __init__( + self, + clip_duration: float, + clip_transform: Optional[Callable] = None, + return_iterable: bool = False, + **get_clip_kwargs, + ) -> None: + """ + Parameters + ---------- + clip_duration : float + Maximum duration (in seconds) of the transformed clip at every iteration. + clip_transform : Callable, optional + Optional transform to apply to each clip, by default None. + return_iterable : bool, optional + Decides if transform should return an iterable (more control over looping) or the iterated result, by default False. + """ + super().__init__() + self._clip_transform = clip_transform + self._clip_duration = clip_duration + self._return_iterable = return_iterable + self._get_clip_kwargs = get_clip_kwargs + + def __call__(self, video: pytorchvideo.data.video.Video) -> Union[pytorchvideo.data.stream.Stream, List[Any]]: + stream = pytorchvideo.data.stream.Stream( + video, + self._clip_duration, + self._clip_transform, + **self._get_clip_kwargs, + ) + if self._return_iterable: + return stream + return tuple(stream) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6b9a5fa6..da3c50db 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,11 +1,13 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from typing import Tuple import unittest from collections import Counter from itertools import permutations import numpy as np import torch +from pytorchvideo.data.encoded_video import EncodedVideo from pytorchvideo.data.utils import thwc_to_cthw from pytorchvideo.transforms import ( ApplyTransformToKey, @@ -22,6 +24,7 @@ ShortSideScale, UniformCropVideo, UniformTemporalSubsample, + Streamed, create_video_transform, ) from pytorchvideo.transforms.functional import ( @@ -45,7 +48,7 @@ RandomCropVideo, RandomHorizontalFlipVideo, ) -from utils import create_dummy_video_frames, create_random_bbox +from utils import create_dummy_video_frames, create_random_bbox, temp_encoded_video class TestTransforms(unittest.TestCase): @@ -935,6 +938,56 @@ def test_permute(self): for p in list(permutations(range(0, 4))): self.assertTrue(video.permute(*p).equal(Permute(p)(video))) + def test_streamed(self): + fps = 4 + seconds = 5 + width = 12 + height = 8 + + def _check_result_shapes(result: Tuple): + self.assertEqual(len(result), seconds+1) + for i in range(seconds): + clip = result[i]["video"] + self.assertEqual(clip.shape[1], fps) + self.assertEqual(clip.shape[2], height) + self.assertEqual(clip.shape[3], width) + clip = result[-1]["video"] + self.assertEqual(clip.shape[1], fps//2) + self.assertEqual(clip.shape[2], height) + self.assertEqual(clip.shape[3], width) + + def _check_counter_result(test_case: unittest.TestCase, result: Tuple): + test_case.assertTrue(all((r["video"] == i).all().item() for i, r in enumerate(result))) + + class _CounterTransform: + def __init__(self) -> None: + self._counter = 0 + def __call__(self, video): + video = torch.full_like(video, fill_value=self._counter) + self._counter += 1 + return video + + with temp_encoded_video(fps*seconds+fps//2, fps=4, height=8, width=width) as (file_name, data): + video = EncodedVideo.from_path(file_name) + + # no transform + result = Streamed(clip_duration=1., clip_transform=None, return_iterable=False)(video) + _check_result_shapes(result) + + # simple transform (iterated through) + transform = ApplyTransformToKey("video", _CounterTransform()) + result = Streamed(clip_duration=1., clip_transform=transform, return_iterable=False)(video) + _check_result_shapes(result) + _check_counter_result(self, result) + + # simple transform (not iterated through) + transform = ApplyTransformToKey("video", _CounterTransform()) + result = Streamed(clip_duration=1., clip_transform=transform, return_iterable=True)(video) + self.assertRaises(TypeError, lambda: len(result)) + result = tuple(result) + _check_result_shapes(result) + _check_counter_result(self, result) + def test_video_transform_factory(self): # Test asserts/raises. self.assertRaises(TypeError, create_video_transform, mode="val", crop_size="s")