|
| 1 | +import logging |
| 2 | +from abc import ABC |
| 3 | + |
| 4 | +import aiohttp |
| 5 | +from azure.core.credentials_async import AsyncTokenCredential |
| 6 | +from azure.identity.aio import get_bearer_token_provider |
| 7 | +from rich.progress import Progress |
| 8 | +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed |
| 9 | + |
| 10 | +logger = logging.getLogger("scripts") |
| 11 | + |
| 12 | + |
| 13 | +class MediaDescriber(ABC): |
| 14 | + |
| 15 | + async def describe_image(self, image_bytes) -> str: |
| 16 | + raise NotImplementedError # pragma: no cover |
| 17 | + |
| 18 | + |
| 19 | +class ContentUnderstandingDescriber: |
| 20 | + CU_API_VERSION = "2024-12-01-preview" |
| 21 | + |
| 22 | + analyzer_schema = { |
| 23 | + "analyzerId": "image_analyzer", |
| 24 | + "name": "Image understanding", |
| 25 | + "description": "Extract detailed structured information from images extracted from documents.", |
| 26 | + "baseAnalyzerId": "prebuilt-image", |
| 27 | + "scenario": "image", |
| 28 | + "config": {"returnDetails": False}, |
| 29 | + "fieldSchema": { |
| 30 | + "name": "ImageInformation", |
| 31 | + "descriptions": "Description of image.", |
| 32 | + "fields": { |
| 33 | + "Description": { |
| 34 | + "type": "string", |
| 35 | + "description": "Description of the image. If the image has a title, start with the title. Include a 2-sentence summary. If the image is a chart, diagram, or table, include the underlying data in an HTML table tag, with accurate numbers. If the image is a chart, describe any axis or legends. The only allowed HTML tags are the table/thead/tr/td/tbody tags.", |
| 36 | + }, |
| 37 | + }, |
| 38 | + }, |
| 39 | + } |
| 40 | + |
| 41 | + def __init__(self, endpoint: str, credential: AsyncTokenCredential): |
| 42 | + self.endpoint = endpoint |
| 43 | + self.credential = credential |
| 44 | + |
| 45 | + async def poll_api(self, session, poll_url, headers): |
| 46 | + |
| 47 | + @retry(stop=stop_after_attempt(60), wait=wait_fixed(2), retry=retry_if_exception_type(ValueError)) |
| 48 | + async def poll(): |
| 49 | + async with session.get(poll_url, headers=headers) as response: |
| 50 | + response.raise_for_status() |
| 51 | + response_json = await response.json() |
| 52 | + if response_json["status"] == "Failed": |
| 53 | + raise Exception("Failed") |
| 54 | + if response_json["status"] == "Running": |
| 55 | + raise ValueError("Running") |
| 56 | + return response_json |
| 57 | + |
| 58 | + return await poll() |
| 59 | + |
| 60 | + async def create_analyzer(self): |
| 61 | + logger.info("Creating analyzer '%s'...", self.analyzer_schema["analyzerId"]) |
| 62 | + |
| 63 | + token_provider = get_bearer_token_provider(self.credential, "https://cognitiveservices.azure.com/.default") |
| 64 | + token = await token_provider() |
| 65 | + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} |
| 66 | + params = {"api-version": self.CU_API_VERSION} |
| 67 | + analyzer_id = self.analyzer_schema["analyzerId"] |
| 68 | + cu_endpoint = f"{self.endpoint}/contentunderstanding/analyzers/{analyzer_id}" |
| 69 | + async with aiohttp.ClientSession() as session: |
| 70 | + async with session.put( |
| 71 | + url=cu_endpoint, params=params, headers=headers, json=self.analyzer_schema |
| 72 | + ) as response: |
| 73 | + if response.status == 409: |
| 74 | + logger.info("Analyzer '%s' already exists.", analyzer_id) |
| 75 | + return |
| 76 | + elif response.status != 201: |
| 77 | + data = await response.text() |
| 78 | + raise Exception("Error creating analyzer", data) |
| 79 | + else: |
| 80 | + poll_url = response.headers.get("Operation-Location") |
| 81 | + |
| 82 | + with Progress() as progress: |
| 83 | + progress.add_task("Creating analyzer...", total=None, start=False) |
| 84 | + await self.poll_api(session, poll_url, headers) |
| 85 | + |
| 86 | + async def describe_image(self, image_bytes: bytes) -> str: |
| 87 | + logger.info("Sending image to Azure Content Understanding service...") |
| 88 | + async with aiohttp.ClientSession() as session: |
| 89 | + token = await self.credential.get_token("https://cognitiveservices.azure.com/.default") |
| 90 | + headers = {"Authorization": "Bearer " + token.token} |
| 91 | + params = {"api-version": self.CU_API_VERSION} |
| 92 | + analyzer_name = self.analyzer_schema["analyzerId"] |
| 93 | + async with session.post( |
| 94 | + url=f"{self.endpoint}/contentunderstanding/analyzers/{analyzer_name}:analyze", |
| 95 | + params=params, |
| 96 | + headers=headers, |
| 97 | + data=image_bytes, |
| 98 | + ) as response: |
| 99 | + response.raise_for_status() |
| 100 | + poll_url = response.headers["Operation-Location"] |
| 101 | + |
| 102 | + with Progress() as progress: |
| 103 | + progress.add_task("Processing...", total=None, start=False) |
| 104 | + results = await self.poll_api(session, poll_url, headers) |
| 105 | + |
| 106 | + fields = results["result"]["contents"][0]["fields"] |
| 107 | + return fields["Description"]["valueString"] |
0 commit comments