Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 135 additions & 23 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ def __init__(self, agent: Agent, sess: AgentSession) -> None:

self._drain_blocked_tasks: list[asyncio.Task[Any]] = []

# Media enable state tracking for hot toggling
self._target_audio_input_enabled = sess.input.audio_enabled
self._current_audio_input_enabled: bool | None = None
self._target_audio_output_enabled = sess.output.audio_enabled
self._current_audio_output_enabled: bool | None = None

self._attached_stt: stt.STT | None = None
self._stt_prewarmed = False
self._attached_tts: tts.TTS | None = None
self._tts_prewarmed = False

if self._turn_detection_mode == "vad" and not self.vad:
logger.warning("turn_detection is set to 'vad', but no VAD model is provided")
self._turn_detection_mode = None
Expand Down Expand Up @@ -414,12 +425,6 @@ async def start(self) -> None:
if isinstance(self.llm, llm.LLM):
self.llm.prewarm()

if isinstance(self.stt, stt.STT):
self.stt.prewarm()

if isinstance(self.tts, tts.TTS):
self.tts.prewarm()

# don't use start_span for _start_session, avoid nested user/assistant turns
await self._start_session()
self._started = True
Expand Down Expand Up @@ -447,14 +452,6 @@ async def _start_session(self) -> None:
self.llm.on("metrics_collected", self._on_metrics_collected)
self.llm.on("error", self._on_error)

if isinstance(self.stt, stt.STT):
self.stt.on("metrics_collected", self._on_metrics_collected)
self.stt.on("error", self._on_error)

if isinstance(self.tts, tts.TTS):
self.tts.on("metrics_collected", self._on_metrics_collected)
self.tts.on("error", self._on_error)

if isinstance(self.vad, vad.VAD):
self.vad.on("metrics_collected", self._on_metrics_collected)

