Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not rebuilding SearchIndex every paper_search #512

Merged
merged 5 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def make_initial_state_and_tools(self) -> tuple[EnvironmentState, list[Tool]]:
return self.state, self.tools

async def reset(self) -> tuple[list[Message], list[Tool]]:
# NOTE: don't build the index here, as sometimes we asyncio.gather over this
# method, and our current design (as of v5.0.10) could hit race conditions
# because index building does not use file locks
self._docs.clear_docs()
self.state, self.tools = self.make_initial_state_and_tools()
return (
Expand Down
18 changes: 10 additions & 8 deletions paperqa/agents/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .env import PaperQAEnvironment
from .helpers import litellm_get_search_query, table_formatter
from .models import AgentStatus, AnswerResponse, QueryRequest, SimpleProfiler
from .search import SearchDocumentStorage, SearchIndex
from .search import SearchDocumentStorage, SearchIndex, get_directory_index
from .tools import EnvironmentState, GatherEvidence, GenerateAnswer, PaperSearch

if TYPE_CHECKING:
Expand All @@ -51,7 +51,7 @@ async def agent_query(
if docs is None:
docs = Docs()

search_index = SearchIndex(
answers_index = SearchIndex(
jamesbraza marked this conversation as resolved.
Show resolved Hide resolved
fields=[*SearchIndex.REQUIRED_FIELDS, "question"],
index_name="answers",
index_directory=query.settings.agent.index.index_directory,
Expand All @@ -63,15 +63,15 @@ async def agent_query(

agent_logger.info(f"[bold blue]Answer: {response.answer.answer}[/bold blue]")

await search_index.add_document(
await answers_index.add_document(
{
"file_location": str(response.answer.id),
"body": response.answer.answer,
"question": response.answer.question,
},
document=response,
)
await search_index.save_index()
await answers_index.save_index()
return response


Expand Down Expand Up @@ -106,6 +106,8 @@ async def run_agent(
f" query {query.model_dump()}."
)

# Build the index once here, and then all tools won't need to rebuild it
await get_directory_index(settings=query.settings)
if isinstance(agent_type, str) and agent_type.lower() == FAKE_AGENT_TYPE:
answer, agent_status = await run_fake_agent(query, docs, **runner_kwargs)
elif tool_selector_or_none := query.settings.make_aviary_tool_selector(agent_type):
Expand Down Expand Up @@ -288,7 +290,7 @@ async def index_search(
fields = [*SearchIndex.REQUIRED_FIELDS]
if index_name == "answers":
fields.append("question")
search_index = SearchIndex(
search_or_answers_index = SearchIndex(
fields=fields,
index_name=index_name,
index_directory=index_directory,
Expand All @@ -301,15 +303,15 @@ async def index_search(

results = [
(AnswerResponse(**a[0]) if index_name == "answers" else a[0], a[1])
for a in await search_index.query(query=query, keep_filenames=True)
for a in await search_or_answers_index.query(query=query, keep_filenames=True)
]

if results:
console = Console(record=True)
# Render the table to a string
console.print(table_formatter(results))
else:
count = await search_index.count
agent_logger.info(f"No results found. Searched {count} docs")
count = await search_or_answers_index.count
agent_logger.info(f"No results found. Searched {count} docs.")

return results
35 changes: 20 additions & 15 deletions paperqa/agents/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,11 @@ async def process_file(
WARN_IF_INDEXING_MORE_THAN = 999


async def get_directory_index(
async def get_directory_index( # noqa: PLR0912
index_name: str | None = None,
sync_index_w_directory: bool = True,
settings: MaybeSettings = None,
build: bool = True,
) -> SearchIndex:
"""
Create a Tantivy index by reading from a directory of text files.
Expand All @@ -436,15 +437,14 @@ async def get_directory_index(
Args:
index_name: Deprecated override on the name of the index. If unspecified,
the default behavior is to generate the name from the input settings.
sync_index_w_directory: Sync the index (add or delete index files) with the
source paper directory.
sync_index_w_directory: Opt-out flag to sync the index (add or delete index
files) with the source paper directory.
settings: Application settings.
build: Opt-out flag (default is True) to read the contents of the source paper
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this logic is a bit complicated for users— what do you think about a “from_directory” class method on the SearchIndex class that you can use when you don’t need to build one? Then you can call that instead of adding an extra flag to this function. Logic looks good otherwise though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you were approximately a few seconds too slow on the PR comment haha.

I didn't decompose get_directory_index in this PR because I was trying to be backwards compatible. I will make a downstream PR that does this then

directory and if sync_index_w_directory is enabled also update the index.
"""
_settings = get_settings(settings)
index_settings = _settings.agent.index
semaphore = anyio.Semaphore(index_settings.concurrency)
paper_directory = await index_settings.finalize_paper_directory()

if index_name:
warnings.warn(
(
Expand All @@ -457,6 +457,17 @@ async def get_directory_index(
)
index_settings.name = index_name
del index_name

search_index = SearchIndex(
fields=[*SearchIndex.REQUIRED_FIELDS, "title", "year"],
index_name=index_settings.name or _settings.get_index_name(),
index_directory=index_settings.index_directory,
)
# NOTE: if the index was not previously built, its index_files will be empty.
# Otherwise, the index_files will not be empty
if not build:
return search_index

if not sync_index_w_directory:
warnings.warn(
(
Expand All @@ -470,12 +481,7 @@ async def get_directory_index(
index_settings.sync_with_paper_directory = sync_index_w_directory
del sync_index_w_directory

search_index = SearchIndex(
fields=[*SearchIndex.REQUIRED_FIELDS, "title", "year"],
index_name=index_settings.name or _settings.get_index_name(),
index_directory=index_settings.index_directory,
)

paper_directory = await index_settings.finalize_paper_directory()
metadata = await maybe_get_manifest(
filename=await index_settings.finalize_manifest_file(paper_directory)
)
Expand All @@ -492,10 +498,8 @@ async def get_directory_index(
logger.warning(
f"Indexing {len(valid_paper_dir_files)} files. This may take a few minutes."
)
# NOTE: if the index was not previously built, this will be empty.
# Otherwise, it will not be empty
index_unique_file_paths: set[str] = set((await search_index.index_files).keys())

index_unique_file_paths: set[str] = set((await search_index.index_files).keys())
if extra_index_files := (
index_unique_file_paths - {str(f) for f in valid_paper_dir_files}
):
Expand All @@ -512,6 +516,7 @@ async def get_directory_index(
f" folder ({paper_directory}).[/bold red]"
)

semaphore = anyio.Semaphore(index_settings.concurrency)
async with anyio.create_task_group() as tg:
for file_path in valid_paper_dir_files:
if index_settings.sync_with_paper_directory:
Expand Down
2 changes: 1 addition & 1 deletion paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def paper_search(
offset = self.previous_searches[search_key] = 0

logger.info(f"Starting paper search for {query!r}.")
index = await get_directory_index(settings=self.settings)
index = await get_directory_index(settings=self.settings, build=False)
results: list[Docs] = await index.query(
query,
top_n=self.settings.agent.search_count,
Expand Down
25 changes: 17 additions & 8 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ async def test_get_directory_index(agent_test_settings: Settings) -> None:
)
paper_dir = agent_test_settings.agent.index.paper_directory = Path(tempdir)

index_name = f"stub{uuid4()}" # Unique across test invocations
agent_test_settings.agent.index.name = index_name
index_name = agent_test_settings.agent.index.name = (
f"stub{uuid4()}" # Unique across test invocations
)
with patch.object(
SearchIndex, "save_index", autospec=True, wraps=SearchIndex.save_index
) as mock_save_index:
Expand Down Expand Up @@ -313,18 +314,26 @@ async def test_agent_sharing_state(
embedding_model = agent_test_settings.get_embedding_model()

answer = Answer(question="What is is a self-explanatory model?")
docs = Docs()
query = QueryRequest(query=answer.question, settings=agent_test_settings)
env_state = EnvironmentState(docs=docs, answer=answer)
env_state = EnvironmentState(docs=Docs(), answer=answer)
built_index = await get_directory_index(settings=agent_test_settings)
assert await built_index.count, "Index build did not work"

with subtests.test(msg=PaperSearch.__name__):
search_tool = PaperSearch(
settings=agent_test_settings, embedding_model=embedding_model
)
await search_tool.paper_search(
"XAI self explanatory model", min_year=None, max_year=None, state=env_state
)
assert env_state.docs.docs, "Search did not save any papers"
with patch.object(
SearchIndex, "save_index", autospec=True, wraps=SearchIndex.save_index
) as mock_save_index:
await search_tool.paper_search(
"XAI self explanatory model",
min_year=None,
max_year=None,
state=env_state,
)
assert env_state.docs.docs, "Search did not add any papers"
mock_save_index.assert_not_awaited(), "Search shouldn't try to update the index"
assert all(
(isinstance(d, Doc) or issubclass(d, Doc)) # type: ignore[unreachable]
for d in env_state.docs.docs.values()
Expand Down