From d29b79196b79eec5d43c84f5559b0f9f121dff0d Mon Sep 17 00:00:00 2001 From: Maximilian Winter Date: Sun, 26 May 2024 01:56:28 +0200 Subject: [PATCH] Fixed provider identifier --- .../web_search/web_search.py | 23 ++++++++++++++++--- src/llama_cpp_agent/providers/vllm_server.py | 2 +- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/examples/03_Tools_And_Function_Calling/web_search/web_search.py b/examples/03_Tools_And_Function_Calling/web_search/web_search.py index 79eb166..4ce2de3 100644 --- a/examples/03_Tools_And_Function_Calling/web_search/web_search.py +++ b/examples/03_Tools_And_Function_Calling/web_search/web_search.py @@ -2,7 +2,7 @@ from llama_cpp_agent.chat_history.messages import Roles from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings from llama_cpp_agent.providers import LlamaCppServerProvider -from llama_cpp_agent.providers.provider_base import LlmProvider +from llama_cpp_agent.providers.provider_base import LlmProvider, LlmProviderId from web_search_interfaces import WebCrawler, WebSearchProvider from default_web_crawlers import TrafilaturaWebCrawler from default_web_search_providers import DDGWebSearchProvider @@ -11,7 +11,10 @@ class WebSearchTool: def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFormatterType, context_character_limit: int = 7500, - web_crawler: WebCrawler = None, web_search_provider: WebSearchProvider = None): + web_crawler: WebCrawler = None, web_search_provider: WebSearchProvider = None, temperature: int = 0.45, + top_p: int = 0.95, + top_k: int = 40, + max_tokens_per_summary: int = 750): self.summarising_agent = LlamaCppAgent(llm_provider, debug_output=True, system_prompt="You are a text summarization and information extraction specialist and you are able to summarize and filter out information relevant to a specific query.", predefined_messages_formatter_type=message_formatter_type) @@ -26,6 +29,20 @@ def __init__(self, llm_provider: LlmProvider, message_formatter_type: MessagesFo self.web_search_provider = web_search_provider self.context_character_limit = context_character_limit + settings = llm_provider.get_provider_default_settings() + provider_id = llm_provider.get_provider_identifier() + settings.temperature = temperature + settings.top_p = top_p + settings.top_k = top_k + + if provider_id == LlmProviderId.llama_cpp_server: + settings.n_predict = max_tokens_per_summary + elif provider_id == LlmProviderId.tgi_server: + settings.max_new_tokens = max_tokens_per_summary + else: + settings.max_tokens = max_tokens_per_summary + + self.settings = settings def search_web(self, search_query: str): """ @@ -40,7 +57,7 @@ def search_web(self, search_query: str): if web_info != "": web_info = self.summarising_agent.get_chat_response( f"Please summarize the following Website content and extract relevant information to this query:'{search_query}'.\n\n" + web_info, - add_response_to_chat_history=False, add_message_to_chat_history=False) + add_response_to_chat_history=False, add_message_to_chat_history=False, llm_sampling_settings=self.settings) result_string += web_info res = result_string.strip() diff --git a/src/llama_cpp_agent/providers/vllm_server.py b/src/llama_cpp_agent/providers/vllm_server.py index 1a2040a..5157abd 100644 --- a/src/llama_cpp_agent/providers/vllm_server.py +++ b/src/llama_cpp_agent/providers/vllm_server.py @@ -86,7 +86,7 @@ def is_using_json_schema_constraints(self): return True def get_provider_identifier(self) -> LlmProviderId: - return LlmProviderId.tgi_server + return LlmProviderId.vllm_server def get_provider_default_settings(self) -> VLLMServerSamplingSettings: return VLLMServerSamplingSettings()