Skip to content

Commit

Permalink
Added progress bar to get_directory_index with env var enable/disable
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Oct 3, 2024
1 parent 085d880 commit d0c3a3c
Showing 1 changed file with 59 additions and 16 deletions.
75 changes: 59 additions & 16 deletions paperqa/agents/search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import csv
import json
import logging
Expand All @@ -8,14 +9,15 @@
import pickle
import warnings
import zlib
from collections.abc import Sequence
from collections.abc import Awaitable, Callable, Collection, Sequence
from enum import StrEnum, auto
from io import StringIO
from typing import TYPE_CHECKING, Any, ClassVar, cast
from uuid import UUID

import anyio
from pydantic import BaseModel
from rich.progress import Progress
from tantivy import ( # pylint: disable=no-name-in-module
Document,
Index,
Expand Down Expand Up @@ -374,6 +376,7 @@ async def process_file(
manifest: dict[str, Any],
semaphore: anyio.Semaphore,
settings: Settings,
progress_bar_update: Callable[[], Awaitable] | None = None,
) -> None:
abs_file_path = (
pathlib.Path(settings.agent.index.paper_directory).absolute() / rel_file_path
Expand Down Expand Up @@ -413,6 +416,8 @@ async def process_file(
f"Error parsing {file_location}, skipping index for this file."
)
await search_index.mark_failed_document(file_location)
if progress_bar_update:
await progress_bar_update()
return

this_doc = next(iter(tmp_docs.docs.values()))
Expand All @@ -434,8 +439,41 @@ async def process_file(
)
logger.info(f"Complete ({title}).")

# Update progress bar for either a new or previously indexed file
if progress_bar_update:
await progress_bar_update()


WARN_IF_INDEXING_MORE_THAN = 999
ENV_VAR_MATCH: Collection[str] = {"1", "true"}


def _make_progress_bar_update(
sync_index_w_directory: bool, total: int
) -> tuple[contextlib.AbstractContextManager, Callable[[], Awaitable] | None]:
# Disable should override enable
env_var_disable = (
os.environ.get("PQA_INDEX_DISABLE_PROGRESS_BAR", "").lower() in ENV_VAR_MATCH
)
env_var_enable = (
os.environ.get("PQA_INDEX_ENABLE_PROGRESS_BAR", "").lower() in ENV_VAR_MATCH
)
try:
is_cli = is_running_under_cli() # pylint: disable=used-before-assignment
except NameError: # Work around circular import
from . import is_running_under_cli

is_cli = is_running_under_cli()

if sync_index_w_directory and not env_var_disable and (is_cli or env_var_enable):
progress = Progress()
task_id = progress.add_task("Indexing...", total=total)

async def progress_bar_update() -> None:
progress.update(task_id, advance=1)

return progress, progress_bar_update
return contextlib.nullcontext(), None


async def get_directory_index( # noqa: PLR0912
Expand Down Expand Up @@ -532,21 +570,26 @@ async def get_directory_index( # noqa: PLR0912
)

semaphore = anyio.Semaphore(index_settings.concurrency)
async with anyio.create_task_group() as tg:
for rel_file_path in valid_papers_rel_file_paths:
if index_settings.sync_with_paper_directory:
tg.start_soon(
process_file,
rel_file_path,
search_index,
manifest,
semaphore,
_settings,
)
else:
logger.debug(
f"File {rel_file_path} found in paper directory {paper_directory}."
)
progress_bar, progress_bar_update_fn = _make_progress_bar_update(
index_settings.sync_with_paper_directory, total=len(valid_papers_rel_file_paths)
)
with progress_bar:
async with anyio.create_task_group() as tg:
for rel_file_path in valid_papers_rel_file_paths:
if index_settings.sync_with_paper_directory:
tg.start_soon(
process_file,
rel_file_path,
search_index,
manifest,
semaphore,
_settings,
progress_bar_update_fn,
)
else:
logger.debug(
f"File {rel_file_path} found in paper directory {paper_directory}."
)

if search_index.changed:
await search_index.save_index()
Expand Down

0 comments on commit d0c3a3c

Please sign in to comment.