diff --git a/sgpt/__main__.py b/sgpt/__main__.py index 1c7bc113..1e1c9e8e 100644 --- a/sgpt/__main__.py +++ b/sgpt/__main__.py @@ -1,3 +1,3 @@ -from .app import entry_point +from sgpt.app import entry_point entry_point() diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index c6302132..04e36073 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -58,28 +58,46 @@ def make_messages(self, prompt: str) -> List[Dict[str, str]]: def handle_function_call( self, messages: List[dict[str, Any]], - name: str, - arguments: str, + name_map, + arguments_map, ) -> Generator[str, None, None]: + all_tool_calls = [ + { + "id": id, + "type": "function", + "function": {"name": name, "arguments": arguments}, + } + for id, name, arguments in zip( + name_map.keys(), name_map.values(), arguments_map.values() + ) + ] messages.append( { "role": "assistant", "content": "", - "function_call": {"name": name, "arguments": arguments}, + "tool_calls": all_tool_calls, } ) if messages and messages[-1]["role"] == "assistant": yield "\n" - dict_args = json.loads(arguments) - joined_args = ", ".join(f'{k}="{v}"' for k, v in dict_args.items()) - yield f"> @FunctionCall `{name}({joined_args})` \n\n" + all_function_res_msgs = [] + for id, name, arguments in zip( + name_map.keys(), name_map.values(), arguments_map.values() + ): + dict_args = json.loads(arguments) + joined_args = ", ".join(f'{k}="{v}"' for k, v in dict_args.items()) + yield f"> @FunctionCall `{name}({joined_args})` \n\n" + + result = get_function(name)(**dict_args) + if cfg.get("SHOW_FUNCTIONS_OUTPUT") == "true": + yield f"```text\n{result}\n```\n" - result = get_function(name)(**dict_args) - if cfg.get("SHOW_FUNCTIONS_OUTPUT") == "true": - yield f"```text\n{result}\n```\n" - messages.append({"role": "function", "content": result, "name": name}) + all_function_res_msgs.append( + {"role": "tool", "content": result, "tool_call_id": id} + ) + messages += all_function_res_msgs @cache def get_completion( @@ -90,7 +108,6 @@ def get_completion( messages: List[Dict[str, Any]], functions: Optional[List[Dict[str, str]]], ) -> Generator[str, None, None]: - name = arguments = "" is_shell_role = self.role.name == DefaultRoles.SHELL.value is_code_role = self.role.name == DefaultRoles.CODE.value is_dsc_shell_role = self.role.name == DefaultRoles.DESCRIBE_SHELL.value @@ -112,6 +129,8 @@ def get_completion( ) try: + arguments_map = {} + name_map = {} for chunk in response: delta = chunk.choices[0].delta @@ -119,14 +138,21 @@ def get_completion( tool_calls = ( delta.get("tool_calls") if use_litellm else delta.tool_calls ) + if tool_calls: for tool_call in tool_calls: - if tool_call.function.name: - name = tool_call.function.name - if tool_call.function.arguments: - arguments += tool_call.function.arguments + if tool_call.id is not None: + id = tool_call.id + name_map[id] = tool_call.function.name + arguments_map[id] = "" + else: + arguments_map[id] += tool_call.function.arguments if chunk.choices[0].finish_reason == "tool_calls": - yield from self.handle_function_call(messages, name, arguments) + yield from self.handle_function_call( + messages, + name_map, + arguments_map, + ) yield from self.get_completion( model=model, temperature=temperature,