Skip to content

Commit

Permalink
enables structured output assistants
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeximenes committed Sep 12, 2024
1 parent 81d0ff6 commit 83410a9
Show file tree
Hide file tree
Showing 7 changed files with 582 additions and 43 deletions.
77 changes: 74 additions & 3 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import abc
import inspect
import json
import re
from typing import Annotated, Any, ClassVar, Sequence, TypedDict, cast
from typing import Annotated, Any, ClassVar, Dict, Sequence, Type, TypedDict, cast

from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
Expand Down Expand Up @@ -37,6 +38,7 @@
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel

from django_ai_assistant.decorators import with_cast_id
from django_ai_assistant.exceptions import (
Expand Down Expand Up @@ -79,6 +81,12 @@ class AIAssistant(abc.ABC): # noqa: F821
When True, the assistant will use a retriever to get documents to provide as context to the LLM.
Additionally, the assistant class should implement the `get_retriever` method to return
the retriever to use."""
structured_output: Dict[str, Any] | Type[BaseModel] | Type | None = None
"""Structured output to use for the assistant.\n
Defaults to `None`.
When not `None`, the assistant will return a structured output in the provided format.
See https://python.langchain.com/v0.2/docs/how_to/structured_output/ for the available formats.
"""
_user: Any | None
"""The current user the assistant is helping. A model instance.\n
Set by the constructor.
Expand Down Expand Up @@ -269,6 +277,27 @@ def get_llm(self) -> BaseChatModel:
model_kwargs=model_kwargs,
)

def get_structured_output_llm(self) -> Runnable:
"""Get the LLM model to use for the structured output.
By default, this is the `get_llm` method.
Returns:
BaseChatModel: The LLM model to use for the structured output.
"""
if not self.structured_output:
raise ValueError("structured_output is not defined")

llm = self.get_llm()

method = "json_mode"
if isinstance(llm, ChatOpenAI):
# When using ChatOpenAI, it's better to use json_schema method
# because it enables strict mode.
# https://platform.openai.com/docs/guides/structured-outputs
method = "json_schema"

return llm.with_structured_output(self.structured_output, method=method)

def get_tools(self) -> Sequence[BaseTool]:
"""Get the list of method tools the assistant can use.
By default, this is the `_method_tools` attribute, which are all `@method_tool`s.\n
Expand Down Expand Up @@ -422,7 +451,36 @@ class AgentState(TypedDict):
output: str

def setup(state: AgentState):
return {"messages": [SystemMessage(content=self.get_instructions())]}
messages = [SystemMessage(content=self.get_instructions())]

if self.structured_output:
schema = None

# If Pydantic
if inspect.isclass(self.structured_output) and issubclass(
self.structured_output, BaseModel
):
schema = json.dumps(self.structured_output.model_json_schema())

schema_information = ""
if schema:
schema_information = f"JSON will have the following schema:\n\n{schema}\n\n"

# The assistant won't have access to the schema of the structured output before
# the last step of the chat. This message gives visibility about what fields the
# response should have so it can gather the necessary information by using tools.
messages.append(
SystemMessage(
content=(
"In the last step of this chat you will be asked to respond in JSON. "
+ schema_information
+ "Gather information using tools. "
"Don't generate JSON until you are explicitly told to. "
)
)
)

return {"messages": messages}

def retriever(state: AgentState):
if not self.has_rag:
Expand Down Expand Up @@ -462,7 +520,20 @@ def tool_selector(state: AgentState):
return "continue"

def record_response(state: AgentState):
return {"output": state["messages"][-1].content}
if self.structured_output:
llm_with_structured_output = self.get_structured_output_llm()
response = llm_with_structured_output.invoke(
[
*state["messages"],
SystemMessage(
content="Use the information gathered in the conversation to answer."
),
]
)
else:
response = state["messages"][-1].content

return {"output": response}

workflow = StateGraph(AgentState)

Expand Down
4 changes: 1 addition & 3 deletions example/demo/views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

from django.contrib import messages
from django.http import JsonResponse
from django.shortcuts import get_object_or_404, redirect, render
Expand Down Expand Up @@ -122,4 +120,4 @@ def get(self, request, *args, **kwargs):
a = TourGuideAIAssistant()
data = a.run(f"My coordinates are: ({coordinates})")

return JsonResponse(json.loads(data))
return JsonResponse(data.model_dump())
42 changes: 15 additions & 27 deletions example/tour_guide/ai_assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,37 @@

from django.utils import timezone

from pydantic import BaseModel, Field

from django_ai_assistant import AIAssistant, method_tool
from tour_guide.integrations import fetch_points_of_interest


def _tour_guide_example_json():
return json.dumps(
{
"nearby_attractions": [
{
"attraction_name": f"<attraction-{i}-name-here>",
"attraction_description": f"<attraction-{i}-description-here>",
"attraction_url": f"<attraction-{i}-imdb-page-url-here>",
}
for i in range(1, 6)
]
},
indent=2,
).translate( # Necessary due to ChatPromptTemplate
str.maketrans(
{
"{": "{{",
"}": "}}",
}
)
class Attraction(BaseModel):
attraction_name: str = Field(description="The name of the attraction in english")
attraction_description: str = Field(
description="The description of the attraction, provide information in an entertaining way"
)
attraction_url: str = Field(
description="The URL of the attraction, keep empty if you don't have this information"
)


class TourGuide(BaseModel):
nearby_attractions: list[Attraction] = Field(description="The list of nearby attractions")


class TourGuideAIAssistant(AIAssistant):
id = "tour_guide_assistant" # noqa: A003
name = "Tour Guide Assistant"
instructions = (
"You are a tour guide assistant that offers information about nearby attractions. "
"You will receive the user coordinates and should use available tools to find nearby attractions. "
"Only include in your response the items that are relevant to a tourist visiting the area. "
"Only call the find_nearby_attractions tool once. "
"Your response should only contain valid JSON data. DON'T include '```json' in your response. "
"The JSON should be formatted according to the following structure: \n"
f"\n\n{_tour_guide_example_json()}\n\n\n"
"In the 'attraction_name' field provide the name of the attraction in english. "
"In the 'attraction_description' field generate an overview about the attraction with the most important information, "
"curiosities and interesting facts. "
"Only include a value for the 'attraction_url' field if you find a real value in the provided data otherwise keep it empty. "
)
model = "gpt-4o-2024-08-06"
structured_output = TourGuide

def get_instructions(self):
# Warning: this will use the server's timezone
Expand Down
89 changes: 80 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 83410a9

Please sign in to comment.