Skip to content

wip #1046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: rm/pr1045
Choose a base branch
from
Draft

wip #1046

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,33 @@ class AgentBase:
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
"""Configuration for MCP servers."""

async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
return await MCPUtil.get_all_function_tools(
self.mcp_servers, convert_schemas_to_strict, run_context, self
)

async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
"""All agent tools, including MCP tools and function tools."""
mcp_tools = await self.get_mcp_tools(run_context)

async def _check_tool_enabled(tool: Tool) -> bool:
if not isinstance(tool, FunctionTool):
return True

attr = tool.is_enabled
if isinstance(attr, bool):
return attr
res = attr(run_context, self)
if inspect.isawaitable(res):
return bool(await res)
return bool(res)

results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
return [*mcp_tools, *enabled]


@dataclass
class Agent(AgentBase, Generic[TContext]):
Expand Down Expand Up @@ -262,30 +289,3 @@ async def get_prompt(
) -> ResponsePromptParam | None:
"""Get the prompt for the agent."""
return await PromptUtil.to_model_input(self.prompt, run_context, self)

async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
return await MCPUtil.get_all_function_tools(
self.mcp_servers, convert_schemas_to_strict, run_context, self
)

async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
"""All agent tools, including MCP tools and function tools."""
mcp_tools = await self.get_mcp_tools(run_context)

async def _check_tool_enabled(tool: Tool) -> bool:
if not isinstance(tool, FunctionTool):
return True

attr = tool.is_enabled
if isinstance(attr, bool):
return attr
res = attr(run_context, self)
if inspect.isawaitable(res):
return bool(await res)
return bool(res)

results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
return [*mcp_tools, *enabled]
14 changes: 8 additions & 6 deletions src/agents/model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
class _OmitTypeAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def validate_from_none(value: None) -> _Omit:
return _Omit()
Expand All @@ -39,12 +39,14 @@ def validate_from_none(value: None) -> _Omit:
from_none_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: None
),
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None),
)


Omit = Annotated[_Omit, _OmitTypeAnnotation]
Headers: TypeAlias = Mapping[str, Union[str, Omit]]
ToolChoice: TypeAlias = Literal["auto", "required", "none"] | str | None


@dataclass
class ModelSettings:
Expand Down
69 changes: 69 additions & 0 deletions src/agents/realtime/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

from typing import (
Literal,
Union,
)

from typing_extensions import NotRequired, TypeAlias, TypedDict

from ..model_settings import ToolChoice
from ..tool import FunctionTool


class RealtimeClientMessage(TypedDict, total=False):
type: str # explicitly required
# All additional keys are permitted because total=False


class UserInputText(TypedDict):
type: Literal["input_text"]
text: str


class RealtimeUserInputMessage(TypedDict):
type: Literal["message"]
role: Literal["user"]
content: list[UserInputText]


RealtimeUserInput: TypeAlias = Union[str, RealtimeUserInputMessage]


RealtimeAudioFormat: TypeAlias = Union[Literal["pcm16", "g711_ulaw", "g711_alaw"], str]


class RealtimeInputAudioTranscriptionConfig(TypedDict, total=False):
language: NotRequired[str]
model: NotRequired[Literal["gpt-4o-transcribe", "gpt-4o-mini-transcribe", "whisper-1"] | str]
prompt: NotRequired[str]


class RealtimeTurnDetectionConfig(TypedDict, total=False):
"""Turn detection config. Allows extra vendor keys if needed."""

type: NotRequired[Literal["semantic_vad", "server_vad"]]
create_response: NotRequired[bool]
eagerness: NotRequired[Literal["auto", "low", "medium", "high"]]
interrupt_response: NotRequired[bool]
prefix_padding_ms: NotRequired[int]
silence_duration_ms: NotRequired[int]
threshold: NotRequired[float]


class RealtimeSessionConfig(TypedDict):
model: NotRequired[str]
instructions: NotRequired[str]
modalities: NotRequired[list[Literal["text", "audio"]]]
voice: NotRequired[str]

