Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ To see how the agent work open up agents.py

**Video**

* The agent receives the video track, and calls agent.llm._watch_video_track
* The agent receives the video track, and calls agent.llm.watch_video_track
* The LLM uses the VideoForwarder to write the video to a websocket or webrtc connection
* The STS writes the reply on agent.llm.audio_track and the RealtimeTranscriptEvent / RealtimePartialTranscriptEvent

Expand Down
132 changes: 62 additions & 70 deletions agents-core/vision_agents/core/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
import uuid
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeGuard
from uuid import uuid4

import getstream.models
Expand All @@ -30,7 +30,7 @@
RealtimeUserSpeechTranscriptionEvent,
RealtimeAgentSpeechTranscriptionEvent,
)
from ..llm.llm import LLM
from ..llm.llm import AudioLLM, LLM, VideoLLM
from ..llm.realtime import Realtime
from ..mcp import MCPBaseServer, MCPManager
from ..processors.base_processor import Processor, ProcessorType, filter_processors
Expand Down Expand Up @@ -110,6 +110,18 @@ def default_agent_options():
return AgentOptions(model_dir=_DEFAULT_MODEL_DIR)


def _is_audio_llm(llm: LLM | VideoLLM | AudioLLM) -> TypeGuard[AudioLLM]:
return isinstance(llm, AudioLLM)


def _is_video_llm(llm: LLM | VideoLLM | AudioLLM) -> TypeGuard[VideoLLM]:
return isinstance(llm, VideoLLM)


def _is_realtime_llm(llm: LLM | AudioLLM | VideoLLM | Realtime) -> TypeGuard[Realtime]:
return isinstance(llm, Realtime)


