diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..672a2a3 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +/.github +/.venv +/tmp +/.env +/.env.* +/docker-compose.yml +/Dockerfile +/README.md \ No newline at end of file diff --git a/.github/actions/create-image/action.yaml b/.github/actions/create-image/action.yaml new file mode 100644 index 0000000..2cc2d1e --- /dev/null +++ b/.github/actions/create-image/action.yaml @@ -0,0 +1,67 @@ +name: "Create Docker Image" +description: "Builds a docker image and tags it" +inputs: + IMAGE_NAME: + description: "The image name" + required: true + VERSION: + description: "The version of the image" + required: true + TAG: + description: "The tag of the image, in addition to the version" + required: true + OTHER_TAGS: + description: "Any additional tags, passed directly to docker/metadata-action" + DOCKERFILE: + description: "The path to the Dockerfile" + required: true + GITHUB_TOKEN: + description: "The github token" + required: true + +runs: + using: "composite" + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set environment variables + shell: bash + run: | + echo COMMITED_AT=$(git show -s --format=%cI ${{ github.sha }}) >> $GITHUB_ENV + echo REVISION=$(git rev-parse --short HEAD) >> $GITHUB_ENV + + - name: Collect docker image metadata + id: meta-data + uses: docker/metadata-action@v5 + with: + images: ${{ inputs.IMAGE_NAME }} + labels: | + org.opencontainers.image.created=${{ env.COMMITED_AT }} + org.opencontainers.image.version=v${{ inputs.VERSION }} + org.opencontainers.image.maintainer=EBP Schweiz AG + flavor: | + latest=${{ inputs.TAG == 'latest' }} + tags: | + type=raw,value=${{ inputs.TAG }} + type=raw,value=${{ inputs.VERSION }} + ${{ inputs.OTHER_TAGS }} + + - name: Log in to the GitHub container registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ inputs.GITHUB_TOKEN }} + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: ./ + file: ${{ inputs.DOCKERFILE }} + push: true + tags: ${{ steps.meta-data.outputs.tags }} + labels: ${{ steps.meta-data.outputs.labels }} + no-cache: true + build-args: | + APP_VERSION=${{ inputs.VERSION }} diff --git a/.github/actions/tag-commit/action.yaml b/.github/actions/tag-commit/action.yaml new file mode 100644 index 0000000..d178e9b --- /dev/null +++ b/.github/actions/tag-commit/action.yaml @@ -0,0 +1,34 @@ +name: "Tag Commit" +description: "Creates or updates a commit tag" +inputs: + TAG_NAME: + description: "The tag's name" + required: true + SHA: + description: "The SHA of the commit to be tagged" + required: true + +runs: + using: "composite" + steps: + - name: Create/update tag + uses: actions/github-script@v7 + env: + TAG: ${{ inputs.TAG_NAME }} + SHA: ${{ inputs.SHA }} + with: + script: | + github.rest.git.createRef({ + owner: context.repo.owner, + repo: context.repo.repo, + ref: `refs/tags/${process.env.TAG}`, + sha: process.env.SHA + }).catch(err => { + if (err.status !== 422) throw err; + github.rest.git.updateRef({ + owner: context.repo.owner, + repo: context.repo.repo, + ref: `tags/${process.env.TAG}`, + sha: process.env.SHA + }); + }) diff --git a/.github/scripts/find-version.js b/.github/scripts/find-version.js new file mode 100644 index 0000000..a459912 --- /dev/null +++ b/.github/scripts/find-version.js @@ -0,0 +1,115 @@ +const findNextVersion = (tags, branch) => { + const version = findMostRecentVersion(tags); + if (version == null) { + return { + major: 1, + minor: 0, + patch: 0, + preRelease: 1, + }; + } + if (branch.startsWith("feature/")) { + // It's a minor feature. + + // If the previous version was a full release or a patch dev release, + // we are a completely new minor dev release. + // Otherwise, the previous version was itself a minor dev release, + // and we can reuse its number. + if (version.preRelease == null || version.patch !== 0) { + version.minor += 1; + version.patch = 0; + } + } else { + // It's a patch. + + // If the previous version was a full release, + // we are a completely new patch dev release. + // Otherwise, we can simply reuse the previous version's number. + if (version.preRelease == null) { + version.patch += 1; + } + } + + version.preRelease ??= 0; + version.preRelease += 1; + return version; +}; + +const findMostRecentVersion = (tags) => { + const versions = findAllVersions(tags); + if (versions.length === 0) { + return null; + } + return versions[0]; +}; + +const findOutdatedVersions = (tags, recentTag) => { + const recentVersion = parseVersion(recentTag); + if (recentVersion == null) { + throw new Error(`recent tag '${recentTag}' is not a version number`); + } + const versions = findAllVersions(tags); + return versions.filter( + (version) => + // Select all pre-releases that appear before the most recent one. + version.preRelease != null && compareVersions(recentVersion, version) > 0 + ); +}; + +const findAllVersions = (tags) => { + return tags + .map(parseVersion) + .filter((it) => it != null) + .sort((a, b) => compareVersions(a, b) * -1); +}; + +const SEMANTIC_VERSION_PATTERN = /^\d+\.\d+\.\d+(?:-dev\d+)?$/; +const parseVersion = (tag) => { + if (!SEMANTIC_VERSION_PATTERN.test(tag)) { + return null; + } + const [major, minor, patch, preRelease] = tag.split(/[.\-]/); + return { + major: parseInt(major), + minor: parseInt(minor), + patch: parseInt(patch), + preRelease: preRelease && parseInt(preRelease.substring(3)), + }; +}; + +const compareVersions = (a, b) => { + if (a.major !== b.major) { + return a.major - b.major; + } + if (a.minor !== b.minor) { + return a.minor - b.minor; + } + if (a.patch !== b.patch) { + return a.patch - b.patch; + } + if (a.preRelease !== b.preRelease) { + if (a.preRelease == null) { + return 1; + } + if (b.preRelease == null) { + return -1; + } + return a.preRelease - b.preRelease; + } + return 0; +}; + +const makeVersionTag = ({ major, minor, patch, preRelease }) => { + const tag = `${major}.${minor}.${patch}`; + if (preRelease == null) { + return tag; + } + return `${tag}-dev${preRelease}`; +}; + +module.exports = { + findNextVersion, + findMostRecentVersion, + findOutdatedVersions, + makeVersionTag, +}; diff --git a/.github/scripts/remove-packages.js b/.github/scripts/remove-packages.js new file mode 100644 index 0000000..b5073ea --- /dev/null +++ b/.github/scripts/remove-packages.js @@ -0,0 +1,51 @@ +const { Octokit } = require("@octokit/rest"); + +const removePackageVersions = async (imageUrl, imageVersions) => { + const octokit = new Octokit({ + auth: process.env.GITHUB_TOKEN, + }); + + const [_imageHost, imageOwner, imageName] = imageUrl.split("/"); + const imageIds = await loadOutdatedVersionIds(octokit, imageOwner, imageName, imageVersions); + for (const imageId of imageIds) { + await octokit.rest.packages.deletePackageVersionForOrg({ + package_type: "container", + package_name: imageName, + org: imageOwner, + package_version_id: imageId, + }); + } +}; + +const loadOutdatedVersionIds = async (octokit, imageOwner, imageName, versions) => { + let page = 0; + versions = new Set(versions); + + const ids = new Set(); + while (true) { + const response = await octokit.rest.packages.getAllPackageVersionsForPackageOwnedByOrg({ + package_type: "container", + package_name: imageName, + org: imageOwner, + page, + }); + if (response.data.length === 0) { + break; + } + for (const entry of response.data) { + // Match any of the requested version's ids, + // as well as any ids that do not have a tag anymore, i.e. are fully unused. + const { tags } = entry.metadata.container; + const matchedTags = tags.filter((tag) => versions.delete(tag)); + if (tags.length === 0 || matchedTags.length !== 0) { + ids.add(entry.id); + } + } + page += 1; + } + return ids; +}; + +module.exports = { + removePackageVersions, +}; diff --git a/.github/workflows/publish-edge.yml b/.github/workflows/publish-edge.yml new file mode 100644 index 0000000..0c9e2fe --- /dev/null +++ b/.github/workflows/publish-edge.yml @@ -0,0 +1,112 @@ +name: Publish Edge + +on: + push: + branches: + - "develop" + + workflow_dispatch: + inputs: + version: + type: string + description: | + Version number (e.g. 1.2.3-dev1). + Leave empty to determine the next version automatically. + required: false + default: "" + is-edge: + type: boolean + description: "Tag the commit and published image with `edge`." + default: true + +permissions: write-all + +env: + IS_EDGE: ${{ github.event_name == 'push' || github.event.inputs.is-edge == 'true' }} + +jobs: + determine_version: + name: "determine version" + runs-on: ubuntu-latest + outputs: + version: ${{ steps.find_version.outputs.result || github.event.inputs.version }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + if: ${{ github.event.inputs.version == '' }} + - name: Get tags of edge commit + id: get_edge_tags + if: ${{ github.event.inputs.version == '' }} + run: | + git fetch --tags + EDGE_COMMIT=$(git rev-list -n 1 edge 2>/dev/null || git rev-parse HEAD) + EDGE_TAGS=$(printf "%s," $(git tag --contains $EDGE_COMMIT)) + EDGE_TAGS=${EDGE_TAGS%,} + echo "edge_tags=$EDGE_TAGS" >> "$GITHUB_OUTPUT" + - name: Find next version + id: find_version + if: ${{ github.event.inputs.version == '' }} + uses: actions/github-script@v7 + env: + EDGE_TAGS: ${{ steps.get_edge_tags.outputs.edge_tags }} + with: + result-encoding: string + script: | + const { findNextVersion } = require('./.github/scripts/find-version.js'); + const tags = process.env.EDGE_TAGS.split(','); + const targetBranch = context.payload.ref.replace('refs/heads/', ''); + + const pullRequests = await github.rest.pulls.list({ + owner: context.repo.owner, + repo: context.repo.repo, + state: 'closed', + base: targetBranch, + sort: 'updated', + direction: 'desc' + }); + + const mergedPullRequest = pullRequests.data.find(pr => pr.merge_commit_sha === context.payload.after); + const sourceBranch = mergedPullRequest == null + ? targetBranch + : mergedPullRequest.head.ref.replace('refs/heads/', '') + + const version = findNextVersion(tags, sourceBranch); + return `${version.major}.${version.minor}.${version.patch}-dev${version.preRelease}`; + + build_and_push_api: + name: "build and push api" + needs: + - determine_version + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Create image + uses: ./.github/actions/create-image + with: + IMAGE_NAME: ${{ vars.BASE_IMAGE_NAME }}-api + TAG: ${{ env.IS_EDGE == 'true' && 'edge' || '' }} + VERSION: ${{ needs.determine_version.outputs.version }} + DOCKERFILE: Dockerfile + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + tag_commit: + name: "tag commit" + needs: + - determine_version + - build_and_push_api + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: tag edge + if: ${{ env.IS_EDGE == 'true' }} + uses: ./.github/actions/tag-commit + with: + TAG_NAME: edge + SHA: ${{ github.sha }} + - name: tag version + uses: ./.github/actions/tag-commit + with: + TAG_NAME: ${{ needs.determine_version.outputs.version }} + SHA: ${{ github.sha }} diff --git a/.github/workflows/publish-rc.yml b/.github/workflows/publish-rc.yml new file mode 100644 index 0000000..43ef96e --- /dev/null +++ b/.github/workflows/publish-rc.yml @@ -0,0 +1,58 @@ +name: Publish Release Candidate + +on: + push: + branches: + - "main" + + workflow_dispatch: + inputs: + base: + type: string + description: | + The tag of the commit that will be published as release-candidate. + Make sure that you also select that tag as the workflow's run location. + required: false + default: "edge" + +permissions: write-all + +env: + BASE: ${{ github.event.inputs.base || 'edge' }} + +jobs: + tag_rc_image_api: + name: tag rc image api + runs-on: ubuntu-latest + steps: + - name: Login to GitHub Packages + run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u ${{ github.repository_owner }} --password-stdin + + - name: Pull docker image + run: docker pull ${{ vars.BASE_IMAGE_NAME }}-api:${{ env.BASE }} + + - name: Tag docker image + run: docker tag ${{ vars.BASE_IMAGE_NAME }}-api:${{ env.BASE }} ${{ vars.BASE_IMAGE_NAME }}-api:release-candidate + + - name: Push docker image + run: docker push ${{ vars.BASE_IMAGE_NAME }}-api:release-candidate + + tag_rc_commit: + name: "tag rc commit" + needs: + - tag_rc_image_api + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Get base commit + id: get_base_commit + run: | + git fetch --tags + BASE_COMMIT=$(git rev-list -n 1 $BASE) + echo "sha=$BASE_COMMIT" >> "$GITHUB_OUTPUT" + - name: tag release-candidate + uses: ./.github/actions/tag-commit + with: + TAG_NAME: release-candidate + SHA: ${{ steps.get_base_commit.outputs.sha }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..842173b --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,173 @@ +name: Release + +on: + workflow_dispatch: + inputs: + base: + type: string + description: | + The tag of the commit that will be released. + Make sure that you also select that tag as the workflow's run location. + required: false + default: "release-candidate" + is-edge: + type: boolean + description: | + Assign the `edge` tag to this release. + default: true + is-release-candidate: + type: boolean + description: | + Assign the `release-candidate` tag to this release. + default: true + +permissions: write-all + +env: + BASE: ${{ github.event.inputs.base || 'release-candidate' }} + IS_EDGE: ${{ github.event.inputs.is-edge == 'true' }} + IS_RC: ${{ github.event.inputs.is-release-candidate == 'true' }} + +jobs: + determine_version: + name: "determine version" + runs-on: ubuntu-latest + outputs: + version: ${{ steps.find_version.outputs.result }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Get tags of base commit + id: get_base_tags + run: | + git fetch --tags + BASE_COMMIT=$(git rev-list -n 1 release-candidate) + BASE_TAGS=$(printf "%s," $(git tag --contains $BASE_COMMIT)) + BASE_TAGS=${BASE_TAGS%,} + echo "base_tags=$BASE_TAGS" >> "$GITHUB_OUTPUT" + - name: Find next version + id: find_version + uses: actions/github-script@v7 + env: + BASE_TAGS: ${{ steps.get_base_tags.outputs.base_tags }} + with: + result-encoding: string + script: | + const { findMostRecentVersion, makeVersionTag } = require('./.github/scripts/find-version.js'); + const tags = process.env.BASE_TAGS.split(','); + const version = findMostRecentVersion(tags); + version.preRelease = null; + return makeVersionTag(version); + + build_and_push_api: + name: "build and push api" + needs: + - determine_version + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Create image + uses: ./.github/actions/create-image + with: + IMAGE_NAME: ${{ vars.BASE_IMAGE_NAME }}-api + TAG: latest + OTHER_TAGS: | + type=raw,value=${{ env.IS_EDGE == 'true' && 'edge' || '' }} + type=raw,value=${{ env.IS_RC == 'true' && 'release-candidate' || '' }} + VERSION: ${{ needs.determine_version.outputs.version }} + DOCKERFILE: ./apps/server-asset-sg/docker/Dockerfile + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + tag_commit: + name: "tag commit" + needs: + - determine_version + - build_and_push_api + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: tag latest + uses: ./.github/actions/tag-commit + with: + TAG_NAME: latest + SHA: ${{ github.sha }} + - name: Tag release-candidate + if: ${{ env.IS_RC == 'true' }} + uses: ./.github/actions/tag-commit + with: + TAG_NAME: release-candidate + SHA: ${{ github.sha }} + - name: Tag edge + if: ${{ env.IS_EDGE == 'true' }} + uses: ./.github/actions/tag-commit + with: + TAG_NAME: edge + SHA: ${{ github.sha }} + - name: Tag version + uses: ./.github/actions/tag-commit + with: + TAG_NAME: ${{ needs.determine_version.outputs.version }} + SHA: ${{ github.sha }} + + create_release: + name: "create release" + needs: + - determine_version + - build_and_push_api + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Create release + uses: softprops/action-gh-release@v2 + with: + tag_name: "${{ needs.determine_version.outputs.version }}" + name: "swissgeol-ocr-api v${{ needs.determine_version.outputs.version }}" + generate_release_notes: true + make_latest: true + + cleanup: + name: "cleanup" + needs: + - determine_version + - create_release + - tag_commit + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Setup node + run: | + npm install @octokit/rest + - name: Get tags + id: get_tags + run: | + git fetch --tags + TAGS=$(printf "%s," $(git tag)) + TAGS=${TAGS%,} + echo "tags=$TAGS" >> "$GITHUB_OUTPUT" + - name: Remove outdated versions + uses: actions/github-script@v7 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + BASE_IMAGE_NAME: ${{ vars.BASE_IMAGE_NAME }} + CURRENT_VERSION: ${{ needs.determine_version.outputs.version }} + TAGS: ${{ steps.get_tags.outputs.tags }} + with: + script: | + const { findOutdatedVersions, makeVersionTag } = require('./.github/scripts/find-version.js'); + const { removePackageVersions } = require('./.github/scripts/remove-packages.js'); + + const tags = process.env.TAGS.split(','); + const outdatedVersions = findOutdatedVersions(tags, process.env.CURRENT_VERSION).map(makeVersionTag); + for (const version of outdatedVersions) { + await github.rest.git.deleteRef({ + owner: context.repo.owner, + repo: context.repo.repo, + ref: `tags/${version}`, + }); + } + + await removePackageVersions(`${process.env.BASE_IMAGE_NAME}-api`, outdatedVersions); diff --git a/.gitignore b/.gitignore index 8fabc93..154dd07 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,10 @@ /venv +/.venv /.env* + +__pycache__ + +# JetBrains IDE configurations +.idea/ + +/tmp \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..640411c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,9 @@ +FROM python:3.12-alpine3.19 + +RUN apk add --quiet --no-cache ghostscript + +WORKDIR /app +COPY . . + +RUN pip install --root-user-action=ignore -r requirements.txt --quiet +ENTRYPOINT ["fastapi", "run", "api.py"] \ No newline at end of file diff --git a/README.md b/README.md index f64ad6d..f00629f 100644 --- a/README.md +++ b/README.md @@ -1,83 +1,139 @@ -# Swissgeol.ch OCR service +# swissgeol.ch OCR service -Source code for the OCR scripts that are used at Swiss [ -Federal Office of Topography swisstopo](https://www.swisstopo.admin.ch/) for digitising geological documents for internal use as well as for publication on the [swissgeol.ch](https://www.swissgeol.ch/) platform. +Source code for the OCR scripts that are used at the Swiss [ +Federal Office of Topography swisstopo](https://www.swisstopo.admin.ch/) +for digitising geological documents for internal use as well as for publication on +the [swissgeol.ch](https://www.swissgeol.ch/) platform. -The script `main.py` processes PDF files, calls the [AWS Textract](https://aws.amazon.com/de/textract/) service for each page to apply OCR, and uses the PyMuPDF and Reportlab libraries to put the detected text into the PDF document (enabling selecting and searching for text in any PDF viewer). +OCR processing is supported both in script form and as REST API. +To process PDF files, the [AWS Textract](https://aws.amazon.com/de/textract/) service is called for each page. +The detected text is then put into the PDF document by using the PyMuPDF and Reportlab libraries. +This enables selecting and searching for text in any PDF viewer. -The resulting functionality is similar to the [OCRmyPDF](https://ocrmypdf.readthedocs.io/en/latest/) software, but with AWS Textract as the underlying OCR model instead of [Tesseract](https://tesseract-ocr.github.io/). Tesseract is open-source while AWS Textract is a commercial API. However, AWS Textract is more scalable and gives better quality results on our inputs, which is more important for our use cases. +The resulting functionality is similar to the [OCRmyPDF](https://ocrmypdf.readthedocs.io/en/latest/) software, +but with AWS Textract as the underlying OCR model instead of [Tesseract](https://tesseract-ocr.github.io/). +Tesseract is open-source while AWS Textract is a commercial API. +However, AWS Textract is more scalable and gives better quality results on our inputs, +which is more important for our use cases. Additional features: + - If necessary, PDF pages rescaled, and images are cropped and/or converted from JPX to JPG. - PDF pages that are already "digitally born" are detected, and can be skipped when applying OCR. -- When a scanned PDF page already contains digital text from an older OCR run, this text can be removed, and the OCR can be re-applied. -- Pages with large dimensions are cut into smaller sections, that are sent separately to the AWS Textract service in multiple requests. Indeed, AWS Textract has certain [limits on file size and page dimensions](https://docs.aws.amazon.com/textract/latest/dg/limits-document.html), and even within those limits, the quality of the results is better when the input dimensions are smaller. - -### Roadmap - -- Allow deploying this OCR script as a microservice (adding an API, logging and monitoring, configurability. etc.), that can be integrated into the applications [assets.swissgeol.ch](https://assets.swissgeol.ch/) ([Github Repo](https://github.com/swisstopo/swissgeol-assets-suite)) and [boreholes.swissgeol.ch](https://boreholes.swissgeol.ch/) ([Github Repo](https://github.com/swisstopo/swissgeol-boreholes-suite)). +- When a scanned PDF page already contains digital text from an older OCR run, this text can be removed, and the OCR can + be re-applied. +- Pages with large dimensions are cut into smaller sections, that are sent separately to the AWS Textract service in + multiple requests. Indeed, AWS Textract has + certain [limits on file size and page dimensions](https://docs.aws.amazon.com/textract/latest/dg/limits-document.html), + and even within those limits, the quality of the results is better when the input dimensions are smaller. ## Installation Example using a virtual environment and `pip install`: + ``` python -m venv venv source venv/bin/activate pip install -r requirements.txt ``` -## Configuration examples +## Usage +The script can be executed like any normal Python script file: +```bash +python main.py +``` + +The API is built on [FastAPI](https://fastapi.tiangolo.com/) and can be run by its CLI: +```bash +fastapi run api.py +``` + +## Configuration Environment variables are read from the file `.env`. -If an environment variable `OCR_PROFILE` is specified, then environment variables are additionally read from `.env.{OCR_PROFILE}`, with the values from this file potentially overriding the values from `.env`. +If an environment variable `OCR_PROFILE` is specified, then environment variables are additionally read +from `.env.{OCR_PROFILE}`, with the values from this file potentially overriding the values from `.env`. For example, run the script as `OCR_PROFILE=assets python -m main` to use the environment variables from `.env.assets`. -### AWS credentials +> The API and Script require different configurations. +> Please ensure that you are using the correct environment variables depending on what you want to execute. + +### Script Configuration -AWS credentials can be provided using a [credentials file](https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-files.html) (`~/.aws/credentials`). The environment variable `AWS_TEXTRACT_PROFILE` in the configuration examples below refers to a profile in that file. +AWS credentials can be provided using +a [credentials file](https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-files.html) (`~/.aws/credentials`). The +environment variable `AWS_TEXTRACT_PROFILE` in the configuration examples below refers to a profile in that file. -### `.env.assets` +#### `.env.assets` - Reads and writes asset files directly from/to S3. - Applies OCR more defensively, only to pages without pre-existing visible digital text. -- Uses a higher confidence threshold (0.7), because for publication on assets.swissgeol.ch, we'd rather not put any OCR'd text in the document at all, rather than putting nonsense in the document. +- Uses a higher confidence threshold (0.7), because for publication on assets.swissgeol.ch, we'd rather not put any OCR' + d text in the document at all, rather than putting nonsense in the document. -``` +```sh AWS_TEXTRACT_PROFILE=default -OCR_INPUT_TYPE=S3 -OCR_INPUT_AWS_PROFILE=s3-assets -OCR_INPUT_S3_BUCKET=swissgeol-assets-swisstopo -OCR_INPUT_S3_PREFIX=asset/asset_files/ -OCR_INPUT_CLEANUP_TMP_FILES=TRUE -OCR_INPUT_IGNORE_EXISTING=TRUE +INPUT_TYPE=S3 +INPUT_AWS_PROFILE=s3-assets +INPUT_S3_BUCKET=swissgeol-assets-swisstopo +INPUT_S3_PREFIX=asset/asset_files/ +INPUT_IGNORE_EXISTING=TRUE -OCR_OUTPUT_TYPE=S3 -OCR_OUTPUT_AWS_PROFILE=s3-assets -OCR_OUTPUT_S3_BUCKET=swissgeol-assets-swisstopo -OCR_OUTPUT_S3_PREFIX=asset/asset_files_new_ocr/ -OCR_OUTPUT_CLEANUP_TMP_FILES=TRUE +OUTPUT_TYPE=S3 +OUTPUT_AWS_PROFILE=s3-assets +OUTPUT_S3_BUCKET=swissgeol-assets-swisstopo +OUTPUT_S3_PREFIX=asset/asset_files_new_ocr/ CONFIDENCE_THRESHOLD=0.7 +CLEANUP_TMP_FILES=TRUE ``` -### `env.boreholes` +#### `env.boreholes` - Read and writes files from/to a local directory. -- Applies OCR more aggressively, also e.g. to images inside digitally-born PDF documents, as long as the newly detected text does not overlap with any pre-existing digital text. -- Uses a lower confidence threshold (0.45), as especially for extracting stratigraphy data, it is better to know all places where some text is located in the document, even when we are not so sure how to actually read the text. +- Applies OCR more aggressively, also e.g. to images inside digitally-born PDF documents, as long as the newly detected + text does not overlap with any pre-existing digital text. +- Uses a lower confidence threshold (0.45), as especially for extracting stratigraphy data, it is better to know all + places where some text is located in the document, even when we are not so sure how to actually read the text. -``` +```sh AWS_TEXTRACT_PROFILE=default -OCR_INPUT_TYPE=path -OCR_INPUT_PATH=/home/stijn/bohrprofile-zurich/ +INPUT_TYPE=path +INPUT_PATH=/home/stijn/bohrprofile-zurich/ -OCR_OUTPUT_TYPE=path -OCR_OUTPUT_PATH=/home/stijn/bohrprofile-zurich/ocr/ +OUTPUT_TYPE=path +OUTPUT_PATH=/home/stijn/bohrprofile-zurich/ocr/ CONFIDENCE_THRESHOLD=0.45 -OCR_STRATEGY_AGGRESSIVE=TRUE +USE_AGGRESSIVE_STRATEGY=TRUE ``` + +### API Configuration + +```sh +# The directory at which temporary files are to be stored. +TMP_PATH=tmp/ + +# The local AWS profile that will be used to access Textract. +# +# If left empty, the credentials will be read from the environment. +# This allows the use of service accounts when deploying to K8s. +AWS_PROFILE=swisstopo-ngm + +# Alternatives to `AWS_PROFILE` to allow you to specify the access keys directly. +# AWS_ACCESS_KEY= +# AWS_SECRET_ACCESS_KEY= + +S3_INPUT_BUCKET=swissgeol-assets-swisstopo +S3_INPUT_FOLDER=asset_files/ + +S3_OUTPUT_BUCKET=swissgeol-assets-swisstopo +S3_OUTPUT_FOLDER=new_ocr_output/ + +CONFIDENCE_THRESHOLD=0.7 +``` + diff --git a/api.http b/api.http new file mode 100644 index 0000000..49e22e0 --- /dev/null +++ b/api.http @@ -0,0 +1,18 @@ +@file = 4178.pdf +// @file = 10000.pdf + +### Start File Processing +POST http://localhost:8000 +Content-Type: application/json + +{ + "file": "{{file}}" +} + +### Collect File Result +POST http://localhost:8000/collect +Content-Type: application/json + +{ + "file": "{{file}}" +} diff --git a/api.py b/api.py new file mode 100644 index 0000000..397e0b3 --- /dev/null +++ b/api.py @@ -0,0 +1,133 @@ +import logging +import os +import shutil +import uuid +from typing import Annotated + +from fastapi import FastAPI, Depends, status, HTTPException, BackgroundTasks, Response +from pydantic import BaseModel, Field +from starlette.responses import JSONResponse + +import ocr +from aws import aws +from utils import task +from utils.settings import ApiSettings, api_settings + +app = FastAPI() + +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) + + +class StartPayload(BaseModel): + file: str = Field(min_length=1) + + +if api_settings().skip_processing: + logging.warning("SKIP_PROCESSING is active, files will always be marked as completed without being proceed") + + +@app.post("/") +def start( + payload: StartPayload, + settings: Annotated[ApiSettings, Depends(api_settings)], + background_tasks: BackgroundTasks, +): + if not payload.file.endswith('.pdf'): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"message": "input must be a PDF file"} + ) + + aws_client = aws.connect(settings) + has_file = aws_client.exists_file( + settings.s3_input_bucket, + f'{settings.s3_input_folder}{payload.file}', + ) + if not has_file: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={"message": "file does not exist"} + ) + + task.start(payload.file, background_tasks, lambda: process(payload, aws_client, settings)) + return Response(status_code=status.HTTP_204_NO_CONTENT) + + +class CollectPayload(BaseModel): + file: str = Field(min_length=1) + + +@app.post("/collect") +def collect( + payload: CollectPayload, +): + result = task.collect_result(payload.file) + if result is None and not task.has_task(payload.file): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={"message": "OCR is not running for this file"} + ) + + has_finished = result is not None + if not has_finished: + logging.info(f"Processing of '{payload.file}' has not yet finished.") + return JSONResponse(status_code=status.HTTP_200_OK, content={ + "has_finished": False, + "data": None, + }) + + if result.ok: + logging.info(f"Processing of '{payload.file}' has been successful.") + return JSONResponse(status_code=status.HTTP_200_OK, content={ + "has_finished": True, + "data": result.value, + }) + + logging.info(f"Processing of '{payload.file}' has failed.") + return JSONResponse(status_code=status.HTTP_200_OK, content={ + "has_finished": True, + "error": "Internal Server Error", + }) + + +def process( + payload: StartPayload, + aws_client: aws.Client, + settings: Annotated[ApiSettings, Depends(api_settings)], +): + if settings.skip_processing: + # Sleep between 30 seconds to 2 minutes to simulate processing time. + # sleep(randint(30, 120)) + return + + task_id = f"{uuid.uuid4()}" + tmp_dir = os.path.join(settings.tmp_path, task_id) + os.makedirs(tmp_dir, exist_ok=True) + + input_path = os.path.join(tmp_dir, "input.pdf") + output_path = os.path.join(tmp_dir, "output.pdf") + + aws.load_file( + aws_client.bucket(settings.s3_input_bucket), + f'{settings.s3_input_folder}{payload.file}', + input_path, + ) + + ocr.process( + input_path, + output_path, + tmp_dir, + aws_client.textract, + settings.confidence_threshold, + settings.use_aggressive_strategy, + ) + + aws.store_file( + aws_client.bucket(settings.s3_output_bucket), + f'{settings.s3_output_folder}{payload.file}', + output_path, + ) + + shutil.rmtree(tmp_dir) + return () diff --git a/util/__init__.py b/aws/__init__.py similarity index 100% rename from util/__init__.py rename to aws/__init__.py diff --git a/aws/aws.py b/aws/aws.py new file mode 100644 index 0000000..6b40ca4 --- /dev/null +++ b/aws/aws.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass + +import boto3 +from botocore.exceptions import ClientError +from mypy_boto3_s3 import S3ServiceResource +from mypy_boto3_s3.service_resource import Bucket +from mypy_boto3_textract import TextractClient as Textractor + +from utils.settings import ApiSettings + +type S3Bucket = any + + +@dataclass +class Client: + s3: S3ServiceResource + textract: Textractor + + def bucket(self, name: str) -> Bucket: + return self.s3.Bucket(name) + + def exists_file(self, bucket_name: str, key: str) -> bool: + try: + self.s3.Object(bucket_name, key).load() + return True + except ClientError as e: + if e.response['Error']['Code'] == '404': + return False + else: + raise e + + +def connect(settings: ApiSettings) -> Client: + has_profile = is_set(settings.aws_profile) + + if has_profile: + session = open_session_by_profile(settings.aws_profile) + else: + session = open_session_by_service_role() + + return Client( + s3=session.resource('s3'), + textract=session.client('textract') + ) + + +def is_set(value: str | None) -> bool: + return value is not None and len(value) > 0 + + +def open_session_by_profile(profile: str) -> boto3.Session: + return boto3.Session(profile_name=profile) + + +def open_session_by_service_role() -> boto3.Session: + return boto3.Session() + + +def load_file(bucket: Bucket, key: str, local_path: str): + bucket.download_file(key, local_path) + + +def store_file(bucket: Bucket, key: str, local_path: str): + bucket.upload_file(local_path, key) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..92ee889 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,10 @@ +services: + api: + image: swissgeol-ocr/api:dev + build: + context: . + volumes: + - .env:.env + - .env.assets:.env.assets + ports: + - "8000:8000" \ No newline at end of file diff --git a/main.py b/main.py index 25c7f21..45111d2 100644 --- a/main.py +++ b/main.py @@ -1,153 +1,91 @@ -import boto3 -from textractor import Textractor - -from util.source import S3AssetSource, FileAssetSource -from util.target import S3AssetTarget, FileAssetTarget, AssetTarget -from util.util import process_page, clean_old_ocr, new_ocr_needed, draw_ocr_text_page, clean_old_ocr_aggressive -from util.crop import crop_images -from util.resize import resize_page -from pathlib import Path -from dotenv import dotenv_values -import os -import fitz -import subprocess -import sys - - -def load_target(config): - if config["OCR_OUTPUT_TYPE"] == "S3": - s3_session = boto3.Session(profile_name=config["OCR_OUTPUT_AWS_PROFILE"]) - s3 = s3_session.resource('s3') - return S3AssetTarget( - s3_bucket=s3.Bucket(config["OCR_OUTPUT_S3_BUCKET"]), - s3_prefix=config["OCR_OUTPUT_S3_PREFIX"], - output_path_fn=lambda filename: Path(sys.path[0], "tmp", "new_" + filename), - do_cleanup=(config["OCR_OUTPUT_CLEANUP_TMP_FILES"] == "TRUE") - ) - elif config["OCR_OUTPUT_TYPE"] == "path": - return FileAssetTarget( - out_path=Path(config["OCR_OUTPUT_PATH"]) - ) - else: - print("No OCR_OUTPUT_TYPE specified.") - sys.exit(1) - - -def load_source(config, target: AssetTarget): - if config["OCR_INPUT_IGNORE_EXISTING"] == "TRUE": - ignore_filenames = target.existing_filenames() - print("Found {} existing objects in output path.".format(len(ignore_filenames))) - else: - ignore_filenames = [] - - if config["OCR_INPUT_TYPE"] == "S3": - s3_session = boto3.Session(profile_name=config["OCR_INPUT_AWS_PROFILE"]) - s3 = s3_session.resource('s3') - - return S3AssetSource( - s3_bucket=s3.Bucket(config["OCR_INPUT_S3_BUCKET"]), - s3_prefix=config["OCR_INPUT_S3_PREFIX"], - allow_override=False, - input_path_fn=lambda filename: Path(sys.path[0], "tmp", filename), - do_cleanup=(config["OCR_INPUT_CLEANUP_TMP_FILES"] == "TRUE"), - ignore_filenames=ignore_filenames - ) - elif config["OCR_INPUT_TYPE"] == "path": - return FileAssetSource( - in_path=Path(config["OCR_INPUT_PATH"]), - ignore_filenames=ignore_filenames - ) - else: - print("No OCR_INPUT_TYPE specified.") - sys.exit(1) - - -def process(filename, in_path, out_path, extractor, confidence_threshold, aggressive_strategy): - - in_doc = fitz.open(in_path) - out_doc = fitz.open(in_path) - - in_page_count = in_doc.page_count - print(f"{in_page_count} pages") - for page_index, new_page in enumerate(out_doc): - page_number = page_index + 1 - print(f"Page {page_number}") - - new_page = resize_page(in_doc, out_doc, page_index) - crop_images(new_page, out_doc) - if aggressive_strategy: - ignore_rects = clean_old_ocr_aggressive(new_page) - else: - if new_ocr_needed(new_page): - clean_old_ocr(new_page) - ignore_rects = [] - else: - continue - tmp_path_prefix = os.path.join(sys.path[0], "tmp", "{}_page{}".format(filename, page_number)) - text_layer_path = os.path.join(sys.path[0], "tmp", "{}_page{}.pdf".format(filename, page_number)) - lines_to_draw = process_page(out_doc, new_page, extractor, tmp_path_prefix, confidence_threshold, ignore_rects) - draw_ocr_text_page(new_page, text_layer_path, lines_to_draw) - out_doc.save(out_path, garbage=3, deflate=True) - - # Verify that we can read the written document, and that it still has the same number of pages. Some corrupt input - # documents might lead to an empty or to a corrupt output document, sometimes even without throwing an error. (See - # LGD-283.) This check should detect such cases. - doc = fitz.open(out_path) - out_page_count = doc.page_count - if in_page_count != out_page_count: - raise ValueError( - "Output document contains {} pages instead of {}".format(out_page_count, in_page_count) - ) - - -def main(): - if 'OCR_PROFILE' in os.environ: - print(f"Loading env variables from .env and .env.{os.environ['OCR_PROFILE']}.") - config = { - **dotenv_values(".env"), - **dotenv_values(f".env.{os.environ['OCR_PROFILE']}"), - } - else: - print(f"Loading env variables from .env.") - config = dotenv_values(".env") - - extractor = Textractor(profile_name=config["AWS_TEXTRACT_PROFILE"]) - confidence_threshold = float(config["CONFIDENCE_THRESHOLD"]) - aggressive_strategy = config["OCR_STRATEGY_AGGRESSIVE"] == "TRUE" - print(f"Using confidence threshold {confidence_threshold} and aggressive strategy {aggressive_strategy}.") - - target = load_target(config) - source = load_source(config, target) - - for asset_item in source.iterator(): - asset_item.load() - out_path = target.local_path(asset_item) - - print() - print(asset_item.filename) - try: - process(asset_item.filename, asset_item.local_path, out_path, extractor, confidence_threshold, aggressive_strategy) - except ValueError as e: - gs_preprocess_path = os.path.join(sys.path[0], "tmp", "gs_pre_" + asset_item.filename) - print("Encountered ValueError: {}. Trying Ghostscript preprocessing.".format(e)) - subprocess.call([ - "ghostscript", - "-sDEVICE=pdfwrite", - "-dCompatibilityLevel=1.4", - "-dPDFSETTINGS=/default", - "-dNOPAUSE", - "-dQUIET", - "-dBATCH", - "-sOutputFile={}".format(gs_preprocess_path), - asset_item.local_path - ]) - process(asset_item.filename, gs_preprocess_path, out_path, extractor, confidence_threshold, aggressive_strategy) - os.remove(gs_preprocess_path) - - asset_item.cleanup() - target.save(asset_item) - target.cleanup(asset_item) - - -if __name__ == '__main__': - main() +import os +import shutil +import sys +from pathlib import Path + +import boto3 +from textractor import Textractor + +import ocr +from ocr.source import S3AssetSource, FileAssetSource +from ocr.target import S3AssetTarget, FileAssetTarget, AssetTarget +from utils.settings import script_settings, ScriptSettings + + +def load_target(settings: ScriptSettings): + if settings.output_type == 's3': + s3_session = boto3.Session(profile_name=settings.output_aws_profile) + s3 = s3_session.resource('s3') + return S3AssetTarget( + s3_bucket=s3.Bucket(settings.output_s3_bucket), + s3_prefix=settings.output_s3_prefix, + output_path_fn=lambda filename: Path(sys.path[0], "tmp", "new_" + filename), + ) + elif settings.output_type == 'path': + return FileAssetTarget( + out_path=Path(settings.output_path) + ) + else: + print("No output type specified.") + sys.exit(1) + + +def load_source(settings: ScriptSettings, target: AssetTarget): + if settings.input_type == 's3': + ignore_filenames = target.existing_filenames() + print("Found {} existing objects in output path.".format(len(ignore_filenames))) + else: + ignore_filenames = [] + + if settings.input_type == "s3": + s3_session = boto3.Session(profile_name=settings.input_aws_profile) + s3 = s3_session.resource('s3') + + return S3AssetSource( + s3_bucket=s3.Bucket(settings.input_s3_bucket), + s3_prefix=settings.input_s3_prefix, + allow_override=False, + input_path_fn=lambda filename: Path(sys.path[0], "tmp", filename), + ignore_filenames=ignore_filenames + ) + elif settings.input_type == "path": + return FileAssetSource( + in_path=Path(settings.input_path), + ignore_filenames=ignore_filenames + ) + else: + print("No input type specified.") + sys.exit(1) + + +def main(): + settings = script_settings() + extractor = Textractor(profile_name=settings.textract_aws_profile) + + target = load_target(settings) + source = load_source(settings, target) + + for asset_item in source.iterator(): + asset_item.load() + out_path = target.local_path(asset_item) + tmp_dir = os.path.join(settings.tmp_path, asset_item.filename) + + print() + print(asset_item.filename) + ocr.process( + str(asset_item.local_path), + str(out_path), + tmp_dir, + extractor.textract_client, + settings.confidence_threshold, + settings.use_aggressive_strategy, + ) + + target.save(asset_item) + + if settings.cleanup_tmp_files: + shutil.rmtree(tmp_dir) + + +if __name__ == '__main__': + main() diff --git a/ocr/__init__.py b/ocr/__init__.py new file mode 100644 index 0000000..cd468c3 --- /dev/null +++ b/ocr/__init__.py @@ -0,0 +1,110 @@ +import os +import subprocess + +import fitz +from mypy_boto3_textract import TextractClient as Textractor +from pymupdf.mupdf import PDF_ENCRYPT_KEEP + +from ocr.crop import crop_images +from ocr.resize import resize_page +from ocr.util import process_page, clean_old_ocr, new_ocr_needed, draw_ocr_text_page, clean_old_ocr_aggressive + + +def process( + input_path: str, + output_path: str, + tmp_dir: str, + textractor: Textractor, + confidence_threshold: float, + use_aggressive_strategy: bool, +): + try: + process_pdf( + input_path, + output_path, + tmp_dir, + textractor, + confidence_threshold, + use_aggressive_strategy + ) + except ValueError as e: + gs_preprocess_path = os.path.join(tmp_dir, "gs.pdf") + print(f"Encountered ValueError: {e}. Trying Ghostscript preprocessing.") + subprocess.call([ + "gs", + "-sDEVICE=pdfwrite", + "-dCompatibilityLevel=1.4", + "-dPDFSETTINGS=/default", + "-dNOPAUSE", + "-dQUIET", + "-dBATCH", + "-sOutputFile={}".format(gs_preprocess_path), + input_path, + ]) + process_pdf( + gs_preprocess_path, + output_path, + tmp_dir, + textractor, + confidence_threshold, + use_aggressive_strategy, + ) + + +def process_pdf( + in_path: str, + out_path: str, + tmp_dir: str, + textractor: Textractor, + confidence_threshold: float, + use_aggressive_strategy: bool, +): + tmp_out_path = os.path.join(tmp_dir, f"output.pdf") + + in_doc = fitz.open(in_path) + out_doc = fitz.open(in_path) + + os.makedirs(tmp_dir, exist_ok=True) + + in_page_count = in_doc.page_count + print(f"{in_page_count} pages") + + out_doc.save(tmp_out_path, garbage=3, deflate=True) + out_doc.close() + out_doc = fitz.open(tmp_out_path) + for page_index, new_page in enumerate(iter(in_doc)): + page_number = page_index + 1 + print(f"Page {page_number}") + + new_page = resize_page(in_doc, out_doc, page_index) + crop_images(new_page, out_doc) + if use_aggressive_strategy: + ignore_rects = clean_old_ocr_aggressive(new_page) + else: + if new_ocr_needed(new_page): + clean_old_ocr(new_page) + ignore_rects = [] + else: + continue + tmp_path_prefix = os.path.join(tmp_dir, f"page{page_number}") + text_layer_path = os.path.join(tmp_dir, f"page{page_number}.pdf") + lines_to_draw = process_page(out_doc, new_page, textractor, tmp_path_prefix, confidence_threshold, ignore_rects) + draw_ocr_text_page(new_page, text_layer_path, lines_to_draw) + out_doc.save(tmp_out_path, incremental=True, encryption=PDF_ENCRYPT_KEEP) + + out_doc.close() + out_doc = fitz.open(tmp_out_path) + out_doc.save(out_path, garbage=3, deflate=True) + in_doc.close() + out_doc.close() + + # Verify that we can read the written document, and that it still has the same number of pages. Some corrupt input + # documents might lead to an empty or to a corrupt output document, sometimes even without throwing an error. (See + # LGD-283.) This check should detect such cases. + doc = fitz.open(out_path) + out_page_count = doc.page_count + if in_page_count != out_page_count: + raise ValueError( + "Output document contains {} pages instead of {}".format(out_page_count, in_page_count) + ) + doc.close() diff --git a/util/applyocr.py b/ocr/applyocr.py similarity index 96% rename from util/applyocr.py rename to ocr/applyocr.py index f8f7f3f..a0d5fb4 100644 --- a/util/applyocr.py +++ b/ocr/applyocr.py @@ -1,7 +1,7 @@ import fitz -from util.readingorder import sort_lines, TextLine -from util.textract import combine_text_lines, textract, clip_rects, MAX_DIMENSION_POINTS -from textractor import Textractor +from ocr.readingorder import sort_lines, TextLine +from ocr.textract import combine_text_lines, textract, clip_rects, MAX_DIMENSION_POINTS +from mypy_boto3_textract import TextractClient as Textractor from uuid import uuid4 diff --git a/util/crop.py b/ocr/crop.py similarity index 100% rename from util/crop.py rename to ocr/crop.py diff --git a/util/ocrresult.py b/ocr/ocrresult.py similarity index 100% rename from util/ocrresult.py rename to ocr/ocrresult.py diff --git a/util/readingorder.py b/ocr/readingorder.py similarity index 100% rename from util/readingorder.py rename to ocr/readingorder.py diff --git a/util/resize.py b/ocr/resize.py similarity index 100% rename from util/resize.py rename to ocr/resize.py diff --git a/util/source.py b/ocr/source.py similarity index 84% rename from util/source.py rename to ocr/source.py index 7559b92..ef4a0ea 100644 --- a/util/source.py +++ b/ocr/source.py @@ -1,106 +1,89 @@ -import os -from collections.abc import Iterator -from abc import abstractmethod -from dataclasses import dataclass -from typing import Callable -from pathlib import Path - - -class AssetItem: - local_path: Path - filename: str - - @abstractmethod - def load(self): - pass - - @abstractmethod - def cleanup(self): - pass - - -@dataclass -class FileAssetItem(AssetItem): - local_path: Path - filename: str - - def load(self): - pass - - def cleanup(self): - pass - - -@dataclass -class S3AssetItem(AssetItem): - s3_bucket: any - s3_key: str - local_path: Path - allow_override: bool - do_cleanup: bool - - def __post_init__(self): - self.filename = S3AssetItem.key_to_filename(self.s3_key) - - def load(self): - if self.allow_override or not os.path.exists(self.local_path): - self.s3_bucket.download_file(self.s3_key, self.local_path) - - def cleanup(self): - if self.do_cleanup: - os.remove(self.local_path) - - @staticmethod - def key_to_filename(key): - return key.split("/")[-1] - - -class AssetSource: - @abstractmethod - def iterator(self) -> Iterator[AssetItem]: - pass - - -@dataclass -class FileAssetSource(AssetSource): - in_path: Path - ignore_filenames: set[str] - - def iterator(self) -> Iterator[AssetItem]: - if self.in_path.is_dir(): - return ( - FileAssetItem(local_path=path, filename=os.path.basename(path)) - for path in sorted(self.in_path.glob("*")) - if os.path.basename(path).endswith(".pdf") and os.path.basename(path) not in self.ignore_filenames - ) - else: - return [ - FileAssetItem(local_path=self.in_path, filename=os.path.basename(self.in_path)) - ] - - -@dataclass -class S3AssetSource(AssetSource): - s3_bucket: any - s3_prefix: str - input_path_fn: Callable[[str], Path] - allow_override: bool - do_cleanup: bool - ignore_filenames: set[str] - - def iterator(self) -> Iterator[AssetItem]: - objs = list(self.s3_bucket.objects.filter(Prefix=self.s3_prefix)) - - return ( - S3AssetItem( - s3_bucket=self.s3_bucket, - s3_key=obj.key, - local_path=self.input_path_fn(S3AssetItem.key_to_filename(obj.key)), - allow_override=self.allow_override, - do_cleanup=self.do_cleanup - ) - for obj in objs - if obj.size - if obj.key.lower().endswith(".pdf") - if S3AssetItem.key_to_filename(obj.key) not in self.ignore_filenames - ) +import os +from collections.abc import Iterator +from abc import abstractmethod +from dataclasses import dataclass +from typing import Callable +from pathlib import Path + + +class AssetItem: + local_path: Path + filename: str + + @abstractmethod + def load(self): + pass + + +@dataclass +class FileAssetItem(AssetItem): + local_path: Path + filename: str + + def load(self): + pass + +@dataclass +class S3AssetItem(AssetItem): + s3_bucket: any + s3_key: str + local_path: Path + allow_override: bool + + def __post_init__(self): + self.filename = S3AssetItem.key_to_filename(self.s3_key) + + def load(self): + if self.allow_override or not os.path.exists(self.local_path): + self.s3_bucket.download_file(self.s3_key, self.local_path) + + @staticmethod + def key_to_filename(key): + return key.split("/")[-1] + + +class AssetSource: + @abstractmethod + def iterator(self) -> Iterator[AssetItem]: + pass + + +@dataclass +class FileAssetSource(AssetSource): + in_path: Path + ignore_filenames: set[str] + + def iterator(self) -> Iterator[AssetItem]: + if self.in_path.is_dir(): + return ( + FileAssetItem(local_path=path, filename=os.path.basename(path)) + for path in sorted(self.in_path.glob("*")) + if os.path.basename(path).endswith(".pdf") and os.path.basename(path) not in self.ignore_filenames + ) + else: + return iter((FileAssetItem(local_path=self.in_path, filename=os.path.basename(self.in_path)),)) + + +@dataclass +class S3AssetSource(AssetSource): + s3_bucket: any + s3_prefix: str + input_path_fn: Callable[[str], Path] + allow_override: bool + ignore_filenames: set[str] + + def iterator(self) -> Iterator[AssetItem]: + objs = list(self.s3_bucket.objects.filter(Prefix=self.s3_prefix)) + + return ( + S3AssetItem( + s3_bucket=self.s3_bucket, + s3_key=obj.key, + local_path=self.input_path_fn(S3AssetItem.key_to_filename(obj.key)), + allow_override=self.allow_override, + ) + for obj in objs + if obj.size + if obj.key.lower().endswith(".pdf") + if S3AssetItem.key_to_filename(obj.key) not in self.ignore_filenames + ) diff --git a/util/target.py b/ocr/target.py similarity index 81% rename from util/target.py rename to ocr/target.py index 1d25513..0b2c036 100644 --- a/util/target.py +++ b/ocr/target.py @@ -1,70 +1,57 @@ -import os -from abc import abstractmethod -from dataclasses import dataclass -from typing import Callable -from pathlib import Path - -from util.source import AssetItem, S3AssetItem - - -class AssetTarget: - @abstractmethod - def save(self, item: AssetItem): - pass - - @abstractmethod - def local_path(self, item: AssetItem) -> Path: - pass - - @abstractmethod - def cleanup(self, item: AssetItem): - pass - - @abstractmethod - def existing_filenames(self) -> set[str]: - pass - - -@dataclass -class FileAssetTarget(AssetTarget): - out_path: Path - - def save(self, item: AssetItem): - pass - - def local_path(self, item: AssetItem) -> Path: - return Path(self.out_path, item.filename) - - def cleanup(self, item: AssetItem): - pass - - def existing_filenames(self) -> set[str]: - return { - os.path.basename(path) - for path in sorted(self.out_path.glob("*")) - } - - -@dataclass -class S3AssetTarget(AssetTarget): - s3_bucket: any - s3_prefix: str - output_path_fn: Callable[[str], Path] - do_cleanup: bool - - def save(self, item: AssetItem): - self.s3_bucket.upload_file(self.local_path(item), self.s3_prefix + item.filename) - - def local_path(self, item: AssetItem) -> Path: - return self.output_path_fn(item.filename) - - def cleanup(self, item: AssetItem): - if self.do_cleanup: - os.remove(self.local_path(item)) - - def existing_filenames(self) -> set[str]: - return { - S3AssetItem.key_to_filename(obj.key) - for obj in self.s3_bucket.objects.filter(Prefix=self.s3_prefix) - } - +import os +from abc import abstractmethod +from dataclasses import dataclass +from typing import Callable +from pathlib import Path + +from ocr.source import AssetItem, S3AssetItem + + +class AssetTarget: + @abstractmethod + def save(self, item: AssetItem): + pass + + @abstractmethod + def local_path(self, item: AssetItem) -> Path: + pass + + @abstractmethod + def existing_filenames(self) -> set[str]: + pass + + +@dataclass +class FileAssetTarget(AssetTarget): + out_path: Path + + def save(self, item: AssetItem): + pass + + def local_path(self, item: AssetItem) -> Path: + return Path(self.out_path, item.filename) + + def existing_filenames(self) -> set[str]: + return { + os.path.basename(path) + for path in sorted(self.out_path.glob("*")) + } + + +@dataclass +class S3AssetTarget(AssetTarget): + s3_bucket: any + s3_prefix: str + output_path_fn: Callable[[str], Path] + + def save(self, item: AssetItem): + self.s3_bucket.upload_file(self.local_path(item), self.s3_prefix + item.filename) + + def local_path(self, item: AssetItem) -> Path: + return self.output_path_fn(item.filename) + + def existing_filenames(self) -> set[str]: + return { + S3AssetItem.key_to_filename(obj.key) + for obj in self.s3_bucket.objects.filter(Prefix=self.s3_prefix) + } diff --git a/util/textract.py b/ocr/textract.py similarity index 95% rename from util/textract.py rename to ocr/textract.py index 9b5f964..7cddbde 100644 --- a/util/textract.py +++ b/ocr/textract.py @@ -4,14 +4,15 @@ import os import backoff from botocore.exceptions import ClientError -from textractor import Textractor +from mypy_boto3_textract import TextractClient as Textractor from trp.t_pipeline import add_page_orientation import trp.trp2 as t2 import trp as t1 import textractcaller.t_call as t_call import statistics -from util.readingorder import TextLine + +from ocr.readingorder import TextLine MAX_DIMENSION_POINTS = 2000 @@ -70,7 +71,6 @@ def textract(doc: fitz.Document, extractor: Textractor, tmp_file_path: str, clip page.set_cropbox(old_cropbox.intersect(page.mediabox)) document = call_textract(extractor, tmp_file_path) - os.remove(tmp_file_path) if document is None: return [] @@ -87,16 +87,17 @@ def backoff_hdlr(details): @backoff.on_exception(backoff.expo, ClientError, on_backoff=backoff_hdlr, - base=2) + base=2, + max_tries=3) def call_textract(extractor: Textractor, tmp_file_path: str) -> t1.Document | None: try: j = t_call.call_textract( input_document=tmp_file_path, - boto3_textract_client=extractor.textract_client, + boto3_textract_client=extractor, call_mode=t_call.Textract_Call_Mode.FORCE_SYNC ) t_document: t2.TDocument = t2.TDocumentSchema().load(j) - except extractor.textract_client.exceptions.InvalidParameterException: + except extractor.exceptions.InvalidParameterException: print("Encountered InvalidParameterException from Textract. Page might require more than 10MB memory. Skipping page.") return None diff --git a/util/util.py b/ocr/util.py similarity index 95% rename from util/util.py rename to ocr/util.py index a61eaf7..b8708ea 100644 --- a/util/util.py +++ b/ocr/util.py @@ -1,14 +1,11 @@ -import os - import fitz - -from reportlab.pdfgen import canvas +from mypy_boto3_textract import TextractClient as Textractor from reportlab.pdfbase import pdfmetrics +from reportlab.pdfgen import canvas from reportlab.pdfgen.textobject import PDFTextObject -from textractor import Textractor -from util.readingorder import TextLine, TextWord -from util.applyocr import OCR +from ocr.applyocr import OCR +from ocr.readingorder import TextLine, TextWord def process_page( @@ -163,15 +160,13 @@ def draw_ocr_text_page( c.showPage() c.save() - text_layer_doc = fitz.open(text_layer_path) - original_rotation = page.rotation - page.set_rotation(0) - page.show_pdf_page(page.rect, text_layer_doc, rotate=original_rotation) - page.set_rotation(original_rotation) - os.remove(text_layer_path) + with fitz.open(text_layer_path) as text_layer_doc: + original_rotation = page.rotation + page.set_rotation(0) + page.show_pdf_page(page.rect, text_layer_doc, rotate=original_rotation) + page.set_rotation(original_rotation) return - def new_ocr_needed(page: fitz.Page) -> bool: bboxes = page.get_bboxlog() @@ -180,6 +175,7 @@ def new_ocr_needed(page: fitz.Page) -> bool: if (boxType == "fill-text" or boxType == "stroke-text") and not fitz.Rect(rectangle).is_empty: print(" skipped") return False + pass return True diff --git a/requirements.txt b/requirements.txt index 3a51527..37588f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,10 @@ backoff pymupdf reportlab python-dotenv +python-multipart +pydantic +pydantic-settings +fastapi[standard] +boto3-stubs[essentials] +boto3-stubs[s3] +boto3-stubs[textract] \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/settings.py b/utils/settings.py new file mode 100644 index 0000000..c31d548 --- /dev/null +++ b/utils/settings.py @@ -0,0 +1,61 @@ +import os +from functools import lru_cache +from typing import Literal + +from dotenv import load_dotenv +from pydantic_settings import BaseSettings + + +class SharedSettings(BaseSettings): + tmp_path: str + + confidence_threshold: float + use_aggressive_strategy: bool = False + + +class ApiSettings(SharedSettings): + aws_profile: str | None = None + skip_processing: bool = False + + s3_input_bucket: str + s3_input_folder: str + + s3_output_bucket: str + s3_output_folder: str + + +class ScriptSettings(SharedSettings): + cleanup_tmp_files: bool + + textract_aws_profile: str + + input_type: Literal['path', 's3'] + input_path: str | None = None + input_aws_profile: str | None = None + input_s3_bucket: str | None = None + input_s3_prefix: str | None = None + input_ignore_existing: bool + + output_type: Literal['path', 's3'] + output_path: str | None = None + output_aws_profile: str | None = None + output_s3_bucket: str | None = None + output_s3_prefix: str | None = None + + +print(f"Loading env variables from '.env'.") +load_dotenv() +if 'OCR_PROFILE' in os.environ: + env_file = f".env.{os.environ['OCR_PROFILE']}" + print(f"Loading env variables from '{env_file}'.") + load_dotenv(env_file) + + +@lru_cache +def api_settings(): + return ApiSettings() + + +@lru_cache +def script_settings(): + return ScriptSettings() diff --git a/utils/task.py b/utils/task.py new file mode 100644 index 0000000..429f9ad --- /dev/null +++ b/utils/task.py @@ -0,0 +1,63 @@ +import logging +import threading +import typing +import uuid +from dataclasses import dataclass +from typing import Dict, TypeVar + +from fastapi import BackgroundTasks + +Result = TypeVar("Result") + + +@dataclass +class Task: + file: str + result: Result | None = None + + +@dataclass +class Output: + ok: bool + value: Result | RuntimeError + + +active_tasks: Dict[str, Task] = {} +active_tasks_lock = threading.Lock() + + +def start(file: str, background_tasks: BackgroundTasks, target: typing.Callable[[], Result]) -> bool: + with active_tasks_lock: + if file in active_tasks: + return False + active_tasks[file] = Task(file=file) + background_tasks.add_task(lambda: run(file, target)) + return True + + +def has_task(file: str) -> bool: + with active_tasks_lock: + return file in active_tasks + + +def collect_result(file: str) -> Output | None: + with active_tasks_lock: + task = active_tasks.get(file) + if task is None or task.result is None: + return None + del active_tasks[file] + return task.result + + +def run(file: str, target: typing.Callable[[], Result]): + try: + logging.info(f"Starting task for file '{file}'.") + value = target() + result = Output(ok=True, value=value) + logging.info(f"Task for file '{file}' has been completed.") + except Exception as e: + logging.exception(f"Processing of '{file}' failed") + result = Output(ok=False, value=e) + + with active_tasks_lock: + active_tasks.get(file).result = result