Skip to content

Commit 420b5c4

Browse files
authored
Fix GPT-OSS streaming tool call parsing (#4023)
Co-authored-by: QwertyJack <[email protected]>
1 parent c875e37 commit 420b5c4

File tree

2 files changed

+331
-2
lines changed

2 files changed

+331
-2
lines changed

lmdeploy/serve/openai/harmony_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,14 @@ def parse_streaming(self, tokens: List[int]) -> DeltaMessage:
5858
index=base_index,
5959
function=DeltaFunctionCall(name=tool_name, arguments=''))
6060
elif delta_text:
61+
# Continuing the same tool call. Ensure we don't duplicate the
62+
# very first delta string in this chunk. Previously we initialized
63+
# with arguments=delta_text and then appended again, causing
64+
# duplicated content like "locationlocation".
6165
if delta_tool_call is None:
62-
delta_tool_call = DeltaToolCall(index=base_index,
63-
function=DeltaFunctionCall(arguments=delta_text))
66+
# We are in the middle of a tool call carried over from the
67+
# previous chunk. Initialize an empty arguments buffer.
68+
delta_tool_call = DeltaToolCall(index=base_index, function=DeltaFunctionCall(arguments=''))
6469
delta_tool_call.function.arguments += delta_text
6570

6671
if delta_tool_call:
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
import collections
2+
import json
3+
import os
4+
import sys
5+
import time
6+
import types
7+
from typing import Generator, List
8+
9+
import pytest
10+
import shortuuid
11+
12+
# Ensure local package is imported (not any site-packages installation)
13+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
14+
if REPO_ROOT not in sys.path:
15+
sys.path.insert(0, REPO_ROOT)
16+
17+
18+
def _install_openai_harmony_stub():
19+
"""Install a minimal stub for `openai_harmony` so the module imports
20+
without the real dependency.
21+
22+
The GptOssChatParser test injects its own dummy parser, so the stub is sufficient.
23+
"""
24+
if 'openai_harmony' in sys.modules:
25+
return
26+
m = types.ModuleType('openai_harmony')
27+
28+
class HarmonyEncodingName:
29+
HARMONY_GPT_OSS = 'HARMONY_GPT_OSS'
30+
31+
class Role:
32+
ASSISTANT = 'assistant'
33+
34+
class StreamableParser: # pragma: no cover - constructor only used
35+
36+
def __init__(self, encoding, role=None):
37+
self.encoding = encoding
38+
self.role = role
39+
40+
def load_harmony_encoding(name): # pragma: no cover - not used in test
41+
return object()
42+
43+
m.HarmonyEncodingName = HarmonyEncodingName
44+
m.Role = Role
45+
m.StreamableParser = StreamableParser
46+
m.load_harmony_encoding = load_harmony_encoding
47+
sys.modules['openai_harmony'] = m
48+
49+
50+
TestExpects = collections.namedtuple('TestExpects', 'func_name location')
51+
52+
53+
class DummyParser:
54+
"""A minimal stand-in for Harmony's StreamableParser with channels.
55+
56+
Control tokens:
57+
-1: start functions.get_weather (commentary)
58+
-4: start functions.get_time (commentary)
59+
-6: start functions.get_weather (again)
60+
-9: end current tool call, append to `messages`
61+
-2: switch to final (visible) content
62+
-3: switch to analysis (reasoning)
63+
Other tokens are interpreted as chr(token).
64+
"""
65+
66+
class _Msg:
67+
68+
def __init__(self, channel, recipient):
69+
self.channel = channel
70+
self.recipient = recipient
71+
72+
def __init__(self):
73+
self.current_channel = None
74+
self.current_recipient = None
75+
self.last_content_delta = ''
76+
self.messages = []
77+
78+
def process(self, token):
79+
if token == -1:
80+
self.current_channel = 'commentary'
81+
self.current_recipient = 'functions.get_weather'
82+
self.last_content_delta = ''
83+
return
84+
if token == -4:
85+
self.current_channel = 'commentary'
86+
self.current_recipient = 'functions.get_time'
87+
self.last_content_delta = ''
88+
return
89+
if token == -6:
90+
self.current_channel = 'commentary'
91+
self.current_recipient = 'functions.get_weather'
92+
self.last_content_delta = ''
93+
return
94+
if token == -9:
95+
if self.current_channel == 'commentary' and self.current_recipient and self.current_recipient.startswith(
96+
'functions.'):
97+
self.messages.append(self._Msg(self.current_channel, self.current_recipient))
98+
# reset recipient to signal end of current tool call
99+
self.current_recipient = None
100+
self.current_channel = None
101+
self.last_content_delta = ''
102+
return
103+
if token == -2:
104+
self.current_channel = 'final'
105+
self.current_recipient = None
106+
self.last_content_delta = ''
107+
return
108+
if token == -3:
109+
self.current_channel = 'analysis'
110+
self.current_recipient = None
111+
self.last_content_delta = ''
112+
return
113+
# regular character token
114+
self.last_content_delta = chr(token)
115+
116+
117+
def _chat_completion_v1(request, token_chunks: List[List[int]]):
118+
from lmdeploy.serve.openai.harmony_utils import GptOssChatParser
119+
from lmdeploy.serve.openai.protocol import (ChatCompletionResponse, ChatCompletionResponseChoice,
120+
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
121+
UsageInfo)
122+
123+
request_id = f'chat-{shortuuid.random()}'
124+
created_time = int(time.time())
125+
model_name = request.model
126+
127+
parser = GptOssChatParser()
128+
parser.parser = DummyParser()
129+
130+
if request.stream:
131+
132+
def completion_stream_generator() -> Generator['ChatCompletionStreamResponse', None, None]:
133+
finish_reason = 'stop'
134+
for chunk in token_chunks:
135+
delta_message = parser.parse_streaming(chunk)
136+
choice_data = ChatCompletionResponseStreamChoice(index=0,
137+
delta=delta_message,
138+
finish_reason=finish_reason,
139+
logprobs=None)
140+
response = ChatCompletionStreamResponse(id=request_id,
141+
created=created_time,
142+
model=model_name,
143+
choices=[choice_data],
144+
usage=None)
145+
yield response
146+
147+
return completion_stream_generator()
148+
149+
# Non-stream path: parse all tokens at once using parse_full
150+
tokens: List[int] = []
151+
for c in token_chunks:
152+
tokens.extend(c)
153+
message = parser.parse_full(tokens)
154+
finish_reason = 'tool_calls' if message.tool_calls else 'stop'
155+
choice_data = ChatCompletionResponseChoice(index=0, message=message, finish_reason=finish_reason)
156+
return ChatCompletionResponse(id=request_id,
157+
created=created_time,
158+
model=model_name,
159+
choices=[choice_data],
160+
usage=UsageInfo())
161+
162+
163+
def _stream_parse(request, token_chunks: List[List[int]]):
164+
from lmdeploy.serve.openai.protocol import DeltaMessage
165+
166+
content = ''
167+
reasoning_content = ''
168+
tool_calls_by_index = {}
169+
170+
for i, stream_resp in enumerate(_chat_completion_v1(request, token_chunks)):
171+
delta_message: DeltaMessage = stream_resp.choices[0].delta
172+
if delta_message.content:
173+
content += delta_message.content
174+
if delta_message.reasoning_content:
175+
reasoning_content += delta_message.reasoning_content
176+
if delta_message.tool_calls:
177+
for c in delta_message.tool_calls:
178+
idx = c.index
179+
existing_call = tool_calls_by_index.get(idx, None)
180+
if not existing_call:
181+
tool_calls_by_index[idx] = c
182+
continue
183+
if c.function.name:
184+
existing_call.function.name = c.function.name
185+
if c.function.arguments:
186+
existing_call.function.arguments = existing_call.function.arguments or ''
187+
existing_call.function.arguments += c.function.arguments
188+
# sorted list for stable order
189+
tool_calls = [tool_calls_by_index[i] for i in sorted(tool_calls_by_index.keys())]
190+
return content, reasoning_content, tool_calls
191+
192+
193+
def _t(s: str) -> List[int]:
194+
return [ord(c) for c in s]
195+
196+
197+
# Basic: single function call split across two chunks (bug repro scenario)
198+
TOKENS_SINGLE_CALL_TWO_CHUNKS = [
199+
[-1] + _t('{"location": "Paris'),
200+
_t(', France"}'),
201+
]
202+
203+
# Multiple calls with indices and different function names
204+
TOKENS_TWO_CALLS_DIFFERENT_FUNCS = [
205+
[-1] + _t('{"location": "Berlin"}') + [-9] + [-4] + _t('{"city": "New'),
206+
_t(' York"}') + [-9],
207+
]
208+
209+
# Interleaved channels: analysis, tool call, final content
210+
TOKENS_INTERLEAVED = [
211+
[-3] + _t('Thinking about the weather. ') + [-1] + _t('{"location": "Par'),
212+
_t('is, France"}') + [-9] + [-2] + _t('Fetching the weather now.'),
213+
]
214+
215+
# Two calls, same function name, indices increment
216+
TOKENS_TWO_CALLS_SAME_FUNC = [
217+
[-1] + _t('{"location": "Tokyo"}') + [-9],
218+
[-6] + _t('{"location": "Ky'),
219+
_t('oto"}') + [-9],
220+
]
221+
222+
223+
@pytest.mark.parametrize(('token_chunks', 'expects'), [
224+
(TOKENS_SINGLE_CALL_TWO_CHUNKS, [TestExpects('get_weather', 'Paris, France')]),
225+
])
226+
def test_parser_stream_basic(token_chunks: List[List[int]], expects: List[TestExpects]):
227+
from lmdeploy.serve.openai.protocol import ChatCompletionRequest
228+
229+
_install_openai_harmony_stub()
230+
request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)
231+
content, reasoning_content, tool_calls = _stream_parse(request, token_chunks)
232+
233+
assert len(tool_calls) == len(expects)
234+
for parsed_call, expected_call in zip(tool_calls, expects):
235+
assert parsed_call.function.name == expected_call.func_name
236+
args = json.loads(parsed_call.function.arguments)
237+
assert args['location'] == expected_call.location
238+
assert content.strip() == ''
239+
assert (reasoning_content or '').strip() == ''
240+
241+
242+
def test_parser_stream_multiple_calls_indices():
243+
from lmdeploy.serve.openai.protocol import ChatCompletionRequest
244+
245+
_install_openai_harmony_stub()
246+
request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)
247+
content, reasoning_content, tool_calls = _stream_parse(request, TOKENS_TWO_CALLS_DIFFERENT_FUNCS)
248+
249+
assert len(tool_calls) == 2
250+
# tool_calls sorted by index ensures stable order
251+
tc0, tc1 = tool_calls
252+
assert tc0.index == 0 and tc1.index == 1
253+
assert tc0.function.name == 'get_weather'
254+
assert json.loads(tc0.function.arguments)['location'] == 'Berlin'
255+
assert tc1.function.name == 'get_time'
256+
assert json.loads(tc1.function.arguments)['city'] == 'New York'
257+
assert (content or '').strip() == ''
258+
assert (reasoning_content or '').strip() == ''
259+
260+
261+
def test_parser_stream_interleaved_channels():
262+
from lmdeploy.serve.openai.protocol import ChatCompletionRequest
263+
264+
_install_openai_harmony_stub()
265+
request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)
266+
content, reasoning_content, tool_calls = _stream_parse(request, TOKENS_INTERLEAVED)
267+
268+
assert json.loads(tool_calls[0].function.arguments)['location'] == 'Paris, France'
269+
assert reasoning_content == 'Thinking about the weather. '
270+
assert content == 'Fetching the weather now.'
271+
272+
273+
@pytest.mark.parametrize(('token_chunks', 'expects'), [
274+
(TOKENS_TWO_CALLS_SAME_FUNC, [TestExpects('get_weather', 'Tokyo'),
275+
TestExpects('get_weather', 'Kyoto')]),
276+
])
277+
def test_parser_stream_two_calls_same_func(token_chunks: List[List[int]], expects: List[TestExpects]):
278+
from lmdeploy.serve.openai.protocol import ChatCompletionRequest
279+
280+
_install_openai_harmony_stub()
281+
request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)
282+
_, _, tool_calls = _stream_parse(request, token_chunks)
283+
284+
assert len(tool_calls) == len(expects)
285+
for parsed_call, expected_call in zip(tool_calls, expects):
286+
assert parsed_call.function.name == expected_call.func_name
287+
args = json.loads(parsed_call.function.arguments)
288+
assert args['location'] == expected_call.location
289+
290+
291+
def test_open_tool_call_no_args():
292+
from lmdeploy.serve.openai.protocol import ChatCompletionRequest
293+
294+
_install_openai_harmony_stub()
295+
request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True)
296+
content, reasoning_content, tool_calls = _stream_parse(request, [[-1]])
297+
298+
assert len(tool_calls) == 1
299+
assert tool_calls[0].function.name == 'get_weather'
300+
assert (tool_calls[0].function.arguments or '') == ''
301+
assert (content or '') == ''
302+
assert (reasoning_content or '') == ''
303+
304+
305+
@pytest.mark.parametrize(('token_chunks', 'expects'), [
306+
(TOKENS_SINGLE_CALL_TWO_CHUNKS, [TestExpects('get_weather', 'Paris, France')]),
307+
(TOKENS_TWO_CALLS_SAME_FUNC, [TestExpects('get_weather', 'Tokyo'),
308+
TestExpects('get_weather', 'Kyoto')]),
309+
])
310+
def test_parser_nonstream(token_chunks: List[List[int]], expects: List[TestExpects]):
311+
from lmdeploy.serve.openai.protocol import ChatCompletionRequest
312+
313+
_install_openai_harmony_stub()
314+
resp = _chat_completion_v1(ChatCompletionRequest(model='gpt-oss', messages=[], stream=False), token_chunks)
315+
316+
assert len(resp.choices) == 1
317+
first_message = resp.choices[0].message
318+
assert first_message.content is None
319+
assert (first_message.reasoning_content or '') == ''
320+
assert len(first_message.tool_calls) == len(expects)
321+
for parsed_call, expected_call in zip(first_message.tool_calls, expects):
322+
assert parsed_call.function.name == expected_call.func_name
323+
args = json.loads(parsed_call.function.arguments)
324+
assert args['location'] == expected_call.location

0 commit comments

Comments
 (0)