From 87f2b9d9fd554d29f16aead4269be883e8927bb5 Mon Sep 17 00:00:00 2001 From: Pamela Fox Date: Wed, 13 Mar 2024 04:04:28 -0700 Subject: [PATCH] Refactoring of prepdocs for easier integration with user upload feature (#1407) * Changes to prepare for upload * All pass except multilang * Made mypy happy * Fix the arguments * Unit tests for code changes * 3.9 compatibility * Remove unused args from prepdocs * Address feedback * Ignore typing error after reporting to SDK team --- scripts/prepdocs.ps1 | 10 +- scripts/prepdocs.py | 430 ++++++++++-------- scripts/prepdocs.sh | 12 +- scripts/prepdocslib/blobmanager.py | 15 +- scripts/prepdocslib/embeddings.py | 30 +- scripts/prepdocslib/filestrategy.py | 95 ++-- scripts/prepdocslib/htmlparser.py | 6 +- .../integratedvectorizerstrategy.py | 37 +- scripts/prepdocslib/listfilestrategy.py | 20 +- scripts/prepdocslib/parser.py | 6 - scripts/prepdocslib/pdfparser.py | 16 +- scripts/prepdocslib/searchmanager.py | 35 +- scripts/prepdocslib/strategy.py | 13 +- scripts/prepdocslib/textsplitter.py | 16 +- tests/test_blob_manager.py | 1 - tests/test_htmlparser.py | 2 +- tests/test_listfilestrategy.py | 11 + tests/test_prepdocs.py | 62 ++- tests/test_prepdocslib_textsplitter.py | 22 +- tests/test_searchmanager.py | 77 +++- 20 files changed, 529 insertions(+), 387 deletions(-) diff --git a/scripts/prepdocs.ps1 b/scripts/prepdocs.ps1 index dad1b0c090..9808342f1f 100755 --- a/scripts/prepdocs.ps1 +++ b/scripts/prepdocs.ps1 @@ -30,17 +30,11 @@ if ($env:AZURE_SEARCH_ANALYZER_NAME) { if ($env:AZURE_VISION_ENDPOINT) { $visionEndpointArg = "--visionendpoint $env:AZURE_VISION_ENDPOINT" } -if ($env:AZURE_VISION_KEY) { - $visionKeyArg = "--visionkey $env:AZURE_VISION_KEY" -} -# If vision keys are stored in keyvault provide the keyvault name and secret name +# If any keys are stored in keyvault provide the keyvault name and secret name if ($env:AZURE_KEY_VAULT_NAME) { $keyVaultName = "--keyvaultname $env:AZURE_KEY_VAULT_NAME" } -if ($env:VISION_SECRET_NAME) { - $visionSecretNameArg = "--visionsecretname $env:VISION_SECRET_NAME" -} if ($env:AZURE_SEARCH_SECRET_NAME) { $searchSecretNameArg = "--searchsecretname $env:AZURE_SEARCH_SECRET_NAME" } @@ -81,7 +75,7 @@ $argumentList = "./scripts/prepdocs.py $dataArg --verbose " + ` "--openaiservice `"$env:AZURE_OPENAI_SERVICE`" --openaideployment `"$env:AZURE_OPENAI_EMB_DEPLOYMENT`" " + ` "--openaikey `"$env:OPENAI_API_KEY`" --openaiorg `"$env:OPENAI_ORGANIZATION`" " + ` "--documentintelligenceservice $env:AZURE_DOCUMENTINTELLIGENCE_SERVICE " + ` -"$searchImagesArg $visionEndpointArg $visionKeyArg $visionSecretNameArg " + ` +"$searchImagesArg $visionEndpointArg " + ` "$adlsGen2StorageAccountArg $adlsGen2FilesystemArg $adlsGen2FilesystemPathArg " + ` "$tenantArg $aclArg " + ` "$disableVectorsArg $localPdfParserArg $localHtmlParserArg " + ` diff --git a/scripts/prepdocs.py b/scripts/prepdocs.py index b0bc539072..97add39df8 100644 --- a/scripts/prepdocs.py +++ b/scripts/prepdocs.py @@ -1,6 +1,7 @@ import argparse import asyncio -from typing import Any, Optional, Union +import logging +from typing import Optional, Union from azure.core.credentials import AzureKeyCredential from azure.core.credentials_async import AsyncTokenCredential @@ -11,7 +12,6 @@ from prepdocslib.embeddings import ( AzureOpenAIEmbeddingService, ImageEmbeddings, - OpenAIEmbeddings, OpenAIEmbeddingService, ) from prepdocslib.fileprocessor import FileProcessor @@ -32,50 +32,158 @@ from prepdocslib.textparser import TextParser from prepdocslib.textsplitter import SentenceTextSplitter, SimpleTextSplitter +logger = logging.getLogger("ingester") -def is_key_empty(key): - return key is None or len(key.strip()) == 0 +def clean_key_if_exists(key: Union[str, None]) -> Union[str, None]: + """Remove leading and trailing whitespace from a key if it exists. If the key is empty, return None.""" + if key is not None and key.strip() != "": + return key.strip() + return None -async def setup_file_strategy(credential: AsyncTokenCredential, args: Any) -> Strategy: - storage_creds = credential if is_key_empty(args.storagekey) else args.storagekey - blob_manager = BlobManager( - endpoint=f"https://{args.storageaccount}.blob.core.windows.net", - container=args.container, - account=args.storageaccount, + +async def setup_search_info( + search_service: str, + index_name: str, + azure_credential: AsyncTokenCredential, + search_key: Union[str, None] = None, + key_vault_name: Union[str, None] = None, + search_secret_name: Union[str, None] = None, +) -> SearchInfo: + if key_vault_name and search_secret_name: + async with SecretClient( + vault_url=f"https://{key_vault_name}.vault.azure.net", credential=azure_credential + ) as key_vault_client: + search_key = (await key_vault_client.get_secret(search_secret_name)).value # type: ignore[attr-defined] + + search_creds: Union[AsyncTokenCredential, AzureKeyCredential] = ( + azure_credential if search_key is None else AzureKeyCredential(search_key) + ) + + return SearchInfo( + endpoint=f"https://{search_service}.search.windows.net/", + credential=search_creds, + index_name=index_name, + ) + + +def setup_blob_manager( + azure_credential: AsyncTokenCredential, + storage_account: str, + storage_container: str, + storage_resource_group: str, + subscription_id: str, + search_images: bool, + storage_key: Union[str, None] = None, +): + storage_creds: Union[AsyncTokenCredential, str] = azure_credential if storage_key is None else storage_key + return BlobManager( + endpoint=f"https://{storage_account}.blob.core.windows.net", + container=storage_container, + account=storage_account, credential=storage_creds, - resourceGroup=args.storageresourcegroup, - subscriptionId=args.subscriptionid, - store_page_images=args.searchimages, - verbose=args.verbose, + resourceGroup=storage_resource_group, + subscriptionId=subscription_id, + store_page_images=search_images, ) + +def setup_list_file_strategy( + azure_credential: AsyncTokenCredential, + local_files: Union[str, None], + datalake_storage_account: Union[str, None], + datalake_filesystem: Union[str, None], + datalake_path: Union[str, None], + datalake_key: Union[str, None], +): + list_file_strategy: ListFileStrategy + if datalake_storage_account: + if datalake_filesystem is None or datalake_path is None: + raise ValueError("DataLake file system and path are required when using Azure Data Lake Gen2") + adls_gen2_creds: Union[AsyncTokenCredential, str] = azure_credential if datalake_key is None else datalake_key + logger.info(f"Using Data Lake Gen2 Storage Account {datalake_storage_account}") + list_file_strategy = ADLSGen2ListFileStrategy( + data_lake_storage_account=datalake_storage_account, + data_lake_filesystem=datalake_filesystem, + data_lake_path=datalake_path, + credential=adls_gen2_creds, + ) + elif local_files: + logger.info(f"Using local files in {local_files}") + list_file_strategy = LocalListFileStrategy(path_pattern=local_files) + else: + raise ValueError("Either local_files or datalake_storage_account must be provided.") + return list_file_strategy + + +def setup_embeddings_service( + azure_credential: AsyncTokenCredential, + openai_host: str, + openai_model_name: str, + openai_service: str, + openai_deployment: str, + openai_key: Union[str, None], + openai_org: Union[str, None], + disable_vectors: bool = False, + disable_batch_vectors: bool = False, +): + if disable_vectors: + logger.info("Not setting up embeddings service") + return None + + if openai_host != "openai": + azure_open_ai_credential: Union[AsyncTokenCredential, AzureKeyCredential] = ( + azure_credential if openai_key is None else AzureKeyCredential(openai_key) + ) + return AzureOpenAIEmbeddingService( + open_ai_service=openai_service, + open_ai_deployment=openai_deployment, + open_ai_model_name=openai_model_name, + credential=azure_open_ai_credential, + disable_batch=disable_batch_vectors, + ) + else: + if openai_key is None: + raise ValueError("OpenAI key is required when using the non-Azure OpenAI API") + return OpenAIEmbeddingService( + open_ai_model_name=openai_model_name, + credential=openai_key, + organization=openai_org, + disable_batch=disable_batch_vectors, + ) + + +def setup_file_processors( + azure_credential: AsyncTokenCredential, + document_intelligence_service: Union[str, None], + document_intelligence_key: Union[str, None] = None, + local_pdf_parser: bool = False, + local_html_parser: bool = False, + search_images: bool = False, +): html_parser: Parser pdf_parser: Parser doc_int_parser: DocumentAnalysisParser # check if Azure Document Intelligence credentials are provided - if args.documentintelligenceservice is not None: + if document_intelligence_service is not None: documentintelligence_creds: Union[AsyncTokenCredential, AzureKeyCredential] = ( - credential - if is_key_empty(args.documentintelligencekey) - else AzureKeyCredential(args.documentintelligencekey) + azure_credential if document_intelligence_key is None else AzureKeyCredential(document_intelligence_key) ) doc_int_parser = DocumentAnalysisParser( - endpoint=f"https://{args.documentintelligenceservice}.cognitiveservices.azure.com/", + endpoint=f"https://{document_intelligence_service}.cognitiveservices.azure.com/", credential=documentintelligence_creds, - verbose=args.verbose, ) - if args.localpdfparser or args.documentintelligenceservice is None: - pdf_parser = LocalPdfParser(verbose=args.verbose) + if local_pdf_parser or document_intelligence_service is None: + pdf_parser = LocalPdfParser() else: pdf_parser = doc_int_parser - if args.localhtmlparser or args.documentintelligenceservice is None: - html_parser = LocalHTMLParser(verbose=args.verbose) + if local_html_parser or document_intelligence_service is None: + html_parser = LocalHTMLParser() else: html_parser = doc_int_parser - sentence_text_splitter = SentenceTextSplitter(has_image_embeddings=args.searchimages) - file_processors = { + sentence_text_splitter = SentenceTextSplitter(has_image_embeddings=search_images) + return { ".pdf": FileProcessor(pdf_parser, sentence_text_splitter), ".html": FileProcessor(html_parser, sentence_text_splitter), ".json": FileProcessor(JsonParser(), SimpleTextSplitter()), @@ -91,160 +199,27 @@ async def setup_file_strategy(credential: AsyncTokenCredential, args: Any) -> St ".md": FileProcessor(TextParser(), sentence_text_splitter), ".txt": FileProcessor(TextParser(), sentence_text_splitter), } - use_vectors = not args.novectors - embeddings: Optional[OpenAIEmbeddings] = None - if use_vectors and args.openaihost != "openai": - azure_open_ai_credential: Union[AsyncTokenCredential, AzureKeyCredential] = ( - credential if is_key_empty(args.openaikey) else AzureKeyCredential(args.openaikey) - ) - embeddings = AzureOpenAIEmbeddingService( - open_ai_service=args.openaiservice, - open_ai_deployment=args.openaideployment, - open_ai_model_name=args.openaimodelname, - credential=azure_open_ai_credential, - disable_batch=args.disablebatchvectors, - verbose=args.verbose, - ) - elif use_vectors: - embeddings = OpenAIEmbeddingService( - open_ai_model_name=args.openaimodelname, - credential=args.openaikey, - organization=args.openaiorg, - disable_batch=args.disablebatchvectors, - verbose=args.verbose, - ) - - image_embeddings: Optional[ImageEmbeddings] = None - if args.searchimages: - image_embeddings = ImageEmbeddings( - endpoint=args.visionendpoint, - token_provider=get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default"), - verbose=args.verbose, - ) - - print("Processing files...") - list_file_strategy: ListFileStrategy - if args.datalakestorageaccount: - adls_gen2_creds = credential if is_key_empty(args.datalakekey) else args.datalakekey - print(f"Using Data Lake Gen2 Storage Account {args.datalakestorageaccount}") - list_file_strategy = ADLSGen2ListFileStrategy( - data_lake_storage_account=args.datalakestorageaccount, - data_lake_filesystem=args.datalakefilesystem, - data_lake_path=args.datalakepath, - credential=adls_gen2_creds, - verbose=args.verbose, - ) - else: - print(f"Using local files in {args.files}") - list_file_strategy = LocalListFileStrategy(path_pattern=args.files, verbose=args.verbose) - - if args.removeall: - document_action = DocumentAction.RemoveAll - elif args.remove: - document_action = DocumentAction.Remove - else: - document_action = DocumentAction.Add - - return FileStrategy( - list_file_strategy=list_file_strategy, - blob_manager=blob_manager, - file_processors=file_processors, - document_action=document_action, - embeddings=embeddings, - image_embeddings=image_embeddings, - search_analyzer_name=args.searchanalyzername, - use_acls=args.useacls, - category=args.category, - ) - - -async def setup_intvectorizer_strategy(credential: AsyncTokenCredential, args: Any) -> Strategy: - storage_creds = credential if is_key_empty(args.storagekey) else args.storagekey - blob_manager = BlobManager( - endpoint=f"https://{args.storageaccount}.blob.core.windows.net", - container=args.container, - account=args.storageaccount, - credential=storage_creds, - resourceGroup=args.storageresourcegroup, - subscriptionId=args.subscriptionid, - store_page_images=args.searchimages, - verbose=args.verbose, - ) - - use_vectors = not args.novectors - embeddings: Union[AzureOpenAIEmbeddingService, None] = None - if use_vectors and args.openaihost != "openai": - azure_open_ai_credential: Union[AsyncTokenCredential, AzureKeyCredential] = ( - credential if is_key_empty(args.openaikey) else AzureKeyCredential(args.openaikey) - ) - embeddings = AzureOpenAIEmbeddingService( - open_ai_service=args.openaiservice, - open_ai_deployment=args.openaideployment, - open_ai_model_name=args.openaimodelname, - credential=azure_open_ai_credential, - disable_batch=args.disablebatchvectors, - verbose=args.verbose, - ) - print("Processing files...") - list_file_strategy: ListFileStrategy - if args.datalakestorageaccount: - adls_gen2_creds = credential if is_key_empty(args.datalakekey) else args.datalakekey - print(f"Using Data Lake Gen2 Storage Account {args.datalakestorageaccount}") - list_file_strategy = ADLSGen2ListFileStrategy( - data_lake_storage_account=args.datalakestorageaccount, - data_lake_filesystem=args.datalakefilesystem, - data_lake_path=args.datalakepath, - credential=adls_gen2_creds, - verbose=args.verbose, +def setup_image_embeddings_service( + azure_credential: AsyncTokenCredential, vision_endpoint: Union[str, None], search_images: bool +) -> Union[ImageEmbeddings, None]: + image_embeddings_service: Optional[ImageEmbeddings] = None + if search_images: + if vision_endpoint is None: + raise ValueError("A computer vision endpoint is required when GPT-4-vision is enabled.") + image_embeddings_service = ImageEmbeddings( + endpoint=vision_endpoint, + token_provider=get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default"), ) - else: - print(f"Using local files in {args.files}") - list_file_strategy = LocalListFileStrategy(path_pattern=args.files, verbose=args.verbose) - - if args.removeall: - document_action = DocumentAction.RemoveAll - elif args.remove: - document_action = DocumentAction.Remove - else: - document_action = DocumentAction.Add + return image_embeddings_service - return IntegratedVectorizerStrategy( - list_file_strategy=list_file_strategy, - blob_manager=blob_manager, - document_action=document_action, - embeddings=embeddings, - subscription_id=args.subscriptionid, - search_service_user_assigned_id=args.searchserviceassignedid, - search_analyzer_name=args.searchanalyzername, - use_acls=args.useacls, - category=args.category, - ) +async def main(strategy: Strategy, setup_index: bool = True): + if setup_index: + await strategy.setup() -async def main(strategy: Strategy, credential: AsyncTokenCredential, args: Any): - search_key = args.searchkey - if args.keyvaultname and args.searchsecretname: - key_vault_client = SecretClient(vault_url=f"https://{args.keyvaultname}.vault.azure.net", credential=credential) - search_key = (await key_vault_client.get_secret(args.searchsecretname)).value - await key_vault_client.close() - - search_creds: Union[AsyncTokenCredential, AzureKeyCredential] = ( - credential if is_key_empty(search_key) else AzureKeyCredential(search_key) - ) - - search_info = SearchInfo( - endpoint=f"https://{args.searchservice}.search.windows.net/", - credential=search_creds, - index_name=args.index, - verbose=args.verbose, - ) - - if not args.remove and not args.removeall: - await strategy.setup(search_info) - - await strategy.run(search_info) + await strategy.run() if __name__ == "__main__": @@ -388,21 +363,11 @@ async def main(strategy: Strategy, credential: AsyncTokenCredential, args: Any): required=False, help="Optional, required if --searchimages is specified. Endpoint of Azure AI Vision service to use when embedding images.", ) - parser.add_argument( - "--visionkey", - required=False, - help="Required if --searchimages is specified. Use this Azure AI Vision key instead of the instead of the current user identity to login.", - ) parser.add_argument( "--keyvaultname", required=False, help="Required only if any keys must be fetched from the key vault.", ) - parser.add_argument( - "--visionsecretname", - required=False, - help="Required if --searchimages is specified and --keyvaultname is provided. Fetch the Azure AI Vision key from this key vault instead of using the current user identity to login.", - ) parser.add_argument( "--useintvectorization", required=False, @@ -410,6 +375,13 @@ async def main(strategy: Strategy, credential: AsyncTokenCredential, args: Any): ) parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") args = parser.parse_args() + + if args.verbose: + logging.basicConfig(format="%(message)s") + # We only set the level to INFO for our logger, + # to avoid seeing the noisy INFO level logs from the Azure SDKs + logger.setLevel(logging.INFO) + use_int_vectorization = args.useintvectorization and args.useintvectorization.lower() == "true" # Use the current user identity to connect to Azure services unless a key is explicitly set for any of them @@ -419,12 +391,94 @@ async def main(strategy: Strategy, credential: AsyncTokenCredential, args: Any): else AzureDeveloperCliCredential(tenant_id=args.tenantid, process_timeout=60) ) + if args.removeall: + document_action = DocumentAction.RemoveAll + elif args.remove: + document_action = DocumentAction.Remove + else: + document_action = DocumentAction.Add + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - ingestion_strategy = None + + search_info = loop.run_until_complete( + setup_search_info( + search_service=args.searchservice, + index_name=args.index, + azure_credential=azd_credential, + search_key=clean_key_if_exists(args.searchkey), + key_vault_name=args.keyvaultname, + search_secret_name=args.searchsecretname, + ) + ) + blob_manager = setup_blob_manager( + azure_credential=azd_credential, + storage_account=args.storageaccount, + storage_container=args.container, + storage_resource_group=args.storageresourcegroup, + subscription_id=args.subscriptionid, + search_images=args.searchimages, + storage_key=clean_key_if_exists(args.storagekey), + ) + list_file_strategy = setup_list_file_strategy( + azure_credential=azd_credential, + local_files=args.files, + datalake_storage_account=args.datalakestorageaccount, + datalake_filesystem=args.datalakefilesystem, + datalake_path=args.datalakepath, + datalake_key=clean_key_if_exists(args.datalakekey), + ) + openai_embeddings_service = setup_embeddings_service( + azure_credential=azd_credential, + openai_host=args.openaihost, + openai_model_name=args.openaimodelname, + openai_service=args.openaiservice, + openai_deployment=args.openaideployment, + openai_key=clean_key_if_exists(args.openaikey), + openai_org=args.openaiorg, + disable_vectors=args.novectors, + disable_batch_vectors=args.disablebatchvectors, + ) + + ingestion_strategy: Strategy if use_int_vectorization: - ingestion_strategy = loop.run_until_complete(setup_intvectorizer_strategy(azd_credential, args)) + ingestion_strategy = IntegratedVectorizerStrategy( + search_info=search_info, + list_file_strategy=list_file_strategy, + blob_manager=blob_manager, + document_action=document_action, + embeddings=openai_embeddings_service, + subscription_id=args.subscriptionid, + search_service_user_assigned_id=args.searchserviceassignedid, + search_analyzer_name=args.searchanalyzername, + use_acls=args.useacls, + category=args.category, + ) else: - ingestion_strategy = loop.run_until_complete(setup_file_strategy(azd_credential, args)) - loop.run_until_complete(main(ingestion_strategy, azd_credential, args)) + file_processors = setup_file_processors( + azure_credential=azd_credential, + document_intelligence_service=args.documentintelligenceservice, + document_intelligence_key=clean_key_if_exists(args.documentintelligencekey), + local_pdf_parser=args.localpdfparser, + local_html_parser=args.localhtmlparser, + search_images=args.searchimages, + ) + image_embeddings_service = setup_image_embeddings_service( + azure_credential=azd_credential, vision_endpoint=args.visionendpoint, search_images=args.searchimages + ) + + ingestion_strategy = FileStrategy( + search_info=search_info, + list_file_strategy=list_file_strategy, + blob_manager=blob_manager, + file_processors=file_processors, + document_action=document_action, + embeddings=openai_embeddings_service, + image_embeddings=image_embeddings_service, + search_analyzer_name=args.searchanalyzername, + use_acls=args.useacls, + category=args.category, + ) + + loop.run_until_complete(main(ingestion_strategy, setup_index=not args.remove and not args.removeall)) loop.close() diff --git a/scripts/prepdocs.sh b/scripts/prepdocs.sh index 6d0b2a5a54..3e9747ba0b 100755 --- a/scripts/prepdocs.sh +++ b/scripts/prepdocs.sh @@ -30,21 +30,11 @@ if [ -n "$AZURE_VISION_ENDPOINT" ]; then visionEndpointArg="--visionendpoint $AZURE_VISION_ENDPOINT" fi -visionKeyArg="" -if [ -n "$AZURE_VISION_KEY" ]; then - visionKeyArg="--visionkey $AZURE_VISION_KEY" -fi - keyVaultName="" if [ -n "$AZURE_KEY_VAULT_NAME" ]; then keyVaultName="--keyvaultname $AZURE_KEY_VAULT_NAME" fi -visionSecretNameArg="" -if [ -n "$VISION_SECRET_NAME" ]; then - visionSecretNameArg="--visionsecretname $VISION_SECRET_NAME" -fi - searchSecretNameArg="" if [ -n "$AZURE_SEARCH_SECRET_NAME" ]; then searchSecretNameArg="--searchsecretname $AZURE_SEARCH_SECRET_NAME" @@ -83,7 +73,7 @@ $searchAnalyzerNameArg $searchSecretNameArg \ --openaiservice "$AZURE_OPENAI_SERVICE" --openaideployment "$AZURE_OPENAI_EMB_DEPLOYMENT" \ --openaikey "$OPENAI_API_KEY" --openaiorg "$OPENAI_ORGANIZATION" \ --documentintelligenceservice "$AZURE_DOCUMENTINTELLIGENCE_SERVICE" \ -$searchImagesArg $visionEndpointArg $visionKeyArg $visionSecretNameArg \ +$searchImagesArg $visionEndpointArg \ $adlsGen2StorageAccountArg $adlsGen2FilesystemArg $adlsGen2FilesystemPathArg \ $tenantArg $aclArg \ $disableVectorsArg $localPdfParserArg $localHtmlParserArg \ diff --git a/scripts/prepdocslib/blobmanager.py b/scripts/prepdocslib/blobmanager.py index 2521b62eac..ec2b6747b9 100644 --- a/scripts/prepdocslib/blobmanager.py +++ b/scripts/prepdocslib/blobmanager.py @@ -1,5 +1,6 @@ import datetime import io +import logging import os import re from typing import List, Optional, Union @@ -17,6 +18,8 @@ from .listfilestrategy import File +logger = logging.getLogger("ingester") + class BlobManager: """ @@ -32,14 +35,12 @@ def __init__( resourceGroup: str, subscriptionId: str, store_page_images: bool = False, - verbose: bool = False, ): self.endpoint = endpoint self.credential = credential self.account = account self.container = container self.store_page_images = store_page_images - self.verbose = verbose self.resourceGroup = resourceGroup self.subscriptionId = subscriptionId self.user_delegation_key: Optional[UserDelegationKey] = None @@ -54,7 +55,7 @@ async def upload_blob(self, file: File) -> Optional[List[str]]: # Re-open and upload the original file with open(file.content.name, "rb") as reopened_file: blob_name = BlobManager.blob_name_from_file_name(file.content.name) - print(f"\tUploading blob for whole file -> {blob_name}") + logger.info(f"\tUploading blob for whole file -> {blob_name}") await container_client.upload_blob(blob_name, reopened_file, overwrite=True) if self.store_page_images and os.path.splitext(file.content.name)[1].lower() == ".pdf": @@ -83,12 +84,11 @@ async def upload_pdf_blob_images( try: font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", 20) except OSError: - print("\tUnable to find arial.ttf or FreeMono.ttf, using default font") + logger.info("\tUnable to find arial.ttf or FreeMono.ttf, using default font") for i in range(page_count): blob_name = BlobManager.blob_image_name_from_file_page(file.content.name, i) - if self.verbose: - print(f"\tConverting page {i} to image and uploading -> {blob_name}") + logger.info(f"\tConverting page {i} to image and uploading -> {blob_name}") doc = fitz.open(file.content.name) page = doc.load_page(i) @@ -154,8 +154,7 @@ async def remove_blob(self, path: Optional[str] = None): ) ) or (path is not None and blob_path == os.path.basename(path)): continue - if self.verbose: - print(f"\tRemoving blob {blob_path}") + logger.info(f"\tRemoving blob {blob_path}") await container_client.delete_blob(blob_path) @classmethod diff --git a/scripts/prepdocslib/embeddings.py b/scripts/prepdocslib/embeddings.py index d6193d1ce7..1e47cf8efd 100644 --- a/scripts/prepdocslib/embeddings.py +++ b/scripts/prepdocslib/embeddings.py @@ -1,3 +1,4 @@ +import logging from abc import ABC from typing import Awaitable, Callable, List, Optional, Union from urllib.parse import urljoin @@ -16,6 +17,8 @@ ) from typing_extensions import TypedDict +logger = logging.getLogger("ingester") + class EmbeddingBatch: """ @@ -35,17 +38,15 @@ class OpenAIEmbeddings(ABC): SUPPORTED_BATCH_AOAI_MODEL = {"text-embedding-ada-002": {"token_limit": 8100, "max_batch_size": 16}} - def __init__(self, open_ai_model_name: str, disable_batch: bool = False, verbose: bool = False): + def __init__(self, open_ai_model_name: str, disable_batch: bool = False): self.open_ai_model_name = open_ai_model_name self.disable_batch = disable_batch - self.verbose = verbose async def create_client(self) -> AsyncOpenAI: raise NotImplementedError def before_retry_sleep(self, retry_state): - if self.verbose: - print("Rate limited on the OpenAI embeddings API, sleeping before retrying...") + logger.info("Rate limited on the OpenAI embeddings API, sleeping before retrying...") def calculate_token_length(self, text: str): encoding = tiktoken.encoding_for_model(self.open_ai_model_name) @@ -96,8 +97,7 @@ async def create_embedding_batch(self, texts: List[str]) -> List[List[float]]: with attempt: emb_response = await client.embeddings.create(model=self.open_ai_model_name, input=batch.texts) embeddings.extend([data.embedding for data in emb_response.data]) - if self.verbose: - print(f"Batch Completed. Batch size {len(batch.texts)} Token count {batch.token_length}") + logger.info(f"Batch Completed. Batch size {len(batch.texts)} Token count {batch.token_length}") return embeddings @@ -134,9 +134,8 @@ def __init__( open_ai_model_name: str, credential: Union[AsyncTokenCredential, AzureKeyCredential], disable_batch: bool = False, - verbose: bool = False, ): - super().__init__(open_ai_model_name, disable_batch, verbose) + super().__init__(open_ai_model_name, disable_batch) self.open_ai_service = open_ai_service self.open_ai_deployment = open_ai_deployment self.credential = credential @@ -171,14 +170,9 @@ class OpenAIEmbeddingService(OpenAIEmbeddings): """ def __init__( - self, - open_ai_model_name: str, - credential: str, - organization: Optional[str] = None, - disable_batch: bool = False, - verbose: bool = False, + self, open_ai_model_name: str, credential: str, organization: Optional[str] = None, disable_batch: bool = False ): - super().__init__(open_ai_model_name, disable_batch, verbose) + super().__init__(open_ai_model_name, disable_batch) self.credential = credential self.organization = organization @@ -192,10 +186,9 @@ class ImageEmbeddings: To learn more, please visit https://learn.microsoft.com/azure/ai-services/computer-vision/how-to/image-retrieval#call-the-vectorize-image-api """ - def __init__(self, endpoint: str, token_provider: Callable[[], Awaitable[str]], verbose: bool = False): + def __init__(self, endpoint: str, token_provider: Callable[[], Awaitable[str]]): self.token_provider = token_provider self.endpoint = endpoint - self.verbose = verbose async def create_embeddings(self, blob_urls: List[str]) -> List[List[float]]: endpoint = urljoin(self.endpoint, "computervision/retrieval:vectorizeImage") @@ -221,5 +214,4 @@ async def create_embeddings(self, blob_urls: List[str]) -> List[List[float]]: return embeddings def before_retry_sleep(self, retry_state): - if self.verbose: - print("Rate limited on the Vision embeddings API, sleeping before retrying...") + logger.info("Rate limited on the Vision embeddings API, sleeping before retrying...") diff --git a/scripts/prepdocslib/filestrategy.py b/scripts/prepdocslib/filestrategy.py index ff134a4c34..2496dfc0e1 100644 --- a/scripts/prepdocslib/filestrategy.py +++ b/scripts/prepdocslib/filestrategy.py @@ -1,12 +1,32 @@ +import logging from typing import List, Optional from .blobmanager import BlobManager from .embeddings import ImageEmbeddings, OpenAIEmbeddings from .fileprocessor import FileProcessor -from .listfilestrategy import ListFileStrategy +from .listfilestrategy import File, ListFileStrategy from .searchmanager import SearchManager, Section from .strategy import DocumentAction, SearchInfo, Strategy +logger = logging.getLogger("ingester") + + +async def parse_file( + file: File, file_processors: dict[str, FileProcessor], category: Optional[str] = None +) -> List[Section]: + key = file.file_extension() + processor = file_processors.get(key) + if processor is None: + logger.info(f"Skipping '{file.filename()}', no parser found.") + return [] + logger.info(f"Parsing '{file.filename()}'") + pages = [page async for page in processor.parser.parse(content=file.content)] + logger.info(f"Splitting '{file.filename()}' into sections") + sections = [ + Section(split_page, content=file, category=category) for split_page in processor.splitter.split_pages(pages) + ] + return sections + class FileStrategy(Strategy): """ @@ -17,6 +37,7 @@ def __init__( self, list_file_strategy: ListFileStrategy, blob_manager: BlobManager, + search_info: SearchInfo, file_processors: dict[str, FileProcessor], document_action: DocumentAction = DocumentAction.Add, embeddings: Optional[OpenAIEmbeddings] = None, @@ -32,12 +53,13 @@ def __init__( self.embeddings = embeddings self.image_embeddings = image_embeddings self.search_analyzer_name = search_analyzer_name + self.search_info = search_info self.use_acls = use_acls self.category = category - async def setup(self, search_info: SearchInfo): + async def setup(self): search_manager = SearchManager( - search_info, + self.search_info, self.search_analyzer_name, self.use_acls, False, @@ -46,34 +68,21 @@ async def setup(self, search_info: SearchInfo): ) await search_manager.create_index() - async def run(self, search_info: SearchInfo): - search_manager = SearchManager(search_info, self.search_analyzer_name, self.use_acls, False, self.embeddings) + async def run(self): + search_manager = SearchManager( + self.search_info, self.search_analyzer_name, self.use_acls, False, self.embeddings + ) if self.document_action == DocumentAction.Add: files = self.list_file_strategy.list() async for file in files: try: - key = file.file_extension() - processor = self.file_processors.get(key) - if processor is None: - # skip file if no parser is found - if search_info.verbose: - print(f"Skipping '{file.filename()}'.") - continue - if search_info.verbose: - print(f"Parsing '{file.filename()}'") - pages = [page async for page in processor.parser.parse(content=file.content)] - if search_info.verbose: - print(f"Splitting '{file.filename()}' into sections") - sections = [ - Section(split_page, content=file, category=self.category) - for split_page in processor.splitter.split_pages(pages) - ] - - blob_sas_uris = await self.blob_manager.upload_blob(file) - blob_image_embeddings: Optional[List[List[float]]] = None - if self.image_embeddings and blob_sas_uris: - blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) - await search_manager.update_content(sections, blob_image_embeddings) + sections = await parse_file(file, self.file_processors, self.category) + if sections: + blob_sas_uris = await self.blob_manager.upload_blob(file) + blob_image_embeddings: Optional[List[List[float]]] = None + if self.image_embeddings and blob_sas_uris: + blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) + await search_manager.update_content(sections, blob_image_embeddings) finally: if file: file.close() @@ -85,3 +94,35 @@ async def run(self, search_info: SearchInfo): elif self.document_action == DocumentAction.RemoveAll: await self.blob_manager.remove_blob() await search_manager.remove_content() + + +class UploadUserFileStrategy: + """ + Strategy for ingesting a file that has already been uploaded to a ADLS2 storage account + """ + + def __init__( + self, + search_info: SearchInfo, + file_processors: dict[str, FileProcessor], + embeddings: Optional[OpenAIEmbeddings] = None, + image_embeddings: Optional[ImageEmbeddings] = None, + ): + self.file_processors = file_processors + self.embeddings = embeddings + self.image_embeddings = image_embeddings + self.search_info = search_info + self.search_manager = SearchManager(self.search_info, None, True, False, self.embeddings) + + async def add_file(self, file: File): + if self.image_embeddings: + logging.warning("Image embeddings are not currently supported for the user upload feature") + sections = await parse_file(file, self.file_processors) + if sections: + await self.search_manager.update_content(sections) + + async def remove_file(self, filename: str, oid: str): + if filename is None or filename == "": + logging.warning("Filename is required to remove a file") + return + await self.search_manager.remove_content(filename, oid) diff --git a/scripts/prepdocslib/htmlparser.py b/scripts/prepdocslib/htmlparser.py index 116bb5fb30..ee30ff9349 100644 --- a/scripts/prepdocslib/htmlparser.py +++ b/scripts/prepdocslib/htmlparser.py @@ -1,3 +1,4 @@ +import logging import re from typing import IO, AsyncGenerator @@ -6,6 +7,8 @@ from .page import Page from .parser import Parser +logger = logging.getLogger("ingester") + def cleanup_data(data: str) -> str: """Cleans up the given content using regexes @@ -35,8 +38,7 @@ async def parse(self, content: IO) -> AsyncGenerator[Page, None]: Returns: Page: The parsed html Page. """ - if self.verbose: - print(f"\tExtracting text from '{content.name}' using local HTML parser (BeautifulSoup)") + logger.info(f"\tExtracting text from '{content.name}' using local HTML parser (BeautifulSoup)") data = content.read() soup = BeautifulSoup(data, "html.parser") diff --git a/scripts/prepdocslib/integratedvectorizerstrategy.py b/scripts/prepdocslib/integratedvectorizerstrategy.py index 04b9e6c834..0c475b9f52 100644 --- a/scripts/prepdocslib/integratedvectorizerstrategy.py +++ b/scripts/prepdocslib/integratedvectorizerstrategy.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from azure.search.documents.indexes._generated.models import ( @@ -27,6 +28,8 @@ from .searchmanager import SearchManager from .strategy import DocumentAction, SearchInfo, Strategy +logger = logging.getLogger("ingester") + class IntegratedVectorizerStrategy(Strategy): """ @@ -37,6 +40,7 @@ def __init__( self, list_file_strategy: ListFileStrategy, blob_manager: BlobManager, + search_info: SearchInfo, embeddings: Optional[AzureOpenAIEmbeddingService], subscription_id: str, search_service_user_assigned_id: str, @@ -45,8 +49,8 @@ def __init__( use_acls: bool = False, category: Optional[str] = None, ): - if not embeddings: - raise Exception("Expecting AzureOpenAI embedding Service") + if not embeddings or not isinstance(embeddings, AzureOpenAIEmbeddingService): + raise Exception("Expecting AzureOpenAI embedding service") self.list_file_strategy = list_file_strategy self.blob_manager = blob_manager @@ -57,6 +61,7 @@ def __init__( self.search_analyzer_name = search_analyzer_name self.use_acls = use_acls self.category = category + self.search_info = search_info async def create_embedding_skill(self, index_name: str): skillset_name = f"{index_name}-skillset" @@ -114,9 +119,9 @@ async def create_embedding_skill(self, index_name: str): return skillset - async def setup(self, search_info: SearchInfo): + async def setup(self): search_manager = SearchManager( - search_info=search_info, + search_info=self.search_info, search_analyzer_name=self.search_analyzer_name, use_acls=self.use_acls, use_int_vectorization=True, @@ -130,7 +135,7 @@ async def setup(self, search_info: SearchInfo): await search_manager.create_index( vectorizers=[ AzureOpenAIVectorizer( - name=f"{search_info.index_name}-vectorizer", + name=f"{self.search_info.index_name}-vectorizer", kind="azureOpenAI", azure_open_ai_parameters=AzureOpenAIParameters( resource_uri=f"https://{self.embeddings.open_ai_service}.openai.azure.com", @@ -141,10 +146,10 @@ async def setup(self, search_info: SearchInfo): ) # create indexer client - ds_client = search_info.create_search_indexer_client() + ds_client = self.search_info.create_search_indexer_client() ds_container = SearchIndexerDataContainer(name=self.blob_manager.container) data_source_connection = SearchIndexerDataSourceConnection( - name=f"{search_info.index_name}-blob", + name=f"{self.search_info.index_name}-blob", type="azureblob", connection_string=self.blob_manager.get_managedidentity_connectionstring(), container=ds_container, @@ -152,13 +157,13 @@ async def setup(self, search_info: SearchInfo): ) await ds_client.create_or_update_data_source_connection(data_source_connection) - print("Search indexer data source connection updated.") + logger.info("Search indexer data source connection updated.") - embedding_skillset = await self.create_embedding_skill(search_info.index_name) + embedding_skillset = await self.create_embedding_skill(self.search_info.index_name) await ds_client.create_or_update_skillset(embedding_skillset) await ds_client.close() - async def run(self, search_info: SearchInfo): + async def run(self): if self.document_action == DocumentAction.Add: files = self.list_file_strategy.list() async for file in files: @@ -175,25 +180,25 @@ async def run(self, search_info: SearchInfo): await self.blob_manager.remove_blob() # Create an indexer - indexer_name = f"{search_info.index_name}-indexer" + indexer_name = f"{self.search_info.index_name}-indexer" indexer = SearchIndexer( name=indexer_name, description="Indexer to index documents and generate embeddings", - skillset_name=f"{search_info.index_name}-skillset", - target_index_name=search_info.index_name, - data_source_name=f"{search_info.index_name}-blob", + skillset_name=f"{self.search_info.index_name}-skillset", + target_index_name=self.search_info.index_name, + data_source_name=f"{self.search_info.index_name}-blob", # Map the metadata_storage_name field to the title field in the index to display the PDF title in the search results field_mappings=[FieldMapping(source_field_name="metadata_storage_name", target_field_name="title")], ) - indexer_client = search_info.create_search_indexer_client() + indexer_client = self.search_info.create_search_indexer_client() indexer_result = await indexer_client.create_or_update_indexer(indexer) # Run the indexer await indexer_client.run_indexer(indexer_name) await indexer_client.close() - print( + logger.info( f"Successfully created index, indexer: {indexer_result.name}, and skillset. Please navigate to search service in Azure Portal to view the status of the indexer." ) diff --git a/scripts/prepdocslib/listfilestrategy.py b/scripts/prepdocslib/listfilestrategy.py index d0b24876f1..c1885d4032 100644 --- a/scripts/prepdocslib/listfilestrategy.py +++ b/scripts/prepdocslib/listfilestrategy.py @@ -1,5 +1,6 @@ import base64 import hashlib +import logging import os import re import tempfile @@ -12,6 +13,8 @@ DataLakeServiceClient, ) +logger = logging.getLogger("ingester") + class File: """ @@ -32,7 +35,10 @@ def file_extension(self): def filename_to_id(self): filename_ascii = re.sub("[^0-9a-zA-Z_-]", "_", self.filename()) filename_hash = base64.b16encode(self.filename().encode("utf-8")).decode("ascii") - return f"file-{filename_ascii}-{filename_hash}" + acls_hash = "" + if self.acls: + acls_hash = base64.b16encode(str(self.acls).encode("utf-8")).decode("ascii") + return f"file-{filename_ascii}-{filename_hash}{acls_hash}" def close(self): if self.content: @@ -58,9 +64,8 @@ class LocalListFileStrategy(ListFileStrategy): Concrete strategy for listing files that are located in a local filesystem """ - def __init__(self, path_pattern: str, verbose: bool = False): + def __init__(self, path_pattern: str): self.path_pattern = path_pattern - self.verbose = verbose async def list_paths(self) -> AsyncGenerator[str, None]: async for p in self._list_paths(self.path_pattern): @@ -95,8 +100,7 @@ def check_md5(self, path: str) -> bool: stored_hash = md5_f.read() if stored_hash and stored_hash.strip() == existing_hash.strip(): - if self.verbose: - print(f"Skipping {path}, no changes detected.") + logger.info(f"Skipping {path}, no changes detected.") return True # Write the hash @@ -117,13 +121,11 @@ def __init__( data_lake_filesystem: str, data_lake_path: str, credential: Union[AsyncTokenCredential, str], - verbose: bool = False, ): self.data_lake_storage_account = data_lake_storage_account self.data_lake_filesystem = data_lake_filesystem self.data_lake_path = data_lake_path self.credential = credential - self.verbose = verbose async def list_paths(self) -> AsyncGenerator[str, None]: async with DataLakeServiceClient( @@ -167,8 +169,8 @@ async def list(self) -> AsyncGenerator[File, None]: acls["groups"].append(acl_parts[1]) yield File(content=open(temp_file_path, "rb"), acls=acls) except Exception as data_lake_exception: - print(f"\tGot an error while reading {path} -> {data_lake_exception} --> skipping file") + logger.error(f"\tGot an error while reading {path} -> {data_lake_exception} --> skipping file") try: os.remove(temp_file_path) except Exception as file_delete_exception: - print(f"\tGot an error while deleting {temp_file_path} -> {file_delete_exception}") + logger.error(f"\tGot an error while deleting {temp_file_path} -> {file_delete_exception}") diff --git a/scripts/prepdocslib/parser.py b/scripts/prepdocslib/parser.py index 8ed9ec8f84..09d12e0ad6 100644 --- a/scripts/prepdocslib/parser.py +++ b/scripts/prepdocslib/parser.py @@ -9,12 +9,6 @@ class Parser(ABC): Abstract parser that parses content into Page objects """ - def __init__( - self, - verbose: bool = False, - ): - self.verbose = verbose - async def parse(self, content: IO) -> AsyncGenerator[Page, None]: if False: yield # pragma: no cover - this is necessary for mypy to type check diff --git a/scripts/prepdocslib/pdfparser.py b/scripts/prepdocslib/pdfparser.py index 5486cd6a0f..d1c39aa0e3 100644 --- a/scripts/prepdocslib/pdfparser.py +++ b/scripts/prepdocslib/pdfparser.py @@ -1,4 +1,5 @@ import html +import logging from typing import IO, AsyncGenerator, Union from azure.ai.documentintelligence.aio import DocumentIntelligenceClient @@ -10,6 +11,8 @@ from .page import Page from .parser import Parser +logger = logging.getLogger("ingester") + class LocalPdfParser(Parser): """ @@ -18,8 +21,7 @@ class LocalPdfParser(Parser): """ async def parse(self, content: IO) -> AsyncGenerator[Page, None]: - if self.verbose: - print(f"\tExtracting text from '{content.name}' using local PDF parser (pypdf)") + logger.info(f"\tExtracting text from '{content.name}' using local PDF parser (pypdf)") reader = PdfReader(content) pages = reader.pages @@ -37,20 +39,14 @@ class DocumentAnalysisParser(Parser): """ def __init__( - self, - endpoint: str, - credential: Union[AsyncTokenCredential, AzureKeyCredential], - model_id="prebuilt-layout", - verbose: bool = False, + self, endpoint: str, credential: Union[AsyncTokenCredential, AzureKeyCredential], model_id="prebuilt-layout" ): self.model_id = model_id self.endpoint = endpoint self.credential = credential - self.verbose = verbose async def parse(self, content: IO) -> AsyncGenerator[Page, None]: - if self.verbose: - print(f"Extracting text from '{content.name}' using Azure Document Intelligence") + logger.info(f"Extracting text from '{content.name}' using Azure Document Intelligence") async with DocumentIntelligenceClient( endpoint=self.endpoint, credential=self.credential diff --git a/scripts/prepdocslib/searchmanager.py b/scripts/prepdocslib/searchmanager.py index 4f9e14bdc3..8356936612 100644 --- a/scripts/prepdocslib/searchmanager.py +++ b/scripts/prepdocslib/searchmanager.py @@ -1,4 +1,5 @@ import asyncio +import logging import os from typing import List, Optional @@ -25,6 +26,8 @@ from .strategy import SearchInfo from .textsplitter import SplitPage +logger = logging.getLogger("ingester") + class Section: """ @@ -60,8 +63,7 @@ def __init__( self.search_images = search_images async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] = None): - if self.search_info.verbose: - print(f"Ensuring search index {self.search_info.index_name} exists") + logger.info(f"Ensuring search index {self.search_info.index_name} exists") async with self.search_info.create_search_index_client() as search_index_client: fields = [ @@ -173,18 +175,12 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] ), ) if self.search_info.index_name not in [name async for name in search_index_client.list_index_names()]: - if self.search_info.verbose: - print(f"Creating {self.search_info.index_name} search index") + logger.info(f"Creating {self.search_info.index_name} search index") await search_index_client.create_index(index) else: - if self.search_info.verbose: - print(f"Search index {self.search_info.index_name} already exists") + logger.info(f"Search index {self.search_info.index_name} already exists") - async def update_content( - self, - sections: List[Section], - image_embeddings: Optional[List[List[float]]] = None, - ): + async def update_content(self, sections: List[Section], image_embeddings: Optional[List[List[float]]] = None): MAX_BATCH_SIZE = 1000 section_batches = [sections[i : i + MAX_BATCH_SIZE] for i in range(0, len(sections), MAX_BATCH_SIZE)] @@ -223,19 +219,20 @@ async def update_content( await search_client.upload_documents(documents) - async def remove_content(self, path: Optional[str] = None): - if self.search_info.verbose: - print(f"Removing sections from '{path or ''}' from search index '{self.search_info.index_name}'") + async def remove_content(self, path: Optional[str] = None, only_oid: Optional[str] = None): + logger.info(f"Removing sections from '{path or ''}' from search index '{self.search_info.index_name}'") async with self.search_info.create_search_client() as search_client: while True: filter = None if path is None else f"sourcefile eq '{os.path.basename(path)}'" result = await search_client.search("", filter=filter, top=1000, include_total_count=True) if await result.get_count() == 0: break - removed_docs = await search_client.delete_documents( - documents=[{"id": document["id"]} async for document in result] - ) - if self.search_info.verbose: - print(f"\tRemoved {len(removed_docs)} sections from index") + documents_to_remove = [] + async for document in result: + # If only_oid is set, only remove documents that have only this oid + if not only_oid or document["oids"] == [only_oid]: + documents_to_remove.append({"id": document["id"]}) + removed_docs = await search_client.delete_documents(documents_to_remove) + logger.info(f"\tRemoved {len(removed_docs)} sections from index") # It can take a few seconds for search results to reflect changes, so wait a bit await asyncio.sleep(2) diff --git a/scripts/prepdocslib/strategy.py b/scripts/prepdocslib/strategy.py index 4d4068c09b..e194dd64bd 100644 --- a/scripts/prepdocslib/strategy.py +++ b/scripts/prepdocslib/strategy.py @@ -16,17 +16,10 @@ class SearchInfo: To learn more, please visit https://learn.microsoft.com/azure/search/search-what-is-azure-search """ - def __init__( - self, - endpoint: str, - credential: Union[AsyncTokenCredential, AzureKeyCredential], - index_name: str, - verbose: bool = False, - ): + def __init__(self, endpoint: str, credential: Union[AsyncTokenCredential, AzureKeyCredential], index_name: str): self.endpoint = endpoint self.credential = credential self.index_name = index_name - self.verbose = verbose def create_search_client(self) -> SearchClient: return SearchClient(endpoint=self.endpoint, index_name=self.index_name, credential=self.credential) @@ -49,8 +42,8 @@ class Strategy(ABC): Abstract strategy for ingesting documents into a search service. It has a single setup step to perform any required initialization, and then a run step that actually ingests documents into the search service. """ - async def setup(self, search_info: SearchInfo): + async def setup(self): raise NotImplementedError - async def run(self, search_info: SearchInfo): + async def run(self): raise NotImplementedError diff --git a/scripts/prepdocslib/textsplitter.py b/scripts/prepdocslib/textsplitter.py index 8f06245a1d..4a715b1f40 100644 --- a/scripts/prepdocslib/textsplitter.py +++ b/scripts/prepdocslib/textsplitter.py @@ -1,3 +1,4 @@ +import logging from abc import ABC from typing import Generator, List @@ -5,6 +6,8 @@ from .page import Page, SplitPage +logger = logging.getLogger("ingester") + class TextSplitter(ABC): """ @@ -84,14 +87,13 @@ class SentenceTextSplitter(TextSplitter): Class that splits pages into smaller chunks. This is required because embedding models may not be able to analyze an entire page at once """ - def __init__(self, has_image_embeddings: bool, verbose: bool = False, max_tokens_per_section: int = 500): + def __init__(self, has_image_embeddings: bool, max_tokens_per_section: int = 500): self.sentence_endings = STANDARD_SENTENCE_ENDINGS + CJK_SENTENCE_ENDINGS self.word_breaks = STANDARD_WORD_BREAKS + CJK_WORD_BREAKS self.max_section_length = DEFAULT_SECTION_LENGTH self.sentence_search_limit = 100 self.max_tokens_per_section = max_tokens_per_section self.section_overlap = self.max_section_length // DEFAULT_OVERLAP_PERCENT - self.verbose = verbose self.has_image_embeddings = has_image_embeddings def split_page_by_max_tokens(self, page_num: int, text: str) -> Generator[SplitPage, None, None]: @@ -198,10 +200,9 @@ def find_page(offset): # If the section ends with an unclosed table, we need to start the next section with the table. # If table starts inside sentence_search_limit, we ignore it, as that will cause an infinite loop for tables longer than MAX_SECTION_LENGTH # If last table starts inside section_overlap, keep overlapping - if self.verbose: - print( - f"Section ends with unclosed table, starting next section with the table at page {find_page(start)} offset {start} table start {last_table_start}" - ) + logger.info( + f"Section ends with unclosed table, starting next section with the table at page {find_page(start)} offset {start} table start {last_table_start}" + ) start = min(end - self.section_overlap, start + last_table_start) else: start = end - self.section_overlap @@ -216,9 +217,8 @@ class SimpleTextSplitter(TextSplitter): This is required because embedding models may not be able to analyze an entire page at once """ - def __init__(self, max_object_length: int = 1000, verbose: bool = False): + def __init__(self, max_object_length: int = 1000): self.max_object_length = max_object_length - self.verbose = verbose def split_pages(self, pages: List[Page]) -> Generator[SplitPage, None, None]: all_text = "".join(page.text for page in pages) diff --git a/tests/test_blob_manager.py b/tests/test_blob_manager.py index 218856f5d2..fb054b0296 100644 --- a/tests/test_blob_manager.py +++ b/tests/test_blob_manager.py @@ -15,7 +15,6 @@ def blob_manager(monkeypatch): endpoint=f"https://{os.environ['AZURE_STORAGE_ACCOUNT']}.blob.core.windows.net", credential=MockAzureCredential(), container=os.environ["AZURE_STORAGE_CONTAINER"], - verbose=True, account=os.environ["AZURE_STORAGE_ACCOUNT"], resourceGroup=os.environ["AZURE_STORAGE_RESOURCE_GROUP"], subscriptionId=os.environ["AZURE_SUBSCRIPTION_ID"], diff --git a/tests/test_htmlparser.py b/tests/test_htmlparser.py index be21f14870..b582ab7da9 100644 --- a/tests/test_htmlparser.py +++ b/tests/test_htmlparser.py @@ -65,7 +65,7 @@ async def test_htmlparser_full(): """ ) file.name = "test.json" - htmlparser = LocalHTMLParser(verbose=True) + htmlparser = LocalHTMLParser() pages = [page async for page in htmlparser.parse(file)] assert len(pages) == 1 assert pages[0].page_num == 0 diff --git a/tests/test_listfilestrategy.py b/tests/test_listfilestrategy.py index bc72c4eba7..f35f11c008 100644 --- a/tests/test_listfilestrategy.py +++ b/tests/test_listfilestrategy.py @@ -47,6 +47,17 @@ def test_file_filename_to_id(): assert File(empty).filename_to_id() == "file-______pdf-E38395E382A1E382A4E383ABE5908D2E706466" +def test_file_filename_to_id_acls(): + empty = io.BytesIO() + empty.name = "foo.pdf" + filename_id = File(empty).filename_to_id() + filename_id2 = File(empty, acls={"oids": ["A-USER-ID"]}).filename_to_id() + filename_id3 = File(empty, acls={"groups": ["A-GROUP-ID"]}).filename_to_id() + filename_id4 = File(empty, acls={"oids": ["A-USER-ID"], "groups": ["A-GROUP-ID"]}).filename_to_id() + # Assert that all filenames are unique + assert len(set([filename_id, filename_id2, filename_id3, filename_id4])) == 4 + + @pytest.mark.asyncio async def test_locallistfilestrategy(): with tempfile.TemporaryDirectory() as tmpdirname: diff --git a/tests/test_prepdocs.py b/tests/test_prepdocs.py index 4343684bc0..7a4695a2db 100644 --- a/tests/test_prepdocs.py +++ b/tests/test_prepdocs.py @@ -1,3 +1,5 @@ +import logging + import openai import openai.types import pytest @@ -123,39 +125,37 @@ async def create_rate_limit_client(*args, **kwargs): @pytest.mark.asyncio -async def test_compute_embedding_ratelimiterror_batch(monkeypatch, capsys): - monkeypatch.setattr(tenacity.wait_random_exponential, "__call__", lambda x, y: 0) - with pytest.raises(tenacity.RetryError): - embeddings = AzureOpenAIEmbeddingService( - open_ai_service="x", - open_ai_deployment="x", - open_ai_model_name="text-embedding-ada-002", - credential=MockAzureCredential(), - disable_batch=False, - verbose=True, - ) - monkeypatch.setattr(embeddings, "create_client", create_rate_limit_client) - await embeddings.create_embeddings(texts=["foo"]) - captured = capsys.readouterr() - assert captured.out.count("Rate limited on the OpenAI embeddings API") == 14 +async def test_compute_embedding_ratelimiterror_batch(monkeypatch, caplog): + with caplog.at_level(logging.INFO): + monkeypatch.setattr(tenacity.wait_random_exponential, "__call__", lambda x, y: 0) + with pytest.raises(tenacity.RetryError): + embeddings = AzureOpenAIEmbeddingService( + open_ai_service="x", + open_ai_deployment="x", + open_ai_model_name="text-embedding-ada-002", + credential=MockAzureCredential(), + disable_batch=False, + ) + monkeypatch.setattr(embeddings, "create_client", create_rate_limit_client) + await embeddings.create_embeddings(texts=["foo"]) + assert caplog.text.count("Rate limited on the OpenAI embeddings API") == 14 @pytest.mark.asyncio -async def test_compute_embedding_ratelimiterror_single(monkeypatch, capsys): - monkeypatch.setattr(tenacity.wait_random_exponential, "__call__", lambda x, y: 0) - with pytest.raises(tenacity.RetryError): - embeddings = AzureOpenAIEmbeddingService( - open_ai_service="x", - open_ai_deployment="x", - open_ai_model_name="text-embedding-ada-002", - credential=MockAzureCredential(), - disable_batch=True, - verbose=True, - ) - monkeypatch.setattr(embeddings, "create_client", create_rate_limit_client) - await embeddings.create_embeddings(texts=["foo"]) - captured = capsys.readouterr() - assert captured.out.count("Rate limited on the OpenAI embeddings API") == 14 +async def test_compute_embedding_ratelimiterror_single(monkeypatch, caplog): + with caplog.at_level(logging.INFO): + monkeypatch.setattr(tenacity.wait_random_exponential, "__call__", lambda x, y: 0) + with pytest.raises(tenacity.RetryError): + embeddings = AzureOpenAIEmbeddingService( + open_ai_service="x", + open_ai_deployment="x", + open_ai_model_name="text-embedding-ada-002", + credential=MockAzureCredential(), + disable_batch=True, + ) + monkeypatch.setattr(embeddings, "create_client", create_rate_limit_client) + await embeddings.create_embeddings(texts=["foo"]) + assert caplog.text.count("Rate limited on the OpenAI embeddings API") == 14 class AuthenticationErrorMockEmbeddingsClient: @@ -177,7 +177,6 @@ async def test_compute_embedding_autherror(monkeypatch, capsys): open_ai_model_name="text-embedding-ada-002", credential=MockAzureCredential(), disable_batch=False, - verbose=True, ) monkeypatch.setattr(embeddings, "create_client", create_auth_error_limit_client) await embeddings.create_embeddings(texts=["foo"]) @@ -189,7 +188,6 @@ async def test_compute_embedding_autherror(monkeypatch, capsys): open_ai_model_name="text-embedding-ada-002", credential=MockAzureCredential(), disable_batch=True, - verbose=True, ) monkeypatch.setattr(embeddings, "create_client", create_auth_error_limit_client) await embeddings.create_embeddings(texts=["foo"]) diff --git a/tests/test_prepdocslib_textsplitter.py b/tests/test_prepdocslib_textsplitter.py index 3f8f82237c..6e5b1c6e66 100644 --- a/tests/test_prepdocslib_textsplitter.py +++ b/tests/test_prepdocslib_textsplitter.py @@ -17,13 +17,13 @@ def test_sentencetextsplitter_split_empty_pages(): - t = SentenceTextSplitter(False, True) + t = SentenceTextSplitter(has_image_embeddings=False) assert list(t.split_pages([])) == [] def test_sentencetextsplitter_split_small_pages(): - t = SentenceTextSplitter(has_image_embeddings=False, verbose=True) + t = SentenceTextSplitter(has_image_embeddings=False) split_pages = list(t.split_pages(pages=[Page(page_num=0, offset=0, text="Not a large page")])) assert len(split_pages) == 1 @@ -33,12 +33,12 @@ def test_sentencetextsplitter_split_small_pages(): @pytest.mark.asyncio async def test_sentencetextsplitter_list_parse_and_split(tmp_path, snapshot): - text_splitter = SentenceTextSplitter(False, True) - pdf_parser = LocalPdfParser(verbose=True) + text_splitter = SentenceTextSplitter(has_image_embeddings=False) + pdf_parser = LocalPdfParser() for pdf in Path("data").glob("*.pdf"): shutil.copy(str(pdf.absolute()), tmp_path) - list_file_strategy = LocalListFileStrategy(path_pattern=str(tmp_path / "*"), verbose=True) + list_file_strategy = LocalListFileStrategy(path_pattern=str(tmp_path / "*")) files = list_file_strategy.list() processed = 0 results = {} @@ -59,13 +59,13 @@ async def test_sentencetextsplitter_list_parse_and_split(tmp_path, snapshot): def test_simpletextsplitter_split_empty_pages(): - t = SimpleTextSplitter(True) + t = SimpleTextSplitter() assert list(t.split_pages([])) == [] def test_simpletextsplitter_split_small_pages(): - t = SimpleTextSplitter(verbose=True) + t = SimpleTextSplitter() split_pages = list(t.split_pages(pages=[Page(page_num=0, offset=0, text='{"test": "Not a large page"}')])) assert len(split_pages) == 1 @@ -75,7 +75,7 @@ def test_simpletextsplitter_split_small_pages(): def test_sentencetextsplitter_split_pages(): max_object_length = 10 - t = SimpleTextSplitter(max_object_length=max_object_length, verbose=True) + t = SimpleTextSplitter(max_object_length=max_object_length) split_pages = list(t.split_pages(pages=[Page(page_num=0, offset=0, text='{"test": "Not a large page"}')])) assert len(split_pages) == 3 @@ -98,13 +98,13 @@ def pytest_generate_tests(metafunc): @pytest.mark.asyncio async def test_sentencetextsplitter_multilang(test_doc, tmp_path): - text_splitter = SentenceTextSplitter(False, True) + text_splitter = SentenceTextSplitter(has_image_embeddings=False) bpe = tiktoken.encoding_for_model(ENCODING_MODEL) pdf_parser = LocalPdfParser() shutil.copy(str(test_doc.absolute()), tmp_path) - list_file_strategy = LocalListFileStrategy(path_pattern=str(tmp_path / "*"), verbose=True) + list_file_strategy = LocalListFileStrategy(path_pattern=str(tmp_path / "*")) files = list_file_strategy.list() processed = 0 async for file in files: @@ -132,7 +132,7 @@ async def test_sentencetextsplitter_multilang(test_doc, tmp_path): def test_split_tables(): - t = SentenceTextSplitter(has_image_embeddings=False, verbose=True) + t = SentenceTextSplitter(has_image_embeddings=False) test_text_without_table = """Contoso Electronics is a leader in the aerospace industry, providing advanced electronic components for both commercial and military aircraft. We specialize in creating cutting- diff --git a/tests/test_searchmanager.py b/tests/test_searchmanager.py index 121ac2852c..f04b6b1b11 100644 --- a/tests/test_searchmanager.py +++ b/tests/test_searchmanager.py @@ -34,7 +34,6 @@ def search_info(): endpoint="https://testsearchclient.blob.core.windows.net", credential=AzureKeyCredential("test"), index_name="test", - verbose=True, ) @@ -321,3 +320,79 @@ async def mock_delete_documents(self, documents): assert searched_filters[0] == "sourcefile eq 'foo.pdf'" assert len(deleted_documents) == 1, "It should have deleted one document" assert deleted_documents[0]["id"] == "file-foo_pdf-666F6F2E706466-page-0" + + +@pytest.mark.asyncio +async def test_remove_content_only_oid(monkeypatch, search_info): + class AsyncSearchResultsIterator: + def __init__(self): + self.results = [ + { + "@search.score": 1, + "id": "file-foo_pdf-666", + "content": "test content", + "category": "test", + "sourcepage": "foo.pdf#page=1", + "sourcefile": "foo.pdf", + "oids": [], + }, + { + "@search.score": 1, + "id": "file-foo_pdf-333", + "content": "test content", + "category": "test", + "sourcepage": "foo.pdf#page=1", + "sourcefile": "foo.pdf", + "oids": ["A-USER-ID", "B-USER-ID"], + }, + { + "@search.score": 1, + "id": "file-foo_pdf-222", + "content": "test content", + "category": "test", + "sourcepage": "foo.pdf#page=1", + "sourcefile": "foo.pdf", + "oids": ["A-USER-ID"], + }, + ] + + def __aiter__(self): + return self + + async def __anext__(self): + if len(self.results) == 0: + raise StopAsyncIteration + return self.results.pop() + + async def get_count(self): + return len(self.results) + + search_results = AsyncSearchResultsIterator() + + searched_filters = [] + + async def mock_search(self, *args, **kwargs): + self.filter = kwargs.get("filter") + searched_filters.append(self.filter) + return search_results + + monkeypatch.setattr(SearchClient, "search", mock_search) + + deleted_documents = [] + + async def mock_delete_documents(self, documents): + deleted_documents.extend(documents) + return documents + + monkeypatch.setattr(SearchClient, "delete_documents", mock_delete_documents) + + manager = SearchManager( + search_info, + ) + + await manager.remove_content("foo.pdf", only_oid="A-USER-ID") + + assert len(searched_filters) == 2, "It should have searched twice (with no results on second try)" + assert searched_filters[0] == "sourcefile eq 'foo.pdf'" + assert len(deleted_documents) == 1, "It should have deleted one document" + assert deleted_documents[0]["id"] == "file-foo_pdf-222"