Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 232 additions & 3 deletions langextract/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
from langextract import progress
from langextract import prompting
from langextract import resolver as resolver_lib
from langextract import retry_utils
from langextract.core import base_model
from langextract.core import data
from langextract.core import exceptions
from langextract.core import format_handler as fh
from langextract.core import types as core_types


class DocumentRepeatError(exceptions.LangExtractError):
Expand Down Expand Up @@ -202,6 +204,139 @@ def __init__(
"Annotator initialized with format_handler: %s", format_handler
)

def _process_batch_with_retry(
self,
batch_prompts: list[str],
batch: list[chunking.TextChunk],
retry_transient_errors: bool = True,
max_retries: int = 3,
retry_initial_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_max_delay: float = 60.0,
**kwargs,
) -> Iterator[list[core_types.ScoredOutput]]:
"""Process a batch of prompts with individual chunk retry capability.

This method processes each chunk individually and retries failed chunks
due to transient errors (like 503 "model overloaded") while preserving
successful chunks from the same batch.

Args:
batch_prompts: List of prompts for the batch
batch: List of TextChunk objects corresponding to the prompts
retry_transient_errors: Whether to retry on transient errors
max_retries: Maximum number of retry attempts
retry_initial_delay: Initial delay before retry
retry_backoff_factor: Backoff multiplier for retries
retry_max_delay: Maximum delay between retries
**kwargs: Additional arguments passed to the language model

Yields:
Lists of ScoredOutputs, with retries for failed chunks
"""
try:
batch_results = list(
self._language_model.infer(
batch_prompts=batch_prompts,
**kwargs,
)
)

yield from batch_results
return

except Exception as e:
if not retry_utils.is_transient_error(e):
raise

logging.warning(
"Batch processing failed with transient error: %s. "
"Falling back to individual chunk processing with retry.",
str(e),
)

individual_results = []

for i, (prompt, chunk) in enumerate(zip(batch_prompts, batch)):
try:
chunk_result = self._process_single_chunk_with_retry(
prompt=prompt,
chunk=chunk,
retry_transient_errors=retry_transient_errors,
max_retries=max_retries,
retry_initial_delay=retry_initial_delay,
retry_backoff_factor=retry_backoff_factor,
retry_max_delay=retry_max_delay,
**kwargs,
)
individual_results.append(chunk_result)

except Exception as e:
logging.error(
"Failed to process chunk %d after retries: %s. "
"Chunk info: document_id=%s, text_length=%d. "
"Stopping document processing.",
i,
str(e),
chunk.document_id,
len(chunk.chunk_text),
)
raise

yield from individual_results

def _process_single_chunk_with_retry(
self,
prompt: str,
chunk: chunking.TextChunk,
retry_transient_errors: bool = True,
max_retries: int = 3,
retry_initial_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_max_delay: float = 60.0,
**kwargs,
) -> list[core_types.ScoredOutput]:
"""Process a single chunk with retry logic.

Args:
prompt: The prompt for this chunk
chunk: The TextChunk object
retry_transient_errors: Whether to retry on transient errors
max_retries: Maximum number of retry attempts
retry_initial_delay: Initial delay before retry
retry_backoff_factor: Backoff multiplier for retries
retry_max_delay: Maximum delay between retries
**kwargs: Additional arguments for the language model

Returns:
List containing a single ScoredOutput for this chunk
"""

# Use the retry decorator with custom parameters
@retry_utils.retry_chunk_processing(
max_retries=max_retries,
initial_delay=retry_initial_delay,
backoff_factor=retry_backoff_factor,
max_delay=retry_max_delay,
enabled=retry_transient_errors,
)
def _process_chunk():
batch_results = list(
self._language_model.infer(
batch_prompts=[prompt],
**kwargs,
)
)

if not batch_results:
raise exceptions.InferenceOutputError(
f"No results returned for chunk in document {chunk.document_id}"
)

return batch_results[0]

return _process_chunk()

