Skip to content

Commit

Permalink
Fix the RAG question rewrite of the retriever node
Browse files Browse the repository at this point in the history
  • Loading branch information
fjsj committed Sep 12, 2024
1 parent 81d0ff6 commit f21cf01
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions django_ai_assistant/helpers/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,17 @@ class AgentState(TypedDict):
def setup(state: AgentState):
return {"messages": [SystemMessage(content=self.get_instructions())]}

def history(state: AgentState):
history = message_history.messages if message_history else []
return {"messages": [*history, HumanMessage(content=state["input"])]}

def retriever(state: AgentState):
if not self.has_rag:
return

retriever = self.get_history_aware_retriever()
docs = retriever.invoke({"input": state["input"], "history": state["messages"]})
messages_without_input = state["messages"][:-1]
docs = retriever.invoke({"input": state["input"], "history": messages_without_input})

document_separator = self.get_document_separator()
document_prompt = self.get_document_prompt()
Expand All @@ -444,10 +449,6 @@ def retriever(state: AgentState):
)
}

def history(state: AgentState):
history = message_history.messages if message_history else []
return {"messages": [*history, HumanMessage(content=state["input"])]}

def agent(state: AgentState):
response = llm_with_tools.invoke(state["messages"])

Expand All @@ -467,16 +468,16 @@ def record_response(state: AgentState):
workflow = StateGraph(AgentState)

workflow.add_node("setup", setup)
workflow.add_node("retriever", retriever)
workflow.add_node("history", history)
workflow.add_node("retriever", retriever)
workflow.add_node("agent", agent)
workflow.add_node("tools", ToolNode(tools))
workflow.add_node("respond", record_response)

workflow.set_entry_point("setup")
workflow.add_edge("setup", "retriever")
workflow.add_edge("retriever", "history")
workflow.add_edge("history", "agent")
workflow.add_edge("setup", "history")
workflow.add_edge("history", "retriever")
workflow.add_edge("retriever", "agent")
workflow.add_conditional_edges(
"agent",
tool_selector,
Expand Down

0 comments on commit f21cf01

Please sign in to comment.