class Agent:
"""
Agent class makes it easy to build your own video AI.
Expand Down Expand Up @@ -140,7 +152,7 @@ def __init__(
# edge network for video & audio
edge: "StreamEdge",
# llm, optionally with sts/realtime capabilities
llm: LLM | Realtime,
llm: LLM | AudioLLM | VideoLLM,
# the agent's user info
agent_user: User,
# instructions
Expand Down Expand Up @@ -428,8 +440,8 @@ async def _on_tts_audio_write_to_output(event: TTSAudioEvent):

@self.events.subscribe
async def on_stt_transcript_event_create_response(event: STTTranscriptEvent):
if self.realtime_mode or not self.llm:
# when running in realtime mode, there is no need to send the response to the LLM
if _is_audio_llm(self.llm):
# There is no need to send the response to the LLM if it handles audio itself.
return

user_id = event.user_id()
Expand Down Expand Up @@ -497,7 +509,7 @@ async def join(self, call: Call) -> "AgentSessionContextManager":

# Ensure Realtime providers are ready before proceeding (they manage their own connection)
self.logger.info(f"🤖 Agent joining call: {call.id}")
if isinstance(self.llm, Realtime):
if _is_realtime_llm(self.llm):
await self.llm.connect()

with self.span("edge.join"):
Expand Down Expand Up @@ -805,7 +817,7 @@ async def on_audio_received(event: AudioReceivedEvent):

# Always listen to remote video tracks so we can forward frames to Realtime providers
@self.edge.events.subscribe
async def on_track(event: TrackAddedEvent):
async def on_video_track_added(event: TrackAddedEvent):
track_id = event.track_id
track_type = event.track_type
user = event.user
Expand All @@ -819,12 +831,12 @@ async def on_track(event: TrackAddedEvent):
f"🎥 Track re-added: {track_type_name} ({track_id}), switching to it"
)

if self.realtime_mode and isinstance(self.llm, Realtime):
if _is_video_llm(self.llm):
# Get the existing forwarder and switch to this track
_, _, forwarder = self._active_video_tracks[track_id]
track = self.edge.add_track_subscriber(track_id)
if track and forwarder:
await self.llm._watch_video_track(
await self.llm.watch_video_track(
track, shared_forwarder=forwarder
)
self._current_video_track_id = track_id
Expand All @@ -835,7 +847,7 @@ async def on_track(event: TrackAddedEvent):
task.add_done_callback(_log_task_exception)

@self.edge.events.subscribe
async def on_track_removed(event: TrackRemovedEvent):
async def on_video_track_removed(event: TrackRemovedEvent):
track_id = event.track_id
track_type = event.track_type
if not track_id:
Expand All @@ -853,11 +865,7 @@ async def on_track_removed(event: TrackRemovedEvent):
self._active_video_tracks.pop(track_id, None)

# If this was the active track, switch to any other available track
if (
track_id == self._current_video_track_id
and self.realtime_mode
and isinstance(self.llm, Realtime)
):
if _is_video_llm(self.llm) and track_id == self._current_video_track_id:
self.logger.info(
"🎥 Active video track removed, switching to next available"
)
Expand All @@ -883,7 +891,7 @@ async def _reply_to_audio(
)

# when in Realtime mode call the Realtime directly (non-blocking)
if self.realtime_mode and isinstance(self.llm, Realtime):
if _is_audio_llm(self.llm):
# TODO: this behaviour should be easy to change in the agent class
asyncio.create_task(
self.llm.simple_audio_response(pcm_data, participant)
Expand Down Expand Up @@ -919,9 +927,9 @@ async def _switch_to_next_available_track(self) -> None:

# Get the track and forwarder
track = self.edge.add_track_subscriber(track_id)
if track and forwarder and isinstance(self.llm, Realtime):
if track and forwarder and _is_video_llm(self.llm):
# Send to Realtime provider
await self.llm._watch_video_track(track, shared_forwarder=forwarder)
await self.llm.watch_video_track(track, shared_forwarder=forwarder)
self._current_video_track_id = track_id
return
else:
Expand Down Expand Up @@ -984,7 +992,7 @@ async def recv(self):
# If Realtime provider supports video, switch to this new track
track_type_name = TrackType.Name(track_type)

if self.realtime_mode:
if _is_video_llm(self.llm):
if self._video_track:
# We have a video publisher (e.g., YOLO processor)
# Create a separate forwarder for the PROCESSED video track
Expand All @@ -1000,22 +1008,20 @@ async def recv(self):
await processed_forwarder.start()
self._video_forwarders.append(processed_forwarder)

if isinstance(self.llm, Realtime):
# Send PROCESSED frames with the processed forwarder
await self.llm._watch_video_track(
self._video_track, shared_forwarder=processed_forwarder
)
self._current_video_track_id = track_id
# Send PROCESSED frames with the processed forwarder
await self.llm.watch_video_track(
self._video_track, shared_forwarder=processed_forwarder
)
self._current_video_track_id = track_id
else:
# No video publisher, send raw frames - switch to this new track
self.logger.info(
f"🎥 Switching to {track_type_name} track: {track_id}"
)
if isinstance(self.llm, Realtime):
await self.llm._watch_video_track(
track, shared_forwarder=raw_forwarder
)
self._current_video_track_id = track_id
await self.llm.watch_video_track(
track, shared_forwarder=raw_forwarder
)
self._current_video_track_id = track_id

has_image_processors = len(self.image_processors) > 0

Expand Down Expand Up @@ -1106,8 +1112,8 @@ async def recv(self):

async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None:
"""Handle turn detection events."""
# In realtime mode, the LLM handles turn detection, interruption, and responses itself
if self.realtime_mode:
# Skip the turn event handling if the model doesn't require TTS or SST audio itself.
if _is_audio_llm(self.llm):
return

if isinstance(event, TurnStartedEvent):
Expand Down Expand Up @@ -1141,56 +1147,44 @@ async def _on_turn_event(self, event: TurnStartedEvent | TurnEndedEvent) -> None
self.logger.info(
f"👉 Turn ended - participant {participant_id} finished (confidence: {event.confidence})"
)
if not event.participant or event.participant.user_id == self.agent_user.id:
# Exit early if the event is triggered by the model response.
return

# When turn detection is enabled, trigger LLM response when user's turn ends
# When turn detection is enabled, trigger LLM response when user's turn ends.
# This is the signal that the user has finished speaking and expects a response
if event.participant and event.participant.user_id != self.agent_user.id:
# Get the accumulated transcript for this speaker
transcript = self._pending_user_transcripts.get(
event.participant.user_id, ""
transcript = self._pending_user_transcripts.get(
event.participant.user_id, ""
)
if transcript.strip():
self.logger.info(
f"🤖 Triggering LLM response after turn ended for {event.participant.user_id}"
)

if transcript and transcript.strip():
self.logger.info(
f"🤖 Triggering LLM response after turn ended for {event.participant.user_id}"
)

# Create participant object if we have metadata
participant = None
if hasattr(event, "custom") and event.custom:
# Try to extract participant info from custom metadata
participant = event.custom.get("participant")
# Create participant object if we have metadata
participant = None
if hasattr(event, "custom") and event.custom:
# Try to extract participant info from custom metadata
participant = event.custom.get("participant")

# Trigger LLM response with the complete transcript
if self.llm:
await self.simple_response(transcript, participant)
# Trigger LLM response with the complete transcript
await self.simple_response(transcript, participant)

# Clear the pending transcript for this speaker
self._pending_user_transcripts[event.participant.user_id] = ""
# Clear the pending transcript for this speaker
self._pending_user_transcripts[event.participant.user_id] = ""

async def _on_stt_error(self, error):
"""Handle STT service errors."""
self.logger.error(f"❌ STT Error: {error}")

@property
def realtime_mode(self) -> bool:
"""Check if the agent is in Realtime mode.

