From a4952c58c4429ae48b1deca6a1145d40d1bc3326 Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sun, 26 May 2024 07:19:33 +0200 Subject: [PATCH] Fixes websearch --- .../web_search_agent.py | 5 ++-- src/llama_cpp_agent/tools/__init__.py | 2 +- .../tools/web_search/__init__.py | 8 +++--- .../tools/web_search/default_web_crawlers.py | 2 +- .../default_web_search_providers.py | 2 +- src/llama_cpp_agent/tools/web_search/tool.py | 28 ++++++++++++------- 6 files changed, 28 insertions(+), 19 deletions(-) diff --git a/examples/03_Tools_And_Function_Calling/web_search_agent.py b/examples/03_Tools_And_Function_Calling/web_search_agent.py index 70fdfd5..94b6221 100644 --- a/examples/03_Tools_And_Function_Calling/web_search_agent.py +++ b/examples/03_Tools_And_Function_Calling/web_search_agent.py @@ -4,6 +4,7 @@ from llama_cpp_agent.providers import LlamaCppServerProvider from llama_cpp_agent.tools import WebSearchTool + def send_message_to_user(message: str): """ Send a message to user. @@ -22,7 +23,7 @@ def send_message_to_user(message: str): add_tools_and_structures_documentation_to_system_prompt=True, ) -search_tool = WebSearchTool(provider, MessagesFormatterType.CHATML, 20000) +search_tool = WebSearchTool(provider, MessagesFormatterType.CHATML, max_tokens_search_results=20000) settings = provider.get_provider_default_settings() @@ -43,4 +44,4 @@ def send_message_to_user(message: str): llm_sampling_settings=settings) else: result = agent.get_chat_response(result[0]["return_value"], role=Roles.tool, prompt_suffix="\n```json\n", - structured_output_settings=output_settings, llm_sampling_settings=settings) \ No newline at end of file + structured_output_settings=output_settings, llm_sampling_settings=settings) diff --git a/src/llama_cpp_agent/tools/__init__.py b/src/llama_cpp_agent/tools/__init__.py index 47704c4..040046a 100644 --- a/src/llama_cpp_agent/tools/__init__.py +++ b/src/llama_cpp_agent/tools/__init__.py @@ -1 +1 @@ -from web_search import WebSearchTool, WebSearchProvider, WebCrawler, TrafilaturaWebCrawler, DDGWebSearchProvider \ No newline at end of file +from .web_search import WebSearchTool, WebSearchProvider, WebCrawler, TrafilaturaWebCrawler, DDGWebSearchProvider \ No newline at end of file diff --git a/src/llama_cpp_agent/tools/web_search/__init__.py b/src/llama_cpp_agent/tools/web_search/__init__.py index 9a41be9..94e2c63 100644 --- a/src/llama_cpp_agent/tools/web_search/__init__.py +++ b/src/llama_cpp_agent/tools/web_search/__init__.py @@ -1,4 +1,4 @@ -from tool import WebSearchTool -from web_search_interfaces import WebCrawler, WebSearchProvider -from default_web_crawlers import TrafilaturaWebCrawler -from default_web_search_providers import DDGWebSearchProvider +from .tool import WebSearchTool +from .web_search_interfaces import WebCrawler, WebSearchProvider +from .default_web_crawlers import TrafilaturaWebCrawler +from .default_web_search_providers import DDGWebSearchProvider diff --git a/src/llama_cpp_agent/tools/web_search/default_web_crawlers.py b/src/llama_cpp_agent/tools/web_search/default_web_crawlers.py index 281df80..73fcd35 100644 --- a/src/llama_cpp_agent/tools/web_search/default_web_crawlers.py +++ b/src/llama_cpp_agent/tools/web_search/default_web_crawlers.py @@ -1,6 +1,6 @@ import json -from web_search_interfaces import WebCrawler +from .web_search_interfaces import WebCrawler from trafilatura import fetch_url, extract diff --git a/src/llama_cpp_agent/tools/web_search/default_web_search_providers.py b/src/llama_cpp_agent/tools/web_search/default_web_search_providers.py index 2894ac5..bd4050f 100644 --- a/src/llama_cpp_agent/tools/web_search/default_web_search_providers.py +++ b/src/llama_cpp_agent/tools/web_search/default_web_search_providers.py @@ -1,6 +1,6 @@ from duckduckgo_search import DDGS -from web_search_interfaces import WebSearchProvider +from .web_search_interfaces import WebSearchProvider class DDGWebSearchProvider(WebSearchProvider): diff --git a/src/llama_cpp_agent/tools/web_search/tool.py b/src/llama_cpp_agent/tools/web_search/tool.py index d2d56bb..9d40a64 100644 --- a/src/llama_cpp_agent/tools/web_search/tool.py +++ b/src/llama_cpp_agent/tools/web_search/tool.py @@ -1,20 +1,18 @@ from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType -from llama_cpp_agent.chat_history.messages import Roles -from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings -from llama_cpp_agent.providers import LlamaCppServerProvider from llama_cpp_agent.providers.provider_base import LlmProvider, LlmProviderId -from web_search_interfaces import WebCrawler, WebSearchProvider -from default_web_crawlers import TrafilaturaWebCrawler -from default_web_search_providers import DDGWebSearchProvider +from .web_search_interfaces import WebCrawler, WebSearchProvider +from .default_web_crawlers import TrafilaturaWebCrawler +from .default_web_search_providers import DDGWebSearchProvider class WebSearchTool: - def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFormatterType, context_character_limit: int = 7500, + def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFormatterType, web_crawler: WebCrawler = None, web_search_provider: WebSearchProvider = None, temperature: int = 0.45, top_p: int = 0.95, - top_k: int = 40, + top_k: int = 40, max_tokens_search_results: int = 7500, max_tokens_per_summary: int = 750): + self.llm_provider = llm_provider self.summarising_agent = LlamaCppAgent(llm_provider, debug_output=True, system_prompt="You are a text summarization and information extraction specialist and you are able to summarize and filter out information relevant to a specific query.", predefined_messages_formatter_type=message_formatter_type) @@ -28,7 +26,7 @@ def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFo else: self.web_search_provider = web_search_provider - self.context_character_limit = context_character_limit + self.max_tokens_search_results = max_tokens_search_results settings = llm_provider.get_provider_default_settings() provider_id = llm_provider.get_provider_identifier() settings.temperature = temperature @@ -61,7 +59,17 @@ def search_web(self, search_query: str): result_string += web_info res = result_string.strip() - return "Based on the following results, answer the previous user query:\nResults:\n\n" + res[:self.context_character_limit] + tokens = self.llm_provider.tokenize(res) + if self.max_tokens_search_results < len(tokens): + remove_chars = len(tokens) - self.max_tokens_search_results + while True: + tokens = self.llm_provider.tokenize(res[:remove_chars]) + if self.max_tokens_search_results >= len(tokens): + break + else: + remove_chars += 100 + + return "Based on the following results, answer the previous user query:\nResults:\n\n" + res[:self.max_tokens_search_results] def get_tool(self): return self.search_web