Skip to content

Commit

Permalink
Aviary agent max_timesteps and fixed `test_gather_evidence_rejects_…
Browse files Browse the repository at this point in the history
…empty_docs` (#515)
  • Loading branch information
jamesbraza authored Oct 2, 2024
1 parent 6934144 commit 4fc6138
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 11 deletions.
27 changes: 24 additions & 3 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ async def run_agent(
else:
raise NotImplementedError(f"Didn't yet handle agent type {agent_type}.")

if "cannot answer" in answer.answer.lower() and agent_status != AgentStatus.TIMEOUT:
if (
"cannot answer" in answer.answer.lower()
and agent_status != AgentStatus.TRUNCATED
):
agent_status = AgentStatus.UNSURE
# stop after, so overall isn't reported as long-running step.
logger.info(
Expand All @@ -140,6 +143,11 @@ async def run_fake_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
if query.settings.agent.max_timesteps is not None:
logger.warning(
f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not"
" applicable with the fake agent, ignoring it."
)
env = PaperQAEnvironment(query, docs, **env_kwargs)
_, tools = await env.reset()
if on_env_reset_callback:
Expand Down Expand Up @@ -209,7 +217,14 @@ async def run_aviary_agent(
tools=tools,
)

timestep, max_timesteps = 0, query.settings.agent.max_timesteps
while not done:
if max_timesteps is not None and timestep >= max_timesteps:
logger.warning(
f"Agent didn't finish within {max_timesteps} timesteps, just answering."
)
await tools[-1]._tool_fn(question=query.query, state=env.state)
return env.state.answer, AgentStatus.TRUNCATED
agent_state.messages += obs
for attempt in Retrying(
stop=stop_after_attempt(5),
Expand All @@ -226,12 +241,13 @@ async def run_aviary_agent(
obs, reward, done, truncated = await env.step(action)
if on_env_step_callback:
await on_env_step_callback(obs, reward, done, truncated)
timestep += 1
status = AgentStatus.SUCCESS
except TimeoutError:
logger.warning(
f"Agent timeout after {query.settings.agent.timeout}-sec, just answering."
)
status = AgentStatus.TIMEOUT
status = AgentStatus.TRUNCATED
await tools[-1]._tool_fn(question=query.query, state=env.state)
except Exception:
logger.exception(f"Agent {agent} failed.")
Expand All @@ -250,6 +266,11 @@ async def run_ldp_agent(
) = None,
**env_kwargs,
) -> tuple[Answer, AgentStatus]:
if query.settings.agent.max_timesteps is not None:
logger.warning(
f"Max timesteps (configured {query.settings.agent.max_timesteps}) is not"
" yet implemented with the ldp agent, ignoring it."
)
env = PaperQAEnvironment(query, docs, **env_kwargs)
done = False

Expand All @@ -274,7 +295,7 @@ async def run_ldp_agent(
logger.warning(
f"Agent timeout after {query.settings.agent.timeout}-sec, just answering."
)
status = AgentStatus.TIMEOUT
status = AgentStatus.TRUNCATED
await tools[-1]._tool_fn(question=query.query, state=env.state)
except Exception:
logger.exception(f"Agent {agent} failed.")
Expand Down
5 changes: 3 additions & 2 deletions paperqa/agents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ class AgentStatus(StrEnum):
FAIL = "fail"
# SUCCESS - answer was generated
SUCCESS = "success"
# TIMEOUT - agent took too long, but an answer was generated
TIMEOUT = "timeout"
# TRUNCATED - agent didn't finish naturally (e.g. timeout, too many actions),
# so we prematurely answered
TRUNCATED = "truncated"
# UNSURE - the agent was unsure, but an answer is present
UNSURE = "unsure"

Expand Down
4 changes: 4 additions & 0 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ class AgentSettings(BaseModel):
" supplied."
),
)
max_timesteps: int | None = Field(
default=None,
description="Optional upper limit on the number of environment steps.",
)

index_concurrency: int = Field(
default=30,
Expand Down
22 changes: 16 additions & 6 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ async def test_timeout(agent_test_settings: Settings, agent_type: str | type) ->
agent_type=agent_type,
)
# ensure that GenerateAnswerTool was called
assert response.status == AgentStatus.TIMEOUT, "Agent did not timeout"
assert response.status == AgentStatus.TRUNCATED, "Agent did not timeout"
assert "I cannot answer" in response.answer.answer


Expand Down Expand Up @@ -287,18 +287,28 @@ async def test_gather_evidence_rejects_empty_docs() -> None:
# lead to an unsure status, which will break this test's assertions. Since
# this test is about a GatherEvidenceTool edge case, defeating
# GenerateAnswerTool is fine
original_doc = GenerateAnswer.gen_answer.__doc__
with patch.object(
GenerateAnswer, "gen_answer", return_value="Failed to answer question."
):
settings = Settings()
settings.agent.tool_names = {"gather_evidence", "gen_answer"}
GenerateAnswer,
"gen_answer",
return_value="Failed to answer question.",
autospec=True,
) as mock_gen_answer:
mock_gen_answer.__doc__ = original_doc
settings = Settings(
agent=AgentSettings(
tool_names={"gather_evidence", "gen_answer"}, max_timesteps=3
)
)
response = await agent_query(
query=QueryRequest(
query="Are COVID-19 vaccines effective?", settings=settings
),
docs=Docs(),
)
assert response.status == AgentStatus.FAIL, "Agent should have registered a failure"
assert (
response.status == AgentStatus.TRUNCATED
), "Agent should have hit its max timesteps"


@pytest.mark.flaky(reruns=3, only_rerun=["AssertionError", "EmptyDocsError"])
Expand Down

0 comments on commit 4fc6138

Please sign in to comment.