From 2abbc92f5d2ee88d9cb37f9ee0664d977a5093b0 Mon Sep 17 00:00:00 2001 From: James Braza Date: Tue, 1 Oct 2024 12:10:23 -0700 Subject: [PATCH 1/5] Added build flag to get_directory_index, enabling one to bypass rebuilds --- paperqa/agents/search.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/paperqa/agents/search.py b/paperqa/agents/search.py index b96ee42f..d75c879b 100644 --- a/paperqa/agents/search.py +++ b/paperqa/agents/search.py @@ -423,10 +423,11 @@ async def process_file( WARN_IF_INDEXING_MORE_THAN = 999 -async def get_directory_index( +async def get_directory_index( # noqa: PLR0912 index_name: str | None = None, sync_index_w_directory: bool = True, settings: MaybeSettings = None, + build: bool = True, ) -> SearchIndex: """ Create a Tantivy index by reading from a directory of text files. @@ -436,15 +437,14 @@ async def get_directory_index( Args: index_name: Deprecated override on the name of the index. If unspecified, the default behavior is to generate the name from the input settings. - sync_index_w_directory: Sync the index (add or delete index files) with the - source paper directory. + sync_index_w_directory: Opt-out flag to sync the index (add or delete index + files) with the source paper directory. settings: Application settings. + build: Opt-out flag (default is True) to read the contents of the source paper + directory and if sync_index_w_directory is enabled also update the index. """ _settings = get_settings(settings) index_settings = _settings.agent.index - semaphore = anyio.Semaphore(index_settings.concurrency) - paper_directory = await index_settings.finalize_paper_directory() - if index_name: warnings.warn( ( @@ -457,6 +457,17 @@ async def get_directory_index( ) index_settings.name = index_name del index_name + + search_index = SearchIndex( + fields=[*SearchIndex.REQUIRED_FIELDS, "title", "year"], + index_name=index_settings.name or _settings.get_index_name(), + index_directory=index_settings.index_directory, + ) + # NOTE: if the index was not previously built, its index_files will be empty. + # Otherwise, the index_files will not be empty + if not build: + return search_index + if not sync_index_w_directory: warnings.warn( ( @@ -470,12 +481,7 @@ async def get_directory_index( index_settings.sync_with_paper_directory = sync_index_w_directory del sync_index_w_directory - search_index = SearchIndex( - fields=[*SearchIndex.REQUIRED_FIELDS, "title", "year"], - index_name=index_settings.name or _settings.get_index_name(), - index_directory=index_settings.index_directory, - ) - + paper_directory = await index_settings.finalize_paper_directory() metadata = await maybe_get_manifest( filename=await index_settings.finalize_manifest_file(paper_directory) ) @@ -492,10 +498,8 @@ async def get_directory_index( logger.warning( f"Indexing {len(valid_paper_dir_files)} files. This may take a few minutes." ) - # NOTE: if the index was not previously built, this will be empty. - # Otherwise, it will not be empty - index_unique_file_paths: set[str] = set((await search_index.index_files).keys()) + index_unique_file_paths: set[str] = set((await search_index.index_files).keys()) if extra_index_files := ( index_unique_file_paths - {str(f) for f in valid_paper_dir_files} ): @@ -512,6 +516,7 @@ async def get_directory_index( f" folder ({paper_directory}).[/bold red]" ) + semaphore = anyio.Semaphore(index_settings.concurrency) async with anyio.create_task_group() as tg: for file_path in valid_paper_dir_files: if index_settings.sync_with_paper_directory: From 3eed4906f5fd205bffa382715754d70c53eec06f Mon Sep 17 00:00:00 2001 From: James Braza Date: Tue, 1 Oct 2024 12:40:57 -0700 Subject: [PATCH 2/5] Turned off rebuilding of the index every paper search --- paperqa/agents/env.py | 3 +++ paperqa/agents/main.py | 4 +++- paperqa/agents/tools.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index da265bd6..171ee0ba 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -129,6 +129,9 @@ def make_initial_state_and_tools(self) -> tuple[EnvironmentState, list[Tool]]: return self.state, self.tools async def reset(self) -> tuple[list[Message], list[Tool]]: + # NOTE: don't build the index here, as sometimes we asyncio.gather over this + # method, and our current design (as of v5.0.10) could hit race conditions + # because index building does not use file locks self._docs.clear_docs() self.state, self.tools = self.make_initial_state_and_tools() return ( diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index 49d8c54a..c5f62f6c 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -27,7 +27,7 @@ from .env import PaperQAEnvironment from .helpers import litellm_get_search_query, table_formatter from .models import AgentStatus, AnswerResponse, QueryRequest, SimpleProfiler -from .search import SearchDocumentStorage, SearchIndex +from .search import SearchDocumentStorage, SearchIndex, get_directory_index from .tools import EnvironmentState, GatherEvidence, GenerateAnswer, PaperSearch if TYPE_CHECKING: @@ -106,6 +106,8 @@ async def run_agent( f" query {query.model_dump()}." ) + # Build the index once here, and then all tools won't need to rebuild it + await get_directory_index(settings=query.settings) if isinstance(agent_type, str) and agent_type.lower() == FAKE_AGENT_TYPE: answer, agent_status = await run_fake_agent(query, docs, **runner_kwargs) elif tool_selector_or_none := query.settings.make_aviary_tool_selector(agent_type): diff --git a/paperqa/agents/tools.py b/paperqa/agents/tools.py index 69e330ae..f18b1c2b 100644 --- a/paperqa/agents/tools.py +++ b/paperqa/agents/tools.py @@ -124,7 +124,7 @@ async def paper_search( offset = self.previous_searches[search_key] = 0 logger.info(f"Starting paper search for {query!r}.") - index = await get_directory_index(settings=self.settings) + index = await get_directory_index(settings=self.settings, build=False) results: list[Docs] = await index.query( query, top_n=self.settings.agent.search_count, From 5060c9c2d280ff5174cc5a3ff6329a4765928e23 Mon Sep 17 00:00:00 2001 From: James Braza Date: Tue, 1 Oct 2024 12:50:21 -0700 Subject: [PATCH 3/5] Made it clear what is a search index and what is an answers index --- paperqa/agents/main.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index c5f62f6c..e805e20e 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -51,7 +51,7 @@ async def agent_query( if docs is None: docs = Docs() - search_index = SearchIndex( + answers_index = SearchIndex( fields=[*SearchIndex.REQUIRED_FIELDS, "question"], index_name="answers", index_directory=query.settings.agent.index.index_directory, @@ -63,7 +63,7 @@ async def agent_query( agent_logger.info(f"[bold blue]Answer: {response.answer.answer}[/bold blue]") - await search_index.add_document( + await answers_index.add_document( { "file_location": str(response.answer.id), "body": response.answer.answer, @@ -71,7 +71,7 @@ async def agent_query( }, document=response, ) - await search_index.save_index() + await answers_index.save_index() return response @@ -290,7 +290,7 @@ async def index_search( fields = [*SearchIndex.REQUIRED_FIELDS] if index_name == "answers": fields.append("question") - search_index = SearchIndex( + search_or_answers_index = SearchIndex( fields=fields, index_name=index_name, index_directory=index_directory, @@ -303,7 +303,7 @@ async def index_search( results = [ (AnswerResponse(**a[0]) if index_name == "answers" else a[0], a[1]) - for a in await search_index.query(query=query, keep_filenames=True) + for a in await search_or_answers_index.query(query=query, keep_filenames=True) ] if results: @@ -311,7 +311,7 @@ async def index_search( # Render the table to a string console.print(table_formatter(results)) else: - count = await search_index.count - agent_logger.info(f"No results found. Searched {count} docs") + count = await search_or_answers_index.count + agent_logger.info(f"No results found. Searched {count} docs.") return results From 2c7cb42e9c28966ab0f84db6d4523e186d9232fd Mon Sep 17 00:00:00 2001 From: James Braza Date: Tue, 1 Oct 2024 13:48:04 -0700 Subject: [PATCH 4/5] Updated and added tests to reflect new behavior --- tests/test_agents.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/tests/test_agents.py b/tests/test_agents.py index f52e3177..fa6acd8f 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -50,8 +50,9 @@ async def test_get_directory_index(agent_test_settings: Settings) -> None: ) paper_dir = agent_test_settings.agent.index.paper_directory = Path(tempdir) - index_name = f"stub{uuid4()}" # Unique across test invocations - agent_test_settings.agent.index.name = index_name + index_name = agent_test_settings.agent.index.name = ( + f"stub{uuid4()}" # Unique across test invocations + ) with patch.object( SearchIndex, "save_index", autospec=True, wraps=SearchIndex.save_index ) as mock_save_index: @@ -313,18 +314,26 @@ async def test_agent_sharing_state( embedding_model = agent_test_settings.get_embedding_model() answer = Answer(question="What is is a self-explanatory model?") - docs = Docs() query = QueryRequest(query=answer.question, settings=agent_test_settings) - env_state = EnvironmentState(docs=docs, answer=answer) + env_state = EnvironmentState(docs=Docs(), answer=answer) + built_index = await get_directory_index(settings=agent_test_settings) + assert await built_index.count, "Index build did not work" with subtests.test(msg=PaperSearch.__name__): search_tool = PaperSearch( settings=agent_test_settings, embedding_model=embedding_model ) - await search_tool.paper_search( - "XAI self explanatory model", min_year=None, max_year=None, state=env_state - ) - assert env_state.docs.docs, "Search did not save any papers" + with patch.object( + SearchIndex, "save_index", autospec=True, wraps=SearchIndex.save_index + ) as mock_save_index: + await search_tool.paper_search( + "XAI self explanatory model", + min_year=None, + max_year=None, + state=env_state, + ) + assert env_state.docs.docs, "Search did not add any papers" + mock_save_index.assert_not_awaited(), "Search shouldn't try to update the index" assert all( (isinstance(d, Doc) or issubclass(d, Doc)) # type: ignore[unreachable] for d in env_state.docs.docs.values() From ee5cc6a5826317dfa86266459703b958f7d7f8c8 Mon Sep 17 00:00:00 2001 From: James Braza Date: Tue, 1 Oct 2024 15:26:46 -0700 Subject: [PATCH 5/5] Renamed search_or_answers_index to index_to_query to avoid a vague name --- paperqa/agents/main.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/paperqa/agents/main.py b/paperqa/agents/main.py index e805e20e..b24b56e5 100644 --- a/paperqa/agents/main.py +++ b/paperqa/agents/main.py @@ -290,7 +290,7 @@ async def index_search( fields = [*SearchIndex.REQUIRED_FIELDS] if index_name == "answers": fields.append("question") - search_or_answers_index = SearchIndex( + index_to_query = SearchIndex( fields=fields, index_name=index_name, index_directory=index_directory, @@ -303,15 +303,14 @@ async def index_search( results = [ (AnswerResponse(**a[0]) if index_name == "answers" else a[0], a[1]) - for a in await search_or_answers_index.query(query=query, keep_filenames=True) + for a in await index_to_query.query(query=query, keep_filenames=True) ] - if results: console = Console(record=True) # Render the table to a string console.print(table_formatter(results)) else: - count = await search_or_answers_index.count + count = await index_to_query.count agent_logger.info(f"No results found. Searched {count} docs.") return results