diff --git a/syftr/agent_flows.py b/syftr/agent_flows.py index 3a66a9c..be28166 100644 --- a/syftr/agent_flows.py +++ b/syftr/agent_flows.py @@ -13,6 +13,7 @@ from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine from llama_index.core.tools import QueryEngineTool, ToolMetadata +from syftr.agents.react import ReActAgent as ReActAgentV2 from syftr.logger import logger @@ -160,6 +161,46 @@ def __init__( self.agent.reset() +class LlamaIndexReactRAGAgentFlowV2(AgentFlow, AgentMixin): + def __init__( + self: AgentProtocol, + indexes: T.List[T.Tuple[str, str, VectorStoreIndex]], + llm: FunctionCallingLLM, + system_prompt: PromptTemplate | None = None, + template: str | None = None, + **kwargs, + ): + super().__init__(llm=llm, template=template, **kwargs) # type: ignore + + prompt_template: PromptTemplate = self.get_prompt_template() + query_engines = [ + index.as_query_engine( + llm=self.llm, + text_qa_template=prompt_template, + node_postprocessors=self.node_postprocessors, + ) + for _, _, index in indexes + ] + query_engine_tools = [ + QueryEngineTool( + query_engine=query_engine, + metadata=ToolMetadata( + name="Grounding_%s" % name, + description="Grounding data related to %s" % desc, + ), + ) + for (name, desc, _), query_engine in zip(indexes, query_engines) + ] + self.agent = ReActAgentV2( + tools=query_engine_tools, # type: ignore + llm=self.llm, + verbose=True, + ) + if system_prompt: + self.agent.update_prompts({"agent_worker:system_prompt": system_prompt}) + self.agent.reset() + + class FlowJSONDecodeError(Exception): pass diff --git a/syftr/agents/react.py b/syftr/agents/react.py new file mode 100644 index 0000000..169da81 --- /dev/null +++ b/syftr/agents/react.py @@ -0,0 +1,267 @@ +import asyncio +from typing import Any + +from llama_index.core.agent.react import ReActChatFormatter, ReActOutputParser +from llama_index.core.agent.react.types import ( + ActionReasoningStep, + ObservationReasoningStep, +) +from llama_index.core.llms import ChatMessage, ChatResponse +from llama_index.core.llms.llm import LLM +from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.tools import ToolOutput, ToolSelection +from llama_index.core.tools.types import BaseTool +from llama_index.core.workflow import ( + Context, + Event, + StartEvent, + StopEvent, + Workflow, + step, +) + + +class PrepEvent(Event): + pass + + +class InputEvent(Event): + input: list[ChatMessage] + + +class StreamEvent(Event): + delta: str + + +class ToolCallEvent(Event): + tool_calls: list[ToolSelection] + + +class FunctionOutputEvent(Event): + output: ToolOutput + + +class ReActAgent(Workflow): + """Implement of RAG using a ReAcT Agent""" + + def __init__( + self, + *args: Any, + llm: LLM, + tools: list[BaseTool] | None = None, + extra_context: str | None = None, + max_iterations: int = 10, + default_tool_choice="required", + react_chat_formatter: ReActChatFormatter | None = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self.tools = tools or [] + self.llm = llm + self.formatter = react_chat_formatter or ReActChatFormatter.from_defaults( + context=extra_context or "" + ) + self.output_parser = ReActOutputParser() + self.default_tool_choice = default_tool_choice + self.max_iterations = max_iterations + + @step + async def new_user_msg(self, ctx: Context, ev: StartEvent) -> PrepEvent: + # clear sources + await ctx.store.set("sources", []) + + # init memory if needed + memory = await ctx.store.get("memory", default=None) + if not memory: + memory = ChatMemoryBuffer.from_defaults(llm=self.llm) + + # get user input + user_input = ev.input + user_msg = ChatMessage(role="user", content=user_input) + memory.put(user_msg) + + # clear current reasoning + await ctx.store.set("current_reasoning", []) + + # set memory + await ctx.store.set("memory", memory) + await ctx.store.set("iteration_counter", 0) + + return PrepEvent() + + @step + async def prepare_chat_history(self, ctx: Context, ev: PrepEvent) -> InputEvent: + # get chat history + memory = await ctx.store.get("memory") + chat_history = memory.get() + current_reasoning = await ctx.store.get("current_reasoning", default=[]) + + # format the prompt with react instructions + llm_input = self.formatter.format( + self.tools, chat_history, current_reasoning=current_reasoning + ) + return InputEvent(input=llm_input) + + @step + async def handle_llm_input( + self, ctx: Context, ev: InputEvent + ) -> ToolCallEvent | StopEvent: + chat_history = ev.input + current_reasoning = await ctx.store.get("current_reasoning", default=[]) + memory = await ctx.store.get("memory") + iteration = await ctx.store.get("iteration_counter") + print(iteration) + if iteration == self.max_iterations: + sources = await ctx.store.get("sources", default=[]) + memory = await ctx.store.get("memory") + response = memory.get()[-1].content + return StopEvent( + result={ + "response": response, + "sources": [sources], + "reasoning": current_reasoning, + } + ) + + response = self.llm.chat( + chat_history, + tools=[x.metadata.to_openai_tool() for x in self.tools], + # use default for the first iteration, then default to auto + # that allows it to stop + tool_choice=self.default_tool_choice if iteration == 0 else "auto", + ) + try: + tool_calls = response.raw.choices[0].message.tool_calls + except AttributeError: # Ocassionaly get Response which has different schema + tool_calls = [] + if tool_calls: # Convert tool calls into proposed text action + tool = tool_calls[0].function + message_content = ( + "Thought: I need to use a tool to help me answer the question.\n" + + f"Action: {tool.name}\nAction Input: {tool.arguments}" + ) + else: + message_content = response.message.content or "" + ctx.write_event_to_stream(StreamEvent(delta=message_content)) + try: + reasoning_step = self.output_parser.parse(message_content) + current_reasoning.append(reasoning_step) + if reasoning_step.is_done: + memory.put( + ChatMessage(role="assistant", content=reasoning_step.response) + ) + await ctx.store.set("memory", memory) + await ctx.store.set("current_reasoning", current_reasoning) + + sources = await ctx.store.get("sources", default=[]) + + return StopEvent( + result={ + "response": reasoning_step.response, + "sources": [sources], + "reasoning": current_reasoning, + } + ) + elif isinstance(reasoning_step, ActionReasoningStep): + tool_name = reasoning_step.action + tool_args = reasoning_step.action_input + return ToolCallEvent( + tool_calls=[ + ToolSelection( + tool_id="fake", + tool_name=tool_name, + tool_kwargs=tool_args, + ) + ] + ) + except Exception as e: + current_reasoning.append( + ObservationReasoningStep( + observation=f"There was an error in parsing my reasoning: {e}" + ) + ) + await ctx.store.set("current_reasoning", current_reasoning) + + iteration += 1 + await ctx.store.set("iteration_counter", iteration) + # if no tool calls or final response, iterate again + return PrepEvent() + + @step + async def handle_tool_calls(self, ctx: Context, ev: ToolCallEvent) -> PrepEvent: + tool_calls = ev.tool_calls + tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools} + current_reasoning = await ctx.store.get("current_reasoning", default=[]) + sources = await ctx.store.get("sources", default=[]) + + # call tools -- safely! + for tool_call in tool_calls: + tool = tools_by_name.get(tool_call.tool_name) + if not tool: + current_reasoning.append( + ObservationReasoningStep( + observation=f"Tool {tool_call.tool_name} does not exist" + ) + ) + continue + + try: + tool_output = tool(**tool_call.tool_kwargs) + sources.append(tool_output) + current_reasoning.append( + ObservationReasoningStep(observation=tool_output.content) + ) + except Exception as e: + current_reasoning.append( + ObservationReasoningStep( + observation=f"Error calling tool {tool.metadata.get_name()}: {e}" + ) + ) + + # save new state in context + await ctx.store.set("sources", sources) + await ctx.store.set("current_reasoning", current_reasoning) + + # iterate the counter + iteration = await ctx.store.get("iteration_counter") + iteration += 1 + await ctx.store.set("iteration_counter", iteration) + # if no tool calls or final response, iterate again + + # prep the next iteraiton + return PrepEvent() + + def chat(self, query, invocation_id=None) -> ChatResponse: + # Backwards compatibility which used chat + loop = asyncio.get_event_loop() + if loop.is_running(): + raise RuntimeError() + else: + + async def wrap_run(): + result = await self.run(input=query) + result = result["response"] + print(result) + result = ChatResponse( + message=ChatMessage.from_str(result), additional_kwargs={} + ) + + return result + + return loop.run_until_complete(wrap_run()) + + +class IntrospectiveAgentWorker(Workflow): + pass + + +class ToolInteractiveReflectionAgentWorker(Workflow): + pass + + +class LATSAgentWorker(Workflow): + pass + + +class CoAAgentPack(Workflow): + pass diff --git a/syftr/flows.py b/syftr/flows.py index 0ab21c8..db7a4af 100644 --- a/syftr/flows.py +++ b/syftr/flows.py @@ -7,14 +7,15 @@ from uuid import uuid4 import llama_index.core.instrumentation as instrument -from llama_index.agent.introspective import ( - IntrospectiveAgentWorker, - ToolInteractiveReflectionAgentWorker, -) -from llama_index.agent.introspective.reflective.tool_interactive_reflection import ( - StoppingCallable, -) -from llama_index.agent.lats import LATSAgentWorker + +# from llama_index.agent.introspective import ( +# IntrospectiveAgentWorker, +# ToolInteractiveReflectionAgentWorker, +# ) +# from llama_index.agent.introspective.reflective.tool_interactive_reflection import ( +# StoppingCallable, +# ) +# from llama_index.agent.lats import LATSAgentWorker from llama_index.core import ( PromptTemplate, QueryBundle, @@ -23,14 +24,19 @@ ) from llama_index.core.agent import ( AgentChatResponse, - AgentRunner, - FunctionCallingAgentWorker, + AgentWorkflow, + FunctionAgent, ReActAgent, ) from llama_index.core.agent.react.formatter import ReActChatFormatter from llama_index.core.evaluation import EvaluationResult from llama_index.core.indices.query.query_transform.base import HyDEQueryTransform -from llama_index.core.llms import ChatMessage, CompletionResponse, MessageRole +from llama_index.core.llms import ( + ChatMessage, + ChatResponse, + CompletionResponse, + MessageRole, +) from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.llms.llm import LLM from llama_index.core.postprocessor import LLMRerank, PrevNextNodePostprocessor @@ -47,9 +53,17 @@ from llama_index.core.schema import NodeWithScore from llama_index.core.storage.docstore.types import BaseDocumentStore from llama_index.core.tools import BaseTool, FunctionTool, QueryEngineTool, ToolMetadata -from llama_index.packs.agents_coa import CoAAgentPack + +# from llama_index.packs.agents_coa import CoAAgentPack from numpy import ceil +from syftr.agents.react import ( + CoAAgentPack, + IntrospectiveAgentWorker, + LATSAgentWorker, + ToolInteractiveReflectionAgentWorker, +) +from syftr.agents.react import ReActAgent as ReActAgentV2 from syftr.configuration import cfg from syftr.instrumentation.arize import instrument_arize from syftr.instrumentation.tokens import LLMCallData, TokenTrackingEventHandler @@ -671,7 +685,7 @@ def tools(self) -> T.List[BaseTool]: ] @property - def agent(self) -> AgentRunner: + def agent(self) -> AgentWorkflow: raise NotImplementedError() @dispatcher.span @@ -682,9 +696,9 @@ def _generate( ) -> T.Tuple[CompletionResponse, float]: start_time = time.perf_counter() query = self.set_thinking(query) - response: AgentChatResponse = self.agent.chat(query) + response: AgentChatResponse | ChatResponse = self.agent.chat(query) try: - completion_response = CompletionResponse(text=response.response) + completion_response = CompletionResponse(text=str(response)) except TypeError: logger.error("Incorrect response from an agent: %s", response) raise @@ -731,7 +745,7 @@ def __repr__(self): return f"{self.name}: {self.params}" @property - def agent(self) -> AgentRunner: + def agent(self) -> AgentWorkflow: synth = get_response_synthesizer( llm=self.subquestion_response_synthesizer_llm, use_async=True, @@ -767,6 +781,54 @@ def agent(self) -> AgentRunner: ) +@dataclass(kw_only=True) +class ReActAgentFlowV2(AgenticRAGFlow): + react_prompt: str = get_react_template() + subquestion_engine_llm: LLM + max_iterations: int = 10 + name: str = "ReAct Agentic RAG Flow" + subquestion_response_synthesizer_llm: LLM + + def __repr__(self): + return f"{self.name}: {self.params}" + + @property + def agent(self) -> AgentWorkflow: + synth = get_response_synthesizer( + llm=self.subquestion_response_synthesizer_llm, + use_async=True, + response_mode=ResponseMode.TREE_SUMMARIZE, + text_qa_template=self.prompt_template, + ) + sub_question_engine = SubQuestionQueryEngine.from_defaults( + query_engine_tools=self.tools, # type: ignore + llm=self.subquestion_engine_llm, + verbose=self.verbose, + response_synthesizer=synth, + use_async=False, + ) + + tools = [ + QueryEngineTool( + query_engine=sub_question_engine, + metadata=ToolMetadata( + name=self.dataset_name.replace("/", "_"), + description=self.dataset_description, + ), + ) + ] + + formatter = ReActChatFormatter.from_defaults(system_header=self.react_prompt) + return ReActAgentV2( + tools=tools, # type: ignore + llm=self.response_synthesizer_llm, + max_iterations=self.max_iterations, + react_chat_formatter=formatter, + default_tool_choice="auto", + verbose=self.verbose, + ) + + @dataclass(kw_only=True) class CritiqueAgentFlow(AgenticRAGFlow): subquestion_engine_llm: LLM @@ -775,18 +837,18 @@ class CritiqueAgentFlow(AgenticRAGFlow): reflection_agent_llm: FunctionCallingLLM max_iterations: int = 10 critique_template = get_critique_template() - stopping_condition: StoppingCallable = lambda critique_str: "PASS" in critique_str + stopping_condition: T.Callable = lambda critique_str: "PASS" in critique_str name: str = "Critique Agentic RAG Flow" def __repr__(self): return f"{self.name}: {self.params}" @property - def agent(self) -> AgentRunner: + def agent(self) -> AgentWorkflow: assert isinstance(self.response_synthesizer_llm, FunctionCallingLLM), ( f"CritiqueAgentFlow requires FunctionCallingLLM. Got {type(self.response_synthesizer_llm)=}" ) - main_worker = FunctionCallingAgentWorker.from_tools( + main_worker = FunctionAgent.from_tools( tools=self.tools, llm=self.response_synthesizer_llm, verbose=self.verbose ) synth = get_response_synthesizer( @@ -813,7 +875,7 @@ def agent(self) -> AgentRunner: ) ] - critique_agent_worker = FunctionCallingAgentWorker.from_tools( + critique_agent_worker = FunctionAgent.from_tools( tools=tools, # type: ignore llm=self.critique_agent_llm, verbose=self.verbose, @@ -872,7 +934,7 @@ def tools(self) -> T.List[BaseTool]: ] @property - def agent(self) -> AgentRunner: + def agent(self) -> AgentWorkflow: agent_worker = LATSAgentWorker.from_tools( self.tools, llm=self.response_synthesizer_llm, @@ -928,7 +990,7 @@ def divide(a: int, b: int): return tools @property - def agent(self) -> AgentRunner: + def agent(self) -> AgentWorkflow: pack = CoAAgentPack( tools=self.tools, llm=self.response_synthesizer_llm,