-
Notifications
You must be signed in to change notification settings - Fork 6.7k
LTX 2 Improve encode_video by Accepting More Input Types
#13057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d5d2910
2e18d2c
857735f
cd60b3d
7354055
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13,10 +13,14 @@ | |||||||||
| # See the License for the specific language governing permissions and | ||||||||||
| # limitations under the License. | ||||||||||
|
|
||||||||||
| from collections.abc import Generator, Iterator | ||||||||||
| from fractions import Fraction | ||||||||||
| from typing import Optional | ||||||||||
| from typing import List, Optional, Tuple, Union | ||||||||||
|
|
||||||||||
| import numpy as np | ||||||||||
| import PIL.Image | ||||||||||
| import torch | ||||||||||
| from tqdm import tqdm | ||||||||||
|
|
||||||||||
| from ...utils import is_av_available | ||||||||||
|
|
||||||||||
|
|
@@ -101,11 +105,52 @@ def _write_audio( | |||||||||
|
|
||||||||||
|
|
||||||||||
| def encode_video( | ||||||||||
| video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str | ||||||||||
| video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]], | ||||||||||
| fps: int, | ||||||||||
| audio: Optional[torch.Tensor], | ||||||||||
| audio_sample_rate: Optional[int], | ||||||||||
| output_path: str, | ||||||||||
| video_chunks_number: int = 1, | ||||||||||
| ) -> None: | ||||||||||
| video_np = video.cpu().numpy() | ||||||||||
| """ | ||||||||||
| Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo: | ||||||||||
| https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182 | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`): | ||||||||||
| A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the | ||||||||||
| input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines | ||||||||||
| usually return with `output_type="np"`). | ||||||||||
| fps (`int`) | ||||||||||
| The frames per second (FPS) of the encoded video. | ||||||||||
| audio (`torch.Tensor`, *optional*): | ||||||||||
| An audio waveform of shape [audio_channels, samples]. | ||||||||||
| audio_sample_rate: (`int`, *optional*): | ||||||||||
| The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz). | ||||||||||
| output_path (`str`): | ||||||||||
| The path to save the encoded video to. | ||||||||||
| video_chunks_number (`int`, *optional*, defaults to `1`): | ||||||||||
| The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The | ||||||||||
| number of chunks to use often depends on the tiling config for the video VAE. | ||||||||||
| """ | ||||||||||
| if isinstance(video, list) and isinstance(video[0], PIL.Image.Image): | ||||||||||
| # Pipeline output_type="pil" | ||||||||||
| video_frames = [np.array(frame) for frame in video] | ||||||||||
| video = np.stack(video_frames, axis=0) | ||||||||||
| video = torch.from_numpy(video) | ||||||||||
| elif isinstance(video, np.ndarray): | ||||||||||
| # Pipeline output_type="np" | ||||||||||
| is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video)) | ||||||||||
| if np.all(is_denormalized): | ||||||||||
| video = (video * 255).round().astype("uint8") | ||||||||||
| video = torch.from_numpy(video) | ||||||||||
|
|
||||||||||
| if isinstance(video, torch.Tensor): | ||||||||||
| video = iter([video]) | ||||||||||
|
|
||||||||||
| first_chunk = next(video) | ||||||||||
|
|
||||||||||
| _, height, width, _ = video_np.shape | ||||||||||
| _, height, width, _ = first_chunk.shape | ||||||||||
|
|
||||||||||
| container = av.open(output_path, mode="w") | ||||||||||
| stream = container.add_stream("libx264", rate=int(fps)) | ||||||||||
|
|
@@ -119,10 +164,18 @@ def encode_video( | |||||||||
|
|
||||||||||
| audio_stream = _prepare_audio_stream(container, audio_sample_rate) | ||||||||||
|
|
||||||||||
| for frame_array in video_np: | ||||||||||
| frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") | ||||||||||
| for packet in stream.encode(frame): | ||||||||||
| container.mux(packet) | ||||||||||
| def all_tiles( | ||||||||||
| first_chunk: torch.Tensor, tiles_generator: Generator[Tuple[torch.Tensor, int], None, None] | ||||||||||
| ) -> Generator[Tuple[torch.Tensor, int], None, None]: | ||||||||||
| yield first_chunk | ||||||||||
| yield from tiles_generator | ||||||||||
|
|
||||||||||
| for video_chunk in tqdm(all_tiles(first_chunk, video), total=video_chunks_number): | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WDYT of getting rid of from itertools import chain
for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number):
video_chunk_cpu = video_chunk.to("cpu").numpy()
for frame_array in video_chunk_cpu:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does the right thing but appears not to work well with 33%|████████████████████████████████▋ | 1/3 [00:04<00:09, 4.57s/it]
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I think #13057 (comment) is wrong - as we generally supply a single diffusers/src/diffusers/pipelines/ltx2/export_utils.py Lines 146 to 149 in 857735f
So when we call I think the underlying difference is that the original LTX 2 code will return an iterator over decoded tiles when performing tiled VAE decoding, whereas we will return the whole decoded output as a single tensor with the tiles stitched back together. So maybe it doesn't make sense to support |
||||||||||
| video_chunk_cpu = video_chunk.to("cpu").numpy() | ||||||||||
| for frame_array in video_chunk_cpu: | ||||||||||
| frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we let the users control this format? 👀
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could allow the users to specify the format, but this would be in tension with value checking as suggested in #13057 (comment): for example, if we always convert denormalized inputs with values in We could conditionally convert based on the supplied elif isinstance(video, np.ndarray):
# Pipeline output_type="np"
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
if np.all(is_denormalized) and video_format == "rgb24":
video = (video * 255).round().astype("uint8")
else:
logger.warning(
f"The video will be encoded using the input `video` values as-is with format {video_format}. Make sure"
f" the values are in the proper range for the supplied format".
)
video = torch.from_numpy(video)An alternative would be to only support EDIT: the right terminology here might be "pixel format" rather than "video format".
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Okay let's go with this. |
||||||||||
| for packet in stream.encode(frame): | ||||||||||
| container.mux(packet) | ||||||||||
|
|
||||||||||
| # Flush encoder | ||||||||||
| for packet in stream.encode(): | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this option helpful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original LTX-2 code will use a
video_chunks_numbercalculated from the video VAE tiling config, for example in two stage inference:https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py#L257
For the default
num_framesvalue of121and default tiling configTilingConfig.default(), I believe this works out to3chunks. The idea seems to be that the chunks correspond to each tiled stride when decoding.In practice, I haven't had any issues with the current code, which is equivalent to just using one chunk. I don't fully understand the reasoning behind why the original code supports it; my guess is that it is useful for very long videos or if there are compute constraints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #13057 (comment) for discussion about some complications for supporting
video_chunks_number.