diff --git a/autotest/interface/restful/test_restful_reasoning.py b/autotest/interface/restful/test_restful_reasoning.py new file mode 100644 index 0000000000..ae4953ae41 --- /dev/null +++ b/autotest/interface/restful/test_restful_reasoning.py @@ -0,0 +1,2105 @@ +import json +import sys +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +from utils.constant import BACKEND_LIST, TOOL_REASONING_MODEL_LIST + +from utils.tool_reasoning_definitions import ( # isort: skip + CALCULATOR_TOOL, REASONING_PARSER_NAMES, SEARCH_TOOL, THINK_END_TOKEN, THINK_START_TOKEN, WEATHER_TOOL, + WEATHER_TOOL_CN, assert_arguments_parseable, assert_tool_call_fields, build_messages_with_tool_response, + build_reasoning_tool_roundtrip_messages, collect_stream_reasoning, get_reasoning_content, get_reasoning_tokens, + make_logged_client, setup_log_file) + +_LMDEPLOY_ROOT = str(Path(__file__).resolve().parents[3] / 'lmdeploy') +_PROJECT_ROOT = str(Path(__file__).resolve().parents[3]) +for _p in (_PROJECT_ROOT, _LMDEPLOY_ROOT): + if _p not in sys.path: + sys.path.insert(0, _p) + +# Lazy imports - only used by parser unit-test classes +_deepseek_parser_cls = None +_qwen_parser_cls = None +_parser_manager = None + + +def _get_deepseek_parser_cls(): + global _deepseek_parser_cls + if _deepseek_parser_cls is None: + from lmdeploy.serve.openai.reasoning_parser import DeepSeekR1ReasoningParser + _deepseek_parser_cls = DeepSeekR1ReasoningParser + return _deepseek_parser_cls + + +def _get_qwen_parser_cls(): + global _qwen_parser_cls + if _qwen_parser_cls is None: + from lmdeploy.serve.openai.reasoning_parser import QwenQwQReasoningParser + _qwen_parser_cls = QwenQwQReasoningParser + return _qwen_parser_cls + + +def _get_parser_manager(): + global _parser_manager + if _parser_manager is None: + from lmdeploy.serve.openai.reasoning_parser import ReasoningParserManager + _parser_manager = ReasoningParserManager + return _parser_manager + + +def _make_mock_tokenizer(vocab=None): + """Create a mock tokenizer with a configurable vocab dict.""" + tok = MagicMock() + default_vocab = { + '': 100, + '': 101, + } + tok.get_vocab.return_value = vocab or default_vocab + return tok + + +def _make_mock_request(): + """Create a mock ChatCompletionRequest.""" + req = MagicMock() + req.model = 'test-model' + return req + + +def _simple_tokenize(text, vocab): + """Tokenise *text* into a list of integer token IDs. + + Known special tokens are mapped via *vocab*; every other character gets a deterministic ID > 200 so it never + collides with specials. + """ + ids = [] + i = 0 + sorted_tokens = sorted(vocab.keys(), key=len, reverse=True) + while i < len(text): + matched = False + for tok_str in sorted_tokens: + if text[i:i + len(tok_str)] == tok_str: + ids.append(vocab[tok_str]) + i += len(tok_str) + matched = True + break + if not matched: + ids.append(200 + ord(text[i])) + i += 1 + return ids + + +def _run_streaming_extraction(parser, deltas, vocab): + """Simulate streaming through *parser* with a list of text *deltas*. + + Returns ``(reasoning_content, content)`` aggregated from all deltas. + """ + reasoning_parts = [] + content_parts = [] + previous_text = '' + previous_token_ids = [] + + for d in deltas: + current_text = previous_text + d + current_token_ids = _simple_tokenize(current_text, vocab) + delta_token_ids = _simple_tokenize(d, vocab) + + result = parser.extract_reasoning_content_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=d, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + ) + + if result is not None: + rc = getattr(result, 'reasoning_content', None) + ct = getattr(result, 'content', None) + if rc: + reasoning_parts.append(rc) + if ct: + content_parts.append(ct) + + previous_text = current_text + previous_token_ids = current_token_ids + + reasoning = ''.join(reasoning_parts) if reasoning_parts else None + content = ''.join(content_parts) if content_parts else None + return reasoning, content + + +# --------------------------------------------------------------------------- +# Shared constants for parser unit tests +# --------------------------------------------------------------------------- + +_DEFAULT_VOCAB = {'': 100, '': 101} + +_BOTH_PARSERS = pytest.mark.parametrize('parser_factory', [ + pytest.param(_get_deepseek_parser_cls, id='deepseek'), + pytest.param(_get_qwen_parser_cls, id='qwen'), +]) + +# --------------------------------------------------------------------------- +# Marks +# --------------------------------------------------------------------------- + +_CLASS_MARKS = [ + pytest.mark.order(9), + pytest.mark.reasoning, + pytest.mark.deepseek_r1_parser, + pytest.mark.qwenqwq_parser, + pytest.mark.flaky(reruns=2), + pytest.mark.parametrize('backend', BACKEND_LIST), + pytest.mark.parametrize('model_case', TOOL_REASONING_MODEL_LIST), +] + +_CLASS_MARKS_STREAM = _CLASS_MARKS + [ + pytest.mark.parametrize('stream', [False, True], ids=['nonstream', 'stream']), +] + +_PARSER_MARKS = [ + pytest.mark.order(9), + pytest.mark.reasoning, + pytest.mark.reasoning_parser, +] + + +def _apply_marks(cls): + """Apply the shared API-level marks to *cls* (no stream parametrize).""" + for m in _CLASS_MARKS: + cls = m(cls) + return cls + + +def _apply_marks_stream(cls): + """Apply API-level marks WITH stream parametrize to *cls*.""" + for m in _CLASS_MARKS_STREAM: + cls = m(cls) + return cls + + +def _apply_parser_marks(cls): + """Apply lightweight marks to parser unit-test classes (no parametrize).""" + for m in _PARSER_MARKS: + cls = m(cls) + return cls + + +# --------------------------------------------------------------------------- +# Shared assertion helpers +# --------------------------------------------------------------------------- + + +def _assert_no_tag_leakage(reasoning, content): + """Assert that / tags do not appear in reasoning or + content.""" + for label, text in [('reasoning', reasoning), ('content', content)]: + assert THINK_START_TOKEN not in text, (f' leaked into {label}: {text[:100]}') + assert THINK_END_TOKEN not in text, (f' leaked into {label}: {text[:100]}') + + +# --------------------------------------------------------------------------- +# Logging helpers – uses shared StreamTee / setup_log_file / make_logged_client +# from utils.tool_reasoning_definitions. +# --------------------------------------------------------------------------- + + +class _ReasoningTestBase: + """Mixin providing per-test API request/response logging and unified + ``_call_api`` helper for both streaming and non-streaming modes.""" + + @pytest.fixture(autouse=True) + def _setup_logging(self, request, config, backend, model_case): + """Create the log directory and compute the log-file path.""" + self._log_file = setup_log_file(config, request.node.name, 'reasoning') + + def _get_client(self): + """Return *(client, model_name)* with transparent logging.""" + return make_logged_client(self._log_file) + + def _call_api(self, stream, messages, **create_kwargs): + """Unified API call for both streaming and non-streaming. + + Returns a dict with keys: + reasoning, content, finish_reason, role, tool_calls (list of dicts), + chunk_count, reasoning_chunks, content_chunks, + finish_reason_count, role_count, + _response (non-stream only), _choice (non-stream only). + """ + client, model_name = self._get_client() + create_kwargs.setdefault('temperature', 0) + create_kwargs.setdefault('max_completion_tokens', 1024) + create_kwargs.setdefault('logprobs', False) + extra_body = create_kwargs.pop('extra_body', {}) + extra_body.setdefault('enable_thinking', True) + create_kwargs['extra_body'] = extra_body + + if stream: + resp = client.chat.completions.create(model=model_name, messages=messages, stream=True, **create_kwargs) + sr = collect_stream_reasoning(resp) + tool_calls = [] + for idx in sorted(sr['tool_calls'].keys()): + tool_calls.append(sr['tool_calls'][idx]) + return { + 'reasoning': sr['reasoning_content'] or '', + 'content': sr['content'] or '', + 'finish_reason': sr['finish_reason'], + 'role': sr.get('role'), + 'tool_calls': tool_calls, + 'chunk_count': sr.get('chunk_count', 0), + 'reasoning_chunks': sr.get('reasoning_chunks', 0), + 'content_chunks': sr.get('content_chunks', 0), + 'finish_reason_count': sr.get('finish_reason_count', 0), + 'role_count': sr.get('role_count', 0), + } + else: + resp = client.chat.completions.create(model=model_name, messages=messages, **create_kwargs) + choice = resp.choices[0] + reasoning = get_reasoning_content(choice.message) or '' + content = choice.message.content or '' + tool_calls = [] + if choice.message.tool_calls: + for tc in choice.message.tool_calls: + tool_calls.append({ + 'name': tc.function.name, + 'args_str': tc.function.arguments, + 'id': tc.id, + }) + return { + 'reasoning': reasoning, + 'content': content, + 'finish_reason': choice.finish_reason, + 'role': choice.message.role, + 'tool_calls': tool_calls, + '_response': resp, + '_choice': choice, + } + + +# --------------------------------------------------------------------------- +# Message constants +# --------------------------------------------------------------------------- + +MESSAGES_REASONING_BASIC = [ + { + 'role': 'system', + 'content': 'You are a helpful assistant. Think step by step ' + 'before answering.', + }, + { + 'role': 'user', + 'content': 'What is 37 * 43? Explain your reasoning.', + }, +] + +MESSAGES_REASONING_COMPLEX = [ + { + 'role': + 'system', + 'content': + 'You are a math tutor. Always think through the problem ' + 'step by step before providing the final answer.', + }, + { + 'role': + 'user', + 'content': + 'A train leaves station A at 60 km/h and another train ' + 'leaves station B at 80 km/h. If the stations are 280 km ' + 'apart and trains leave at the same time heading towards ' + 'each other, when will they meet?', + }, +] + +MESSAGES_REASONING_SIMPLE = [ + { + 'role': 'user', + 'content': 'What is 2 + 2?', + }, +] + +MESSAGES_REASONING_WEATHER_TOOL = [ + { + 'role': + 'system', + 'content': + 'You are a helpful assistant that can use tools. ' + 'Think step by step before deciding whether to use a tool.', + }, + { + 'role': 'user', + 'content': "I'm traveling to Dallas, TX tomorrow. Should I pack " + 'an umbrella? Check the weather first.', + }, +] + +MESSAGES_REASONING_CN = [ + { + 'role': 'system', + 'content': '你是一个有用的助手。请逐步思考后再回答问题。', + }, + { + 'role': 'user', + 'content': '一辆火车从A站以60公里/小时出发,另一辆从B站以80公里/小时出发。' + '如果两站相距280公里,两车相向而行,何时相遇?', + }, +] + +MESSAGES_REASONING_MULTI_TURN = [ + { + 'role': 'system', + 'content': 'You are a math tutor. Think step by step.', + }, + { + 'role': 'user', + 'content': 'What is the sum of the first 10 natural numbers?', + }, + { + 'role': 'assistant', + 'content': 'The sum of 1 to 10 is 55.', + }, + { + 'role': 'user', + 'content': 'Now what is the sum of the first 100 natural numbers? ' + 'Explain your reasoning.', + }, +] + +MESSAGES_REASONING_PARALLEL_TOOLS = [ + { + 'role': + 'system', + 'content': + 'You are a helpful assistant. Think about what tools to ' + 'use, then call them. You can call multiple tools at once.', + }, + { + 'role': 'user', + 'content': "What's the weather in Dallas, TX? " + 'Also calculate 37 * 43.', + }, +] + +MESSAGES_REASONING_SEARCH_TOOL = [ + { + 'role': + 'system', + 'content': + 'You are a helpful assistant that can use tools. ' + 'Think step by step before deciding whether to use a tool.', + }, + { + 'role': 'user', + 'content': 'Search for the latest advances in quantum computing ' + 'in 2025. Summarize the key breakthroughs.', + }, +] + + +def _build_search_roundtrip_messages(tool_call_id='call_search_001'): + """Build multi-turn messages for web_search round-trip.""" + return [ + { + 'role': 'system', + 'content': 'You are a helpful assistant that can use tools. ' + 'Think through problems step by step.', + }, + { + 'role': 'user', + 'content': 'Search for who won the 2024 Nobel Prize in Physics.', + }, + { + 'role': + 'assistant', + 'content': + 'Let me search for the 2024 Nobel Prize in Physics winner.', + 'tool_calls': [{ + 'id': tool_call_id, + 'type': 'function', + 'function': { + 'name': 'web_search', + 'arguments': '{"query": "2024 Nobel Prize in Physics winner"}', + }, + }], + }, + { + 'role': + 'tool', + 'tool_call_id': + tool_call_id, + 'content': + json.dumps({ + 'title': + '2024 Nobel Prize in Physics', + 'snippet': + 'The 2024 Nobel Prize in Physics was awarded to ' + 'John Hopfield and Geoffrey Hinton for foundational ' + 'discoveries in machine learning with artificial ' + 'neural networks.', + 'url': + 'https://www.nobelprize.org/prizes/physics/2024/', + }), + }, + ] + + +# =========================================================================== +# Parser unit-test classes (no API calls, no logging needed) +# =========================================================================== + + +@_apply_parser_marks +class TestReasoningParserManager: + """Verify that all parsers are correctly registered.""" + + def test_all_parser_names_registered(self): + mgr = _get_parser_manager() + for name in REASONING_PARSER_NAMES: + cls = mgr.get(name) + assert cls is not None, (f'Parser "{name}" not found in ReasoningParserManager') + + @pytest.mark.parametrize('name,expected_cls_name', [ + ('deepseek-r1', 'DeepSeekR1ReasoningParser'), + ('qwen-qwq', 'QwenQwQReasoningParser'), + ('intern-s1', 'QwenQwQReasoningParser'), + ]) + def test_specific_parser_class(self, name, expected_cls_name): + mgr = _get_parser_manager() + cls = mgr.get(name) + assert cls is not None + assert cls.__name__ == expected_cls_name + + def test_unknown_parser_returns_none(self): + mgr = _get_parser_manager() + result = mgr.get('unknown-parser-xyz') + assert result is None + + +@_apply_parser_marks +@pytest.mark.deepseek_r1_parser +class TestDeepSeekR1ParserNonStreaming: + """Unit tests for DeepSeekR1ReasoningParser.extract_reasoning_content.""" + + @pytest.fixture(autouse=True) + def _setup(self): + self.tok = _make_mock_tokenizer() + self.parser = _get_deepseek_parser_cls()(self.tok) + self.req = _make_mock_request() + + def test_full_think_tags(self): + model_output = 'Let me think step by step...The answer is 42.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert 'step by step' in reasoning + assert final is not None + assert '42' in final + + def test_missing_start_tag(self): + model_output = 'I need to reason about this...The answer is 7.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert final is not None + assert '7' in final + + def test_no_think_tags_at_all(self): + model_output = 'Just a plain answer without reasoning.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning == model_output + assert final is None + + def test_empty_reasoning(self): + model_output = 'Direct answer.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert final is not None + assert 'Direct answer' in final + + def test_only_think_tags(self): + model_output = 'All reasoning, no final answer.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert 'reasoning' in reasoning.lower() + assert final is None + + def test_multiline_reasoning(self): + model_output = ('\nStep 1: calculate 37 * 40 = 1480\n' + 'Step 2: calculate 37 * 3 = 111\n' + 'Step 3: 1480 + 111 = 1591\nThe answer is 1591.') + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert 'Step 1' in reasoning + assert 'Step 3' in reasoning + assert final is not None + assert '1591' in final + + def test_unclosed_think_tag(self): + model_output = 'I am still reasoning about this problem...' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning == model_output + assert final is None + + def test_empty_model_output(self): + reasoning, final = self.parser.extract_reasoning_content('', self.req) + assert reasoning == '' + assert final is None + + def test_multiple_think_blocks(self): + model_output = ('first reasoningmiddle text' + 'second reasoningfinal text') + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert 'first reasoning' in reasoning + assert final is not None + assert 'middle text' in final + assert 'second reasoning' in final + + def test_newlines_between_think_and_content(self): + model_output = 'Step 1: 37*43=1591\n\nThe answer is 1591.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert '1591' in reasoning + assert final is not None + assert final.startswith('\n\n'), (f'DeepSeek parser should preserve leading \\n\\n in content, ' + f'got: {repr(final[:20])}') + + def test_single_newline_between_think_and_content(self): + model_output = 'reasoning\nThe answer.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning == 'reasoning' + assert final is not None + assert final.startswith('\n') + + +@_apply_parser_marks +@pytest.mark.deepseek_r1_parser +class TestDeepSeekR1ParserStreaming: + """Unit tests for DeepSeekR1ReasoningParser streaming extraction.""" + + @pytest.fixture(autouse=True) + def _setup(self): + self.tok = _make_mock_tokenizer(_DEFAULT_VOCAB) + self.parser = _get_deepseek_parser_cls()(self.tok) + + def test_think_start_token_only(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='', + current_text='', + delta_text='', + previous_token_ids=[], + current_token_ids=[100], + delta_token_ids=[100]) + assert result is None + + def test_think_end_token_only(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='Some reasoning', + current_text='Some reasoning', + delta_text='', + previous_token_ids=[100, 200, 201], + current_token_ids=[100, 200, 201, 101], + delta_token_ids=[101]) + assert result is not None + assert result.content == '' + + def test_reasoning_continues(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='Step 1', + current_text='Step 1: multiply', + delta_text=': multiply', + previous_token_ids=[100, 200], + current_token_ids=[100, 200, 201], + delta_token_ids=[201]) + assert result is not None + assert result.reasoning_content == ': multiply' + + def test_reasoning_ends_in_delta(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='Step 1', + current_text='Step 1. DoneAnswer', + delta_text='. DoneAnswer', + previous_token_ids=[100, 200], + current_token_ids=[100, 200, 201, 101, 300], + delta_token_ids=[201, 101, 300]) + assert result is not None + assert '. Done' in result.reasoning_content + assert result.content == 'Answer' + + def test_content_after_think_end(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='ReasoningPrev', + current_text='ReasoningPrev content', + delta_text=' content', + previous_token_ids=[100, 200, 101, 300], + current_token_ids=[100, 200, 101, 300, 301], + delta_token_ids=[301]) + assert result is not None + assert result.content == ' content' + + def test_no_think_in_previous_with_end_in_delta(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='Some reasoning text', + current_text='Some reasoning textFinal', + delta_text='Final', + previous_token_ids=[200, 201], + current_token_ids=[200, 201, 101, 300], + delta_token_ids=[101, 300]) + assert result is not None + assert result.content == 'Final' + + def test_no_think_tags_at_all_streaming(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='', + current_text='Just a plain answer.', + delta_text='Just a plain answer.', + previous_token_ids=[], + current_token_ids=[200, 201, 202], + delta_token_ids=[200, 201, 202]) + assert result is not None + assert result.reasoning_content == 'Just a plain answer.' + assert result.content is None + + def test_empty_reasoning_streaming(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='', + current_text='Direct answer.', + delta_text='Direct answer.', + previous_token_ids=[], + current_token_ids=[100, 101, 200, 201], + delta_token_ids=[100, 101, 200, 201]) + assert result is not None + assert result.reasoning_content is not None + assert result.content is not None + assert 'Direct answer' in result.content + + def test_only_think_tags_streaming(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='All reasoning here.', + current_text='All reasoning here.', + delta_text='', + previous_token_ids=[100, 200, 201], + current_token_ids=[100, 200, 201, 101], + delta_token_ids=[101]) + assert result is not None + assert result.content == '' + + def test_multiline_reasoning_streaming(self): + result = self.parser.extract_reasoning_content_streaming( + previous_text='Step 1: calculate\n', + current_text='Step 1: calculate\nStep 2: add\nStep 3: done', + delta_text='Step 2: add\nStep 3: done', + previous_token_ids=[100, 200, 201], + current_token_ids=[100, 200, 201, 202, 203, 204], + delta_token_ids=[202, 203, 204]) + assert result is not None + assert result.reasoning_content == 'Step 2: add\nStep 3: done' + + def test_unclosed_think_tag_streaming(self): + result = self.parser.extract_reasoning_content_streaming( + previous_text='I am still reasoning', + current_text='I am still reasoning about this...', + delta_text=' about this...', + previous_token_ids=[100, 200, 201], + current_token_ids=[100, 200, 201, 202, 203], + delta_token_ids=[202, 203]) + assert result is not None + assert result.reasoning_content == ' about this...' + assert result.content is None + + def test_empty_model_output_streaming(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='', + current_text='', + delta_text='', + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[]) + if result is not None: + assert (result.reasoning_content or '') == '' + assert (result.content or '') == '' + + def test_multiple_think_blocks_streaming(self): + result = self.parser.extract_reasoning_content_streaming( + previous_text='first reasoningmiddle text', + current_text=('first reasoningmiddle text' + 'second reasoningfinal'), + delta_text='second reasoningfinal', + previous_token_ids=[100, 200, 101, 300], + current_token_ids=[100, 200, 101, 300, 100, 201, 101, 301], + delta_token_ids=[100, 201, 101, 301]) + assert result is not None + assert result.content is not None + assert 'second reasoning' in result.content + + def test_newlines_between_think_and_content_streaming(self): + result = self.parser.extract_reasoning_content_streaming( + previous_text='Step 1: reasoning', + current_text='Step 1: reasoning. Done\n\nThe answer is 1591.', + delta_text='. Done\n\nThe answer is 1591.', + previous_token_ids=[100, 200], + current_token_ids=[100, 200, 201, 101, 300, 301], + delta_token_ids=[201, 101, 300, 301]) + assert result is not None + assert '. Done' in result.reasoning_content + assert result.content is not None + assert result.content.startswith('\n\n') + assert '1591' in result.content + + def test_newline_only_content_streaming(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='reasoning', + current_text='reasoning\n\n', + delta_text='\n\n', + previous_token_ids=[100, 200], + current_token_ids=[100, 200, 101, 300], + delta_token_ids=[101, 300]) + assert result is not None + assert result.content == '\n\n' + + +@_apply_parser_marks +@pytest.mark.qwenqwq_parser +class TestQwenQwQParserNonStreaming: + """Unit tests for QwenQwQReasoningParser.extract_reasoning_content.""" + + @pytest.fixture(autouse=True) + def _setup(self): + self.tok = _make_mock_tokenizer() + self.parser = _get_qwen_parser_cls()(self.tok) + self.req = _make_mock_request() + + def test_full_think_tags(self): + model_output = 'Let me reason about this...The answer is 42.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert 'reason' in reasoning.lower() + assert final is not None + assert '42' in final + + def test_missing_start_tag(self): + model_output = 'Reasoning here...Final answer.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert final is not None + assert 'Final answer' in final + + def test_newline_stripping(self): + model_output = '\nStep 1: think\nStep 2: conclude\nDone.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert not reasoning.startswith('\n') + assert not reasoning.endswith('\n') + assert final is not None + + def test_only_reasoning_no_content(self): + model_output = 'Only reasoning content here.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert final is None + + def test_empty_reasoning_tags(self): + model_output = 'Direct answer.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert final is not None + assert 'Direct answer' in final + + def test_multiple_think_blocks(self): + model_output = ('first reasoningmiddle text' + 'second reasoningfinal text') + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert 'first reasoning' in reasoning + assert final is not None + assert 'middle text' in final + assert 'second reasoning' in final + + def test_newlines_between_think_and_content(self): + model_output = '\nStep 1: 37*43=1591\n\n\nThe answer is 1591.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert not reasoning.startswith('\n') + assert not reasoning.endswith('\n') + assert '1591' in reasoning + assert final is not None + assert final.startswith('\n\n') + + def test_single_newline_between_think_and_content(self): + model_output = '\nreasoning\n\nThe answer.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning == 'reasoning' + assert final is not None + assert final.startswith('\n') + + def test_unclosed_think_should_return_reasoning(self): + """ without → reasoning should be preserved.""" + model_output = 'I am still reasoning about this problem...' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None, ('reasoning should not be None when is present but unclosed') + assert reasoning == 'I am still reasoning about this problem...' + assert final is None + + def test_no_tags_all_reasoning(self): + """No think tags at all → everything is reasoning.""" + model_output = 'Plain output without any think tokens.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert reasoning == model_output + assert final is None + + def test_empty_input(self): + """Edge case: empty string → ('', None).""" + reasoning, final = self.parser.extract_reasoning_content('', self.req) + assert final is None + + def test_prefix_before_think_tags(self): + """Text before causes wrong content extraction.""" + model_output = 'Some prefixreasoning hereThe answer.' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert 'reasoning here' in reasoning + assert final is not None + assert final == 'The answer.' + assert '<' not in final and '>' not in final + + def test_prefix_only_whitespace_before_think(self): + """Whitespace-only prefix before .""" + model_output = ' reasoninganswer' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert final is not None + assert final == 'answer' + + def test_unclosed_think_with_newlines(self): + r"""\nreasoning\n without .""" + model_output = '\nI am reasoning step by step\n' + reasoning, final = self.parser.extract_reasoning_content(model_output, self.req) + assert reasoning is not None + assert 'reasoning step by step' in reasoning + assert final is None + + +@_apply_parser_marks +@pytest.mark.qwenqwq_parser +class TestQwenQwQParserStreaming: + """Unit tests for QwenQwQReasoningParser streaming extraction.""" + + @pytest.fixture(autouse=True) + def _setup(self): + self.tok = _make_mock_tokenizer(_DEFAULT_VOCAB) + self.parser = _get_qwen_parser_cls()(self.tok) + + def test_think_start_text_skipped(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='', + current_text='', + delta_text='', + previous_token_ids=[], + current_token_ids=[100], + delta_token_ids=[100]) + assert result is not None + assert result.content == '' + + def test_think_end_text_skipped(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='Some text', + current_text='Some text', + delta_text='', + previous_token_ids=[100, 200], + current_token_ids=[100, 200, 101], + delta_token_ids=[101]) + assert result is not None + assert result.content == '' + + def test_reasoning_continues(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='Step 1', + current_text='Step 1, Step 2', + delta_text=', Step 2', + previous_token_ids=[100, 200], + current_token_ids=[100, 200, 201], + delta_token_ids=[201]) + assert result is not None + assert result.reasoning_content == ', Step 2' + + def test_reasoning_ends_in_delta(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='Step 1', + current_text='Step 1. FinalResult', + delta_text='. FinalResult', + previous_token_ids=[100, 200], + current_token_ids=[100, 200, 201, 101, 300], + delta_token_ids=[201, 101, 300]) + assert result is not None + assert '. Final' in (result.reasoning_content or '') + assert result.content == 'Result' + + def test_content_after_think_closed(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='ReasoningPrev', + current_text='ReasoningPrev more', + delta_text=' more', + previous_token_ids=[100, 200, 101, 300], + current_token_ids=[100, 200, 101, 300, 301], + delta_token_ids=[301]) + assert result is not None + assert result.content == ' more' + + def test_no_think_tags_streaming(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='', + current_text='Plain answer without reasoning.', + delta_text='Plain answer without reasoning.', + previous_token_ids=[], + current_token_ids=[200, 201, 202], + delta_token_ids=[200, 201, 202]) + assert result is not None + assert result.content == 'Plain answer without reasoning.' + assert result.reasoning_content is None + + def test_missing_start_tag_streaming(self): + result = self.parser.extract_reasoning_content_streaming( + previous_text='Reasoning without start tag', + current_text='Reasoning without start tagFinal answer.', + delta_text='Final answer.', + previous_token_ids=[200, 201], + current_token_ids=[200, 201, 101, 300], + delta_token_ids=[101, 300]) + assert result is not None + assert result.content == 'Final answer.' + + def test_only_reasoning_no_content_streaming(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='Only reasoning content.', + current_text='Only reasoning content.', + delta_text='', + previous_token_ids=[100, 200], + current_token_ids=[100, 200, 101], + delta_token_ids=[101]) + assert result is not None + assert result.content == '' + + def test_empty_reasoning_tags_streaming(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='', + current_text='Direct answer.', + delta_text='Direct answer.', + previous_token_ids=[], + current_token_ids=[100, 101, 200], + delta_token_ids=[100, 101, 200]) + assert result is not None + assert result.content is not None + assert 'Direct answer' in result.content + + def test_newline_in_reasoning_streaming(self): + result = self.parser.extract_reasoning_content_streaming( + previous_text='\nStep 1: think', + current_text='\nStep 1: think\nStep 2: conclude\n', + delta_text='\nStep 2: conclude\n', + previous_token_ids=[100, 200, 201], + current_token_ids=[100, 200, 201, 202, 203], + delta_token_ids=[202, 203]) + assert result is not None + assert result.reasoning_content == '\nStep 2: conclude\n' + + def test_unclosed_think_tag_streaming(self): + result = self.parser.extract_reasoning_content_streaming( + previous_text='I am still reasoning', + current_text='I am still reasoning about this...', + delta_text=' about this...', + previous_token_ids=[100, 200, 201], + current_token_ids=[100, 200, 201, 202, 203], + delta_token_ids=[202, 203]) + assert result is not None + assert result.reasoning_content == ' about this...' + assert result.content is None + + def test_empty_model_output_streaming(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='', + current_text='', + delta_text='', + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[]) + if result is not None: + assert (result.reasoning_content or '') == '' + assert (result.content or '') == '' + + def test_multiple_think_blocks_streaming(self): + result = self.parser.extract_reasoning_content_streaming( + previous_text='first reasoningmiddle text', + current_text=('first reasoningmiddle text' + 'second reasoningfinal'), + delta_text='second reasoningfinal', + previous_token_ids=[100, 200, 101, 300], + current_token_ids=[100, 200, 101, 300, 100, 201, 101, 301], + delta_token_ids=[100, 201, 101, 301]) + assert result is not None + assert result.content is not None + assert 'second reasoning' in result.content + + def test_newlines_between_think_and_content_streaming(self): + result = self.parser.extract_reasoning_content_streaming( + previous_text='Step 1: reasoning', + current_text='Step 1: reasoning. Done\n\nThe answer is 1591.', + delta_text='. Done\n\nThe answer is 1591.', + previous_token_ids=[100, 200], + current_token_ids=[100, 200, 201, 101, 300, 301], + delta_token_ids=[201, 101, 300, 301]) + assert result is not None + assert '. Done' in result.reasoning_content + assert result.content is not None + assert result.content.startswith('\n\n') + assert '1591' in result.content + + def test_newline_only_content_streaming(self): + result = self.parser.extract_reasoning_content_streaming(previous_text='reasoning', + current_text='reasoning\n\n', + delta_text='\n\n', + previous_token_ids=[100, 200], + current_token_ids=[100, 200, 101, 300], + delta_token_ids=[101, 300]) + assert result is not None + assert result.content == '\n\n' + + def test_unclosed_think_should_return_reasoning(self): + """ without → reasoning should be accumulated across + all deltas and content should remain None.""" + parser = _get_qwen_parser_cls()(_make_mock_tokenizer(_DEFAULT_VOCAB)) + deltas = ['', 'I am still ', 'reasoning about ', 'this problem...'] + reasoning, content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert reasoning is not None, ('reasoning should not be None when ' + ' is present but unclosed') + assert 'I am still reasoning about this problem...' == reasoning + assert content is None + + def test_no_tags_all_reasoning_streaming(self): + """No think tags at all → everything is reasoning (Qwen treats tag-less + output as content).""" + parser = _get_qwen_parser_cls()(_make_mock_tokenizer(_DEFAULT_VOCAB)) + deltas = ['Plain output ', 'without any ', 'think tokens.'] + reasoning, content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + full_text = 'Plain output without any think tokens.' + # Qwen streaming without tags emits content (not reasoning) + assert (reasoning or '') + (content or '') == full_text + + def test_prefix_before_think_tags_streaming(self): + """Text before must not corrupt content (streaming).""" + parser = _get_qwen_parser_cls()(_make_mock_tokenizer(_DEFAULT_VOCAB)) + deltas = ['Some prefix', '', 'reasoning here', '', 'The answer.'] + reasoning, content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert reasoning is not None + assert 'reasoning here' in reasoning + assert content is not None + assert 'The answer.' in content + assert '<' not in content and '>' not in content + + def test_prefix_only_whitespace_before_think_streaming(self): + """Whitespace-only prefix before (streaming).""" + parser = _get_qwen_parser_cls()(_make_mock_tokenizer(_DEFAULT_VOCAB)) + deltas = [' ', '', 'reasoning', '', 'answer'] + reasoning, content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert reasoning is not None + assert content is not None + assert 'answer' in content + + def test_unclosed_think_with_newlines_streaming(self): + r"""\nreasoning\n without (streaming).""" + parser = _get_qwen_parser_cls()(_make_mock_tokenizer(_DEFAULT_VOCAB)) + deltas = ['', '\n', 'I am reasoning ', 'step by step', '\n'] + reasoning, content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert reasoning is not None, ('reasoning should not be None for unclosed ' + ' with newlines') + assert 'reasoning step by step' in reasoning + assert content is None + + def test_think_tag_embedded_in_delta_should_not_leak(self): + """ embedded in a larger delta must not leak into + reasoning_content.""" + parser = _get_qwen_parser_cls()(_make_mock_tokenizer(_DEFAULT_VOCAB)) + # is part of a larger delta, not a separate token + deltas = ['I am reasoning', ' step by step', '', 'The answer.'] + reasoning, content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert reasoning is not None + assert THINK_START_TOKEN not in reasoning, (' tag leaked into reasoning_content') + assert THINK_END_TOKEN not in reasoning, (' tag leaked into reasoning_content') + assert 'I am reasoning step by step' in reasoning + assert content is not None + assert 'The answer.' in content + + def test_think_tag_embedded_with_end_in_same_delta(self): + """..... all in one large delta.""" + parser = _get_qwen_parser_cls()(_make_mock_tokenizer(_DEFAULT_VOCAB)) + deltas = ['short reasoninganswer'] + reasoning, content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert reasoning is not None + assert THINK_START_TOKEN not in reasoning, (' tag leaked into reasoning_content') + assert 'short reasoning' in reasoning + assert content is not None + assert 'answer' in content + + def test_think_tag_embedded_no_end_token(self): + """ embedded in delta, output truncated (no at all) — + tag must not appear in reasoning.""" + parser = _get_qwen_parser_cls()(_make_mock_tokenizer(_DEFAULT_VOCAB)) + deltas = ['I am reasoning about', ' a complex problem...'] + reasoning, content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert reasoning is not None + assert THINK_START_TOKEN not in reasoning, (' tag leaked into reasoning_content ' + 'when output is truncated') + assert 'I am reasoning about a complex problem...' in reasoning + assert content is None + + +# --------------------------------------------------------------------------- +# Parser edge cases — parametrized over both parsers (was separate per-parser) +# --------------------------------------------------------------------------- + + +@_apply_parser_marks +class TestReasoningParserEdgeCases: + """Edge cases that both parsers should handle identically.""" + + @_BOTH_PARSERS + def test_unicode_in_reasoning(self, parser_factory): + """Chinese / Unicode content inside think tags.""" + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + req = _make_mock_request() + model_output = '让我想想... 37 × 43 = 1591答案是1591。' + reasoning, final = parser.extract_reasoning_content(model_output, req) + assert reasoning is not None + assert '37' in reasoning + assert final is not None + assert '1591' in final + + @_BOTH_PARSERS + def test_very_long_reasoning(self, parser_factory): + """Long reasoning content should be fully extracted.""" + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + req = _make_mock_request() + long_reasoning = 'Step ' + '. Step '.join(str(i) for i in range(100)) + model_output = f'{long_reasoning}Final.' + reasoning, final = parser.extract_reasoning_content(model_output, req) + assert reasoning is not None + assert len(reasoning) > 100 + assert 'Step 99' in reasoning + assert final == 'Final.' + + @_BOTH_PARSERS + def test_special_chars_in_reasoning(self, parser_factory): + """Special characters inside think tags.""" + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + req = _make_mock_request() + model_output = ('Let\'s check: 2 > 1 & 3 < 5, also "quoted" text' + 'All good.') + reasoning, final = parser.extract_reasoning_content(model_output, req) + assert reasoning is not None + assert '&' in reasoning or '>' in reasoning + assert final is not None + + @pytest.mark.parametrize('parser_name', REASONING_PARSER_NAMES) + def test_parser_instantiation(self, parser_name): + """Every registered parser should be instantiable with a mock + tokenizer.""" + mgr = _get_parser_manager() + cls = mgr.get(parser_name) + assert cls is not None + tok = _make_mock_tokenizer() + parser = cls(tok) + assert parser is not None + assert parser.think_start_token == THINK_START_TOKEN + assert parser.think_end_token == THINK_END_TOKEN + + +# --------------------------------------------------------------------------- +# Parser init / robustness tests +# --------------------------------------------------------------------------- + + +@_apply_parser_marks +class TestReasoningParserInitErrors: + """Initialization error handling for reasoning parsers.""" + + def test_deepseek_none_tokenizer(self): + with pytest.raises(ValueError, match='(?i)tokenizer'): + _get_deepseek_parser_cls()(None) + + def test_qwen_none_tokenizer(self): + with pytest.raises(ValueError, match='(?i)tokenizer'): + _get_qwen_parser_cls()(None) + + def test_deepseek_missing_vocab_tokens(self): + tok = _make_mock_tokenizer(vocab={'unrelated': 999}) + with pytest.raises(RuntimeError, match='(?i)think.*token'): + _get_deepseek_parser_cls()(tok) + + def test_deepseek_token_ids_not_none(self): + tok = _make_mock_tokenizer() + parser = _get_deepseek_parser_cls()(tok) + assert parser.think_start_token_id is not None + assert parser.think_end_token_id is not None + assert parser.think_start_token_id == 100 + assert parser.think_end_token_id == 101 + + # NOTE: token_properties for both parsers are covered by + # TestReasoningParserEdgeCases.test_parser_instantiation. + + def test_vocab_property_accessible(self): + tok = _make_mock_tokenizer() + parser = _get_deepseek_parser_cls()(tok) + assert '' in parser.vocab + assert '' in parser.vocab + assert parser.vocab[''] == 100 + + +# --------------------------------------------------------------------------- +# Parser dual-mode tests (already parametrized) +# --------------------------------------------------------------------------- + + +@_apply_parser_marks +class TestReasoningParserDualMode: + """Same extraction scenario verified in both streaming and non- + streaming.""" + + @_BOTH_PARSERS + def test_simple_extraction_both_modes(self, parser_factory): + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + req = _make_mock_request() + model_output = 'Step by step reasoningFinal answer' + ns_reasoning, ns_content = parser.extract_reasoning_content(model_output, req) + deltas = ['', 'Step by step ', 'reasoning', '', 'Final answer'] + s_reasoning, s_content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert ns_reasoning is not None and len(ns_reasoning) > 0 + assert s_reasoning is not None and len(s_reasoning) > 0 + assert ns_content is not None and s_content is not None + assert ns_content.strip() == s_content.strip() + + @_BOTH_PARSERS + def test_no_end_token_both_modes(self, parser_factory): + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + req = _make_mock_request() + model_output = 'Just reasoning without end token' + ns_reasoning, ns_content = parser.extract_reasoning_content(model_output, req) + deltas = ['Just reasoning ', 'without end token'] + s_reasoning, s_content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert ns_reasoning == model_output + assert ns_content is None + assert s_reasoning is not None + + @_BOTH_PARSERS + def test_empty_output_both_modes(self, parser_factory): + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + req = _make_mock_request() + ns_reasoning, ns_content = parser.extract_reasoning_content('', req) + s_reasoning, s_content = _run_streaming_extraction(parser, [''], _DEFAULT_VOCAB) + assert isinstance(ns_reasoning, (str, type(None))) + assert isinstance(ns_content, (str, type(None))) + + @_BOTH_PARSERS + def test_incremental_deltas_both_modes(self, parser_factory): + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + req = _make_mock_request() + model_output = 'ABCD' + ns_reasoning, ns_content = parser.extract_reasoning_content(model_output, req) + deltas = ['', 'A', 'B', '', 'C', 'D'] + s_reasoning, s_content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert ns_content is not None and s_content is not None + assert ns_content.strip() == s_content.strip() + + +# --------------------------------------------------------------------------- +# Advanced edge cases — parametrized over both parsers (was separate) +# --------------------------------------------------------------------------- + + +@_apply_parser_marks +class TestReasoningParserAdvancedEdgeCases: + """Advanced edge cases — parametrized over both parsers.""" + + @_BOTH_PARSERS + def test_multiple_end_tokens(self, parser_factory): + """Multiple tokens: should stop at the first one.""" + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + req = _make_mock_request() + model_output = 'FirstMiddleLast' + reasoning, content = parser.extract_reasoning_content(model_output, req) + assert reasoning is not None and 'First' in reasoning + assert content is not None and 'Middle' in content + + @_BOTH_PARSERS + def test_nested_think_tokens(self, parser_factory): + """Nested inside .""" + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + req = _make_mock_request() + model_output = 'OuterInnerContent' + reasoning, content = parser.extract_reasoning_content(model_output, req) + assert reasoning is not None + assert 'Outer' in reasoning and 'Inner' in reasoning + assert content is not None and 'Content' in content + + @_BOTH_PARSERS + def test_malformed_similar_tokens(self, parser_factory): + """Tags like should be treated as plain text.""" + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + req = _make_mock_request() + model_output = 'Not real tagsContent' + reasoning, content = parser.extract_reasoning_content(model_output, req) + assert reasoning == model_output + assert content is None + + @_BOTH_PARSERS + def test_streaming_no_end_token(self, parser_factory): + """Streaming with only start token — reasoning continues.""" + tok = _make_mock_tokenizer() + parser = parser_factory()(tok) + deltas = ['', 'Reasoning ', 'without ', 'end'] + reasoning, content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert reasoning is not None and 'Reasoning' in reasoning + assert content is None + + def test_deepseek_streaming_multiple_end_tokens(self): + """Multiple in streaming — DeepSeek specific.""" + tok = _make_mock_tokenizer() + parser = _get_deepseek_parser_cls()(tok) + deltas = ['', 'First', '', 'Middle', '', 'Last'] + reasoning, content = _run_streaming_extraction(parser, deltas, _DEFAULT_VOCAB) + assert reasoning is not None and 'First' in reasoning + assert content is not None and 'Middle' in content + + +# --------------------------------------------------------------------------- +# Parser independence tests +# --------------------------------------------------------------------------- + + +@_apply_parser_marks +class TestReasoningParserIndependence: + """Verify that different parser instances don't share state.""" + + def test_parsers_do_not_share_state(self): + tok = _make_mock_tokenizer() + req = _make_mock_request() + deepseek_parser = _get_deepseek_parser_cls()(tok) + qwen_parser = _get_qwen_parser_cls()(tok) + model_output = 'Shared reasoningShared content' + ds_reasoning, ds_content = deepseek_parser.extract_reasoning_content(model_output, req) + qw_reasoning, qw_content = qwen_parser.extract_reasoning_content(model_output, req) + assert ds_reasoning is not None and 'Shared reasoning' in ds_reasoning + assert qw_reasoning is not None and 'Shared reasoning' in qw_reasoning + assert ds_content is not None and 'Shared content' in ds_content + assert qw_content is not None and 'Shared content' in qw_content + + def test_parsers_independent_streaming(self): + tok = _make_mock_tokenizer() + deepseek_parser = _get_deepseek_parser_cls()(tok) + qwen_parser = _get_qwen_parser_cls()(tok) + deltas = ['', 'Step 1 ', 'Step 2', '', 'Answer'] + ds_reasoning, ds_content = _run_streaming_extraction(deepseek_parser, deltas, _DEFAULT_VOCAB) + qw_reasoning, qw_content = _run_streaming_extraction(qwen_parser, deltas, _DEFAULT_VOCAB) + assert ds_reasoning is not None and qw_reasoning is not None + assert ds_content is not None and qw_content is not None + assert ds_content.strip() == qw_content.strip() + + def test_multiple_instances_same_parser(self): + tok1 = _make_mock_tokenizer() + tok2 = _make_mock_tokenizer() + req = _make_mock_request() + parser1 = _get_deepseek_parser_cls()(tok1) + parser2 = _get_deepseek_parser_cls()(tok2) + r1, c1 = parser1.extract_reasoning_content('Reasoning AContent A', req) + r2, c2 = parser2.extract_reasoning_content('Reasoning BContent B', req) + assert r1 is not None and 'A' in r1 + assert r2 is not None and 'B' in r2 + assert c1 is not None and 'A' in c1 + assert c2 is not None and 'B' in c2 + r1_again, c1_again = parser1.extract_reasoning_content('Reasoning AContent A', req) + assert r1_again == r1 and c1_again == c1 + + +# =========================================================================== +# API-level test classes (parametrized by backend × model × stream) +# =========================================================================== + +# --------------------------------------------------------------------------- +# Basic reasoning: presence, quality, separation +# (merged from TestReasoningBasic, TestReasoningStreaming, +# TestReasoningContentQuality, TestReasoningContentQualityStreaming) +# --------------------------------------------------------------------------- + + +@_apply_marks_stream +class TestReasoningBasic(_ReasoningTestBase): + """Basic reasoning_content presence, quality, and content separation.""" + + def test_reasoning_content_present(self, backend, model_case, stream): + """Model should populate reasoning_content for math questions.""" + r = self._call_api(stream, MESSAGES_REASONING_BASIC) + assert r['finish_reason'] in ('stop', 'length') + assert len(r['reasoning']) > 10, (f'reasoning too short ({len(r["reasoning"])} chars)') + assert any(kw in r['reasoning'] for kw in ('37', '43', '1591', 'multiply', '*', '×')) + assert len(r['content'].strip()) > 0 + assert '1591' in r['content'] + assert r['reasoning'].strip() != r['content'].strip() + _assert_no_tag_leakage(r['reasoning'], r['content']) + + def test_reasoning_quality_complex(self, backend, model_case, stream): + """Complex train problem: reasoning should contain calculation steps.""" + r = self._call_api(stream, MESSAGES_REASONING_COMPLEX, max_completion_tokens=2048) + assert len(r['reasoning']) > 50 + assert any(kw in r['reasoning'] for kw in ('60', '80', '140', '280')) + assert len(r['content'].strip()) > 0 + assert '2' in r['content'] + _assert_no_tag_leakage(r['reasoning'], r['content']) + if stream: + assert r['reasoning_chunks'] > 1 + + +# --------------------------------------------------------------------------- +# Streaming ↔ Non-streaming consistency (cross-mode comparison) +# --------------------------------------------------------------------------- + + +@_apply_marks +class TestReasoningStreamConsistency(_ReasoningTestBase): + """Both modes must produce reasoning AND content with correct + separation.""" + + def test_reasoning_presence_consistent(self, backend, model_case): + client, model_name = self._get_client() + ns_resp = client.chat.completions.create(model=model_name, + messages=MESSAGES_REASONING_BASIC, + temperature=0, + max_completion_tokens=1024, + logprobs=False) + ns_reasoning = get_reasoning_content(ns_resp.choices[0].message) + ns_content = ns_resp.choices[0].message.content or '' + + stream = client.chat.completions.create(model=model_name, + messages=MESSAGES_REASONING_BASIC, + temperature=0, + max_completion_tokens=1024, + logprobs=False, + stream=True) + result = collect_stream_reasoning(stream) + + assert ns_reasoning is not None and len(ns_reasoning) > 0 + assert len(result['reasoning_content']) > 0 + assert len(ns_content.strip()) > 0 + assert len(result['content'].strip()) > 0 + assert '1591' in ns_content + assert '1591' in result['content'] + for text in [ns_reasoning, ns_content, result['reasoning_content'], result['content']]: + assert THINK_START_TOKEN not in text + assert THINK_END_TOKEN not in text + + +# --------------------------------------------------------------------------- +# Tool calls + tool_choice (merged from TestReasoningWithToolCalls, +# TestReasoningWithToolCallsStreaming, TestReasoningWithToolChoice, +# TestReasoningWithToolChoiceStreaming) +# --------------------------------------------------------------------------- + + +@_apply_marks_stream +class TestReasoningWithTools(_ReasoningTestBase): + """Reasoning with tool calls under different tool_choice settings.""" + + def test_tool_choice_auto(self, backend, model_case, stream): + """tool_choice='auto': weather question should trigger weather tool.""" + r = self._call_api(stream, + MESSAGES_REASONING_WEATHER_TOOL, + tools=[WEATHER_TOOL, SEARCH_TOOL], + tool_choice='auto') + if len(r['tool_calls']) > 0: + assert r['finish_reason'] == 'tool_calls' + for tc in r['tool_calls']: + assert tc['name'] in ('get_current_weather', 'web_search') + parsed = json.loads(tc['args_str']) + assert isinstance(parsed, dict) + else: + assert len(r['content'].strip()) > 0 + + def test_tool_choice_required(self, backend, model_case, stream): + """tool_choice='required': must produce tool call.""" + try: + r = self._call_api(stream, MESSAGES_REASONING_WEATHER_TOOL, tools=[WEATHER_TOOL], tool_choice='required') + assert len(r['tool_calls']) >= 1 + assert r['finish_reason'] == 'tool_calls' + tc = r['tool_calls'][0] + assert tc['name'] == 'get_current_weather' + parsed = json.loads(tc['args_str']) + assert 'city' in parsed + except Exception as e: + pytest.skip(f'tool_choice="required" not supported: {e}') + + def test_tool_choice_none(self, backend, model_case, stream): + """tool_choice='none': no tool calls, text answer instead.""" + r = self._call_api(stream, MESSAGES_REASONING_WEATHER_TOOL, tools=[WEATHER_TOOL], tool_choice='none') + assert len(r['tool_calls']) == 0 + assert r['finish_reason'] in ('stop', 'length') + assert len(r['reasoning'] + r['content']) > 0 + _assert_no_tag_leakage(r['reasoning'], r['content']) + + def test_tool_choice_specific(self, backend, model_case, stream): + """Force get_current_weather: must call exactly that tool.""" + r = self._call_api(stream, + MESSAGES_REASONING_WEATHER_TOOL, + tools=[WEATHER_TOOL, SEARCH_TOOL], + tool_choice={ + 'type': 'function', + 'function': { + 'name': 'get_current_weather' + } + }) + assert r['finish_reason'] == 'tool_calls' + assert len(r['tool_calls']) >= 1 + tc = r['tool_calls'][0] + assert tc['name'] == 'get_current_weather' + parsed = json.loads(tc['args_str']) + assert 'city' in parsed + assert 'dallas' in parsed['city'].lower() + + +# --------------------------------------------------------------------------- +# Parallel tool calls +# --------------------------------------------------------------------------- + + +@_apply_marks_stream +class TestReasoningParallelToolCalls(_ReasoningTestBase): + """Reasoning model calling multiple tools in parallel.""" + + def test_parallel_tools(self, backend, model_case, stream): + r = self._call_api(stream, MESSAGES_REASONING_PARALLEL_TOOLS, tools=[WEATHER_TOOL, CALCULATOR_TOOL]) + assert len(r['tool_calls']) >= 1 + assert r['finish_reason'] == 'tool_calls' + ids = [tc['id'] for tc in r['tool_calls'] if tc.get('id')] + if len(ids) >= 2: + assert len(set(ids)) == len(ids), f'IDs must be unique: {ids}' + for tc in r['tool_calls']: + assert tc['name'] in ('get_current_weather', 'calculate') + parsed = json.loads(tc['args_str']) + assert isinstance(parsed, dict) + + +# --------------------------------------------------------------------------- +# Tool round-trip: reason → tool → result → answer +# --------------------------------------------------------------------------- + + +@_apply_marks_stream +class TestReasoningToolRoundTrip(_ReasoningTestBase): + """Multi-turn: reason → tool → result → reasoning → answer.""" + + def test_after_tool_result(self, backend, model_case, stream): + r = self._call_api(stream, build_reasoning_tool_roundtrip_messages(), tools=[WEATHER_TOOL]) + assert r['finish_reason'] in ('stop', 'length') + assert len(r['tool_calls']) == 0 + assert len(r['reasoning'] + r['content']) > 0 + if r['content']: + has_ref = any(kw in r['content'].lower() + for kw in ('sunny', 'umbrella', 'dallas', 'clear', 'rain', 'weather', 'no')) + assert has_ref, f'Content should reference weather: {r["content"][:200]}' + _assert_no_tag_leakage(r['reasoning'], r['content']) + + +# --------------------------------------------------------------------------- +# Streaming ↔ Non-streaming tool-call consistency (cross-mode comparison) +# --------------------------------------------------------------------------- + + +@_apply_marks +class TestReasoningToolCallConsistency(_ReasoningTestBase): + """Compare streaming vs non-streaming tool-call results.""" + + def test_tool_call_stream_vs_nonstream(self, backend, model_case): + client, model_name = self._get_client() + common_kwargs = dict(model=model_name, + messages=MESSAGES_REASONING_WEATHER_TOOL, + temperature=0, + max_completion_tokens=1024, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + ns_resp = client.chat.completions.create(**common_kwargs) + ns_choice = ns_resp.choices[0] + assert ns_choice.finish_reason == 'tool_calls' + assert ns_choice.message.tool_calls is not None + ns_tc = ns_choice.message.tool_calls[0] + assert_tool_call_fields(ns_tc) + ns_parsed = assert_arguments_parseable(ns_tc.function.arguments) + assert isinstance(ns_tc.id, str) and len(ns_tc.id) >= 9 + + stream = client.chat.completions.create(**common_kwargs, stream=True) + sr = collect_stream_reasoning(stream) + assert sr['finish_reason'] == 'tool_calls' + assert sr['finish_reason_count'] == 1 + assert sr['role'] == 'assistant' + assert sr['role_count'] == 1 + s_tc = list(sr['tool_calls'].values())[0] + assert s_tc['name'] is not None + assert s_tc['id'] is not None and len(s_tc['id']) >= 9 + s_parsed = json.loads(s_tc['args_str']) + + assert s_tc['name'] == ns_tc.function.name + assert ns_parsed == s_parsed + assert sr['finish_reason'] == ns_choice.finish_reason + + def test_streaming_role_exactly_once(self, backend, model_case): + client, model_name = self._get_client() + stream = client.chat.completions.create(model=model_name, + messages=MESSAGES_REASONING_BASIC, + temperature=0, + max_completion_tokens=1024, + logprobs=False, + stream=True) + result = collect_stream_reasoning(stream) + assert result['role'] == 'assistant' + assert result['role_count'] == 1 + + def test_streaming_function_name_not_fragmented(self, backend, model_case): + client, model_name = self._get_client() + stream = client.chat.completions.create(model=model_name, + messages=MESSAGES_REASONING_WEATHER_TOOL, + temperature=0, + max_completion_tokens=1024, + tools=[WEATHER_TOOL], + tool_choice={ + 'type': 'function', + 'function': { + 'name': 'get_current_weather' + } + }, + logprobs=False, + stream=True) + name_events = [] + for chunk in stream: + if not chunk.choices: + continue + delta = chunk.choices[0].delta + if delta.tool_calls: + for tc in delta.tool_calls: + if tc.function and tc.function.name: + name_events.append(tc.function.name) + assert len(name_events) == 1 + assert name_events[0] == 'get_current_weather' + + +# --------------------------------------------------------------------------- +# Tool-result consistency (cross-mode comparison) +# --------------------------------------------------------------------------- + + +@_apply_marks +class TestReasoningToolResultConsistency(_ReasoningTestBase): + """After providing tool results, streaming content must match non- + streaming.""" + + def test_tool_result_stream_vs_nonstream(self, backend, model_case): + client, model_name = self._get_client() + messages = build_messages_with_tool_response() + common_kwargs = dict(model=model_name, + messages=messages, + temperature=0, + max_completion_tokens=256, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + ns_resp = client.chat.completions.create(**common_kwargs) + ns_choice = ns_resp.choices[0] + assert ns_choice.finish_reason != 'tool_calls' + assert ns_choice.message.content is not None + assert len(ns_choice.message.content) > 0 + + stream = client.chat.completions.create(**common_kwargs, stream=True) + chunks = [] + finish_count = 0 + for chunk in stream: + if not chunk.choices: + continue + delta = chunk.choices[0].delta + if delta.content: + chunks.append(delta.content) + if chunk.choices[0].finish_reason is not None: + finish_count += 1 + assert finish_count == 1 + streamed_content = ''.join(chunks) + assert streamed_content == ns_choice.message.content + + def test_tool_result_no_tag_leakage(self, backend, model_case): + client, model_name = self._get_client() + response = client.chat.completions.create(model=model_name, + messages=build_messages_with_tool_response(), + temperature=0, + max_completion_tokens=256, + tools=[WEATHER_TOOL], + logprobs=False) + content = response.choices[0].message.content or '' + assert THINK_START_TOKEN not in content + assert THINK_END_TOKEN not in content + + def test_reasoning_roundtrip_stream_vs_nonstream(self, backend, model_case): + client, model_name = self._get_client() + messages = build_reasoning_tool_roundtrip_messages() + common_kwargs = dict(model=model_name, + messages=messages, + temperature=0, + max_completion_tokens=512, + tools=[WEATHER_TOOL], + logprobs=False) + + ns_resp = client.chat.completions.create(**common_kwargs) + ns_choice = ns_resp.choices[0] + ns_content = ns_choice.message.content or '' + ns_reasoning = get_reasoning_content(ns_choice.message) + + stream = client.chat.completions.create(**common_kwargs, stream=True) + sr = collect_stream_reasoning(stream) + assert sr['finish_reason'] == ns_choice.finish_reason + assert sr['content'] == ns_content + if ns_reasoning: + assert len(sr['reasoning_content']) > 0 + + +# --------------------------------------------------------------------------- +# Web search tool +# --------------------------------------------------------------------------- + + +@_apply_marks_stream +class TestReasoningWebSearchTool(_ReasoningTestBase): + """Tests for web_search tool call — forced, auto, and round-trip.""" + + def test_web_search_forced(self, backend, model_case, stream): + r = self._call_api(stream, + MESSAGES_REASONING_SEARCH_TOOL, + tools=[SEARCH_TOOL], + tool_choice={ + 'type': 'function', + 'function': { + 'name': 'web_search' + } + }) + assert r['finish_reason'] == 'tool_calls' + assert len(r['tool_calls']) >= 1 + tc = r['tool_calls'][0] + assert tc['name'] == 'web_search' + parsed = json.loads(tc['args_str']) + assert 'query' in parsed and len(parsed['query']) > 0 + + def test_web_search_auto(self, backend, model_case, stream): + r = self._call_api(stream, + MESSAGES_REASONING_SEARCH_TOOL, + tools=[SEARCH_TOOL, WEATHER_TOOL], + tool_choice='auto') + if len(r['tool_calls']) > 0: + assert r['finish_reason'] == 'tool_calls' + names = [tc['name'] for tc in r['tool_calls']] + assert 'web_search' in names + for tc in r['tool_calls']: + if tc['name'] == 'web_search': + parsed = json.loads(tc['args_str']) + assert 'query' in parsed + else: + assert len(r['content'].strip()) > 0 + + def test_web_search_roundtrip(self, backend, model_case, stream): + r = self._call_api(stream, _build_search_roundtrip_messages(), tools=[SEARCH_TOOL]) + assert r['finish_reason'] in ('stop', 'length') + assert len(r['tool_calls']) == 0 + assert len(r['reasoning'] + r['content']) > 0 + if r['content']: + has_ref = any(kw in r['content'].lower() + for kw in ('hopfield', 'hinton', 'nobel', 'physics', 'machine learning', 'neural network')) + assert has_ref + _assert_no_tag_leakage(r['reasoning'], r['content']) + + +# --------------------------------------------------------------------------- +# Token accounting +# --------------------------------------------------------------------------- + + +@_apply_marks +class TestReasoningTokenAccounting(_ReasoningTestBase): + """Verify token usage includes reasoning tokens when available.""" + + def test_usage_present(self, backend, model_case): + client, model_name = self._get_client() + response = client.chat.completions.create(model=model_name, + messages=MESSAGES_REASONING_BASIC, + temperature=0, + max_completion_tokens=1024, + logprobs=False) + assert response.usage is not None + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens == (response.usage.prompt_tokens + response.usage.completion_tokens) + assert response.usage.completion_tokens > 10 + + def test_reasoning_tokens_if_available(self, backend, model_case): + client, model_name = self._get_client() + response = client.chat.completions.create(model=model_name, + messages=MESSAGES_REASONING_COMPLEX, + temperature=0, + max_completion_tokens=2048, + logprobs=False) + rt = get_reasoning_tokens(response) + if rt is not None: + assert rt >= 0 + assert rt <= response.usage.completion_tokens + if response.usage.completion_tokens > 50: + assert rt > 0 + + def test_usage_present_streaming(self, backend, model_case): + client, model_name = self._get_client() + stream = client.chat.completions.create(model=model_name, + messages=MESSAGES_REASONING_BASIC, + temperature=0, + max_completion_tokens=1024, + logprobs=False, + stream=True, + stream_options={'include_usage': True}) + usage = None + for chunk in stream: + chunk_usage = getattr(chunk, 'usage', None) + if chunk_usage is not None: + usage = chunk_usage + if usage is not None: + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + + def test_reasoning_tokens_streaming_if_available(self, backend, model_case): + client, model_name = self._get_client() + stream = client.chat.completions.create(model=model_name, + messages=MESSAGES_REASONING_COMPLEX, + temperature=0, + max_completion_tokens=2048, + logprobs=False, + stream=True, + stream_options={'include_usage': True}) + usage = None + for chunk in stream: + chunk_usage = getattr(chunk, 'usage', None) + if chunk_usage is not None: + usage = chunk_usage + if usage is not None: + details = getattr(usage, 'completion_tokens_details', None) + rt = getattr(details, 'reasoning_tokens', None) if details else None + if rt is None: + rt = getattr(usage, 'reasoning_tokens', None) + if rt is not None: + assert rt >= 0 + assert rt <= usage.completion_tokens + + +# --------------------------------------------------------------------------- +# Multilingual reasoning +# --------------------------------------------------------------------------- + + +@_apply_marks_stream +class TestReasoningMultilingual(_ReasoningTestBase): + """Reasoning with Chinese / multilingual prompts.""" + + def test_chinese_reasoning(self, backend, model_case, stream): + r = self._call_api(stream, MESSAGES_REASONING_CN, max_completion_tokens=2048) + assert r['finish_reason'] in ('stop', 'length') + assert len(r['reasoning']) > 20 + assert any(kw in r['reasoning'] for kw in ('60', '80', '140', '280')) + assert len(r['content'].strip()) > 0 + assert '2' in r['content'] + _assert_no_tag_leakage(r['reasoning'], r['content']) + + def test_chinese_with_tool(self, backend, model_case, stream): + messages = [ + { + 'role': 'system', + 'content': '你是一个有用的助手,可以使用工具。请先思考是否需要使用工具。' + }, + { + 'role': 'user', + 'content': '北京今天的天气怎么样?我需要带伞吗?' + }, + ] + r = self._call_api(stream, messages, tools=[WEATHER_TOOL_CN]) + assert len(r['tool_calls']) > 0 + assert r['finish_reason'] == 'tool_calls' + tc = r['tool_calls'][0] + assert tc['name'] == 'get_current_weather' + parsed = json.loads(tc['args_str']) + assert 'city' in parsed + assert '北京' in parsed['city'] or 'Beijing' in parsed['city'] + + +# --------------------------------------------------------------------------- +# Multi-turn reasoning +# --------------------------------------------------------------------------- + + +@_apply_marks_stream +class TestReasoningMultiTurn(_ReasoningTestBase): + """Multi-turn conversations where reasoning persists.""" + + def test_multi_turn_reasoning(self, backend, model_case, stream): + r = self._call_api(stream, MESSAGES_REASONING_MULTI_TURN, max_completion_tokens=2048) + assert r['finish_reason'] in ('stop', 'length') + assert len(r['reasoning']) > 20 + assert any(kw in r['reasoning'] for kw in ('100', '101', '5050', '5,050', 'formula', 'Gauss', 'n(n', 'n *')) + assert len(r['content'].strip()) > 0 + assert '5050' in r['content'] or '5,050' in r['content'] + _assert_no_tag_leakage(r['reasoning'], r['content']) + + +# --------------------------------------------------------------------------- +# Response-level validation (separate streaming / non-streaming methods) +# --------------------------------------------------------------------------- + + +@_apply_marks +class TestReasoningResponseValidation(_ReasoningTestBase): + """Validate response-level fields in reasoning mode.""" + + def test_model_id_created_fields(self, backend, model_case): + client, model_name = self._get_client() + response = client.chat.completions.create(model=model_name, + messages=MESSAGES_REASONING_BASIC, + temperature=0, + max_completion_tokens=1024, + logprobs=False) + assert response.model is not None and len(response.model) > 0 + assert response.id is not None and len(str(response.id)) > 0 + assert response.created is not None and response.created > 0 + assert len(response.choices) >= 1 + assert response.choices[0].index == 0 + assert response.choices[0].finish_reason in ('stop', 'length') + assert response.choices[0].message.role == 'assistant' + msg = response.choices[0].message + reasoning = get_reasoning_content(msg) + assert reasoning is not None + assert msg.content is not None and len(msg.content.strip()) > 0 + assert THINK_START_TOKEN not in (msg.content or '') + assert THINK_END_TOKEN not in (msg.content or '') + + def test_model_id_created_fields_streaming(self, backend, model_case): + client, model_name = self._get_client() + stream = client.chat.completions.create(model=model_name, + messages=MESSAGES_REASONING_BASIC, + temperature=0, + max_completion_tokens=1024, + logprobs=False, + stream=True) + first_chunk = None + chunk_count = 0 + has_role = False + last_finish = None + for chunk in stream: + chunk_count += 1 + if first_chunk is None: + first_chunk = chunk + if chunk.choices and chunk.choices[0].delta.role: + has_role = True + if chunk.choices and chunk.choices[0].finish_reason: + last_finish = chunk.choices[0].finish_reason + assert first_chunk is not None + assert first_chunk.model is not None and len(first_chunk.model) > 0 + assert first_chunk.id is not None + assert first_chunk.created is not None and first_chunk.created > 0 + assert has_role + assert last_finish in ('stop', 'length') + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@_apply_marks_stream +class TestReasoningEdgeCases(_ReasoningTestBase): + """Edge cases for reasoning functionality.""" + + def test_simple_question(self, backend, model_case, stream): + """'What is 2+2?' should produce answer '4'.""" + r = self._call_api(stream, MESSAGES_REASONING_SIMPLE, max_completion_tokens=512) + assert r['finish_reason'] in ('stop', 'length') + full = r['reasoning'] + r['content'] + assert len(full) > 0 + assert '4' in full + _assert_no_tag_leakage(r['reasoning'], r['content']) + + def test_no_tools_provided(self, backend, model_case, stream): + """Without tools, weather question produces text answer.""" + r = self._call_api(stream, MESSAGES_REASONING_WEATHER_TOOL) + assert r['finish_reason'] in ('stop', 'length') + assert len(r['tool_calls']) == 0 + assert len(r['reasoning'] + r['content']) > 0 + _assert_no_tag_leakage(r['reasoning'], r['content']) + + def test_empty_tools(self, backend, model_case, stream): + """Empty tools list: no tool calls, pure reasoning + text.""" + from openai import BadRequestError + try: + r = self._call_api(stream, MESSAGES_REASONING_BASIC, tools=[]) + except BadRequestError: + pytest.skip('Backend rejects empty tools list') + assert len(r['tool_calls']) == 0 + assert len(r['reasoning'] + r['content']) > 0 + + def test_low_max_tokens(self, backend, model_case, stream): + """Very low max_tokens: truncated but valid output.""" + r = self._call_api(stream, MESSAGES_REASONING_COMPLEX, max_completion_tokens=50) + assert r['finish_reason'] in ('stop', 'length') + assert len(r['reasoning'] + r['content']) > 0 + _assert_no_tag_leakage(r['reasoning'], r['content']) + + def test_reasoning_not_parsed_as_tool_call(self, backend, model_case, stream): + """Reasoning mentioning function names must not be extracted as tool + calls.""" + messages = [{ + 'role': + 'user', + 'content': ('Explain the proof that the square root of 2 is irrational. ' + 'Do not call any tools, just explain in text.'), + }] + r = self._call_api(stream, messages, tools=[CALCULATOR_TOOL, WEATHER_TOOL], tool_choice='auto') + assert len(r['content'].strip()) > 0 + assert r['finish_reason'] in ('stop', 'length') + assert len(r['tool_calls']) == 0 + _assert_no_tag_leakage(r['reasoning'], r['content']) + + +# --------------------------------------------------------------------------- +# Disable thinking (enable_thinking=False) +# --------------------------------------------------------------------------- + + +@_apply_marks_stream +class TestReasoningDisableThinking(_ReasoningTestBase): + """Tests with enable_thinking=False — non-think mode.""" + + def test_no_reasoning_content(self, backend, model_case, stream): + """enable_thinking=False: reasoning_content should be absent.""" + r = self._call_api(stream, MESSAGES_REASONING_BASIC, extra_body={'enable_thinking': False}) + assert r['finish_reason'] in ('stop', 'length') + assert len(r['content'].strip()) > 0 + assert THINK_START_TOKEN not in r['content'] + assert THINK_END_TOKEN not in r['content'] + # reasoning should be empty + assert r['reasoning'] == '' + if stream: + assert r['reasoning_chunks'] == 0 + + def test_with_tool_call(self, backend, model_case, stream): + """enable_thinking=False + tool call: should call tool without + reasoning.""" + r = self._call_api(stream, + MESSAGES_REASONING_WEATHER_TOOL, + tools=[WEATHER_TOOL], + extra_body={'enable_thinking': False}) + if len(r['tool_calls']) > 0: + assert r['finish_reason'] == 'tool_calls' + tc = r['tool_calls'][0] + assert tc['name'] == 'get_current_weather' + parsed = json.loads(tc['args_str']) + assert 'city' in parsed + assert r['reasoning'] == '' + _assert_no_tag_leakage(r['reasoning'], r['content']) + + def test_content_quality(self, backend, model_case, stream): + """enable_thinking=False: content should still contain correct + answer.""" + r = self._call_api(stream, MESSAGES_REASONING_BASIC, extra_body={'enable_thinking': False}) + assert len(r['content'].strip()) > 0 + assert '1591' in r['content'] diff --git a/autotest/interface/restful/test_restful_tool_calls.py b/autotest/interface/restful/test_restful_tool_calls.py new file mode 100644 index 0000000000..05e2944dde --- /dev/null +++ b/autotest/interface/restful/test_restful_tool_calls.py @@ -0,0 +1,1164 @@ +import json + +import pytest +from utils.constant import BACKEND_LIST, TOOL_REASONING_MODEL_LIST +from utils.tool_reasoning_definitions import (ALL_OPTIONAL_TOOL, CALCULATOR_TOOL, NESTED_PARAM_TOOL, SEARCH_TOOL, + WEATHER_TOOL, WEATHER_TOOL_CN, assert_arguments_parseable, + assert_tool_call_fields, build_messages_with_parallel_tool_responses, + build_messages_with_tool_response, collect_stream_parallel_tool_calls, + collect_stream_tool_call, make_logged_client, setup_log_file) + +_CLASS_MARKS = [ + pytest.mark.order(8), + pytest.mark.tool_call, + pytest.mark.flaky(reruns=2), + pytest.mark.parametrize('backend', BACKEND_LIST), + pytest.mark.parametrize('model_case', TOOL_REASONING_MODEL_LIST), +] + + +def _apply_marks(cls): + """Apply the shared set of marks to *cls* and return it.""" + for m in _CLASS_MARKS: + cls = m(cls) + return cls + + +# --------------------------------------------------------------------------- +# Logging helpers – uses shared StreamTee / setup_log_file / make_logged_client +# from utils.tool_reasoning_definitions. +# --------------------------------------------------------------------------- + + +class _ToolCallTestBase: + """Mixin providing per-test API request/response logging to *log_path*.""" + + @pytest.fixture(autouse=True) + def _setup_logging(self, request, config, backend, model_case): + """Create the log directory and compute the log-file path.""" + self._log_file = setup_log_file(config, request.node.name, 'tool_calls') + + def _get_client(self): + """Return *(client, model_name)* with transparent logging.""" + return make_logged_client(self._log_file) + + +# --------------------------------------------------------------------------- +# Message constants +# --------------------------------------------------------------------------- + +MESSAGES_ASKING_FOR_WEATHER = [ + { + 'role': + 'system', + 'content': + 'You are a helpful assistant that can use tools. ' + 'When asked about weather, use the get_current_weather tool.', + }, + { + 'role': 'user', + 'content': "What's the weather like in Dallas, TX?", + }, +] + +MESSAGES_ASKING_FOR_SEARCH = [ + { + 'role': + 'system', + 'content': + 'You are a helpful assistant with access to tools. ' + 'Use the web_search tool when asked to look something up.', + }, + { + 'role': 'user', + 'content': 'Search the web for the latest news about AI.', + }, +] + +MESSAGES_ASKING_FOR_CALCULATION = [ + { + 'role': 'system', + 'content': 'You are a helpful assistant. When asked math questions, ' + 'use the calculate tool.', + }, + { + 'role': 'user', + 'content': 'What is 1234 * 5678?', + }, +] + +MESSAGES_ASKING_FOR_WEATHER_CN = [ + { + 'role': 'system', + 'content': '你是一个有用的助手,可以使用工具。' + '当被问到天气时,请使用get_current_weather工具。', + }, + { + 'role': 'user', + 'content': '北京今天的天气怎么样?', + }, +] + +MESSAGES_NO_TOOL_NEEDED = [ + { + 'role': 'user', + 'content': 'Hi, please introduce yourself briefly.', + }, +] + +MESSAGES_PARALLEL_WEATHER = [ + { + 'role': + 'system', + 'content': + 'You are a helpful assistant. When asked about weather ' + 'in multiple cities, call the weather tool for each city ' + 'separately.', + }, + { + 'role': 'user', + 'content': "What's the weather in Dallas, TX and also in " + 'San Francisco, CA?', + }, +] + +MESSAGES_PARALLEL_MIXED = [ + { + 'role': + 'system', + 'content': + 'You are a helpful assistant with access to multiple tools. ' + 'You can call multiple tools in parallel when needed.', + }, + { + 'role': 'user', + 'content': "What's the weather in Dallas, TX? " + 'Also calculate 1234 * 5678.', + }, +] + +# --------------------------------------------------------------------------- +# Test classes +# --------------------------------------------------------------------------- + + +@_apply_marks +class TestToolCallBasic(_ToolCallTestBase): + """Basic tool call: response structure, finish_reason, field validation.""" + + def test_non_streaming(self, backend, model_case): + """Non-streaming: complete tool call response structure.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + ) + + # Response-level checks + assert response.object == 'chat.completion' + assert response.id is not None and len(response.id) > 0 + assert response.model is not None and len(response.model) > 0 + assert len(response.choices) == 1 + + choice = response.choices[0] + tool_calls = choice.message.tool_calls + + assert choice.message.role == 'assistant' + assert choice.finish_reason == 'tool_calls' + assert tool_calls is not None and len(tool_calls) >= 1 + + tc = tool_calls[0] + assert_tool_call_fields(tc) + assert tc.type == 'function' + assert tc.function.name == WEATHER_TOOL['function']['name'] + + parsed_args = assert_arguments_parseable(tc.function.arguments) + assert isinstance(parsed_args.get('city'), str) and len(parsed_args['city']) > 0 + assert isinstance(parsed_args.get('state'), str) and len(parsed_args['state']) > 0 + + # Token usage sanity + assert response.usage is not None + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + + def test_streaming(self, backend, model_case): + """Streaming: tool call id / name streamed once, args accumulated.""" + client, model_name = self._get_client() + + stream = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True, + ) + + r = collect_stream_tool_call(stream) + + assert r['finish_reason_count'] == 1, (f'Expected exactly 1 finish_reason, got {r["finish_reason_count"]}') + assert r['finish_reason'] == 'tool_calls' + assert r['role'] == 'assistant' + assert isinstance(r['tool_call_id'], str) and len(r['tool_call_id']) >= 1 + assert r['tool_call_id'].strip() == r['tool_call_id'], 'tool_call_id has leading/trailing whitespace' + assert r['function_name'] == WEATHER_TOOL['function']['name'] + + streamed_args = assert_arguments_parseable(r['args_str']) + assert isinstance(streamed_args.get('city'), str) and len(streamed_args['city']) > 0 + assert isinstance(streamed_args.get('state'), str) and len(streamed_args['state']) > 0 + + +@_apply_marks +class TestToolCallStreamConsistency(_ToolCallTestBase): + """Streaming and non-streaming tool call results must match.""" + + def test_stream_nonstream_consistency(self, backend, model_case): + client, model_name = self._get_client() + + # Non-streaming + ns_resp = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + ) + ns_choice = ns_resp.choices[0] + assert ns_choice.finish_reason == 'tool_calls' + assert ns_choice.message.tool_calls is not None and len(ns_choice.message.tool_calls) >= 1 + ns_tc = ns_choice.message.tool_calls[0] + ns_name = ns_tc.function.name + ns_args = json.loads(ns_tc.function.arguments) + + # Streaming + stream = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True, + ) + r = collect_stream_tool_call(stream) + assert r['finish_reason'] == 'tool_calls' + s_args = json.loads(r['args_str']) + + assert ns_name == r['function_name'], (f'Function name mismatch: non-stream={ns_name}, ' + f'stream={r["function_name"]}') + assert ns_args == s_args, (f'Arguments mismatch: non-stream={ns_args}, stream={s_args}') + # Verify both resolved to a valid tool name + assert ns_name in ('get_current_weather', 'web_search'), (f'Unexpected function name: {ns_name}') + + +@_apply_marks +class TestToolCallChoice(_ToolCallTestBase): + """Test all tool_choice variants.""" + + # -- auto ---------------------------------------------------------------- + def test_tool_choice_auto(self, backend, model_case): + """tool_choice='auto': model decides whether to call a tool.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + tool_choice='auto', + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + assert choice.finish_reason in ('stop', 'length', + 'tool_calls'), (f'Unexpected finish_reason: {choice.finish_reason}') + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + for tc in choice.message.tool_calls: + assert_tool_call_fields(tc) + assert tc.type == 'function' + assert tc.function.name in ('get_current_weather', + 'web_search'), (f'Unexpected tool: {tc.function.name}') + assert_arguments_parseable(tc.function.arguments) + assert choice.finish_reason == 'tool_calls' + else: + assert choice.message.content is not None and len(choice.message.content.strip()) > 0 + assert choice.finish_reason in ('stop', 'length') + + # -- none ---------------------------------------------------------------- + def test_tool_choice_none(self, backend, model_case): + """tool_choice='none': model must NOT return tool calls.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + tool_choice='none', + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0), ('tool_choice="none" but got tool_calls in response') + assert choice.message.content is not None + assert len(choice.message.content.strip()) > 0, ('tool_choice="none" should produce non-empty text content') + assert choice.finish_reason in ('stop', 'length') + + # -- required ------------------------------------------------------------ + def test_tool_choice_required(self, backend, model_case): + """tool_choice='required': model MUST return at least one tool call. + + Only skip when the *server* rejects the request (HTTP error). + """ + client, model_name = self._get_client() + + try: + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + tool_choice='required', + logprobs=False, + ) + except Exception as e: + # Only skip if the server itself rejects the request + pytest.skip(f'tool_choice="required" not supported by server: {e}') + + # Validation MUST fail loudly — never skip on assertion errors + choice = response.choices[0] + assert choice.message.role == 'assistant' + assert choice.message.tool_calls is not None, ('tool_choice="required" but got no tool_calls') + assert len(choice.message.tool_calls) >= 1 + for tc in choice.message.tool_calls: + assert_tool_call_fields(tc) + assert_arguments_parseable(tc.function.arguments) + + def test_tool_choice_required_streaming(self, backend, model_case): + """tool_choice='required' + streaming: must return tool call chunks. + + Only skip if the server rejects the request, not on parse errors. + """ + client, model_name = self._get_client() + + try: + stream = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + tool_choice='required', + logprobs=False, + stream=True, + ) + r = collect_stream_tool_call(stream) + except Exception as e: + pytest.skip(f'tool_choice="required" streaming not supported by server: {e}') + + # Validation MUST fail loudly + assert r['function_name'] is not None, ('tool_choice="required" streaming but no function name received') + assert len(r['args_str']) > 0, ('tool_choice="required" streaming but no arguments received') + assert r['tool_call_id'] is not None + assert_arguments_parseable(r['args_str']) + assert r['finish_reason'] == 'tool_calls' + + # -- specific function --------------------------------------------------- + def test_tool_choice_specific_function(self, backend, model_case): + """Force a specific tool via tool_choice={type, function}.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + tool_choice={ + 'type': 'function', + 'function': { + 'name': 'get_current_weather' + }, + }, + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + + tool_calls = choice.message.tool_calls + assert tool_calls is not None and len(tool_calls) >= 1 + + tc = tool_calls[0] + assert_tool_call_fields(tc) + assert tc.function.name == 'get_current_weather' + assert_arguments_parseable(tc.function.arguments) + + +@_apply_marks +class TestToolCallArgumentsParsing(_ToolCallTestBase): + """Validate that arguments are parseable and contain expected keys.""" + + def test_weather_args(self, backend, model_case): + """Weather tool args should contain city & state.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL], + logprobs=False, + ) + + tc = response.choices[0].message.tool_calls[0] + parsed = json.loads(tc.function.arguments) + + assert 'city' in parsed, f'Missing "city": {parsed}' + assert 'state' in parsed, f'Missing "state": {parsed}' + assert 'dallas' in parsed['city'].lower() + assert 'tx' in parsed['state'].lower() + + def test_weather_args_streaming(self, backend, model_case): + """Streaming: weather tool args should contain city & state.""" + client, model_name = self._get_client() + + stream = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL], + logprobs=False, + stream=True, + ) + + r = collect_stream_tool_call(stream) + parsed = json.loads(r['args_str']) + + assert 'city' in parsed + assert 'state' in parsed + + def test_search_tool_args(self, backend, model_case): + """Search tool args should contain a non-empty query.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_SEARCH, + temperature=0, + max_completion_tokens=200, + tools=[SEARCH_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert tc.function.name == 'web_search' + parsed = json.loads(tc.function.arguments) + assert 'query' in parsed + assert isinstance(parsed['query'], str) and len(parsed['query']) > 0 + + def test_enum_parameter_constraint(self, backend, model_case): + """Enum-constrained params should only return valid values.""" + client, model_name = self._get_client() + + messages = [ + { + 'role': 'system', + 'content': 'You are a helpful weather assistant. ' + 'Always use the weather tool and specify the unit.' + }, + { + 'role': 'user', + 'content': 'What is the weather in Miami, FL in celsius?' + }, + ] + + response = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL], + tool_choice={ + 'type': 'function', + 'function': { + 'name': 'get_current_weather' + }, + }, + logprobs=False, + ) + + tc = response.choices[0].message.tool_calls[0] + parsed = json.loads(tc.function.arguments) + if 'unit' in parsed: + assert parsed['unit'] in ('celsius', 'fahrenheit'), (f'unit should be from enum, got "{parsed["unit"]}"') + + +@_apply_marks +class TestToolCallMultipleTools(_ToolCallTestBase): + """Model should pick the right tool from a multi-tool list.""" + + def test_selects_weather_tool(self, backend, model_case): + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + assert choice.message.tool_calls[0].function.name == ('get_current_weather') + + def test_selects_calculator_tool(self, backend, model_case): + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_CALCULATION, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL, CALCULATOR_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert tc.function.name == 'calculate' + parsed = json.loads(tc.function.arguments) + assert 'expression' in parsed + + def test_no_tool_when_not_needed(self, backend, model_case): + """Unrelated question + tool_choice=auto → prefer text.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_NO_TOOL_NEEDED, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + tool_choice='auto', + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + for tc in choice.message.tool_calls: + assert_tool_call_fields(tc) + else: + assert choice.message.content is not None + assert len(choice.message.content) > 0 + + def test_large_number_of_tools(self, backend, model_case): + """Model should still pick the right tool among 10+ definitions.""" + client, model_name = self._get_client() + + tools = [WEATHER_TOOL] + for i in range(10): + tools.append({ + 'type': 'function', + 'function': { + 'name': f'dummy_tool_{i}', + 'description': f'A dummy tool number {i} (does nothing).', + 'parameters': { + 'type': 'object', + 'properties': { + 'input': { + 'type': 'string', + 'description': f'Input for dummy tool {i}', + }, + }, + 'required': ['input'], + }, + }, + }) + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=tools, + logprobs=False, + ) + + choice = response.choices[0] + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert_tool_call_fields(tc) + assert_arguments_parseable(tc.function.arguments) + assert tc.function.name == 'get_current_weather', (f'Expected weather tool, got "{tc.function.name}"') + + +@_apply_marks +class TestToolCallParallel(_ToolCallTestBase): + """Parallel tool calls in a single response.""" + + def test_parallel_same_tool(self, backend, model_case): + """Two cities → two weather tool calls.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_PARALLEL_WEATHER, + temperature=0, + max_completion_tokens=300, + tools=[WEATHER_TOOL], + logprobs=False, + ) + + tool_calls = response.choices[0].message.tool_calls + + # Hard assertion: two cities asked → must get ≥ 2 tool calls + assert tool_calls is not None and len(tool_calls) >= 2, (f'Expected ≥2 parallel tool calls for two cities, ' + f'got {len(tool_calls) if tool_calls else 0}') + + for tc in tool_calls: + assert_tool_call_fields(tc) + assert tc.function.name == 'get_current_weather' + parsed = assert_arguments_parseable(tc.function.arguments) + assert 'city' in parsed and 'state' in parsed + + ids = [tc.id for tc in tool_calls] + assert len(set(ids)) == len(ids), (f'IDs should be unique, got {ids}') + assert response.choices[0].finish_reason == 'tool_calls' + + def test_parallel_same_tool_streaming(self, backend, model_case): + """Streaming: parallel tool calls indexed correctly.""" + client, model_name = self._get_client() + + stream = client.chat.completions.create( + model=model_name, + messages=MESSAGES_PARALLEL_WEATHER, + temperature=0, + max_completion_tokens=300, + tools=[WEATHER_TOOL], + logprobs=False, + stream=True, + ) + + tc_data, fr_count = collect_stream_parallel_tool_calls(stream) + assert fr_count == 1 + + # Hard assertion: must receive ≥ 2 distinct tool call indices + assert len(tc_data) >= 2, (f'Expected ≥2 parallel streaming tool calls, ' + f'got {len(tc_data)} indices: {list(tc_data.keys())}') + + for idx, data in tc_data.items(): + assert data['name'] is not None, (f'Index {idx}: missing function name') + assert len(data['args_str']) > 0, (f'Index {idx}: missing arguments') + parsed = json.loads(data['args_str']) + assert isinstance(parsed, dict) + + def test_parallel_mixed_tools(self, backend, model_case): + """Weather + calculator in one request.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_PARALLEL_MIXED, + temperature=0, + max_completion_tokens=400, + tools=[WEATHER_TOOL, CALCULATOR_TOOL], + logprobs=False, + ) + + tool_calls = response.choices[0].message.tool_calls + + # Hard assertion: weather + calculation asked → ≥ 2 tool calls + assert tool_calls is not None and len(tool_calls) >= 2, (f'Expected ≥2 parallel tool calls (weather+calc), ' + f'got {len(tool_calls) if tool_calls else 0}') + + for tc in tool_calls: + assert_tool_call_fields(tc) + assert_arguments_parseable(tc.function.arguments) + + ids = [tc.id for tc in tool_calls] + assert len(set(ids)) == len(ids), (f'Tool call IDs should be unique, got {ids}') + + names = {tc.function.name for tc in tool_calls} + assert len(names) >= 2, (f'Expected ≥2 distinct tool names, got {names}') + assert 'get_current_weather' in names, (f'Expected get_current_weather in tool calls, got {names}') + assert 'calculate' in names, (f'Expected calculate in tool calls, got {names}') + + +@_apply_marks +class TestToolCallWithResults(_ToolCallTestBase): + """Feed tool results back; model should reply with text.""" + + def test_single_result(self, backend, model_case): + client, model_name = self._get_client() + messages = build_messages_with_tool_response() + + response = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + assert choice.finish_reason in ('stop', 'length') + assert choice.message.role == 'assistant' + assert (choice.message.tool_calls is None or len(choice.message.tool_calls) == 0) + assert choice.message.content and len(choice.message.content) > 0 + assert '98' in choice.message.content or 'Dallas' in choice.message.content + + def test_multiple_results(self, backend, model_case): + """Feed two parallel tool results back at once.""" + client, model_name = self._get_client() + messages = build_messages_with_parallel_tool_responses() + + response = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0, + max_completion_tokens=300, + tools=[WEATHER_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + assert choice.finish_reason in ('stop', 'length') + assert choice.message.content and len(choice.message.content) > 0 + + content = choice.message.content + has_dallas = 'Dallas' in content or '98' in content + has_sf = 'San Francisco' in content or '65' in content + assert has_dallas or has_sf + + +@_apply_marks +class TestToolCallMultilingual(_ToolCallTestBase): + + def test_chinese_description(self, backend, model_case): + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER_CN, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL_CN], + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert_tool_call_fields(tc) + assert tc.function.name == 'get_current_weather' + parsed = assert_arguments_parseable(tc.function.arguments) + assert 'city' in parsed + assert isinstance(parsed['city'], str) and len(parsed['city']) > 0 + + def test_chinese_description_streaming(self, backend, model_case): + client, model_name = self._get_client() + + stream = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER_CN, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL_CN], + logprobs=False, + stream=True, + ) + + r = collect_stream_tool_call(stream) + if r['function_name'] is not None: + assert r['function_name'] == 'get_current_weather' + parsed = json.loads(r['args_str']) + assert isinstance(parsed, dict) + assert 'city' in parsed + + def test_mixed_language_tools(self, backend, model_case): + """Pass Chinese + English tool definitions together.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER_CN, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL_CN, SEARCH_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + for tc in choice.message.tool_calls: + assert_tool_call_fields(tc) + assert_arguments_parseable(tc.function.arguments) + + def test_unicode_arguments(self, backend, model_case): + """Chinese query → tool arguments with Unicode chars.""" + client, model_name = self._get_client() + + messages = [ + { + 'role': 'system', + 'content': 'You are a helpful assistant. Use the search tool.' + }, + { + 'role': 'user', + 'content': '请搜索一下"人工智能最新进展"' + }, + ] + + response = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0, + max_completion_tokens=200, + tools=[SEARCH_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert_tool_call_fields(tc) + assert_arguments_parseable(tc.function.arguments) + + +@_apply_marks +class TestToolCallComplexParams(_ToolCallTestBase): + """Nested objects, arrays, enum constraints, all-optional params.""" + + def test_nested_object_parameters(self, backend, model_case): + client, model_name = self._get_client() + + messages = [ + { + 'role': 'system', + 'content': 'You are a helpful assistant. Use the create_event ' + 'tool when asked to schedule events.' + }, + { + 'role': + 'user', + 'content': + 'Schedule a team meeting titled "Sprint Review" at ' + 'the Conference Room in New York with attendees ' + 'alice@example.com and bob@example.com, high priority.' + }, + ] + + response = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0, + max_completion_tokens=400, + tools=[NESTED_PARAM_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert_tool_call_fields(tc) + assert tc.function.name == 'create_event' + + parsed = json.loads(tc.function.arguments) + assert 'title' in parsed + + if 'location' in parsed: + assert isinstance(parsed['location'], dict) + if 'attendees' in parsed: + assert isinstance(parsed['attendees'], list) + if 'priority' in parsed: + assert parsed['priority'] in ('low', 'medium', 'high') + + def test_all_optional_parameters(self, backend, model_case): + client, model_name = self._get_client() + + messages = [ + { + 'role': 'system', + 'content': 'You are a logging assistant. ' + 'Use the log_message tool to log messages.' + }, + { + 'role': 'user', + 'content': 'Log an info message saying ' + '"System started successfully".' + }, + ] + + response = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0, + max_completion_tokens=200, + tools=[ALL_OPTIONAL_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert_tool_call_fields(tc) + parsed = assert_arguments_parseable(tc.function.arguments) + + if 'message' in parsed: + assert isinstance(parsed['message'], str) + if 'level' in parsed: + assert parsed['level'] in ('debug', 'info', 'warning', 'error') + + +@_apply_marks +class TestToolCallResponseValidation(_ToolCallTestBase): + """Validate response-level fields when tool calls are returned.""" + + def test_content_null_when_tool_calls_present(self, backend, model_case): + """Per OpenAI spec, content should be null or empty when tool_calls + exist.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL], + tool_choice={ + 'type': 'function', + 'function': { + 'name': 'get_current_weather' + }, + }, + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.tool_calls is not None + assert len(choice.message.tool_calls) >= 1 + + # Per OpenAI spec: content should be null or empty when tool_calls + # are present. + if choice.message.content is not None: + assert choice.message.content.strip() == '', (f'content should be null/empty when tool_calls are ' + f'present, got: {choice.message.content!r}') + + def test_usage_field_present(self, backend, model_case): + """usage.prompt_tokens / completion_tokens / total_tokens.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL], + logprobs=False, + ) + + assert response.usage is not None + assert response.usage.prompt_tokens > 0 + assert response.usage.total_tokens > 0 + + choice = response.choices[0] + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens == (response.usage.prompt_tokens + response.usage.completion_tokens) + + def test_model_and_metadata_fields(self, backend, model_case): + """Response must contain model, id, and created.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL], + logprobs=False, + ) + + assert response.model is not None + assert isinstance(response.model, str) and len(response.model) > 0 + assert response.id is not None + assert response.created is not None + + def test_choices_structure(self, backend, model_case): + """choices[0].index should be 0.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL], + logprobs=False, + n=1, + ) + + assert len(response.choices) >= 1 + assert response.choices[0].index == 0 + + +@_apply_marks +class TestToolCallEdgeCases(_ToolCallTestBase): + """Edge cases and robustness tests.""" + + def test_empty_tools_list(self, backend, model_case): + """Empty tools list should behave like no tools.""" + client, model_name = self._get_client() + + try: + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_NO_TOOL_NEEDED, + temperature=0, + max_completion_tokens=100, + tools=[], + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + assert choice.message.content is not None + assert (choice.message.tool_calls is None or len(choice.message.tool_calls) == 0) + except Exception: + # Some backends reject an empty tools list + pass + + def test_tool_call_with_max_tokens(self, backend, model_case): + """With sufficient max_tokens, tool call structure must be valid.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=500, + tools=[WEATHER_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert_tool_call_fields(tc) + assert_arguments_parseable(tc.function.arguments) + + def test_tool_call_id_format(self, backend, model_case): + """ID should be a non-empty string with no leading/trailing spaces.""" + client, model_name = self._get_client() + + response = client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_WEATHER, + temperature=0, + max_completion_tokens=200, + tools=[WEATHER_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert isinstance(tc.id, str) + assert len(tc.id) >= 1 + assert tc.id.strip() == tc.id + + def test_multi_turn_conversation(self, backend, model_case): + """Tool call → result → follow-up question → possible second call.""" + client, model_name = self._get_client() + + messages = build_messages_with_tool_response() + messages.append({ + 'role': 'user', + 'content': 'Now search the web for how to stay cool in hot weather.', + }) + + response = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0, + max_completion_tokens=300, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert_tool_call_fields(tc) + if tc.function.name == 'web_search': + parsed = json.loads(tc.function.arguments) + assert 'query' in parsed + else: + assert choice.message.content and len(choice.message.content) > 0 + + def test_special_characters_in_query(self, backend, model_case): + """Quotes, angle brackets, Unicode → JSON args still parseable.""" + client, model_name = self._get_client() + + messages = [ + { + 'role': 'system', + 'content': 'You are a helpful assistant that can use tools.' + }, + { + 'role': + 'user', + 'content': + 'Search for "what\'s the latest on AI & ML?" ' + '(include results with special chars: <>, "quotes", ' + 'and unicode: café, naïve, résumé)' + }, + ] + + response = client.chat.completions.create( + model=model_name, + messages=messages, + temperature=0, + max_completion_tokens=200, + tools=[SEARCH_TOOL], + logprobs=False, + ) + + choice = response.choices[0] + assert choice.message.role == 'assistant' + if choice.message.tool_calls and len(choice.message.tool_calls) > 0: + tc = choice.message.tool_calls[0] + assert_tool_call_fields(tc) + parsed = assert_arguments_parseable(tc.function.arguments) + if 'query' in parsed: + assert isinstance(parsed['query'], str) + assert len(parsed['query']) > 0 diff --git a/autotest/utils/constant.py b/autotest/utils/constant.py index 800138cee9..7a52d3d526 100644 --- a/autotest/utils/constant.py +++ b/autotest/utils/constant.py @@ -54,10 +54,29 @@ BACKEND_LIST = ['turbomind', 'pytorch'] RESTFUL_MODEL_LIST = [ - 'Qwen/Qwen3-0.6B', 'Qwen/Qwen3-VL-2B-Instruct', 'Qwen/Qwen3-30B-A3B', 'internlm/Intern-S1', - 'internlm/internlm2_5-20b-chat', 'internlm/internlm2_5-20b', 'Qwen/Qwen3-32B', 'OpenGVLab/InternVL3_5-30B-A3B', - 'OpenGVLab/InternVL3-38B', 'Qwen/Qwen3-VL-8B-Instruct', 'internlm/internlm3-8b-instruct', - 'meta-llama/Llama-3.2-3B-Instruct', 'Qwen/Qwen3-VL-30B-A3B-Instruct' + 'Qwen/Qwen3-0.6B', + 'Qwen/Qwen3-VL-2B-Instruct', + 'Qwen/Qwen3-30B-A3B', + 'internlm/Intern-S1', + 'internlm/internlm2_5-20b-chat', + 'internlm/internlm2_5-20b', + 'Qwen/Qwen3-32B', + 'OpenGVLab/InternVL3_5-30B-A3B', + 'OpenGVLab/InternVL3-38B', + 'Qwen/Qwen3-VL-8B-Instruct', + 'internlm/internlm3-8b-instruct', + 'meta-llama/Llama-3.2-3B-Instruct', + 'Qwen/Qwen3-VL-30B-A3B-Instruct', +] + +TOOL_REASONING_MODEL_LIST = [ + 'Qwen/Qwen3-8B-FP8', + 'internlm/Intern-S1-Pro-FP8', + 'internlm/internlm2_5-7b-chat', + 'Qwen/Qwen2.5-7B-Instruct', + 'meta-llama/Meta-Llama-3-1-70B-Instruct', + 'deepseek/DeepSeek-R1-Distill-Llama-8B', + 'deepseek/DeepSeek-R1-Distill-Qwen-32B', ] RESTFUL_BASE_MODEL_LIST = [ diff --git a/autotest/utils/tool_reasoning_definitions.py b/autotest/utils/tool_reasoning_definitions.py new file mode 100644 index 0000000000..b6defcf83b --- /dev/null +++ b/autotest/utils/tool_reasoning_definitions.py @@ -0,0 +1,691 @@ +"""Shared definitions for tool-call and reasoning tests. + +This module centralises tool schemas, message templates, connection settings, +assertion helpers and stream-consumption helpers so they can be imported by +both ``test_restful_tool_calls.py`` and ``test_restful_reasoning.py`` (and +future test modules) without duplication. +""" + +import json +import os +import re + +from openai import OpenAI +from utils.constant import DEFAULT_PORT + +BASE_HTTP_URL = os.environ.get('LMDEPLOY_BASE_URL', 'http://localhost') +BASE_URL = os.environ.get('LMDEPLOY_API_URL', ':'.join([BASE_HTTP_URL, str(DEFAULT_PORT)])) + +#: Supported reasoning parser names (from lmdeploy ReasoningParserManager) +REASONING_PARSER_NAMES = ['deepseek-r1', 'qwen-qwq', 'intern-s1'] + +#: Think-tag delimiters used by DeepSeek-R1 and QwenQwQ parsers +THINK_START_TOKEN = '' +THINK_END_TOKEN = '' + +#: Supported tool parser names (from lmdeploy ToolParserManager) +TOOL_PARSER_NAMES = ['qwen', 'qwen3', 'qwen2d5', 'internlm', 'intern-s1', 'llama3'] + +#: Tool-call tag delimiters — Qwen family (qwen, qwen3, qwen2d5) +TOOL_CALL_START_TOKEN = '' +TOOL_CALL_END_TOKEN = '' + +#: Tool-call tag delimiters — InternLM family +INTERNLM_ACTION_START = '<|action_start|><|plugin|>' +INTERNLM_ACTION_END = '<|action_end|>' + +#: Llama 3 bot token +LLAMA3_BOT_TOKEN = '<|python_tag|>' + +#: Mapping: --tool-call-parser / --reasoning-parser value -> pytest -m expression +TOOL_PARSER_MARK_MAP = { + 'qwen': 'qwen3_parser', + 'qwen3': 'qwen3_parser', + 'qwen2d5': 'qwen2d5_parser', + 'internlm': 'internlm2_parser', + 'intern-s1': 'internlm2_parser', + 'llama3': 'llama3_parser', +} + +REASONING_PARSER_MARK_MAP = { + 'deepseek-r1': 'deepseek_r1_parser', + 'qwen-qwq': 'qwenqwq_parser', + 'intern-s1': 'qwenqwq_parser', +} + +# -- Basic tools (English) -------------------------------------------------- + +WEATHER_TOOL = { + 'type': 'function', + 'function': { + 'name': 'get_current_weather', + 'description': 'Get the current weather in a given location', + 'parameters': { + 'type': 'object', + 'properties': { + 'city': { + 'type': 'string', + 'description': 'The city to find the weather for, ' + 'e.g. San Francisco', + }, + 'state': { + 'type': 'string', + 'description': 'The state abbreviation, e.g. CA', + }, + 'unit': { + 'type': 'string', + 'description': 'The unit for temperature', + 'enum': ['celsius', 'fahrenheit'], + }, + }, + 'required': ['city', 'state'], + }, + }, +} + +SEARCH_TOOL = { + 'type': 'function', + 'function': { + 'name': 'web_search', + 'description': 'Search the web for information', + 'parameters': { + 'type': 'object', + 'properties': { + 'query': { + 'type': 'string', + 'description': 'The search query string', + }, + }, + 'required': ['query'], + }, + }, +} + +CALCULATOR_TOOL = { + 'type': 'function', + 'function': { + 'name': 'calculate', + 'description': 'Perform a mathematical calculation', + 'parameters': { + 'type': 'object', + 'properties': { + 'expression': { + 'type': 'string', + 'description': 'The math expression to evaluate, e.g. 2+2', + }, + }, + 'required': ['expression'], + }, + }, +} + +# -- Chinese tool ------------------------------------------------------------ + +WEATHER_TOOL_CN = { + 'type': 'function', + 'function': { + 'name': 'get_current_weather', + 'description': '获取指定城市的当前天气信息', + 'parameters': { + 'type': 'object', + 'properties': { + 'city': { + 'type': 'string', + 'description': '城市名称,例如:北京', + }, + 'unit': { + 'type': 'string', + 'description': '温度单位', + 'enum': ['摄氏度', '华氏度'], + }, + }, + 'required': ['city'], + }, + }, +} + +# -- Complex-parameter tools ------------------------------------------------- + +NESTED_PARAM_TOOL = { + 'type': 'function', + 'function': { + 'name': 'create_event', + 'description': 'Create a calendar event with nested location details', + 'parameters': { + 'type': 'object', + 'properties': { + 'title': { + 'type': 'string', + 'description': 'Event title', + }, + 'location': { + 'type': 'object', + 'description': 'Event location details', + 'properties': { + 'venue': { + 'type': 'string', + 'description': 'Venue name', + }, + 'address': { + 'type': 'string', + 'description': 'Street address', + }, + 'city': { + 'type': 'string', + 'description': 'City name', + }, + }, + 'required': ['venue', 'city'], + }, + 'attendees': { + 'type': 'array', + 'description': 'List of attendee emails', + 'items': { + 'type': 'string' + }, + }, + 'priority': { + 'type': 'string', + 'description': 'Event priority level', + 'enum': ['low', 'medium', 'high'], + }, + }, + 'required': ['title', 'location'], + }, + }, +} + +ALL_OPTIONAL_TOOL = { + 'type': 'function', + 'function': { + 'name': 'log_message', + 'description': 'Log a message with optional metadata', + 'parameters': { + 'type': 'object', + 'properties': { + 'message': { + 'type': 'string', + 'description': 'The log message', + }, + 'level': { + 'type': 'string', + 'description': 'Log level', + 'enum': ['debug', 'info', 'warning', 'error'], + }, + 'timestamp': { + 'type': 'string', + 'description': 'Optional ISO timestamp', + }, + }, + # NOTE: no 'required' key — all params are optional + }, + }, +} + + +def get_client_and_model(base_url=None): + """Return an ``OpenAI`` client and the first available model name.""" + url = base_url or BASE_URL + client = OpenAI(api_key='YOUR_API_KEY', base_url=f'{url}/v1') + model_name = client.models.list().data[0].id + return client, model_name + + +# -- Logging / client helpers ------------------------------------------------ + + +class StreamTee: + """Transparent iterator proxy: yields every chunk unchanged while + recording each ``repr(chunk)`` to the log file.""" + + def __init__(self, stream, log_file): + self._stream = stream + self._log_file = log_file + + def __iter__(self): + try: + for chunk in self._stream: + try: + with open(self._log_file, 'a', encoding='utf-8') as f: + f.write(repr(chunk) + '\n') + except Exception: + pass + yield chunk + except Exception: + raise + + +def setup_log_file(config, test_name, category): + """Compute log-file path and ensure the directory exists. + + Parameters + ---------- + config : dict + Test configuration (must contain ``log_path`` or defaults to + ``./logs``). + test_name : str + Raw test node name (will be sanitised for filesystem safety). + category : str + Subdirectory under *log_path*, e.g. ``'tool_calls'`` or + ``'reasoning'``. + + Returns + ------- + str + Full path to the log file. + """ + safe_test_name = re.sub(r'[^\w\.-]', '_', test_name) + log_base = config.get('log_path', './logs') + log_dir = os.path.join(log_base, category) + os.makedirs(log_dir, exist_ok=True) + return os.path.join(log_dir, f'{safe_test_name}.log') + + +def make_logged_client(log_file): + """Return ``(client, model_name)`` with transparent logging. + + Every ``chat.completions.create`` call is intercepted to: + + 1. Inject ``extra_body={'spaces_between_special_tokens': False}``. + 2. For streaming calls, wrap the iterator with :class:`StreamTee`. + 3. For non-streaming calls, append ``repr(response)`` to *log_file*. + """ + client, model_name = get_client_and_model() + _original_create = client.chat.completions.create + + def _logged_create(*args, **kwargs): + kwargs.setdefault('extra_body', dict(spaces_between_special_tokens=False)) + is_stream = kwargs.get('stream', False) + result = _original_create(*args, **kwargs) + if is_stream: + return StreamTee(result, log_file) + try: + with open(log_file, 'a', encoding='utf-8') as f: + f.write(repr(result) + '\n') + except Exception: + pass + return result + + client.chat.completions.create = _logged_create + return client, model_name + + +# -- Assertion helpers ------------------------------------------------------- + + +def assert_tool_call_fields(tool_call): + """Assert a single tool call object has all required fields.""" + assert tool_call.type == 'function', (f'tool_call.type should be "function", got {tool_call.type}') + assert tool_call.function is not None, ('tool_call.function should not be None') + assert isinstance(tool_call.id, str), (f'tool_call.id should be a string, got {type(tool_call.id)}') + assert len(tool_call.id) >= 1, (f'tool_call.id should be non-empty, got "{tool_call.id}"') + assert isinstance(tool_call.function.name, + str), (f'function.name should be a string, got {type(tool_call.function.name)}') + assert len(tool_call.function.name) > 0, ('function.name should be non-empty') + assert isinstance(tool_call.function.arguments, str), (f'function.arguments should be a string, ' + f'got {type(tool_call.function.arguments)}') + + +def assert_arguments_parseable(arguments_str): + """Assert *arguments_str* is valid JSON dict; return the parsed dict.""" + parsed = json.loads(arguments_str) + assert isinstance(parsed, dict), (f'Parsed arguments should be a dict, got {type(parsed)}') + return parsed + + +# -- Stream consumption helpers ---------------------------------------------- + + +def collect_stream_tool_call(stream): + """Consume a streaming response and return aggregated tool-call data. + + Returns a dict with keys: + function_name, args_str, tool_call_id, finish_reason, role, + finish_reason_count + """ + result = { + 'function_name': None, + 'args_str': '', + 'tool_call_id': None, + 'finish_reason': None, + 'role': None, + 'finish_reason_count': 0, + } + + for chunk in stream: + choice = chunk.choices[0] + + if choice.finish_reason: + result['finish_reason'] = choice.finish_reason + result['finish_reason_count'] += 1 + + delta = choice.delta + if delta.role: + result['role'] = delta.role + + if delta.tool_calls and len(delta.tool_calls) > 0: + tc = delta.tool_calls[0] + if tc.id: + result['tool_call_id'] = tc.id + if tc.function: + if tc.function.name: + result['function_name'] = tc.function.name + if tc.function.arguments: + result['args_str'] += tc.function.arguments + + return result + + +def collect_stream_parallel_tool_calls(stream): + """Consume a streaming response that may contain parallel tool calls. + + Returns (tool_calls_data, finish_reason_count) where tool_calls_data is a dict index -> {name, args_str, id}. + """ + tool_calls_data = {} + finish_reason_count = 0 + + for chunk in stream: + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + + streamed_tool_calls = chunk.choices[0].delta.tool_calls + if streamed_tool_calls: + for stc in streamed_tool_calls: + idx = stc.index if stc.index is not None else 0 + if idx not in tool_calls_data: + tool_calls_data[idx] = { + 'name': None, + 'args_str': '', + 'id': None, + } + if stc.id: + tool_calls_data[idx]['id'] = stc.id + if stc.function: + if stc.function.name: + tool_calls_data[idx]['name'] = stc.function.name + if stc.function.arguments: + tool_calls_data[idx]['args_str'] += (stc.function.arguments) + + return tool_calls_data, finish_reason_count + + +def collect_stream_content(stream): + """Consume a streaming response and return (chunks, finish_reason).""" + chunks = [] + finish_reason = None + for chunk in stream: + delta = chunk.choices[0].delta + if delta.content: + chunks.append(delta.content) + if chunk.choices[0].finish_reason is not None: + finish_reason = chunk.choices[0].finish_reason + return chunks, finish_reason + + +def collect_stream_reasoning(stream): + """Consume a streaming response, collecting reasoning + content + tool + calls. + + Returns a dict with keys: + reasoning_content – aggregated reasoning string + content – aggregated final content string + tool_calls – dict index -> {name, args_str, id} + finish_reason – last non-None finish_reason + finish_reason_count – how many chunks carried a non-None finish_reason + role – first non-None role value + role_count – how many chunks carried a non-None role + chunk_count – total number of chunks received + reasoning_chunks – number of chunks containing reasoning + content_chunks – number of chunks containing content + """ + result = { + 'reasoning_content': '', + 'content': '', + 'tool_calls': {}, + 'finish_reason': None, + 'finish_reason_count': 0, + 'role': None, + 'role_count': 0, + 'chunk_count': 0, + 'reasoning_chunks': 0, + 'content_chunks': 0, + } + + for chunk in stream: + result['chunk_count'] += 1 + if not chunk.choices: + continue + choice = chunk.choices[0] + + if choice.finish_reason is not None: + result['finish_reason'] = choice.finish_reason + result['finish_reason_count'] += 1 + + delta = choice.delta + if delta.role: + result['role'] = delta.role + result['role_count'] += 1 + + # -- reasoning_content (lmdeploy extension field) ------------------- + rc = getattr(delta, 'reasoning_content', None) + if rc: + result['reasoning_content'] += rc + result['reasoning_chunks'] += 1 + + # -- regular content ------------------------------------------------ + if delta.content: + result['content'] += delta.content + result['content_chunks'] += 1 + + # -- tool calls ----------------------------------------------------- + if delta.tool_calls: + for stc in delta.tool_calls: + idx = stc.index if stc.index is not None else 0 + if idx not in result['tool_calls']: + result['tool_calls'][idx] = { + 'name': None, + 'args_str': '', + 'id': None, + } + if stc.id: + result['tool_calls'][idx]['id'] = stc.id + if stc.function: + if stc.function.name: + result['tool_calls'][idx]['name'] = stc.function.name + if stc.function.arguments: + result['tool_calls'][idx]['args_str'] += (stc.function.arguments) + + return result + + +# -- Reasoning extraction helpers ------------------------------------------- + + +def get_reasoning_content(message): + """Extract reasoning_content from a chat completion message. + + Different backends may expose reasoning in different ways: + - ``message.reasoning_content`` (lmdeploy / OpenAI extension) + - Wrapped inside ``...`` tags in ``message.content`` + This helper tries both and returns the reasoning string (or *None*). + """ + reasoning = getattr(message, 'reasoning_content', None) + if reasoning: + return reasoning + + content = message.content or '' + if THINK_START_TOKEN in content and THINK_END_TOKEN in content: + start = content.index(THINK_START_TOKEN) + len(THINK_START_TOKEN) + end = content.index(THINK_END_TOKEN) + extracted = content[start:end].strip() + if extracted: + return extracted + + return None + + +def get_reasoning_tokens(response): + """Extract reasoning_tokens from usage, handling various response shapes. + + Returns int or *None* if not available. + """ + usage = response.usage + if usage is None: + return None + + # completion_tokens_details.reasoning_tokens (OpenAI style) + details = getattr(usage, 'completion_tokens_details', None) + if details is not None: + rt = getattr(details, 'reasoning_tokens', None) + if rt is not None: + return rt + + # Direct attribute + rt = getattr(usage, 'reasoning_tokens', None) + if rt is not None: + return rt + + return None + + +# -- Message-building helpers ----------------------------------------------- + + +def build_messages_with_tool_response( + tool_call_id='call_test_001', + function_name='get_current_weather', +): + """Build message list: user ask → assistant tool_call → tool result.""" + return [ + { + 'role': 'system', + 'content': 'You are a helpful assistant that can use tools.', + }, + { + 'role': 'user', + 'content': "What's the weather like in Dallas, TX?", + }, + { + 'role': + 'assistant', + 'content': + None, + 'tool_calls': [{ + 'id': tool_call_id, + 'type': 'function', + 'function': { + 'name': function_name, + 'arguments': '{"city": "Dallas", "state": "TX"}', + }, + }], + }, + { + 'role': 'tool', + 'tool_call_id': tool_call_id, + 'content': json.dumps({ + 'temperature': 98, + 'unit': 'fahrenheit', + 'description': 'Sunny with clear skies', + }), + }, + ] + + +def build_messages_with_parallel_tool_responses(): + """Build message list simulating two parallel tool calls + results.""" + return [ + { + 'role': 'system', + 'content': 'You are a helpful assistant that can use tools.', + }, + { + 'role': 'user', + 'content': "What's the weather in Dallas, TX and San Francisco, CA?", + }, + { + 'role': + 'assistant', + 'content': + None, + 'tool_calls': [ + { + 'id': 'call_001', + 'type': 'function', + 'function': { + 'name': 'get_current_weather', + 'arguments': '{"city": "Dallas", "state": "TX"}', + }, + }, + { + 'id': 'call_002', + 'type': 'function', + 'function': { + 'name': 'get_current_weather', + 'arguments': '{"city": "San Francisco", "state": "CA"}', + }, + }, + ], + }, + { + 'role': 'tool', + 'tool_call_id': 'call_001', + 'content': json.dumps({ + 'temperature': 98, + 'unit': 'fahrenheit', + 'description': 'Sunny', + }), + }, + { + 'role': 'tool', + 'tool_call_id': 'call_002', + 'content': json.dumps({ + 'temperature': 65, + 'unit': 'fahrenheit', + 'description': 'Foggy', + }), + }, + ] + + +def build_reasoning_tool_roundtrip_messages(tool_call_id='call_reason_001'): + """Build multi-turn messages: user → assistant (reasoning+tool_call) → tool → continue.""" + return [ + { + 'role': 'system', + 'content': 'You are a helpful assistant that can use tools. ' + 'Think through problems step by step.', + }, + { + 'role': 'user', + 'content': "I'm visiting Dallas, TX. Should I bring an umbrella?", + }, + { + 'role': + 'assistant', + 'content': + 'Let me think about this. To answer whether you need ' + 'an umbrella, I should check the current weather in ' + 'Dallas, TX.', + 'tool_calls': [{ + 'id': tool_call_id, + 'type': 'function', + 'function': { + 'name': 'get_current_weather', + 'arguments': '{"city": "Dallas", "state": "TX"}', + }, + }], + }, + { + 'role': + 'tool', + 'tool_call_id': + tool_call_id, + 'content': + json.dumps({ + 'temperature': 95, + 'unit': 'fahrenheit', + 'description': 'Sunny with clear skies', + 'precipitation': '0%', + }), + }, + ]