From 3e4c5911673c388376063bc2395e05a8e5c0b4b3 Mon Sep 17 00:00:00 2001 From: Ravi Patel Date: Mon, 22 Apr 2024 11:57:17 -0500 Subject: [PATCH 1/3] added ability to use ollama and other OpenAI compatible APIs --- browserpilot/agents/compilers/instruction_compiler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/browserpilot/agents/compilers/instruction_compiler.py b/browserpilot/agents/compilers/instruction_compiler.py index a7c8dfb..418fd0d 100644 --- a/browserpilot/agents/compilers/instruction_compiler.py +++ b/browserpilot/agents/compilers/instruction_compiler.py @@ -16,8 +16,7 @@ logger = logging.getLogger(__name__) # Instantiate OpenAI with OPENAI_API_KEY. -client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) - +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url=os.environ.get("OPENAI_API_BASE_URL", None)) """Set up all the prompt variables.""" # Designated tokens. From 016aa4152125efff041a6b229a6cacf3e563c66e Mon Sep 17 00:00:00 2001 From: Ravi Patel Date: Mon, 22 Apr 2024 14:22:49 -0500 Subject: [PATCH 2/3] added ability to use ollama and other OpenAI compatible APIs --- browserpilot/agents/compilers/instruction_compiler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/browserpilot/agents/compilers/instruction_compiler.py b/browserpilot/agents/compilers/instruction_compiler.py index 418fd0d..5026093 100644 --- a/browserpilot/agents/compilers/instruction_compiler.py +++ b/browserpilot/agents/compilers/instruction_compiler.py @@ -16,7 +16,8 @@ logger = logging.getLogger(__name__) # Instantiate OpenAI with OPENAI_API_KEY. -client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url=os.environ.get("OPENAI_API_BASE_URL", None)) +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY").strip(), + base_url=os.environ.get("OPENAI_API_BASE_URL", None)) """Set up all the prompt variables.""" # Designated tokens. @@ -287,7 +288,8 @@ def get_completion( return text try: - if "gpt-3.5-turbo" in model or "gpt-4" in model: + + if "gpt-3.5-turbo" in model or "gpt-4" in model or ('api.openai.com' not in str(client.base_url)): response = client.chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}], From 8b270f689fd37fe150bbf3b5c4c7999a992bcef8 Mon Sep 17 00:00:00 2001 From: Ravi Patel Date: Mon, 22 Apr 2024 18:24:45 -0500 Subject: [PATCH 3/3] fixed llama3 prompting issues. added regex to extract code from response --- .../agents/compilers/instruction_compiler.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/browserpilot/agents/compilers/instruction_compiler.py b/browserpilot/agents/compilers/instruction_compiler.py index 5026093..bfa8d97 100644 --- a/browserpilot/agents/compilers/instruction_compiler.py +++ b/browserpilot/agents/compilers/instruction_compiler.py @@ -6,6 +6,7 @@ import logging import traceback import os +import re from typing import Dict, List, Union @@ -288,8 +289,7 @@ def get_completion( return text try: - - if "gpt-3.5-turbo" in model or "gpt-4" in model or ('api.openai.com' not in str(client.base_url)): + if "gpt-3.5-turbo" in model or "gpt-4" in model: response = client.chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}], @@ -301,6 +301,21 @@ def get_completion( stop=stop, ) text = response.choices[0].message.content + elif'api.openai.com' not in str(client.base_url): + # LLama3 is more aggressive and stops at the first stop token. might want + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_tokens, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + temperature=temperature + ) + raw_text = response.choices[0].message.content + text = "\n".join(re.findall(r"```([^`]+)```", raw_text)) + + else: response = client.completions.create( model=model, @@ -336,6 +351,7 @@ def get_action_output(self, instructions): """Get the action output for the given instructions.""" prompt = self.base_prompt.format(instructions=instructions) completion = self.get_completion(prompt).strip() + action_output = completion.strip() lines = [line for line in action_output.split("\n") if not line.startswith("import ")] action_output = "\n".join(lines)