def annotate_documents(
self,
documents: Iterable[data.Document],
Expand All @@ -211,6 +346,11 @@ def annotate_documents(
debug: bool = True,
extraction_passes: int = 1,
show_progress: bool = True,
retry_transient_errors: bool = True,
max_retries: int = 3,
retry_initial_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_max_delay: float = 60.0,
**kwargs,
) -> Iterator[data.AnnotatedDocument]:
"""Annotates a sequence of documents with NLP extractions.
Expand All @@ -234,6 +374,11 @@ def annotate_documents(
Values > 1 reprocess tokens multiple times, potentially increasing
costs with the potential for a more thorough extraction.
show_progress: Whether to show progress bar. Defaults to True.
retry_transient_errors: Whether to retry on transient errors. Defaults to True.
max_retries: Maximum number of retry attempts. Defaults to 3.
retry_initial_delay: Initial delay before retry in seconds. Defaults to 1.0.
retry_backoff_factor: Backoff multiplier for retries. Defaults to 2.0.
retry_max_delay: Maximum delay between retries in seconds. Defaults to 60.0.
**kwargs: Additional arguments passed to LanguageModel.infer and Resolver.

Yields:
Expand All @@ -253,6 +398,11 @@ def annotate_documents(
batch_length,
debug,
show_progress,
retry_transient_errors,
max_retries,
retry_initial_delay,
retry_backoff_factor,
retry_max_delay,
**kwargs,
)
else:
Expand All @@ -264,6 +414,11 @@ def annotate_documents(
debug,
extraction_passes,
show_progress,
retry_transient_errors,
max_retries,
retry_initial_delay,
retry_backoff_factor,
retry_max_delay,
**kwargs,
)

Expand All @@ -275,9 +430,32 @@ def _annotate_documents_single_pass(
batch_length: int,
debug: bool,
show_progress: bool = True,
retry_transient_errors: bool = True,
max_retries: int = 3,
retry_initial_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_max_delay: float = 60.0,
**kwargs,
) -> Iterator[data.AnnotatedDocument]:
"""Single-pass annotation logic (original implementation)."""
"""Single-pass annotation logic (original implementation).

Args:
documents: Iterable of documents to annotate
resolver: Resolver for processing inference results
max_char_buffer: Maximum character buffer for chunking
batch_length: Number of chunks to process in each batch
debug: Whether to enable debug logging
show_progress: Whether to show progress bar
retry_transient_errors: Whether to retry on transient errors
max_retries: Maximum number of retry attempts
retry_initial_delay: Initial delay before retry
retry_backoff_factor: Backoff multiplier for retries
retry_max_delay: Maximum delay between retries
**kwargs: Additional arguments passed to language model

Yields:
AnnotatedDocument objects with extracted data
"""

logging.info("Starting document annotation.")
doc_iter, doc_iter_for_chunks = itertools.tee(documents, 2)
Expand Down Expand Up @@ -321,8 +499,15 @@ def _annotate_documents_single_pass(
)
progress_bar.set_description(desc)

batch_scored_outputs = self._language_model.infer(
# Process batch with individual chunk retry capability
batch_scored_outputs = self._process_batch_with_retry(
batch_prompts=batch_prompts,
batch=batch,
retry_transient_errors=retry_transient_errors,
max_retries=max_retries,
retry_initial_delay=retry_initial_delay,
retry_backoff_factor=retry_backoff_factor,
retry_max_delay=retry_max_delay,
**kwargs,
)

Expand Down Expand Up @@ -419,9 +604,33 @@ def _annotate_documents_sequential_passes(
debug: bool,
extraction_passes: int,
show_progress: bool = True,
retry_transient_errors: bool = True,
max_retries: int = 3,
retry_initial_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_max_delay: float = 60.0,
**kwargs,
) -> Iterator[data.AnnotatedDocument]:
"""Sequential extraction passes logic for improved recall."""
"""Sequential extraction passes logic for improved recall.

Args:
documents: Iterable of documents to annotate
resolver: Resolver for processing inference results
max_char_buffer: Maximum character buffer for chunking
batch_length: Number of chunks to process in each batch
debug: Whether to enable debug logging
extraction_passes: Number of extraction passes to perform
show_progress: Whether to show progress bar
retry_transient_errors: Whether to retry on transient errors
max_retries: Maximum number of retry attempts
retry_initial_delay: Initial delay before retry
retry_backoff_factor: Backoff multiplier for retries
retry_max_delay: Maximum delay between retries
**kwargs: Additional arguments passed to language model

Yields:
AnnotatedDocument objects with merged extracted data
"""

logging.info(
"Starting sequential extraction passes for improved recall with %d"
Expand All @@ -446,6 +655,11 @@ def _annotate_documents_sequential_passes(
batch_length,
debug=(debug and pass_num == 0),
show_progress=show_progress if pass_num == 0 else False,
retry_transient_errors=retry_transient_errors,
max_retries=max_retries,
retry_initial_delay=retry_initial_delay,
retry_backoff_factor=retry_backoff_factor,
retry_max_delay=retry_max_delay,
**kwargs,
):
doc_id = annotated_doc.document_id
Expand Down Expand Up @@ -494,6 +708,11 @@ def annotate_text(
debug: bool = True,
extraction_passes: int = 1,
show_progress: bool = True,
retry_transient_errors: bool = True,
max_retries: int = 3,
retry_initial_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_max_delay: float = 60.0,
**kwargs,
) -> data.AnnotatedDocument:
"""Annotates text with NLP extractions for text input.
Expand All @@ -511,6 +730,11 @@ def annotate_text(
standard single extraction. Values > 1 reprocess tokens multiple times,
potentially increasing costs.
show_progress: Whether to show progress bar. Defaults to True.
retry_transient_errors: Whether to retry on transient errors. Defaults to True.
max_retries: Maximum number of retry attempts. Defaults to 3.
retry_initial_delay: Initial delay before retry in seconds. Defaults to 1.0.
retry_backoff_factor: Backoff multiplier for retries. Defaults to 2.0.
retry_max_delay: Maximum delay between retries in seconds. Defaults to 60.0.
**kwargs: Additional arguments for inference and resolver_lib.

Returns:
Expand Down Expand Up @@ -540,6 +764,11 @@ def annotate_text(
debug,
extraction_passes,
show_progress,
retry_transient_errors,
max_retries,
retry_initial_delay,
retry_backoff_factor,
retry_max_delay,
**kwargs,
)
)
Expand Down
21 changes: 21 additions & 0 deletions langextract/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def extract(
prompt_validation_level: pv.PromptValidationLevel = pv.PromptValidationLevel.WARNING,
prompt_validation_strict: bool = False,
show_progress: bool = True,
retry_transient_errors: bool = True,
max_retries: int = 3,
retry_initial_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_max_delay: float = 60.0,
) -> typing.Any:
"""Extracts structured information from text.

Expand Down Expand Up @@ -150,6 +155,12 @@ def extract(
prompt_validation_strict: When True and prompt_validation_level is ERROR,
raises on non-exact matches (MATCH_FUZZY, MATCH_LESSER). Defaults to False.
show_progress: Whether to show progress bar during extraction. Defaults to True.
retry_transient_errors: Whether to automatically retry on transient errors
like 503 "model overloaded". Defaults to True.
max_retries: Maximum number of retry attempts for transient errors. Defaults to 3.
retry_initial_delay: Initial delay in seconds before first retry. Defaults to 1.0.
retry_backoff_factor: Multiplier for exponential backoff between retries. Defaults to 2.0.
retry_max_delay: Maximum delay between retries in seconds. Defaults to 60.0.

Returns:
An AnnotatedDocument with the extracted information when input is a
Expand Down Expand Up @@ -320,6 +331,16 @@ def extract(
format_handler=format_handler,
)

# Add retry parameters to alignment kwargs
retry_kwargs = {
"retry_transient_errors": retry_transient_errors,
"max_retries": max_retries,
"retry_initial_delay": retry_initial_delay,
"retry_backoff_factor": retry_backoff_factor,
"retry_max_delay": retry_max_delay,
}
alignment_kwargs.update(retry_kwargs)

if isinstance(text_or_documents, str):
return annotator.annotate_text(
text=text_or_documents,
Expand Down
2 changes: 2 additions & 0 deletions langextract/providers/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from absl import logging

from langextract import retry_utils
from langextract.core import base_model
from langextract.core import data
from langextract.core import exceptions
Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
k: v for k, v in (kwargs or {}).items() if k in _API_CONFIG_KEYS
}

@retry_utils.retry_chunk_processing()
def _process_single_prompt(
self, prompt: str, config: dict
) -> core_types.ScoredOutput:
Expand Down
Loading
Loading