Returns:
True if `llm` is a `Realtime` implementation; otherwise False.
"""
if self.llm is not None and isinstance(self.llm, Realtime):
return True
return False

@property
def publish_audio(self) -> bool:
"""Whether the agent should publish an outbound audio track.

Returns:
True if TTS is configured, when in Realtime mode, or if there are audio publishers.
"""
if self.tts is not None or self.realtime_mode:
if self.tts is not None or _is_audio_llm(self.llm):
return True
# Also publish audio if there are audio publishers (e.g., HeyGen avatar)
if self.audio_publishers:
Expand Down Expand Up @@ -1227,9 +1221,7 @@ def _needs_audio_or_video_input(self) -> bool:
# Video input needed for:
# - Video processors (for frame analysis)
# - Realtime mode with video (multimodal LLMs)
needs_video = len(self.video_processors) > 0 or (
self.realtime_mode and isinstance(self.llm, Realtime)
)
needs_video = len(self.video_processors) > 0 or _is_video_llm(self.llm)

return needs_audio or needs_video

Expand Down Expand Up @@ -1280,7 +1272,7 @@ def image_processors(self) -> List[Any]:

def _validate_configuration(self):
"""Validate the agent configuration."""
if self.realtime_mode:
if _is_audio_llm(self.llm):
# Realtime mode - should not have separate STT/TTS
if self.stt or self.tts:
self.logger.warning(
Expand Down Expand Up @@ -1317,8 +1309,8 @@ def _prepare_rtc(self):

# Set up audio track if TTS is available
if self.publish_audio:
if self.realtime_mode and isinstance(self.llm, Realtime):
self._audio_track = self.llm.output_track
if _is_audio_llm(self.llm):
self._audio_track = self.llm.output_audio_track
self.logger.info("🎵 Using Realtime provider output track for audio")
elif self.audio_publishers:
# Get the first audio publisher to create the track
Expand Down
12 changes: 10 additions & 2 deletions agents-core/vision_agents/core/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .llm import LLM
from .llm import LLM, AudioLLM, VideoLLM, OmniLLM
from .realtime import Realtime
from .function_registry import FunctionRegistry, function_registry

__all__ = ["LLM", "Realtime", "FunctionRegistry", "function_registry"]
__all__ = [
"LLM",
"AudioLLM",
"VideoLLM",
"OmniLLM",
"Realtime",
"FunctionRegistry",
"function_registry",
]
Loading