Expand Down Expand Up @@ -538,16 +535,19 @@ async def _list_mcp_tools_task(
logger.exception("failed to update the instructions")

await self._resume_scheduling_task()
initial_audio_enabled = self._target_audio_input_enabled
self._audio_recognition = AudioRecognition(
hooks=self,
stt=self._agent.stt_node if self.stt else None,
vad=self.vad,
stt=self._agent.stt_node if (initial_audio_enabled and self.stt) else None,
vad=self.vad if initial_audio_enabled else None,
turn_detector=self.turn_detection if not isinstance(self.turn_detection, str) else None,
min_endpointing_delay=self.min_endpointing_delay,
max_endpointing_delay=self.max_endpointing_delay,
turn_detection_mode=self._turn_detection_mode,
)
self._audio_recognition.start()
self._apply_audio_input_enabled()
self._apply_audio_output_enabled()

@tracer.start_as_current_span("drain_agent_activity")
async def drain(self) -> None:
Expand Down Expand Up @@ -650,13 +650,8 @@ async def _close_session(self) -> None:
self._rt_session.off("metrics_collected", self._on_metrics_collected)
self._rt_session.off("error", self._on_error)

if isinstance(self.stt, stt.STT):
self.stt.off("metrics_collected", self._on_metrics_collected)
self.stt.off("error", self._on_error)

if isinstance(self.tts, tts.TTS):
self.tts.off("metrics_collected", self._on_metrics_collected)
self.tts.off("error", self._on_error)
self._detach_stt_handlers()
self._detach_tts_handlers()

if isinstance(self.vad, vad.VAD):
self.vad.off("metrics_collected", self._on_metrics_collected)
Expand All @@ -669,6 +664,10 @@ async def _close_session(self) -> None:

if self._audio_recognition is not None:
await self._audio_recognition.aclose()
self._audio_recognition = None

self._current_audio_input_enabled = None
self._current_audio_output_enabled = None

await self._interrupt_paused_speech(old_task=self._interrupt_paused_speech_task)
self._interrupt_paused_speech_task = None
Expand Down Expand Up @@ -902,6 +901,14 @@ def on_playout_done(_: SpeechHandle) -> None:

return future

def on_audio_input_enabled_changed(self, enabled: bool) -> None:
self._target_audio_input_enabled = enabled
self._apply_audio_input_enabled()

def on_audio_output_enabled_changed(self, enabled: bool) -> None:
self._target_audio_output_enabled = enabled
self._apply_audio_output_enabled()

def clear_user_turn(self) -> None:
if self._audio_recognition:
self._audio_recognition.clear_user_turn()
Expand Down Expand Up @@ -2399,6 +2406,111 @@ async def _interrupt_paused_speech(self, old_task: asyncio.Task[None] | None = N
if self._session.options.resume_false_interruption and self._session.output.audio:
self._session.output.audio.resume()

def _apply_audio_input_enabled(self) -> None:
if self._audio_recognition is None:
return

enabled = self._target_audio_input_enabled
if self._current_audio_input_enabled == enabled:
return

if enabled:
self._attach_stt_handlers()
self._audio_recognition.update_vad(self.vad if self.vad else None)
stt_node = self._agent.stt_node if self.stt else None
self._audio_recognition.update_stt(stt_node)
else:
self._audio_recognition.update_stt(None)
self._audio_recognition.update_vad(None)
self._detach_stt_handlers()

self._current_audio_input_enabled = enabled

def _apply_audio_output_enabled(self) -> None:
if self._audio_recognition is None:
return

enabled = self._target_audio_output_enabled
previously = self._current_audio_output_enabled

if previously == enabled:
return

if enabled:
self._attach_tts_handlers()
else:
self._detach_tts_handlers()
if previously:
if self._started:
future = self.interrupt(force=True)

def _log_interrupt(fut: asyncio.Future[None]) -> None:
if fut.cancelled():
return
exc = fut.exception()
if exc:
logger.warning(
"failed to interrupt pending speeches after disabling audio output",
exc_info=exc,
)

future.add_done_callback(_log_interrupt)
self._current_audio_output_enabled = enabled

def _attach_stt_handlers(self) -> None:
model = self.stt
if not isinstance(model, stt.STT):
self._detach_stt_handlers()
return

if self._attached_stt is model:
return

if self._attached_stt is not None:
self._detach_stt_handlers()

model.on("metrics_collected", self._on_metrics_collected)
model.on("error", self._on_error)
if not self._stt_prewarmed:
model.prewarm()
self._stt_prewarmed = True
self._attached_stt = model

def _detach_stt_handlers(self) -> None:
if self._attached_stt is None:
return

self._attached_stt.off("metrics_collected", self._on_metrics_collected)
self._attached_stt.off("error", self._on_error)
self._attached_stt = None

def _attach_tts_handlers(self) -> None:
model = self.tts
if not isinstance(model, tts.TTS):
self._detach_tts_handlers()
return

if self._attached_tts is model:
return

if self._attached_tts is not None:
self._detach_tts_handlers()

model.on("metrics_collected", self._on_metrics_collected)
model.on("error", self._on_error)
if not self._tts_prewarmed:
model.prewarm()
self._tts_prewarmed = True
self._attached_tts = model

def _detach_tts_handlers(self) -> None:
if self._attached_tts is None:
return

self._attached_tts.off("metrics_collected", self._on_metrics_collected)
self._attached_tts.off("error", self._on_error)
self._attached_tts = None

# move them to the end to avoid shadowing the same named modules for mypy
@property
def vad(self) -> vad.VAD | None:
Expand Down
19 changes: 18 additions & 1 deletion livekit-agents/livekit/agents/voice/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,16 @@ def __init__(
self._tts_error_counts = 0

# configurable IO
self._input = io.AgentInput(self._on_video_input_changed, self._on_audio_input_changed)
self._input = io.AgentInput(
self._on_video_input_changed,
self._on_audio_input_changed,
audio_enabled_changed=self._on_audio_input_enabled_changed,
)
self._output = io.AgentOutput(
self._on_video_output_changed,
self._on_audio_output_changed,
self._on_text_output_changed,
audio_enabled_changed=self._on_audio_output_enabled_changed,
)

self._forward_audio_atask: asyncio.Task[None] | None = None
Expand Down Expand Up @@ -1154,6 +1159,12 @@ def _on_audio_input_changed(self) -> None:
self._forward_audio_task(), name="_forward_audio_task"
)

def _on_audio_input_enabled_changed(self, enabled: bool) -> None:
if self._activity is not None:
self._activity.on_audio_input_enabled_changed(enabled)
if self._next_activity is not None:
self._next_activity.on_audio_input_enabled_changed(enabled)

def _on_video_output_changed(self) -> None:
pass

Expand All @@ -1172,6 +1183,12 @@ def _on_audio_output_changed(self) -> None:
def _on_text_output_changed(self) -> None:
pass

def _on_audio_output_enabled_changed(self, enabled: bool) -> None:
if self._activity is not None:
self._activity.on_audio_output_enabled_changed(enabled)
if self._next_activity is not None:
self._next_activity.on_audio_output_enabled_changed(enabled)

# ---

async def __aenter__(self) -> AgentSession:
Expand Down
26 changes: 25 additions & 1 deletion livekit-agents/livekit/agents/voice/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,16 @@ def __repr__(self) -> str:

class AgentInput:
def __init__(
self, video_changed: Callable[[], None], audio_changed: Callable[[], None]
self,
video_changed: Callable[[], None],
audio_changed: Callable[[], None],
audio_enabled_changed: Callable[[bool], None] | None = None,
) -> None:
self._video_stream: VideoInput | None = None
self._audio_stream: AudioInput | None = None
self._video_changed = video_changed
self._audio_changed = audio_changed
self._audio_enabled_changed = audio_enabled_changed

# enabled by default
self._audio_enabled = True
Expand All @@ -362,13 +366,22 @@ def set_audio_enabled(self, enable: bool) -> None:
self._audio_enabled = enable

if not self._audio_stream:
if self._audio_changed:
self._audio_changed()
if self._audio_enabled_changed:
self._audio_enabled_changed(enable)
return

if enable:
self._audio_stream.on_attached()
else:
self._audio_stream.on_detached()

self._audio_changed()

if self._audio_enabled_changed:
self._audio_enabled_changed(enable)

def set_video_enabled(self, enable: bool) -> None:
if enable and not self._video_stream:
logger.warning("Cannot enable video input when it's not set")
Expand Down Expand Up @@ -443,13 +456,15 @@ def __init__(
video_changed: Callable[[], None],
audio_changed: Callable[[], None],
transcription_changed: Callable[[], None],
audio_enabled_changed: Callable[[bool], None] | None = None,
) -> None:
self._video_sink: VideoOutput | None = None
self._audio_sink: AudioOutput | None = None
self._transcription_sink: TextOutput | None = None
self._video_changed = video_changed
self._audio_changed = audio_changed
self._transcription_changed = transcription_changed
self._audio_enabled_changed = audio_enabled_changed

self._audio_enabled = True
self._video_enabled = True
Expand Down Expand Up @@ -482,13 +497,22 @@ def set_audio_enabled(self, enabled: bool) -> None:
self._audio_enabled = enabled

if not self._audio_sink:
if self._audio_changed:
self._audio_changed()
if self._audio_enabled_changed:
self._audio_enabled_changed(enabled)
return

if enabled:
self._audio_sink.on_attached()
else:
self._audio_sink.on_detached()

self._audio_changed()

if self._audio_enabled_changed:
self._audio_enabled_changed(enabled)

def set_transcription_enabled(self, enabled: bool) -> None:
if enabled and not self._transcription_sink:
logger.warning("Cannot enable transcription output when it's not set")
Expand Down
Loading
Loading