Skip to content

Commit

Permalink
fix: passes typing
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuagruenstein committed Jun 2, 2024
1 parent 3138089 commit 43096f6
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 32 deletions.
7 changes: 4 additions & 3 deletions examples/flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
VIDEO_PTIME = 1 / 30 # 30fps


def generate_flag_frames():
def generate_flag_frames() -> list[VideoFrame]:
height, width = 480, 640

duck_np_array = cv2.imread(str(DUCK_JPEG_PATH))
data_bgr = cv2.resize(duck_np_array, (width, height))

# shrink and center it
M = np.float32([[0.5, 0, width / 4], [0, 0.5, height / 4]])
M = np.array([[0.5, 0, width / 4], [0, 0.5, height / 4]])
data_bgr = cv2.warpAffine(data_bgr, M, (width, height))

# compute animation
Expand All @@ -47,7 +47,8 @@ def generate_flag_frames():
map_y = id_y + 10 * np.sin(omega * id_x + phase)
frames.append(
VideoFrame.from_ndarray(
cv2.remap(data_bgr, map_x, map_y, cv2.INTER_LINEAR), format="bgr24"
cv2.remap(data_bgr, map_x, map_y, cv2.INTER_LINEAR).astype(np.uint8),
format="bgr24",
)
)

Expand Down
22 changes: 11 additions & 11 deletions vpx_rtp/codecs/vpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import List, Tuple, Type, TypeVar, cast

from av import VideoFrame
from av.frame import Frame
from av.packet import Packet

from vpx_rtp.codecs._vpx import ffi, lib
Expand Down Expand Up @@ -51,12 +50,12 @@ def number_of_threads(pixels: int, cpus: int) -> int:
class VpxPayloadDescriptor:
def __init__(
self,
partition_start,
partition_id,
picture_id=None,
tl0picidx=None,
tid=None,
keyidx=None,
partition_start: int,
partition_id: int,
picture_id: int | None = None,
tl0picidx: int | None = None,
tid: tuple[int, int] | None = None,
keyidx: int | None = None,
) -> None:
self.partition_start = partition_start
self.partition_id = partition_id
Expand Down Expand Up @@ -198,8 +197,8 @@ def __init__(self) -> None:
def __del__(self) -> None:
lib.vpx_codec_destroy(self.codec)

