diff --git a/src/spdl/io/_core.py b/src/spdl/io/_core.py index cb6faa546..e91c44126 100644 --- a/src/spdl/io/_core.py +++ b/src/spdl/io/_core.py @@ -90,6 +90,7 @@ def log_api_usage_once(_: str) -> None: CUDABuffer = _libspdl_cuda.CUDABuffer CUDAConfig = _libspdl_cuda.CUDAConfig + _NvDecDecoder = _libspdl_cuda.NvDecDecoder SourceType = str | Path | bytes | UintArray | Tensor @@ -723,7 +724,17 @@ def decode_packets_nvdec( # long_name attribute # # "QuickTime / MOV", "FLV (Flash Video)", "Matroska / WebM" - match packets.codec.name: + if (codec := packets.codec) is None: + raise ValueError( + "The packets must have codec. " + "The packets object does not have codec when streaming demuxing. " + "The `decode_packets_nvdec` function is for one-off decoding. " + "If you want to use NVDEC with streaming demuxing, instantiate " + "NvDecDecoder object manually." + ) + + assert codec is not None # for type-checker + match codec.name: case "h264": packets = apply_bsf(packets, "h264_mp4toannexb") case "hevc": @@ -878,10 +889,10 @@ class NvDecDecoder: buffer = spdl.io.nv12_to_rgb(buffer) """ - def __init__(self, decoder) -> None: + def __init__(self, decoder: "_NvDecDecoder") -> None: log_api_usage_once("spdl.io.NvDecDecoder") - self._decoder = decoder + self._decoder: "_NvDecDecoder" = decoder def init( self, @@ -967,13 +978,13 @@ def flush(self) -> "list[CUDABuffer]": _THREAD_LOCAL = threading.local() -def _get_decoder(): +def _get_decoder() -> "_NvDecDecoder": if not hasattr(_THREAD_LOCAL, "_decoder"): - _THREAD_LOCAL._decoder = _libspdl_cuda._nvdec_decoder() + _THREAD_LOCAL._decoder = _libspdl_cuda._nvdec_decoder() # pyre-ignore[16] return _THREAD_LOCAL._decoder -def nvdec_decoder(use_cache: bool = True) -> NvDecDecoder: +def nvdec_decoder(use_cache: bool = True) -> "NvDecDecoder": """Instantiate an :py:class:`NvDecDecoder` object. Args: @@ -1030,10 +1041,12 @@ def convert_frames( stacklevel=2, ) kwargs.pop("pin_memory") - return _libspdl.convert_frames(frames, storage=storage, **kwargs) + return _libspdl.convert_frames(frames, storage=storage, **kwargs) # pyre-ignore[6] -def convert_array(vals, storage: "CPUStorage | None" = None) -> "CPUBuffer": +def convert_array( + vals: "UintArray", storage: "CPUStorage | None" = None +) -> "CPUBuffer": """Convert the given array to buffer. This function is intended to be used when sending class labels (which is @@ -1051,7 +1064,7 @@ def convert_array(vals, storage: "CPUStorage | None" = None) -> "CPUBuffer": def create_reference_audio_frame( - array, sample_fmt: str, sample_rate: int, pts: int + array: "UintArray", sample_fmt: str, sample_rate: int, pts: int ) -> "AudioFrames": """Create an AudioFrame object which refers to the given array/tensor. @@ -1101,7 +1114,7 @@ def create_reference_audio_frame( def create_reference_video_frame( - array, pix_fmt: str, frame_rate: tuple[int, int], pts: int + array: "UintArray", pix_fmt: str, frame_rate: tuple[int, int], pts: int ) -> "VideoFrames": """Create an VideoFrame object which refers to the given array/tensor. @@ -1292,7 +1305,7 @@ class Muxer: """ def __init__(self, dst: str | Path, /, *, format: str | None = None) -> None: - self._muxer = _libspdl.muxer(str(dst), format=format) + self._muxer: "_libspdl.Muxer" = _libspdl.muxer(str(dst), format=format) self._open = False @overload @@ -1367,7 +1380,7 @@ def add_remux_stream(self, codec: "AudioCodec | VideoCodec") -> None: for packets in demuxer.streaming_demux(num_packets=5): muxer.write(0, packets) """ - self._muxer.add_remux_stream(codec) + self._muxer.add_remux_stream(codec) # pyre-ignore[6] def open(self, muxer_config: dict[str, str] | None = None) -> "Muxer": """Open the muxer (output file) for writing. @@ -1404,7 +1417,7 @@ def write(self, stream_index: int, packets: "AudioPackets | VideoPackets") -> No stream_index: The stream to write to. packets: Audio/video data. """ - self._muxer.write(stream_index, packets) + self._muxer.write(stream_index, packets) # pyre-ignore[6] def flush(self) -> None: """Notify the muxer that all the streams are written.