diff --git a/examples/flag.py b/examples/flag.py index 8d85fb4..a38a0b8 100644 --- a/examples/flag.py +++ b/examples/flag.py @@ -25,14 +25,14 @@ VIDEO_PTIME = 1 / 30 # 30fps -def generate_flag_frames(): +def generate_flag_frames() -> list[VideoFrame]: height, width = 480, 640 duck_np_array = cv2.imread(str(DUCK_JPEG_PATH)) data_bgr = cv2.resize(duck_np_array, (width, height)) # shrink and center it - M = np.float32([[0.5, 0, width / 4], [0, 0.5, height / 4]]) + M = np.array([[0.5, 0, width / 4], [0, 0.5, height / 4]]) data_bgr = cv2.warpAffine(data_bgr, M, (width, height)) # compute animation @@ -47,7 +47,8 @@ def generate_flag_frames(): map_y = id_y + 10 * np.sin(omega * id_x + phase) frames.append( VideoFrame.from_ndarray( - cv2.remap(data_bgr, map_x, map_y, cv2.INTER_LINEAR), format="bgr24" + cv2.remap(data_bgr, map_x, map_y, cv2.INTER_LINEAR).astype(np.uint8), + format="bgr24", ) ) diff --git a/vpx_rtp/codecs/vpx.py b/vpx_rtp/codecs/vpx.py index c5f1c93..d1181c7 100644 --- a/vpx_rtp/codecs/vpx.py +++ b/vpx_rtp/codecs/vpx.py @@ -5,7 +5,6 @@ from typing import List, Tuple, Type, TypeVar, cast from av import VideoFrame -from av.frame import Frame from av.packet import Packet from vpx_rtp.codecs._vpx import ffi, lib @@ -51,12 +50,12 @@ def number_of_threads(pixels: int, cpus: int) -> int: class VpxPayloadDescriptor: def __init__( self, - partition_start, - partition_id, - picture_id=None, - tl0picidx=None, - tid=None, - keyidx=None, + partition_start: int, + partition_id: int, + picture_id: int | None = None, + tl0picidx: int | None = None, + tid: tuple[int, int] | None = None, + keyidx: int | None = None, ) -> None: self.partition_start = partition_start self.partition_id = partition_id @@ -198,8 +197,8 @@ def __init__(self) -> None: def __del__(self) -> None: lib.vpx_codec_destroy(self.codec) - def decode(self, encoded_frame: JitterFrame) -> List[Frame]: - frames: List[Frame] = [] + def decode(self, encoded_frame: JitterFrame) -> list[VideoFrame]: + frames = list[VideoFrame]() result = lib.vpx_codec_decode( self.codec, encoded_frame.data, @@ -260,9 +259,8 @@ def __del__(self) -> None: lib.vpx_codec_destroy(self.codec) def encode( - self, frame: Frame, force_keyframe: bool = False + self, frame: VideoFrame, force_keyframe: bool = False ) -> Tuple[List[bytes], int]: - assert isinstance(frame, VideoFrame) if frame.format.name != "yuv420p": frame = frame.reformat(format="yuv420p") @@ -371,6 +369,8 @@ def encode( def pack(self, packet: Packet) -> Tuple[List[bytes], int]: payloads = self._packetize(bytes(packet), self.picture_id) + + assert packet.pts is not None, "Packet must have a PTS" timestamp = convert_timebase(packet.pts, packet.time_base, VIDEO_TIME_BASE) self.picture_id = (self.picture_id + 1) % (1 << 15) return payloads, timestamp diff --git a/vpx_rtp/jitterbuffer.py b/vpx_rtp/jitterbuffer.py index fcecae0..758fa76 100644 --- a/vpx_rtp/jitterbuffer.py +++ b/vpx_rtp/jitterbuffer.py @@ -67,6 +67,8 @@ def _remove_frame(self, sequence_number: int) -> Optional[JitterFrame]: remove = 0 timestamp = None + assert self._origin is not None, "origin must be set" + for count in range(self.capacity): pos = (self._origin + count) % self._capacity packet = self._packets[pos] @@ -98,7 +100,8 @@ def _remove_frame(self, sequence_number: int) -> Optional[JitterFrame]: def remove(self, count: int) -> None: assert count <= self._capacity - for i in range(count): + assert self._origin is not None, "origin must be set" + for _ in range(count): pos = self._origin % self._capacity self._packets[pos] = None self._origin = uint16_add(self._origin, 1) @@ -109,6 +112,7 @@ def smart_remove(self, count: int) -> bool: to prevent sending corrupted frames to the decoder. """ timestamp = None + assert self._origin is not None, "origin must be set" for i in range(self._capacity): pos = self._origin % self._capacity packet = self._packets[pos] diff --git a/vpx_rtp/rtcrtpparameters.py b/vpx_rtp/rtcrtpparameters.py index d12c2c7..ef62e78 100644 --- a/vpx_rtp/rtcrtpparameters.py +++ b/vpx_rtp/rtcrtpparameters.py @@ -11,24 +11,24 @@ class RTCRtpCodecParameters: codec settings. """ - mimeType: str - "The codec MIME media type/subtype, for instance `'audio/PCMU'`." - clockRate: int - "The codec clock rate expressed in Hertz." - channels: Optional[int] = None - "The number of channels supported (e.g. two for stereo)." - payloadType: Optional[int] = None - "The value that goes in the RTP Payload Type Field." - rtcpFeedback: List["RTCRtcpFeedback"] = field(default_factory=list) - "Transport layer and codec-specific feedback messages for this codec." - parameters: ParametersDict = field(default_factory=dict) - "Codec-specific parameters available for signaling." + mimeType: str # The codec MIME media type/subtype, for instance `'audio/PCMU'`. + clockRate: int # The codec clock rate expressed in Hertz. + payloadType: int # The value that goes in the RTP Payload Type Field. + channels: Optional[ + int + ] = None # The number of channels supported (e.g. two for stereo). + rtcpFeedback: List["RTCRtcpFeedback"] = field( + default_factory=list + ) # Transport layer and codec-specific feedback messages for this codec. + parameters: ParametersDict = field( + default_factory=dict + ) # Codec-specific parameters available for signaling. @property - def name(self): + def name(self) -> str: return self.mimeType.split("/")[1] - def __str__(self): + def __str__(self) -> str: s = f"{self.name}/{self.clockRate}" if self.channels == 2: s += "/2" diff --git a/vpx_rtp/rtp.py b/vpx_rtp/rtp.py index 0949f99..303d33d 100644 --- a/vpx_rtp/rtp.py +++ b/vpx_rtp/rtp.py @@ -3,6 +3,8 @@ from struct import pack, unpack, unpack_from from typing import Any, List, Optional, Tuple +from typing_extensions import Self + from vpx_rtp.rtcrtpparameters import RTCRtpParameters # used for NACK and retransmission @@ -92,7 +94,7 @@ def get(self, extension_profile: int, extension_value: bytes) -> HeaderExtension values.transport_sequence_number = unpack("!H", x_value)[0] return values - def set(self, values: HeaderExtensions): + def set(self, values: HeaderExtensions) -> tuple[int, bytes]: extensions = [] if values.mid is not None and self.__ids.mid: extensions.append((self.__ids.mid, values.mid.encode("utf8"))) @@ -259,6 +261,17 @@ def __init__( self.extensions = HeaderExtensions() self.payload = payload self.padding_size = 0 + self._parsed_data: bytes | None = None + + @property + def _data(self) -> bytes: + if self._parsed_data is None: + raise ValueError("RTP payload has not been parsed") + return self._parsed_data + + @_data.setter + def _data(self, value: bytes) -> None: + self._parsed_data = value def __repr__(self) -> str: return ( @@ -268,7 +281,9 @@ def __repr__(self) -> str: ) @classmethod - def parse(cls, data: bytes, extensions_map=HeaderExtensionsMap()): + def parse( + cls, data: bytes, extensions_map: HeaderExtensionsMap = HeaderExtensionsMap() + ) -> Self: if len(data) < RTP_HEADER_LENGTH: raise ValueError( f"RTP packet length is less than {RTP_HEADER_LENGTH} bytes" @@ -321,7 +336,9 @@ def parse(cls, data: bytes, extensions_map=HeaderExtensionsMap()): return packet - def serialize(self, extensions_map=HeaderExtensionsMap()) -> bytes: + def serialize( + self, extensions_map: HeaderExtensionsMap = HeaderExtensionsMap() + ) -> bytes: extension_profile, extension_value = extensions_map.set(self.extensions) has_extension = bool(extension_value)