def decode(self, encoded_frame: JitterFrame) -> List[Frame]:
frames: List[Frame] = []
def decode(self, encoded_frame: JitterFrame) -> list[VideoFrame]:
frames = list[VideoFrame]()
result = lib.vpx_codec_decode(
self.codec,
encoded_frame.data,
Expand Down Expand Up @@ -260,9 +259,8 @@ def __del__(self) -> None:
lib.vpx_codec_destroy(self.codec)

def encode(
self, frame: Frame, force_keyframe: bool = False
self, frame: VideoFrame, force_keyframe: bool = False
) -> Tuple[List[bytes], int]:
assert isinstance(frame, VideoFrame)
if frame.format.name != "yuv420p":
frame = frame.reformat(format="yuv420p")

Expand Down Expand Up @@ -371,6 +369,8 @@ def encode(

def pack(self, packet: Packet) -> Tuple[List[bytes], int]:
payloads = self._packetize(bytes(packet), self.picture_id)

assert packet.pts is not None, "Packet must have a PTS"
timestamp = convert_timebase(packet.pts, packet.time_base, VIDEO_TIME_BASE)
self.picture_id = (self.picture_id + 1) % (1 << 15)
return payloads, timestamp
Expand Down
6 changes: 5 additions & 1 deletion vpx_rtp/jitterbuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def _remove_frame(self, sequence_number: int) -> Optional[JitterFrame]:
remove = 0
timestamp = None

assert self._origin is not None, "origin must be set"

for count in range(self.capacity):
pos = (self._origin + count) % self._capacity
packet = self._packets[pos]
Expand Down Expand Up @@ -98,7 +100,8 @@ def _remove_frame(self, sequence_number: int) -> Optional[JitterFrame]:

def remove(self, count: int) -> None:
assert count <= self._capacity
for i in range(count):
assert self._origin is not None, "origin must be set"
for _ in range(count):
pos = self._origin % self._capacity
self._packets[pos] = None
self._origin = uint16_add(self._origin, 1)
Expand All @@ -109,6 +112,7 @@ def smart_remove(self, count: int) -> bool:
to prevent sending corrupted frames to the decoder.
"""
timestamp = None
assert self._origin is not None, "origin must be set"
for i in range(self._capacity):
pos = self._origin % self._capacity
packet = self._packets[pos]
Expand Down
28 changes: 14 additions & 14 deletions vpx_rtp/rtcrtpparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,24 @@ class RTCRtpCodecParameters:
codec settings.
"""

mimeType: str
"The codec MIME media type/subtype, for instance `'audio/PCMU'`."
clockRate: int
"The codec clock rate expressed in Hertz."
channels: Optional[int] = None
"The number of channels supported (e.g. two for stereo)."
payloadType: Optional[int] = None
"The value that goes in the RTP Payload Type Field."
rtcpFeedback: List["RTCRtcpFeedback"] = field(default_factory=list)
"Transport layer and codec-specific feedback messages for this codec."
parameters: ParametersDict = field(default_factory=dict)
"Codec-specific parameters available for signaling."
mimeType: str # The codec MIME media type/subtype, for instance `'audio/PCMU'`.
clockRate: int # The codec clock rate expressed in Hertz.
payloadType: int # The value that goes in the RTP Payload Type Field.
channels: Optional[
int
] = None # The number of channels supported (e.g. two for stereo).
rtcpFeedback: List["RTCRtcpFeedback"] = field(
default_factory=list
) # Transport layer and codec-specific feedback messages for this codec.
parameters: ParametersDict = field(
default_factory=dict
) # Codec-specific parameters available for signaling.

@property
def name(self):
def name(self) -> str:
return self.mimeType.split("/")[1]

def __str__(self):
def __str__(self) -> str:
s = f"{self.name}/{self.clockRate}"
if self.channels == 2:
s += "/2"
Expand Down
23 changes: 20 additions & 3 deletions vpx_rtp/rtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from struct import pack, unpack, unpack_from
from typing import Any, List, Optional, Tuple

from typing_extensions import Self

from vpx_rtp.rtcrtpparameters import RTCRtpParameters

# used for NACK and retransmission
Expand Down Expand Up @@ -92,7 +94,7 @@ def get(self, extension_profile: int, extension_value: bytes) -> HeaderExtension
values.transport_sequence_number = unpack("!H", x_value)[0]
return values

def set(self, values: HeaderExtensions):
def set(self, values: HeaderExtensions) -> tuple[int, bytes]:
extensions = []
if values.mid is not None and self.__ids.mid:
extensions.append((self.__ids.mid, values.mid.encode("utf8")))
Expand Down Expand Up @@ -259,6 +261,17 @@ def __init__(
self.extensions = HeaderExtensions()
self.payload = payload
self.padding_size = 0
self._parsed_data: bytes | None = None

@property
def _data(self) -> bytes:
if self._parsed_data is None:
raise ValueError("RTP payload has not been parsed")
return self._parsed_data

@_data.setter
def _data(self, value: bytes) -> None:
self._parsed_data = value

def __repr__(self) -> str:
return (
Expand All @@ -268,7 +281,9 @@ def __repr__(self) -> str:
)

@classmethod
def parse(cls, data: bytes, extensions_map=HeaderExtensionsMap()):
def parse(
cls, data: bytes, extensions_map: HeaderExtensionsMap = HeaderExtensionsMap()
) -> Self:
if len(data) < RTP_HEADER_LENGTH:
raise ValueError(
f"RTP packet length is less than {RTP_HEADER_LENGTH} bytes"
Expand Down Expand Up @@ -321,7 +336,9 @@ def parse(cls, data: bytes, extensions_map=HeaderExtensionsMap()):

return packet

def serialize(self, extensions_map=HeaderExtensionsMap()) -> bytes:
def serialize(
self, extensions_map: HeaderExtensionsMap = HeaderExtensionsMap()
) -> bytes:
extension_profile, extension_value = extensions_map.set(self.extensions)
has_extension = bool(extension_value)

Expand Down

0 comments on commit 43096f6

Please sign in to comment.