From f833dbff559675f026c638aea4a500f4be6e8468 Mon Sep 17 00:00:00 2001 From: tugrulguner Date: Sun, 20 Aug 2023 23:37:57 -0400 Subject: [PATCH 1/8] load_json_output modified for general llm --- .../agent/conversational_agent/output_parser.py | 5 +++-- autochain/agent/structs.py | 12 ++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/autochain/agent/conversational_agent/output_parser.py b/autochain/agent/conversational_agent/output_parser.py index b0518ec..d2a6169 100644 --- a/autochain/agent/conversational_agent/output_parser.py +++ b/autochain/agent/conversational_agent/output_parser.py @@ -5,13 +5,14 @@ from autochain.agent.message import BaseMessage from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser +from autochain.models.base import BaseLanguageModel from autochain.errors import OutputParserException from autochain.utils import print_with_color class ConvoJSONOutputParser(AgentOutputParser): - def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]: - response = self.load_json_output(message) + def parse(self, message: BaseMessage, llm: BaseLanguageModel) -> Union[AgentAction, AgentFinish]: + response = self.load_json_output(message, llm) action_name = response.get("tool", {}).get("name") action_args = response.get("tool", {}).get("args") diff --git a/autochain/agent/structs.py b/autochain/agent/structs.py index 435fd02..8b2839d 100644 --- a/autochain/agent/structs.py +++ b/autochain/agent/structs.py @@ -6,6 +6,7 @@ from autochain.models.base import Generation from autochain.models.chat_openai import ChatOpenAI +from autochain.models.base import BaseLanguageModel from pydantic import BaseModel from autochain.agent.message import BaseMessage, UserMessage @@ -56,7 +57,7 @@ def format_output(self) -> Dict[str, Any]: class AgentOutputParser(BaseModel): @staticmethod - def load_json_output(message: BaseMessage) -> Dict[str, Any]: + def load_json_output(message: BaseMessage, llm: BaseLanguageModel) -> Dict[str, Any]: """If the message contains a json response, try to parse it into dictionary""" text = message.content clean_text = "" @@ -65,14 +66,13 @@ def load_json_output(message: BaseMessage) -> Dict[str, Any]: clean_text = text[text.index("{") : text.rindex("}") + 1].strip() response = json.loads(clean_text) except Exception: - llm = ChatOpenAI(temperature=0) message = [ UserMessage( content=f"""Fix the following json into correct format -```json -{clean_text} -``` -""" + ```json + {clean_text} + ``` + """ ) ] full_output: Generation = llm.generate(message).generations[0] From b7ed9c2c0fe598436d6825ccb3a0108f6e7599c4 Mon Sep 17 00:00:00 2001 From: tugrulguner Date: Mon, 21 Aug 2023 22:35:38 -0400 Subject: [PATCH 2/8] json correction generalized with max_retry --- autochain/agent/structs.py | 94 ++++++++++++++++++++++++++++++-------- 1 file changed, 75 insertions(+), 19 deletions(-) diff --git a/autochain/agent/structs.py b/autochain/agent/structs.py index 8b2839d..de1536c 100644 --- a/autochain/agent/structs.py +++ b/autochain/agent/structs.py @@ -1,17 +1,15 @@ import json -import re from abc import abstractmethod from typing import Union, Any, Dict, List +from colorama import Fore -from autochain.models.base import Generation - -from autochain.models.chat_openai import ChatOpenAI from autochain.models.base import BaseLanguageModel from pydantic import BaseModel +from autochain.models.base import Generation from autochain.agent.message import BaseMessage, UserMessage from autochain.chain import constants -from autochain.errors import OutputParserException +from autochain.utils import print_with_color class AgentAction(BaseModel): @@ -56,29 +54,87 @@ def format_output(self) -> Dict[str, Any]: class AgentOutputParser(BaseModel): - @staticmethod - def load_json_output(message: BaseMessage, llm: BaseLanguageModel) -> Dict[str, Any]: - """If the message contains a json response, try to parse it into dictionary""" - text = message.content - clean_text = "" + def load_json_output( + self, + message: BaseMessage, + llm: BaseLanguageModel, + max_retry=3 + ) -> Dict[str, Any]: + """Try to parse JSON response from the message content.""" + text = message.content + print('Message: ', message) + clean_text = self._extract_json_text(text) + print('Clean text: ', clean_text) try: - clean_text = text[text.index("{") : text.rindex("}") + 1].strip() response = json.loads(clean_text) except Exception: - message = [ - UserMessage( - content=f"""Fix the following json into correct format + print_with_color( + 'Generating JSON format attempt FAILED! Trying Again...', + Fore.RED + ) + message = self._fix_message(clean_text) + print('Check Message: ', message) + full_output: Generation = llm.generate(message).generations[0] + print("TEST: ", full_output) + response = self._attempt_fix_and_generate(message, llm, max_retry, attempt=0) + + return response + + @staticmethod + def _fix_message(clean_text: str) -> UserMessage: + message = [UserMessage( + content=f""" + Fix the following json into correct format ```json {clean_text} ``` """ - ) - ] - full_output: Generation = llm.generate(message).generations[0] + )] + return message + + @staticmethod + def _extract_json_text(text: str) -> str: + """Extract JSON text from the input string.""" + clean_text = "" + try: + clean_text = text[text.index("{") : text.rindex("}") + 1].strip() + except Exception: + clean_text = text + return clean_text + + def _attempt_fix_and_generate( + self, + message: BaseMessage, + llm: BaseLanguageModel, + max_retry: int, + attempt: int + ) -> Dict[str, Any]: + + """Attempt to fix JSON format using model generation.""" + if attempt >= max_retry: + raise ValueError( + """ + Max retry reached. Model is unable to generate proper JSON output. + Try with another Model! + """ + ) + + full_output: Generation = llm.generate([message]).generations[0] + print('model output: ', full_output) + try: response = json.loads(full_output.message.content) - - return response + return response + except Exception: + print_with_color( + 'Generating JSON format attempt FAILED! Trying Again...', + Fore.RED + ) + clean_text = self._extract_json_text(full_output.message.content) + print("Clean Text: ", clean_text) + message = self._fix_message(clean_text) + print('Message: ', message) + return self._attempt_fix_and_generate(message, llm, max_retry, attempt=attempt + 1) @abstractmethod def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]: From d8bab2c9a611ab0cde8e1d158e134de4e8bda642 Mon Sep 17 00:00:00 2001 From: tugrulguner Date: Mon, 21 Aug 2023 22:37:05 -0400 Subject: [PATCH 3/8] passing llm into parse --- .../agent/conversational_agent/conversational_agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/autochain/agent/conversational_agent/conversational_agent.py b/autochain/agent/conversational_agent/conversational_agent.py index cf93324..a1aa3b4 100644 --- a/autochain/agent/conversational_agent/conversational_agent.py +++ b/autochain/agent/conversational_agent/conversational_agent.py @@ -165,6 +165,7 @@ def plan( tool_strings = "\n\n".join( [f"> {tool.name}: \n{tool.description}" for tool in self.tools] ) + inputs = { "tool_names": tool_names, "tools": tool_strings, @@ -175,11 +176,14 @@ def plan( final_prompt = self.format_prompt( self.prompt_template, intermediate_steps, **inputs ) + logger.info(f"\nPlanning Input: {final_prompt[0].content} \n") full_output: Generation = self.llm.generate(final_prompt).generations[0] + agent_output: Union[AgentAction, AgentFinish] = self.output_parser.parse( - full_output.message + full_output.message, + self.llm ) print(f"Planning output: \n{repr(full_output.message.content)}", Fore.YELLOW) From 1aeb323b69a534462c8cd1aacc2612f007c35347 Mon Sep 17 00:00:00 2001 From: tugrulguner Date: Mon, 21 Aug 2023 22:43:28 -0400 Subject: [PATCH 4/8] Corrections made for message --- autochain/agent/structs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/autochain/agent/structs.py b/autochain/agent/structs.py index de1536c..3bf6130 100644 --- a/autochain/agent/structs.py +++ b/autochain/agent/structs.py @@ -75,7 +75,7 @@ def load_json_output( ) message = self._fix_message(clean_text) print('Check Message: ', message) - full_output: Generation = llm.generate(message).generations[0] + full_output: Generation = llm.generate([message]).generations[0] print("TEST: ", full_output) response = self._attempt_fix_and_generate(message, llm, max_retry, attempt=0) @@ -83,14 +83,14 @@ def load_json_output( @staticmethod def _fix_message(clean_text: str) -> UserMessage: - message = [UserMessage( + message = UserMessage( content=f""" Fix the following json into correct format ```json {clean_text} ``` """ - )] + ) return message @staticmethod From be7d5ed09cb6034af4086898993f39a101718716 Mon Sep 17 00:00:00 2001 From: tugrulguner Date: Tue, 22 Aug 2023 13:21:44 -0400 Subject: [PATCH 5/8] more debugging for json correction --- autochain/agent/structs.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/autochain/agent/structs.py b/autochain/agent/structs.py index 3bf6130..c7cbe5b 100644 --- a/autochain/agent/structs.py +++ b/autochain/agent/structs.py @@ -63,9 +63,8 @@ def load_json_output( ) -> Dict[str, Any]: """Try to parse JSON response from the message content.""" text = message.content - print('Message: ', message) clean_text = self._extract_json_text(text) - print('Clean text: ', clean_text) + try: response = json.loads(clean_text) except Exception: @@ -74,7 +73,6 @@ def load_json_output( Fore.RED ) message = self._fix_message(clean_text) - print('Check Message: ', message) full_output: Generation = llm.generate([message]).generations[0] print("TEST: ", full_output) response = self._attempt_fix_and_generate(message, llm, max_retry, attempt=0) @@ -131,9 +129,8 @@ def _attempt_fix_and_generate( Fore.RED ) clean_text = self._extract_json_text(full_output.message.content) - print("Clean Text: ", clean_text) message = self._fix_message(clean_text) - print('Message: ', message) + print("MESSAGE: ", message) return self._attempt_fix_and_generate(message, llm, max_retry, attempt=attempt + 1) @abstractmethod From c76155619732d2362ef4d7f3dab6c7fcb226a003 Mon Sep 17 00:00:00 2001 From: tugrulguner Date: Tue, 22 Aug 2023 22:11:01 -0400 Subject: [PATCH 6/8] comment and to do added --- autochain/agent/structs.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/autochain/agent/structs.py b/autochain/agent/structs.py index c7cbe5b..20c125e 100644 --- a/autochain/agent/structs.py +++ b/autochain/agent/structs.py @@ -81,6 +81,14 @@ def load_json_output( @staticmethod def _fix_message(clean_text: str) -> UserMessage: + ''' + If the response from model is not proper, this function should + iteratively construct better response until response becomes json parseable + ''' + + # TO DO + # Construct this message better in order to make it better iteratively by + # _attempt_fix_and_generate recursive function message = UserMessage( content=f""" Fix the following json into correct format @@ -109,7 +117,7 @@ def _attempt_fix_and_generate( attempt: int ) -> Dict[str, Any]: - """Attempt to fix JSON format using model generation.""" + """Attempt to fix JSON format using model generation recursively.""" if attempt >= max_retry: raise ValueError( """ From a2afbf1244887df62f9ce333f57b7ebc45422928 Mon Sep 17 00:00:00 2001 From: tugrulguner Date: Tue, 22 Aug 2023 22:26:33 -0400 Subject: [PATCH 7/8] clearing prints and modifying initial llm gen --- autochain/agent/structs.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/autochain/agent/structs.py b/autochain/agent/structs.py index 20c125e..64c4b54 100644 --- a/autochain/agent/structs.py +++ b/autochain/agent/structs.py @@ -74,8 +74,12 @@ def load_json_output( ) message = self._fix_message(clean_text) full_output: Generation = llm.generate([message]).generations[0] - print("TEST: ", full_output) - response = self._attempt_fix_and_generate(message, llm, max_retry, attempt=0) + response = self._attempt_fix_and_generate( + full_output.message.content, + llm, + max_retry, + attempt=0 + ) return response @@ -127,7 +131,7 @@ def _attempt_fix_and_generate( ) full_output: Generation = llm.generate([message]).generations[0] - print('model output: ', full_output) + try: response = json.loads(full_output.message.content) return response @@ -138,7 +142,6 @@ def _attempt_fix_and_generate( ) clean_text = self._extract_json_text(full_output.message.content) message = self._fix_message(clean_text) - print("MESSAGE: ", message) return self._attempt_fix_and_generate(message, llm, max_retry, attempt=attempt + 1) @abstractmethod From 4ccfbdd2ff0c6e0d241c4bf9935f06209f1b9765 Mon Sep 17 00:00:00 2001 From: tugrulguner Date: Tue, 22 Aug 2023 22:31:06 -0400 Subject: [PATCH 8/8] reversing back the llm gen parameter --- autochain/agent/structs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/autochain/agent/structs.py b/autochain/agent/structs.py index 64c4b54..f7efcb5 100644 --- a/autochain/agent/structs.py +++ b/autochain/agent/structs.py @@ -73,9 +73,8 @@ def load_json_output( Fore.RED ) message = self._fix_message(clean_text) - full_output: Generation = llm.generate([message]).generations[0] response = self._attempt_fix_and_generate( - full_output.message.content, + message, llm, max_retry, attempt=0