Skip to content

Commit 53c6c38

Browse files
authored
Use stricter typing for tool calls (following kosong library) (#127)
1 parent 40cf35f commit 53c6c38

File tree

3 files changed

+162
-25
lines changed

3 files changed

+162
-25
lines changed

tinker_cookbook/recipes/tool_use/search/search_env.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,19 +164,23 @@ async def step(self, action: Action) -> StepResult:
164164
self.past_messages.append(message)
165165

166166
if "tool_calls" in message:
167+
tool_calls = message["tool_calls"]
167168
failure_result = StepResult(
168169
reward=0.0,
169170
episode_done=True,
170171
next_observation=tinker.ModelInput.empty(),
171172
next_stop_condition=self.stop_condition,
172173
)
173-
if message["tool_calls"][0]["name"] == "search":
174+
# Check if tool_calls list is not empty
175+
if not tool_calls:
176+
return failure_result
177+
if tool_calls[0].function.name == "search":
174178
self.current_num_calls += 1
175179
if self.current_num_calls > self.max_num_calls:
176180
return failure_result
177181
# NOTE(tianyi): seems wasteful: we should share the client somehow
178182
try:
179-
tool_return_message = await self.call_search_tool(message["tool_calls"][0])
183+
tool_return_message = await self.call_search_tool(tool_calls[0])
180184
self.past_messages.extend(tool_return_message)
181185
except Exception as e:
182186
logger.error(f"Error calling search tool: {repr(e)}")

tinker_cookbook/recipes/tool_use/search/tools.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import json
23
import logging
34
from abc import ABC, abstractmethod
45
from typing import Any
@@ -142,17 +143,27 @@ async def _query_chroma_with_retry(self, query_embeddings: list[list[float]]) ->
142143
raise RuntimeError("All ChromaDB query attempts failed")
143144

144145
async def invoke(self, tool_call: ToolCall) -> list[Message]:
145-
if tool_call["name"] != "search":
146-
raise ValueError(f"Invalid tool name: {tool_call['name']}")
147-
if not isinstance(tool_call["args"], dict) or "query_list" not in tool_call["args"]:
146+
if tool_call.function.name != "search":
147+
raise ValueError(f"Invalid tool name: {tool_call.function.name}")
148+
149+
# Parse arguments with error handling
150+
try:
151+
args = json.loads(tool_call.function.arguments)
152+
except json.JSONDecodeError as e:
153+
return [
154+
Message(
155+
role="tool",
156+
content=f"Error invoking search tool: Invalid JSON in arguments - {str(e)}",
157+
)
158+
]
159+
160+
query_list = args.get("query_list")
161+
if not isinstance(query_list, list):
148162
return [
149163
Message(role="tool", content="Error invoking search tool: query_list is required")
150164
]
151-
query_list = tool_call["args"]["query_list"]
152-
if (
153-
not isinstance(query_list, list)
154-
or not len(query_list) > 0
155-
or not all(isinstance(query, str) and len(query.strip()) > 0 for query in query_list)
165+
if not query_list or not all(
166+
isinstance(query, str) and query.strip() for query in query_list
156167
):
157168
return [
158169
Message(

tinker_cookbook/renderers.py

Lines changed: 137 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,126 @@
88
import re
99
from datetime import datetime
1010
from enum import StrEnum
11-
from typing import Callable, NotRequired, TypedDict
11+
from typing import Callable, Literal, NotRequired, TypedDict
1212

1313
import tinker
1414
import torch
15+
import pydantic
1516

1617
from tinker_cookbook.tokenizer_utils import Tokenizer
1718

1819
logger = logging.getLogger(__name__)
1920

21+
# Tool types are based on kosong (https://github.com/MoonshotAI/kosong).
2022

21-
class ToolCall(TypedDict):
22-
name: str
23-
# Each argument is a stringified JSON object
24-
args: dict[str, str]
23+
24+
class StrictBase(pydantic.BaseModel):
25+
"""
26+
Pydantic base class that's immutable and doesn't silently ignore extra fields.
27+
"""
28+
29+
model_config = pydantic.ConfigDict(frozen=True, extra="forbid")
30+
31+
def __str__(self) -> str:
32+
return repr(self)
33+
34+
35+
class ToolCall(StrictBase):
36+
"""
37+
Structured tool invocation following OpenAI/kosong format.
38+
39+
This represents a request to invoke a tool/function. The structure follows
40+
the OpenAI function calling format for compatibility with various LLM APIs.
41+
42+
Example:
43+
tool_call = ToolCall(
44+
function=ToolCall.FunctionBody(
45+
name="search",
46+
arguments='{"query_list": ["python async", "pydantic validation"]}'
47+
),
48+
id="call_abc123"
49+
)
50+
"""
51+
52+
class FunctionBody(pydantic.BaseModel):
53+
"""
54+
Tool call function body containing the tool name and arguments.
55+
56+
The arguments field must be a valid JSON string that will be parsed
57+
by the tool implementation.
58+
"""
59+
60+
name: str
61+
"""The name of the tool to be called."""
62+
arguments: str
63+
"""Arguments of the tool call in JSON string format."""
64+
65+
type: Literal["function"] = "function"
66+
"""Tool call type, must be 'function' for compatibility."""
67+
68+
id: str | None = None
69+
"""Optional unique identifier for tracking this specific tool call."""
70+
71+
function: FunctionBody
72+
"""The function body containing tool name and arguments."""
73+
74+
75+
class ToolOk(StrictBase):
76+
"""
77+
Successful tool execution result.
78+
79+
Used to indicate that a tool call completed successfully, with
80+
the main output and optional metadata fields.
81+
"""
82+
83+
output: str
84+
"""The main output/result from the tool execution."""
85+
86+
message: str = ""
87+
"""Optional human-readable message about the execution."""
88+
89+
brief: str = ""
90+
"""Optional brief summary of the result for logging."""
91+
92+
93+
class ToolError(StrictBase):
94+
"""
95+
Tool execution error result.
96+
97+
Used to indicate that a tool call failed or encountered an error,
98+
with details about what went wrong.
99+
"""
100+
101+
output: str = ""
102+
"""Any partial output that was generated before the error."""
103+
104+
message: str = ""
105+
"""Error message describing what went wrong."""
106+
107+
brief: str = ""
108+
"""Brief error summary for logging."""
109+
110+
111+
ToolReturnType = ToolOk | ToolError
112+
"""Union type for tool execution results - either success or error."""
113+
114+
115+
class ToolResult(StrictBase):
116+
"""
117+
Complete tool execution result with tracking ID.
118+
119+
Wraps the actual result (ToolOk or ToolError) with the corresponding
120+
tool call ID for correlation in multi-tool scenarios.
121+
122+
Note: This class is defined for future use in handling multiple
123+
concurrent tool calls with result correlation.
124+
"""
125+
126+
tool_call_id: str | None
127+
"""ID of the tool call this result corresponds to."""
128+
129+
result: ToolReturnType
130+
"""The actual execution result (success or error)."""
25131

26132

27133
# NOTE: we use a broad type definition for the role to be flexible
@@ -35,6 +141,17 @@ class Message(TypedDict):
35141
tool_calls: NotRequired[list[ToolCall]]
36142
thinking: NotRequired[str]
37143
trainable: NotRequired[bool]
144+
tool_call_id: NotRequired[str]
145+
name: NotRequired[str]
146+
147+
148+
def _tool_call_payload(tool_call: ToolCall) -> dict[str, object]:
149+
"""Minimal JSON payload for embedding in <tool_call> blocks."""
150+
# Convert from nested structure to flat format for compatibility
151+
return {
152+
"name": tool_call.function.name,
153+
"args": json.loads(tool_call.function.arguments),
154+
}
38155

39156

40157
class TrainOnWhat(StrEnum):
@@ -369,7 +486,7 @@ def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[i
369486
if "tool_calls" in message:
370487
ac_content += "\n".join(
371488
[
372-
f"<tool_call>\n{json.dumps(tool_call)}\n</tool_call>"
489+
f"<tool_call>\n{json.dumps(_tool_call_payload(tool_call))}\n</tool_call>"
373490
for tool_call in message["tool_calls"]
374491
]
375492
)
@@ -425,15 +542,20 @@ def _parse_tool_call(self, tool_call_str: str) -> list[ToolCall] | None:
425542

426543
if not isinstance(tool_call, dict):
427544
return None
428-
if (
429-
"name" not in tool_call
430-
or "args" not in tool_call
431-
or not isinstance(tool_call["name"], str)
432-
or not isinstance(tool_call["args"], dict)
433-
):
545+
name = tool_call.get("name")
546+
args = tool_call.get("args")
547+
tool_id = tool_call.get("id")
548+
if not isinstance(name, str) or not isinstance(args, dict):
434549
return None
435-
436-
return [ToolCall(**tool_call)]
550+
if tool_id is not None and not isinstance(tool_id, str):
551+
tool_id = None
552+
# Convert to nested structure with arguments as JSON string
553+
return [
554+
ToolCall(
555+
function=ToolCall.FunctionBody(name=name, arguments=json.dumps(args)),
556+
id=tool_id,
557+
)
558+
]
437559

438560
def parse_response(self, response: list[int]) -> tuple[Message, bool]:
439561
assistant_message, parse_success = parse_response_for_stop_token(
@@ -485,7 +607,7 @@ def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[i
485607
if "tool_calls" in message:
486608
ac_content += "\n".join(
487609
[
488-
f"<tool_call>\n{json.dumps(tool_call)}\n</tool_call>"
610+
f"<tool_call>\n{json.dumps(_tool_call_payload(tool_call))}\n</tool_call>"
489611
for tool_call in message["tool_calls"]
490612
]
491613
)

0 commit comments

Comments
 (0)