From 538daf697d7fad632947c029f9dc6efa3de505ec Mon Sep 17 00:00:00 2001 From: Ori Kotek Date: Thu, 23 Nov 2023 16:37:18 +0200 Subject: [PATCH 1/3] Changed to support HuggingFace self served TGI serving --- pr_agent/algo/ai_handler.py | 51 ++++++++++++++++++++++++------------- pr_agent/algo/utils.py | 2 +- requirements.txt | 4 +-- 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 9a48cdc3d..619124c38 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -3,8 +3,10 @@ import litellm import openai from litellm import acompletion -from openai.error import APIError, RateLimitError, Timeout, TryAgain +from openai import APIError, RateLimitError from retry import retry +from tenacity import TryAgain + from pr_agent.config_loader import get_settings from pr_agent.log import get_logger @@ -24,7 +26,7 @@ def __init__(self): Raises a ValueError if the OpenAI key is missing. """ self.azure = False - + self.api_base = None if get_settings().get("OPENAI.KEY", None): openai.api_key = get_settings().openai.key litellm.openai_key = get_settings().openai.key @@ -53,8 +55,9 @@ def __init__(self): litellm.replicate_key = get_settings().replicate.key if get_settings().get("HUGGINGFACE.KEY", None): litellm.huggingface_key = get_settings().huggingface.key - if get_settings().get("HUGGINGFACE.API_BASE", None): - litellm.api_base = get_settings().huggingface.api_base + if get_settings().get("HUGGINGFACE.API_BASE", None): + litellm.api_base = get_settings().huggingface.api_base + self.api_base = get_settings().huggingface.api_base if get_settings().get("VERTEXAI.VERTEX_PROJECT", None): litellm.vertex_project = get_settings().vertexai.vertex_project litellm.vertex_location = get_settings().get( @@ -68,22 +71,29 @@ def deployment_id(self): """ return get_settings().get("OPENAI.DEPLOYMENT_ID", None) - @retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError), - tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3)) - async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2): + @retry( + exceptions=(APIError, TryAgain, AttributeError, RateLimitError), + tries=OPENAI_RETRIES, + delay=2, + backoff=2, + jitter=(1, 3), + ) + async def chat_completion( + self, model: str, system: str, user: str, temperature: float = 0.2 + ): """ Performs a chat completion using the OpenAI ChatCompletion API. Retries in case of API errors or timeouts. - + Args: model (str): The model to use for chat completion. temperature (float): The temperature parameter for chat completion. system (str): The system message for chat completion. user (str): The user message for chat completion. - + Returns: tuple: A tuple containing the response and finish reason from the API. - + Raises: TryAgain: If the API response is empty or there are no choices in the response. APIError: If there is an error during OpenAI inference. @@ -105,22 +115,29 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: deployment_id=deployment_id, messages=messages, temperature=temperature, - force_timeout=get_settings().config.ai_timeout + force_timeout=get_settings().config.ai_timeout, + api_base=self.api_base, ) - except (APIError, Timeout, TryAgain) as e: + except (APIError, TryAgain) as e: get_logger().error("Error during OpenAI inference: ", e) raise - except (RateLimitError) as e: + except RateLimitError as e: get_logger().error("Rate limit error during OpenAI inference: ", e) raise - except (Exception) as e: + except Exception as e: get_logger().error("Unknown error during OpenAI inference: ", e) raise TryAgain from e if response is None or len(response["choices"]) == 0: raise TryAgain - resp = response["choices"][0]['message']['content'] + resp = response["choices"][0]["message"]["content"] finish_reason = response["choices"][0]["finish_reason"] usage = response.get("usage") - get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason, - model=model, usage=usage) + get_logger().info( + "AI response", + response=resp, + messages=messages, + finish_reason=finish_reason, + model=model, + usage=usage, + ) return resp, finish_reason diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index d3377deeb..317619411 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -370,7 +370,7 @@ def get_user_labels(current_labels: List[str] = None): def get_max_tokens(model): settings = get_settings() - max_tokens_model = MAX_TOKENS[model] + max_tokens_model = MAX_TOKENS.get(model, 4000) if settings.config.max_model_tokens: max_tokens_model = min(settings.config.max_model_tokens, max_tokens_model) # get_logger().debug(f"limiting max tokens to {max_tokens_model}") diff --git a/requirements.txt b/requirements.txt index eae08f4cc..68a909f18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ dynaconf==3.1.12 fastapi==0.99.0 PyGithub==1.59.* retry==0.9.2 -openai==0.27.8 +openai==1.3.5 Jinja2==3.1.2 tiktoken==0.4.0 uvicorn==0.22.0 @@ -13,7 +13,6 @@ atlassian-python-api==3.39.0 GitPython==3.1.32 PyYAML==6.0 starlette-context==0.3.6 -litellm==0.12.5 boto3==1.28.25 google-cloud-storage==2.10.0 ujson==5.8.0 @@ -23,3 +22,4 @@ pinecone-client pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main loguru==0.7.2 google-cloud-aiplatform==1.35.0 +litellm @ git+https://github.com/Codium-ai/litellm.git \ No newline at end of file From 5d9c767254cb1bcb60ad846a0e5a748ea0f747d9 Mon Sep 17 00:00:00 2001 From: Ori Kotek Date: Mon, 27 Nov 2023 19:43:22 +0200 Subject: [PATCH 2/3] Further adaptations for TGI --- pr_agent/algo/ai_handler.py | 6 +++++- requirements.txt | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 619124c38..f1b4fadd0 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -55,7 +55,7 @@ def __init__(self): litellm.replicate_key = get_settings().replicate.key if get_settings().get("HUGGINGFACE.KEY", None): litellm.huggingface_key = get_settings().huggingface.key - if get_settings().get("HUGGINGFACE.API_BASE", None): + if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model: litellm.api_base = get_settings().huggingface.api_base self.api_base = get_settings().huggingface.api_base if get_settings().get("VERTEXAI.VERTEX_PROJECT", None): @@ -109,6 +109,10 @@ async def chat_completion( ) if self.azure: model = 'azure/' + model + system = get_settings().get("CONFIG.MODEL_SYSTEM_PREFIX", "") + system + \ + get_settings().get("CONFIG.MODEL_SYSTEM_SUFFIX", "") + user = get_settings().get("CONFIG.MODEL_USER_PREFIX", "") + user + \ + get_settings().get("CONFIG.MODEL_USER_SUFFIX", "") messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] response = await acompletion( model=model, diff --git a/requirements.txt b/requirements.txt index 68a909f18..93dabdb20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,4 @@ pinecone-client pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main loguru==0.7.2 google-cloud-aiplatform==1.35.0 -litellm @ git+https://github.com/Codium-ai/litellm.git \ No newline at end of file +litellm==1.7.1 \ No newline at end of file From 7390f52c059644aad5070e0d2d35847c3ef349c2 Mon Sep 17 00:00:00 2001 From: Ori Kotek Date: Wed, 29 Nov 2023 09:15:04 +0200 Subject: [PATCH 3/3] Further adaptations for TGI --- pr_agent/algo/ai_handler.py | 15 ++++++++++++++- pr_agent/git_providers/github_provider.py | 3 ++- requirements.txt | 3 ++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index f1b4fadd0..f45dc71dd 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -109,11 +109,18 @@ async def chat_completion( ) if self.azure: model = 'azure/' + model + system = get_settings().get("CONFIG.MODEL_SYSTEM_PREFIX", "") + system + \ get_settings().get("CONFIG.MODEL_SYSTEM_SUFFIX", "") + suffix = '' + yaml_start = '```yaml' + if user.endswith(yaml_start): + user = user[:-len(yaml_start)] + suffix = '\n' + yaml_start + '\n' user = get_settings().get("CONFIG.MODEL_USER_PREFIX", "") + user + \ - get_settings().get("CONFIG.MODEL_USER_SUFFIX", "") + get_settings().get("CONFIG.MODEL_USER_SUFFIX", "") + suffix messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + stop = get_settings().get("CONFIG.MODEL_STOP", None) response = await acompletion( model=model, deployment_id=deployment_id, @@ -121,6 +128,7 @@ async def chat_completion( temperature=temperature, force_timeout=get_settings().config.ai_timeout, api_base=self.api_base, + stop=stop, ) except (APIError, TryAgain) as e: get_logger().error("Error during OpenAI inference: ", e) @@ -134,6 +142,11 @@ async def chat_completion( if response is None or len(response["choices"]) == 0: raise TryAgain resp = response["choices"][0]["message"]["content"] + if stop: + for stop_word in stop: + if resp.endswith(stop_word): + resp = resp[:-len(stop_word)] + break finish_reason = response["choices"][0]["finish_reason"] usage = response.get("usage") get_logger().info( diff --git a/pr_agent/git_providers/github_provider.py b/pr_agent/git_providers/github_provider.py index 1fb851644..bad1c44fb 100644 --- a/pr_agent/git_providers/github_provider.py +++ b/pr_agent/git_providers/github_provider.py @@ -327,7 +327,8 @@ def add_eyes_reaction(self, issue_comment_id: int) -> Optional[int]: def remove_reaction(self, issue_comment_id: int, reaction_id: int) -> bool: try: - self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id) + if reaction_id: + self.pr.get_issue_comment(issue_comment_id).delete_reaction(reaction_id) return True except Exception as e: get_logger().exception(f"Failed to remove eyes reaction, error: {e}") diff --git a/requirements.txt b/requirements.txt index 93dabdb20..8cedbbf10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,5 @@ pinecone-client pinecone-datasets @ git+https://github.com/mrT23/pinecone-datasets.git@main loguru==0.7.2 google-cloud-aiplatform==1.35.0 -litellm==1.7.1 \ No newline at end of file +litellm==1.7.1 +tenacity==8.1.0