Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open-source LLM from HuggingFace Agent Support for json output parsing #130

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion autochain/agent/conversational_agent/conversational_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def plan(
tool_strings = "\n\n".join(
[f"> {tool.name}: \n{tool.description}" for tool in self.tools]
)

inputs = {
"tool_names": tool_names,
"tools": tool_strings,
Expand All @@ -175,11 +176,14 @@ def plan(
final_prompt = self.format_prompt(
self.prompt_template, intermediate_steps, **inputs
)

logger.info(f"\nPlanning Input: {final_prompt[0].content} \n")

full_output: Generation = self.llm.generate(final_prompt).generations[0]

agent_output: Union[AgentAction, AgentFinish] = self.output_parser.parse(
full_output.message
full_output.message,
self.llm
)

print(f"Planning output: \n{repr(full_output.message.content)}", Fore.YELLOW)
Expand Down
5 changes: 3 additions & 2 deletions autochain/agent/conversational_agent/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

from autochain.agent.message import BaseMessage
from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser
from autochain.models.base import BaseLanguageModel
from autochain.errors import OutputParserException
from autochain.utils import print_with_color


class ConvoJSONOutputParser(AgentOutputParser):
def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]:
response = self.load_json_output(message)
def parse(self, message: BaseMessage, llm: BaseLanguageModel) -> Union[AgentAction, AgentFinish]:
response = self.load_json_output(message, llm)

action_name = response.get("tool", {}).get("name")
action_args = response.get("tool", {}).get("args")
Expand Down
107 changes: 85 additions & 22 deletions autochain/agent/structs.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import json
import re
from abc import abstractmethod
from typing import Union, Any, Dict, List
from colorama import Fore

from autochain.models.base import Generation

from autochain.models.chat_openai import ChatOpenAI
from autochain.models.base import BaseLanguageModel
from pydantic import BaseModel

from autochain.models.base import Generation
from autochain.agent.message import BaseMessage, UserMessage
from autochain.chain import constants
from autochain.errors import OutputParserException
from autochain.utils import print_with_color


class AgentAction(BaseModel):
Expand Down Expand Up @@ -55,30 +54,94 @@ def format_output(self) -> Dict[str, Any]:


class AgentOutputParser(BaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think there is a way to simplify this logics inside of this function
first json load, if run into exception, fix and generate, format into message for load_json_output, decrease retry, call load_json_output again,
if no exception, just return

@staticmethod
def load_json_output(message: BaseMessage) -> Dict[str, Any]:
"""If the message contains a json response, try to parse it into dictionary"""

def load_json_output(
self,
message: BaseMessage,
llm: BaseLanguageModel,
max_retry=3
) -> Dict[str, Any]:
"""Try to parse JSON response from the message content."""
text = message.content
clean_text = ""
clean_text = self._extract_json_text(text)

try:
clean_text = text[text.index("{") : text.rindex("}") + 1].strip()
response = json.loads(clean_text)
except Exception:
llm = ChatOpenAI(temperature=0)
message = [
UserMessage(
content=f"""Fix the following json into correct format
```json
{clean_text}
```
"""
)
]
full_output: Generation = llm.generate(message).generations[0]
response = json.loads(full_output.message.content)
print_with_color(
'Generating JSON format attempt FAILED! Trying Again...',
Fore.RED
)
message = self._fix_message(clean_text)
response = self._attempt_fix_and_generate(
message,
llm,
max_retry,
attempt=0
)

return response

@staticmethod
def _fix_message(clean_text: str) -> UserMessage:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make this is a helper function inside of _attempt_fix_and_generate?

'''
If the response from model is not proper, this function should
iteratively construct better response until response becomes json parseable
'''

# TO DO
# Construct this message better in order to make it better iteratively by
# _attempt_fix_and_generate recursive function
message = UserMessage(
content=f"""
Fix the following json into correct format
```json
{clean_text}
```
"""
)
return message

@staticmethod
def _extract_json_text(text: str) -> str:
"""Extract JSON text from the input string."""
clean_text = ""
try:
clean_text = text[text.index("{") : text.rindex("}") + 1].strip()
except Exception:
clean_text = text
return clean_text

def _attempt_fix_and_generate(
self,
message: BaseMessage,
llm: BaseLanguageModel,
max_retry: int,
attempt: int
) -> Dict[str, Any]:

"""Attempt to fix JSON format using model generation recursively."""
if attempt >= max_retry:
raise ValueError(
"""
Max retry reached. Model is unable to generate proper JSON output.
Try with another Model!
"""
)

full_output: Generation = llm.generate([message]).generations[0]

try:
response = json.loads(full_output.message.content)
return response
except Exception:
print_with_color(
'Generating JSON format attempt FAILED! Trying Again...',
Fore.RED
)
clean_text = self._extract_json_text(full_output.message.content)
message = self._fix_message(clean_text)
return self._attempt_fix_and_generate(message, llm, max_retry, attempt=attempt + 1)

@abstractmethod
def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]:
Expand Down