88import re
99from datetime import datetime
1010from enum import StrEnum
11- from typing import Callable , NotRequired , TypedDict
11+ from typing import Callable , Literal , NotRequired , TypedDict
1212
1313import tinker
1414import torch
15+ import pydantic
1516
1617from tinker_cookbook .tokenizer_utils import Tokenizer
1718
1819logger = logging .getLogger (__name__ )
1920
21+ # Tool types are based on kosong (https://github.com/MoonshotAI/kosong).
2022
21- class ToolCall (TypedDict ):
22- name : str
23- # Each argument is a stringified JSON object
24- args : dict [str , str ]
23+
24+ class StrictBase (pydantic .BaseModel ):
25+ """
26+ Pydantic base class that's immutable and doesn't silently ignore extra fields.
27+ """
28+
29+ model_config = pydantic .ConfigDict (frozen = True , extra = "forbid" )
30+
31+ def __str__ (self ) -> str :
32+ return repr (self )
33+
34+
35+ class ToolCall (StrictBase ):
36+ """
37+ Structured tool invocation following OpenAI/kosong format.
38+
39+ This represents a request to invoke a tool/function. The structure follows
40+ the OpenAI function calling format for compatibility with various LLM APIs.
41+
42+ Example:
43+ tool_call = ToolCall(
44+ function=ToolCall.FunctionBody(
45+ name="search",
46+ arguments='{"query_list": ["python async", "pydantic validation"]}'
47+ ),
48+ id="call_abc123"
49+ )
50+ """
51+
52+ class FunctionBody (pydantic .BaseModel ):
53+ """
54+ Tool call function body containing the tool name and arguments.
55+
56+ The arguments field must be a valid JSON string that will be parsed
57+ by the tool implementation.
58+ """
59+
60+ name : str
61+ """The name of the tool to be called."""
62+ arguments : str
63+ """Arguments of the tool call in JSON string format."""
64+
65+ type : Literal ["function" ] = "function"
66+ """Tool call type, must be 'function' for compatibility."""
67+
68+ id : str | None = None
69+ """Optional unique identifier for tracking this specific tool call."""
70+
71+ function : FunctionBody
72+ """The function body containing tool name and arguments."""
73+
74+
75+ class ToolOk (StrictBase ):
76+ """
77+ Successful tool execution result.
78+
79+ Used to indicate that a tool call completed successfully, with
80+ the main output and optional metadata fields.
81+ """
82+
83+ output : str
84+ """The main output/result from the tool execution."""
85+
86+ message : str = ""
87+ """Optional human-readable message about the execution."""
88+
89+ brief : str = ""
90+ """Optional brief summary of the result for logging."""
91+
92+
93+ class ToolError (StrictBase ):
94+ """
95+ Tool execution error result.
96+
97+ Used to indicate that a tool call failed or encountered an error,
98+ with details about what went wrong.
99+ """
100+
101+ output : str = ""
102+ """Any partial output that was generated before the error."""
103+
104+ message : str = ""
105+ """Error message describing what went wrong."""
106+
107+ brief : str = ""
108+ """Brief error summary for logging."""
109+
110+
111+ ToolReturnType = ToolOk | ToolError
112+ """Union type for tool execution results - either success or error."""
113+
114+
115+ class ToolResult (StrictBase ):
116+ """
117+ Complete tool execution result with tracking ID.
118+
119+ Wraps the actual result (ToolOk or ToolError) with the corresponding
120+ tool call ID for correlation in multi-tool scenarios.
121+
122+ Note: This class is defined for future use in handling multiple
123+ concurrent tool calls with result correlation.
124+ """
125+
126+ tool_call_id : str | None
127+ """ID of the tool call this result corresponds to."""
128+
129+ result : ToolReturnType
130+ """The actual execution result (success or error)."""
25131
26132
27133# NOTE: we use a broad type definition for the role to be flexible
@@ -35,6 +141,17 @@ class Message(TypedDict):
35141 tool_calls : NotRequired [list [ToolCall ]]
36142 thinking : NotRequired [str ]
37143 trainable : NotRequired [bool ]
144+ tool_call_id : NotRequired [str ]
145+ name : NotRequired [str ]
146+
147+
148+ def _tool_call_payload (tool_call : ToolCall ) -> dict [str , object ]:
149+ """Minimal JSON payload for embedding in <tool_call> blocks."""
150+ # Convert from nested structure to flat format for compatibility
151+ return {
152+ "name" : tool_call .function .name ,
153+ "args" : json .loads (tool_call .function .arguments ),
154+ }
38155
39156
40157class TrainOnWhat (StrEnum ):
@@ -369,7 +486,7 @@ def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[i
369486 if "tool_calls" in message :
370487 ac_content += "\n " .join (
371488 [
372- f"<tool_call>\n { json .dumps (tool_call )} \n </tool_call>"
489+ f"<tool_call>\n { json .dumps (_tool_call_payload ( tool_call ) )} \n </tool_call>"
373490 for tool_call in message ["tool_calls" ]
374491 ]
375492 )
@@ -425,15 +542,20 @@ def _parse_tool_call(self, tool_call_str: str) -> list[ToolCall] | None:
425542
426543 if not isinstance (tool_call , dict ):
427544 return None
428- if (
429- "name" not in tool_call
430- or "args" not in tool_call
431- or not isinstance (tool_call ["name" ], str )
432- or not isinstance (tool_call ["args" ], dict )
433- ):
545+ name = tool_call .get ("name" )
546+ args = tool_call .get ("args" )
547+ tool_id = tool_call .get ("id" )
548+ if not isinstance (name , str ) or not isinstance (args , dict ):
434549 return None
435-
436- return [ToolCall (** tool_call )]
550+ if tool_id is not None and not isinstance (tool_id , str ):
551+ tool_id = None
552+ # Convert to nested structure with arguments as JSON string
553+ return [
554+ ToolCall (
555+ function = ToolCall .FunctionBody (name = name , arguments = json .dumps (args )),
556+ id = tool_id ,
557+ )
558+ ]
437559
438560 def parse_response (self , response : list [int ]) -> tuple [Message , bool ]:
439561 assistant_message , parse_success = parse_response_for_stop_token (
@@ -485,7 +607,7 @@ def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[i
485607 if "tool_calls" in message :
486608 ac_content += "\n " .join (
487609 [
488- f"<tool_call>\n { json .dumps (tool_call )} \n </tool_call>"
610+ f"<tool_call>\n { json .dumps (_tool_call_payload ( tool_call ) )} \n </tool_call>"
489611 for tool_call in message ["tool_calls" ]
490612 ]
491613 )
0 commit comments