diff --git a/pytorchvideo/transforms/transforms.py b/pytorchvideo/transforms/transforms.py index 7b98112..be7b2c1 100644 --- a/pytorchvideo/transforms/transforms.py +++ b/pytorchvideo/transforms/transforms.py @@ -339,6 +339,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.permute(*self._dims) +class Grayscale(torchvision.transforms.Grayscale): + """ + Converts RGB frames of (CTHW) video clip to grayscale. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): video tensor with shape (C, T, H, W). + """ + vid = x.permute(1, 0, 2, 3) + vid = super().forward(vid) + vid = vid.permute(1, 0, 2, 3) + return vid + + class OpSampler(torch.nn.Module): """ Given a list of transforms with weights, OpSampler applies weighted sampling to diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 46a0782..bc8f3b1 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -10,8 +10,8 @@ from pytorchvideo.transforms import ( ApplyTransformToKey, AugMix, - create_video_transform, CutMix, + Grayscale, MixUp, MixVideo, Normalize, @@ -23,6 +23,7 @@ ShortSideScale, UniformCropVideo, UniformTemporalSubsample, + create_video_transform, ) from pytorchvideo.transforms.functional import ( clip_boxes_to_image, @@ -38,7 +39,7 @@ uniform_temporal_subsample, uniform_temporal_subsample_repeated, ) -from torchvision.transforms import Compose +from torchvision.transforms import Compose, Grayscale as FrameGrayscale from torchvision.transforms._transforms_video import ( CenterCropVideo, NormalizeVideo, @@ -935,6 +936,26 @@ def test_permute(self): for p in list(permutations(range(0, 4))): self.assertTrue(video.permute(*p).equal(Permute(p)(video))) + def test_grayscale(self): + video = thwc_to_cthw(create_dummy_video_frames(10, 30, 40)).to( + dtype=torch.float32 + ) + transform = Grayscale() + + actual = transform(video) + self.assertEqual(actual.shape[0], 1) + + # Apply frame-wise grayscale. + framewise_gray = [] + for n in range(video.shape[1]): + frame = video[:, n, :].squeeze(1) + frame_gray = FrameGrayscale()(frame) + framewise_gray.append(frame_gray) + + framewise_gray = torch.stack(framewise_gray).transpose(0, 1) + self.assertEqual(actual.shape, framewise_gray.shape) + self.assertTrue(torch.equal(actual, framewise_gray)) + def test_video_transform_factory(self): # Test asserts/raises. self.assertRaises(TypeError, create_video_transform, mode="val", crop_size="s")