Skip to content

Commit

Permalink
Fixed provider identifier
Browse files Browse the repository at this point in the history
  • Loading branch information
Maximilian-Winter committed May 25, 2024
1 parent 744e20e commit d29b791
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
23 changes: 20 additions & 3 deletions examples/03_Tools_And_Function_Calling/web_search/web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from llama_cpp_agent.chat_history.messages import Roles
from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings
from llama_cpp_agent.providers import LlamaCppServerProvider
from llama_cpp_agent.providers.provider_base import LlmProvider
from llama_cpp_agent.providers.provider_base import LlmProvider, LlmProviderId
from web_search_interfaces import WebCrawler, WebSearchProvider
from default_web_crawlers import TrafilaturaWebCrawler
from default_web_search_providers import DDGWebSearchProvider
Expand All @@ -11,7 +11,10 @@
class WebSearchTool:

def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFormatterType, context_character_limit: int = 7500,
web_crawler: WebCrawler = None, web_search_provider: WebSearchProvider = None):
web_crawler: WebCrawler = None, web_search_provider: WebSearchProvider = None, temperature: int = 0.45,
top_p: int = 0.95,
top_k: int = 40,
max_tokens_per_summary: int = 750):
self.summarising_agent = LlamaCppAgent(llm_provider, debug_output=True,
system_prompt="You are a text summarization and information extraction specialist and you are able to summarize and filter out information relevant to a specific query.",
predefined_messages_formatter_type=message_formatter_type)
Expand All @@ -26,6 +29,20 @@ def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFo
self.web_search_provider = web_search_provider

self.context_character_limit = context_character_limit
settings = llm_provider.get_provider_default_settings()
provider_id = llm_provider.get_provider_identifier()
settings.temperature = temperature
settings.top_p = top_p
settings.top_k = top_k

if provider_id == LlmProviderId.llama_cpp_server:
settings.n_predict = max_tokens_per_summary
elif provider_id == LlmProviderId.tgi_server:
settings.max_new_tokens = max_tokens_per_summary
else:
settings.max_tokens = max_tokens_per_summary

self.settings = settings

def search_web(self, search_query: str):
"""
Expand All @@ -40,7 +57,7 @@ def search_web(self, search_query: str):
if web_info != "":
web_info = self.summarising_agent.get_chat_response(
f"Please summarize the following Website content and extract relevant information to this query:'{search_query}'.\n\n" + web_info,
add_response_to_chat_history=False, add_message_to_chat_history=False)
add_response_to_chat_history=False, add_message_to_chat_history=False, llm_sampling_settings=self.settings)
result_string += web_info

res = result_string.strip()
Expand Down
2 changes: 1 addition & 1 deletion src/llama_cpp_agent/providers/vllm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def is_using_json_schema_constraints(self):
return True

def get_provider_identifier(self) -> LlmProviderId:
return LlmProviderId.tgi_server
return LlmProviderId.vllm_server

def get_provider_default_settings(self) -> VLLMServerSamplingSettings:
return VLLMServerSamplingSettings()
Expand Down

0 comments on commit d29b791

Please sign in to comment.