Skip to content

Commit

Permalink
Fixes websearch
Browse files Browse the repository at this point in the history
  • Loading branch information
Maximilian-Winter committed May 26, 2024
1 parent e99c16e commit a4952c5
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 19 deletions.
5 changes: 3 additions & 2 deletions examples/03_Tools_And_Function_Calling/web_search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from llama_cpp_agent.providers import LlamaCppServerProvider
from llama_cpp_agent.tools import WebSearchTool


def send_message_to_user(message: str):
"""
Send a message to user.
Expand All @@ -22,7 +23,7 @@ def send_message_to_user(message: str):
add_tools_and_structures_documentation_to_system_prompt=True,
)

search_tool = WebSearchTool(provider, MessagesFormatterType.CHATML, 20000)
search_tool = WebSearchTool(provider, MessagesFormatterType.CHATML, max_tokens_search_results=20000)

settings = provider.get_provider_default_settings()

Expand All @@ -43,4 +44,4 @@ def send_message_to_user(message: str):
llm_sampling_settings=settings)
else:
result = agent.get_chat_response(result[0]["return_value"], role=Roles.tool, prompt_suffix="\n```json\n",
structured_output_settings=output_settings, llm_sampling_settings=settings)
structured_output_settings=output_settings, llm_sampling_settings=settings)
2 changes: 1 addition & 1 deletion src/llama_cpp_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from web_search import WebSearchTool, WebSearchProvider, WebCrawler, TrafilaturaWebCrawler, DDGWebSearchProvider
from .web_search import WebSearchTool, WebSearchProvider, WebCrawler, TrafilaturaWebCrawler, DDGWebSearchProvider
8 changes: 4 additions & 4 deletions src/llama_cpp_agent/tools/web_search/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tool import WebSearchTool
from web_search_interfaces import WebCrawler, WebSearchProvider
from default_web_crawlers import TrafilaturaWebCrawler
from default_web_search_providers import DDGWebSearchProvider
from .tool import WebSearchTool
from .web_search_interfaces import WebCrawler, WebSearchProvider
from .default_web_crawlers import TrafilaturaWebCrawler
from .default_web_search_providers import DDGWebSearchProvider
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

from web_search_interfaces import WebCrawler
from .web_search_interfaces import WebCrawler
from trafilatura import fetch_url, extract


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from duckduckgo_search import DDGS

from web_search_interfaces import WebSearchProvider
from .web_search_interfaces import WebSearchProvider


class DDGWebSearchProvider(WebSearchProvider):
Expand Down
28 changes: 18 additions & 10 deletions src/llama_cpp_agent/tools/web_search/tool.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
from llama_cpp_agent.chat_history.messages import Roles
from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings
from llama_cpp_agent.providers import LlamaCppServerProvider
from llama_cpp_agent.providers.provider_base import LlmProvider, LlmProviderId
from web_search_interfaces import WebCrawler, WebSearchProvider
from default_web_crawlers import TrafilaturaWebCrawler
from default_web_search_providers import DDGWebSearchProvider
from .web_search_interfaces import WebCrawler, WebSearchProvider
from .default_web_crawlers import TrafilaturaWebCrawler
from .default_web_search_providers import DDGWebSearchProvider


class WebSearchTool:

def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFormatterType, context_character_limit: int = 7500,
def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFormatterType,
web_crawler: WebCrawler = None, web_search_provider: WebSearchProvider = None, temperature: int = 0.45,
top_p: int = 0.95,
top_k: int = 40,
top_k: int = 40, max_tokens_search_results: int = 7500,
max_tokens_per_summary: int = 750):
self.llm_provider = llm_provider
self.summarising_agent = LlamaCppAgent(llm_provider, debug_output=True,
system_prompt="You are a text summarization and information extraction specialist and you are able to summarize and filter out information relevant to a specific query.",
predefined_messages_formatter_type=message_formatter_type)
Expand All @@ -28,7 +26,7 @@ def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFo
else:
self.web_search_provider = web_search_provider

self.context_character_limit = context_character_limit
self.max_tokens_search_results = max_tokens_search_results
settings = llm_provider.get_provider_default_settings()
provider_id = llm_provider.get_provider_identifier()
settings.temperature = temperature
Expand Down Expand Up @@ -61,7 +59,17 @@ def search_web(self, search_query: str):
result_string += web_info

res = result_string.strip()
return "Based on the following results, answer the previous user query:\nResults:\n\n" + res[:self.context_character_limit]
tokens = self.llm_provider.tokenize(res)
if self.max_tokens_search_results < len(tokens):
remove_chars = len(tokens) - self.max_tokens_search_results
while True:
tokens = self.llm_provider.tokenize(res[:remove_chars])
if self.max_tokens_search_results >= len(tokens):
break
else:
remove_chars += 100

return "Based on the following results, answer the previous user query:\nResults:\n\n" + res[:self.max_tokens_search_results]

def get_tool(self):
return self.search_web
Expand Down

0 comments on commit a4952c5

Please sign in to comment.