input_audio_format: NotRequired[RealtimeAudioFormat]
output_audio_format: NotRequired[RealtimeAudioFormat]
input_audio_transcription: NotRequired[RealtimeInputAudioTranscriptionConfig]
turn_detection: NotRequired[RealtimeTurnDetectionConfig]

tool_choice: NotRequired[ToolChoice]
tools: NotRequired[list[FunctionTool]]

# TODO (rm) Add tracing support
# tracing: NotRequired[RealtimeTracingConfig | None]
209 changes: 209 additions & 0 deletions src/agents/realtime/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
from dataclasses import dataclass
from typing import Literal, Union

from typing_extensions import TypeAlias

from ..run import RunContextWrapper
from ..tool import Tool
from .agent import RealtimeAgent
from .items import RealtimeItem
from .transport_events import RealtimeTransportAudioEvent, RealtimeTransportEvent


@dataclass
class RealtimeEventInfo:
context: RunContextWrapper
"""The context for the event."""


@dataclass
class RealtimeAgentStartEvent:
"""A new agent has started."""

type: Literal["agent_start"] = "agent_start"

agent: RealtimeAgent
"""The new agent."""

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeAgentEndEvent:
"""An agent has ended."""

type: Literal["agent_end"] = "agent_end"

agent: RealtimeAgent
"""The agent that ended."""

output: str
"""The output of the agent."""

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeHandoffEvent:
"""An agent has handed off to another agent."""

type: Literal["handoff"] = "handoff"

from_agent: RealtimeAgent
"""The agent that handed off."""

to_agent: RealtimeAgent
"""The agent that was handed off to."""

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeToolStart:
"""An agent is starting a tool call."""

type: Literal["tool_start"] = "tool_start"

agent: RealtimeAgent
"""The agent that updated."""

tool: Tool

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeToolEnd:
"""An agent has ended a tool call."""

type: Literal["tool_end"] = "tool_end"

agent: RealtimeAgent
"""The agent that ended the tool call."""

tool: Tool
"""The tool that was called."""

output: str
"""The output of the tool call."""

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeRawTransportEvent:
"""Forwards raw events from the transport layer."""

type: Literal["raw_transport_event"] = "raw_transport_event"

data: RealtimeTransportEvent
"""The raw data from the transport layer."""

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeAudioStart:
"""Triggered when the agent starts generating audio."""

type: Literal["audio_start"] = "audio_start"

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeAudioEnd:
"""Triggered when the agent stops generating audio."""

type: Literal["audio_end"] = "audio_end"

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeAudio:
"""Triggered when the agent generates new audio to be played."""

type: Literal["audio"] = "audio"

audio: RealtimeTransportAudioEvent
"""The audio event from the transport layer."""

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeAudioInterrupted:
"""Triggered when the agent is interrupted. Can be listened to by the user to stop audio playback
or give visual indicators to the user.
"""

type: Literal["audio_interrupted"] = "audio_interrupted"

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeError:
"""An error has occurred."""

type: Literal["error"] = "error"

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeHistoryUpdated:
"""The history has been updated. Contains the full history of the session."""

type: Literal["history_updated"] = "history_updated"

history: list[RealtimeItem]
"""The full history of the session."""

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


@dataclass
class RealtimeHistoryAdded:
"""A new item has been added to the history."""

type: Literal["history_added"] = "history_added"

item: RealtimeItem
"""The new item that was added to the history."""

info: RealtimeEventInfo
"""Common info for all events, such as the context."""


# TODO (rm) Add guardrails

RealtimeSessionEvent: TypeAlias = Union[
RealtimeAgentStartEvent,
RealtimeAgentEndEvent,
RealtimeHandoffEvent,
RealtimeToolStart,
RealtimeToolEnd,
RealtimeRawTransportEvent,
RealtimeAudioStart,
RealtimeAudioEnd,
RealtimeAudio,
RealtimeAudioInterrupted,
RealtimeError,
RealtimeHistoryUpdated,
RealtimeHistoryAdded,
]
"""An event emitted by the realtime session."""
Loading
Loading