diff --git a/examples/03_Tools_And_Function_Calling/web_search/web_search.py b/examples/03_Tools_And_Function_Calling/web_search/web_search.py index 90f6c97..79eb166 100644 --- a/examples/03_Tools_And_Function_Calling/web_search/web_search.py +++ b/examples/03_Tools_And_Function_Calling/web_search/web_search.py @@ -10,7 +10,7 @@ class WebSearchTool: - def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFormatterType, + def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFormatterType, context_character_limit: int = 7500, web_crawler: WebCrawler = None, web_search_provider: WebSearchProvider = None): self.summarising_agent = LlamaCppAgent(llm_provider, debug_output=True, system_prompt="You are a text summarization and information extraction specialist and you are able to summarize and filter out information relevant to a specific query.", @@ -25,6 +25,8 @@ def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFo else: self.web_search_provider = web_search_provider + self.context_character_limit = context_character_limit + def search_web(self, search_query: str): """ Search the web for information. @@ -42,7 +44,7 @@ def search_web(self, search_query: str): result_string += web_info res = result_string.strip() - return "Based on the following results, answer the previous user query:\nResults:\n\n" + res + return "Based on the following results, answer the previous user query:\nResults:\n\n" + res[:self.context_character_limit] def get_tool(self): return self.search_web @@ -67,22 +69,25 @@ def send_message_to_user(message: str): add_tools_and_structures_documentation_to_system_prompt=True, ) -search_tool = WebSearchTool(provider, MessagesFormatterType.CHATML) +search_tool = WebSearchTool(provider, MessagesFormatterType.CHATML, 20000) settings = provider.get_provider_default_settings() -settings.temperature = 0.45 -settings.max_tokens = 1024 +settings.temperature = 0.65 +# settings.top_p = 0.85 +# settings.top_k = 60 +# settings.tfs_z = 0.95 +settings.max_tokens = 2048 output_settings = LlmStructuredOutputSettings.from_functions( [search_tool.get_tool(), send_message_to_user]) user = input(">") -result = agent.get_chat_response(user, +result = agent.get_chat_response(user, prompt_suffix="\n```json\n", llm_sampling_settings=settings, structured_output_settings=output_settings) while True: if result[0]["function"] == "send_message_to_user": user = input(">") - result = agent.get_chat_response(user, structured_output_settings=output_settings, + result = agent.get_chat_response(user, prompt_suffix="\n```json\n", structured_output_settings=output_settings, llm_sampling_settings=settings) else: - result = agent.get_chat_response(result[0]["return_value"], role=Roles.tool, + result = agent.get_chat_response(result[0]["return_value"], role=Roles.tool, prompt_suffix="\n```json\n", structured_output_settings=output_settings, llm_sampling_settings=settings) diff --git a/src/llama_cpp_agent/agent_memory/memory_tools.py b/src/llama_cpp_agent/agent_memory/memory_tools.py index 393a3a7..afb1cfc 100644 --- a/src/llama_cpp_agent/agent_memory/memory_tools.py +++ b/src/llama_cpp_agent/agent_memory/memory_tools.py @@ -32,7 +32,7 @@ class core_memory_append(BaseModel): def run(self, core_memory_manager: CoreMemoryManager): return core_memory_manager.add_to_core_memory( - self.key.value, self.field, self.value + self.key, self.field, self.value ) @@ -50,7 +50,7 @@ class core_memory_replace(BaseModel): def run(self, core_memory_manager: CoreMemoryManager): return core_memory_manager.replace_in_core_memory( - self.key.value, self.field, self.new_value + self.key, self.field, self.new_value ) @@ -63,7 +63,7 @@ class core_memory_remove(BaseModel): field: str = Field(..., description="The field within the core memory.") def run(self, core_memory_manager: CoreMemoryManager): - return core_memory_manager.remove_from_core_memory(self.key.value, self.field) + return core_memory_manager.remove_from_core_memory(self.key, self.field) class conversation_search(BaseModel): diff --git a/src/llama_cpp_agent/llm_agent.py b/src/llama_cpp_agent/llm_agent.py index ce96504..7e6ff5d 100644 --- a/src/llama_cpp_agent/llm_agent.py +++ b/src/llama_cpp_agent/llm_agent.py @@ -298,7 +298,7 @@ def stream_results(): } ) return structured_output_settings.handle_structured_output( - full_response_stream + full_response_stream, prompt_suffix=prompt_suffix ) if self.provider: @@ -334,7 +334,7 @@ def stream_results(): ) return structured_output_settings.handle_structured_output( - full_response + full_response, prompt_suffix=prompt_suffix ) else: text = completion["choices"][0]["text"] @@ -353,7 +353,7 @@ def stream_results(): } ) - return structured_output_settings.handle_structured_output(text) + return structured_output_settings.handle_structured_output(text, prompt_suffix=prompt_suffix) return "Error: No model loaded!" def get_text_completion( diff --git a/src/llama_cpp_agent/llm_output_settings/settings.py b/src/llama_cpp_agent/llm_output_settings/settings.py index 26070f5..c492715 100644 --- a/src/llama_cpp_agent/llm_output_settings/settings.py +++ b/src/llama_cpp_agent/llm_output_settings/settings.py @@ -551,9 +551,11 @@ def get_json_schema(self): add_inner_thoughts=self.add_thoughts_and_reasoning_field, ) - def handle_structured_output(self, llm_output: str): + def handle_structured_output(self, llm_output: str, prompt_suffix: str = None): if self.output_raw_json_string: return llm_output + if prompt_suffix: + llm_output = llm_output.replace(prompt_suffix, "", 1) if ( self.output_type is LlmStructuredOutputType.function_calling or self.output_type is LlmStructuredOutputType.parallel_function_calling diff --git a/src/llama_cpp_agent/messages_formatter.py b/src/llama_cpp_agent/messages_formatter.py index 5be347b..0f3fa43 100644 --- a/src/llama_cpp_agent/messages_formatter.py +++ b/src/llama_cpp_agent/messages_formatter.py @@ -241,7 +241,7 @@ def _format_response( False, ["<|im_end|>", ""], use_user_role_for_function_call_result=False, - strip_prompt=False, + strip_prompt=True, ) vicuna_formatter = MessagesFormatter(