From 3dd7214f09608d39816819ca4f1b694fdbcae05d Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 1 Jul 2025 11:39:12 -0700 Subject: [PATCH 1/2] Error check in get_frames_at, adjust indices in get_frames_in_range --- src/torchcodec/decoders/_video_decoder.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 548f59c3..1dc49596 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -220,9 +220,14 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch: Returns: FrameBatch: The frames at the given indices. """ - indices = [ - index if index >= 0 else index + self._num_frames for index in indices - ] + for i, index in enumerate(indices): + index = index if index >= 0 else index + self._num_frames + if not 0 <= index < self._num_frames: + raise IndexError( + f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})." + ) + else: + indices[i] = index data, pts_seconds, duration_seconds = core.get_frames_at_indices( self._decoder, frame_indices=indices @@ -247,6 +252,8 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc Returns: FrameBatch: The frames within the specified range. """ + start = start if start >= 0 else start + self._num_frames + stop = min(stop if stop >= 0 else stop + self._num_frames, self._num_frames) if not 0 <= start < self._num_frames: raise IndexError( f"Start index {start} is out of bounds; must be in the range [0, {self._num_frames})." From 730361a88d5eed54d4f6d39f45f254fd5bc10cca Mon Sep 17 00:00:00 2001 From: Daniel Flores Date: Tue, 1 Jul 2025 11:40:07 -0700 Subject: [PATCH 2/2] Update get_frames_at_fails regex, add get_frames_in_range tests --- test/test_decoders.py | 67 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/test/test_decoders.py b/test/test_decoders.py index dcf9a158..09e6e3d6 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -549,13 +549,10 @@ def test_get_frames_at(self, device, seek_mode): def test_get_frames_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - expected_converted_index = -10000 + len(decoder) - with pytest.raises( - RuntimeError, match=f"Invalid frame index={expected_converted_index}" - ): + with pytest.raises(IndexError, match="Index -\\d+ is out of bounds"): decoder.get_frames_at([-10000]) - with pytest.raises(RuntimeError, match="Invalid frame index=390"): + with pytest.raises(IndexError, match="Index 390 is out of bounds"): decoder.get_frames_at([390]) with pytest.raises(RuntimeError, match="Expected a value of type"): @@ -772,6 +769,66 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("stream_index", [3, None]) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_in_range_tensor_index_semantics( + self, stream_index, device, seek_mode + ): + decoder = VideoDecoder( + NASA_VIDEO.path, + stream_index=stream_index, + device=device, + seek_mode=seek_mode, + ) + # slices with upper bound greater than len(decoder) are supported + ref_frames387_389 = NASA_VIDEO.get_frame_data_by_range( + start=387, stop=390, stream_index=stream_index + ).to(device) + frames387_389 = decoder.get_frames_in_range(start=387, stop=1000) + print(f"{frames387_389.data.shape=}") + assert frames387_389.data.shape == torch.Size( + [ + 3, + NASA_VIDEO.get_num_color_channels(stream_index=stream_index), + NASA_VIDEO.get_height(stream_index=stream_index), + NASA_VIDEO.get_width(stream_index=stream_index), + ] + ) + assert_frames_equal(ref_frames387_389, frames387_389.data) + + # test that negative values in the range are supported + ref_frames386_389 = NASA_VIDEO.get_frame_data_by_range( + start=386, stop=390, stream_index=stream_index + ).to(device) + frames386_389 = decoder.get_frames_in_range(start=-4, stop=1000) + assert frames386_389.data.shape == torch.Size( + [ + 4, + NASA_VIDEO.get_num_color_channels(stream_index=stream_index), + NASA_VIDEO.get_height(stream_index=stream_index), + NASA_VIDEO.get_width(stream_index=stream_index), + ] + ) + assert_frames_equal(ref_frames386_389, frames386_389.data) + + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_in_range_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) + + with pytest.raises(IndexError, match="Start index 1000 is out of bounds"): + decoder.get_frames_in_range(start=1000, stop=10) + + with pytest.raises(IndexError, match="Start index -\\d+ is out of bounds"): + decoder.get_frames_in_range(start=-1000, stop=10) + + with pytest.raises( + IndexError, + match="Stop index \\(-\\d+\\) must not be less than the start index", + ): + decoder.get_frames_in_range(start=0, stop=-1000) + @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) @patch("torchcodec._core._metadata._get_stream_json_metadata")