diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index b24b56e5..fc0eb4d9 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -121,7 +121,10 @@ async def run_agent( else: raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.") - if "cannot answer" in answer.answer.lower() and agent_status != AgentStatus.TIMEOUT: + if ( + "cannot answer" in answer.answer.lower() + and agent_status != AgentStatus.TRUNCATED + ): agent_status = AgentStatus.UNSURE # stop after, so overall isn't reported as long-running step. logger.info( @@ -140,6 +143,11 @@ async def run_fake_agent( ) = None, **env_kwargs, ) -> tuple[Answer, AgentStatus]: + if query.settings.agent.max_timesteps is not None: + logger.warning( + f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not" + " applicable with the fake agent, ignoring it." + ) env = PaperQAEnvironment(query, docs, **env_kwargs) _, tools = await env.reset() if on_env_reset_callback: @@ -209,7 +217,14 @@ async def run_aviary_agent( tools=tools, ) + timestep, max_timesteps = 0, query.settings.agent.max_timesteps while not done: + if max_timesteps is not None and timestep >= max_timesteps: + logger.warning( + f"Agent didn't finish within {max_timesteps} timesteps, just answering." + ) + await tools[-1]._tool_fn(question=query.query, state=env.state) + return env.state.answer, AgentStatus.TRUNCATED agent_state.messages += obs for attempt in Retrying( stop=stop_after_attempt(5), @@ -226,12 +241,13 @@ async def run_aviary_agent( obs, reward, done, truncated = await env.step(action) if on_env_step_callback: await on_env_step_callback(obs, reward, done, truncated) + timestep += 1 status = AgentStatus.SUCCESS except TimeoutError: logger.warning( f"Agent timeout after {query.settings.agent.timeout}-sec, just answering." ) - status = AgentStatus.TIMEOUT + status = AgentStatus.TRUNCATED await tools[-1]._tool_fn(question=query.query, state=env.state) except Exception: logger.exception(f"Agent {agent} failed.") @@ -250,6 +266,11 @@ async def run_ldp_agent( ) = None, **env_kwargs, ) -> tuple[Answer, AgentStatus]: + if query.settings.agent.max_timesteps is not None: + logger.warning( + f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not" + " yet implemented with the ldp agent, ignoring it." + ) env = PaperQAEnvironment(query, docs, **env_kwargs) done = False @@ -274,7 +295,7 @@ async def run_ldp_agent( logger.warning( f"Agent timeout after {query.settings.agent.timeout}-sec, just answering." ) - status = AgentStatus.TIMEOUT + status = AgentStatus.TRUNCATED await tools[-1]._tool_fn(question=query.query, state=env.state) except Exception: logger.exception(f"Agent {agent} failed.") diff --git a/paperqa/agents/models.py b/paperqa/agents/models.py index e35c92fa..266f3b5a 100644 --- a/paperqa/agents/models.py +++ b/paperqa/agents/models.py @@ -39,8 +39,9 @@ class AgentStatus(StrEnum): FAIL = "fail" # SUCCESS - answer was generated SUCCESS = "success" - # TIMEOUT - agent took too long, but an answer was generated - TIMEOUT = "timeout" + # TRUNCATED - agent didn't finish naturally (e.g. timeout, too many actions), + # so we prematurely answered + TRUNCATED = "truncated" # UNSURE - the agent was unsure, but an answer is present UNSURE = "unsure" diff --git a/paperqa/settings.py b/paperqa/settings.py index abb9e298..690e223a 100644 --- a/paperqa/settings.py +++ b/paperqa/settings.py @@ -406,6 +406,10 @@ class AgentSettings(BaseModel): " supplied." ), ) + max_timesteps: int | None = Field( + default=None, + description="Optional upper limit on the number of environment steps.", + ) index_concurrency: int = Field( default=30, diff --git a/tests/test_agents.py b/tests/test_agents.py index fa6acd8f..633b197f 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -243,7 +243,7 @@ async def test_timeout(agent_test_settings: Settings, agent_type: str | type) -> agent_type=agent_type, ) # ensure that GenerateAnswerTool was called - assert response.status == AgentStatus.TIMEOUT, "Agent did not timeout" + assert response.status == AgentStatus.TRUNCATED, "Agent did not timeout" assert "I cannot answer" in response.answer.answer @@ -287,18 +287,28 @@ async def test_gather_evidence_rejects_empty_docs() -> None: # lead to an unsure status, which will break this test's assertions. Since # this test is about a GatherEvidenceTool edge case, defeating # GenerateAnswerTool is fine + original_doc = GenerateAnswer.gen_answer.__doc__ with patch.object( - GenerateAnswer, "gen_answer", return_value="Failed to answer question." - ): - settings = Settings() - settings.agent.tool_names = {"gather_evidence", "gen_answer"} + GenerateAnswer, + "gen_answer", + return_value="Failed to answer question.", + autospec=True, + ) as mock_gen_answer: + mock_gen_answer.__doc__ = original_doc + settings = Settings( + agent=AgentSettings( + tool_names={"gather_evidence", "gen_answer"}, max_timesteps=3 + ) + ) response = await agent_query( query=QueryRequest( query="Are COVID-19 vaccines effective?", settings=settings ), docs=Docs(), ) - assert response.status == AgentStatus.FAIL, "Agent should have registered a failure" + assert ( + response.status == AgentStatus.TRUNCATED + ), "Agent should have hit its max timesteps" @pytest.mark.flaky(reruns=3, only_rerun=["AssertionError", "EmptyDocsError"])