diff --git a/mcp_client_for_ollama/utils/streaming.py b/mcp_client_for_ollama/utils/streaming.py index 0396e76..7cbd8fe 100644 --- a/mcp_client_for_ollama/utils/streaming.py +++ b/mcp_client_for_ollama/utils/streaming.py @@ -4,6 +4,8 @@ Classes: StreamingManager: Handles streaming responses from Ollama. """ +import json +from ollama._types import Message from rich.markdown import Markdown from .metrics import display_metrics, extract_metrics @@ -141,4 +143,74 @@ async def process_streaming_response(self, stream, print_response=True, thinking if show_metrics and metrics: display_metrics(self.console, metrics) + # Check for JSON tool calls in the accumulated text if no tool_calls object was found + if not tool_calls and accumulated_text: + # Some models wrap JSON in markdown, let's strip it + text_to_parse = accumulated_text.strip() + if text_to_parse.startswith("```json"): + text_to_parse = text_to_parse[7:] + if text_to_parse.endswith("```"): + text_to_parse = text_to_parse[:-3] + text_to_parse = text_to_parse.strip() + + # Find the start and end of the JSON object/array + json_start = -1 + first_brace = text_to_parse.find('{') + first_bracket = text_to_parse.find('[') + + if first_brace == -1: + json_start = first_bracket + elif first_bracket == -1: + json_start = first_brace + else: + json_start = min(first_brace, first_bracket) + + if json_start != -1: + json_end = -1 + last_brace = text_to_parse.rfind('}') + last_bracket = text_to_parse.rfind(']') + json_end = max(last_brace, last_bracket) + + if json_end > json_start: + json_str = text_to_parse[json_start:json_end+1] + try: + parsed_json = json.loads(json_str) + + potential_tool_calls = [] + if isinstance(parsed_json, list): + potential_tool_calls = parsed_json + elif isinstance(parsed_json, dict): + # Some models wrap the call in a 'tool_calls' key + if 'tool_calls' in parsed_json and isinstance(parsed_json['tool_calls'], list): + potential_tool_calls = parsed_json['tool_calls'] + else: + potential_tool_calls = [parsed_json] + + for tc_json in potential_tool_calls: + # Case 1: Standard OpenAI/Ollama format {'function': {'name': ..., 'arguments': ...}} + if (isinstance(tc_json, dict) and 'function' in tc_json and + isinstance(tc_json['function'], dict) and 'name' in tc_json['function'] and + 'arguments' in tc_json['function']): + + tool_calls.append(Message.ToolCall( + function=Message.ToolCall.Function( + name=tc_json['function']['name'], + arguments=tc_json['function']['arguments'] + ) + )) + # Case 2: Flattened format {'name': ..., 'arguments': ...} as seen from qwen2.5 + elif (isinstance(tc_json, dict) and 'name' in tc_json and 'arguments' in tc_json): + tool_calls.append(Message.ToolCall( + function=Message.ToolCall.Function( + name=tc_json['name'], + arguments=tc_json['arguments'] + ) + )) + + if tool_calls: + accumulated_text = "" # Clear text if we have tool calls + + except json.JSONDecodeError: + pass # Not a valid JSON, treat as text + return accumulated_text, tool_calls, metrics