|
| 1 | +import collections |
| 2 | +import json |
| 3 | +import os |
| 4 | +import sys |
| 5 | +import time |
| 6 | +import types |
| 7 | +from typing import Generator, List |
| 8 | + |
| 9 | +import pytest |
| 10 | +import shortuuid |
| 11 | + |
| 12 | +# Ensure local package is imported (not any site-packages installation) |
| 13 | +REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) |
| 14 | +if REPO_ROOT not in sys.path: |
| 15 | + sys.path.insert(0, REPO_ROOT) |
| 16 | + |
| 17 | + |
| 18 | +def _install_openai_harmony_stub(): |
| 19 | + """Install a minimal stub for `openai_harmony` so the module imports |
| 20 | + without the real dependency. |
| 21 | +
|
| 22 | + The GptOssChatParser test injects its own dummy parser, so the stub is sufficient. |
| 23 | + """ |
| 24 | + if 'openai_harmony' in sys.modules: |
| 25 | + return |
| 26 | + m = types.ModuleType('openai_harmony') |
| 27 | + |
| 28 | + class HarmonyEncodingName: |
| 29 | + HARMONY_GPT_OSS = 'HARMONY_GPT_OSS' |
| 30 | + |
| 31 | + class Role: |
| 32 | + ASSISTANT = 'assistant' |
| 33 | + |
| 34 | + class StreamableParser: # pragma: no cover - constructor only used |
| 35 | + |
| 36 | + def __init__(self, encoding, role=None): |
| 37 | + self.encoding = encoding |
| 38 | + self.role = role |
| 39 | + |
| 40 | + def load_harmony_encoding(name): # pragma: no cover - not used in test |
| 41 | + return object() |
| 42 | + |
| 43 | + m.HarmonyEncodingName = HarmonyEncodingName |
| 44 | + m.Role = Role |
| 45 | + m.StreamableParser = StreamableParser |
| 46 | + m.load_harmony_encoding = load_harmony_encoding |
| 47 | + sys.modules['openai_harmony'] = m |
| 48 | + |
| 49 | + |
| 50 | +TestExpects = collections.namedtuple('TestExpects', 'func_name location') |
| 51 | + |
| 52 | + |
| 53 | +class DummyParser: |
| 54 | + """A minimal stand-in for Harmony's StreamableParser with channels. |
| 55 | +
|
| 56 | + Control tokens: |
| 57 | + -1: start functions.get_weather (commentary) |
| 58 | + -4: start functions.get_time (commentary) |
| 59 | + -6: start functions.get_weather (again) |
| 60 | + -9: end current tool call, append to `messages` |
| 61 | + -2: switch to final (visible) content |
| 62 | + -3: switch to analysis (reasoning) |
| 63 | + Other tokens are interpreted as chr(token). |
| 64 | + """ |
| 65 | + |
| 66 | + class _Msg: |
| 67 | + |
| 68 | + def __init__(self, channel, recipient): |
| 69 | + self.channel = channel |
| 70 | + self.recipient = recipient |
| 71 | + |
| 72 | + def __init__(self): |
| 73 | + self.current_channel = None |
| 74 | + self.current_recipient = None |
| 75 | + self.last_content_delta = '' |
| 76 | + self.messages = [] |
| 77 | + |
| 78 | + def process(self, token): |
| 79 | + if token == -1: |
| 80 | + self.current_channel = 'commentary' |
| 81 | + self.current_recipient = 'functions.get_weather' |
| 82 | + self.last_content_delta = '' |
| 83 | + return |
| 84 | + if token == -4: |
| 85 | + self.current_channel = 'commentary' |
| 86 | + self.current_recipient = 'functions.get_time' |
| 87 | + self.last_content_delta = '' |
| 88 | + return |
| 89 | + if token == -6: |
| 90 | + self.current_channel = 'commentary' |
| 91 | + self.current_recipient = 'functions.get_weather' |
| 92 | + self.last_content_delta = '' |
| 93 | + return |
| 94 | + if token == -9: |
| 95 | + if self.current_channel == 'commentary' and self.current_recipient and self.current_recipient.startswith( |
| 96 | + 'functions.'): |
| 97 | + self.messages.append(self._Msg(self.current_channel, self.current_recipient)) |
| 98 | + # reset recipient to signal end of current tool call |
| 99 | + self.current_recipient = None |
| 100 | + self.current_channel = None |
| 101 | + self.last_content_delta = '' |
| 102 | + return |
| 103 | + if token == -2: |
| 104 | + self.current_channel = 'final' |
| 105 | + self.current_recipient = None |
| 106 | + self.last_content_delta = '' |
| 107 | + return |
| 108 | + if token == -3: |
| 109 | + self.current_channel = 'analysis' |
| 110 | + self.current_recipient = None |
| 111 | + self.last_content_delta = '' |
| 112 | + return |
| 113 | + # regular character token |
| 114 | + self.last_content_delta = chr(token) |
| 115 | + |
| 116 | + |
| 117 | +def _chat_completion_v1(request, token_chunks: List[List[int]]): |
| 118 | + from lmdeploy.serve.openai.harmony_utils import GptOssChatParser |
| 119 | + from lmdeploy.serve.openai.protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, |
| 120 | + ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, |
| 121 | + UsageInfo) |
| 122 | + |
| 123 | + request_id = f'chat-{shortuuid.random()}' |
| 124 | + created_time = int(time.time()) |
| 125 | + model_name = request.model |
| 126 | + |
| 127 | + parser = GptOssChatParser() |
| 128 | + parser.parser = DummyParser() |
| 129 | + |
| 130 | + if request.stream: |
| 131 | + |
| 132 | + def completion_stream_generator() -> Generator['ChatCompletionStreamResponse', None, None]: |
| 133 | + finish_reason = 'stop' |
| 134 | + for chunk in token_chunks: |
| 135 | + delta_message = parser.parse_streaming(chunk) |
| 136 | + choice_data = ChatCompletionResponseStreamChoice(index=0, |
| 137 | + delta=delta_message, |
| 138 | + finish_reason=finish_reason, |
| 139 | + logprobs=None) |
| 140 | + response = ChatCompletionStreamResponse(id=request_id, |
| 141 | + created=created_time, |
| 142 | + model=model_name, |
| 143 | + choices=[choice_data], |
| 144 | + usage=None) |
| 145 | + yield response |
| 146 | + |
| 147 | + return completion_stream_generator() |
| 148 | + |
| 149 | + # Non-stream path: parse all tokens at once using parse_full |
| 150 | + tokens: List[int] = [] |
| 151 | + for c in token_chunks: |
| 152 | + tokens.extend(c) |
| 153 | + message = parser.parse_full(tokens) |
| 154 | + finish_reason = 'tool_calls' if message.tool_calls else 'stop' |
| 155 | + choice_data = ChatCompletionResponseChoice(index=0, message=message, finish_reason=finish_reason) |
| 156 | + return ChatCompletionResponse(id=request_id, |
| 157 | + created=created_time, |
| 158 | + model=model_name, |
| 159 | + choices=[choice_data], |
| 160 | + usage=UsageInfo()) |
| 161 | + |
| 162 | + |
| 163 | +def _stream_parse(request, token_chunks: List[List[int]]): |
| 164 | + from lmdeploy.serve.openai.protocol import DeltaMessage |
| 165 | + |
| 166 | + content = '' |
| 167 | + reasoning_content = '' |
| 168 | + tool_calls_by_index = {} |
| 169 | + |
| 170 | + for i, stream_resp in enumerate(_chat_completion_v1(request, token_chunks)): |
| 171 | + delta_message: DeltaMessage = stream_resp.choices[0].delta |
| 172 | + if delta_message.content: |
| 173 | + content += delta_message.content |
| 174 | + if delta_message.reasoning_content: |
| 175 | + reasoning_content += delta_message.reasoning_content |
| 176 | + if delta_message.tool_calls: |
| 177 | + for c in delta_message.tool_calls: |
| 178 | + idx = c.index |
| 179 | + existing_call = tool_calls_by_index.get(idx, None) |
| 180 | + if not existing_call: |
| 181 | + tool_calls_by_index[idx] = c |
| 182 | + continue |
| 183 | + if c.function.name: |
| 184 | + existing_call.function.name = c.function.name |
| 185 | + if c.function.arguments: |
| 186 | + existing_call.function.arguments = existing_call.function.arguments or '' |
| 187 | + existing_call.function.arguments += c.function.arguments |
| 188 | + # sorted list for stable order |
| 189 | + tool_calls = [tool_calls_by_index[i] for i in sorted(tool_calls_by_index.keys())] |
| 190 | + return content, reasoning_content, tool_calls |
| 191 | + |
| 192 | + |
| 193 | +def _t(s: str) -> List[int]: |
| 194 | + return [ord(c) for c in s] |
| 195 | + |
| 196 | + |
| 197 | +# Basic: single function call split across two chunks (bug repro scenario) |
| 198 | +TOKENS_SINGLE_CALL_TWO_CHUNKS = [ |
| 199 | + [-1] + _t('{"location": "Paris'), |
| 200 | + _t(', France"}'), |
| 201 | +] |
| 202 | + |
| 203 | +# Multiple calls with indices and different function names |
| 204 | +TOKENS_TWO_CALLS_DIFFERENT_FUNCS = [ |
| 205 | + [-1] + _t('{"location": "Berlin"}') + [-9] + [-4] + _t('{"city": "New'), |
| 206 | + _t(' York"}') + [-9], |
| 207 | +] |
| 208 | + |
| 209 | +# Interleaved channels: analysis, tool call, final content |
| 210 | +TOKENS_INTERLEAVED = [ |
| 211 | + [-3] + _t('Thinking about the weather. ') + [-1] + _t('{"location": "Par'), |
| 212 | + _t('is, France"}') + [-9] + [-2] + _t('Fetching the weather now.'), |
| 213 | +] |
| 214 | + |
| 215 | +# Two calls, same function name, indices increment |
| 216 | +TOKENS_TWO_CALLS_SAME_FUNC = [ |
| 217 | + [-1] + _t('{"location": "Tokyo"}') + [-9], |
| 218 | + [-6] + _t('{"location": "Ky'), |
| 219 | + _t('oto"}') + [-9], |
| 220 | +] |
| 221 | + |
| 222 | + |
| 223 | +@pytest.mark.parametrize(('token_chunks', 'expects'), [ |
| 224 | + (TOKENS_SINGLE_CALL_TWO_CHUNKS, [TestExpects('get_weather', 'Paris, France')]), |
| 225 | +]) |
| 226 | +def test_parser_stream_basic(token_chunks: List[List[int]], expects: List[TestExpects]): |
| 227 | + from lmdeploy.serve.openai.protocol import ChatCompletionRequest |
| 228 | + |
| 229 | + _install_openai_harmony_stub() |
| 230 | + request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True) |
| 231 | + content, reasoning_content, tool_calls = _stream_parse(request, token_chunks) |
| 232 | + |
| 233 | + assert len(tool_calls) == len(expects) |
| 234 | + for parsed_call, expected_call in zip(tool_calls, expects): |
| 235 | + assert parsed_call.function.name == expected_call.func_name |
| 236 | + args = json.loads(parsed_call.function.arguments) |
| 237 | + assert args['location'] == expected_call.location |
| 238 | + assert content.strip() == '' |
| 239 | + assert (reasoning_content or '').strip() == '' |
| 240 | + |
| 241 | + |
| 242 | +def test_parser_stream_multiple_calls_indices(): |
| 243 | + from lmdeploy.serve.openai.protocol import ChatCompletionRequest |
| 244 | + |
| 245 | + _install_openai_harmony_stub() |
| 246 | + request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True) |
| 247 | + content, reasoning_content, tool_calls = _stream_parse(request, TOKENS_TWO_CALLS_DIFFERENT_FUNCS) |
| 248 | + |
| 249 | + assert len(tool_calls) == 2 |
| 250 | + # tool_calls sorted by index ensures stable order |
| 251 | + tc0, tc1 = tool_calls |
| 252 | + assert tc0.index == 0 and tc1.index == 1 |
| 253 | + assert tc0.function.name == 'get_weather' |
| 254 | + assert json.loads(tc0.function.arguments)['location'] == 'Berlin' |
| 255 | + assert tc1.function.name == 'get_time' |
| 256 | + assert json.loads(tc1.function.arguments)['city'] == 'New York' |
| 257 | + assert (content or '').strip() == '' |
| 258 | + assert (reasoning_content or '').strip() == '' |
| 259 | + |
| 260 | + |
| 261 | +def test_parser_stream_interleaved_channels(): |
| 262 | + from lmdeploy.serve.openai.protocol import ChatCompletionRequest |
| 263 | + |
| 264 | + _install_openai_harmony_stub() |
| 265 | + request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True) |
| 266 | + content, reasoning_content, tool_calls = _stream_parse(request, TOKENS_INTERLEAVED) |
| 267 | + |
| 268 | + assert json.loads(tool_calls[0].function.arguments)['location'] == 'Paris, France' |
| 269 | + assert reasoning_content == 'Thinking about the weather. ' |
| 270 | + assert content == 'Fetching the weather now.' |
| 271 | + |
| 272 | + |
| 273 | +@pytest.mark.parametrize(('token_chunks', 'expects'), [ |
| 274 | + (TOKENS_TWO_CALLS_SAME_FUNC, [TestExpects('get_weather', 'Tokyo'), |
| 275 | + TestExpects('get_weather', 'Kyoto')]), |
| 276 | +]) |
| 277 | +def test_parser_stream_two_calls_same_func(token_chunks: List[List[int]], expects: List[TestExpects]): |
| 278 | + from lmdeploy.serve.openai.protocol import ChatCompletionRequest |
| 279 | + |
| 280 | + _install_openai_harmony_stub() |
| 281 | + request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True) |
| 282 | + _, _, tool_calls = _stream_parse(request, token_chunks) |
| 283 | + |
| 284 | + assert len(tool_calls) == len(expects) |
| 285 | + for parsed_call, expected_call in zip(tool_calls, expects): |
| 286 | + assert parsed_call.function.name == expected_call.func_name |
| 287 | + args = json.loads(parsed_call.function.arguments) |
| 288 | + assert args['location'] == expected_call.location |
| 289 | + |
| 290 | + |
| 291 | +def test_open_tool_call_no_args(): |
| 292 | + from lmdeploy.serve.openai.protocol import ChatCompletionRequest |
| 293 | + |
| 294 | + _install_openai_harmony_stub() |
| 295 | + request = ChatCompletionRequest(model='gpt-oss', messages=[], stream=True) |
| 296 | + content, reasoning_content, tool_calls = _stream_parse(request, [[-1]]) |
| 297 | + |
| 298 | + assert len(tool_calls) == 1 |
| 299 | + assert tool_calls[0].function.name == 'get_weather' |
| 300 | + assert (tool_calls[0].function.arguments or '') == '' |
| 301 | + assert (content or '') == '' |
| 302 | + assert (reasoning_content or '') == '' |
| 303 | + |
| 304 | + |
| 305 | +@pytest.mark.parametrize(('token_chunks', 'expects'), [ |
| 306 | + (TOKENS_SINGLE_CALL_TWO_CHUNKS, [TestExpects('get_weather', 'Paris, France')]), |
| 307 | + (TOKENS_TWO_CALLS_SAME_FUNC, [TestExpects('get_weather', 'Tokyo'), |
| 308 | + TestExpects('get_weather', 'Kyoto')]), |
| 309 | +]) |
| 310 | +def test_parser_nonstream(token_chunks: List[List[int]], expects: List[TestExpects]): |
| 311 | + from lmdeploy.serve.openai.protocol import ChatCompletionRequest |
| 312 | + |
| 313 | + _install_openai_harmony_stub() |
| 314 | + resp = _chat_completion_v1(ChatCompletionRequest(model='gpt-oss', messages=[], stream=False), token_chunks) |
| 315 | + |
| 316 | + assert len(resp.choices) == 1 |
| 317 | + first_message = resp.choices[0].message |
| 318 | + assert first_message.content is None |
| 319 | + assert (first_message.reasoning_content or '') == '' |
| 320 | + assert len(first_message.tool_calls) == len(expects) |
| 321 | + for parsed_call, expected_call in zip(first_message.tool_calls, expects): |
| 322 | + assert parsed_call.function.name == expected_call.func_name |
| 323 | + args = json.loads(parsed_call.function.arguments) |
| 324 | + assert args['location'] == expected_call.location |
0 commit comments