Skip to content

Commit

Permalink
Fixed test_tool_failure by building the index as part of the fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Oct 9, 2024
1 parent c122640 commit 8bfa3b4
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ldp.alg.runners import Evaluator, EvaluatorConfig

from paperqa import Docs, QueryRequest, Settings
from paperqa.agents import SearchIndex
from paperqa.agents import get_directory_index
from paperqa.agents.task import (
GradablePaperQAEnvironment,
LitQATaskDataset,
Expand Down Expand Up @@ -77,6 +77,7 @@ def test___len__(

@pytest.mark.asyncio
async def test_evaluation(self, base_query_request: QueryRequest) -> None:
await get_directory_index(settings=base_query_request.settings) # Build
docs = Docs()
# Why are we constructing a TaskConfig here using a serialized QueryRequest and
# Docs? It's to confirm everything works as if hydrating from a YAML config file
Expand Down Expand Up @@ -136,15 +137,14 @@ async def test_tool_failure(self, base_query_request: QueryRequest) -> None:
dataset=dataset,
callbacks=[metrics_callback],
)
with patch.object(
SearchIndex,
"query",
with patch(
"paperqa.agents.search.SearchIndex",
side_effect=Exception("Totally unexpected but retryable error."),
) as mock_query:
) as mock_SearchIndex:
await evaluator.evaluate() # Confirm this does not crash
assert (
metrics_callback.eval_means["truncation_rate"] == 1.0
), "Expected 100% truncations due to max_rollout_steps"
mock_query.assert_awaited(), "Expected failures to come from unit test"
mock_SearchIndex.assert_called(), "Expected failures to come from unit test"
assert metrics_callback.eval_means["correct"] == 0.0
assert metrics_callback.eval_means["correct_unsure"] == 0.0

0 comments on commit 8bfa3b4

Please sign in to comment.