-
Notifications
You must be signed in to change notification settings - Fork 45
Support negative indices in get_frames_in_range #746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -220,9 +220,14 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch: | |
Returns: | ||
FrameBatch: The frames at the given indices. | ||
""" | ||
indices = [ | ||
index if index >= 0 else index + self._num_frames for index in indices | ||
] | ||
for i, index in enumerate(indices): | ||
index = index if index >= 0 else index + self._num_frames | ||
if not 0 <= index < self._num_frames: | ||
raise IndexError( | ||
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})." | ||
) | ||
else: | ||
indices[i] = index | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, we should avoid modifying the user-provided input inplace, as this can be surprising and have unintended effects. Let's keep the list comprehension logic? |
||
|
||
data, pts_seconds, duration_seconds = core.get_frames_at_indices( | ||
self._decoder, frame_indices=indices | ||
|
@@ -247,6 +252,8 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc | |
Returns: | ||
FrameBatch: The frames within the specified range. | ||
""" | ||
start = start if start >= 0 else start + self._num_frames | ||
stop = min(stop if stop >= 0 else stop + self._num_frames, self._num_frames) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to note that eventually, we'll want to support But we should do that separately. |
||
if not 0 <= start < self._num_frames: | ||
raise IndexError( | ||
f"Start index {start} is out of bounds; must be in the range [0, {self._num_frames})." | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -549,13 +549,10 @@ def test_get_frames_at(self, device, seek_mode): | |||||
def test_get_frames_at_fails(self, device, seek_mode): | ||||||
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) | ||||||
|
||||||
expected_converted_index = -10000 + len(decoder) | ||||||
with pytest.raises( | ||||||
RuntimeError, match=f"Invalid frame index={expected_converted_index}" | ||||||
): | ||||||
with pytest.raises(IndexError, match="Index -\\d+ is out of bounds"): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since pytest.raises accepts regex patterns, it is not necessary to calculate the index that will appear in the error. |
||||||
decoder.get_frames_at([-10000]) | ||||||
|
||||||
with pytest.raises(RuntimeError, match="Invalid frame index=390"): | ||||||
with pytest.raises(IndexError, match="Index 390 is out of bounds"): | ||||||
decoder.get_frames_at([390]) | ||||||
|
||||||
with pytest.raises(RuntimeError, match="Expected a value of type"): | ||||||
|
@@ -772,6 +769,66 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode): | |||||
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds | ||||||
) | ||||||
|
||||||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||||||
@pytest.mark.parametrize("stream_index", [3, None]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I think we can remove the |
||||||
@pytest.mark.parametrize("seek_mode", ("exact", "approximate")) | ||||||
def test_get_frames_in_range_tensor_index_semantics( | ||||||
self, stream_index, device, seek_mode | ||||||
): | ||||||
decoder = VideoDecoder( | ||||||
NASA_VIDEO.path, | ||||||
stream_index=stream_index, | ||||||
device=device, | ||||||
seek_mode=seek_mode, | ||||||
) | ||||||
# slices with upper bound greater than len(decoder) are supported | ||||||
ref_frames387_389 = NASA_VIDEO.get_frame_data_by_range( | ||||||
start=387, stop=390, stream_index=stream_index | ||||||
).to(device) | ||||||
frames387_389 = decoder.get_frames_in_range(start=387, stop=1000) | ||||||
print(f"{frames387_389.data.shape=}") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like a debugging left-over:
Suggested change
|
||||||
assert frames387_389.data.shape == torch.Size( | ||||||
[ | ||||||
3, | ||||||
NASA_VIDEO.get_num_color_channels(stream_index=stream_index), | ||||||
NASA_VIDEO.get_height(stream_index=stream_index), | ||||||
NASA_VIDEO.get_width(stream_index=stream_index), | ||||||
] | ||||||
) | ||||||
assert_frames_equal(ref_frames387_389, frames387_389.data) | ||||||
|
||||||
# test that negative values in the range are supported | ||||||
ref_frames386_389 = NASA_VIDEO.get_frame_data_by_range( | ||||||
start=386, stop=390, stream_index=stream_index | ||||||
).to(device) | ||||||
frames386_389 = decoder.get_frames_in_range(start=-4, stop=1000) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's check for a negative
Suggested change
|
||||||
assert frames386_389.data.shape == torch.Size( | ||||||
[ | ||||||
4, | ||||||
NASA_VIDEO.get_num_color_channels(stream_index=stream_index), | ||||||
NASA_VIDEO.get_height(stream_index=stream_index), | ||||||
NASA_VIDEO.get_width(stream_index=stream_index), | ||||||
] | ||||||
) | ||||||
assert_frames_equal(ref_frames386_389, frames386_389.data) | ||||||
|
||||||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||||||
@pytest.mark.parametrize("seek_mode", ("exact", "approximate")) | ||||||
def test_get_frames_in_range_fails(self, device, seek_mode): | ||||||
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) | ||||||
|
||||||
with pytest.raises(IndexError, match="Start index 1000 is out of bounds"): | ||||||
decoder.get_frames_in_range(start=1000, stop=10) | ||||||
|
||||||
with pytest.raises(IndexError, match="Start index -\\d+ is out of bounds"): | ||||||
decoder.get_frames_in_range(start=-1000, stop=10) | ||||||
|
||||||
with pytest.raises( | ||||||
IndexError, | ||||||
match="Stop index \\(-\\d+\\) must not be less than the start index", | ||||||
): | ||||||
decoder.get_frames_in_range(start=0, stop=-1000) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I considered possible edge cases for As a result, we do not need to explicitly check that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of precaution and to be pedantic, maybe also test when both start and stop are positive:
so that |
||||||
|
||||||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||||||
@pytest.mark.parametrize("seek_mode", ("exact", "approximate")) | ||||||
@patch("torchcodec._core._metadata._get_stream_json_metadata") | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize I'm sort of going back on the discussion from https://github.com/pytorch/torchcodec/pull/743/files#r2167398761, so I apologize for not bringing that up earlier. But I wonder if we actually want to do the validation and the conversion in Python. For performance reasons, it might be better to keep everything in C++, as the
indices
list can potentially be quite long, especially when sampling a large number of frames for long videos. I suspect we can easily havelen(indices) > Nx1000
, typically when using our samplers on long videos:torchcodec/src/torchcodec/samplers/_index_based.py
Line 180 in 103f714
As a somewhat related concern I think we'll eventually want to support
indices
not just as a list, but also as a tensor, e.g. for users who create their own sampling strategy usingtorch
utilities. And it'd be great if we could avoid copying that input tensor before passing it to the underlying decoding ops, which we can only do if we run the validation/conversion in C++.In any case, we can and should merge this PR as-is regardless of our conclusion, because we are already doing such validation in
main
, and this PR doesn't add much on top of the existing one.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed on both points - when I wrote the comment on #743, I was also wondering about performance, and figured let's go for consistency first, then deal with performance later. We can make the checking basically zero-cost on the C++ side if we do it as we iterate through the indices - but we'll want to not do a
TORCH_CHECK()
, but throw astd::out_of_range
so it becomes the right thing on the Python side. We can do all of that on a follow-up PR.