Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions src/spdl/io/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def log_api_usage_once(_: str) -> None:

CUDABuffer = _libspdl_cuda.CUDABuffer
CUDAConfig = _libspdl_cuda.CUDAConfig
_NvDecDecoder = _libspdl_cuda.NvDecDecoder

SourceType = str | Path | bytes | UintArray | Tensor

Expand Down Expand Up @@ -723,7 +724,17 @@ def decode_packets_nvdec(
# long_name attribute
#
# "QuickTime / MOV", "FLV (Flash Video)", "Matroska / WebM"
match packets.codec.name:
if (codec := packets.codec) is None:
raise ValueError(
"The packets must have codec. "
"The packets object does not have codec when streaming demuxing. "
"The `decode_packets_nvdec` function is for one-off decoding. "
"If you want to use NVDEC with streaming demuxing, instantiate "
"NvDecDecoder object manually."
)

assert codec is not None # for type-checker
match codec.name:
case "h264":
packets = apply_bsf(packets, "h264_mp4toannexb")
case "hevc":
Expand Down Expand Up @@ -878,10 +889,10 @@ class NvDecDecoder:
buffer = spdl.io.nv12_to_rgb(buffer)
"""

def __init__(self, decoder) -> None:
def __init__(self, decoder: "_NvDecDecoder") -> None:
log_api_usage_once("spdl.io.NvDecDecoder")

self._decoder = decoder
self._decoder: "_NvDecDecoder" = decoder

def init(
self,
Expand Down Expand Up @@ -967,13 +978,13 @@ def flush(self) -> "list[CUDABuffer]":
_THREAD_LOCAL = threading.local()


def _get_decoder():
def _get_decoder() -> "_NvDecDecoder":
if not hasattr(_THREAD_LOCAL, "_decoder"):
_THREAD_LOCAL._decoder = _libspdl_cuda._nvdec_decoder()
_THREAD_LOCAL._decoder = _libspdl_cuda._nvdec_decoder() # pyre-ignore[16]
return _THREAD_LOCAL._decoder


def nvdec_decoder(use_cache: bool = True) -> NvDecDecoder:
def nvdec_decoder(use_cache: bool = True) -> "NvDecDecoder":
"""Instantiate an :py:class:`NvDecDecoder` object.

Args:
Expand Down Expand Up @@ -1030,10 +1041,12 @@ def convert_frames(
stacklevel=2,
)
kwargs.pop("pin_memory")
return _libspdl.convert_frames(frames, storage=storage, **kwargs)
return _libspdl.convert_frames(frames, storage=storage, **kwargs) # pyre-ignore[6]


def convert_array(vals, storage: "CPUStorage | None" = None) -> "CPUBuffer":
def convert_array(
vals: "UintArray", storage: "CPUStorage | None" = None
) -> "CPUBuffer":
"""Convert the given array to buffer.

This function is intended to be used when sending class labels (which is
Expand All @@ -1051,7 +1064,7 @@ def convert_array(vals, storage: "CPUStorage | None" = None) -> "CPUBuffer":


def create_reference_audio_frame(
array, sample_fmt: str, sample_rate: int, pts: int
array: "UintArray", sample_fmt: str, sample_rate: int, pts: int
) -> "AudioFrames":
"""Create an AudioFrame object which refers to the given array/tensor.

Expand Down Expand Up @@ -1101,7 +1114,7 @@ def create_reference_audio_frame(


def create_reference_video_frame(
array, pix_fmt: str, frame_rate: tuple[int, int], pts: int
array: "UintArray", pix_fmt: str, frame_rate: tuple[int, int], pts: int
) -> "VideoFrames":
"""Create an VideoFrame object which refers to the given array/tensor.

Expand Down Expand Up @@ -1292,7 +1305,7 @@ class Muxer:
"""

def __init__(self, dst: str | Path, /, *, format: str | None = None) -> None:
self._muxer = _libspdl.muxer(str(dst), format=format)
self._muxer: "_libspdl.Muxer" = _libspdl.muxer(str(dst), format=format)
self._open = False

@overload
Expand Down Expand Up @@ -1367,7 +1380,7 @@ def add_remux_stream(self, codec: "AudioCodec | VideoCodec") -> None:
for packets in demuxer.streaming_demux(num_packets=5):
muxer.write(0, packets)
"""
self._muxer.add_remux_stream(codec)
self._muxer.add_remux_stream(codec) # pyre-ignore[6]

def open(self, muxer_config: dict[str, str] | None = None) -> "Muxer":
"""Open the muxer (output file) for writing.
Expand Down Expand Up @@ -1404,7 +1417,7 @@ def write(self, stream_index: int, packets: "AudioPackets | VideoPackets") -> No
stream_index: The stream to write to.
packets: Audio/video data.
"""
self._muxer.write(stream_index, packets)
self._muxer.write(stream_index, packets) # pyre-ignore[6]

def flush(self) -> None:
"""Notify the muxer that all the streams are written.
Expand Down
Loading