Skip to content

Commit f6f1576

Browse files
committed
Add MCPToolChoice
1 parent 03b88e2 commit f6f1576

File tree

3 files changed

+28
-16
lines changed

3 files changed

+28
-16
lines changed

src/agents/model_settings.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def validate_from_none(value: None) -> _Omit:
4646
Omit = Annotated[_Omit, _OmitTypeAnnotation]
4747
Headers: TypeAlias = Mapping[str, Union[str, Omit]]
4848

49+
@dataclass
50+
class MCPToolChoice:
51+
server_label: str
52+
name: str
4953

5054
@dataclass
5155
class ModelSettings:
@@ -70,7 +74,7 @@ class ModelSettings:
7074
presence_penalty: float | None = None
7175
"""The presence penalty to use when calling the model."""
7276

73-
tool_choice: Literal["auto", "required", "none"] | str | None = None
77+
tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None = None
7478
"""The tool choice to use when calling the model."""
7579

7680
parallel_tool_calls: bool | None = None

src/agents/models/openai_responses.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ..handoffs import Handoff
2626
from ..items import ItemHelpers, ModelResponse, TResponseInputItem
2727
from ..logger import logger
28+
from ..model_settings import MCPToolChoice
2829
from ..tool import (
2930
CodeInterpreterTool,
3031
ComputerTool,
@@ -303,19 +304,16 @@ class ConvertedTools:
303304
class Converter:
304305
@classmethod
305306
def convert_tool_choice(
306-
cls, tool_choice: Literal["auto", "required", "none"] | str | dict[str, Any] | None
307+
cls, tool_choice: Literal["auto", "required", "none"] | str | MCPToolChoice | None
307308
) -> response_create_params.ToolChoice | NotGiven:
308309
if tool_choice is None:
309310
return NOT_GIVEN
310-
elif isinstance(tool_choice, dict):
311-
if tool_choice.get("type") == "mcp":
312-
return {
313-
"server_label": tool_choice.get("server_label") or "mcp",
314-
"type": "mcp",
315-
"name": tool_choice.get("name"),
316-
}
317-
else:
318-
raise UserError(f"Unknown tool choice: {tool_choice}")
311+
elif isinstance(tool_choice, MCPToolChoice):
312+
return {
313+
"server_label": tool_choice.server_label,
314+
"type": "mcp",
315+
"name": tool_choice.name,
316+
}
319317
elif tool_choice == "required":
320318
return "required"
321319
elif tool_choice == "auto":
@@ -343,10 +341,9 @@ def convert_tool_choice(
343341
"type": "code_interpreter",
344342
}
345343
elif tool_choice == "mcp":
346-
return {
347-
"server_label": "mcp",
348-
"type": "mcp",
349-
}
344+
# Note that this is still here for backwards compatibility,
345+
# but migrating to MCPToolChoice is recommended.
346+
return { "type": "mcp" }
350347
else:
351348
return {
352349
"type": "function",

tests/model_settings/test_serialization.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import TypeAdapter
66
from pydantic_core import to_json
77

8-
from agents.model_settings import ModelSettings
8+
from agents.model_settings import MCPToolChoice, ModelSettings
99

1010

1111
def verify_serialization(model_settings: ModelSettings) -> None:
@@ -29,6 +29,17 @@ def test_basic_serialization() -> None:
2929
verify_serialization(model_settings)
3030

3131

32+
def test_mcp_tool_choice_serialization() -> None:
33+
"""Tests whether ModelSettings with MCPToolChoice can be serialized to a JSON string."""
34+
# First, lets create a ModelSettings instance
35+
model_settings = ModelSettings(
36+
temperature=0.5,
37+
tool_choice=MCPToolChoice(server_label="mcp", name="mcp_tool"),
38+
)
39+
# Now, lets serialize the ModelSettings instance to a JSON string
40+
verify_serialization(model_settings)
41+
42+
3243
def test_all_fields_serialization() -> None:
3344
"""Tests whether ModelSettings can be serialized to a JSON string."""
3445

0 commit comments

Comments
 (0)