diff --git a/.github/workflows/onnx-web.yml b/.github/workflows/onnx-web.yml new file mode 100644 index 000000000..d5bf8148f --- /dev/null +++ b/.github/workflows/onnx-web.yml @@ -0,0 +1,319 @@ +name: build +on: + push: + workflow_dispatch: +jobs: + github-pending: + runs-on: + - ubuntu-latest + container: + image: docker.io/python:3.10 + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" + steps: + - uses: actions/checkout@v4.1.0 + - run: "./common/scripts/github-status.sh pending" + build-api-coverage-3-10: + needs: github-pending + runs-on: + - ubuntu-latest + container: + image: docker.io/python:3.10 + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/cache@v3.3.2 + with: + path: ".cache/pip" + key: python-3-10 + - run: apt-get -y update && apt-get -y install python3-opencv + - run: cd api + - run: "${{ github.workspace }}/common/scripts/make-venv.sh" + - run: make ci + build-api-coverage-3-9: + needs: github-pending + runs-on: + - ubuntu-latest + container: + image: docker.io/python:3.9 + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/cache@v3.3.2 + with: + path: ".cache/pip" + key: python-3-9 + - run: apt-get -y update && apt-get -y install python3-opencv + - run: cd api + - run: "${{ github.workspace }}/common/scripts/make-venv.sh" + - run: make ci + build-api-coverage-3-8: + needs: github-pending + runs-on: + - ubuntu-latest + container: + image: docker.io/python:3.8 + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/cache@v3.3.2 + with: + path: ".cache/pip" + key: python-3-8 + - run: apt-get -y update && apt-get -y install python3-opencv + - run: cd api + - run: "${{ github.workspace }}/common/scripts/make-venv.sh" + - run: make ci + build-gui-bundle: + needs: github-pending + runs-on: + - ubuntu-latest + container: + image: docker.io/node:18 + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/cache@v3.3.2 + with: + path: gui/node_modules/ + key: "${{ runner.os }}-${{ hashFiles('gui/yarn.lock') }}" + - run: cd gui + - run: make ci +# # 'artifacts.coverage_report' was not transformed because there is no suitable equivalent in GitHub Actions +# # 'artifacts.junit' was not transformed because there is no suitable equivalent in GitHub Actions + - uses: actions/upload-artifact@v4.1.0 + if: success() + with: + name: "${{ github.job }}" + retention-days: 1 + path: gui/out/ + package-api-oci: + needs: + - build-api-coverage-3-8 + - build-api-coverage-3-9 + - build-api-coverage-3-10 + - build-gui-bundle + runs-on: + - ubuntu-latest + container: + image: docker.io/docker:20.10 + services: + docker.io/docker:20.10-dind: + image: docker.io/docker:20.10-dind + env: + DOCKER_CERT_PATH: "/shared/docker/client" + DOCKER_DRIVER: overlay2 + DOCKER_HOST: tcp://localhost:2376 + DOCKER_NAME: "${{ github.repository }}" + DOCKER_TLS_CERTDIR: "/shared/docker" + DOCKER_TLS_VERIFY: 1 + VERSION_TAG: "${{ github.ref }}" + IMAGE_ROOT: "${{ github.workspace }}" + IMAGE_SUFFIX: api + strategy: + matrix: + include: + - IMAGE_ARCH: cpu-buster + IMAGE_FILE: api/Containerfile.cpu.buster + - IMAGE_ARCH: cuda-ubuntu + IMAGE_FILE: api/Containerfile.cuda.ubuntu + - IMAGE_ARCH: rocm-ubuntu + IMAGE_FILE: api/Containerfile.rocm.ubuntu + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/download-artifact@v4.1.0 + - run: mkdir ${HOME}/.docker + - run: echo "${DOCKER_SECRET}" | base64 -d > ${HOME}/.docker/config.json + - run: "${{ github.workspace }}/common/scripts/image-build.sh --push" + - run: rm -rfv ${HOME}/.docker + if: always() + package-gui-oci: + needs: build-gui-bundle + runs-on: + - ubuntu-latest + container: + image: docker.io/docker:20.10 + services: + docker.io/docker:20.10-dind: + image: docker.io/docker:20.10-dind + env: + DOCKER_CERT_PATH: "/shared/docker/client" + DOCKER_DRIVER: overlay2 + DOCKER_HOST: tcp://localhost:2376 + DOCKER_NAME: "${{ github.repository }}" + DOCKER_TLS_CERTDIR: "/shared/docker" + DOCKER_TLS_VERIFY: 1 + VERSION_TAG: "${{ github.ref }}" + IMAGE_ROOT: "${{ github.workspace }}/gui" + IMAGE_SUFFIX: gui + strategy: + matrix: + include: + - IMAGE_ARCH: nginx-alpine + IMAGE_FILE: Containerfile.nginx.alpine + - IMAGE_ARCH: nginx-bullseye + IMAGE_FILE: Containerfile.nginx.bullseye + - IMAGE_ARCH: node-alpine + IMAGE_FILE: Containerfile.node.alpine + - IMAGE_ARCH: node-buster + IMAGE_FILE: Containerfile.node.buster + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/download-artifact@v4.1.0 + - run: mkdir ${HOME}/.docker + - run: echo "${DOCKER_SECRET}" | base64 -d > ${HOME}/.docker/config.json + - run: cd gui + - run: "${{ github.workspace }}/common/scripts/image-build.sh --push" + - run: rm -rfv ${HOME}/.docker + if: always() + package-api-twine: + needs: + - build-api-coverage-3-8 + - build-api-coverage-3-10 + runs-on: + - ubuntu-latest + container: + image: docker.io/python:3.10 + if: startsWith(github.ref, 'refs/tags') + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/cache@v3.3.2 + with: + path: ".cache/pip" + key: "${{ runner.os }}-${{ hashFiles('api/requirements/dev.txt') }}" + - uses: actions/download-artifact@v4.1.0 + - run: echo "${PIP_SECRET}" | base64 -d > $HOME/.pypirc + - run: cp -v README.md api/README.md + - run: cd api + - run: pip3 install -r requirements/dev.txt + - run: python3 -m build + - run: twine check dist/* + - run: twine upload --repository onnx-web dist/* + - uses: actions/upload-artifact@v4.1.0 + if: success() + with: + name: "${{ github.job }}" + retention-days: 7 + path: dist/ + package-api-twine-dry: + needs: + - build-api-coverage-3-8 + - build-api-coverage-3-10 + runs-on: + - ubuntu-latest + container: + image: docker.io/python:3.10 + if: "!(startsWith(github.ref, 'refs/tags'))" + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/cache@v3.3.2 + with: + path: ".cache/pip" + key: "${{ runner.os }}-${{ hashFiles('api/requirements/dev.txt') }}" + - uses: actions/download-artifact@v4.1.0 + - run: cp -v README.md api/README.md + - run: cd api + - run: pip install build twine + - run: python -m build + - run: twine check dist/* + - uses: actions/upload-artifact@v4.1.0 + if: success() + with: + name: "${{ github.job }}" + retention-days: 7 + path: dist/ + package-gui-npm: + needs: build-gui-bundle + runs-on: + - ubuntu-latest + container: + image: docker.io/node:18 + if: startsWith(github.ref, 'refs/tags') + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/cache@v3.3.2 + with: + path: gui/node_modules/ + key: "${{ runner.os }}-${{ hashFiles('gui/yarn.lock') }}" + - uses: actions/download-artifact@v4.1.0 + - run: echo "${NPM_SECRET}" | base64 -d > $HOME/.npmrc + - run: cp -v README.md gui/README.md + - run: cd gui + - run: npm publish +# # 'artifacts.coverage_report' was not transformed because there is no suitable equivalent in GitHub Actions +# # 'artifacts.junit' was not transformed because there is no suitable equivalent in GitHub Actions + - uses: actions/upload-artifact@v4.1.0 + if: success() + with: + name: "${{ github.job }}" + retention-days: 7 + path: "${{ github.workspace }}/*.tgz" + package-gui-npm-dry: + needs: build-gui-bundle + runs-on: + - ubuntu-latest + container: + image: docker.io/node:18 + if: "!(startsWith(github.ref, 'refs/tags'))" + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/cache@v3.3.2 + with: + path: gui/node_modules/ + key: "${{ runner.os }}-${{ hashFiles('gui/yarn.lock') }}" + - uses: actions/download-artifact@v4.1.0 + - run: cp -v README.md gui/README.md + - run: cd gui + - run: npm pack +# # 'artifacts.coverage_report' was not transformed because there is no suitable equivalent in GitHub Actions +# # 'artifacts.junit' was not transformed because there is no suitable equivalent in GitHub Actions + - uses: actions/upload-artifact@v4.1.0 + if: success() + with: + name: "${{ github.job }}" + retention-days: 7 + path: "${{ github.workspace }}/*.tgz" + github-failure: + needs: + - package-api-oci + - package-gui-oci + - package-api-twine + - package-api-twine-dry + - package-gui-npm + - package-gui-npm-dry + runs-on: + - ubuntu-latest + container: + image: docker.io/python:3.10 + if: failure() + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/download-artifact@v4.1.0 + - run: "./common/scripts/github-status.sh failure" + github-success: + needs: + - package-api-oci + - package-gui-oci + - package-api-twine + - package-api-twine-dry + - package-gui-npm + - package-gui-npm-dry + runs-on: + - ubuntu-latest + container: + image: docker.io/python:3.10 + if: success() + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.cache/pip" + steps: + - uses: actions/checkout@v4.1.0 + - uses: actions/download-artifact@v4.1.0 + - run: "./common/scripts/github-status.sh success" diff --git a/.vscode/launch.json b/.vscode/launch.json index 8e610783f..6ce654ba9 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,6 +1,14 @@ { "version": "0.2.0", "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "args": ["--name", "outpaint"], + "console": "integratedTerminal" + }, { "name": "Python: Remote Attach", "type": "python", @@ -16,6 +24,16 @@ } ], "justMyCode": false + }, + { + "name": "Python: Local Attach", + "type": "python", + "request": "attach", + "connect": { + "host": "127.0.0.1", + "port": 5679 + }, + "justMyCode": false } ] } \ No newline at end of file diff --git a/api/.gitignore b/api/.gitignore index 165cb3008..6f8e634ab 100644 --- a/api/.gitignore +++ b/api/.gitignore @@ -7,8 +7,10 @@ entry.py *.swp *.pyc +.cache/ __pycache__/ dist/ +filter/ htmlcov/ onnx_env/ venv/ diff --git a/api/Containerfile.cuda.ubuntu b/api/Containerfile.cuda.ubuntu index 26c10cd95..5b710c98f 100644 --- a/api/Containerfile.cuda.ubuntu +++ b/api/Containerfile.cuda.ubuntu @@ -1,4 +1,4 @@ -FROM docker.io/nvidia/cuda:11.7.1-runtime-ubuntu20.04 +FROM docker.io/nvidia/cuda:11.8.0-runtime-ubuntu20.04 WORKDIR /onnx-web/api diff --git a/api/convert.sh b/api/convert.sh new file mode 100755 index 000000000..7e9b3e330 --- /dev/null +++ b/api/convert.sh @@ -0,0 +1,28 @@ +#! /bin/sh + +set -eu + +if [ -n "${VIRTUAL_ENV+set}" ]; then + echo "Using current virtual env..." +else + if [ -d "onnx_env" ]; then + echo "Loading existing virtual env..." + . onnx_env/bin/activate + else + echo "Creating new virtual env..." + python -m venv onnx_env + . onnx_env/bin/activate + fi +fi + +echo "Downloading and converting models to ONNX format..." +python3 -m onnx_web.convert \ + --correction \ + --diffusion \ + --networks \ + --sources \ + --upscaling \ + --extras=${ONNX_WEB_EXTRA_MODELS:-../models/extras.json} \ + --token=${HF_TOKEN:-} \ + ${ONNX_WEB_EXTRA_ARGS:-} + diff --git a/api/gui/.gitignore b/api/gui/.gitignore index e9b37fd24..229bf879c 100644 --- a/api/gui/.gitignore +++ b/api/gui/.gitignore @@ -1,3 +1,4 @@ +bundle/main.css bundle/main.js config.json index.html diff --git a/api/launch.bat b/api/launch.bat index e1bd53e81..2f20a471a 100644 --- a/api/launch.bat +++ b/api/launch.bat @@ -5,11 +5,11 @@ echo "This launch.bat script is deprecated in favor of launch.ps1 and will be re echo "Downloading and converting models to ONNX format..." IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=..\models\extras.json) python -m onnx_web.convert ^ ---sources ^ ---diffusion ^ ---upscaling ^ --correction ^ +--diffusion ^ --networks ^ +--sources ^ +--upscaling ^ --extras=%ONNX_WEB_EXTRA_MODELS% ^ --token=%HF_TOKEN% %ONNX_WEB_EXTRA_ARGS% diff --git a/api/launch.ps1 b/api/launch.ps1 index 4e3787554..420a3554a 100644 --- a/api/launch.ps1 +++ b/api/launch.ps1 @@ -3,11 +3,11 @@ echo "Downloading and converting models to ONNX format..." IF ($Env:ONNX_WEB_EXTRA_MODELS -eq "") {$Env:ONNX_WEB_EXTRA_MODELS="..\models\extras.json"} python -m onnx_web.convert ` ---sources ` ---diffusion ` ---upscaling ` --correction ` +--diffusion ` --networks ` +--sources ` +--upscaling ` --extras=$Env:ONNX_WEB_EXTRA_MODELS ` --token=$Env:HF_TOKEN $Env:ONNX_WEB_EXTRA_ARGS diff --git a/api/launch.sh b/api/launch.sh index 8e793771e..e4523bbca 100755 --- a/api/launch.sh +++ b/api/launch.sh @@ -17,10 +17,11 @@ fi echo "Downloading and converting models to ONNX format..." python3 -m onnx_web.convert \ - --sources \ + --correction \ --diffusion \ + --networks \ + --sources \ --upscaling \ - --correction \ --extras=${ONNX_WEB_EXTRA_MODELS:-../models/extras.json} \ --token=${HF_TOKEN:-} \ ${ONNX_WEB_EXTRA_ARGS:-} diff --git a/api/onnx_web/chain/blend_denoise_fastnlmeans.py b/api/onnx_web/chain/blend_denoise_fastnlmeans.py index c3de00e2b..ea39a4518 100644 --- a/api/onnx_web/chain/blend_denoise_fastnlmeans.py +++ b/api/onnx_web/chain/blend_denoise_fastnlmeans.py @@ -32,9 +32,9 @@ def run( logger.info("denoising source images") results = [] - for source in sources.as_numpy(): + for source in sources.as_arrays(): data = cv2.cvtColor(source, cv2.COLOR_RGB2BGR) data = cv2.fastNlMeansDenoisingColored(data, None, strength, strength) results.append(cv2.cvtColor(data, cv2.COLOR_BGR2RGB)) - return StageResult(arrays=results) + return StageResult.from_arrays(results, metadata=sources.metadata) diff --git a/api/onnx_web/chain/blend_denoise_localstd.py b/api/onnx_web/chain/blend_denoise_localstd.py index 389e30a8f..08f20d470 100644 --- a/api/onnx_web/chain/blend_denoise_localstd.py +++ b/api/onnx_web/chain/blend_denoise_localstd.py @@ -14,6 +14,11 @@ class BlendDenoiseLocalStdStage(BaseStage): + """ + Experimental stage to blend and denoise images using local means compared to local standard deviation. + Very slow. + """ + max_tile = SizeChart.max def run( @@ -35,8 +40,9 @@ def run( return StageResult.from_arrays( [ remove_noise(source, threshold=strength, deviation=range)[0] - for source in sources.as_numpy() - ] + for source in sources.as_arrays() + ], + metadata=sources.metadata, ) diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index 34e4f5351..84d11393c 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -7,7 +7,7 @@ from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage -from .result import StageResult +from .result import ImageMetadata, StageResult logger = getLogger(__name__) @@ -20,7 +20,7 @@ def run( _worker: WorkerContext, _server: ServerContext, _stage: StageParams, - _params: ImageParams, + params: ImageParams, sources: StageResult, *, height: int, @@ -35,7 +35,7 @@ def run( ) -> StageResult: logger.info("combining source images using grid layout") - images = sources.as_image() + images = sources.as_images() ref_image = images[0] size = Size(*ref_image.size) @@ -52,7 +52,12 @@ def run( n = order[i] output.paste(images[n], (x * size.width, y * size.height)) - return StageResult(images=[*images, output]) + result = StageResult(source=sources) + result.push_image( + output, + ImageMetadata(params, Size(width, height), ancestors=[sources.metadata]), + ) + return result def outputs( self, diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 4528946f6..e6ee29a4e 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -66,7 +66,7 @@ def run( pipe_params["strength"] = strength outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): if params.is_lpw(): logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) @@ -102,7 +102,10 @@ def run( outputs.extend(result.images) - return StageResult(images=outputs) + metadata = [ + metadata.child(params, metadata.size) for metadata in sources.metadata + ] + return StageResult(images=outputs, metadata=metadata) def steps( self, diff --git a/api/onnx_web/chain/blend_linear.py b/api/onnx_web/chain/blend_linear.py index 1b40a5fd0..5d1baa9e0 100644 --- a/api/onnx_web/chain/blend_linear.py +++ b/api/onnx_web/chain/blend_linear.py @@ -23,14 +23,15 @@ def run( *, alpha: float, stage_source: Optional[Image.Image] = None, - _callback: Optional[ProgressCallback] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("blending source images using linear interpolation") - return StageResult( - images=[ + return StageResult.from_images( + [ Image.blend(source, stage_source, alpha) - for source in sources.as_image() - ] + for source in sources.as_images() + ], + metadata=sources.metadata, ) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 93abd1ee1..91d464fb5 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -27,7 +27,7 @@ def run( stage_source: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None, tile_mask: Optional[Image.Image] = None, - _callback: Optional[ProgressCallback] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("blending image using mask") @@ -48,6 +48,7 @@ def run( return StageResult.from_images( [ Image.composite(stage_source_tile, source, mult_mask) - for source in sources.as_image() - ] + for source in sources.as_images() + ], + metadata=sources.metadata, ) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 84c8a9d90..803c82e9d 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -7,9 +7,10 @@ from PIL import Image from torchvision.transforms.functional import normalize -from ..params import ImageParams, StageParams, UpscaleParams +from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext +from ..worker.context import ProgressCallback from .base import BaseStage from .result import StageResult @@ -28,23 +29,24 @@ def run( _params: ImageParams, sources: StageResult, *, - stage_source: Optional[Image.Image] = None, upscale: UpscaleParams, + callback: Optional[ProgressCallback] = None, + highres: Optional[HighresParams] = None, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: # adapted from https://github.com/kadirnar/codeformer-pip/blob/main/codeformer/app.py and # https://pypi.org/project/codeformer-perceptor/ # import must be within the load function for patches to take effect - # TODO: rewrite and remove + from codeformer.basicsr.archs.codeformer_arch import CodeFormer from codeformer.basicsr.utils import img2tensor, tensor2img - from codeformer.basicsr.utils.registry import ARCH_REGISTRY from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper upscale = upscale.with_args(**kwargs) device = worker.get_device() - net = ARCH_REGISTRY.get("CodeFormer")( + net = CodeFormer( dim_embd=512, codebook_size=1024, n_head=8, @@ -68,7 +70,7 @@ def run( ) results = [] - for img in sources.as_numpy(): + for img in sources.as_arrays(): img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # clean all the intermediate results to process the next image face_helper.clean_all() @@ -122,4 +124,4 @@ def run( ) results.append(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB))) - return StageResult.from_images(results) + return StageResult.from_images(results, metadata=sources.metadata) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index f3ce33f3f..d80232150 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -4,10 +4,16 @@ from PIL import Image -from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..params import ( + DeviceParams, + HighresParams, + ImageParams, + StageParams, + UpscaleParams, +) from ..server import ModelTypes, ServerContext from ..utils import run_gc -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -60,6 +66,8 @@ def run( sources: StageResult, *, upscale: UpscaleParams, + callback: Optional[ProgressCallback] = None, + highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: @@ -74,7 +82,7 @@ def run( gfpgan = self.load(server, stage, upscale, device) outputs = [] - for source in sources.as_numpy(): + for source in sources.as_arrays(): cropped, restored, result = gfpgan.enhance( source, has_aligned=False, @@ -84,4 +92,4 @@ def run( ) outputs.append(result) - return StageResult.from_arrays(outputs) + return StageResult.from_arrays(outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/edit_metadata.py b/api/onnx_web/chain/edit_metadata.py new file mode 100644 index 000000000..02edaf783 --- /dev/null +++ b/api/onnx_web/chain/edit_metadata.py @@ -0,0 +1,54 @@ +from typing import Optional + +from ..params import ( + HighresParams, + ImageParams, + Size, + SizeChart, + StageParams, + UpscaleParams, +) +from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .result import StageResult + + +class EditMetadataStage(BaseStage): + max_tile = SizeChart.max + + def run( + self, + _worker: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: StageResult, + *, + callback: Optional[ProgressCallback] = None, + size: Optional[Size] = None, + upscale: Optional[UpscaleParams] = None, + highres: Optional[HighresParams] = None, + note: Optional[str] = None, + replace_params: Optional[ImageParams] = None, + **kwargs, + ) -> StageResult: + # Modify the source image's metadata using the provided parameters + for metadata in source.metadata: + if note is not None: + metadata.note = note + + if replace_params is not None: + metadata.params = replace_params + + if size is not None: + metadata.size = size + + if upscale is not None: + metadata.upscale = upscale + + if highres is not None: + metadata.highres = highres + + # Return the modified source image + return source diff --git a/api/onnx_web/chain/edit_safety.py b/api/onnx_web/chain/edit_safety.py new file mode 100644 index 000000000..d4635b2e4 --- /dev/null +++ b/api/onnx_web/chain/edit_safety.py @@ -0,0 +1,94 @@ +from logging import getLogger +from typing import Any, Optional + +from PIL import Image + +from ..errors import CancelledException +from ..output import save_metadata +from ..params import ImageParams, SizeChart, StageParams +from ..server import ServerContext +from ..server.model_cache import ModelTypes +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .result import StageResult + +logger = getLogger(__name__) + + +class EditSafetyStage(BaseStage): + max_tile = SizeChart.max + + def load(self, server: ServerContext, device: str) -> Any: + # keep these within run to make this sort of like a plugin or peer dependency + from horde_safety.deep_danbooru_model import get_deep_danbooru_model + from horde_safety.interrogate import get_interrogator_no_blip + from horde_safety.nsfw_checker_class import NSFWChecker + + # check cache + cache_key = ("horde-safety",) + cache_checker = server.cache.get(ModelTypes.safety, cache_key) + if cache_checker is not None: + return cache_checker + + # set up + interrogator = get_interrogator_no_blip(device=device) + deep_danbooru_model = get_deep_danbooru_model(device=device) + + nsfw_checker = NSFWChecker( + interrogator, + deep_danbooru_model, + ) + + server.cache.set(ModelTypes.safety, cache_key, nsfw_checker) + + return nsfw_checker + + def run( + self, + worker: WorkerContext, + server: ServerContext, + _stage: StageParams, + _params: ImageParams, + sources: StageResult, + *, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> StageResult: + logger.info("checking results using horde safety") + + try: + # set up + torch_device = worker.device.torch_str() + nsfw_checker = self.load(server, torch_device) + block_nsfw = server.has_feature("horde-safety-nsfw") + is_csam = False + + # check each output + images = sources.as_images() + results = [] + for i, image in enumerate(images): + metadata = sources.metadata[i] + prompt = metadata.params.prompt + check = nsfw_checker.check_for_nsfw(image, prompt=prompt) + + if check.is_csam: + logger.warning("flagging csam result: %s, %s", i, prompt) + is_csam = True + + report_name = f"csam-report-{worker.job}-{i}" + report_path = save_metadata(server, report_name, metadata) + logger.info("saved csam report: %s", report_path) + elif check.is_nsfw and block_nsfw: + logger.warning("blocking nsfw image: %s, %s", i, prompt) + results.append(Image.new("RGB", image.size, color="black")) + else: + results.append(image) + + if is_csam: + logger.warning("blocking csam result") + raise CancelledException(reason="csam") + else: + return StageResult.from_images(results, metadata=sources.metadata) + except ImportError: + logger.warning("horde safety not installed") + return StageResult.empty() diff --git a/api/onnx_web/chain/edit_text.py b/api/onnx_web/chain/edit_text.py new file mode 100644 index 000000000..2aa453d2d --- /dev/null +++ b/api/onnx_web/chain/edit_text.py @@ -0,0 +1,42 @@ +from typing import Optional, Tuple + +from PIL import ImageDraw + +from ..params import ImageParams, SizeChart, StageParams +from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .result import StageResult + + +class EditTextStage(BaseStage): + max_tile = SizeChart.max + + def run( + self, + _worker: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: StageResult, + *, + text: str, + position: Tuple[int, int], + fill: str = "white", + stroke: str = "black", + stroke_width: int = 1, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> StageResult: + # Add text to each image in source at the given position + results = [] + + for image in source.as_images(): + image = image.copy() + draw = ImageDraw.Draw(image) + draw.text( + position, text, fill=fill, stroke_width=stroke_width, stroke_fill=stroke + ) + results.append(image) + + return StageResult.from_images(results, source.metadata) diff --git a/api/onnx_web/chain/highres.py b/api/onnx_web/chain/highres.py index 2a43e0512..f48f7c40d 100644 --- a/api/onnx_web/chain/highres.py +++ b/api/onnx_web/chain/highres.py @@ -2,9 +2,10 @@ from typing import Optional from ..chain.blend_img2img import BlendImg2ImgStage +from ..chain.edit_metadata import EditMetadataStage from ..chain.upscale import stage_upscale_correction from ..chain.upscale_simple import UpscaleSimpleStage -from ..params import HighresParams, ImageParams, StageParams, UpscaleParams +from ..params import HighresParams, ImageParams, SizeChart, StageParams, UpscaleParams from .pipeline import ChainPipeline logger = getLogger(__name__) @@ -41,6 +42,7 @@ def stage_highres( faces=False, scale=highres.scale, outscale=highres.scale, + upscale=True, ), chain=chain, overlap=params.vae_overlap, @@ -52,7 +54,9 @@ def stage_highres( stage, method=highres.method, overlap=params.vae_overlap, - upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale), + upscale=upscale.with_args( + scale=highres.scale, outscale=highres.scale, upscale=True + ), ) chain.stage( @@ -63,4 +67,13 @@ def stage_highres( strength=highres.strength, ) + # add highres parameters to the image metadata and clear upscale stage + chain.stage( + EditMetadataStage(), + stage.with_args(outscale=1, tile_size=SizeChart.max), + highres=highres, + upscale=upscale, + replace_params=params, + ) + return chain diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index 28a088482..340956a9c 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -1,12 +1,12 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional from PIL import Image -from ..output import save_image -from ..params import ImageParams, Size, SizeChart, StageParams +from ..output import save_result +from ..params import ImageParams, SizeChart, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -18,21 +18,18 @@ class PersistDiskStage(BaseStage): def run( self, - _worker: WorkerContext, + worker: WorkerContext, server: ServerContext, _stage: StageParams, params: ImageParams, sources: StageResult, *, - output: List[str], - size: Optional[Size] = None, + callback: Optional[ProgressCallback] = None, stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: - logger.info("persisting %s images to disk: %s", len(sources), output) + logger.info("persisting %s images to disk", len(sources)) - for source, name in zip(sources.as_image(), output): - dest = save_image(server, name, source, params=params, size=size) - logger.info("saved image to %s", dest) + save_result(server, sources, worker.job) return sources diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 060afc4f4..c1825d1b0 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -1,13 +1,15 @@ from io import BytesIO +from json import dumps from logging import getLogger -from typing import List, Optional +from typing import Optional from boto3 import Session from PIL import Image +from ..output import make_output_names from ..params import ImageParams, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -17,23 +19,24 @@ class PersistS3Stage(BaseStage): def run( self, - _worker: WorkerContext, + worker: WorkerContext, server: ServerContext, _stage: StageParams, _params: ImageParams, sources: StageResult, *, - output: List[str], bucket: str, endpoint_url: Optional[str] = None, profile_name: Optional[str] = None, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: session = Session(profile_name=profile_name) s3 = session.client("s3", endpoint_url=endpoint_url) - for source, name in zip(sources.as_image(), output): + image_names = make_output_names(server, worker.job, len(sources)) + for source, name in zip(sources.as_images(), image_names): data = BytesIO() source.save(data, format=server.image_format) data.seek(0) @@ -44,4 +47,18 @@ def run( except Exception: logger.exception("error saving image to S3") + metadata_names = make_output_names( + server, worker.job, len(sources), extension="json" + ) + for metadata, name in zip(sources.metadata, metadata_names): + data = BytesIO() + data.write(dumps(metadata.tojson(server, [name]))) + data.seek(0) + + try: + s3.upload_fileobj(data, bucket, name) + logger.info("saved metadata to s3://%s/%s", bucket, name) + except Exception: + logger.exception("error saving metadata to S3") + return sources diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index ff3fae810..fed2f01e3 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -22,21 +22,29 @@ class ChainProgress: + parent: ProgressCallback + step: int # current number of steps + prev: int # accumulator when step resets + + # TODO: should probably be moved to worker context as well + result: Optional[StageResult] + def __init__(self, parent: ProgressCallback, start=0) -> None: self.parent = parent self.step = start - self.total = 0 + self.prev = 0 + self.result = None def __call__(self, step: int, timestep: int, latents: Any) -> None: if step < self.step: # accumulate on resets - self.total += self.step + self.prev += self.step self.step = step self.parent(self.get_total(), timestep, latents) def get_total(self) -> int: - return self.step + self.total + return self.step + self.prev @classmethod def from_progress(cls, parent: ProgressCallback): @@ -50,6 +58,8 @@ class ChainPipeline: tiles as needed. """ + stages: List[PipelineStage] + def __init__( self, stages: Optional[List[PipelineStage]] = None, @@ -81,7 +91,7 @@ def run( result = self( worker, server, params, sources=sources, callback=callback, **kwargs ) - return result.as_image() + return result.as_images() def stage(self, callback: BaseStage, params: StageParams, **kwargs): self.stages.append((callback, params, kwargs)) @@ -94,12 +104,8 @@ def steps(self, params: ImageParams, size: Size) -> int: return steps - def outputs(self, params: ImageParams, sources: int) -> int: - outputs = sources - for callback, _params, kwargs in self.stages: - outputs = callback.outputs(kwargs.get("params", params), outputs) - - return outputs + def stages(self) -> int: + return len(self.stages) def __call__( self, @@ -110,14 +116,22 @@ def __call__( callback: Optional[ProgressCallback] = None, **pipeline_kwargs, ) -> StageResult: - """ - DEPRECATED: use `.run()` instead - """ if callback is None: callback = worker.get_progress_callback() - else: + + # wrap the progress counter in a one that can be reset if needed + if not isinstance(callback, ChainProgress): callback = ChainProgress.from_progress(callback) + # set estimated totals + if "size" in pipeline_kwargs and isinstance(pipeline_kwargs["size"], Size): + size = pipeline_kwargs["size"] + else: + size = sources.size() + + total_steps = self.steps(params, size) + worker.set_totals(total_steps, stages=len(self.stages), tiles=0) + start = monotonic() if len(sources) > 0: @@ -129,7 +143,7 @@ def __call__( logger.info("running pipeline without source images") stage_sources = sources - for stage_pipe, stage_params, stage_kwargs in self.stages: + for stage_i, (stage_pipe, stage_params, stage_kwargs) in enumerate(self.stages): name = stage_params.name or stage_pipe.__class__.__name__ kwargs = stage_kwargs or {} kwargs = {**pipeline_kwargs, **kwargs} @@ -139,6 +153,7 @@ def __call__( len(stage_sources), kwargs.keys(), ) + worker.set_stages(stage_i) per_stage_params = params if "params" in kwargs: @@ -146,25 +161,21 @@ def __call__( kwargs.pop("params") # the stage must be split and tiled if any image is larger than the selected/max tile size - must_tile = has_mask(stage_kwargs) or any( - [ - needs_tile( - stage_pipe.max_tile, - stage_params.tile_size, - size=kwargs.get("size", None), - source=source, - ) - for source in stage_sources.as_image() - ] + must_tile = has_mask(stage_kwargs) or needs_tile( + stage_pipe.max_tile, + stage_params.tile_size, + size=kwargs.get("size", None), + source=stage_sources.size(), ) tile = stage_params.tile_size if stage_pipe.max_tile > 0: tile = min(stage_pipe.max_tile, stage_params.tile_size) + worker.set_tiles(0) if must_tile: logger.info( - "image contains sources or is larger than tile size of %s, tiling stage", + "image has mask or is larger than tile size of %s, tiling stage", tile, ) @@ -172,15 +183,19 @@ def stage_tile( source_tile: List[Image.Image], tile_mask: Image.Image, dims: Tuple[int, int, int], + progress: Tuple[int, int], ) -> List[Image.Image]: for _i in range(worker.retries): try: + stage_input = StageResult( + images=source_tile, metadata=stage_sources.metadata + ) tile_result = stage_pipe.run( worker, server, stage_params, per_stage_params, - StageResult(images=source_tile), + stage_input, tile_mask=tile_mask, callback=callback, dims=dims, @@ -188,9 +203,11 @@ def stage_tile( ) if is_debug(): - for j, image in enumerate(tile_result.as_image()): + for j, image in enumerate(tile_result.as_images()): save_image(server, f"last-tile-{j}.png", image) + worker.set_tiles(current=progress[0], total=progress[1]) + return tile_result except CancelledException as err: worker.retries = 0 @@ -207,7 +224,7 @@ def stage_tile( raise RetryException("exhausted retries on tile") - stage_results = process_tile_order( + stage_result = process_tile_order( stage_params.tile_order, stage_sources, tile, @@ -216,7 +233,7 @@ def stage_tile( **kwargs, ) - stage_sources = StageResult(images=stage_results) + stage_sources = stage_result else: logger.debug( "image does not contain sources and is within tile size of %s, running stage", @@ -260,8 +277,12 @@ def stage_tile( len(stage_sources), ) + callback.result = ( + stage_sources # this has just been set to the result of the last stage + ) + if is_debug(): - for j, image in enumerate(stage_sources.as_image()): + for j, image in enumerate(stage_sources.as_images()): save_image(server, f"last-stage-{j}.png", image) end = monotonic() @@ -271,6 +292,8 @@ def stage_tile( duration, len(stage_sources), ) + + callback.result = stage_sources return stage_sources diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index fe98fbd38..6261f0454 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -5,7 +5,7 @@ from ..params import ImageParams, Size, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -24,15 +24,16 @@ def run( origin: Size, size: Size, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): image = source.crop((origin.width, origin.height, size.width, size.height)) logger.info( "created thumbnail with dimensions: %sx%s", image.width, image.height ) outputs.append(image) - return StageResult(images=outputs) + return StageResult.from_images(outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 9c65a8199..565c8c0ac 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -1,10 +1,11 @@ from logging import getLogger +from typing import Optional from PIL import Image from ..params import ImageParams, Size, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -22,11 +23,12 @@ def run( *, size: Size, stage_source: Image.Image, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): image = source.copy() image = image.thumbnail((size.width, size.height)) @@ -37,4 +39,4 @@ def run( outputs.append(image) - return StageResult(images=outputs) + return StageResult.from_images(outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 813f5863d..1e117832a 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -1,8 +1,350 @@ -from typing import List, Optional +from json import dumps +from logging import getLogger +from os import path +from re import compile +from typing import Any, List, Optional, Tuple import numpy as np from PIL import Image +from ..convert.utils import resolve_tensor +from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams +from ..server.context import ServerContext +from ..server.load import get_extra_hashes +from ..utils import coalesce, hash_file, load_config_str + +logger = getLogger(__name__) + +FLOAT_PATTERN = compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?") + + +class NetworkMetadata: + name: str + hash: str + weight: float + + def __init__(self, name: str, hash: str, weight: float) -> None: + self.name = name + self.hash = hash + self.weight = weight + + +class ImageMetadata: + ancestors: List["ImageMetadata"] + note: str + params: ImageParams + size: Size + + # models + inversions: List[NetworkMetadata] + loras: List[NetworkMetadata] + models: List[NetworkMetadata] + + # optional params + border: Optional[Border] + highres: Optional[HighresParams] + upscale: Optional[UpscaleParams] + + @staticmethod + def unknown_image() -> "ImageMetadata": + UNKNOWN_STR = "unknown" + return ImageMetadata( + ImageParams(UNKNOWN_STR, UNKNOWN_STR, UNKNOWN_STR, "", 0, 0, 0), + Size(0, 0), + ) + + def __init__( + self, + params: ImageParams, + size: Size, + upscale: Optional[UpscaleParams] = None, + border: Optional[Border] = None, + highres: Optional[HighresParams] = None, + inversions: Optional[List[NetworkMetadata]] = None, + loras: Optional[List[NetworkMetadata]] = None, + models: Optional[List[NetworkMetadata]] = None, + ancestors: Optional[List["ImageMetadata"]] = None, + ) -> None: + self.params = params + self.size = size + self.upscale = upscale + self.border = border + self.highres = highres + self.inversions = inversions or [] + self.loras = loras or [] + self.models = models or [] + self.ancestors = ancestors or [] + self.note = "" + + def child( + self, + params: ImageParams, + size: Size, + upscale: Optional[UpscaleParams] = None, + border: Optional[Border] = None, + highres: Optional[HighresParams] = None, + inversions: Optional[List[NetworkMetadata]] = None, + loras: Optional[List[NetworkMetadata]] = None, + models: Optional[List[NetworkMetadata]] = None, + ) -> "ImageMetadata": + return ImageMetadata( + params, + size, + upscale, + border, + highres, + inversions, + loras, + models, + [self], + ) + + def get_model_hash( + self, server: ServerContext, model: Optional[str] = None + ) -> Tuple[str, str]: + model_name = path.basename(path.normpath(model or self.params.model)) + logger.debug("getting model hash for %s", model_name) + + if model_name in server.hash_cache: + logger.debug("using cached model hash for %s", model_name) + return (model_name, server.hash_cache[model_name]) + + model_hash = get_extra_hashes().get(model_name, None) + if model_hash is None: + model_hash_path = path.join(self.params.model, "hash.txt") + if path.exists(model_hash_path): + with open(model_hash_path, "r") as f: + model_hash = f.readline().rstrip(",. \n\t\r") + + model_hash = model_hash or "unknown" + server.hash_cache[model_name] = model_hash + + return (model_name, model_hash) + + def get_network_hash( + self, server: ServerContext, network_name: str, network_type: str + ) -> Tuple[str, str]: + # run this again just in case the file path changes + network_path = resolve_tensor( + path.join(server.model_path, network_type, network_name) + ) + + if network_path in server.hash_cache: + logger.debug("using cached network hash for %s", network_path) + return (network_name, server.hash_cache[network_path]) + + network_hash = hash_file(network_path).upper() + server.hash_cache[network_path] = network_hash + + return (network_name, network_hash) + + def to_exif(self, server: ServerContext) -> str: + model_name, model_hash = self.get_model_hash(server) + hash_map = { + model_name: model_hash, + } + + inversion_hashes = "" + if self.inversions is not None: + inversion_pairs = [ + ( + name, + self.get_network_hash(server, name, "inversion")[1], + ) + for name, _weight in self.inversions + ] + inversion_hashes = ",".join( + [f"{name}: {hash}" for name, hash in inversion_pairs] + ) + hash_map.update(dict(inversion_pairs)) + + lora_hashes = "" + if self.loras is not None: + lora_pairs = [ + ( + name, + self.get_network_hash(server, name, "lora")[1], + ) + for name, _weight in self.loras + ] + lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs]) + hash_map.update(dict(lora_pairs)) + + return ( + f"{self.params.prompt or ''}\nNegative prompt: {self.params.negative_prompt or ''}\n" + f"Steps: {self.params.steps}, Sampler: {self.params.scheduler}, CFG scale: {self.params.cfg}, " + f"Seed: {self.params.seed}, Size: {self.size.width}x{self.size.height}, " + f"Model hash: {model_hash}, Model: {model_name}, " + f"Tool: onnx-web, Version: {server.server_version}, " + f'Inversion hashes: "{inversion_hashes}", ' + f'Lora hashes: "{lora_hashes}", ' + f"Hashes: {dumps(hash_map)}" + ) + + def tojson(self, server: ServerContext, output: List[str]): + json = { + "input_size": self.size.tojson(), + "outputs": output, + "params": self.params.tojson(), + "inversions": [], + "loras": [], + "models": [], + } + + # fix up some fields + model_name, model_hash = self.get_model_hash(server, self.params.model) + json["params"]["model"] = model_name + json["models"].append( + { + "hash": model_hash, + "name": model_name, + "weight": 1.0, + } + ) + + # add optional params + if self.border is not None: + json["border"] = self.border.tojson() + + if self.highres is not None: + json["highres"] = self.highres.tojson() + + if self.upscale is not None: + json["upscale"] = self.upscale.tojson() + + # calculate final output size + json["size"] = self.get_output_size().tojson() + + # hash and add models and networks + if self.inversions is not None: + for name, weight in self.inversions: + model_hash = self.get_network_hash(server, name, "inversion")[1] + json["inversions"].append( + {"name": name, "weight": weight, "hash": model_hash} + ) + + if self.loras is not None: + for name, weight in self.loras: + model_hash = self.get_network_hash(server, name, "lora")[1] + json["loras"].append( + {"name": name, "weight": weight, "hash": model_hash} + ) + + if self.models is not None: + for name, weight in self.models: + name, model_hash = self.get_model_hash(server) + json["models"].append( + {"name": name, "weight": weight, "hash": model_hash} + ) + + return json + + def get_output_size(self) -> Size: + output_size = self.size + if self.border is not None: + output_size = output_size.add_border(self.border) + + if self.highres is not None: + output_size = self.highres.resize(output_size) + + if self.upscale is not None: + output_size = self.upscale.resize(output_size) + + return output_size + + def with_args( + self, + params: Optional[ImageParams] = None, + size: Optional[Size] = None, + upscale: Optional[UpscaleParams] = None, + border: Optional[Border] = None, + highres: Optional[HighresParams] = None, + inversions: Optional[List[NetworkMetadata]] = None, + loras: Optional[List[NetworkMetadata]] = None, + models: Optional[List[NetworkMetadata]] = None, + ancestors: Optional[List["ImageMetadata"]] = None, + **kwargs, + ) -> "ImageMetadata": + logger.info("ignoring extra kwargs for metadata: %s", list(kwargs.keys())) + return ImageMetadata( + params or self.params, + size or self.size, + upscale=coalesce(upscale, self.upscale), + border=coalesce(border, self.border), + highres=coalesce(highres, self.highres), + inversions=coalesce(inversions, self.inversions), + loras=coalesce(loras, self.loras), + models=coalesce(models, self.models), + ancestors=coalesce(ancestors, self.ancestors), + ) + + @staticmethod + def from_exif(input: str) -> "ImageMetadata": + lines = input.splitlines() + prompt, maybe_negative, *rest = lines + + # process negative prompt or put that line back into rest + if maybe_negative.startswith("Negative prompt:"): + negative_prompt = maybe_negative[len("Negative prompt:") :] + negative_prompt = negative_prompt.strip() + else: + rest.insert(0, maybe_negative) + negative_prompt = None + + rest = " ".join(rest) + other_params = rest.split(",") + + # process other params + params = {} + size = None + for param in other_params: + key, value = param.split(":") + key = key.strip().lower() + value = value.strip() + + if key == "size": + width, height = value.split("x") + width = int(width.strip()) + height = int(height.strip()) + size = Size(width, height) + elif value.isdecimal(): + value = int(value) + elif FLOAT_PATTERN.match(value) is not None: + value = float(value) + + params[key] = value + + params = ImageParams( + "TODO", + "txt2img", # TODO: can this be detected? + params["sampler"], + prompt, + params["cfg scale"], + params["steps"], + params["seed"], + negative_prompt, + ) + return ImageMetadata(params, size) + + @staticmethod + def from_json(input: str) -> "ImageMetadata": + data = load_config_str(input) + # TODO: enforce schema + + return ImageMetadata( + data["params"], + data["input_size"], + data.get("upscale", None), + data.get("border", None), + data.get("highres", None), + data.get("inversions", None), + data.get("loras", None), + data.get("models", None), + ) + + +ERROR_NO_METADATA = "metadata must be provided" + class StageResult: """ @@ -14,27 +356,50 @@ class StageResult: arrays: Optional[List[np.ndarray]] images: Optional[List[Image.Image]] + metadata: List[ImageMetadata] + + # output paths, filled in when the result is saved + outputs: Optional[List[str]] + thumbnails: Optional[List[str]] @staticmethod def empty(): return StageResult(images=[]) @staticmethod - def from_arrays(arrays: List[np.ndarray]): - return StageResult(arrays=arrays) + def from_arrays(arrays: List[np.ndarray], metadata: List[ImageMetadata]): + return StageResult(arrays=arrays, metadata=metadata) @staticmethod - def from_images(images: List[Image.Image]): - return StageResult(images=images) + def from_images(images: List[Image.Image], metadata: List[ImageMetadata]): + return StageResult(images=images, metadata=metadata) - def __init__(self, arrays=None, images=None) -> None: - if arrays is not None and images is not None: - raise ValueError("stages must only return one type of result") - elif arrays is None and images is None: - raise ValueError("stages must return results") + def __init__( + self, + arrays: Optional[List[np.ndarray]] = None, + images: Optional[List[Image.Image]] = None, + metadata: Optional[List[ImageMetadata]] = None, # TODO: should not be optional + source: Optional[Any] = None, + ) -> None: + data_provided = sum( + [arrays is not None, images is not None, source is not None] + ) + if data_provided > 1: + raise ValueError("results must only contain one type of data") + elif data_provided == 0: + raise ValueError("results must contain some data") - self.arrays = arrays - self.images = images + self.outputs = None + self.thumbnails = None + + if source is not None: + self.arrays = source.arrays + self.images = source.images + self.metadata = source.metadata + else: + self.arrays = arrays + self.images = images + self.metadata = metadata or [] def __len__(self) -> int: if self.arrays is not None: @@ -44,7 +409,7 @@ def __len__(self) -> int: else: return 0 - def as_numpy(self) -> List[np.ndarray]: + def as_arrays(self) -> List[np.ndarray]: if self.arrays is not None: return self.arrays elif self.images is not None: @@ -52,7 +417,7 @@ def as_numpy(self) -> List[np.ndarray]: else: return [] - def as_image(self) -> List[Image.Image]: + def as_images(self) -> List[Image.Image]: if self.images is not None: return self.images elif self.arrays is not None: @@ -60,6 +425,85 @@ def as_image(self) -> List[Image.Image]: else: return [] + def push_array(self, array: np.ndarray, metadata: ImageMetadata): + if self.arrays is not None: + self.arrays.append(array) + elif self.images is not None: + self.images.append(Image.fromarray(np.uint8(array), shape_mode(array))) + else: + self.arrays = [array] + + if metadata is not None: + self.metadata.append(metadata) + else: + raise ValueError(ERROR_NO_METADATA) + + def push_image(self, image: Image.Image, metadata: ImageMetadata): + if self.images is not None: + self.images.append(image) + elif self.arrays is not None: + self.arrays.append(np.array(image)) + else: + self.images = [image] + + if metadata is not None: + self.metadata.append(metadata) + else: + raise ValueError(ERROR_NO_METADATA) + + def insert_array(self, index: int, array: np.ndarray, metadata: ImageMetadata): + if self.arrays is not None: + self.arrays.insert(index, array) + elif self.images is not None: + self.images.insert( + index, Image.fromarray(np.uint8(array), shape_mode(array)) + ) + else: + self.arrays = [array] + + if metadata is not None: + self.metadata.insert(index, metadata) + else: + raise ValueError(ERROR_NO_METADATA) + + def insert_image(self, index: int, image: Image.Image, metadata: ImageMetadata): + if self.images is not None: + self.images.insert(index, image) + elif self.arrays is not None: + self.arrays.insert(index, np.array(image)) + else: + self.images = [image] + + if metadata is not None: + self.metadata.insert(index, metadata) + else: + raise ValueError(ERROR_NO_METADATA) + + def size(self) -> Size: + if self.images is not None: + return Size( + max([image.width for image in self.images], default=0), + max([image.height for image in self.images], default=0), + ) + elif self.arrays is not None: + return Size( + max([array.shape[0] for array in self.arrays], default=0), + max([array.shape[1] for array in self.arrays], default=0), + ) # TODO: which fields within the shape are width/height? + else: + return Size(0, 0) + + def validate(self) -> None: + """ + Make sure the data exists and that data and metadata match in length. + """ + + if self.arrays is None and self.images is None: + raise ValueError("no data in result") + + if len(self) != len(self.metadata): + raise ValueError("metadata and data do not match in length") + def shape_mode(arr: np.ndarray) -> str: if len(arr.shape) != 3: diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index d1b2eac2a..19554af2c 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -5,7 +5,7 @@ from ..params import ImageParams, Size, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -24,6 +24,7 @@ def run( size: Size, noise_source: Callable, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("generating image from noise source") @@ -36,13 +37,13 @@ def run( outputs = [] # TODO: looping over sources and ignoring params does not make much sense for a source stage - for source in sources.as_image(): + for source in sources.as_images(): output = noise_source(source, (size.width, size.height), (0, 0)) logger.info("final output image size: %sx%s", output.width, output.height) outputs.append(output) - return StageResult(images=outputs) + return StageResult.from_images(outputs, metadata=sources.metadata) def outputs( self, diff --git a/api/onnx_web/chain/source_s3.py b/api/onnx_web/chain/source_s3.py index d9a53acac..3f427d68a 100644 --- a/api/onnx_web/chain/source_s3.py +++ b/api/onnx_web/chain/source_s3.py @@ -7,9 +7,9 @@ from ..params import ImageParams, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage -from .result import StageResult +from .result import ImageMetadata, StageResult logger = getLogger(__name__) @@ -27,6 +27,7 @@ def run( bucket: str, endpoint_url: Optional[str] = None, profile_name: Optional[str] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: session = Session(profile_name=profile_name) @@ -37,7 +38,7 @@ def run( "source images were passed to a source stage, new images will be appended" ) - outputs = sources.as_image() + outputs = sources.as_images() for key in source_keys: try: logger.info("loading image from s3://%s/%s", bucket, key) @@ -49,7 +50,9 @@ def run( except Exception: logger.exception("error loading image from S3") - return StageResult(outputs) + # TODO: attempt to load metadata from s3 or load it from the image itself (exif data) + metadata = [ImageMetadata.unknown_image()] * len(outputs) + return StageResult(outputs, metadata=metadata) def outputs( self, diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 571e58ad4..891e47763 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -18,7 +18,7 @@ from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage -from .result import StageResult +from .result import ImageMetadata, StageResult logger = getLogger(__name__) @@ -115,7 +115,7 @@ def run( if params.is_lpw(): logger.debug("using LPW pipeline for txt2img") rng = torch.manual_seed(params.seed) - result = pipe.text2img( + output = pipe.text2img( prompt, height=latent_size.height, width=latent_size.width, @@ -141,7 +141,7 @@ def run( pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) - result = pipe( + output = pipe( prompt, height=latent_size.height, width=latent_size.width, @@ -155,10 +155,14 @@ def run( callback=callback, ) - outputs = sources.as_image() - outputs.extend(result.images) - logger.debug("produced %s outputs", len(outputs)) - return StageResult(images=outputs) + result = StageResult(source=sources) + for image in output.images: + result.push_image( + image, ImageMetadata(params, size, inversions=inversions, loras=loras) + ) + + logger.debug("produced %s outputs", len(result)) + return result def steps( self, diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index b6aa62cde..e25fa5907 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -7,9 +7,9 @@ from ..params import ImageParams, StageParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage -from .result import StageResult +from .result import ImageMetadata, StageResult logger = getLogger(__name__) @@ -25,6 +25,7 @@ def run( *, source_urls: List[str], stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("loading image from URL source") @@ -34,7 +35,7 @@ def run( "source images were passed to a source stage, new images will be appended" ) - outputs = sources.as_image() + outputs = sources.as_images() for url in source_urls: response = requests.get(url) output = Image.open(BytesIO(response.content)) @@ -42,7 +43,8 @@ def run( logger.info("final output image size: %sx%s", output.width, output.height) outputs.append(output) - return StageResult(images=outputs) + metadata = [ImageMetadata.unknown_image()] * len(outputs) + return StageResult(images=outputs, metadata=metadata) def outputs( self, diff --git a/api/onnx_web/chain/stages.py b/api/onnx_web/chain/stages.py index 0b3e6359a..bb6442337 100644 --- a/api/onnx_web/chain/stages.py +++ b/api/onnx_web/chain/stages.py @@ -9,6 +9,9 @@ from .blend_mask import BlendMaskStage from .correct_codeformer import CorrectCodeformerStage from .correct_gfpgan import CorrectGFPGANStage +from .edit_metadata import EditMetadataStage +from .edit_safety import EditSafetyStage +from .edit_text import EditTextStage from .persist_disk import PersistDiskStage from .persist_s3 import PersistS3Stage from .reduce_crop import ReduceCropStage @@ -17,6 +20,7 @@ from .source_s3 import SourceS3Stage from .source_txt2img import SourceTxt2ImgStage from .source_url import SourceURLStage +from .text_prompt import TextPromptStage from .upscale_bsrgan import UpscaleBSRGANStage from .upscale_highres import UpscaleHighresStage from .upscale_outpaint import UpscaleOutpaintStage @@ -38,6 +42,9 @@ "blend-mask": BlendMaskStage, "correct-codeformer": CorrectCodeformerStage, "correct-gfpgan": CorrectGFPGANStage, + "edit-metadata": EditMetadataStage, + "edit-safety": EditSafetyStage, + "edit-text": EditTextStage, "persist-disk": PersistDiskStage, "persist-s3": PersistS3Stage, "reduce-crop": ReduceCropStage, @@ -46,6 +53,7 @@ "source-s3": SourceS3Stage, "source-txt2img": SourceTxt2ImgStage, "source-url": SourceURLStage, + "text-prompt": TextPromptStage, "upscale-bsrgan": UpscaleBSRGANStage, "upscale-highres": UpscaleHighresStage, "upscale-outpaint": UpscaleOutpaintStage, diff --git a/api/onnx_web/chain/text_prompt.py b/api/onnx_web/chain/text_prompt.py new file mode 100644 index 000000000..fb0bd523f --- /dev/null +++ b/api/onnx_web/chain/text_prompt.py @@ -0,0 +1,95 @@ +from logging import getLogger +from random import randint +from re import match, sub +from typing import Optional + +from transformers import pipeline + +from ..diffusers.utils import split_prompt +from ..params import ImageParams, SizeChart, StageParams +from ..server import ServerContext +from ..worker import ProgressCallback, WorkerContext +from .base import BaseStage +from .result import StageResult + +logger = getLogger(__name__) + + +LENGTH_MARGIN = 15 +RETRY_LIMIT = 5 + + +class TextPromptStage(BaseStage): + max_tile = SizeChart.max + + def run( + self, + worker: WorkerContext, + server: ServerContext, + stage: StageParams, + params: ImageParams, + sources: StageResult, + *, + callback: Optional[ProgressCallback] = None, + prompt_filter: str, + remove_tokens: Optional[str] = None, + add_suffix: Optional[str] = None, + min_length: int = 80, + **kwargs, + ) -> StageResult: + device = worker.device.torch_str() + text_pipe = pipeline( + "text-generation", + model=prompt_filter, + device=device, + framework="pt", + ) + + prompt_parts = split_prompt(params.prompt) + prompt_results = [] + for prompt in prompt_parts: + retries = 0 + while len(prompt) < min_length and retries < RETRY_LIMIT: + max_length = len(prompt) + randint( + min_length - LENGTH_MARGIN, min_length + LENGTH_MARGIN + ) + logger.debug( + "extending input prompt to max length of %d from %s: %s", + max_length, + len(prompt), + prompt, + ) + + result = text_pipe( + prompt, max_length=max_length, num_return_sequences=1 + ) + prompt = result[0]["generated_text"].strip() + + if remove_tokens: + logger.debug( + "removing excluded tokens from prompt: %s", remove_tokens + ) + + remove_limit = 3 + while remove_limit > 0 and match(remove_tokens, prompt): + prompt = sub(remove_tokens, "", prompt) + remove_limit -= 1 + + if retries >= RETRY_LIMIT: + logger.warning( + "failed to extend input prompt to min length of %d, ended up with %d: %s", + min_length, + len(prompt), + prompt, + ) + + if add_suffix: + prompt = f"{prompt}, {add_suffix}" + logger.trace("adding suffix to prompt: %s", prompt) + + prompt_results.append(prompt) + + complete_prompt = " || ".join(prompt_results) + logger.debug("replacing input prompt: %s -> %s", params.prompt, complete_prompt) + params.prompt = complete_prompt + return sources diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index e8e1baff2..de66f80a5 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -9,7 +9,7 @@ from ..image.noise_source import noise_source_histogram from ..params import Size, TileOrder -from .result import StageResult +from .result import ImageMetadata, StageResult # from skimage.exposure import match_histograms @@ -26,7 +26,11 @@ class TileCallback(Protocol): """ def __call__( - self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int] + self, + sources: List[Image.Image], + mask: Image.Image, + dims: Tuple[int, int, int], + progress: Tuple[int, int], ) -> StageResult: """ Run this stage against a single tile. @@ -168,6 +172,7 @@ def blend_tiles( ) channels = max([get_channels(tile_image) for _left, _top, tile_image in tiles]) + channels = min(channels, 3) # remove alpha channels for now scaled_size = (height * scale, width * scale, channels) count = np.zeros(scaled_size) @@ -177,6 +182,11 @@ def blend_tiles( equalized = np.array(tile_image).astype(np.float32) mask = np.ones_like(equalized[:, :, 0]) + # match channels by removing the alpha channel, if present + if equalized.shape[-1] > value.shape[-1]: + logger.debug("removing alpha channel from tile") + equalized = equalized[:, :, :channels] + if adj_tile < tile: # sort gradient points p1 = (adj_tile * scale) - 1 @@ -229,7 +239,7 @@ def blend_tiles( ] += equalized[ margin_top : equalized.shape[0] + margin_bottom, margin_left : equalized.shape[1] + margin_right, - :, + :channels, ] count[ writable_top:writable_bottom, writable_left:writable_right, : @@ -256,8 +266,8 @@ def process_tile_stack( tile_generator: TileGenerator, overlap: float = 0.5, **kwargs, -) -> List[Image.Image]: - sources = stack.as_image() +) -> StageResult: + sources = stack.as_images() width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None) mask = kwargs.get("mask", kwargs.get("stage_mask", None)) @@ -266,14 +276,15 @@ def process_tile_stack( if not mask: tile_mask = None + metadata: List[ImageMetadata] = stack.metadata tiles: List[Tuple[int, int, Image.Image]] = [] tile_coords = tile_generator(width, height, tile, overlap) - single_tile = len(tile_coords) == 1 + + total_tiles = len(tile_coords) + single_tile = total_tiles == 1 for counter, (left, top) in enumerate(tile_coords): - logger.info( - "processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top - ) + logger.info("processing tile %s of %s, %sx%s", counter, total_tiles, left, top) right = left + tile bottom = top + tile @@ -308,7 +319,7 @@ def process_tile_stack( bottom_margin, ) tile_stack = add_margin( - stack.as_image(), + stack.as_images(), left, top, right, @@ -333,7 +344,6 @@ def process_tile_stack( ) tile_mask = Image.new("L", (tile, tile), color=0) tile_mask.paste(base_mask, (left_margin, top_margin)) - else: logger.debug("tiling normally") tile_stack = get_result_tile(stack, (left, top), Size(tile, tile)) @@ -341,12 +351,18 @@ def process_tile_stack( tile_mask = mask.crop((left, top, right, bottom)) for image_filter in filters: - tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile)) + tile_stack = image_filter( + tile_stack, tile_mask, (left, top, tile), (counter, total_tiles) + ) + # TODO: this should be inverted to extract them from the result if isinstance(tile_stack, list): - tile_stack = StageResult.from_images(tile_stack) + tile_stack = StageResult.from_images(tile_stack, metadata=stack.metadata) - tiles.append((left, top, tile_stack.as_image())) + # metadata gets replaced rather than combined, since it should be the same for each tile + # this will need to change if tiles can have individual metadata + metadata = tile_stack.metadata + tiles.append((left, top, tile_stack.as_images())) lefts, tops, stacks = list(zip(*tiles)) coords = list(zip(lefts, tops)) @@ -358,7 +374,7 @@ def process_tile_stack( stack_tiles = [(left, top, tile) for (left, top), tile in stack_tiles] result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap)) - return result + return StageResult(images=result, metadata=metadata) def process_tile_order( @@ -368,7 +384,7 @@ def process_tile_order( scale: int, filters: List[TileCallback], **kwargs, -) -> List[Image.Image]: +) -> StageResult: if order == TileOrder.grid: logger.debug("using grid tile order with tile size: %s", tile) return process_tile_stack( @@ -516,7 +532,7 @@ def get_result_tile( top, left = origin return [ layer.crop((top, left, top + tile.height, left + tile.width)) - for layer in result.as_image() + for layer in result.as_images() ] diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 08c077595..4b4d5752a 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -8,6 +8,7 @@ from ..models.onnx import OnnxModel from ..params import ( DeviceParams, + HighresParams, ImageParams, Size, SizeChart, @@ -16,7 +17,7 @@ ) from ..server import ModelTypes, ServerContext from ..utils import run_gc -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -47,7 +48,7 @@ def load( pipe = OnnxModel( server, model_path, - provider=device.ort_provider(), + provider=device.ort_provider("bsrgan"), sess_options=device.sess_options(), ) @@ -65,7 +66,9 @@ def run( sources: StageResult, *, upscale: UpscaleParams, + highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: upscale = upscale.with_args(**kwargs) @@ -79,7 +82,7 @@ def run( bsrgan = self.load(server, stage, upscale, device) outputs = [] - for source in sources.as_numpy(): + for source in sources.as_arrays(): image = source / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) @@ -105,7 +108,7 @@ def run( logger.debug("output image shape: %s", output.shape) outputs.append(output) - return StageResult(arrays=outputs) + return StageResult(arrays=outputs, metadata=sources.metadata) def steps( self, diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index 32f891a69..ed86b97ee 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -42,7 +42,7 @@ def run( source, callback=callback, ) - for source in sources.as_image() + for source in sources.as_images() ] - return StageResult(images=outputs) + return StageResult(images=outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 464f5920d..5d260db0c 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -19,7 +19,7 @@ from ..utils import is_debug from ..worker import ProgressCallback, WorkerContext from .base import BaseStage -from .result import StageResult +from .result import ImageMetadata, StageResult logger = getLogger(__name__) @@ -61,8 +61,9 @@ def run( loras=loras, ) + sizes = [] outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): if is_debug(): save_image(server, "tile-source.png", source) save_image(server, "tile-mask.png", tile_mask) @@ -121,6 +122,14 @@ def run( callback=callback, ) + sizes.extend([size] * len(result.images)) outputs.extend(result.images) - return StageResult(images=outputs) + metadata = [ + ImageMetadata( + params, size, inversions=inversions, loras=loras, ancestors=[source] + ) + for size, source in zip(sizes, sources.metadata) + ] + + return StageResult(images=outputs, metadata=metadata) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 33338b437..ff9cb98d4 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -5,10 +5,16 @@ from PIL import Image from ..onnx import OnnxRRDBNet -from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..params import ( + DeviceParams, + HighresParams, + ImageParams, + StageParams, + UpscaleParams, +) from ..server import ModelTypes, ServerContext from ..utils import run_gc -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -48,10 +54,10 @@ def __init__( self.model = model self.device = device - model_file = "%s.%s" % (params.upscale_model, params.format) + model_file = f"{params.upscale_model}.onnx" model_path = path.join(server.model_path, model_file) - cache_key = (model_path, params.format) + cache_key = (model_path, params.scale) cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key) if cache_pipe is not None: logger.info("reusing existing Real ESRGAN pipeline") @@ -64,7 +70,7 @@ def __init__( model = OnnxRRDBNet( server, model_file, - provider=device.ort_provider(), + provider=device.ort_provider("esrgan"), sess_options=device.sess_options(), ) @@ -102,7 +108,9 @@ def run( sources: StageResult, *, upscale: UpscaleParams, + highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) @@ -112,9 +120,12 @@ def run( ) outputs = [] - for source in sources.as_numpy(): + for source in sources.as_arrays(): output, _ = upsampler.enhance(source, outscale=upscale.outscale) logger.info("final output image size: %s", output.shape) outputs.append(output) - return StageResult(arrays=outputs) + for metadata in sources.metadata: + metadata.upscale = upscale + + return StageResult(arrays=outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/upscale_simple.py b/api/onnx_web/chain/upscale_simple.py index 7e939bd42..686f9c717 100644 --- a/api/onnx_web/chain/upscale_simple.py +++ b/api/onnx_web/chain/upscale_simple.py @@ -3,9 +3,9 @@ from PIL import Image -from ..params import ImageParams, StageParams, UpscaleParams +from ..params import ImageParams, SizeChart, StageParams, UpscaleParams from ..server import ServerContext -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -13,6 +13,8 @@ class UpscaleSimpleStage(BaseStage): + max_tile = SizeChart.max + def run( self, _worker: WorkerContext, @@ -24,6 +26,7 @@ def run( method: str, upscale: UpscaleParams, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: if upscale.scale <= 1: @@ -33,7 +36,7 @@ def run( return sources outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): scaled_size = (source.width * upscale.scale, source.height * upscale.scale) if method == "bilinear": @@ -49,4 +52,4 @@ def run( else: logger.warning("unknown upscaling method: %s", method) - return StageResult(images=outputs) + return StageResult(images=outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 6c8a300e8..d169fc3fe 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -59,7 +59,7 @@ def run( pipeline.unet.set_prompts(prompt_embeds) outputs = [] - for source in sources.as_image(): + for source in sources.as_images(): result = pipeline( prompt, source, @@ -73,4 +73,4 @@ def run( ) outputs.extend(result.images) - return StageResult(images=outputs) + return StageResult(images=outputs, metadata=sources.metadata) diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index ef7d421f5..2f6831aed 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -6,10 +6,17 @@ from PIL import Image from ..models.onnx import OnnxModel -from ..params import DeviceParams, ImageParams, SizeChart, StageParams, UpscaleParams +from ..params import ( + DeviceParams, + HighresParams, + ImageParams, + SizeChart, + StageParams, + UpscaleParams, +) from ..server import ModelTypes, ServerContext from ..utils import run_gc -from ..worker import WorkerContext +from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult @@ -40,7 +47,7 @@ def load( pipe = OnnxModel( server, model_path, - provider=device.ort_provider(), + provider=device.ort_provider("swinir"), sess_options=device.sess_options(), ) @@ -58,21 +65,27 @@ def run( sources: StageResult, *, upscale: UpscaleParams, + highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: upscale = upscale.with_args(**kwargs) if upscale.upscale_model is None: - logger.warning("no correction model given, skipping") + logger.warning("no upscale model given, skipping") return sources - logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model) + logger.info( + "upscaling %sx with SwinIR model: %s", + upscale.outscale, + upscale.upscale_model, + ) device = worker.get_device() swinir = self.load(server, stage, upscale, device) outputs = [] - for source in sources.as_numpy(): + for source in sources.as_arrays(): # TODO: add support for grayscale (1-channel) images image = source / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) @@ -98,4 +111,4 @@ def run( logger.info("output image size: %s", output.shape) outputs.append(output) - return StageResult(images=outputs) + return StageResult(images=outputs, metadata=sources.metadata) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index b7515b987..37d4bcf47 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -61,7 +61,7 @@ "archive": convert_extract_archive, "img2img": convert_diffusion_diffusers_optimum, "img2img-sdxl": convert_diffusion_diffusers_xl, - "inpaint": convert_diffusion_diffusers_legacy, + "inpaint": convert_diffusion_diffusers_optimum, "txt2img": convert_diffusion_diffusers_optimum, "txt2img-legacy": convert_diffusion_diffusers_legacy, "txt2img-sdxl": convert_diffusion_diffusers_xl, diff --git a/api/onnx_web/convert/diffusion/checkpoint.py b/api/onnx_web/convert/diffusion/checkpoint.py index 0362494ae..5d72baa87 100644 --- a/api/onnx_web/convert/diffusion/checkpoint.py +++ b/api/onnx_web/convert/diffusion/checkpoint.py @@ -612,12 +612,12 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[ - f"down_blocks.{block_id}.downsamplers.0.conv.weight" - ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") - new_checkpoint[ - f"down_blocks.{block_id}.downsamplers.0.conv.bias" - ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = ( + unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = ( + unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + ) paths = renew_resnet_paths(resnets) meta_path = { @@ -705,12 +705,12 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False index = list(output_block_list.values()).index( ["conv.bias", "conv.weight"] ) - new_checkpoint[ - f"up_blocks.{block_id}.upsamplers.0.conv.weight" - ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] - new_checkpoint[ - f"up_blocks.{block_id}.upsamplers.0.conv.bias" - ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = ( + unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = ( + unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] + ) # Clear attentions as they have been attributed above. if len(attentions) == 2: @@ -818,12 +818,12 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): ] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") - new_checkpoint[ - f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" - ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = ( + vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = ( + vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") + ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} @@ -871,12 +871,12 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] - new_checkpoint[ - f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" - ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = ( + vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] + ) + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = ( + vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] + ) paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} @@ -983,9 +983,9 @@ def convert_ldm_clip_checkpoint(checkpoint): "text_model." + key[len("cond_stage_model.transformer.") :] ] = checkpoint[key] else: - text_model_dict[ - key[len("cond_stage_model.transformer.") :] - ] = checkpoint[key] + text_model_dict[key[len("cond_stage_model.transformer.") :]] = ( + checkpoint[key] + ) text_model.load_state_dict(text_model_dict, strict=False) @@ -1109,9 +1109,9 @@ def convert_open_clip_checkpoint(checkpoint): else: logger.debug("no projection shape found, setting to 1024") d_model = 1024 - text_model_dict[ - "text_model.embeddings.position_ids" - ] = text_model.text_model.embeddings.get_buffer("position_ids") + text_model_dict["text_model.embeddings.position_ids"] = ( + text_model.text_model.embeddings.get_buffer("position_ids") + ) for key in keys: if ( diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index 969d0ea2d..d5de94ea9 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -20,6 +20,7 @@ AutoencoderKL, OnnxRuntimeModel, OnnxStableDiffusionPipeline, + StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline, StableDiffusionUpscalePipeline, @@ -48,13 +49,14 @@ remove_prefix, ) from .checkpoint import convert_extract_checkpoint +from .patches import patch_optimum logger = getLogger(__name__) CONVERT_PIPELINES = { "controlnet": OnnxStableDiffusionControlNetPipeline, "img2img": StableDiffusionPipeline, - "inpaint": StableDiffusionPipeline, + "inpaint": StableDiffusionInpaintPipeline, "lpw": StableDiffusionPipeline, "panorama": StableDiffusionPipeline, "pix2pix": StableDiffusionInstructPix2PixPipeline, @@ -841,6 +843,8 @@ def convert_diffusion_diffusers_optimum( del pipeline run_gc() + # patch Optimum for conversion and convert to ONNX + patch_optimum() main_export( temp_path, output=dest_path, @@ -850,6 +854,8 @@ def convert_diffusion_diffusers_optimum( "torch-fp16" ), # optimum's fp16 mode only works on CUDA or ROCm framework="pt", + library_name="diffusers", + do_validation=conversion.has_feature("optimum-validation"), ) if "hash" in model: diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index 59708fa73..5d88e1a59 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -94,6 +94,7 @@ def convert_diffusion_diffusers_xl( "torch-fp16" ), # optimum's fp16 mode only works on CUDA or ROCm framework="pt", + do_validation=conversion.has_feature("optimum-validation"), ) if "hash" in model: diff --git a/api/onnx_web/convert/diffusion/patches.py b/api/onnx_web/convert/diffusion/patches.py new file mode 100644 index 000000000..2a420416c --- /dev/null +++ b/api/onnx_web/convert/diffusion/patches.py @@ -0,0 +1,40 @@ +""" +Patches for optimum's internal conversion process. +""" + +from logging import getLogger + +from optimum.exporters.onnx import model_patcher + +logger = getLogger(__name__) + +original_override_arguments = model_patcher.override_arguments + + +def override_override_arguments(args, kwargs, signature, model_kwargs=None): + """ + Override the arguments of the `override_arguments` function. + """ + logger.debug( + "overriding arguments for `override_arguments`: %s, %s, %s", + args, + kwargs, + signature, + ) + + if "output_hidden_states" in signature.parameters: + logger.debug("enabling hidden states for model") + parameter_names = list(signature.parameters.keys()) + hidden_states_index = parameter_names.index("output_hidden_states") + + # convert the arguments to a list for modification + arg_list = list(args) + arg_list[hidden_states_index] = True + args = tuple(arg_list) + + return original_override_arguments(args, kwargs, signature, model_kwargs) + + +def patch_optimum(): + logger.info("installing patches for optimum's internal conversion process") + model_patcher.override_arguments = override_override_arguments diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 6b4dca351..69c7b37bc 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -107,7 +107,7 @@ def download_progress(source: str, dest: str): stream=True, allow_redirects=True, headers={ - "User-Agent": "onnx-web-api", + "User-Agent": "onnx-web-api", # TODO: add version }, ) if req.status_code != 200: @@ -226,9 +226,7 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]: logger.debug("loading tensor with Torch: %s", name) checkpoint = torch.load(name, map_location=map_location) except Exception: - logger.exception( - "error loading with Torch JIT, trying with Torch JIT: %s", name - ) + logger.exception("error loading with Torch, trying with Torch JIT: %s", name) checkpoint = torch.jit.load(name) return checkpoint diff --git a/api/onnx_web/device.py b/api/onnx_web/device.py new file mode 100644 index 000000000..32880e05b --- /dev/null +++ b/api/onnx_web/device.py @@ -0,0 +1,32 @@ +import gc +import threading +from logging import getLogger +from typing import List, Optional + +import torch + +from .params import DeviceParams + +logger = getLogger(__name__) + + +def run_gc(devices: Optional[List[DeviceParams]] = None): + logger.debug( + "running garbage collection with %s active threads", threading.active_count() + ) + gc.collect() + + if torch.cuda.is_available() and devices is not None: + for device in [d for d in devices if d.device.startswith("cuda")]: + logger.debug("running Torch garbage collection for device: %s", device) + with torch.cuda.device(device.torch_str()): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + mem_free, mem_total = torch.cuda.mem_get_info() + mem_pct = (1 - (mem_free / mem_total)) * 100 + logger.debug( + "CUDA VRAM usage: %s of %s (%.2f%%)", + (mem_total - mem_free), + mem_total, + mem_pct, + ) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 9e8a4e50b..350e033b2 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -13,11 +13,13 @@ from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors from ..convert.diffusion.textual_inversion import blend_textual_inversions from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline -from ..diffusers.utils import expand_prompt +from ..diffusers.utils import expand_prompt as encode_prompt_onnx_legacy from ..params import DeviceParams, ImageParams +from ..prompt.compel import encode_prompt_compel, encode_prompt_compel_sdxl from ..server import ModelTypes, ServerContext from ..torch_before_ort import InferenceSession from ..utils import run_gc +from .patches.scheduler import SchedulerPatch from .patches.unet import UNetWrapper from .patches.vae import VAEWrapper from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline @@ -53,6 +55,7 @@ available_pipelines = { "controlnet": OnnxStableDiffusionControlNetPipeline, + # "highres": OnnxStableDiffusionHighresPipeline, "img2img": OnnxStableDiffusionImg2ImgPipeline, "img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline, "inpaint": OnnxStableDiffusionInpaintPipeline, @@ -158,7 +161,7 @@ def load_pipeline( logger.debug("loading new diffusion scheduler") scheduler = scheduler_type.from_pretrained( model, - provider=device.ort_provider(), + provider=device.ort_provider("scheduler"), sess_options=device.sess_options(), subfolder="scheduler", torch_dtype=torch_dtype, @@ -179,7 +182,7 @@ def load_pipeline( logger.debug("loading new diffusion pipeline from %s", model) scheduler = scheduler_type.from_pretrained( model, - provider=device.ort_provider(), + provider=device.ort_provider("scheduler"), sess_options=device.sess_options(), subfolder="scheduler", torch_dtype=torch_dtype, @@ -231,6 +234,9 @@ def load_pipeline( ) else: if params.is_control(): + if "controlnet" not in components or components["controlnet"] is None: + raise ValueError("ControlNet is required for control pipelines") + logger.debug( "assembling SD pipeline for %s with ControlNet", pipeline_class.__name__, @@ -320,7 +326,7 @@ def load_controlnet(server: ServerContext, device: DeviceParams, params: ImagePa components["controlnet"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( cnet_path, - provider=device.ort_provider(), + provider=device.ort_provider("controlnet"), sess_options=device.sess_options(), ) ) @@ -448,7 +454,7 @@ def load_text_encoders( # session for te1 text_encoder_session = InferenceSession( text_encoder.SerializeToString(), - providers=[device.ort_provider("text-encoder")], + providers=[device.ort_provider("text-encoder", "sdxl")], sess_options=text_encoder_opts, ) text_encoder_session._model_path = path.join(model, "text_encoder") @@ -457,7 +463,7 @@ def load_text_encoders( # session for te2 text_encoder_2_session = InferenceSession( text_encoder_2.SerializeToString(), - providers=[device.ort_provider("text-encoder")], + providers=[device.ort_provider("text-encoder", "sdxl")], sess_options=text_encoder_2_opts, ) text_encoder_2_session._model_path = path.join(model, "text_encoder_2") @@ -511,7 +517,7 @@ def load_unet( if params.is_xl(): unet_session = InferenceSession( unet_model.SerializeToString(), - providers=[device.ort_provider("unet")], + providers=[device.ort_provider("unet", "sdxl")], sess_options=unet_opts, ) unet_session._model_path = path.join(model, "unet") @@ -551,7 +557,7 @@ def load_vae( logger.debug("loading VAE decoder from %s", vae_decoder) components["vae_decoder_session"] = OnnxRuntimeModel.load_model( vae_decoder, - provider=device.ort_provider("vae"), + provider=device.ort_provider("vae", "sdxl"), sess_options=device.sess_options(), ) components["vae_decoder_session"]._model_path = vae_decoder @@ -559,7 +565,7 @@ def load_vae( logger.debug("loading VAE encoder from %s", vae_encoder) components["vae_encoder_session"] = OnnxRuntimeModel.load_model( vae_encoder, - provider=device.ort_provider("vae"), + provider=device.ort_provider("vae", "sdxl"), sess_options=device.sess_options(), ) components["vae_encoder_session"]._model_path = vae_encoder @@ -637,6 +643,19 @@ def optimize_pipeline( logger.warning("error while enabling memory efficient attention: %s", e) +IMAGE_PIPELINES = [ + OnnxStableDiffusionControlNetPipeline, + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInstructPix2PixPipeline, + OnnxStableDiffusionLongPromptWeightingPipeline, + OnnxStableDiffusionPanoramaPipeline, + OnnxStableDiffusionUpscalePipeline, + ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLPanoramaPipeline, +] + + def patch_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, @@ -645,12 +664,32 @@ def patch_pipeline( ) -> None: logger.debug("patching SD pipeline") - if not params.is_lpw() and not params.is_xl(): - pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) + if server.has_feature("compel-prompts"): + logger.debug("patching prompt encoder with Compel") + if params.is_xl(): + pipe._encode_prompt = encode_prompt_compel_sdxl.__get__(pipe, pipeline) + else: + pipe._encode_prompt = encode_prompt_compel.__get__(pipe, pipeline) + else: + if not params.is_lpw() and not params.is_xl(): + logger.debug("patching prompt encoder with ONNX legacy method") + pipe._encode_prompt = encode_prompt_onnx_legacy.__get__(pipe, pipeline) + else: + logger.warning("no prompt encoder patch available") + + # the pipeline requested in params may not be the one currently being used, especially during the later img2img + # stages of a highres pipeline, so we need to check the pipeline type + is_text_pipeline = type(pipe) not in IMAGE_PIPELINES + logger.debug( + "patching pipeline scheduler for %s pipeline", + "txt2img" if is_text_pipeline else "img2img", + ) + original_scheduler = pipe.scheduler + pipe.scheduler = SchedulerPatch(server, original_scheduler, is_text_pipeline) + logger.debug("patching pipeline UNet") original_unet = pipe.unet pipe.unet = UNetWrapper(server, original_unet, params.is_xl()) - logger.debug("patched UNet with wrapper") if hasattr(pipe, "vae_decoder"): original_decoder = pipe.vae_decoder diff --git a/api/onnx_web/diffusers/patches/scheduler.py b/api/onnx_web/diffusers/patches/scheduler.py new file mode 100644 index 000000000..dabe7a66d --- /dev/null +++ b/api/onnx_web/diffusers/patches/scheduler.py @@ -0,0 +1,158 @@ +from logging import getLogger +from typing import Any, Literal + +import numpy as np +import torch +from diffusers.schedulers.scheduling_utils import SchedulerOutput +from torch import FloatTensor, Tensor + +from ...server.context import ServerContext + +logger = getLogger(__name__) + + +class SchedulerPatch: + server: ServerContext + text_pipeline: bool + wrapped: Any + + def __init__(self, server: ServerContext, scheduler: Any, text_pipeline: bool): + self.server = server + self.wrapped = scheduler + self.text_pipeline = text_pipeline + + def __getattr__(self, attr): + return getattr(self.wrapped, attr) + + def step( + self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor + ) -> SchedulerOutput: + result = self.wrapped.step(model_output, timestep, sample) + + if self.text_pipeline and self.server.has_feature("mirror-latents"): + logger.info("using experimental latent mirroring") + + if self.server.has_feature("mirror-latents-vertical"): + axis_of_symmetry = 2 + expand_dims = (0, 1, 3) + else: + axis_of_symmetry = 3 + expand_dims = (0, 1, 2) + + white_point = 2 + black_point = result.prev_sample.shape[axis_of_symmetry] // 4 + center_line = result.prev_sample.shape[axis_of_symmetry] // 2 + + gradient = linear_gradient( + white_point, black_point, center_line, expand_dims + ) + latents = result.prev_sample.numpy() + + # gradiated_latents = np.multiply(latents, gradient) + inverse_gradiated_latents = np.multiply( + np.flip(latents, axis=axis_of_symmetry), gradient + ) + latents += inverse_gradiated_latents + + mask = np.ones_like(latents).astype(np.float32) + # gradiated_mask = np.multiply(mask, gradient) + # flipping the mask would do nothing, we need to flip the gradient for this one + inverse_gradiated_mask = np.multiply( + mask, np.flip(gradient, axis=axis_of_symmetry) + ) + mask += inverse_gradiated_mask + + latents = np.where(mask > 0, latents / mask, latents) + + return SchedulerOutput( + prev_sample=torch.from_numpy(latents), + ) + else: + return result + + +def linear_gradient( + white_point: int, + black_point: int, + center_line: int, + expand_dims: tuple[int, ...] = (0, 1, 2), +) -> np.ndarray: + gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32) + gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1) + gradient = np.pad(gradient, (0, center_line - black_point), mode="constant") + gradient = np.reshape([gradient, np.flip(gradient)], -1) + return np.expand_dims(gradient, expand_dims) + + +def mirror_latents( + latents: np.ndarray, + gradient: np.ndarray, + center_line: int, + direction: Literal["horizontal", "vertical"], +) -> np.ndarray: + if direction == "horizontal": + pad_left = max(0, -center_line) + pad_right = max(0, 2 * center_line - latents.shape[3]) + + # create the symmetrical copies + padded_array = np.pad( + latents, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant" + ) + flipped_array = np.flip(padded_array, axis=3) + + # apply the gradient to both copies + padded_gradiated = np.multiply(padded_array, gradient) + flipped_gradiated = np.multiply(flipped_array, gradient) + + # produce masks + mask = np.ones_like(latents).astype(np.float32) + padded_mask = np.pad( + mask, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant" + ) + flipped_mask = np.flip(padded_mask, axis=3) + + padded_mask += np.multiply(padded_mask, gradient) + padded_mask += np.multiply(flipped_mask, gradient) + + # combine the two copies + result = padded_array + padded_gradiated + flipped_gradiated + result = np.where(padded_mask > 0, result / padded_mask, result) + return result[:, :, :, pad_left : pad_left + latents.shape[3]] + elif direction == "vertical": + pad_top = max(0, -center_line) + pad_bottom = max(0, 2 * center_line - latents.shape[2]) + + # create the symmetrical copies + padded_array = np.pad( + latents, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant" + ) + flipped_array = np.flip(padded_array, axis=2) + + # apply the gradient to both copies + padded_gradiated = np.multiply( + padded_array.transpose(0, 1, 3, 2), gradient + ).transpose(0, 1, 3, 2) + flipped_gradiated = np.multiply( + flipped_array.transpose(0, 1, 3, 2), gradient + ).transpose(0, 1, 3, 2) + + # produce masks + mask = np.ones_like(latents).astype(np.float32) + padded_mask = np.pad( + mask, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant" + ) + flipped_mask = np.flip(padded_mask, axis=2) + + padded_mask += np.multiply( + padded_mask.transpose(0, 1, 3, 2), gradient + ).transpose(0, 1, 3, 2) + padded_mask += np.multiply( + flipped_mask.transpose(0, 1, 3, 2), gradient + ).transpose(0, 1, 3, 2) + + # combine the two copies + result = padded_array + padded_gradiated + flipped_gradiated + result = np.where(padded_mask > 0, result / padded_mask, result) + return flipped_array[:, :, pad_top : pad_top + latents.shape[2], :] + else: + raise ValueError("Invalid direction. Must be 'horizontal' or 'vertical'.") diff --git a/api/onnx_web/diffusers/pipelines/base.py b/api/onnx_web/diffusers/pipelines/base.py new file mode 100644 index 000000000..b4275637c --- /dev/null +++ b/api/onnx_web/diffusers/pipelines/base.py @@ -0,0 +1,268 @@ +from typing import List, Optional, Union + +import numpy as np +from diffusers.configuration_utils import FrozenDict +from diffusers.pipelines.onnx_utils import OnnxRuntimeModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import deprecate, logging +from transformers import CLIPImageProcessor, CLIPTokenizer + +logger = logging.get_logger(__name__) + + +class OnnxStableDiffusionBasePipeline(DiffusionPipeline): + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPImageProcessor + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate( + "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: Optional[int], + do_classifier_free_guidance: bool, + negative_prompt: Optional[str], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`np.ndarray`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`np.ndarray`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="max_length", return_tensors="np" + ).input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids.astype(np.int32) + )[0] + + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = self.text_encoder( + input_ids=uncond_input.input_ids.astype(np.int32) + )[0] + + if do_classifier_free_guidance: + negative_prompt_embeds = np.repeat( + negative_prompt_embeds, num_images_per_prompt, axis=0 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) diff --git a/api/onnx_web/diffusers/pipelines/highres.py b/api/onnx_web/diffusers/pipelines/highres.py new file mode 100644 index 000000000..969792915 --- /dev/null +++ b/api/onnx_web/diffusers/pipelines/highres.py @@ -0,0 +1,417 @@ +import inspect +from logging import getLogger +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.models.unet_2d_condition import UNet2DConditionOutput +from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from transformers import CLIPImageProcessor, CLIPTokenizer + +from ...constants import LATENT_CHANNELS, LATENT_FACTOR, ONNX_MODEL +from ...convert.utils import onnx_export +from .base import OnnxStableDiffusionBasePipeline + +logger = getLogger(__name__) + + +class OnnxStableDiffusionHighresPipeline(OnnxStableDiffusionBasePipeline): + upscaler: OnnxRuntimeModel + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + upscaler: OnnxRuntimeModel = None, + ): + super().__init__( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker, + ) + + self.upscaler = upscaler + + @torch.no_grad() + def text2img( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + num_upscale_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + # check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, text_pooler_out = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # get the initial random noise unless the user supplied it + latents_dtype = prompt_embeds.dtype + latents_shape = ( + batch_size * num_images_per_prompt, + LATENT_CHANNELS, + height // LATENT_FACTOR, + width // LATENT_FACTOR, + ) + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif latents.shape != latents_shape: + raise ValueError( + f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}" + ) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + timestep_dtype = next( + ( + input.type + for input in self.unet.model.get_inputs() + if input.name == "timestep" + ), + "tensor(float)", + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + np.concatenate([latents] * 2) + if do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + torch.from_numpy(latent_model_input), t + ) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), + t, + torch.from_numpy(latents), + **extra_step_kwargs, + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if self.upscaler is not None: + # 5. set upscale timesteps + self.scheduler.set_timesteps(num_upscale_steps) + timesteps = self.scheduler.timesteps + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = np.concatenate([latents] * batch_multiplier) + + # 5. Add noise to image (set to be 0): + # (see below notes from the author): + # "the This step theoretically can make the model work better on out-of-distribution inputs, but mostly + # just seems to make it match the input less, so it's turned off by default." + noise_level = np.array([0.0], dtype=np.float32) + noise_level = np.concatenate([noise_level] * image.shape[0]) + inv_noise_level = (noise_level**2 + 1) ** (-0.5) + + image_cond = ( + F.interpolate(torch.tensor(image), scale_factor=2, mode="nearest") + * inv_noise_level[:, None, None, None] + ) + image_cond = image_cond.numpy().astype(prompt_embeds.dtype) + + noise_level_embed = np.concatenate( + [ + np.ones( + (text_pooler_out.shape[0], 64), dtype=text_pooler_out.dtype + ), + np.zeros( + (text_pooler_out.shape[0], 64), dtype=text_pooler_out.dtype + ), + ], + axis=1, + ) + + # upscaling latents + latents_shape = ( + batch_size * num_images_per_prompt, + LATENT_CHANNELS, + height * 2 // LATENT_FACTOR, + width * 2 // LATENT_FACTOR, + ) + latents = generator.randn(*latents_shape).astype(latents_dtype) + + timestep_condition = np.concatenate( + [noise_level_embed, text_pooler_out], axis=1 + ) + + num_warmup_steps = 0 + + with self.progress_bar(total=num_upscale_steps) as progress_bar: + for i, t in enumerate(timesteps): + sigma = self.scheduler.sigmas[i] + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + np.concatenate([latents] * 2) + if do_classifier_free_guidance + else latents + ) + scaled_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + scaled_model_input = np.concatenate( + [scaled_model_input, image_cond], axis=1 + ) + # preconditioning parameter based on Karras et al. (2022) (table 1) + timestep = np.log(sigma) * 0.25 + + noise_pred = self.upscaler( + sample=scaled_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_condition, + ).sample + + # in original repo, the output contains a variance channel that's not used + noise_pred = noise_pred[:, :-1] + + # apply preconditioning, based on table 1 in Karras et al. (2022) + inv_sigma = 1 / (sigma**2 + 1) + noise_pred = ( + inv_sigma * latent_model_input + + self.scheduler.scale_model_input(sigma, t) * noise_pred + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + noise_pred, t, torch.from_numpy(latents) + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + else: + logger.debug("skipping latent upscaler, no model provided") + + # decode image + latents = 1 / 0.18215 * latents + + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [ + self.vae_decoder(latent_sample=latents[i : i + 1])[0] + for i in range(latents.shape[0]) + ] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) + + +def export_unet(pipeline, output_path, unet_sample_size=1024): + device = torch.device("cpu") + dtype = torch.float32 + + num_tokens = pipeline.text_encoder.config.max_position_embeddings + text_hidden_size = pipeline.text_encoder.config.hidden_size + + unet_inputs = ["sample", "timestep", "encoder_hidden_states", "timestep_cond"] + unet_in_channels = pipeline.unet.config.in_channels + unet_path = output_path / "unet" / ONNX_MODEL + + logger.info("exporting UNet to %s", unet_path) + onnx_export( + pipeline.unet, + model_args=( + torch.randn( + 2, + unet_in_channels, + unet_sample_size // LATENT_FACTOR, + unet_sample_size // LATENT_FACTOR, + ).to(device=device, dtype=dtype), + torch.randn(2).to(device=device, dtype=dtype), + torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype), + torch.randn(2, 64, 64, 2).to( + device=device, dtype=dtype + ), # TODO: not the right shape + ), + output_path=unet_path, + ordered_input_names=unet_inputs, + # has to be different from "sample" for correct tracing + output_names=["out_sample"], + dynamic_axes={ + "sample": {0: "batch"}, # , 1: "channels", 2: "height", 3: "width"}, + "timestep": {0: "batch"}, + "encoder_hidden_states": {0: "batch", 1: "sequence"}, + }, + opset=14, + half=False, + external_data=True, + v2=False, + ) + + +def load_and_export(output, source="stabilityai/sd-x2-latent-upscaler"): + from pathlib import Path + + from diffusers import StableDiffusionLatentUpscalePipeline + + upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained( + source, torch_dtype=torch.float32 + ) + export_unet(upscaler, Path(output)) + + +def load_and_run( + prompt, + output, + source="stabilityai/sd-x2-latent-upscaler", + checkpoint="../models/stable-diffusion-onnx-v1-5", +): + from diffusers import ( + EulerAncestralDiscreteScheduler, + StableDiffusionLatentUpscalePipeline, + ) + + upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(source) + highres = OnnxStableDiffusionHighresPipeline.from_pretrained(checkpoint) + scheduler = EulerAncestralDiscreteScheduler.from_pretrained( + f"{checkpoint}/scheduler" + ) + + # combine them + highres.scheduler = scheduler + highres.upscaler = RetorchModel(upscaler.unet) + + # run + result = highres.text2img(prompt, num_inference_steps=25, num_upscale_steps=25) + image = result.images[0] + image.save(output) + + +class RetorchModel: + """ + Shim back from ONNX to PyTorch + """ + + def __init__(self, model) -> None: + self.model = model + + def __call__(self, **kwargs): + inputs = { + k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v + for k, v in kwargs.items() + } + outputs = self.model(**inputs) + return UNet2DConditionOutput(sample=outputs.sample.numpy()) diff --git a/api/onnx_web/diffusers/pipelines/lpw.py b/api/onnx_web/diffusers/pipelines/lpw.py index 8febc277d..f8b1a4cf2 100644 --- a/api/onnx_web/diffusers/pipelines/lpw.py +++ b/api/onnx_web/diffusers/pipelines/lpw.py @@ -465,6 +465,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) """ + if version.parse( version.parse(diffusers.__version__).base_version ) >= version.parse("0.9.0"): diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 317a8515b..fa7968668 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -13,18 +13,17 @@ # limitations under the License. import inspect +from logging import getLogger from math import ceil from typing import Callable, List, Optional, Tuple, Union import numpy as np import PIL import torch -from diffusers.configuration_utils import FrozenDict from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel -from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from diffusers.utils import PIL_INTERPOLATION, deprecate, logging +from diffusers.utils import PIL_INTERPOLATION, deprecate from transformers import CLIPImageProcessor, CLIPTokenizer from ...chain.tile import make_tile_mask @@ -37,8 +36,9 @@ repair_nan, resize_latent_shape, ) +from .base import OnnxStableDiffusionBasePipeline -logger = logging.get_logger(__name__) +logger = getLogger(__name__) # inpaint constants @@ -96,18 +96,7 @@ def prepare_mask_and_masked_image(image, mask, latents_shape): return mask, masked_image -class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): - vae_encoder: OnnxRuntimeModel - vae_decoder: OnnxRuntimeModel - text_encoder: OnnxRuntimeModel - tokenizer: CLIPTokenizer - unet: OnnxRuntimeModel - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] - safety_checker: OnnxRuntimeModel - feature_extractor: CLIPImageProcessor - - _optional_components = ["safety_checker", "feature_extractor"] - +class OnnxStableDiffusionPanoramaPipeline(OnnxStableDiffusionBasePipeline): def __init__( self, vae_encoder: OnnxRuntimeModel, @@ -122,65 +111,7 @@ def __init__( window: Optional[int] = None, stride: Optional[int] = None, ): - super().__init__() - - self.window = window or DEFAULT_WINDOW - self.stride = stride or DEFAULT_STRIDE - - if ( - hasattr(scheduler.config, "steps_offset") - and scheduler.config.steps_offset != 1 - ): - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate( - "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False - ) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if ( - hasattr(scheduler.config, "clip_sample") - and scheduler.config.clip_sample is True - ): - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate( - "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False - ) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( + super().__init__( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, @@ -189,173 +120,11 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + requires_safety_checker=requires_safety_checker, ) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def _encode_prompt( - self, - prompt: Union[str, List[str]], - num_images_per_prompt: Optional[int], - do_classifier_free_guidance: bool, - negative_prompt: Optional[str], - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - Args: - prompt (`str` or `List[str]`): - prompt to be encoded - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). - prompt_embeds (`np.ndarray`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`np.ndarray`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer( - prompt, padding="max_length", return_tensors="np" - ).input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = self.text_encoder( - input_ids=text_input_ids.astype(np.int32) - )[0] - - prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] * batch_size - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="np", - ) - negative_prompt_embeds = self.text_encoder( - input_ids=uncond_input.input_ids.astype(np.int32) - )[0] - - if do_classifier_free_guidance: - negative_prompt_embeds = np.repeat( - negative_prompt_embeds, num_images_per_prompt, axis=0 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def check_inputs( - self, - prompt: Union[str, List[str]], - height: Optional[int], - width: Optional[int], - callback_steps: int, - negative_prompt: Optional[str] = None, - prompt_embeds: Optional[np.ndarray] = None, - negative_prompt_embeds: Optional[np.ndarray] = None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError( - f"`height` and `width` have to be divisible by 8 but are {height} and {width}." - ) - - if (callback_steps is None) or ( - callback_steps is not None - and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and ( - not isinstance(prompt, str) and not isinstance(prompt, list) - ): - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) + self.window = window or DEFAULT_WINDOW + self.stride = stride or DEFAULT_STRIDE def get_views( self, panorama_height: int, panorama_width: int, window_size: int, stride: int diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index 5551267b4..d54877eed 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -598,7 +598,6 @@ def text2img( for i in range(latents.shape[0]) ] ) - image = self.watermark.apply_watermark(image) # TODO: add image_processor image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1)) @@ -917,7 +916,6 @@ def img2img( for i in range(latents.shape[0]) ] ) - image = self.watermark.apply_watermark(image) # TODO: add image_processor image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1)) diff --git a/api/onnx_web/diffusers/pipelines/pix2pix.py b/api/onnx_web/diffusers/pipelines/pix2pix.py index 2689fa4f9..d257315fd 100644 --- a/api/onnx_web/diffusers/pipelines/pix2pix.py +++ b/api/onnx_web/diffusers/pipelines/pix2pix.py @@ -88,6 +88,7 @@ class OnnxStableDiffusionInstructPix2PixPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index a3737d060..76822e3cb 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -5,31 +5,32 @@ from PIL import Image, ImageOps from ..chain import ( - BlendDenoiseFastNLMeansStage, BlendImg2ImgStage, BlendMaskStage, ChainPipeline, + EditSafetyStage, SourceTxt2ImgStage, + TextPromptStage, UpscaleOutpaintStage, ) from ..chain.highres import stage_highres -from ..chain.result import StageResult +from ..chain.result import ImageMetadata, StageResult from ..chain.upscale import split_upscale, stage_upscale_correction from ..image import expand_image -from ..output import save_image +from ..output import make_output_names, read_metadata, save_image, save_result from ..params import ( - Border, + ExperimentalParams, HighresParams, ImageParams, + RequestParams, Size, StageParams, - UpscaleParams, ) from ..server import ServerContext from ..server.load import get_source_filters from ..utils import is_debug, run_gc, show_system_toast from ..worker import WorkerContext -from .utils import get_latents_from_seed, parse_prompt +from .utils import get_latents_from_seed logger = getLogger(__name__) @@ -52,38 +53,64 @@ def get_highres_tile( return params.unet_tile +def add_safety_stage( + server: ServerContext, + pipeline: ChainPipeline, +) -> None: + if server.has_feature("horde-safety"): + pipeline.stage( + EditSafetyStage(), StageParams(tile_size=EditSafetyStage.max_tile) + ) + + +def add_prompt_filter( + server: ServerContext, + pipeline: ChainPipeline, + experimental: ExperimentalParams = None, +) -> None: + if experimental and experimental.prompt_editing.enabled: + if server.has_feature("prompt-filter"): + pipeline.stage( + TextPromptStage(), + StageParams(), + add_suffix=experimental.prompt_editing.add_suffix, + min_length=experimental.prompt_editing.min_length, + prompt_filter=experimental.prompt_editing.filter, + remove_tokens=experimental.prompt_editing.remove_tokens, + ) + else: + logger.warning("prompt editing is not supported by the server") + + def run_txt2img_pipeline( worker: WorkerContext, server: ServerContext, - params: ImageParams, - size: Size, - outputs: List[str], - upscale: UpscaleParams, - highres: HighresParams, + request: RequestParams, ) -> None: + params = request.image + size = request.size + upscale = request.upscale + highres = request.highres + # if using panorama, the pipeline will tile itself (views) tile_size = get_base_tile(params, size) # prepare the chain pipeline and first stage chain = ChainPipeline() + add_prompt_filter(server, chain, request.experimental) + chain.stage( SourceTxt2ImgStage(), StageParams( tile_size=tile_size, ), - size=size, + size=request.size, prompt_index=0, overlap=params.vae_overlap, ) # apply upscaling and correction, before highres highres_size = get_highres_tile(server, params, highres, tile_size) - if params.is_panorama(): - chain.stage( - BlendDenoiseFastNLMeansStage(), - StageParams(tile_size=highres_size), - ) - first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( @@ -111,48 +138,37 @@ def run_txt2img_pipeline( upscale=after_upscale, ) + add_safety_stage(server, chain) + # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) - progress = worker.get_progress_callback() - images = chain.run( + progress = worker.get_progress_callback(reset=True) + images = chain( worker, server, params, StageResult.empty(), callback=progress, latents=latents ) - _pairs, loras, inversions, _rest = parse_prompt(params) - - for image, output in zip(images, outputs): - logger.trace("saving output image %s: %s", output, image.size) - dest = save_image( - server, - output, - image, - params, - size, - upscale=upscale, - highres=highres, - inversions=inversions, - loras=loras, - ) + save_result(server, images, worker.job, save_thumbnails=params.thumbnail) # clean up run_gc([worker.get_device()]) # notify the user - show_system_toast(f"finished txt2img job: {dest}") - logger.info("finished txt2img job: %s", dest) + show_system_toast(f"finished txt2img job: {worker.job}") + logger.info("finished txt2img job: %s", worker.job) def run_img2img_pipeline( worker: WorkerContext, server: ServerContext, - params: ImageParams, - outputs: List[str], - upscale: UpscaleParams, - highres: HighresParams, + request: RequestParams, source: Image.Image, strength: float, source_filter: Optional[str] = None, ) -> None: + params = request.image + upscale = request.upscale + highres = request.highres + # run filter on the source image if source_filter is not None: f = get_source_filters().get(source_filter, None) @@ -215,51 +231,41 @@ def run_img2img_pipeline( chain=chain, ) + add_safety_stage(server, chain) + + # prep inputs + input_metadata = read_metadata(source) or ImageMetadata.unknown_image() + input_result = StageResult(images=[source], metadata=[input_metadata]) + # run and append the filtered source - progress = worker.get_progress_callback() - images = chain.run( - worker, server, params, StageResult(images=[source]), callback=progress + progress = worker.get_progress_callback(reset=True) + images = chain( + worker, + server, + params, + input_result, # terrible naming, I know + callback=progress, ) if source_filter is not None and source_filter != "none": - images.append(source) + images.push_image(source, ImageMetadata.unknown_image()) - # save with metadata - _pairs, loras, inversions, _rest = parse_prompt(params) - size = Size(*source.size) - - for image, output in zip(images, outputs): - dest = save_image( - server, - output, - image, - params, - size, - upscale=upscale, - highres=highres, - inversions=inversions, - loras=loras, - ) + save_result(server, images, worker.job, save_thumbnails=params.thumbnail) # clean up run_gc([worker.get_device()]) # notify the user - show_system_toast(f"finished img2img job: {dest}") - logger.info("finished img2img job: %s", dest) + show_system_toast(f"finished img2img job: {worker.job}") + logger.info("finished img2img job: %s", worker.job) def run_inpaint_pipeline( worker: WorkerContext, server: ServerContext, - params: ImageParams, - size: Size, - outputs: List[str], - upscale: UpscaleParams, - highres: HighresParams, + request: RequestParams, source: Image.Image, mask: Image.Image, - border: Border, noise_source: Any, mask_filter: Any, fill_color: str, @@ -267,6 +273,12 @@ def run_inpaint_pipeline( full_res_inpaint: bool, full_res_inpaint_padding: float, ) -> None: + border = request.border + params = request.image + size = request.size + upscale = request.upscale + highres = request.highres + logger.debug("building inpaint pipeline") tile_size = get_base_tile(params, size) @@ -279,7 +291,7 @@ def run_inpaint_pipeline( mask = ImageOps.contain(mask, (mask_max, mask_max)) mask = mask.crop((0, 0, source.width, source.height)) - source, mask, noise, full_size = expand_image( + source, mask, noise, _full_size = expand_image( source, mask, border, @@ -400,57 +412,74 @@ def run_inpaint_pipeline( chain=chain, ) + add_safety_stage(server, chain) + + # prep inputs + input_metadata = read_metadata(source) or ImageMetadata.unknown_image() + input_result = StageResult(images=[source], metadata=[input_metadata]) + # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) - progress = worker.get_progress_callback() - images = chain.run( + progress = worker.get_progress_callback(reset=True) + images = chain( worker, server, params, - StageResult(images=[source]), + input_result, callback=progress, latents=latents, ) - _pairs, loras, inversions, _rest = parse_prompt(params) - for image, output in zip(images, outputs): + # custom version of save for full-res inpainting + images.outputs = make_output_names(server, worker.job, len(images)) + for image, metadata, output in zip( + images.as_images(), images.metadata, images.outputs + ): if full_res_inpaint: if is_debug(): save_image(server, "adjusted-output.png", image) + mini_image = ImageOps.contain(image, (adj_mask_size, adj_mask_size)) image = original_source image.paste(mini_image, box=adj_mask_border) - dest = save_image( + + save_image( server, output, image, - params, - size, - upscale=upscale, - border=border, - inversions=inversions, - loras=loras, + metadata, + ) + + if params.thumbnail: + images.thumbnails = make_output_names( + server, worker.job, len(images), suffix="thumbnail" ) + for image, thumbnail in zip(images.as_images(), images.thumbnails): + save_image( + server, + thumbnail, + image, + ) # clean up - del image run_gc([worker.get_device()]) # notify the user - show_system_toast(f"finished inpaint job: {dest}") - logger.info("finished inpaint job: %s", dest) + show_system_toast(f"finished inpaint job: {worker.job}") + logger.info("finished inpaint job: %s", worker.job) def run_upscale_pipeline( worker: WorkerContext, server: ServerContext, - params: ImageParams, - size: Size, - outputs: List[str], - upscale: UpscaleParams, - highres: HighresParams, + request: RequestParams, source: Image.Image, ) -> None: + params = request.image + size = request.size + upscale = request.upscale + highres = request.highres + # set up the chain pipeline, no base stage for upscaling chain = ChainPipeline() tile_size = get_base_tile(params, size) @@ -484,51 +513,49 @@ def run_upscale_pipeline( chain=chain, ) + add_safety_stage(server, chain) + + # prep inputs + input_metadata = read_metadata(source) or ImageMetadata.unknown_image() + input_result = StageResult(images=[source], metadata=[input_metadata]) + # run and save - progress = worker.get_progress_callback() - images = chain.run( - worker, server, params, StageResult(images=[source]), callback=progress + progress = worker.get_progress_callback(reset=True) + images = chain( + worker, + server, + params, + input_result, + callback=progress, ) - _pairs, loras, inversions, _rest = parse_prompt(params) - for image, output in zip(images, outputs): - dest = save_image( - server, - output, - image, - params, - size, - upscale=upscale, - inversions=inversions, - loras=loras, - ) + save_result(server, images, worker.job, save_thumbnails=params.thumbnail) # clean up - del image run_gc([worker.get_device()]) # notify the user - show_system_toast(f"finished upscale job: {dest}") - logger.info("finished upscale job: %s", dest) + show_system_toast(f"finished upscale job: {worker.job}") + logger.info("finished upscale job: %s", worker.job) def run_blend_pipeline( worker: WorkerContext, server: ServerContext, - params: ImageParams, - size: Size, - outputs: List[str], - upscale: UpscaleParams, - # highres: HighresParams, + request: RequestParams, sources: List[Image.Image], mask: Image.Image, ) -> None: + params = request.image + size = request.size + upscale = request.upscale + # set up the chain pipeline and base stage chain = ChainPipeline() tile_size = get_base_tile(params, size) # resize mask to match source size - stage_source = sources[1] + stage_source = sources.pop() stage_mask = mask.resize(stage_source.size, Image.Resampling.BILINEAR) chain.stage( @@ -546,19 +573,29 @@ def run_blend_pipeline( chain=chain, ) + add_safety_stage(server, chain) + + # prep inputs + input_metadata = [ + read_metadata(source) or ImageMetadata.unknown_image() for source in sources + ] + input_result = StageResult(images=sources, metadata=input_metadata) + # run and save - progress = worker.get_progress_callback() - images = chain.run( - worker, server, params, StageResult(images=sources), callback=progress + progress = worker.get_progress_callback(reset=True) + images = chain( + worker, + server, + params, + input_result, + callback=progress, ) - for image, output in zip(images, outputs): - dest = save_image(server, output, image, params, size, upscale=upscale) + save_result(server, images, worker.job, save_thumbnails=params.thumbnail) # clean up - del image run_gc([worker.get_device()]) # notify the user - show_system_toast(f"finished blend job: {dest}") - logger.info("finished blend job: %s", dest) + show_system_toast(f"finished blend job: {worker.job}") + logger.info("finished blend job: %s", worker.job) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index ecf013734..06a3560bb 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -79,6 +79,17 @@ def expand_alternative_ranges(prompt: str) -> List[str]: return prompts +def split_clip_skip(prompt: str) -> Tuple[str, int]: + prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN) + + skip_clip_states = 0 + if len(clip_tokens) > 0: + skip_clip_states = int(clip_tokens[0][1]) + logger.info("skipping %s CLIP layers", skip_clip_states) + + return prompt, skip_clip_states + + @torch.no_grad() def expand_prompt( self: OnnxStableDiffusionPipeline, @@ -94,10 +105,7 @@ def expand_prompt( # tokenizer: CLIPTokenizer # encoder: OnnxRuntimeModel - prompt, clip_tokens = get_tokens_from_prompt(prompt, CLIP_TOKEN) - if len(clip_tokens) > 0: - skip_clip_states = int(clip_tokens[0][1]) - logger.info("skipping %s CLIP layers", skip_clip_states) + prompt, skip_clip_states = split_clip_skip(prompt) batch_size = len(prompt) if isinstance(prompt, list) else 1 prompt = expand_interval_ranges(prompt) @@ -403,9 +411,6 @@ def encode_prompt( num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, ) -> List[np.ndarray]: - """ - TODO: does not work with SDXL, fix or turn into a pipeline patch - """ return [ pipe._encode_prompt( remove_tokens(prompt), @@ -475,9 +480,16 @@ def repair_nan(tile: np.ndarray) -> np.ndarray: return tile +def split_prompt(prompt: str) -> List[str]: + if "||" in prompt: + return prompt.split("||") + + return [prompt] + + def slice_prompt(prompt: str, slice: int) -> str: if "||" in prompt: - parts = prompt.split("||") + parts = split_prompt(prompt) return parts[min(slice, len(parts) - 1)] else: return prompt diff --git a/api/onnx_web/errors.py b/api/onnx_web/errors.py index 4ab6f79b8..ef7a768f5 100644 --- a/api/onnx_web/errors.py +++ b/api/onnx_web/errors.py @@ -1,3 +1,6 @@ +from typing import Optional + + class RetryException(Exception): """ Used when a chain pipeline has run out of retries. @@ -11,7 +14,12 @@ class CancelledException(Exception): Used when a job has been cancelled and needs to stop. """ - pass + reason: Optional[str] + + def __init__(self, *args: object, reason: Optional[str] = None) -> None: + super().__init__(*args) + + self.reason = reason class RequestException(Exception): diff --git a/api/onnx_web/models/swinir.py b/api/onnx_web/models/swinir.py index 81c969313..ff9649930 100644 --- a/api/onnx_web/models/swinir.py +++ b/api/onnx_web/models/swinir.py @@ -495,9 +495,9 @@ def __init__( qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] - if isinstance(drop_path, list) - else drop_path, + drop_path=( + drop_path[i] if isinstance(drop_path, list) else drop_path + ), norm_layer=norm_layer, ) for i in range(depth) diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 08704ffe3..570491321 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -1,165 +1,44 @@ from hashlib import sha256 from json import dumps from logging import getLogger -from os import path -from struct import pack from time import time -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional from piexif import ExifIFD, ImageIFD, dump from piexif.helper import UserComment from PIL import Image, PngImagePlugin -from .convert.utils import resolve_tensor -from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams +from .chain.result import ImageMetadata, StageResult +from .params import ImageParams, Param, Size from .server import ServerContext -from .server.load import get_extra_hashes -from .utils import base_join +from .utils import base_join, hash_value logger = getLogger(__name__) -HASH_BUFFER_SIZE = 2**22 # 4MB +def make_output_names( + server: ServerContext, + job_name: str, + count: int = 1, + offset: int = 0, + extension: Optional[str] = None, + suffix: Optional[str] = None, +) -> List[str]: + if suffix is not None: + job_name = f"{job_name}_{suffix}" -def hash_file(name: str): - sha = sha256() - with open(name, "rb") as f: - while True: - data = f.read(HASH_BUFFER_SIZE) - if not data: - break - - sha.update(data) - - return sha.hexdigest() - - -def hash_value(sha, param: Optional[Param]): - if param is None: - return - elif isinstance(param, bool): - sha.update(bytearray(pack("!B", param))) - elif isinstance(param, float): - sha.update(bytearray(pack("!f", param))) - elif isinstance(param, int): - sha.update(bytearray(pack("!I", param))) - elif isinstance(param, str): - sha.update(param.encode("utf-8")) - else: - logger.warning("cannot hash param: %s, %s", param, type(param)) - - -def json_params( - outputs: List[str], - params: ImageParams, - size: Size, - upscale: Optional[UpscaleParams] = None, - border: Optional[Border] = None, - highres: Optional[HighresParams] = None, - parent: Optional[Dict] = None, -) -> Any: - json = { - "input_size": size.tojson(), - "outputs": outputs, - "params": params.tojson(), - } - - json["params"]["model"] = path.basename(params.model) - json["params"]["scheduler"] = params.scheduler - - # calculate final output size - output_size = size - if border is not None: - json["border"] = border.tojson() - output_size = output_size.add_border(border) - - if highres is not None: - json["highres"] = highres.tojson() - output_size = highres.resize(output_size) - - if upscale is not None: - json["upscale"] = upscale.tojson() - output_size = upscale.resize(output_size) - - json["size"] = output_size.tojson() - - return json + return [ + f"{job_name}_{i}.{extension or server.image_format}" + for i in range(offset, count + offset) + ] -def str_params( - server: ServerContext, - params: ImageParams, - size: Size, - inversions: List[Tuple[str, float]] = None, - loras: List[Tuple[str, float]] = None, -) -> str: - model_name = path.basename(path.normpath(params.model)) - logger.debug("getting model hash for %s", model_name) - - model_hash = get_extra_hashes().get(model_name, None) - if model_hash is None: - model_hash_path = path.join(params.model, "hash.txt") - if path.exists(model_hash_path): - with open(model_hash_path, "r") as f: - model_hash = f.readline().rstrip(",. \n\t\r") - - model_hash = model_hash or "unknown" - hash_map = { - model_name: model_hash, - } - - inversion_hashes = "" - if inversions is not None: - inversion_pairs = [ - ( - name, - hash_file( - resolve_tensor(path.join(server.model_path, "inversion", name)) - ).upper(), - ) - for name, _weight in inversions - ] - inversion_hashes = ",".join( - [f"{name}: {hash}" for name, hash in inversion_pairs] - ) - hash_map.update(dict(inversion_pairs)) - - lora_hashes = "" - if loras is not None: - lora_pairs = [ - ( - name, - hash_file( - resolve_tensor(path.join(server.model_path, "lora", name)) - ).upper(), - ) - for name, _weight in loras - ] - lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs]) - hash_map.update(dict(lora_pairs)) - - return ( - f"{params.prompt or ''}\nNegative prompt: {params.negative_prompt or ''}\n" - f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, " - f"Seed: {params.seed}, Size: {size.width}x{size.height}, " - f"Model hash: {model_hash}, Model: {model_name}, " - f"Tool: onnx-web, Version: {server.server_version}, " - f'Inversion hashes: "{inversion_hashes}", ' - f'Lora hashes: "{lora_hashes}", ' - f"Hashes: {dumps(hash_map)}" - ) - - -def make_output_name( - server: ServerContext, +def make_job_name( mode: str, params: ImageParams, size: Size, extras: Optional[List[Optional[Param]]] = None, - count: Optional[int] = None, - offset: int = 0, -) -> List[str]: - count = count or params.batch +) -> str: now = int(time()) sha = sha256() @@ -181,48 +60,77 @@ def make_output_name( for param in extras: hash_value(sha, param) - return [ - f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}" - for i in range(offset, count + offset) - ] + return f"{mode}_{params.seed}_{sha.hexdigest()}_{now}" + + +def save_result( + server: ServerContext, + result: StageResult, + base_name: str, + save_thumbnails: bool = False, +) -> List[str]: + images = result.as_images() + result.outputs = make_output_names(server, base_name, len(images)) + logger.debug("saving %s images: %s", len(images), result.outputs) + + outputs = [] + for image, metadata, filename in zip(images, result.metadata, result.outputs): + outputs.append( + save_image( + server, + filename, + image, + metadata, + ) + ) + + if save_thumbnails: + result.thumbnails = make_output_names( + server, + base_name, + len(images), + suffix="thumbnail", + ) + logger.debug("saving %s thumbnails: %s", len(images), result.thumbnails) + + thumbnails = [] + for image, filename in zip(images, result.thumbnails): + # TODO: only make a thumbnail if the image is larger than the thumbnail size + thumbnail = image.copy() + thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) + + thumbnails.append( + save_image( + server, + filename, + thumbnail, + ) + ) + + return outputs def save_image( server: ServerContext, output: str, image: Image.Image, - params: Optional[ImageParams] = None, - size: Optional[Size] = None, - upscale: Optional[UpscaleParams] = None, - border: Optional[Border] = None, - highres: Optional[HighresParams] = None, - inversions: List[Tuple[str, float]] = None, - loras: List[Tuple[str, float]] = None, + metadata: Optional[ImageMetadata] = None, ) -> str: path = base_join(server.output_path, output) if server.image_format == "png": exif = PngImagePlugin.PngInfo() - if params is not None: + if metadata is not None: exif.add_text("make", "onnx-web") exif.add_text( "maker note", - dumps( - json_params( - [output], - params, - size, - upscale=upscale, - border=border, - highres=highres, - ) - ), + dumps(metadata.tojson(server, [output])), ) exif.add_text("model", server.server_version) exif.add_text( "parameters", - str_params(server, params, size, inversions=inversions, loras=loras), + metadata.to_exif(server), ) image.save(path, format=server.image_format, pnginfo=exif) @@ -231,22 +139,11 @@ def save_image( { "0th": { ExifIFD.MakerNote: UserComment.dump( - dumps( - json_params( - [output], - params, - size, - upscale=upscale, - border=border, - highres=highres, - ) - ), + dumps(metadata.tojson(server, [output])), encoding="unicode", ), ExifIFD.UserComment: UserComment.dump( - str_params( - server, params, size, inversions=inversions, loras=loras - ), + metadata.to_exif(server), encoding="unicode", ), ImageIFD.Make: "onnx-web", @@ -256,35 +153,41 @@ def save_image( ) image.save(path, format=server.image_format, exif=exif) - if params is not None: - save_params( + if metadata is not None: + save_metadata( server, output, - params, - size, - upscale=upscale, - border=border, - highres=highres, + metadata, ) logger.debug("saved output image to: %s", path) return path -def save_params( +def save_metadata( server: ServerContext, output: str, - params: ImageParams, - size: Size, - upscale: Optional[UpscaleParams] = None, - border: Optional[Border] = None, - highres: Optional[HighresParams] = None, + metadata: ImageMetadata, ) -> str: path = base_join(server.output_path, f"{output}.json") - json = json_params( - output, params, size, upscale=upscale, border=border, highres=highres - ) + json = metadata.tojson(server, [output]) with open(path, "w") as f: f.write(dumps(json)) logger.debug("saved image params to: %s", path) return path + + +def read_metadata( + image: Image.Image, +) -> Optional[ImageMetadata]: + exif_data = image.getexif() + + if ImageIFD.Make in exif_data and exif_data[ImageIFD.Make] == "onnx-web": + return ImageMetadata.from_json(exif_data[ExifIFD.MakerNote]) + + if ExifIFD.UserComment in exif_data: + return ImageMetadata.from_exif(exif_data[ExifIFD.UserComment]) + + # this could return ImageMetadata.unknown_image(), but that would not indicate whether the input + # had metadata or not, so it's easier to return None and follow the call with `or ImageMetadata.unknown_image()` + return None diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 8d2551830..04aa3713f 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -5,12 +5,15 @@ from .models.meta import NetworkModel from .torch_before_ort import GraphOptimizationLevel, SessionOptions +from .utils import coalesce logger = getLogger(__name__) Param = Union[str, int, float] Point = Tuple[int, int] +UpscaleOrder = Literal["correction-first", "correction-last", "correction-both"] +UpscaleMethod = Literal["bilinear", "lanczos", "upscale"] class SizeChart(IntEnum): @@ -57,12 +60,20 @@ def tojson(self): "bottom": self.bottom, } - def with_args(self, **kwargs): + def with_args( + self, + left: Optional[int] = None, + right: Optional[int] = None, + top: Optional[int] = None, + bottom: Optional[int] = None, + **kwargs, + ): + logger.debug("ignoring extra kwargs for border: %s", list(kwargs.keys())) return Border( - kwargs.get("left", self.left), - kwargs.get("right", self.right), - kwargs.get("top", self.top), - kwargs.get("bottom", self.bottom), + left or self.left, + right or self.right, + top or self.top, + bottom or self.bottom, ) @classmethod @@ -81,6 +92,12 @@ def __iter__(self): def __str__(self) -> str: return "%sx%s" % (self.width, self.height) + def __eq__(self, other: Any) -> bool: + if isinstance(other, Size): + return self.width == other.width and self.height == other.height + + return False + def add_border(self, border: Border): return Size( border.left + self.width + border.right, @@ -105,10 +122,16 @@ def tojson(self) -> Dict[str, int]: "height": self.height, } - def with_args(self, **kwargs): + def with_args( + self, + height: Optional[int] = None, + width: Optional[int] = None, + **kwargs, + ): + logger.debug("ignoring extra kwargs for size: %s", list(kwargs.keys())) return Size( - kwargs.get("width", self.width), - kwargs.get("height", self.height), + width or self.width, + height or self.height, ) @@ -130,13 +153,22 @@ def __str__(self) -> str: return "%s - %s (%s)" % (self.device, self.provider, self.options) def ort_provider( - self, model_type: Optional[str] = None + self, + model_type: str, + suffix: Optional[str] = None, ) -> Union[str, Tuple[str, Any]]: - if model_type is not None: - # check if model has been pinned to CPU - # TODO: check whether the CPU device is allowed - if f"onnx-cpu-{model_type}" in self.optimizations: - return "CPUExecutionProvider" + # check if model has been pinned to CPU + # TODO: check whether the CPU device is allowed + if f"onnx-cpu-{model_type}" in self.optimizations: + logger.debug("pinning %s to CPU", model_type) + return "CPUExecutionProvider" + + if ( + suffix is not None + and f"onnx-cpu-{model_type}-{suffix}" in self.optimizations + ): + logger.debug("pinning %s-%s to CPU", model_type, suffix) + return "CPUExecutionProvider" if self.options is None: return self.provider @@ -213,6 +245,7 @@ class ImageParams: vae_tile: int vae_overlap: float denoise: int + thumbnail: int def __init__( self, @@ -236,6 +269,7 @@ def __init__( vae_overlap: float = 0.25, vae_tile: int = 512, denoise: int = 3, + thumbnail: int = 1, ) -> None: self.model = model self.pipeline = pipeline @@ -257,6 +291,7 @@ def __init__( self.vae_overlap = vae_overlap self.vae_tile = vae_tile self.denoise = denoise + self.thumbnail = thumbnail def do_cfg(self): return self.cfg > 1.0 @@ -328,6 +363,7 @@ def tojson(self) -> Dict[str, Optional[Param]]: "vae_overlap": self.vae_overlap, "vae_tile": self.vae_tile, "denoise": self.denoise, + "thumbnail": self.thumbnail, } def with_args(self, **kwargs): @@ -352,6 +388,7 @@ def with_args(self, **kwargs): kwargs.get("vae_overlap", self.vae_overlap), kwargs.get("vae_tile", self.vae_tile), kwargs.get("denoise", self.denoise), + kwargs.get("thumbnail", self.thumbnail), ) @@ -375,13 +412,18 @@ def __init__( def with_args( self, + name: Optional[str] = None, + outscale: Optional[int] = None, + tile_order: Optional[str] = None, + tile_size: Optional[int] = None, **kwargs, ): + logger.debug("ignoring extra kwargs for stage: %s", list(kwargs.keys())) return StageParams( - name=kwargs.get("name", self.name), - outscale=kwargs.get("outscale", self.outscale), - tile_order=kwargs.get("tile_order", self.tile_order), - tile_size=kwargs.get("tile_size", self.tile_size), + name=coalesce(name, self.name), + outscale=coalesce(outscale, self.outscale), + tile_order=coalesce(tile_order, self.tile_order), + tile_size=coalesce(tile_size, self.tile_size), ) @@ -395,14 +437,11 @@ def __init__( faces=True, face_outscale: int = 1, face_strength: float = 0.5, - format: Literal["onnx", "pth"] = "onnx", # TODO: deprecated, remove outscale: int = 1, scale: int = 4, pre_pad: int = 0, tile_pad: int = 10, - upscale_order: Literal[ - "correction-first", "correction-last", "correction-both" - ] = "correction-first", + upscale_order: UpscaleOrder = "correction-first", ) -> None: self.upscale_model = upscale_model self.correction_model = correction_model @@ -411,7 +450,6 @@ def __init__( self.faces = faces self.face_outscale = face_outscale self.face_strength = face_strength - self.format = format self.outscale = outscale self.pre_pad = pre_pad self.scale = scale @@ -427,7 +465,6 @@ def rescale(self, scale: int): faces=self.faces, face_outscale=self.face_outscale, face_strength=self.face_strength, - format=self.format, outscale=scale, scale=scale, pre_pad=self.pre_pad, @@ -436,13 +473,19 @@ def rescale(self, scale: int): ) def resize(self, size: Size) -> Size: - face_outscale = self.face_outscale - if self.upscale_order == "correction-both": - face_outscale *= self.face_outscale + face_outscale = 1 + if self.faces: + face_outscale = self.face_outscale + if self.upscale_order == "correction-both": + face_outscale *= self.face_outscale + + upscale_outscale = 1 + if self.upscale: + upscale_outscale = self.outscale return Size( - size.width * self.outscale * face_outscale, - size.height * self.outscale * face_outscale, + size.width * face_outscale * upscale_outscale, + size.height * face_outscale * upscale_outscale, ) def tojson(self): @@ -454,7 +497,6 @@ def tojson(self): "faces": self.faces, "face_outscale": self.face_outscale, "face_strength": self.face_strength, - "format": self.format, "outscale": self.outscale, "pre_pad": self.pre_pad, "scale": self.scale, @@ -462,21 +504,36 @@ def tojson(self): "upscale_order": self.upscale_order, } - def with_args(self, **kwargs): + def with_args( + self, + upscale_model: Optional[str] = None, + correction_model: Optional[str] = None, + denoise: Optional[float] = None, + upscale: Optional[bool] = None, + faces: Optional[bool] = None, + face_outscale: Optional[int] = None, + face_strength: Optional[float] = None, + outscale: Optional[int] = None, + scale: Optional[int] = None, + pre_pad: Optional[int] = None, + tile_pad: Optional[int] = None, + upscale_order: Optional[UpscaleOrder] = None, + **kwargs, + ): + logger.debug("ignoring extra kwargs for upscale: %s", list(kwargs.keys())) return UpscaleParams( - kwargs.get("upscale_model", self.upscale_model), - kwargs.get("correction_model", self.correction_model), - kwargs.get("denoise", self.denoise), - kwargs.get("upscale", self.upscale), - kwargs.get("faces", self.faces), - kwargs.get("face_outscale", self.face_outscale), - kwargs.get("face_strength", self.face_strength), - kwargs.get("format", self.format), - kwargs.get("outscale", self.outscale), - kwargs.get("scale", self.scale), - kwargs.get("pre_pad", self.pre_pad), - kwargs.get("tile_pad", self.tile_pad), - kwargs.get("upscale_order", self.upscale_order), + upscale_model=coalesce(upscale_model, self.upscale_model), + correction_model=coalesce(correction_model, self.correction_model), + denoise=coalesce(denoise, self.denoise), + upscale=coalesce(upscale, self.upscale), + faces=coalesce(faces, self.faces), + face_outscale=coalesce(face_outscale, self.face_outscale), + face_strength=coalesce(face_strength, self.face_strength), + outscale=coalesce(outscale, self.outscale), + scale=coalesce(scale, self.scale), + pre_pad=coalesce(pre_pad, self.pre_pad), + tile_pad=coalesce(tile_pad, self.tile_pad), + upscale_order=coalesce(upscale_order, self.upscale_order), ) @@ -487,7 +544,7 @@ def __init__( scale: int, steps: int, strength: float, - method: Literal["bilinear", "lanczos", "upscale"] = "lanczos", + method: UpscaleMethod = "lanczos", iterations: int = 1, ): self.enabled = enabled @@ -516,3 +573,122 @@ def tojson(self): "steps": self.steps, "strength": self.strength, } + + def with_args( + self, + enabled: Optional[bool] = None, + scale: Optional[int] = None, + steps: Optional[int] = None, + strength: Optional[float] = None, + method: Optional[UpscaleMethod] = None, + iterations: Optional[int] = None, + **kwargs, + ): + logger.debug("ignoring extra kwargs for highres: %s", list(kwargs.keys())) + return HighresParams( + enabled=coalesce(enabled, self.enabled), + scale=coalesce(scale, self.scale), + steps=coalesce(steps, self.steps), + strength=coalesce(strength, self.strength), + method=coalesce(method, self.method), + iterations=coalesce(iterations, self.iterations), + ) + + +class LatentSymmetryParams: + enabled: bool + gradient_start: float + gradient_end: float + line_of_symmetry: float + + def __init__( + self, + enabled: bool, + gradient_start: float, + gradient_end: float, + line_of_symmetry: float, + ) -> None: + self.enabled = enabled + self.gradient_start = gradient_start + self.gradient_end = gradient_end + self.line_of_symmetry = line_of_symmetry + + +class PromptEditingParams: + enabled: bool + filter: str + remove_tokens: str + add_suffix: str + min_length: int + + def __init__( + self, + enabled: bool, + filter: str, + remove_tokens: str, + add_suffix: str, + min_length: int, + ) -> None: + self.enabled = enabled + self.filter = filter + self.remove_tokens = remove_tokens + self.add_suffix = add_suffix + self.min_length = min_length + + +class ExperimentalParams: + latent_symmetry: LatentSymmetryParams + prompt_editing: PromptEditingParams + + def __init__( + self, + latent_symmetry: LatentSymmetryParams, + prompt_editing: PromptEditingParams, + ) -> None: + self.latent_symmetry = latent_symmetry + self.prompt_editing = prompt_editing + + +class RequestParams: + device: DeviceParams + image: ImageParams + size: Optional[Size] + border: Optional[Border] + upscale: Optional[UpscaleParams] + highres: Optional[HighresParams] + experimental: Optional[ExperimentalParams] + + def __init__( + self, + device: DeviceParams, + image: ImageParams, + size: Optional[Size] = None, + border: Optional[Border] = None, + upscale: Optional[UpscaleParams] = None, + highres: Optional[HighresParams] = None, + experimental: Optional[ExperimentalParams] = None, + ) -> None: + self.device = device + self.image = image + self.size = size + self.border = border + self.upscale = upscale + self.highres = highres + self.experimental = experimental + + +def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]: + if val is None: + return SizeChart.auto + + if type(val) is int: + return val + + if type(val) is str: + for size in SizeChart: + if val == size.name: + return size + + return int(val) + + raise ValueError("invalid size") diff --git a/api/onnx_web/prompt/base.py b/api/onnx_web/prompt/base.py new file mode 100644 index 000000000..6e1e5654e --- /dev/null +++ b/api/onnx_web/prompt/base.py @@ -0,0 +1,190 @@ +from typing import Any, List, Optional, Union + + +class PromptNetwork: + type: str + name: str + strength: float + + def __init__(self, type: str, name: str, strength: float) -> None: + self.type = type + self.name = name + self.strength = strength + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, self.__class__) + and other.type == self.type + and other.name == self.name + and other.strength == self.strength + ) + + def __repr__(self) -> str: + return f"PromptNetwork({self.type}, {self.name}, {self.strength})" + + +class PromptPhrase: + phrase: str + weight: float + + def __init__(self, phrase: str, weight: float = 1.0) -> None: + self.phrase = phrase + self.weight = weight + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, self.__class__) + and other.phrase == self.phrase + and other.weight == self.weight + ) + + def __repr__(self) -> str: + return f"PromptPhrase({self.phrase}, {self.weight})" + + +class PromptRegion: + top: int + left: int + bottom: int + right: int + prompt: str + append: bool + + def __init__( + self, + top: int, + left: int, + bottom: int, + right: int, + prompt: str, + append: bool, + ) -> None: + self.top = top + self.left = left + self.bottom = bottom + self.right = right + self.prompt = prompt + self.append = append + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, self.__class__) + and other.top == self.top + and other.left == self.left + and other.bottom == self.bottom + and other.right == self.right + and other.prompt == self.prompt + and other.append == self.append + ) + + def __repr__(self) -> str: + return f"PromptRegion({self.top}, {self.left}, {self.bottom}, {self.right}, {self.prompt}, {self.append})" + + +class PromptSeed: + top: int + left: int + bottom: int + right: int + seed: int + + def __init__(self, top: int, left: int, bottom: int, right: int, seed: int) -> None: + self.top = top + self.left = left + self.bottom = bottom + self.right = right + self.seed = seed + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, self.__class__) + and other.top == self.top + and other.left == self.left + and other.bottom == self.bottom + and other.right == self.right + and other.seed == self.seed + ) + + def __repr__(self) -> str: + return f"PromptSeed({self.top}, {self.left}, {self.bottom}, {self.right}, {self.seed})" + + +class Prompt: + clip_skip: int + networks: List[PromptNetwork] + positive_phrases: List[PromptPhrase] + negative_phrases: List[PromptPhrase] + region_prompts: List[PromptRegion] + region_seeds: List[PromptSeed] + + def __init__( + self, + networks: Optional[List[PromptNetwork]], + positive_phrases: List[PromptPhrase], + negative_phrases: List[PromptPhrase], + region_prompts: List[PromptRegion], + region_seeds: List[PromptSeed], + clip_skip: int, + ) -> None: + self.positive_phrases = positive_phrases + self.negative_phrases = negative_phrases + self.networks = networks or [] + self.region_prompts = region_prompts or [] + self.region_seeds = region_seeds or [] + self.clip_skip = clip_skip + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, self.__class__) + and other.networks == self.networks + and other.positive_phrases == self.positive_phrases + and other.negative_phrases == self.negative_phrases + and other.region_prompts == self.region_prompts + and other.region_seeds == self.region_seeds + and other.clip_skip == self.clip_skip + ) + + def __repr__(self) -> str: + return f"Prompt({self.networks}, {self.positive_phrases}, {self.negative_phrases}, {self.region_prompts}, {self.region_seeds}, {self.clip_skip})" + + def collapse_runs(self) -> None: + self.positive_phrases = collapse_phrases(self.positive_phrases) + self.negative_phrases = collapse_phrases(self.negative_phrases) + + +def collapse_phrases( + nodes: List[Union[Any]], +) -> List[Union[Any]]: + """ + Combine phrases with the same weight. + """ + + weight = None + tokens = [] + phrases = [] + + def flush_tokens(): + nonlocal weight, tokens + if len(tokens) > 0: + phrase = " ".join(tokens) + phrases.append(PromptPhrase([phrase], weight)) + tokens = [] + weight = None + + for node in nodes: + if isinstance(node, str): + node = PromptPhrase(node) + elif isinstance(node, (PromptNetwork, PromptRegion, PromptSeed)): + flush_tokens() + phrases.append(node) + continue + + if node.weight == weight: + tokens.extend(node.phrase) + else: + flush_tokens() + tokens = node.phrase + weight = node.weight + + flush_tokens() + return phrases diff --git a/api/onnx_web/prompt/compel.py b/api/onnx_web/prompt/compel.py new file mode 100644 index 000000000..cc2db1694 --- /dev/null +++ b/api/onnx_web/prompt/compel.py @@ -0,0 +1,170 @@ +from types import SimpleNamespace +from typing import List, Optional, Union + +import numpy as np +import torch +from compel import Compel, ReturnedEmbeddingsType +from diffusers import OnnxStableDiffusionPipeline + +from ..diffusers.utils import split_clip_skip + + +def get_inference_session(model): + if hasattr(model, "session") and model.session is not None: + return model.session + + if hasattr(model, "model") and model.model is not None: + return model.model + + raise ValueError("model does not have an inference session") + + +def wrap_encoder(text_encoder): + class WrappedEncoder: + device = "cpu" + + def __init__(self, text_encoder): + self.text_encoder = text_encoder + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward( + self, token_ids, attention_mask, output_hidden_states=None, return_dict=True + ): + """ + If `output_hidden_states` is None, return pooled embeds. + """ + dtype = np.int32 + session = get_inference_session(self.text_encoder) + if session.get_inputs()[0].type == "tensor(int64)": + dtype = np.int64 + + # TODO: does compel use attention masks? + outputs = text_encoder(input_ids=token_ids.numpy().astype(dtype)) + + if output_hidden_states is None: + return SimpleNamespace( + text_embeds=torch.from_numpy(outputs[0]), + last_hidden_state=torch.from_numpy(outputs[1]), + ) + elif output_hidden_states is True: + hidden_states = [torch.from_numpy(state) for state in outputs[2:]] + return SimpleNamespace( + last_hidden_state=torch.from_numpy(outputs[0]), + pooler_output=torch.from_numpy(outputs[1]), + hidden_states=hidden_states, + ) + else: + return SimpleNamespace( + last_hidden_state=torch.from_numpy(outputs[0]), + pooler_output=torch.from_numpy(outputs[1]), + ) + + def __getattr__(self, name): + return getattr(self.text_encoder, name) + + return WrappedEncoder(text_encoder) + + +@torch.no_grad() +def encode_prompt_compel( + self: OnnxStableDiffusionPipeline, + prompt: str, + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, +) -> np.ndarray: + """ + Text encoder patch for SD v1 and v2. + + Using clip skip requires an ONNX model compiled with `return_hidden_states=True`. + """ + prompt, skip_clip_states = split_clip_skip(prompt) + + embeddings_type = ( + ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED + if skip_clip_states == 0 + else ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED + ) + wrapped_encoder = wrap_encoder(self.text_encoder) + compel = Compel( + tokenizer=self.tokenizer, + text_encoder=wrapped_encoder, + returned_embeddings_type=embeddings_type, + ) + + prompt_embeds = compel(prompt) + + if negative_prompt is None: + negative_prompt = "" + + negative_prompt_embeds = compel(negative_prompt) + + if negative_prompt_embeds is not None: + [prompt_embeds, negative_prompt_embeds] = ( + compel.pad_conditioning_tensors_to_same_length( + [prompt_embeds, negative_prompt_embeds] + ) + ) + + prompt_embeds = prompt_embeds.numpy().astype(np.float32) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.float32) + + prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + +@torch.no_grad() +def encode_prompt_compel_sdxl( + self: OnnxStableDiffusionPipeline, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, list]] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, +) -> np.ndarray: + wrapped_encoder = wrap_encoder(self.text_encoder) + wrapped_encoder_2 = wrap_encoder(self.text_encoder_2) + compel = Compel( + tokenizer=[self.tokenizer, self.tokenizer_2], + text_encoder=[wrapped_encoder, wrapped_encoder_2], + returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, + requires_pooled=[False, True], + ) + + prompt, _skip_clip_states = split_clip_skip(prompt) + prompt_embeds, prompt_pooled = compel(prompt) + + negative_pooled = None + if negative_prompt is None: + negative_prompt = "" + + negative_prompt_embeds, negative_pooled = compel(negative_prompt) + + if negative_prompt_embeds is not None: + [prompt_embeds, negative_prompt_embeds] = ( + compel.pad_conditioning_tensors_to_same_length( + [prompt_embeds, negative_prompt_embeds] + ) + ) + + prompt_embeds = prompt_embeds.numpy().astype(np.float32) + prompt_pooled = prompt_pooled.numpy().astype(np.float32) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.numpy().astype(np.float32) + negative_pooled = negative_pooled.numpy().astype(np.float32) + + return ( + prompt_embeds, + negative_prompt_embeds, + prompt_pooled, + negative_pooled, + ) diff --git a/api/onnx_web/prompt/grammar.py b/api/onnx_web/prompt/grammar.py index 20127030b..ed43a8e9c 100644 --- a/api/onnx_web/prompt/grammar.py +++ b/api/onnx_web/prompt/grammar.py @@ -2,6 +2,12 @@ from arpeggio import EOF, OneOrMore, PTNodeVisitor, RegExMatch +from .utils import collapse_phrases, flatten + + +def token_delimiter(): + return ":" + def token(): return RegExMatch(r"\w+") @@ -11,6 +17,66 @@ def token_run(): return OneOrMore(token) +def decimal(): + return RegExMatch(r"\d+\.\d*") + + +def integer(): + return RegExMatch(r"\d+") + + +def token_clip_skip(): + return ("clip", token_delimiter, "skip", token_delimiter, integer) + + +def token_inversion(): + return ("inversion", token_delimiter, token_run, token_delimiter, decimal) + + +def token_lora(): + return ("lora", token_delimiter, token_run, token_delimiter, decimal) + + +def token_region(): + return ( + "region", + token_delimiter, + integer, + token_delimiter, + integer, + token_delimiter, + integer, + token_delimiter, + integer, + token_delimiter, + decimal, + token_delimiter, + decimal, + token_delimiter, + token_run, + ) + + +def token_reseed(): + return ( + "reseed", + token_delimiter, + integer, + token_delimiter, + integer, + token_delimiter, + integer, + token_delimiter, + integer, + token_delimiter, + integer, + ) + + +def token_inner(): + return [token_clip_skip, token_inversion, token_lora, token_region, token_reseed] + + def phrase_inner(): return [phrase, token_run] @@ -23,15 +89,19 @@ def neg_phrase(): return ("[", OneOrMore(phrase_inner), "]") +def token_phrase(): + return ("<", OneOrMore(token_inner), ">") + + def phrase(): - return [pos_phrase, neg_phrase, token_run] + return [pos_phrase, neg_phrase, token_phrase, token_run] def prompt(): return OneOrMore(phrase), EOF -class PromptPhrase: +class PhraseNode: def __init__(self, tokens: Union[List[str], str], weight: float = 1.0) -> None: self.tokens = tokens self.weight = weight @@ -46,6 +116,26 @@ def __eq__(self, other: object) -> bool: return False +class TokenNode: + def __init__(self, type: str, name: str, *rest): + self.type = type + self.name = name + self.rest = rest + + def __repr__(self) -> str: + return f"<{self.type}:{self.name}:{self.rest}>" + + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): + return ( + other.type == self.type + and other.name == self.name + and other.rest == self.rest + ) + + return False + + class OnnxPromptVisitor(PTNodeVisitor): def __init__(self, defaults=True, weight=0.5, **kwargs): super().__init__(defaults, **kwargs) @@ -53,34 +143,64 @@ def __init__(self, defaults=True, weight=0.5, **kwargs): self.neg_weight = weight self.pos_weight = 1.0 + weight + def visit_decimal(self, node, children): + return float(node.value) + + def visit_integer(self, node, children): + return int(node.value) + def visit_token(self, node, children): return str(node.value) + def visit_token_clip_skip(self, node, children): + return TokenNode("clip", "skip", children[0]) + + def visit_token_inversion(self, node, children): + return TokenNode("inversion", children[0][0], children[1]) + + def visit_token_lora(self, node, children): + return TokenNode("lora", children[0][0], children[1]) + + def visit_token_region(self, node, children): + return TokenNode("region", None, children) + + def visit_token_reseed(self, node, children): + return TokenNode("reseed", None, children) + def visit_token_run(self, node, children): return children def visit_phrase_inner(self, node, children): - if isinstance(children[0], PromptPhrase): - return children[0] - else: - return PromptPhrase(children[0]) + return [ + ( + child + if isinstance(child, (PhraseNode, TokenNode, list)) + else PhraseNode(child) + ) + for child in children + ] def visit_pos_phrase(self, node, children): - c = children[0] - if isinstance(c, PromptPhrase): - return PromptPhrase(c.tokens, c.weight * self.pos_weight) - elif isinstance(c, str): - return PromptPhrase(c, self.pos_weight) + return parse_phrase(children, self.pos_weight) def visit_neg_phrase(self, node, children): - c = children[0] - if isinstance(c, PromptPhrase): - return PromptPhrase(c.tokens, c.weight * self.neg_weight) - elif isinstance(c, str): - return PromptPhrase(c, self.neg_weight) + return parse_phrase(children, self.neg_weight) def visit_phrase(self, node, children): - return children[0] + return list(flatten(children)) def visit_prompt(self, node, children): - return children + return collapse_phrases(list(flatten(children)), PhraseNode, TokenNode) + + +def parse_phrase(child, weight): + if isinstance(child, PhraseNode): + return PhraseNode(child.tokens, child.weight * weight) + elif isinstance(child, str): + return PhraseNode([child], weight) + elif isinstance(child, list): + # TODO: when this is a list of strings, create a single node with all of them + # if all(isinstance(c, str) for c in child): + # return PhraseNode(child, weight) + + return [parse_phrase(c, weight) for c in child] diff --git a/api/onnx_web/prompt/parser.py b/api/onnx_web/prompt/parser.py index 8bab4dbfe..ee4c1e76a 100644 --- a/api/onnx_web/prompt/parser.py +++ b/api/onnx_web/prompt/parser.py @@ -1,9 +1,10 @@ -from typing import Literal +from typing import Literal, Union import numpy as np from arpeggio import ParserPython, visit_parse_tree -from .grammar import OnnxPromptVisitor +from .base import Prompt, PromptNetwork, PromptPhrase, PromptRegion, PromptSeed +from .grammar import OnnxPromptVisitor, PhraseNode, TokenNode from .grammar import prompt as prompt_base @@ -22,8 +23,10 @@ def parse_prompt_onnx(pipeline, prompt: str, debug=False) -> np.ndarray: parser = ParserPython(prompt_base, debug=debug) visitor = OnnxPromptVisitor() - ast = parser.parse(prompt) - return visit_parse_tree(ast, visitor) + lst = parser.parse(prompt) + ast = visit_parse_tree(lst, visitor) + + return ast def parse_prompt_vanilla(pipeline, prompt: str) -> np.ndarray: @@ -45,3 +48,52 @@ def parse_prompt( return parse_prompt_vanilla(pipeline, prompt) else: raise ValueError("invalid prompt parser") + + +def compile_prompt_onnx(prompt: str) -> Prompt: + ast = parse_prompt_onnx(None, prompt) + + tokens = [node for node in ast if isinstance(node, TokenNode)] + clip_skip = [token.rest[0] for token in tokens if token.type == "clip"] + networks = [ + PromptNetwork(token.type, token.name, token.rest[0]) + for token in tokens + if token.type in ["lora", "inversion"] + ] + regions = [PromptRegion(*token.rest) for token in tokens if token.type == "region"] + reseeds = [PromptSeed(*token.rest) for token in tokens if token.type == "reseed"] + + phrases = [ + compile_prompt_phrase(node) + for node in ast + if isinstance(node, (list, PhraseNode, str)) + ] + phrases = list(flatten(phrases)) + # TODO: collapse phrases with the same weight + + return Prompt( + networks=networks, + positive_phrases=phrases, + negative_phrases=[], + region_prompts=regions, + region_seeds=reseeds, + clip_skip=next(iter(clip_skip), 0), + ) + + +def compile_prompt_phrase(node: Union[PhraseNode, str]) -> PromptPhrase: + if isinstance(node, list): + return [compile_prompt_phrase(subnode) for subnode in node] + + if isinstance(node, str): + return PromptPhrase([node]) + + return PromptPhrase(node.tokens, node.weight) + + +def flatten(val): + if isinstance(val, list): + for subval in val: + yield from flatten(subval) + else: + yield val diff --git a/api/onnx_web/prompt/utils.py b/api/onnx_web/prompt/utils.py new file mode 100644 index 000000000..811ce6105 --- /dev/null +++ b/api/onnx_web/prompt/utils.py @@ -0,0 +1,48 @@ +from typing import Any, List, Union + + +def flatten(lst): + for el in lst: + if isinstance(el, list): + yield from flatten(el) + else: + yield el + + +def collapse_phrases( + nodes: List[Union[Any]], + phrase, + token, +) -> List[Union[Any]]: + """ + Combine phrases with the same weight. + """ + + weight = None + tokens = [] + phrases = [] + + def flush_tokens(): + nonlocal weight, tokens + if len(tokens) > 0: + phrases.append(phrase(tokens, weight)) + tokens = [] + weight = None + + for node in nodes: + if isinstance(node, str): + node = phrase([node]) + elif isinstance(node, token): + flush_tokens() + phrases.append(node) + continue + + if node.weight == weight: + tokens.extend(node.tokens) + else: + flush_tokens() + tokens = [*node.tokens] + weight = node.weight + + flush_tokens() + return phrases diff --git a/api/onnx_web/server/admin.py b/api/onnx_web/server/admin.py index bdbc9adeb..54358f9d1 100644 --- a/api/onnx_web/server/admin.py +++ b/api/onnx_web/server/admin.py @@ -26,14 +26,14 @@ def restart_workers(server: ServerContext, pool: DevicePoolExecutor): pool.recycle(recycle_all=True) logger.info("restarted worker pool") - return jsonify(pool.status()) + return jsonify(pool.summary()) def worker_status(server: ServerContext, pool: DevicePoolExecutor): if not check_admin(server): return make_response(jsonify({})), 401 - return jsonify(pool.status()) + return jsonify(pool.summary()) def get_extra_models(server: ServerContext): @@ -102,8 +102,8 @@ def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExe app.route("/api/extras", methods=["PUT"])( wrap_route(update_extra_models, server) ), - app.route("/api/restart", methods=["POST"])( + app.route("/api/worker/restart", methods=["POST"])( wrap_route(restart_workers, server, pool=pool) ), - app.route("/api/status")(wrap_route(worker_status, server, pool=pool)), + app.route("/api/worker/status")(wrap_route(worker_status, server, pool=pool)), ] diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index ebccbfe52..c191488d6 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -1,14 +1,14 @@ from io import BytesIO from logging import getLogger from os import path -from typing import Any, Dict +from typing import Any, Dict, List, Optional from flask import Flask, jsonify, make_response, request, url_for from jsonschema import validate from PIL import Image from ..chain import CHAIN_STAGES, ChainPipeline -from ..chain.result import StageResult +from ..chain.result import ImageMetadata, StageResult from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers from ..diffusers.run import ( run_blend_pipeline, @@ -18,8 +18,8 @@ run_upscale_pipeline, ) from ..diffusers.utils import replace_wildcards -from ..output import json_params, make_output_name -from ..params import Size, StageParams, TileOrder +from ..output import make_job_name +from ..params import Size, StageParams, TileOrder, get_size from ..transformers.run import run_txt2txt_pipeline from ..utils import ( base_join, @@ -28,12 +28,13 @@ get_boolean, get_from_list, get_from_map, + get_list, get_not_empty, - get_size, load_config, load_config_str, sanitize_name, ) +from ..worker.command import JobStatus, JobType, Progress from ..worker.pool import DevicePoolExecutor from .context import ServerContext from .load import ( @@ -46,16 +47,17 @@ get_mask_filters, get_network_models, get_noise_sources, + get_prompt_filters, get_source_filters, get_upscaling_models, get_wildcard_data, ) from .params import ( build_border, - build_highres, build_upscale, + get_request_data, + get_request_params, pipeline_from_json, - pipeline_from_request, ) from .utils import wrap_route @@ -92,6 +94,89 @@ def error_reply(err: str): return response +EMPTY_PROGRESS = Progress(0, 0) + + +def job_reply(name: str, queue: int = 0): + return jsonify( + { + "name": name, + "queue": Progress(queue, queue).tojson(), + "status": JobStatus.PENDING, + "stages": EMPTY_PROGRESS.tojson(), + "steps": EMPTY_PROGRESS.tojson(), + "tiles": EMPTY_PROGRESS.tojson(), + } + ) + + +def image_reply( + server: ServerContext, + name: str, + status: str, + queue: Progress = None, + stages: Progress = None, + steps: Progress = None, + tiles: Progress = None, + metadata: Optional[List[ImageMetadata]] = None, + outputs: Optional[List[str]] = None, + thumbnails: Optional[List[str]] = None, + reason: Optional[str] = None, +) -> Dict[str, Any]: + if queue is None: + queue = EMPTY_PROGRESS + + if stages is None: + stages = EMPTY_PROGRESS + + if steps is None: + steps = EMPTY_PROGRESS + + if tiles is None: + tiles = EMPTY_PROGRESS + + data = { + "name": name, + "status": status, + "queue": queue.tojson(), + "stages": stages.tojson(), + "steps": steps.tojson(), + "tiles": tiles.tojson(), + } + + if reason is not None: + data["reason"] = reason + + if outputs is not None: + if metadata is None: + logger.error("metadata is required with outputs") + return error_reply("metadata is required with outputs") + + if len(metadata) != len(outputs): + logger.error("metadata and outputs must be the same length") + return error_reply("metadata and outputs must be the same length") + + data["metadata"] = [m.tojson(server, [o]) for m, o in zip(metadata, outputs)] + data["outputs"] = outputs + + if thumbnails is not None: + if len(thumbnails) != len(outputs): + logger.error("thumbnails and outputs must be the same length") + return error_reply("thumbnails and outputs must be the same length") + + data["thumbnails"] = thumbnails + + return data + + +def multi_image_reply(results: List[Dict[str, Any]]): + return jsonify( + { + "results": results, + } + ) + + def url_from_rule(rule) -> str: options = {} for arg in rule.arguments: @@ -104,7 +189,10 @@ def introspect(server: ServerContext, app: Flask): return { "name": "onnx-web", "routes": [ - {"path": url_from_rule(rule), "methods": list(rule.methods or []).sort()} + { + "path": url_from_rule(rule), + "methods": list(rule.methods or []), + } for rule in app.url_map.iter_rules() ], } @@ -116,10 +204,12 @@ def list_extra_strings(server: ServerContext): def list_filters(server: ServerContext): mask_filters = list(get_mask_filters().keys()) + prompt_filters = get_prompt_filters() source_filters = list(get_source_filters().keys()) return jsonify( { "mask": mask_filters, + "prompt": prompt_filters, "source": source_filters, } ) @@ -171,81 +261,62 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): return error_reply("source image is required") source = Image.open(BytesIO(source_file.read())).convert("RGB") - size = Size(source.width, source.height) - device, params, _size = pipeline_from_request(server, "img2img") - upscale = build_upscale() - highres = build_highres() + data = get_request_data() + data_params = data.get("params", data) source_filter = get_from_list( - request.args, "sourceFilter", list(get_source_filters().keys()) + data_params, "sourceFilter", list(get_source_filters().keys()) ) strength = get_and_clamp_float( - request.args, + data_params, "strength", get_config_value("strength"), get_config_value("strength", "max"), get_config_value("strength", "min"), ) - replace_wildcards(params, get_wildcard_data()) + params = get_request_params(server, JobType.IMG2IMG.value) + params.size = Size(source.width, source.height) + replace_wildcards(params.image, get_wildcard_data()) - output_count = params.batch - if source_filter is not None and source_filter != "none": - logger.debug( - "including filtered source with outputs, filter: %s", source_filter - ) - output_count += 1 - - output = make_output_name( - server, "img2img", params, size, extras=[strength], count=output_count + job_name = make_job_name( + JobType.IMG2IMG.value, params.image, params.size, extras=[strength] ) - - job_name = output[0] - pool.submit( + queue = pool.submit( job_name, + JobType.IMG2IMG, run_img2img_pipeline, server, params, - output, - upscale, - highres, source, strength, - needs_device=device, + needs_device=params.device, source_filter=source_filter, ) logger.info("img2img job queued for: %s", job_name) - return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) + return job_reply(job_name, queue=queue) def txt2img(server: ServerContext, pool: DevicePoolExecutor): - device, params, size = pipeline_from_request(server, "txt2img") - upscale = build_upscale() - highres = build_highres() - - replace_wildcards(params, get_wildcard_data()) + params = get_request_params(server, JobType.TXT2IMG.value) + replace_wildcards(params.image, get_wildcard_data()) - output = make_output_name(server, "txt2img", params, size, count=params.batch) - - job_name = output[0] - pool.submit( + job_name = make_job_name(JobType.TXT2IMG.value, params.image, params.size) + queue = pool.submit( job_name, + JobType.TXT2IMG, run_txt2img_pipeline, server, params, - size, - output, - upscale, - highres, - needs_device=device, + needs_device=params.device, ) logger.info("txt2img job queued for: %s", job_name) - return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) + return job_reply(job_name, queue=queue) def inpaint(server: ServerContext, pool: DevicePoolExecutor): @@ -265,42 +336,40 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): mask.alpha_composite(mask_top_layer) mask.convert(mode="L") + data = get_request_data() + data_params = data.get("params", data) full_res_inpaint = get_boolean( - request.args, "fullresInpaint", get_config_value("fullresInpaint") + data_params, "fullresInpaint", get_config_value("fullresInpaint") ) full_res_inpaint_padding = get_and_clamp_float( - request.args, + data_params, "fullresInpaintPadding", get_config_value("fullresInpaintPadding"), get_config_value("fullresInpaintPadding", "max"), get_config_value("fullresInpaintPadding", "min"), ) - device, params, _size = pipeline_from_request(server, "inpaint") - expand = build_border() - upscale = build_upscale() - highres = build_highres() + params = get_request_params(server, JobType.INPAINT.value) + replace_wildcards(params.image, get_wildcard_data()) - fill_color = get_not_empty(request.args, "fillColor", "white") - mask_filter = get_from_map(request.args, "filter", get_mask_filters(), "none") - noise_source = get_from_map(request.args, "noise", get_noise_sources(), "histogram") + fill_color = get_not_empty(data_params, "fillColor", "white") + mask_filter = get_from_map(data_params, "filter", get_mask_filters(), "none") + noise_source = get_from_map(data_params, "noise", get_noise_sources(), "histogram") tile_order = get_from_list( - request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral] + data_params, + "tileOrder", + [TileOrder.grid, TileOrder.kernel, TileOrder.spiral], ) - tile_order = TileOrder.spiral - - replace_wildcards(params, get_wildcard_data()) - output = make_output_name( - server, - "inpaint", - params, + job_name = make_job_name( + JobType.INPAINT.value, + params.image, size, extras=[ - expand.left, - expand.right, - expand.top, - expand.bottom, + params.border.left, + params.border.right, + params.border.top, + params.border.bottom, mask_filter.__name__, noise_source.__name__, fill_color, @@ -308,35 +377,26 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): ], ) - job_name = output[0] - pool.submit( + queue = pool.submit( job_name, + JobType.INPAINT, run_inpaint_pipeline, server, params, - size, - output, - upscale, - highres, source, mask, - expand, noise_source, mask_filter, fill_color, tile_order, full_res_inpaint, full_res_inpaint_padding, - needs_device=device, + needs_device=params.device, ) logger.info("inpaint job queued for: %s", job_name) - return jsonify( - json_params( - output, params, size, upscale=upscale, border=expand, highres=highres - ) - ) + return job_reply(job_name, queue=queue) def upscale(server: ServerContext, pool: DevicePoolExecutor): @@ -346,31 +406,23 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): source = Image.open(BytesIO(source_file.read())).convert("RGB") - device, params, size = pipeline_from_request(server) - upscale = build_upscale() - highres = build_highres() + params = get_request_params(server, JobType.UPSCALE.value) + replace_wildcards(params.image, get_wildcard_data()) - replace_wildcards(params, get_wildcard_data()) - - output = make_output_name(server, "upscale", params, size) - - job_name = output[0] - pool.submit( + job_name = make_job_name("upscale", params.image, params.size) + queue = pool.submit( job_name, + JobType.UPSCALE, run_upscale_pipeline, server, params, - size, - output, - upscale, - highres, source, - needs_device=device, + needs_device=params.device, ) logger.info("upscale job queued for: %s", job_name) - return jsonify(json_params(output, params, size, upscale=upscale, highres=highres)) + return job_reply(job_name, queue=queue) # keys that are specially parsed by params and should not show up in with_args @@ -466,25 +518,21 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): logger.info("running chain pipeline with %s stages", len(pipeline.stages)) - output = make_output_name( - server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0) - ) - job_name = output[0] + job_name = make_job_name("chain", base_params, base_size) # build and run chain pipeline - pool.submit( + queue = pool.submit( job_name, + JobType.CHAIN, pipeline, server, base_params, StageResult.empty(), - output=output, size=base_size, needs_device=device, ) - step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size)) - return jsonify(json_params(output, step_params, base_size)) + return job_reply(job_name, queue=queue) def blend(server: ServerContext, pool: DevicePoolExecutor): @@ -505,48 +553,41 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): source = Image.open(BytesIO(source_file.read())).convert("RGB") sources.append(source) - device, params, size = pipeline_from_request(server) - upscale = build_upscale() + params = get_request_params(server) - output = make_output_name(server, "upscale", params, size) - job_name = output[0] - pool.submit( + job_name = make_job_name("blend", params.image, params.size) + queue = pool.submit( job_name, + JobType.BLEND, run_blend_pipeline, server, params, - size, - output, - upscale, - # TODO: highres sources, mask, - needs_device=device, + needs_device=params.device, ) logger.info("upscale job queued for: %s", job_name) - return jsonify(json_params(output, params, size, upscale=upscale)) + return job_reply(job_name, queue=queue) def txt2txt(server: ServerContext, pool: DevicePoolExecutor): - device, params, size = pipeline_from_request(server) + params = get_request_params(server) - output = make_output_name(server, "txt2txt", params, size) - job_name = output[0] + job_name = make_job_name("txt2txt", params.image, params.size) logger.info("upscale job queued for: %s", job_name) - pool.submit( + queue = pool.submit( job_name, + JobType.TXT2TXT, run_txt2txt_pipeline, server, params, - size, - output, - needs_device=device, + needs_device=params.device, ) - return jsonify(json_params(output, params, size)) + return job_reply(job_name, queue=queue) def cancel(server: ServerContext, pool: DevicePoolExecutor): @@ -566,9 +607,9 @@ def ready(server: ServerContext, pool: DevicePoolExecutor): return error_reply("output name is required") output_file = sanitize_name(output_file) - pending, progress = pool.done(output_file) + status, progress, _queue = pool.status(output_file) - if pending: + if status == JobStatus.PENDING: return ready_reply(pending=True) if progress is None: @@ -582,16 +623,103 @@ def ready(server: ServerContext, pool: DevicePoolExecutor): ) # is a missing image really an error? yes will display the retry button return ready_reply( - ready=progress.finished, - progress=progress.progress, - failed=progress.failed, - cancelled=progress.cancelled, + ready=(status == JobStatus.SUCCESS), + progress=progress.steps.current, + failed=(status == JobStatus.FAILED), + cancelled=(status == JobStatus.CANCELLED), ) +def job_create(server: ServerContext, pool: DevicePoolExecutor): + return chain(server, pool) + + +def job_cancel(server: ServerContext, pool: DevicePoolExecutor): + legacy_job_name = request.args.get("job", None) + job_list = get_list(request.args, "jobs") + + if legacy_job_name is not None: + job_list.append(legacy_job_name) + + if len(job_list) == 0: + return error_reply("at least one job name is required") + elif len(job_list) > 10: + return error_reply("too many jobs") + + results: List[Dict[str, str]] = [] + for job_name in job_list: + job_name = sanitize_name(job_name) + cancelled = pool.cancel(job_name) + results.append( + { + "name": job_name, + "status": JobStatus.CANCELLED if cancelled else JobStatus.PENDING, + } + ) + + return multi_image_reply(results) + + +def job_status(server: ServerContext, pool: DevicePoolExecutor): + legacy_job_name = request.args.get("job", None) + job_list = get_list(request.args, "jobs") + + if legacy_job_name is not None: + job_list.append(legacy_job_name) + + if len(job_list) == 0: + return error_reply("at least one job name is required") + elif len(job_list) > 10: + return error_reply("too many jobs") + + records = [] + + for job_name in job_list: + job_name = sanitize_name(job_name) + status, progress, queue = pool.status(job_name) + + if progress is not None: + metadata = None + outputs = None + thumbnails = None + + if progress.result is not None: + metadata = progress.result.metadata + outputs = progress.result.outputs + thumbnails = progress.result.thumbnails + + records.append( + image_reply( + server, + job_name, + status, + stages=progress.stages, + steps=progress.steps, + tiles=progress.tiles, + metadata=metadata, + outputs=outputs, + thumbnails=thumbnails, + reason=progress.reason, + ) + ) + else: + records.append(image_reply(server, job_name, status, queue=queue)) + + return jsonify(records) + + def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor): return [ app.route("/api")(wrap_route(introspect, server, app=app)), + # job routes + app.route("/api/job", methods=["POST"])( + wrap_route(job_create, server, pool=pool) + ), + app.route("/api/job/cancel", methods=["PUT"])( + wrap_route(job_cancel, server, pool=pool) + ), + app.route("/api/job/status")(wrap_route(job_status, server, pool=pool)), + # settings routes app.route("/api/settings/filters")(wrap_route(list_filters, server)), app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)), app.route("/api/settings/models")(wrap_route(list_models, server)), @@ -602,6 +730,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)), app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)), app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)), + # legacy job routes app.route("/api/img2img", methods=["POST"])( wrap_route(img2img, server, pool=pool) ), @@ -619,6 +748,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu ), app.route("/api/chain", methods=["POST"])(wrap_route(chain, server, pool=pool)), app.route("/api/blend", methods=["POST"])(wrap_route(blend, server, pool=pool)), + # deprecated routes app.route("/api/cancel", methods=["PUT"])( wrap_route(cancel, server, pool=pool) ), diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 6f77522bc..bbad066ad 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -1,7 +1,7 @@ from logging import getLogger from os import environ, path from secrets import token_urlsafe -from typing import List, Optional +from typing import Dict, List, Optional import torch @@ -16,6 +16,7 @@ DEFAULT_IMAGE_FORMAT = "png" DEFAULT_SERVER_VERSION = "v0.12.0" DEFAULT_SHOW_PROGRESS = True +DEFAULT_THUMBNAIL_SIZE = 1024 DEFAULT_WORKER_RETRIES = 3 @@ -42,6 +43,8 @@ class ServerContext: feature_flags: List[str] plugins: List[str] debug: bool + thumbnail_size: int + hash_cache: Dict[str, str] def __init__( self, @@ -67,6 +70,8 @@ def __init__( feature_flags: Optional[List[str]] = None, plugins: Optional[List[str]] = None, debug: bool = False, + thumbnail_size: Optional[int] = DEFAULT_THUMBNAIL_SIZE, + hash_cache: Optional[Dict[str, str]] = None, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -90,6 +95,8 @@ def __init__( self.feature_flags = feature_flags or [] self.plugins = plugins or [] self.debug = debug + self.thumbnail_size = thumbnail_size + self.hash_cache = hash_cache or {} self.cache = ModelCache(self.cache_limit) @@ -127,6 +134,9 @@ def from_environ(cls, env=environ): feature_flags=get_list(env, "ONNX_WEB_FEATURE_FLAGS"), plugins=get_list(env, "ONNX_WEB_PLUGINS", ""), debug=get_boolean(env, "ONNX_WEB_DEBUG", False), + thumbnail_size=int( + env.get("ONNX_WEB_THUMBNAIL_SIZE", DEFAULT_THUMBNAIL_SIZE) + ), ) def get_setting(self, flag: str, default: str) -> Optional[str]: diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 3c76f9cc4..28f79a867 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -35,7 +35,7 @@ from ..models.meta import NetworkModel from ..params import DeviceParams from ..torch_before_ort import get_available_providers -from ..utils import load_config, merge +from ..utils import load_config, merge, recursive_get from .context import ServerContext logger = getLogger(__name__) @@ -83,6 +83,12 @@ "segment": source_filter_segment, "scribble": source_filter_scribble, } +prompt_filters = [ + "AUTOMATIC/promptgen-lexart", + "AUTOMATIC/promptgen-majinai-safe", + "AUTOMATIC/promptgen-majinai-unsafe", + "Gustavosta/MagicPrompt-Stable-Diffusion", +] # Available ORT providers available_platforms: List[DeviceParams] = [] @@ -148,12 +154,17 @@ def get_noise_sources(): return noise_sources +def get_prompt_filters(): + return prompt_filters + + def get_source_filters(): return source_filters def get_config_value(key: str, subkey: str = "default", default=None): - return config_params.get(key, {}).get(subkey, default) + val = recursive_get(config_params, key.split("."), default_value={}) + return val.get(subkey, default) def load_extras(server: ServerContext): @@ -233,9 +244,9 @@ def load_extras(server: ServerContext): inversion_name, model_name, ) - labels[ - f"inversion.{inversion_name}" - ] = inversion["label"] + labels[f"inversion.{inversion_name}"] = ( + inversion["label"] + ) if "loras" in model: for lora in model["loras"]: diff --git a/api/onnx_web/server/model_cache.py b/api/onnx_web/server/model_cache.py index 6525d4ae3..8c1578581 100644 --- a/api/onnx_web/server/model_cache.py +++ b/api/onnx_web/server/model_cache.py @@ -12,6 +12,7 @@ class ModelTypes(str, Enum): diffusion = "diffusion" scheduler = "scheduler" upscaling = "upscaling" + safety = "safety" class ModelCache: diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index b19f541a7..a1b9e5d08 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union from flask import request @@ -8,8 +8,12 @@ from ..params import ( Border, DeviceParams, + ExperimentalParams, HighresParams, ImageParams, + LatentSymmetryParams, + PromptEditingParams, + RequestParams, Size, UpscaleParams, ) @@ -19,6 +23,7 @@ get_boolean, get_from_list, get_not_empty, + load_config_str, ) from .context import ServerContext from .load import ( @@ -118,6 +123,7 @@ def build_params( get_config_value("steps", "min"), ) tiled_vae = get_boolean(data, "tiled_vae", get_config_value("tiled_vae")) + thumbnail = get_boolean(data, "thumbnail", get_config_value("thumbnail")) unet_overlap = get_and_clamp_float( data, "unet_overlap", @@ -169,6 +175,7 @@ def build_params( unet_tile=unet_tile, vae_overlap=vae_overlap, vae_tile=vae_tile, + thumbnail=thumbnail, ) return params @@ -196,11 +203,8 @@ def build_size( def build_border( - data: Dict[str, str] = None, + data: Dict[str, str], ) -> Border: - if data is None: - data = request.args - left = get_and_clamp_int( data, "left", @@ -234,11 +238,8 @@ def build_border( def build_upscale( - data: Dict[str, str] = None, + data: Dict[str, str], ) -> UpscaleParams: - if data is None: - data = request.args - upscale = get_boolean(data, "upscale", False) denoise = get_and_clamp_float( data, @@ -289,7 +290,6 @@ def build_upscale( faces=faces, face_outscale=face_outscale, face_strength=face_strength, - format="onnx", outscale=outscale, scale=scale, upscale_order=upscale_order, @@ -297,11 +297,8 @@ def build_upscale( def build_highres( - data: Dict[str, str] = None, + data: Dict[str, str], ) -> HighresParams: - if data is None: - data = request.args - enabled = get_boolean(data, "highres", get_config_value("highres")) iterations = get_and_clamp_int( data, @@ -343,6 +340,90 @@ def build_highres( ) +def build_latent_symmetry( + data: Dict[str, str] = None, +) -> LatentSymmetryParams: + if data is None: + data = request.args + + enabled = get_boolean(data, "enabled", get_config_value("latentSymmetry.enabled")) + + gradient_start = get_and_clamp_float( + data, + "gradientStart", + get_config_value("latentSymmetry.gradientStart"), + get_config_value("latentSymmetry.gradientStart", "max"), + get_config_value("latentSymmetry.gradientStart", "min"), + ) + + gradient_end = get_and_clamp_float( + data, + "gradientEnd", + get_config_value("latentSymmetry.gradientEnd"), + get_config_value("latentSymmetry.gradientEnd", "max"), + get_config_value("latentSymmetry.gradientEnd", "min"), + ) + + line_of_symmetry = get_and_clamp_float( + data, + "lineOfSymmetry", + get_config_value("latentSymmetry.lineOfSymmetry"), + get_config_value("latentSymmetry.lineOfSymmetry", "max"), + get_config_value("latentSymmetry.lineOfSymmetry", "min"), + ) + + return LatentSymmetryParams(enabled, gradient_start, gradient_end, line_of_symmetry) + + +def build_prompt_editing( + data: Dict[str, str] = None, +) -> Dict[str, str]: + if data is None: + data = request.args + + enabled = get_boolean(data, "enabled", get_config_value("promptEditing.enabled")) + + add_suffix = data.get("addSuffix", get_config_value("promptEditing.addSuffix")) + min_length = get_and_clamp_int( + data, + "minLength", + get_config_value("promptEditing.minLength"), + get_config_value("promptEditing.minLength", "max"), + get_config_value("promptEditing.minLength", "min"), + ) + prompt_filter = data.get("promptFilter", get_config_value("promptEditing.filter")) + remove_tokens = data.get( + "removeTokens", get_config_value("promptEditing.removeTokens") + ) + + return PromptEditingParams( + enabled, prompt_filter, remove_tokens, add_suffix, min_length + ) + + +def build_experimental( + data: Dict[str, str] = None, +) -> ExperimentalParams: + if data is None: + data = request.args + + latent_symmetry_data = data.get("latentSymmetry", {}) + latent_symmetry = build_latent_symmetry(latent_symmetry_data) + + prompt_editing_data = data.get("promptEditing", {}) + prompt_editing = build_prompt_editing(prompt_editing_data) + + return ExperimentalParams(latent_symmetry, prompt_editing) + + +def is_json_request() -> bool: + return request.mimetype == "application/json" + + +def is_json_form_request() -> bool: + return request.mimetype == "multipart/form-data" and "json" in request.form + + PipelineParams = Tuple[Optional[DeviceParams], ImageParams, Size] @@ -351,27 +432,11 @@ def pipeline_from_json( data: Dict[str, Union[str, Dict[str, str]]], default_pipeline: str = "txt2img", ) -> PipelineParams: - """ - Like pipeline_from_request but expects a nested structure. - """ - device = build_device(server, data.get("device", data)) params = build_params(server, default_pipeline, data.get("params", data)) size = build_size(server, data.get("params", data)) - return (device, params, size) - - -def pipeline_from_request( - server: ServerContext, - default_pipeline: str = "txt2img", -) -> PipelineParams: user = request.remote_addr - - device = build_device(server, request.args) - params = build_params(server, default_pipeline, request.args) - size = build_size(server, request.args) - logger.info( "request from %s: %s steps of %s using %s in %s on %s, %sx%s, %s, %s - %s", user, @@ -388,3 +453,49 @@ def pipeline_from_request( ) return (device, params, size) + + +def get_request_data(key: Optional[str] = None) -> Any: + if is_json_request(): + json = request.json + elif is_json_form_request(): + json = load_config_str(request.form.get("json")) + else: + json = None + + if key is not None and json is not None: + json = json.get(key) + + return json or request.args + + +def get_request_params( + server: ServerContext, default_pipeline: str = None +) -> RequestParams: + data = get_request_data() + + device, params, size = pipeline_from_json(server, data, default_pipeline) + + border = build_border(get_dict_or_self(data, "border")) + upscale = build_upscale(get_dict_or_self(data, "upscale")) + highres = build_highres(get_dict_or_self(data, "highres")) + experimental = build_experimental(get_dict_or_self(data, "experimental")) + + return RequestParams( + device, + params, + size=size, + border=border, + upscale=upscale, + highres=highres, + experimental=experimental, + ) + + +def get_dict_or_self(obj: Dict[str, Any], key: str) -> Any: + if key in obj: + value = obj[key] + if isinstance(value, dict): + return value + + return obj diff --git a/api/onnx_web/transformers/run.py b/api/onnx_web/transformers/run.py index eb7896399..e45ab4f7c 100644 --- a/api/onnx_web/transformers/run.py +++ b/api/onnx_web/transformers/run.py @@ -12,7 +12,6 @@ def run_txt2txt_pipeline( _server: ServerContext, params: ImageParams, _size: Size, - output: str, ) -> None: from transformers import AutoTokenizer, GPTJForCausalLM @@ -38,4 +37,4 @@ def run_txt2txt_pipeline( print("Server says: %s" % result_text) - logger.info("finished txt2txt job: %s", output) + logger.info("finished txt2txt job: %s", worker.job) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index d047ec0d5..74a8dac20 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -1,18 +1,16 @@ -import gc import importlib import json -import threading +from functools import reduce +from hashlib import sha256 from json import JSONDecodeError from logging import getLogger from os import environ, path from platform import system -from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union +from struct import pack +from typing import Any, Dict, List, Optional, Sequence, TypeVar -import torch from yaml import safe_load -from .params import DeviceParams, SizeChart - logger = getLogger(__name__) SAFE_CHARS = "._-" @@ -32,8 +30,18 @@ def is_debug() -> bool: return get_boolean(environ, "DEBUG", False) +def recursive_get(d, keys, default_value=None): + empty_dict = {} + val = reduce(lambda c, k: c.get(k, empty_dict), keys, d) + + if val == empty_dict: + return default_value + + return val + + def get_boolean(args: Any, key: str, default_value: bool) -> bool: - val = args.get(key, str(default_value)) + val = recursive_get(args, key.split("."), default_value=str(default_value)) if isinstance(val, bool): return val @@ -42,19 +50,22 @@ def get_boolean(args: Any, key: str, default_value: bool) -> bool: def get_list(args: Any, key: str, default="") -> List[str]: - return split_list(args.get(key, default)) + val = recursive_get(args, key.split("."), default_value=default) + return split_list(val) def get_and_clamp_float( args: Any, key: str, default_value: float, max_value: float, min_value=0.0 ) -> float: - return min(max(float(args.get(key, default_value)), min_value), max_value) + val = recursive_get(args, key.split("."), default_value=default_value) + return min(max(float(val), min_value), max_value) def get_and_clamp_int( args: Any, key: str, default_value: int, max_value: int, min_value=1 ) -> int: - return min(max(int(args.get(key, default_value)), min_value), max_value) + val = recursive_get(args, key.split("."), default_value=default_value) + return min(max(int(val), min_value), max_value) TElem = TypeVar("TElem") @@ -93,43 +104,14 @@ def get_not_empty(args: Any, key: str, default: TElem) -> TElem: return val -def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]: - if val is None: - return SizeChart.auto - - if type(val) is int: - return val - - if type(val) is str: - for size in SizeChart: - if val == size.name: - return size - - return int(val) - - raise ValueError("invalid size") +def run_gc(devices: Optional[List[Any]] = None): + """ + Deprecated, use `onnx_web.device.run_gc` instead. + """ + from .device import run_gc as run_gc_impl - -def run_gc(devices: Optional[List[DeviceParams]] = None): - logger.debug( - "running garbage collection with %s active threads", threading.active_count() - ) - gc.collect() - - if torch.cuda.is_available() and devices is not None: - for device in [d for d in devices if d.device.startswith("cuda")]: - logger.debug("running Torch garbage collection for device: %s", device) - with torch.cuda.device(device.torch_str()): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - mem_free, mem_total = torch.cuda.mem_get_info() - mem_pct = (1 - (mem_free / mem_total)) * 100 - logger.debug( - "CUDA VRAM usage: %s of %s (%.2f%%)", - (mem_total - mem_free), - mem_total, - mem_pct, - ) + logger.debug("calling deprecated run_gc, please use onnx_web.device.run_gc instead") + run_gc_impl(devices) def sanitize_name(name): @@ -218,3 +200,45 @@ def load_config_str(raw: str) -> Dict: return json.loads(raw) except JSONDecodeError: return safe_load(raw) + + +HASH_BUFFER_SIZE = 2**22 # 4MB + + +def hash_file(name: str): + sha = sha256() + with open(name, "rb") as f: + while True: + data = f.read(HASH_BUFFER_SIZE) + if not data: + break + + sha.update(data) + + return sha.hexdigest() + + +def hash_value(sha, param: Optional[Any]): + if param is None: + return None + elif isinstance(param, bool): + sha.update(bytearray(pack("!B", param))) + elif isinstance(param, float): + sha.update(bytearray(pack("!f", param))) + elif isinstance(param, int): + sha.update(bytearray(pack("!I", param))) + elif isinstance(param, str): + sha.update(param.encode("utf-8")) + else: + logger.warning("cannot hash param: %s, %s", param, type(param)) + + +def coalesce(*args, throw=False): + for arg in args: + if arg is not None: + return arg + + if throw: + raise ValueError("no value found") + + return None diff --git a/api/onnx_web/worker/command.py b/api/onnx_web/worker/command.py index 1d7db225c..40f4ed77e 100644 --- a/api/onnx_web/worker/command.py +++ b/api/onnx_web/worker/command.py @@ -1,34 +1,107 @@ -from typing import Any, Callable, Dict +from enum import Enum +from typing import Any, Callable, Dict, Optional + + +class JobStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + CANCELLED = "cancelled" + UNKNOWN = "unknown" + + +class JobType(str, Enum): + TXT2TXT = "txt2txt" + TXT2IMG = "txt2img" + IMG2IMG = "img2img" + INPAINT = "inpaint" + UPSCALE = "upscale" + BLEND = "blend" + CHAIN = "chain" + + +class Progress: + """ + Generic counter with current and expected/final/total value. Can be used to count up or down. + + Counter is considered "complete" when the current value is greater than or equal to the total value, and "empty" + when the current value is zero. + """ + + current: int + total: int + + def __init__(self, current: int, total: int) -> None: + self.current = current + self.total = total + + def __str__(self) -> str: + return "%s/%s" % (self.current, self.total) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Progress): + return self.current == other.current and self.total == other.total + + return False + + def tojson(self): + return { + "current": self.current, + "total": self.total, + } + + def is_complete(self) -> bool: + return self.current >= self.total + + def is_empty(self) -> bool: + # TODO: what if total is also 0? + return self.current == 0 + + def update(self, current: int) -> "Progress": + return Progress(current, self.total) class ProgressCommand: device: str job: str - finished: bool - progress: int - cancelled: bool - failed: bool + job_type: str + status: JobStatus + reason: Optional[str] + result: Optional[Any] # really StageResult but that would be a very circular import + steps: Progress + stages: Progress + tiles: Progress def __init__( self, job: str, + job_type: str, device: str, - finished: bool, - progress: int, - cancelled: bool = False, - failed: bool = False, + status: JobStatus, + steps: Progress, + stages: Progress, + tiles: Progress, + result: Any = None, + reason: Optional[str] = None, ): self.job = job + self.job_type = job_type self.device = device - self.finished = finished - self.progress = progress - self.cancelled = cancelled - self.failed = failed + self.status = status + + # progress info + self.steps = steps + self.stages = stages + self.tiles = tiles + self.result = result + self.reason = reason class JobCommand: device: str name: str + job_type: str fn: Callable[..., None] args: Any kwargs: Dict[str, Any] @@ -37,12 +110,14 @@ def __init__( self, name: str, device: str, + job_type: str, fn: Callable[..., None], args: Any, kwargs: Dict[str, Any], ): self.device = device self.name = name + self.job_type = job_type self.fn = fn self.args = args self.kwargs = kwargs diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 2d6d02781..1a874dde0 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -2,21 +2,23 @@ from os import getpid from typing import Any, Callable, Optional +import numpy as np from torch.multiprocessing import Queue, Value from ..errors import CancelledException from ..params import DeviceParams -from .command import JobCommand, ProgressCommand +from .command import JobCommand, JobStatus, Progress, ProgressCommand logger = getLogger(__name__) -ProgressCallback = Callable[[int, int, Any], None] +ProgressCallback = Callable[[int, int, np.ndarray], None] class WorkerContext: cancel: "Value[bool]" job: Optional[str] + job_type: Optional[str] name: str pending: "Queue[JobCommand]" active_pid: "Value[int]" @@ -26,6 +28,12 @@ class WorkerContext: timeout: float retries: int initial_retries: int + callback: Optional[Any] + + # progress state + steps: Progress + stages: Progress + tiles: Progress def __init__( self, @@ -41,6 +49,7 @@ def __init__( timeout: float, ): self.job = None + self.job_type = None self.name = name self.device = device self.cancel = cancel @@ -53,10 +62,20 @@ def __init__( self.initial_retries = retries self.retries = retries self.timeout = timeout + self.callback = None + self.steps = Progress(0, 0) + self.stages = Progress(0, 0) + self.tiles = Progress(0, 0) + + def start(self, job: JobCommand) -> None: + # set job name and type + self.job = job.name + self.job_type = job.job_type - def start(self, job: str) -> None: - self.job = job + # reset retries self.retries = self.initial_retries + + # clear flags self.set_cancel(cancel=False) self.set_idle(idle=False) @@ -79,20 +98,31 @@ def get_device(self) -> DeviceParams: """ return self.device - def get_progress(self) -> int: - if self.last_progress is not None: - return self.last_progress.progress + def get_progress(self) -> Progress: + return self.get_last_steps() + + def get_last_steps(self) -> Progress: + return self.steps - return 0 + def get_last_stages(self) -> Progress: + return self.stages - def get_progress_callback(self) -> ProgressCallback: + def get_last_tiles(self) -> Progress: + return self.tiles + + def get_progress_callback(self, reset=False) -> ProgressCallback: from ..chain.pipeline import ChainProgress + if not reset and self.callback is not None: + return self.callback + def on_progress(step: int, timestep: int, latents: Any): - on_progress.step = step - self.set_progress(step) + self.set_progress( + step, + ) - return ChainProgress.from_progress(on_progress) + self.callback = ChainProgress.from_progress(on_progress) + return self.callback def set_cancel(self, cancel: bool = True) -> None: with self.cancel.get_lock(): @@ -102,47 +132,93 @@ def set_idle(self, idle: bool = True) -> None: with self.idle.get_lock(): self.idle.value = idle - def set_progress(self, progress: int) -> None: + def set_progress(self, steps: int, stages: int = None, tiles: int = None) -> None: if self.job is None: raise RuntimeError("no job on which to set progress") if self.is_cancelled(): raise CancelledException("job has been cancelled") - logger.debug("setting progress for job %s to %s", self.job, progress) + # update current progress counters + self.steps = self.steps.update(steps) + + if stages is not None: + self.stages = self.stages.update(stages) + + if tiles is not None: + self.tiles = self.tiles.update(tiles) + + # TODO: result should really be part of context at this point + result = None + if self.callback is not None: + result = self.callback.result + + # send progress to worker pool + logger.debug("setting progress for job %s to %s", self.job, steps) self.last_progress = ProgressCommand( self.job, + self.job_type, self.device.device, - False, - progress, - self.is_cancelled(), - False, + JobStatus.RUNNING, + steps=self.steps, + stages=self.stages, + tiles=self.tiles, + result=result, ) - self.progress.put( self.last_progress, block=False, ) + def set_steps(self, current: int, total: int = 0) -> None: + if total > 0: + self.steps = Progress(current, total) + else: + self.steps = self.steps.update(current) + + def set_stages(self, current: int, total: int = 0) -> None: + if total > 0: + self.stages = Progress(current, total) + else: + self.stages = self.stages.update(current) + + def set_tiles(self, current: int, total: int = 0) -> None: + if total > 0: + self.tiles = Progress(current, total) + else: + self.tiles = self.tiles.update(current) + + def set_totals(self, steps: int, stages: int = 0, tiles: int = 0) -> None: + self.steps = Progress(0, steps) + self.stages = Progress(0, stages) + self.tiles = Progress(0, tiles) + def finish(self) -> None: if self.job is None: logger.warning("setting finished without an active job") else: logger.debug("setting finished for job %s", self.job) + + result = None + if self.callback is not None: + result = self.callback.result + self.last_progress = ProgressCommand( self.job, + self.job_type, self.device.device, - True, - self.get_progress(), - self.is_cancelled(), - False, + JobStatus.SUCCESS, + steps=self.steps, + stages=self.stages, + tiles=self.tiles, + result=result, ) self.progress.put( self.last_progress, block=False, ) - def fail(self) -> None: + def fail(self, reason: Optional[str] = None) -> None: if self.job is None: logger.warning("setting failure without an active job") else: @@ -150,11 +226,14 @@ def fail(self) -> None: try: self.last_progress = ProgressCommand( self.job, + self.job_type, self.device.device, - True, - self.get_progress(), - self.is_cancelled(), - True, + JobStatus.FAILED, + steps=self.steps, + stages=self.stages, + tiles=self.tiles, + reason=reason, + # TODO: should this include partial results? ) self.progress.put( self.last_progress, @@ -162,25 +241,3 @@ def fail(self) -> None: ) except Exception: logger.exception("error setting failure on job %s", self.job) - - -class JobStatus: - name: str - device: str - progress: int - cancelled: bool - finished: bool - - def __init__( - self, - name: str, - device: DeviceParams, - progress: int = 0, - cancelled: bool = False, - finished: bool = False, - ) -> None: - self.name = name - self.device = device.device - self.progress = progress - self.cancelled = cancelled - self.finished = finished diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index d210421e5..cc9d990bc 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -8,7 +8,7 @@ from ..params import DeviceParams from ..server import ServerContext -from .command import JobCommand, ProgressCommand +from .command import JobCommand, JobStatus, Progress, ProgressCommand from .context import WorkerContext from .utils import Interval from .worker import worker_main @@ -201,6 +201,10 @@ def cancel(self, key: str) -> bool: should be cancelled on the next progress callback. """ + if key in self.cancelled_jobs: + logger.debug("cancelling already cancelled job: %s", key) + return True + for job in self.finished_jobs: if job.job == key: logger.debug("cannot cancel finished job: %s", key) @@ -209,6 +213,9 @@ def cancel(self, key: str) -> bool: for job in self.pending_jobs: if job.name == key: self.pending_jobs.remove(job) + self.cancelled_jobs.append( + key + ) # ensure workers never pick up this job and the status endpoint knows about it later logger.info("cancelled pending job: %s", key) return True @@ -221,28 +228,33 @@ def cancel(self, key: str) -> bool: self.cancelled_jobs.append(key) return True - def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]: + def status( + self, key: str + ) -> Tuple[JobStatus, Optional[ProgressCommand], Optional[Progress]]: """ Check if a job has been finished and report the last progress update. - - If the job is still pending, the first item will be True and there will be no ProgressCommand. """ + + if key in self.cancelled_jobs: + logger.debug("checking status for cancelled job: %s", key) + return (JobStatus.CANCELLED, None, None) + if key in self.running_jobs: logger.debug("checking status for running job: %s", key) - return (False, self.running_jobs[key]) + return (JobStatus.RUNNING, self.running_jobs[key], None) for job in self.finished_jobs: if job.job == key: logger.debug("checking status for finished job: %s", key) - return (False, job) + return (job.status, job, None) - for job in self.pending_jobs: + for i, job in enumerate(self.pending_jobs): if job.name == key: logger.debug("checking status for pending job: %s", key) - return (True, None) + return (JobStatus.PENDING, None, Progress(i, len(self.pending_jobs))) logger.trace("checking status for unknown job: %s", key) - return (False, None) + return (JobStatus.UNKNOWN, None, None) def join(self): logger.info("stopping worker pool") @@ -383,12 +395,13 @@ def recycle(self, recycle_all=False): def submit( self, key: str, + job_type: str, fn: Callable[..., None], /, *args, needs_device: Optional[DeviceParams] = None, **kwargs, - ) -> None: + ) -> int: device_idx = self.get_next_device(needs_device=needs_device) device = self.devices[device_idx].device logger.info( @@ -399,56 +412,66 @@ def submit( ) # build and queue job - job = JobCommand(key, device, fn, args, kwargs) + job = JobCommand(key, device, job_type, fn, args, kwargs) self.pending_jobs.append(job) - def status(self) -> Dict[str, List[Tuple[str, int, bool, bool, bool, bool]]]: + # return position in queue + return len(self.pending_jobs) + + def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]: """ Returns a tuple of: job/device, progress, progress, finished, cancelled, failed """ - return { - "cancelled": [], - "finished": [ + + jobs: List[Tuple[str, int, JobStatus]] = [] + jobs.extend( + [ ( - job.job, - job.progress, - False, - job.finished, - job.cancelled, - job.failed, + job, + 0, + JobStatus.CANCELLED, ) - for job in self.finished_jobs - ], - "pending": [ + for job in self.cancelled_jobs + ] + ) + jobs.extend( + [ ( job.name, 0, - True, - False, - False, - False, + JobStatus.PENDING, ) for job in self.pending_jobs - ], - "running": [ + ] + ) + jobs.extend( + [ ( name, - job.progress, - False, - job.finished, - job.cancelled, - job.failed, + job.steps, + job.status, ) for name, job in self.running_jobs.items() - ], - "total": [ + ] + ) + jobs.extend( + [ + ( + job.job, + job.steps, + job.status, + ) + for job in self.finished_jobs + ] + ) + + return { + "jobs": jobs, + "workers": [ ( device, total, self.workers[device].is_alive(), - False, - False, - False, ) for device, total in self.total_jobs.items() ], @@ -476,20 +499,18 @@ def finish_job(self, progress: ProgressCommand): self.cancelled_jobs.remove(progress.job) def update_job(self, progress: ProgressCommand): - if progress.finished: + if progress.status in [JobStatus.SUCCESS, JobStatus.FAILED]: return self.finish_job(progress) # move from pending to running - logger.debug( - "progress update for job: %s to %s", progress.job, progress.progress - ) + logger.debug("progress update for job: %s to %s", progress.job, progress.steps) self.running_jobs[progress.job] = progress self.pending_jobs[:] = [ job for job in self.pending_jobs if job.name != progress.job ] # increment job counter if this is the start of a new job - if progress.progress == 0: + if progress.steps == 0: if progress.device in self.total_jobs: self.total_jobs[progress.device] += 1 else: diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 55ebcaac1..260f6674e 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -5,7 +5,7 @@ from setproctitle import setproctitle -from ..errors import RetryException +from ..errors import CancelledException, RetryException from ..server import ServerContext, apply_patches from ..torch_before_ort import get_available_providers from .context import WorkerContext @@ -57,7 +57,7 @@ def worker_main( logger.info("worker %s got job: %s", worker.device.device, job.name) # clear flags and save the job name - worker.start(job.name) + worker.start(job) logger.info("starting job: %s", job.name) # reset progress, which does a final check for cancellation @@ -82,13 +82,16 @@ def worker_main( logger.exception("value error in worker, exiting") worker.fail() return exit(EXIT_ERROR) + except CancelledException as e: + logger.warning("job was cancelled, continuing") + worker.fail(e.reason or "cancelled") except Exception as e: e_str = str(e) # restart the worker on memory errors for e_mem in MEMORY_ERRORS: if e_mem in e_str: logger.error("detected out-of-memory error, exiting: %s", e) - worker.fail() + worker.fail("oom") return exit(EXIT_MEMORY) # carry on for other errors diff --git a/api/params.json b/api/params.json index 7f7d5e766..4b5f3ef43 100644 --- a/api/params.json +++ b/api/params.json @@ -1,5 +1,11 @@ { "version": "0.12.0", + "motd": { + "de": "Willkommen bei onnx-web!", + "en": "Welcome to onnx-web!", + "es": "Bienvenido a onnx-web!", + "fr": "Bienvenue sur onnx-web!" + }, "batch": { "default": 1, "min": 1, @@ -111,6 +117,29 @@ "default": "", "keys": [] }, + "latentSymmetry": { + "enabled": { + "default": false + }, + "gradientStart": { + "default": 0.0, + "min": 0, + "max": 0.5, + "step": 0.01 + }, + "gradientEnd": { + "default": 0.25, + "min": 0.0, + "max": 0.5, + "step": 0.01 + }, + "lineOfSymmetry": { + "default": 0.5, + "min": 0, + "max": 1, + "step": 0.01 + } + }, "left": { "default": 0, "min": 0, @@ -157,6 +186,25 @@ "default": "an astronaut eating a hamburger", "keys": [] }, + "promptEditing": { + "default": false, + "addSuffix": { + "default": "" + }, + "filter": { + "default": "none", + "keys": [] + }, + "minLength": { + "default": 80, + "min": 1, + "max": 200, + "step": 1 + }, + "removeTokens": { + "default": "" + } + }, "right": { "default": 0, "min": 0, @@ -201,6 +249,9 @@ "spiral" ] }, + "thumbnail": { + "default": true + }, "top": { "default": 0, "min": 0, @@ -249,4 +300,4 @@ "max": 8192, "step": 8 } -} +} \ No newline at end of file diff --git a/api/requirements/nvidia.txt b/api/requirements/nvidia.txt index 245421707..084748a6b 100644 --- a/api/requirements/nvidia.txt +++ b/api/requirements/nvidia.txt @@ -1,4 +1,4 @@ ---extra-index-url https://download.pytorch.org/whl/cu121 +--extra-index-url https://download.pytorch.org/whl/cu118 torch==2.1.1 torchvision==0.16.1 onnxruntime-gpu==1.16.3 diff --git a/api/scripts/test-diffusers.py b/api/scripts/test-diffusers.py index 5852eb6b7..3fba1ce40 100644 --- a/api/scripts/test-diffusers.py +++ b/api/scripts/test-diffusers.py @@ -1,12 +1,6 @@ from diffusers import OnnxStableDiffusionPipeline from os import path -import cv2 -import numpy as np -import onnxruntime as ort -import torch -import time - cfg = 8 steps = 22 height = 512 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-0.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-0.png index 9af7aed24..5bf77eb84 100644 --- a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-0.png +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d4106859b033cf57d0f2b08c36cff38beb5a67f6e99e2eddefad998223be4182 -size 498897 +oid sha256:a85dac3ec1ca707f9faac3c5689019e1e08d7edaf0b1cfeb097b91a26f4ea509 +size 497978 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-cloud-0.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-cloud-0.png index 198de189c..ec330db66 100644 --- a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-cloud-0.png +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-cloud-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c8fd64b0e15d0a60ac40f4b55d4d2447b687eb3361827fdb287840763903cfc7 -size 478703 +oid sha256:a57b8b2abbe75db72e4d859db9e1cfba5cd7e516157a1e4542dc5181bd4595da +size 496043 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis-0.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis-0.png index 84d7abfdf..5bf77eb84 100644 --- a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis-0.png +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-deis-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ce7ac5d1989ba5997b23d15adbcfd2f4565da47d67e746f2d1f75c780d090f8c -size 498915 +oid sha256:a85dac3ec1ca707f9faac3c5689019e1e08d7edaf0b1cfeb097b91a26f4ea509 +size 497978 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png index 7919d7d70..0eaaf35fa 100644 --- a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-dpm-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:93def1c9b1355ed33d25916df16f037a11ba4ba3ee6bcd487d58818371fc7ad5 -size 526093 +oid sha256:608de6683949bc2dff1fcaa92f46205e2e16db88122a14bda75e6abc60933a54 +size 513274 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png index 004080506..f2c4b16b0 100644 --- a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-heun-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bb9043d076c9084fb3f74fa5f58ec734c781ef403ae9723be261dff27e191f5e -size 494092 +oid sha256:9b0aadbe42afc5482a53b4f587292d7c024a676fce9722b4a5404f566d324321 +size 494617 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-taters-0.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-taters-0.png index 4e0d79855..33b779c0c 100644 --- a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-taters-0.png +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-taters-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b97e71085b7166e6676a9c74590982f8f755c0c7256202a921aa0b2e70f06bd8 -size 451466 +oid sha256:e626b72827d603e1be3164035ad60d471ba6fdaefc165c5f033ceba088235c25 +size 462202 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-unipc-0.png b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-unipc-0.png index 722a83bea..1aa8578f4 100644 --- a/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-unipc-0.png +++ b/api/scripts/test-refs/txt2img-sd-v1-5-512-muffin-unipc-0.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1cff7dfca10127a9169f0f51ee6d3f46b9bec0bbe30c149aed183a4fab69844a -size 513803 +oid sha256:b1161740efad3946fcdf26b65c4031da3591b29c9d384373c16c99fc6515bb6b +size 502265 diff --git a/api/scripts/test-refs/txt2img-sd-v1-5-highres-muffin-0.png b/api/scripts/test-refs/txt2img-sd-v1-5-highres-muffin-0.png new file mode 100644 index 000000000..408d2dc51 --- /dev/null +++ b/api/scripts/test-refs/txt2img-sd-v1-5-highres-muffin-0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cef498ba1ba0ee35f8f898688a5748d3d0de2209a50df237be64ba0b2dd6f49 +size 1543938 diff --git a/api/scripts/test-release.py b/api/scripts/test-release.py index 1af512c5c..f0f5da9a8 100644 --- a/api/scripts/test-release.py +++ b/api/scripts/test-release.py @@ -60,71 +60,71 @@ def __init__( TEST_DATA = [ TestCase( "txt2img-sd-v1-5-256-muffin", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=256&height=256", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=256&height=256&thumbnail=false", ), TestCase( "txt2img-sd-v1-5-512-muffin", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&thumbnail=false", ), TestCase( "txt2img-sd-v1-5-512-muffin-deis", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=deis&thumbnail=false", mse_threshold=LOOSE_TEST, ), TestCase( "txt2img-sd-v1-5-512-muffin-dpm", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=dpm-multi", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=dpm-multi&thumbnail=false", ), TestCase( "txt2img-sd-v1-5-512-muffin-heun", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=heun&thumbnail=false", mse_threshold=LOOSE_TEST, ), TestCase( "txt2img-sd-v1-5-512-muffin-unipc", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=unipc-multi", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=unipc-multi&thumbnail=false", mse_threshold=LOOSE_TEST, ), TestCase( "txt2img-sd-v2-1-512-muffin", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&thumbnail=false", ), TestCase( "txt2img-sd-v2-1-768-muffin", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768&unet_tile=768", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v2-1&width=768&height=768&unet_tile=768&thumbnail=false", max_attempts=SLOW_TEST, ), TestCase( "txt2img-openjourney-256-muffin", - "txt2img?prompt=mdjrny-v4+style+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-openjourney&width=256&height=256", + "txt2img?prompt=mdjrny-v4+style+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-openjourney&width=256&height=256&thumbnail=false", ), TestCase( "txt2img-openjourney-512-muffin", - "txt2img?prompt=mdjrny-v4+style+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-openjourney", + "txt2img?prompt=mdjrny-v4+style+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-openjourney&thumbnail=false", ), TestCase( "txt2img-knollingcase-512-muffin", - "txt2img?prompt=knollingcase+display+case+with+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-knollingcase", + "txt2img?prompt=knollingcase+display+case+with+a+giant+muffin&seed=0&scheduler=ddim&model=diffusion-knollingcase&thumbnail=false", ), TestCase( "img2img-sd-v1-5-512-pumpkin", - "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none", + "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&thumbnail=false", source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "img2img-sd-v1-5-256-pumpkin", - "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&unet_tile=256", + "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&unet_tile=256&thumbnail=false", source="txt2img-sd-v1-5-256-muffin-0", ), TestCase( "inpaint-v1-512-white", - "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting", + "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&thumbnail=false", source="txt2img-sd-v1-5-512-muffin-0", mask="mask-white", ), TestCase( "inpaint-v1-512-black", - "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting", + "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&thumbnail=false", source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", ), @@ -132,7 +132,7 @@ def __init__( "outpaint-even-256", ( "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" - "&top=256&bottom=256&left=256&right=256" + "&top=256&bottom=256&left=256&right=256&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", @@ -143,7 +143,7 @@ def __init__( "outpaint-vertical-512", ( "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" - "&top=512&bottom=512&left=0&right=0" + "&top=512&bottom=512&left=0&right=0&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", @@ -154,7 +154,7 @@ def __init__( "outpaint-horizontal-512", ( "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" - "&top=0&bottom=0&left=512&right=512" + "&top=0&bottom=0&left=512&right=512&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", @@ -163,17 +163,17 @@ def __init__( ), TestCase( "upscale-resrgan-x2-1024-muffin", - "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2&upscale=true", + "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x2-plus&scale=2&outscale=2&upscale=true&thumbnail=false", source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "upscale-resrgan-x4-2048-muffin", - "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4&upscale=true", + "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4&upscale=true&thumbnail=false", source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "blend-512-muffin-black", - "blend?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", + "blend?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&thumbnail=false", mask="mask-black", source=[ "txt2img-sd-v1-5-512-muffin-0", @@ -182,7 +182,7 @@ def __init__( ), TestCase( "blend-512-muffin-white", - "blend?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", + "blend?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&thumbnail=false", mask="mask-white", source=[ "txt2img-sd-v1-5-512-muffin-0", @@ -191,7 +191,7 @@ def __init__( ), TestCase( "blend-512-muffin-blend", - "blend?prompt=a+giant+pumpkin&seed=0&scheduler=ddim", + "blend?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&thumbnail=false", mask="mask-blend", source=[ "txt2img-sd-v1-5-512-muffin-0", @@ -200,30 +200,30 @@ def __init__( ), TestCase( "txt2img-sd-v1-5-512-muffin-taters", - "txt2img?prompt=+a+giant+muffin+made+of+mashed+potatoes&seed=0&scheduler=unipc-multi", + "txt2img?prompt=+a+giant+muffin+made+of+mashed+potatoes&seed=0&scheduler=unipc-multi&thumbnail=false", ), TestCase( "txt2img-sd-v1-5-512-muffin-cloud", - "txt2img?prompt=+a+giant+muffin+made+of+cloud-all&seed=0&scheduler=unipc-multi", + "txt2img?prompt=+a+giant+muffin+made+of+cloud-all&seed=0&scheduler=unipc-multi&thumbnail=false", ), TestCase( "upscale-swinir-x4-2048-muffin", - "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-swinir-real-large-x4&scale=4&outscale=4&upscale=true", + "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-swinir-real-large-x4&scale=4&outscale=4&upscale=true&thumbnail=false", source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "upscale-codeformer-512-muffin", - "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0", + "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0&thumbnail=false", source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "upscale-gfpgan-muffin", - "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0", + "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0&thumbnail=false", source="txt2img-sd-v1-5-512-muffin-0", ), TestCase( "upscale-sd-x4-2048-muffin", - "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-stable-diffusion-x4&scale=4&outscale=4&upscale=true", + "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-stable-diffusion-x4&scale=4&outscale=4&upscale=true&thumbnail=false", source="txt2img-sd-v1-5-512-muffin-0", max_attempts=VERY_SLOW_TEST, ), @@ -231,7 +231,7 @@ def __init__( "outpaint-panorama-even-256", ( "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=fill-mask" - "&top=256&bottom=256&left=256&right=256&pipeline=panorama" + "&top=256&bottom=256&left=256&right=256&pipeline=panorama&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", @@ -242,7 +242,7 @@ def __init__( "outpaint-panorama-vertical-512", ( "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=histogram" - "&top=512&bottom=512&left=0&right=0&pipeline=panorama" + "&top=512&bottom=512&left=0&right=0&pipeline=panorama&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", @@ -253,7 +253,7 @@ def __init__( "outpaint-panorama-horizontal-512", ( "inpaint?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&model=stable-diffusion-onnx-v1-inpainting&noise=histogram" - "&top=0&bottom=0&left=512&right=512&pipeline=panorama" + "&top=0&bottom=0&left=512&right=512&pipeline=panorama&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", mask="mask-black", @@ -264,7 +264,7 @@ def __init__( "upscale-resrgan-x4-codeformer-2048-muffin", ( "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4" - "&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true" + "&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", max_attempts=SLOW_TEST, @@ -273,7 +273,7 @@ def __init__( "upscale-resrgan-x4-gfpgan-2048-muffin", ( "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus&scale=4&outscale=4" - "&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true" + "&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", max_attempts=SLOW_TEST, @@ -282,7 +282,7 @@ def __init__( "upscale-swinir-x4-codeformer-2048-muffin", ( "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-swinir-real-large-x4&scale=4&outscale=4" - "&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true" + "&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", max_attempts=SLOW_TEST, @@ -291,7 +291,7 @@ def __init__( "upscale-swinir-x4-gfpgan-2048-muffin", ( "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-swinir-real-large-x4&scale=4&outscale=4" - "&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true" + "&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", max_attempts=SLOW_TEST, @@ -300,7 +300,7 @@ def __init__( "upscale-sd-x4-codeformer-2048-muffin", ( "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-stable-diffusion-x4&scale=4&outscale=4" - "&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true" + "&correction=correction-codeformer&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", max_attempts=VERY_SLOW_TEST, @@ -308,32 +308,32 @@ def __init__( TestCase( "upscale-sd-x4-gfpgan-2048-muffin", ( - "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-stable-diffusion-x4" - "&scale=4&outscale=4&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true" + "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-stable-diffusion-x4&scale=4" + "&outscale=4&correction=correction-gfpgan&faces=true&faceOutscale=1&faceStrength=1.0&upscale=true&thumbnail=false" ), source="txt2img-sd-v1-5-512-muffin-0", max_attempts=VERY_SLOW_TEST, ), TestCase( "txt2img-panorama-1024x768-muffin", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiled_vae=true", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiled_vae=true&thumbnail=false", max_attempts=VERY_SLOW_TEST, ), TestCase( "img2img-panorama-1024x768-pumpkin", - "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiled_vae=true", + "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiled_vae=true&thumbnail=false", source="txt2img-panorama-1024x768-muffin-0", max_attempts=VERY_SLOW_TEST, ), TestCase( "txt2img-sd-v1-5-tall-muffin", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768&unet_tile=768", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=512&height=768&unet_tile=768&thumbnail=false", ), TestCase( "upscale-resrgan-x4-tall-muffin", ( "upscale?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&upscaling=upscaling-real-esrgan-x4-plus" - "&scale=4&outscale=4&correction=correction-gfpgan&faces=false&faceOutscale=1&faceStrength=1.0&upscale=true" + "&scale=4&outscale=4&correction=correction-gfpgan&faces=false&faceOutscale=1&faceStrength=1.0&upscale=true&thumbnail=false" ), source="txt2img-sd-v1-5-tall-muffin-0", max_attempts=SLOW_TEST, @@ -342,7 +342,7 @@ def __init__( "txt2img-sdxl-muffin", ( "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=1024&unet_tile=1024" - "&pipeline=txt2img-sdxl&model=diffusion-sdxl-base" + "&pipeline=txt2img-sdxl&model=diffusion-sdxl-base&thumbnail=false" ), max_attempts=SLOW_TEST, ), @@ -350,7 +350,7 @@ def __init__( "txt2img-sdxl-lcm-muffin", ( "txt2img?prompt=+a+giant+muffin&seed=0&scheduler=lcm&width=1024&height=1024" - "&unet_tile=1024&pipeline=txt2img-sdxl&model=diffusion-sdxl-base&cfg=1.5&steps=10" + "&unet_tile=1024&pipeline=txt2img-sdxl&model=diffusion-sdxl-base&cfg=1.5&steps=10&thumbnail=false" ), max_attempts=SLOW_TEST, mse_threshold=LOOSE_TEST, @@ -359,7 +359,7 @@ def __init__( "txt2img-sdxl-turbo-muffin", ( "txt2img?prompt=a+giant+muffin&seed=0&scheduler=dpm-sde&width=512&height=512&unet_tile=512" - "&pipeline=txt2img-sdxl&model=diffusion-sdxl-turbo&cfg=1&steps=5" + "&pipeline=txt2img-sdxl&model=diffusion-sdxl-turbo&cfg=1&steps=5&thumbnail=false" ), max_attempts=SLOW_TEST, mse_threshold=LOOSE_TEST, @@ -368,11 +368,42 @@ def __init__( "txt2img-sd-v1-5-lcm-muffin", ( "txt2img?prompt=+a+giant+muffin&seed=0&scheduler=lcm&width=512&height=512&unet_tile=512" - "&pipeline=txt2img&cfg=1.5&steps=10" + "&pipeline=txt2img&cfg=1.5&steps=10&thumbnail=false" ), max_attempts=SLOW_TEST, mse_threshold=VERY_LOOSE_TEST, ), + # SD v1.5 highres + TestCase( + "txt2img-sd-v1-5-highres-muffin", + ( + "txt2img?batch=1&cfg=5.00&eta=0.00&steps=25&tiled_vae=true&unet_overlap=0.50&unet_tile=512&vae_overlap=0.50" + "&vae_tile=512&scheduler=ddim&seed=-1&prompt=a+giant+muffin&upscaling=upscaling-real-esrgan-x2-plus" + "&correction=correction-gfpgan-v1-3&control=&width=512&height=512&upscale=false" + "&upscaleOrder=correction-first&highres=true&highresIterations=1&highresMethod=upscale&highresScale=2" + "&highresSteps=50&highresStrength=0.20&thumbnail=false" + ), + max_attempts=VERY_SLOW_TEST, + mse_threshold=LOOSE_TEST, + ), + # SDXL highres + TestCase( + "txt2img-sdxl-highres-muffin", + ( + "txt2img?batch=1&cfg=5.00&eta=0.00&steps=25&tiled_vae=true&unet_overlap=0.50&unet_tile=1024&vae_overlap=0.50" + "&vae_tile=1024&scheduler=ddim&seed=-1&prompt=a+giant+muffin&upscaling=upscaling-real-esrgan-x2-plus" + "&correction=correction-gfpgan-v1-3&control=&width=1024&height=1024&upscale=false&pipeline=txt2img-sdxl" + "&upscaleOrder=correction-first&highres=true&highresIterations=1&highresMethod=upscale&highresScale=2" + "&highresSteps=50&highresStrength=0.20&model=diffusion-sdxl-base&thumbnail=false" + ), + max_attempts=200, + mse_threshold=LOOSE_TEST, + ) + # TODO: highres panorama + # TODO: grid mode + # TODO: grid highres + # TODO: batch size > 1 + # TODO: highres batch size > 1 # TODO: non-square controlnet ] @@ -463,7 +494,7 @@ def generate_images(host: str, test: TestCase) -> Optional[str]: resp = requests.post(f"{host}/api/{test.query}", files=files) if resp.status_code == 200: json = resp.json() - return json.get("outputs") + return json.get("name") else: logger.warning("generate request failed: %s: %s", resp.status_code, resp.text) raise TestError("error generating image") @@ -484,16 +515,28 @@ def check_ready(host: str, key: str) -> bool: logger.warning("ready request failed: %s", resp.status_code) raise TestError("error getting image status") +def check_outputs(host: str, key: str) -> List[str]: + resp = requests.get(f"{host}/api/job/status?jobs={key}") + if resp.status_code == 200: + json = resp.json() + outputs = json[0].get("outputs", []) + return outputs + + logger.warning("getting outputs failed: %s: %s", resp.status_code, resp.text) + raise TestError("error getting image outputs") + +def download_images(host: str, key: str) -> List[Image.Image]: + outputs = check_outputs(host, key) -def download_images(host: str, keys: List[str]) -> List[Image.Image]: images = [] - for key in keys: - resp = requests.get(f"{host}/output/{key}") + for key in outputs: + url = f"{host}/output/{key}" + resp = requests.get(url) if resp.status_code == 200: logger.debug("downloading image: %s", key) images.append(Image.open(BytesIO(resp.content))) else: - logger.warning("download request failed: %s", resp.status_code) + logger.warning("download request failed: %s: %s", url, resp.status_code) raise TestError("error downloading image") return images @@ -528,14 +571,14 @@ def run_test( Generate an image, wait for it to be ready, and calculate the MSE from the reference. """ - keys = generate_images(host, test) - if keys is None: + job = generate_images(host, test) + if job is None: return TestResult.failed(test.name, "could not generate image") ready = False for attempt in tqdm(range(test.max_attempts * time_mult)): - if check_ready(host, keys[0]): - logger.debug("image is ready: %s", keys) + if check_ready(host, job): + logger.debug("image is ready: %s", job) ready = True break else: @@ -545,17 +588,22 @@ def run_test( if not ready: return TestResult.failed(test.name, "image was not ready in time") - results = download_images(host, keys) + results = download_images(host, job) if results is None or len(results) == 0: return TestResult.failed(test.name, "could not download image") passed = False for i in range(len(results)): result = results[i] - result.save(test_path(path.join("test-results", f"{test.name}-{i}.png"))) + result_name = f"{test.name}-{i}.png" + result.save(test_path(path.join("test-results", result_name))) + + ref_name = test_path(path.join("test-refs", result_name)) + if not path.exists(ref_name): + return TestResult.failed(test.name, f"no reference image for {result_name}") - ref_name = test_path(path.join("test-refs", f"{test.name}-{i}.png")) - ref = Image.open(ref_name) if path.exists(ref_name) else None + ref = Image.open(ref_name) + logger.warning("comparing image %s to %s", result, ref) mse = find_mse(result, ref) threshold = test.mse_threshold * mse_mult diff --git a/api/tests/chain/test_blend_denoise_fastnlmeans.py b/api/tests/chain/test_blend_denoise_fastnlmeans.py new file mode 100644 index 000000000..14a42be8d --- /dev/null +++ b/api/tests/chain/test_blend_denoise_fastnlmeans.py @@ -0,0 +1,51 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage +from onnx_web.chain.result import ImageMetadata, StageResult +from tests.helpers import test_params, test_size + + +class TestBlendDenoiseFastNLMeansStage(unittest.TestCase): + def test_run(self): + # Create a dummy image + size = test_size() + image = Image.new("RGB", (size.width, size.height), color="white") + + # Create a dummy StageResult object + sources = StageResult.from_images( + [image], + metadata=[ + ImageMetadata( + test_params(), + size, + ) + ], + ) + + # Create an instance of BlendDenoiseLocalStdStage + stage = BlendDenoiseFastNLMeansStage() + + # Call the run method with dummy parameters + result = stage.run( + _worker=None, + _server=None, + _stage=None, + _params=None, + sources=sources, + strength=5, + range=4, + stage_source=None, + callback=None, + ) + + # Assert that the result is an instance of StageResult + self.assertIsInstance(result, StageResult) + + # Assert that the result contains the denoised image + self.assertEqual(len(result), 1) + self.assertEqual(result.size(), size) + + # Assert that the metadata is preserved + self.assertEqual(result.metadata, sources.metadata) diff --git a/api/tests/chain/test_blend_denoise_localstd.py b/api/tests/chain/test_blend_denoise_localstd.py new file mode 100644 index 000000000..cb908bed7 --- /dev/null +++ b/api/tests/chain/test_blend_denoise_localstd.py @@ -0,0 +1,50 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_denoise_localstd import BlendDenoiseLocalStdStage +from onnx_web.chain.result import ImageMetadata, StageResult +from onnx_web.params import ImageParams, Size + + +class TestBlendDenoiseLocalStdStage(unittest.TestCase): + def test_run(self): + # Create a dummy image + image = Image.new("RGB", (64, 64), color="white") + + # Create a dummy StageResult object + sources = StageResult.from_images( + [image], + metadata=[ + ImageMetadata( + ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0), + Size(64, 64), + ) + ], + ) + + # Create an instance of BlendDenoiseLocalStdStage + stage = BlendDenoiseLocalStdStage() + + # Call the run method with dummy parameters + result = stage.run( + _worker=None, + _server=None, + _stage=None, + _params=None, + sources=sources, + strength=5, + range=4, + stage_source=None, + callback=None, + ) + + # Assert that the result is an instance of StageResult + self.assertIsInstance(result, StageResult) + + # Assert that the result contains the denoised image + self.assertEqual(len(result), 1) + self.assertEqual(result.size(), Size(64, 64)) + + # Assert that the metadata is preserved + self.assertEqual(result.metadata, sources.metadata) diff --git a/api/tests/chain/test_blend_grid.py b/api/tests/chain/test_blend_grid.py index 0e6188b1c..146208294 100644 --- a/api/tests/chain/test_blend_grid.py +++ b/api/tests/chain/test_blend_grid.py @@ -3,7 +3,8 @@ from PIL import Image from onnx_web.chain.blend_grid import BlendGridStage -from onnx_web.chain.result import StageResult +from onnx_web.chain.result import ImageMetadata, StageResult +from onnx_web.params import ImageParams, Size class BlendGridStageTests(unittest.TestCase): @@ -15,9 +16,17 @@ def test_stage(self): Image.new("RGB", (64, 64), "white"), Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "white"), + ], + metadata=[ + ImageMetadata( + ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1), + Size(64, 64), + ), ] + * 4, ) result = stage.run(None, None, None, None, sources, height=2, width=2) + result.validate() self.assertEqual(len(result), 5) - self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0)) + self.assertEqual(result.as_images()[-1].getpixel((0, 0)), (0, 0, 0)) diff --git a/api/tests/chain/test_blend_img2img.py b/api/tests/chain/test_blend_img2img.py index 9d6f71d92..d7ab22115 100644 --- a/api/tests/chain/test_blend_img2img.py +++ b/api/tests/chain/test_blend_img2img.py @@ -3,8 +3,8 @@ from PIL import Image from onnx_web.chain.blend_img2img import BlendImg2ImgStage -from onnx_web.chain.result import StageResult -from onnx_web.params import ImageParams +from onnx_web.chain.result import ImageMetadata, StageResult +from onnx_web.params import ImageParams, Size from onnx_web.server.context import ServerContext from onnx_web.worker.context import WorkerContext from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models @@ -39,9 +39,16 @@ def test_stage(self): sources = StageResult( images=[ Image.new("RGB", (64, 64), "black"), - ] + ], + metadata=[ + ImageMetadata( + ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1), + Size(64, 64), + ), + ], ) result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1) + result.validate() self.assertEqual(len(result), 1) - self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0)) + self.assertEqual(result.as_images()[0].getpixel((0, 0)), (0, 0, 0)) diff --git a/api/tests/chain/test_blend_linear.py b/api/tests/chain/test_blend_linear.py index 76a2715a3..aab545fab 100644 --- a/api/tests/chain/test_blend_linear.py +++ b/api/tests/chain/test_blend_linear.py @@ -3,7 +3,8 @@ from PIL import Image from onnx_web.chain.blend_linear import BlendLinearStage -from onnx_web.chain.result import StageResult +from onnx_web.chain.result import ImageMetadata, StageResult +from onnx_web.params import ImageParams, Size class BlendLinearStageTests(unittest.TestCase): @@ -12,12 +13,19 @@ def test_stage(self): sources = StageResult( images=[ Image.new("RGB", (64, 64), "black"), - ] + ], + metadata=[ + ImageMetadata( + ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1), + Size(64, 64), + ), + ], ) stage_source = Image.new("RGB", (64, 64), "white") result = stage.run( None, None, None, None, sources, alpha=0.5, stage_source=stage_source ) + result.validate() self.assertEqual(len(result), 1) - self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127)) + self.assertEqual(result.as_images()[0].getpixel((0, 0)), (127, 127, 127)) diff --git a/api/tests/chain/test_blend_mask.py b/api/tests/chain/test_blend_mask.py index f168fab95..6611e3a65 100644 --- a/api/tests/chain/test_blend_mask.py +++ b/api/tests/chain/test_blend_mask.py @@ -23,5 +23,6 @@ def test_empty(self): stage_source=Image.new("RGBA", (64, 64)), dims=(0, 0, SizeChart.auto), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_correct_codeformer.py b/api/tests/chain/test_correct_codeformer.py index 1498beebf..e1ac41dea 100644 --- a/api/tests/chain/test_correct_codeformer.py +++ b/api/tests/chain/test_correct_codeformer.py @@ -42,5 +42,6 @@ def test_empty(self): highres=HighresParams(False, 1, 0, 0), upscale=UpscaleParams(""), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_correct_gfpgan.py b/api/tests/chain/test_correct_gfpgan.py index 9f8b6cb3e..a90449fb3 100644 --- a/api/tests/chain/test_correct_gfpgan.py +++ b/api/tests/chain/test_correct_gfpgan.py @@ -6,13 +6,14 @@ from onnx_web.server.context import ServerContext from onnx_web.server.hacks import apply_patches from onnx_web.worker.context import WorkerContext -from tests.helpers import test_device, test_needs_onnx_models +from tests.helpers import test_device, test_needs_models -TEST_MODEL = "../models/correction-gfpgan-v1-3" +TEST_MODEL_NAME = "correction-gfpgan-v1-3" +TEST_MODEL = f"../models/.cache/{TEST_MODEL_NAME}.pth" class CorrectGFPGANStageTests(unittest.TestCase): - @test_needs_onnx_models([TEST_MODEL]) + @test_needs_models([TEST_MODEL]) def test_empty(self): server = ServerContext(model_path="../models", output_path="../outputs") apply_patches(server) @@ -33,12 +34,13 @@ def test_empty(self): sources = StageResult.empty() result = stage.run( worker, - None, + server, None, None, sources, highres=HighresParams(False, 1, 0, 0), - upscale=UpscaleParams(TEST_MODEL), + upscale=UpscaleParams(TEST_MODEL_NAME, TEST_MODEL_NAME), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_edit_metadata.py b/api/tests/chain/test_edit_metadata.py new file mode 100644 index 000000000..fda0bf01e --- /dev/null +++ b/api/tests/chain/test_edit_metadata.py @@ -0,0 +1,41 @@ +import unittest +from unittest.mock import MagicMock + +from onnx_web.chain.edit_metadata import EditMetadataStage + + +class TestEditMetadataStage(unittest.TestCase): + def setUp(self): + self.stage = EditMetadataStage() + + def test_run_with_no_changes(self): + source = MagicMock() + source.metadata = [] + + result = self.stage.run(None, None, None, None, source) + + self.assertEqual(result, source) + + def test_run_with_note_change(self): + source = MagicMock() + source.metadata = [MagicMock()] + note = "New note" + + result = self.stage.run(None, None, None, None, source, note=note) + + self.assertEqual(result, source) + self.assertEqual(result.metadata[0].note, note) + + def test_run_with_replace_params_change(self): + source = MagicMock() + source.metadata = [MagicMock()] + replace_params = MagicMock() + + result = self.stage.run( + None, None, None, None, source, replace_params=replace_params + ) + + self.assertEqual(result, source) + self.assertEqual(result.metadata[0].params, replace_params) + + # Add more test cases for other parameters... diff --git a/api/tests/chain/test_edit_text.py b/api/tests/chain/test_edit_text.py new file mode 100644 index 000000000..4d0b80cef --- /dev/null +++ b/api/tests/chain/test_edit_text.py @@ -0,0 +1,48 @@ +import unittest + +import numpy as np +from PIL import Image + +from onnx_web.chain.edit_text import EditTextStage +from onnx_web.chain.result import StageResult + + +class TestEditTextStage(unittest.TestCase): + def test_run(self): + # Create a sample image + image = Image.new("RGB", (100, 100), color="black") + + # Create an instance of EditTextStage + stage = EditTextStage() + + # Define the input parameters + text = "Hello, World!" + position = (10, 10) + fill = "white" + stroke = "white" + stroke_width = 2 + + # Create a mock source StageResult + source = StageResult.from_images([image], metadata={}) + + # Call the run method + result = stage.run( + None, + None, + None, + None, + source, + text=text, + position=position, + fill=fill, + stroke=stroke, + stroke_width=stroke_width, + ) + + # Assert the output + self.assertEqual(len(result.as_images()), 1) + # self.assertEqual(result.metadata, {}) + + # Verify the modified image + modified_image = result.as_images()[0] + self.assertEqual(np.max(np.array(modified_image)), 255) diff --git a/api/tests/chain/test_reduce_crop.py b/api/tests/chain/test_reduce_crop.py index bfc7adc4d..81629df2f 100644 --- a/api/tests/chain/test_reduce_crop.py +++ b/api/tests/chain/test_reduce_crop.py @@ -20,5 +20,6 @@ def test_empty(self): origin=Size(0, 0), size=Size(128, 128), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_reduce_thumbnail.py b/api/tests/chain/test_reduce_thumbnail.py index 8b1296722..a162bf5e4 100644 --- a/api/tests/chain/test_reduce_thumbnail.py +++ b/api/tests/chain/test_reduce_thumbnail.py @@ -24,5 +24,6 @@ def test_empty(self): size=Size(128, 128), stage_source=stage_source, ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_result.py b/api/tests/chain/test_result.py new file mode 100644 index 000000000..02f998904 --- /dev/null +++ b/api/tests/chain/test_result.py @@ -0,0 +1,35 @@ +import unittest + +from onnx_web.chain.result import ImageMetadata + + +class ImageMetadataTests(unittest.TestCase): + def test_from_exif_normal(self): + exif_data = """test prompt +Negative prompt: negative prompt +Sampler: ddim, CFG scale: 4.0, Steps: 30, Seed: 5 +""" + + metadata = ImageMetadata.from_exif(exif_data) + self.assertEqual(metadata.params.prompt, "test prompt") + self.assertEqual(metadata.params.negative_prompt, "negative prompt") + self.assertEqual(metadata.params.scheduler, "ddim") + self.assertEqual(metadata.params.cfg, 4.0) + self.assertEqual(metadata.params.steps, 30) + self.assertEqual(metadata.params.seed, 5) + + def test_from_exif_split(self): + exif_data = """test prompt +Negative prompt: negative prompt +Sampler: ddim, +CFG scale: 4.0, +Steps: 30, Seed: 5 +""" + + metadata = ImageMetadata.from_exif(exif_data) + self.assertEqual(metadata.params.prompt, "test prompt") + self.assertEqual(metadata.params.negative_prompt, "negative prompt") + self.assertEqual(metadata.params.scheduler, "ddim") + self.assertEqual(metadata.params.cfg, 4.0) + self.assertEqual(metadata.params.steps, 30) + self.assertEqual(metadata.params.seed, 5) diff --git a/api/tests/chain/test_source_noise.py b/api/tests/chain/test_source_noise.py index 37c99bfac..40b0c437b 100644 --- a/api/tests/chain/test_source_noise.py +++ b/api/tests/chain/test_source_noise.py @@ -22,5 +22,6 @@ def test_empty(self): size=Size(128, 128), noise_source=noise_source_fill_edge, ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_s3.py b/api/tests/chain/test_source_s3.py index 59bbb72f5..8587859fa 100644 --- a/api/tests/chain/test_source_s3.py +++ b/api/tests/chain/test_source_s3.py @@ -22,5 +22,6 @@ def test_empty(self): bucket="test", source_keys=[], ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_url.py b/api/tests/chain/test_source_url.py index 4d03dedb0..59d3f9904 100644 --- a/api/tests/chain/test_source_url.py +++ b/api/tests/chain/test_source_url.py @@ -21,5 +21,6 @@ def test_empty(self): size=Size(128, 128), source_urls=[], ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_tile.py b/api/tests/chain/test_tile.py index c27cb0776..93f8e23d5 100644 --- a/api/tests/chain/test_tile.py +++ b/api/tests/chain/test_tile.py @@ -2,7 +2,7 @@ from PIL import Image -from onnx_web.chain.result import StageResult +from onnx_web.chain.result import ImageMetadata, StageResult from onnx_web.chain.tile import ( complete_tile, generate_tile_grid, @@ -125,16 +125,28 @@ def test_spiral_50_overlap(self): class TestProcessTileStack(unittest.TestCase): def test_grid_full(self): source = Image.new("RGB", (64, 64)) - blend = process_tile_stack( - StageResult(images=[source]), 32, 1, [], generate_tile_grid + result = process_tile_stack( + StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]), + 32, + 1, + [], + generate_tile_grid, ) + images = result.as_images() - self.assertEqual(blend[0].size, (64, 64)) + self.assertEqual(len(images), 1) + self.assertEqual(images[0].size, (64, 64)) def test_grid_partial(self): source = Image.new("RGB", (72, 72)) - blend = process_tile_stack( - StageResult(images=[source]), 32, 1, [], generate_tile_grid + result = process_tile_stack( + StageResult(images=[source], metadata=[ImageMetadata.unknown_image()]), + 32, + 1, + [], + generate_tile_grid, ) + images = result.as_images() - self.assertEqual(blend[0].size, (72, 72)) + self.assertEqual(len(images), 1) + self.assertEqual(images[0].size, (72, 72)) diff --git a/api/tests/chain/test_upscale_bsrgan.py b/api/tests/chain/test_upscale_bsrgan.py index f93b800c7..cca11f115 100644 --- a/api/tests/chain/test_upscale_bsrgan.py +++ b/api/tests/chain/test_upscale_bsrgan.py @@ -37,5 +37,6 @@ def test_empty(self): highres=HighresParams(False, 1, 0, 0), upscale=UpscaleParams(TEST_MODEL), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_highres.py b/api/tests/chain/test_upscale_highres.py index 096eea544..788d36128 100644 --- a/api/tests/chain/test_upscale_highres.py +++ b/api/tests/chain/test_upscale_highres.py @@ -18,5 +18,6 @@ def test_empty(self): highres=HighresParams(False, 1, 0, 0), upscale=UpscaleParams(""), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_outpaint.py b/api/tests/chain/test_upscale_outpaint.py index 261a8d457..19a67dd7e 100644 --- a/api/tests/chain/test_upscale_outpaint.py +++ b/api/tests/chain/test_upscale_outpaint.py @@ -46,5 +46,6 @@ def test_empty(self): dims=(), tile_mask=Image.new("RGB", (64, 64)), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_resrgan.py b/api/tests/chain/test_upscale_resrgan.py index f832767f1..f464947e4 100644 --- a/api/tests/chain/test_upscale_resrgan.py +++ b/api/tests/chain/test_upscale_resrgan.py @@ -35,5 +35,6 @@ def test_empty(self): highres=HighresParams(False, 1, 0, 0), upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_swinir.py b/api/tests/chain/test_upscale_swinir.py index dfa9676e8..ce23695db 100644 --- a/api/tests/chain/test_upscale_swinir.py +++ b/api/tests/chain/test_upscale_swinir.py @@ -35,5 +35,6 @@ def test_empty(self): highres=HighresParams(False, 1, 0, 0), upscale=UpscaleParams(TEST_MODEL), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/convert/diffusion/test_lora.py b/api/tests/convert/diffusion/test_lora.py index a39b0b8b0..61c345a6e 100644 --- a/api/tests/convert/diffusion/test_lora.py +++ b/api/tests/convert/diffusion/test_lora.py @@ -1,11 +1,14 @@ import unittest +from unittest.mock import MagicMock, patch import numpy as np import torch from onnx import GraphProto, ModelProto, NodeProto from onnx.numpy_helper import from_array +from onnx_web.constants import ONNX_MODEL from onnx_web.convert.diffusion.lora import ( + blend_loras, blend_node_conv_gemm, blend_node_matmul, blend_weights_loha, @@ -226,6 +229,31 @@ def test_xl_keys(self): def test_node_dtype(self): pass + @patch("onnx_web.convert.diffusion.lora.load") + @patch("onnx_web.convert.diffusion.lora.load_tensor") + def test_blend_loras_load_str(self, mock_load_tensor, mock_load): + loras = [("loras/model1.safetensors", 0.5), ("loras/safetensors.onnx", 0.5)] + model_type = "unet" + model_index = 2 + xl = True + + mock_load.return_value = MagicMock() + mock_load_tensor.return_value = MagicMock() + + # Call the blend_loras function + blended_model = blend_loras( + None, ONNX_MODEL, loras, model_type, model_index, xl + ) + + # Assert that the InferenceSession is called with the correct arguments + mock_load.assert_called_once_with(ONNX_MODEL) + + # Assert that the model is loaded successfully + self.assertEqual(blended_model, mock_load.return_value) + + # Assert that the blending logic is executed correctly + # (assertions specific to the blending logic can be added here) + class BlendWeightsLoHATests(unittest.TestCase): def test_blend_t1_t2(self): diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index 34b0bf9b3..1a314f25d 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -1,9 +1,18 @@ import unittest +from os import path +from unittest import mock +from unittest.mock import MagicMock, patch +from onnx_web.constants import ONNX_MODEL from onnx_web.convert.utils import ( DEFAULT_OPSET, ConversionContext, + build_cache_paths, download_progress, + fix_diffusion_name, + get_first_exists, + load_tensor, + load_torch, remove_prefix, resolve_tensor, source_format, @@ -30,6 +39,32 @@ def test_download_example(self): path = download_progress("https://example.com", "/tmp/example-dot-com") self.assertEqual(path, "/tmp/example-dot-com") + @patch("onnx_web.convert.utils.Path") + @patch("onnx_web.convert.utils.requests") + @patch("onnx_web.convert.utils.shutil") + @patch("onnx_web.convert.utils.tqdm") + def test_download_progress(self, mock_tqdm, mock_shutil, mock_requests, mock_path): + source = "http://example.com/image.jpg" + dest = "/path/to/destination/image.jpg" + + dest_path_mock = MagicMock() + mock_path.return_value.expanduser.return_value.resolve.return_value = ( + dest_path_mock + ) + dest_path_mock.exists.return_value = False + dest_path_mock.absolute.return_value = "test" + mock_requests.get.return_value.status_code = 200 + mock_requests.get.return_value.headers.get.return_value = "1000" + mock_tqdm.wrapattr.return_value.__enter__.return_value = MagicMock() + + result = download_progress(source, dest) + + mock_path.assert_called_once_with(dest) + dest_path_mock.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True) + dest_path_mock.open.assert_called_once_with("wb") + mock_shutil.copyfileobj.assert_called_once() + self.assertEqual(result, str(dest_path_mock.absolute.return_value)) + class TupleToSourceTests(unittest.TestCase): def test_basic_tuple(self): @@ -221,14 +256,6 @@ def test_without_prefix(self): self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar") -class LoadTorchTests(unittest.TestCase): - pass - - -class LoadTensorTests(unittest.TestCase): - pass - - class ResolveTensorTests(unittest.TestCase): @test_needs_models([TEST_MODEL_UPSCALING_SWINIR]) def test_resolve_existing(self): @@ -239,3 +266,257 @@ def test_resolve_existing(self): def test_resolve_missing(self): self.assertIsNone(resolve_tensor("missing")) + + +TORCH_MODEL = "model.pth" + + +class LoadTorchTests(unittest.TestCase): + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_torch_with_torch_load(self, mock_torch, mock_logger): + map_location = "cpu" + checkpoint = MagicMock() + mock_torch.load.return_value = checkpoint + + result = load_torch(TORCH_MODEL, map_location) + + mock_logger.debug.assert_called_once_with( + "loading tensor with Torch: %s", TORCH_MODEL + ) + mock_torch.load.assert_called_once_with(TORCH_MODEL, map_location=map_location) + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger): + checkpoint = MagicMock() + mock_torch.load.side_effect = Exception() + mock_torch.jit.load.return_value = checkpoint + + result = load_torch(TORCH_MODEL) + + mock_logger.debug.assert_called_once_with( + "loading tensor with Torch: %s", TORCH_MODEL + ) + mock_logger.exception.assert_called_once_with( + "error loading with Torch, trying with Torch JIT: %s", TORCH_MODEL + ) + mock_torch.jit.load.assert_called_once_with(TORCH_MODEL) + self.assertEqual(result, checkpoint) + + +LOAD_TENSOR_LOG = "loading tensor: %s" + + +class LoadTensorTests(unittest.TestCase): + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.path") + @patch("onnx_web.convert.utils.torch") + def test_load_tensor_with_no_extension(self, mock_torch, mock_path, mock_logger): + name = "model" + map_location = "cpu" + checkpoint = MagicMock() + mock_path.exists.return_value = True + mock_path.splitext.side_effect = [("model", ""), ("model", ".safetensors")] + mock_torch.load.return_value = checkpoint + + result = load_tensor(name, map_location) + + mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)]) + mock_path.splitext.assert_called_once_with(name) + mock_path.exists.assert_called_once_with(name) + mock_torch.load.assert_called_once_with(name, map_location=map_location) + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.environ") + @patch("onnx_web.convert.utils.safetensors") + def test_load_tensor_with_safetensors_extension( + self, mock_safetensors, mock_environ, mock_logger + ): + name = "model.safetensors" + checkpoint = MagicMock() + mock_environ.__getitem__.return_value = "1" + mock_safetensors.torch.load_file.return_value = checkpoint + + result = load_tensor(name) + + mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)]) + mock_safetensors.torch.load_file.assert_called_once_with(name, device="cpu") + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_tensor_with_pickle_extension(self, mock_torch, mock_logger): + name = "model.pt" + map_location = "cpu" + checkpoint = MagicMock() + mock_torch.load.side_effect = [checkpoint] + + result = load_tensor(name, map_location) + + mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)]) + mock_torch.load.assert_has_calls( + [ + mock.call(name, map_location=map_location), + ] + ) + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger): + map_location = "cpu" + checkpoint = MagicMock() + mock_torch.load.side_effect = [checkpoint] + + result = load_tensor(ONNX_MODEL, map_location) + + mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, ONNX_MODEL)]) + mock_logger.warning.assert_called_once_with( + "tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx" + ) + mock_torch.load.assert_has_calls( + [ + mock.call(ONNX_MODEL, map_location=map_location), + ] + ) + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_tensor_with_unknown_extension(self, mock_torch, mock_logger): + name = "model.xyz" + map_location = "cpu" + checkpoint = MagicMock() + mock_torch.load.side_effect = [checkpoint] + + result = load_tensor(name, map_location) + + mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)]) + mock_logger.warning.assert_called_once_with( + "unknown tensor type, falling back to PyTorch: %s", "xyz" + ) + mock_torch.load.assert_has_calls( + [ + mock.call(name, map_location=map_location), + ] + ) + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_tensor_with_error_loading_tensor(self, mock_torch, mock_logger): + name = "model" + map_location = "cpu" + mock_torch.load.side_effect = Exception() + + with self.assertRaises(ValueError): + load_tensor(name, map_location) + + +class FixDiffusionNameTests(unittest.TestCase): + def test_fix_diffusion_name_with_valid_name(self): + name = "diffusion-model" + result = fix_diffusion_name(name) + self.assertEqual(result, name) + + @patch("onnx_web.convert.utils.logger") + def test_fix_diffusion_name_with_invalid_name(self, logger): + name = "model" + expected_result = "diffusion-model" + result = fix_diffusion_name(name) + + self.assertEqual(result, expected_result) + logger.warning.assert_called_once_with( + "diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match", + name, + ) + + +CACHE_PATH = "/path/to/cache" + + +class BuildCachePathsTests(unittest.TestCase): + def test_build_cache_paths_without_format(self): + client = "client1" + + conversion = ConversionContext(cache_path=CACHE_PATH) + result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH) + + expected_paths = [ + path.join(CACHE_PATH, ONNX_MODEL), + path.join("/path/to/cache/client1", ONNX_MODEL), + ] + self.assertEqual(result, expected_paths) + + def test_build_cache_paths_with_format(self): + name = "model" + client = "client2" + model_format = "onnx" + + conversion = ConversionContext(cache_path=CACHE_PATH) + result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format) + + expected_paths = [ + path.join(CACHE_PATH, ONNX_MODEL), + path.join("/path/to/cache/client2", ONNX_MODEL), + ] + self.assertEqual(result, expected_paths) + + def test_build_cache_paths_with_existing_extension(self): + client = "client3" + model_format = "onnx" + + conversion = ConversionContext(cache_path=CACHE_PATH) + result = build_cache_paths( + conversion, TORCH_MODEL, client, CACHE_PATH, model_format + ) + + expected_paths = [ + path.join(CACHE_PATH, TORCH_MODEL), + path.join("/path/to/cache/client3", TORCH_MODEL), + ] + self.assertEqual(result, expected_paths) + + def test_build_cache_paths_with_empty_extension(self): + name = "model" + client = "client4" + model_format = "onnx" + + conversion = ConversionContext(cache_path=CACHE_PATH) + result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format) + + expected_paths = [ + path.join(CACHE_PATH, ONNX_MODEL), + path.join("/path/to/cache/client4", ONNX_MODEL), + ] + self.assertEqual(result, expected_paths) + + +class GetFirstExistsTests(unittest.TestCase): + @patch("onnx_web.convert.utils.path") + @patch("onnx_web.convert.utils.logger") + def test_get_first_exists_with_existing_path(self, mock_logger, mock_path): + paths = ["path1", "path2", "path3"] + mock_path.exists.side_effect = [False, True, False] + mock_path.return_value = MagicMock() + + result = get_first_exists(paths) + + mock_logger.debug.assert_called_once_with( + "model already exists in cache, skipping fetch: %s", "path2" + ) + self.assertEqual(result, "path2") + + @patch("onnx_web.convert.utils.path") + @patch("onnx_web.convert.utils.logger") + def test_get_first_exists_with_no_existing_path(self, mock_logger, mock_path): + paths = ["path1", "path2", "path3"] + mock_path.exists.return_value = False + + result = get_first_exists(paths) + + mock_logger.debug.assert_not_called() + self.assertIsNone(result) diff --git a/api/tests/helpers.py b/api/tests/helpers.py index f0c10edd2..4803e1093 100644 --- a/api/tests/helpers.py +++ b/api/tests/helpers.py @@ -3,7 +3,7 @@ from typing import List from unittest import skipUnless -from onnx_web.params import DeviceParams +from onnx_web.params import DeviceParams, ImageParams, Size from onnx_web.worker.context import WorkerContext @@ -23,6 +23,14 @@ def test_device() -> DeviceParams: return DeviceParams("cpu", "CPUExecutionProvider") +def test_size() -> Size: + return Size(64, 64) + + +def test_params() -> ImageParams: + return ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0) + + def test_worker() -> WorkerContext: cancel = Value("L", 0) logs = Queue() diff --git a/api/tests/image/test_ade_palette.py b/api/tests/image/test_ade_palette.py new file mode 100644 index 000000000..7749b689b --- /dev/null +++ b/api/tests/image/test_ade_palette.py @@ -0,0 +1,9 @@ +import unittest + +from onnx_web.image.ade_palette import ade_palette + + +class TestADEPalette(unittest.TestCase): + def test_palette_length(self): + palette = ade_palette() + self.assertEqual(len(palette), 150, "Palette length should be 150") diff --git a/api/tests/image/test_laion_face.py b/api/tests/image/test_laion_face.py new file mode 100644 index 000000000..48c0a7d2d --- /dev/null +++ b/api/tests/image/test_laion_face.py @@ -0,0 +1,69 @@ +import unittest + +import numpy as np + +from onnx_web.image.laion_face import draw_pupils, generate_annotation, reverse_channels + + +class TestLaionFace(unittest.TestCase): + @unittest.skip("need to prepare a good input image") + def test_draw_pupils(self): + # Create a dummy image + image = np.zeros((100, 100, 3), dtype=np.uint8) + + # Create a dummy landmark list + class LandmarkList: + def __init__(self, landmarks): + self.landmark = landmarks + + # Create a dummy drawing spec + class DrawingSpec: + def __init__(self, color): + self.color = color + + # Create some dummy landmarks + landmarks = [ + # Add your landmarks here + ] + + # Create a dummy drawing spec + drawing_spec = DrawingSpec(color=(255, 0, 0)) # Red color + + # Call the draw_pupils function + draw_pupils(image, LandmarkList(landmarks), drawing_spec) + + self.assertNotEqual(np.sum(image), 0, "Image should be modified") + + @unittest.skip("need to prepare a good input image") + def test_generate_annotation(self): + # Create a dummy image + image = np.zeros((100, 100, 3), dtype=np.uint8) + + # Call the generate_annotation function + result = generate_annotation(image, max_faces=1, min_confidence=0.5) + + self.assertEqual( + result.shape, + image.shape, + "Result shape should be the same as the input image", + ) + self.assertNotEqual(np.sum(result), 0, "Result should not be all zeros") + + +class TestReverseChannels(unittest.TestCase): + def test_reverse_channels(self): + # Create a dummy image + image = np.zeros((100, 100, 3), dtype=np.uint8) + layer = np.ones((100, 100), dtype=np.uint8) + image[:, :, 0] = layer + + # Call the reverse_channels function + reversed_image = reverse_channels(image) + + self.assertEqual( + image.shape, reversed_image.shape, "Image shape should be the same" + ) + self.assertTrue( + np.array_equal(reversed_image[:, :, 2], layer), + "Channels should be reversed", + ) diff --git a/api/tests/image/test_source_filter.py b/api/tests/image/test_source_filter.py index fb44073ef..85dc2c5c4 100644 --- a/api/tests/image/test_source_filter.py +++ b/api/tests/image/test_source_filter.py @@ -1,11 +1,24 @@ import unittest +from os import path +import numpy as np from PIL import Image from onnx_web.image.source_filter import ( + filter_model_path, + pil_to_cv2, + source_filter_canny, + source_filter_depth, + source_filter_face, source_filter_gaussian, + source_filter_hed, + source_filter_mlsd, source_filter_noise, source_filter_none, + source_filter_normal, + source_filter_openpose, + source_filter_scribble, + source_filter_segment, ) from onnx_web.server.context import ServerContext @@ -35,3 +48,119 @@ def test_basic(self): source = Image.new("RGB", dims) result = source_filter_noise(server, source) self.assertEqual(result.size, dims) + + +class PILToCV2Tests(unittest.TestCase): + def test_conversion(self): + dims = (64, 64) + source = Image.new("RGB", dims) + result = pil_to_cv2(source) + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.shape, (dims[1], dims[0], 3)) + self.assertEqual(result.dtype, np.uint8) + + +class FilterModelPathTests(unittest.TestCase): + def test_filter_model_path(self): + server = ServerContext() + filter_name = "gaussian" + expected_path = path.join(server.model_path, "filter", filter_name) + result = filter_model_path(server, filter_name) + self.assertEqual(result, expected_path) + + +class SourceFilterFaceTests(unittest.TestCase): # Added new test class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_face(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterSegmentTests( + unittest.TestCase +): # Added SourceFilterSegmentTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_segment(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterMLSDTests(unittest.TestCase): # Added SourceFilterMLSDTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_mlsd(server, source) + self.assertEqual(result.size, (512, 512)) + + +class SourceFilterNormalTests(unittest.TestCase): # Added SourceFilterNormalTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_normal(server, source) + + # normal will resize inputs to 384x384 + self.assertEqual(result.size, (384, 384)) + + +class SourceFilterHEDTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_hed(server, source) + self.assertEqual(result.size, (512, 512)) + + +class SourceFilterScribbleTests( + unittest.TestCase +): # Added SourceFilterScribbleTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_scribble(server, source) + + # scribble will resize inputs to 512x512 + self.assertEqual(result.size, (512, 512)) + + +class SourceFilterDepthTests( + unittest.TestCase +): # Added SourceFilterScribbleTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_depth(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterCannyTests( + unittest.TestCase +): # Added SourceFilterScribbleTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_canny(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterOpenPoseTests( + unittest.TestCase +): # Added SourceFilterScribbleTests class + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_openpose(server, source) + + # openpose will resize inputs to 512x512 + self.assertEqual(result.size, (512, 512)) diff --git a/api/tests/mocks.py b/api/tests/mocks.py index ef95d7544..fad18c09b 100644 --- a/api/tests/mocks.py +++ b/api/tests/mocks.py @@ -11,6 +11,7 @@ class MockPipeline: # stubs _encode_prompt: Optional[Any] + scheduler: Optional[Any] unet: Optional[Any] vae_decoder: Optional[Any] vae_encoder: Optional[Any] @@ -23,6 +24,7 @@ def __init__(self) -> None: self.xformers = None self._encode_prompt = None + self.scheduler = None self.unet = None self.vae_decoder = None self.vae_encoder = None diff --git a/api/tests/prompt/test_compel.py b/api/tests/prompt/test_compel.py new file mode 100644 index 000000000..c29a67e5c --- /dev/null +++ b/api/tests/prompt/test_compel.py @@ -0,0 +1,76 @@ +import unittest +from unittest.mock import MagicMock + +import numpy as np + +from onnx_web.prompt.compel import ( + encode_prompt_compel, + encode_prompt_compel_sdxl, + get_inference_session, + wrap_encoder, +) + + +class TestCompelHelpers(unittest.TestCase): + def test_get_inference_session_missing(self): + self.assertRaises(ValueError, get_inference_session, None) + + def test_get_inference_session_onnx_session(self): + model = MagicMock() + model.model = None + model.session = "session" + self.assertEqual(get_inference_session(model), "session") + + def test_get_inference_session_onnx_model(self): + model = MagicMock() + model.model = "model" + model.session = None + self.assertEqual(get_inference_session(model), "model") + + def test_wrap_encoder(self): + text_encoder = MagicMock() + wrapped = wrap_encoder(text_encoder) + self.assertEqual(wrapped.device, "cpu") + self.assertEqual(wrapped.text_encoder, text_encoder) + + +class TestCompelEncodePrompt(unittest.TestCase): + def test_encode_basic(self): + pipeline = MagicMock() + pipeline.text_encoder = MagicMock() + pipeline.text_encoder.return_value = [ + np.array([[1], [2]]), + np.array([[3], [4]]), + ] + pipeline.tokenizer = MagicMock() + pipeline.tokenizer.model_max_length = 1 + + embeds = encode_prompt_compel(pipeline, "prompt", 1, True) + np.testing.assert_equal(embeds, [[[3, 3]], [[3, 3]]]) + + +class TestCompelEncodePromptSDXL(unittest.TestCase): + @unittest.skip("need to fix the tensor shapes") + def test_encode_basic(self): + text_encoder_output = MagicMock() + text_encoder_output.hidden_states = [[0], [1], [2], [3]] + + def call_text_encoder(*args, return_dict=False, **kwargs): + print("call_text_encoder", return_dict) + if return_dict: + return text_encoder_output + + return [np.array([[1]]), np.array([[3]]), np.array([[5]]), np.array([[7]])] + + pipeline = MagicMock() + pipeline.text_encoder.side_effect = call_text_encoder + pipeline.text_encoder_2.side_effect = call_text_encoder + pipeline.tokenizer.model_max_length = 1 + pipeline.tokenizer_2.model_max_length = 1 + + embeds = encode_prompt_compel_sdxl(pipeline, "prompt", 1, True) + np.testing.assert_equal(embeds, [[[3, 3]], [[3, 3]]]) + + +if __name__ == "__main__": + unittest.main() diff --git a/api/tests/prompt/test_parser.py b/api/tests/prompt/test_parser.py index 15c91d6ca..ca167757e 100644 --- a/api/tests/prompt/test_parser.py +++ b/api/tests/prompt/test_parser.py @@ -1,39 +1,128 @@ import unittest -from onnx_web.prompt.grammar import PromptPhrase -from onnx_web.prompt.parser import parse_prompt_onnx +from onnx_web.prompt.base import PromptNetwork, PromptPhrase +from onnx_web.prompt.grammar import PhraseNode, TokenNode +from onnx_web.prompt.parser import compile_prompt_onnx, parse_prompt_onnx class ParserTests(unittest.TestCase): def test_single_word_phrase(self): res = parse_prompt_onnx(None, "foo (bar) bin", debug=False) self.assertListEqual( - [str(i) for i in res], + res, [ - str(["foo"]), - str(PromptPhrase(["bar"], weight=1.5)), - str(["bin"]), + PhraseNode(["foo"]), + PhraseNode(["bar"], weight=1.5), + PhraseNode(["bin"]), ], ) def test_multi_word_phrase(self): res = parse_prompt_onnx(None, "foo bar (middle words) bin bun", debug=False) self.assertListEqual( - [str(i) for i in res], + res, [ - str(["foo", "bar"]), - str(PromptPhrase(["middle", "words"], weight=1.5)), - str(["bin", "bun"]), + PhraseNode(["foo", "bar"]), + PhraseNode(["middle", "words"], weight=1.5), + PhraseNode(["bin", "bun"]), ], ) def test_nested_phrase(self): res = parse_prompt_onnx(None, "foo (((bar))) bin", debug=False) self.assertListEqual( - [str(i) for i in res], + res, [ - str(["foo"]), - str(PromptPhrase(["bar"], weight=(1.5**3))), - str(["bin"]), + PhraseNode(["foo"]), + PhraseNode(["bar"], weight=(1.5**3)), + PhraseNode(["bin"]), + ], + ) + + def test_clip_skip_token(self): + res = parse_prompt_onnx(None, "foo bin", debug=False) + self.assertListEqual( + res, + [ + PhraseNode(["foo"]), + TokenNode("clip", "skip", 2), + PhraseNode(["bin"]), + ], + ) + + def test_lora_token(self): + res = parse_prompt_onnx(None, "foo bin", debug=False) + self.assertListEqual( + res, + [ + PhraseNode(["foo"]), + TokenNode("lora", "name", 1.5), + PhraseNode(["bin"]), + ], + ) + + def test_region_token(self): + res = parse_prompt_onnx( + None, "foo bin", debug=False + ) + self.assertListEqual( + res, + [ + PhraseNode(["foo"]), + TokenNode("region", None, [1, 2, 3, 4, 0.5, 0.75, ["prompt"]]), + PhraseNode(["bin"]), + ], + ) + + def test_reseed_token(self): + res = parse_prompt_onnx(None, "foo bin", debug=False) + self.assertListEqual( + res, + [ + PhraseNode(["foo"]), + TokenNode("reseed", None, [1, 2, 3, 4, 12345]), + PhraseNode(["bin"]), + ], + ) + + def test_compile_tokens(self): + prompt = compile_prompt_onnx("foo bar (baz) ") + + self.assertEqual(prompt.clip_skip, 2) + self.assertEqual(prompt.networks, [PromptNetwork("lora", "qux", 1.5)]) + self.assertEqual( + prompt.positive_phrases, + [ + PromptPhrase(["foo"]), + PromptPhrase(["bar"]), + PromptPhrase(["baz"], weight=1.5), + ], + ) + + def test_compile_weights(self): + prompt = compile_prompt_onnx("foo ((bar)) baz [[qux]] bun ([nest] me)") + + self.assertEqual( + prompt.positive_phrases, + [ + PromptPhrase(["foo"]), + PromptPhrase(["bar"], weight=2.25), + PromptPhrase(["baz"]), + PromptPhrase(["qux"], weight=0.25), + PromptPhrase(["bun"]), + PromptPhrase(["nest"], weight=0.75), + PromptPhrase(["me"], weight=1.5), + ], + ) + + def test_compile_runs(self): + prompt = compile_prompt_onnx("foo bar (baz) ") + prompt.collapse_runs() + + self.assertEqual( + prompt.positive_phrases, + [ + PromptPhrase(["foo bar"]), + PromptPhrase(["baz"], weight=1.5), ], ) diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py index 322712e4b..11c04426f 100644 --- a/api/tests/test_diffusers/test_run.py +++ b/api/tests/test_diffusers/test_run.py @@ -17,11 +17,13 @@ Border, HighresParams, ImageParams, + RequestParams, Size, TileOrder, UpscaleParams, ) from onnx_web.server.context import ServerContext +from onnx_web.worker.command import JobCommand from onnx_web.worker.context import WorkerContext from tests.helpers import ( TEST_MODEL_DIFFUSION_SD15, @@ -45,9 +47,10 @@ def test_basic(self): active = Value("L", 0) idle = Value("L", 0) + device = test_device() worker = WorkerContext( "test", - test_device(), + device, cancel, logs, pending, @@ -57,11 +60,14 @@ def test_basic(self): 3, 0.1, ) - worker.start("test") + worker.start( + JobCommand( + "test-txt2img-basic", "test", "test", run_txt2img_pipeline, [], {} + ) + ) - run_txt2img_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), + params = RequestParams( + device, ImageParams( TEST_MODEL_DIFFUSION_SD15, "txt2img", @@ -71,16 +77,22 @@ def test_basic(self): 1, 1, ), - Size(256, 256), - ["test-txt2img-basic.png"], - UpscaleParams("test"), - HighresParams(False, 1, 0, 0), + size=Size(256, 256), + upscale=UpscaleParams("test"), + highres=HighresParams(False, 1, 0, 0), + ) + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + params, ) - self.assertTrue(path.exists("../outputs/test-txt2img-basic.png")) - output = Image.open("../outputs/test-txt2img-basic.png") - self.assertEqual(output.size, (256, 256)) - # TODO: test contents of image + self.assertTrue(path.exists("../outputs/test-txt2img-basic_0.png")) + + with Image.open("../outputs/test-txt2img-basic_0.png") as output: + self.assertEqual(output.size, (256, 256)) + # TODO: test contents of image @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) def test_batch(self): @@ -91,9 +103,10 @@ def test_batch(self): active = Value("L", 0) idle = Value("L", 0) + device = test_device() worker = WorkerContext( "test", - test_device(), + device, cancel, logs, pending, @@ -103,11 +116,14 @@ def test_batch(self): 3, 0.1, ) - worker.start("test") + worker.start( + JobCommand( + "test-txt2img-batch", "test", "test", run_txt2img_pipeline, [], {} + ) + ) - run_txt2img_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), + params = RequestParams( + device, ImageParams( TEST_MODEL_DIFFUSION_SD15, "txt2img", @@ -118,18 +134,23 @@ def test_batch(self): 1, batch=2, ), - Size(256, 256), - ["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"], - UpscaleParams("test"), - HighresParams(False, 1, 0, 0), + size=Size(256, 256), + upscale=UpscaleParams("test"), + highres=HighresParams(False, 1, 0, 0), ) - self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png")) - self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png")) + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + params, + ) + + self.assertTrue(path.exists("../outputs/test-txt2img-batch_0.png")) + self.assertTrue(path.exists("../outputs/test-txt2img-batch_1.png")) - output = Image.open("../outputs/test-txt2img-batch-0.png") - self.assertEqual(output.size, (256, 256)) - # TODO: test contents of image + with Image.open("../outputs/test-txt2img-batch_0.png") as output: + self.assertEqual(output.size, (256, 256)) + # TODO: test contents of image @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) def test_highres(self): @@ -140,9 +161,10 @@ def test_highres(self): active = Value("L", 0) idle = Value("L", 0) + device = test_device() worker = WorkerContext( "test", - test_device(), + device, cancel, logs, pending, @@ -152,11 +174,14 @@ def test_highres(self): 3, 0.1, ) - worker.start("test") + worker.start( + JobCommand( + "test-txt2img-highres", "test", "test", run_txt2img_pipeline, [], {} + ) + ) - run_txt2img_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), + params = RequestParams( + device, ImageParams( TEST_MODEL_DIFFUSION_SD15, "txt2img", @@ -167,15 +192,20 @@ def test_highres(self): 1, unet_tile=256, ), - Size(256, 256), - ["test-txt2img-highres.png"], - UpscaleParams("test", scale=2, outscale=2), - HighresParams(True, 2, 0, 0), + size=Size(256, 256), + upscale=UpscaleParams("test", scale=2, outscale=2), + highres=HighresParams(True, 2, 0, 0), + ) + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + params, ) - self.assertTrue(path.exists("../outputs/test-txt2img-highres.png")) - output = Image.open("../outputs/test-txt2img-highres.png") - self.assertEqual(output.size, (512, 512)) + self.assertTrue(path.exists("../outputs/test-txt2img-highres_0.png")) + with Image.open("../outputs/test-txt2img-highres_0.png") as output: + self.assertEqual(output.size, (512, 512)) @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) def test_highres_batch(self): @@ -186,9 +216,10 @@ def test_highres_batch(self): active = Value("L", 0) idle = Value("L", 0) + device = test_device() worker = WorkerContext( "test", - test_device(), + device, cancel, logs, pending, @@ -198,11 +229,19 @@ def test_highres_batch(self): 3, 0.1, ) - worker.start("test") + worker.start( + JobCommand( + "test-txt2img-highres-batch", + "test", + "test", + run_txt2img_pipeline, + [], + {}, + ) + ) - run_txt2img_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), + params = RequestParams( + device, ImageParams( TEST_MODEL_DIFFUSION_SD15, "txt2img", @@ -213,62 +252,75 @@ def test_highres_batch(self): 1, batch=2, ), - Size(256, 256), - ["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"], - UpscaleParams("test"), - HighresParams(True, 2, 0, 0), + size=Size(256, 256), + upscale=UpscaleParams("test"), + highres=HighresParams(True, 2, 0, 0), + ) + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + params, ) - self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png")) - self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png")) + self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch_0.png")) + self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch_1.png")) - output = Image.open("../outputs/test-txt2img-highres-batch-0.png") - self.assertEqual(output.size, (512, 512)) + with Image.open("../outputs/test-txt2img-highres-batch_0.png") as output: + self.assertEqual(output.size, (512, 512)) class TestImg2ImgPipeline(unittest.TestCase): @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) def test_basic(self): worker = test_worker() - worker.start("test") + worker.start( + JobCommand("test-img2img", "test", "test", run_txt2img_pipeline, [], {}) + ) source = Image.new("RGB", (64, 64), "black") - run_img2img_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), + params = RequestParams( + test_device(), ImageParams( TEST_MODEL_DIFFUSION_SD15, - "txt2img", + "img2img", TEST_SCHEDULER, TEST_PROMPT, 3.0, 1, 1, ), - ["test-img2img.png"], - UpscaleParams("test"), - HighresParams(False, 1, 0, 0), + upscale=UpscaleParams("test"), + highres=HighresParams(False, 1, 0, 0), + ) + run_img2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + params, source, 1.0, ) - self.assertTrue(path.exists("../outputs/test-img2img.png")) + self.assertTrue(path.exists("../outputs/test-img2img_0.png")) class TestInpaintPipeline(unittest.TestCase): @test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT]) def test_basic_white(self): worker = test_worker() - worker.start("test") + worker.start( + JobCommand( + "test-inpaint-white", "test", "test", run_txt2img_pipeline, [], {} + ) + ) source = Image.new("RGB", (64, 64), "black") mask = Image.new("RGB", (64, 64), "white") - run_inpaint_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), + params = RequestParams( + test_device(), ImageParams( TEST_MODEL_DIFFUSION_SD15_INPAINT, - "txt2img", + "inpaint", TEST_SCHEDULER, TEST_PROMPT, 3.0, @@ -276,13 +328,18 @@ def test_basic_white(self): 1, unet_tile=64, ), - Size(*source.size), - ["test-inpaint-white.png"], - UpscaleParams("test"), - HighresParams(False, 1, 0, 0), + border=Border.even(0), + size=Size(*source.size), + upscale=UpscaleParams("test"), + highres=HighresParams(False, 1, 0, 0), + ) + + run_inpaint_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + params, source, mask, - Border.even(0), noise_source_uniform, mask_filter_none, "white", @@ -291,21 +348,24 @@ def test_basic_white(self): 0.0, ) - self.assertTrue(path.exists("../outputs/test-inpaint-white.png")) + self.assertTrue(path.exists("../outputs/test-inpaint-white_0.png")) @test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT]) def test_basic_black(self): worker = test_worker() - worker.start("test") + worker.start( + JobCommand( + "test-inpaint-black", "test", "test", run_txt2img_pipeline, [], {} + ) + ) source = Image.new("RGB", (64, 64), "black") mask = Image.new("RGB", (64, 64), "black") - run_inpaint_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), + params = RequestParams( + test_device(), ImageParams( TEST_MODEL_DIFFUSION_SD15_INPAINT, - "txt2img", + "inpaint", TEST_SCHEDULER, TEST_PROMPT, 3.0, @@ -313,13 +373,18 @@ def test_basic_black(self): 1, unet_tile=64, ), - Size(*source.size), - ["test-inpaint-black.png"], - UpscaleParams("test"), - HighresParams(False, 1, 0, 0), + border=Border.even(0), + size=Size(*source.size), + upscale=UpscaleParams("test"), + highres=HighresParams(False, 1, 0, 0), + ) + + run_inpaint_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + params, source, mask, - Border.even(0), noise_source_uniform, mask_filter_none, "black", @@ -328,7 +393,7 @@ def test_basic_black(self): 0.0, ) - self.assertTrue(path.exists("../outputs/test-inpaint-black.png")) + self.assertTrue(path.exists("../outputs/test-inpaint-black_0.png")) class TestUpscalePipeline(unittest.TestCase): @@ -341,9 +406,10 @@ def test_basic(self): active = Value("L", 0) idle = Value("L", 0) + device = test_device() worker = WorkerContext( "test", - test_device(), + device, cancel, logs, pending, @@ -353,12 +419,13 @@ def test_basic(self): 3, 0.1, ) - worker.start("test") + worker.start( + JobCommand("test-upscale", "test", "test", run_upscale_pipeline, [], {}) + ) source = Image.new("RGB", (64, 64), "black") - run_upscale_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), + params = RequestParams( + device, ImageParams( "../models/upscaling-stable-diffusion-x4", "txt2img", @@ -368,14 +435,18 @@ def test_basic(self): 1, 1, ), - Size(256, 256), - ["test-upscale.png"], - UpscaleParams("test"), - HighresParams(False, 1, 0, 0), + size=Size(256, 256), + upscale=UpscaleParams("test"), + highres=HighresParams(False, 1, 0, 0), + ) + run_upscale_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + params, source, ) - self.assertTrue(path.exists("../outputs/test-upscale.png")) + self.assertTrue(path.exists("../outputs/test-upscale_0.png")) class TestBlendPipeline(unittest.TestCase): @@ -387,9 +458,10 @@ def test_basic(self): active = Value("L", 0) idle = Value("L", 0) + device = test_device() worker = WorkerContext( "test", - test_device(), + device, cancel, logs, pending, @@ -399,13 +471,14 @@ def test_basic(self): 3, 0.1, ) - worker.start("test") + worker.start( + JobCommand("test-blend", "test", "test", run_blend_pipeline, [], {}) + ) source = Image.new("RGBA", (64, 64), "black") mask = Image.new("RGBA", (64, 64), "white") - run_blend_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), + params = RequestParams( + device, ImageParams( TEST_MODEL_DIFFUSION_SD15, "txt2img", @@ -416,11 +489,15 @@ def test_basic(self): 1, unet_tile=64, ), - Size(64, 64), - ["test-blend.png"], - UpscaleParams("test"), + size=Size(64, 64), + upscale=UpscaleParams("test"), + ) + run_blend_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + params, [source, source], mask, ) - self.assertTrue(path.exists("../outputs/test-blend.png")) + self.assertTrue(path.exists("../outputs/test-blend_0.png")) diff --git a/api/tests/test_diffusers/test_scheduler.py b/api/tests/test_diffusers/test_scheduler.py new file mode 100644 index 000000000..bdef9a57f --- /dev/null +++ b/api/tests/test_diffusers/test_scheduler.py @@ -0,0 +1,64 @@ +import unittest +from unittest.mock import MagicMock + +import numpy as np +import torch +from diffusers.schedulers.scheduling_utils import SchedulerOutput +from numpy.testing import assert_array_equal + +from onnx_web.diffusers.patches.scheduler import ( + SchedulerPatch, + linear_gradient, + mirror_latents, +) + + +class SchedulerPatchTests(unittest.TestCase): + def test_scheduler_step(self): + wrapped_scheduler = MagicMock() + wrapped_scheduler.step.return_value = SchedulerOutput(None) + scheduler = SchedulerPatch(None, wrapped_scheduler, None) + model_output = torch.FloatTensor([1.0, 2.0, 3.0]) + timestep = torch.Tensor([0.1]) + sample = torch.FloatTensor([0.5, 0.6, 0.7]) + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, SchedulerOutput) + + def test_mirror_latents_horizontal(self): + latents = np.array( + [ # batch + [ # channels + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], + ], + ] + ) + white_point = 0 + black_point = 1 + center_line = 2 + direction = "horizontal" + gradient = linear_gradient(white_point, black_point, center_line) + mirrored_latents = mirror_latents(latents, gradient, center_line, direction) + assert_array_equal(mirrored_latents, latents) + + def test_mirror_latents_vertical(self): + latents = np.array( + [ # batch + [ # channels + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], + ], + ] + ) + white_point = 0 + black_point = 1 + center_line = 3 + direction = "vertical" + gradient = linear_gradient(white_point, black_point, center_line) + mirrored_latents = mirror_latents(latents, gradient, center_line, direction) + assert_array_equal( + mirrored_latents, + [ + [ + [[0, 0, 0], [0, 0, 0], [10, 11, 12], [7, 8, 9]], + ] + ], + ) diff --git a/api/tests/test_utils.py b/api/tests/test_utils.py new file mode 100644 index 000000000..04c2818d6 --- /dev/null +++ b/api/tests/test_utils.py @@ -0,0 +1,99 @@ +import unittest + +from onnx_web.utils import ( + get_and_clamp_float, + get_and_clamp_int, + get_boolean, + get_from_list, + get_from_map, + get_list, + get_not_empty, + split_list, +) + + +class TestUtils(unittest.TestCase): + def test_split_list_empty(self): + self.assertEqual(split_list(""), []) + self.assertEqual(split_list(" "), []) + self.assertEqual(split_list(" , "), []) + + def test_split_list_single(self): + self.assertEqual(split_list("a"), ["a"]) + self.assertEqual(split_list(" a "), ["a"]) + self.assertEqual(split_list(" a, "), ["a"]) + self.assertEqual(split_list(" a , "), ["a"]) + + def test_split_list_multiple(self): + self.assertEqual(split_list("a,b"), ["a", "b"]) + self.assertEqual(split_list(" a , b "), ["a", "b"]) + self.assertEqual(split_list(" a, b "), ["a", "b"]) + self.assertEqual(split_list(" a ,b "), ["a", "b"]) + + def test_get_boolean_empty(self): + self.assertFalse(get_boolean({}, "key", False)) + self.assertTrue(get_boolean({}, "key", True)) + + def test_get_boolean_true(self): + self.assertTrue(get_boolean({"key": True}, "key", False)) + self.assertTrue(get_boolean({"key": True}, "key", True)) + + def test_get_boolean_false(self): + self.assertFalse(get_boolean({"key": False}, "key", False)) + self.assertFalse(get_boolean({"key": False}, "key", True)) + + def test_get_list_empty(self): + self.assertEqual(get_list({}, "key", ""), []) + self.assertEqual(get_list({}, "key", "a"), ["a"]) + + def test_get_list_exists(self): + self.assertEqual(get_list({"key": "a,b"}, "key", ""), ["a", "b"]) + self.assertEqual(get_list({"key": "a,b"}, "key", "c"), ["a", "b"]) + + def test_get_and_clamp_float_empty(self): + self.assertEqual(get_and_clamp_float({}, "key", 0.0, 1.0), 0.0) + self.assertEqual(get_and_clamp_float({}, "key", 1.0, 1.0), 1.0) + + def test_get_and_clamp_float_clamped(self): + self.assertEqual(get_and_clamp_float({"key": -1.0}, "key", 0.0, 1.0), 0.0) + self.assertEqual(get_and_clamp_float({"key": 2.0}, "key", 0.0, 1.0), 1.0) + + def test_get_and_clamp_float_normal(self): + self.assertEqual(get_and_clamp_float({"key": 0.5}, "key", 0.0, 1.0), 0.5) + + def test_get_and_clamp_int_empty(self): + self.assertEqual(get_and_clamp_int({}, "key", 0, 1), 1) + self.assertEqual(get_and_clamp_int({}, "key", 1, 1), 1) + + def test_get_and_clamp_int_clamped(self): + self.assertEqual(get_and_clamp_int({"key": 0}, "key", 1, 1), 1) + self.assertEqual(get_and_clamp_int({"key": 2}, "key", 1, 1), 1) + + def test_get_and_clamp_int_normal(self): + self.assertEqual(get_and_clamp_int({"key": 1}, "key", 0, 1), 1) + + def test_get_from_list_empty(self): + self.assertEqual(get_from_list({}, "key", ["a", "b"]), "a") + self.assertEqual(get_from_list({}, "key", ["a", "b"], "a"), "a") + + def test_get_from_list_exists(self): + self.assertEqual(get_from_list({"key": "a"}, "key", ["a", "b"]), "a") + self.assertEqual(get_from_list({"key": "b"}, "key", ["a", "b"]), "b") + + def test_get_from_list_invalid(self): + self.assertEqual(get_from_list({"key": "c"}, "key", ["a", "b"]), "a") + + def test_get_from_map_empty(self): + self.assertEqual(get_from_map({}, "key", {"a": 1, "b": 2}, "a"), 1) + self.assertEqual(get_from_map({}, "key", {"a": 1, "b": 2}, "b"), 2) + + def test_get_from_map_exists(self): + self.assertEqual(get_from_map({"key": "a"}, "key", {"a": 1, "b": 2}, "a"), 1) + self.assertEqual(get_from_map({"key": "b"}, "key", {"a": 1, "b": 2}, "a"), 2) + + def test_get_not_empty_empty(self): + self.assertEqual(get_not_empty({}, "key", "a"), "a") + self.assertEqual(get_not_empty({"key": ""}, "key", "a"), "a") + + def test_get_not_empty_exists(self): + self.assertEqual(get_not_empty({"key": "b"}, "key", "a"), "b") diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py index ea7091563..ea1708187 100644 --- a/api/tests/worker/test_pool.py +++ b/api/tests/worker/test_pool.py @@ -5,21 +5,31 @@ from onnx_web.params import DeviceParams from onnx_web.server.context import ServerContext +from onnx_web.worker.command import JobStatus, Progress from onnx_web.worker.pool import DevicePoolExecutor +from tests.helpers import test_device TEST_JOIN_TIMEOUT = 0.2 lock = Event() -def test_job(*args, **kwargs): +def lock_job(*args, **kwargs): lock.wait() -def wait_job(*args, **kwargs): +def sleep_job(*args, **kwargs): sleep(0.5) +def progress_job(worker, *args, **kwargs): + worker.set_progress(1) + + +def fail_job(*args, **kwargs): + raise RuntimeError("job failed") + + class TestWorkerPool(unittest.TestCase): # lock: Optional[Event] pool: Optional[DevicePoolExecutor] @@ -37,30 +47,33 @@ def test_no_devices(self): self.pool.start() def test_fake_worker(self): - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool.start() self.assertEqual(len(self.pool.workers), 1) def test_cancel_pending(self): - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool.start() - self.pool.submit("test", wait_job, lock=lock) - self.assertEqual(self.pool.done("test"), (True, None)) + self.pool.submit("test", "test", sleep_job, lock=lock) + self.assertEqual( + self.pool.status("test"), (JobStatus.PENDING, None, Progress(0, 1)) + ) self.assertTrue(self.pool.cancel("test")) - self.assertEqual(self.pool.done("test"), (False, None)) + self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None, None)) + @unittest.skip("TODO") def test_cancel_running(self): pass def test_next_device(self): - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool.start() @@ -82,29 +95,33 @@ def test_done_running(self): """ TODO: flaky """ - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor( server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 ) + + lock.clear() self.pool.start(lock) - self.pool.submit("test", test_job) + self.pool.submit("test", "test", lock_job) sleep(5.0) - pending, _progress = self.pool.done("test") - self.assertFalse(pending) + status, _progress, _status = self.pool.status("test") + self.assertEqual(status, JobStatus.RUNNING) def test_done_pending(self): - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool.start(lock) - self.pool.submit("test1", test_job) - self.pool.submit("test2", test_job) - self.assertTrue(self.pool.done("test2"), (True, None)) + self.pool.submit("test1", "test", lock_job) + self.pool.submit("test2", "test", lock_job) + self.assertEqual( + self.pool.status("test2"), (JobStatus.PENDING, None, Progress(1, 2)) + ) lock.set() @@ -112,25 +129,66 @@ def test_done_finished(self): """ TODO: flaky """ - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor( server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 ) self.pool.start() - self.pool.submit("test", wait_job) - self.assertEqual(self.pool.done("test"), (True, None)) + self.pool.submit("test", "test", sleep_job) + self.assertEqual( + self.pool.status("test"), (JobStatus.PENDING, None, Progress(0, 1)) + ) sleep(5.0) - pending, _progress = self.pool.done("test") - self.assertFalse(pending) + status, _progress, _queue = self.pool.status("test") + self.assertEqual(status, JobStatus.SUCCESS) + @unittest.skip("TODO") def test_recycle_live(self): pass + @unittest.skip("TODO") def test_recycle_dead(self): pass + @unittest.skip("TODO") def test_running_status(self): pass + + @unittest.skip("TODO") + def test_progress_update(self): + pass + + def test_progress_finished(self): + device = test_device() + server = ServerContext() + + self.pool = DevicePoolExecutor( + server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 + ) + self.pool.start() + + self.pool.submit("test", "test", progress_job) + sleep(5.0) + + status, progress, _queue = self.pool.status("test") + self.assertEqual(status, JobStatus.SUCCESS) + self.assertEqual(progress.steps.current, 1) + + def test_progress_failed(self): + device = test_device() + server = ServerContext() + + self.pool = DevicePoolExecutor( + server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 + ) + self.pool.start() + + self.pool.submit("test", "test", fail_job) + sleep(5.0) + + status, progress, _queue = self.pool.status("test") + self.assertEqual(status, JobStatus.FAILED) + self.assertEqual(progress.steps.current, 0) diff --git a/api/tests/worker/test_worker.py b/api/tests/worker/test_worker.py index 6365fac9e..a0684fff4 100644 --- a/api/tests/worker/test_worker.py +++ b/api/tests/worker/test_worker.py @@ -40,7 +40,7 @@ def exit(exit_status): nonlocal status status = exit_status - job = JobCommand("test", "test", main_interrupt, [], {}) + job = JobCommand("test", "test", "test", main_interrupt, [], {}) cancel = Value("L", False) logs = Queue() pending = Queue() @@ -75,7 +75,7 @@ def exit(exit_status): nonlocal status status = exit_status - job = JobCommand("test", "test", main_retry, [], {}) + job = JobCommand("test", "test", "test", main_retry, [], {}) cancel = Value("L", False) logs = Queue() pending = Queue() @@ -144,7 +144,7 @@ def exit(exit_status): nonlocal status status = exit_status - job = JobCommand("test", "test", main_memory, [], {}) + job = JobCommand("test", "test", "test", main_memory, [], {}) cancel = Value("L", False) logs = Queue() pending = Queue() diff --git a/docs/api.md b/docs/api.md index 93892a909..b462d891b 100644 --- a/docs/api.md +++ b/docs/api.md @@ -8,6 +8,10 @@ - [GUI bundle](#gui-bundle) - [`GET /`](#get-) - [`GET /`](#get-path) + - [Jobs](#jobs) + - [`POST /api/job`](#post-apijob) + - [`PUT /api/job/cancel`](#put-apijobcancel) + - [`GET /api/job/status`](#get-apijobstatus) - [Settings and parameters](#settings-and-parameters) - [`GET /api`](#get-api) - [`GET /api/settings/filters`](#get-apisettingsfilters) @@ -51,11 +55,132 @@ Usually includes: - `config.json` - `index.html` +### Jobs + +#### `POST /api/job` + +Create a new job that will run a chain pipeline provided in the `POST` body. + +#### `PUT /api/job/cancel` + +Cancel one or more jobs by name, provided in the `jobs` query parameter. + +Request: + +```shell +> curl http://localhost:5000/api/job/status?jobs=job-1,job-2,job-3,job-4 +``` + +Response: + +```json +[ + { + "name": "job-1", + "status": "cancelled" + }, + { + "name": "job-2", + "status": "pending" + } +] +``` + +#### `GET /api/job/status` + +Get the status of one or more jobs by name, provided in the `jobs` query parameter. + +Request: + +```shell +> curl http://localhost:5000/api/job/status?jobs=job-1,job-2,job-3,job-4 +``` + +Response: + +```json +[ + { + "metadata": [ + // metadata for each output + ], + "name": "job-1", + "outputs": [ + "txt2img_job_1.png" + ], + "queue": { + "current": 0, + "total": 0 + }, + "stages": { + "current": 4, + "total": 6 + }, + "status": "running", + "steps": { + "current": 120, + "total": 78 + }, + "tiles": { + "current": 0, + "total": 16 + } + }, + { + "name": "job-2", + "queue": { + "current": 2, + "total": 3 + }, + "stages": { + "current": 0, + "total": 0 + }, + "status": "pending", + "steps": { + "current": 0, + "total": 0 + }, + "tiles": { + "current": 0, + "total": 0 + } + } +] +``` + ### Settings and parameters #### `GET /api` -Introspection endpoint. +API introspection endpoint. + +Returns a JSON document with all of the available API endpoints and valid methods for them: + +```json +{ + "name": "onnx-web", + "routes": [ + { + "methods": [ + "HEAD", + "GET", + "OPTIONS" + ], + "path": "/static/:filename" + }, + { + "methods": [ + "HEAD", + "GET", + "OPTIONS" + ], + "path": "/" + }, + ... + ] +} +``` #### `GET /api/settings/filters` diff --git a/docs/index.md b/docs/index.md index c0cfffc3f..e839f1ac8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,7 +3,7 @@ onnx-web is designed to simplify the process of running Stable Diffusion and other [ONNX models](https://onnx.ai) so you can focus on making high quality, high resolution art. With the efficiency of hardware acceleration on both AMD and Nvidia GPUs, and offering a reliable CPU software fallback, it offers the full feature set on desktop, laptops, and -servers with a seamless user experience. +multi-GPU servers with a seamless user experience. You can navigate through the user-friendly web UI, hosted on Github Pages and accessible across all major browsers, including your go-to mobile device. Here, you have the flexibility to choose diffusion models and accelerators for each @@ -84,18 +84,6 @@ This is an incomplete list of new and interesting features: - includes both the API and GUI bundle in a single container - runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services -## Contents - -- [onnx-web](#onnx-web) - - [Features](#features) - - [Contents](#contents) - - [Setup](#setup) - - [Adding your own models](#adding-your-own-models) - - [Usage](#usage) - - [Known errors and solutions](#known-errors-and-solutions) - - [Running the containers](#running-the-containers) - - [Credits](#credits) - ## Setup There are a few ways to run onnx-web: diff --git a/docs/server-admin.md b/docs/server-admin.md index 530fdcec2..5bb3864a1 100644 --- a/docs/server-admin.md +++ b/docs/server-admin.md @@ -17,6 +17,7 @@ Please see [the user guide](user-guide.md) for descriptions of the client and ea - [Hosting the client](#hosting-the-client) - [Customizing the client config](#customizing-the-client-config) - [Configuration](#configuration) + - [Client Parameters](#client-parameters) - [Debug Mode](#debug-mode) - [Environment Variables](#environment-variables) - [Client Variables](#client-variables) @@ -185,6 +186,23 @@ custom config using: Configuration is still very simple, loading models from a directory and parameters from a single JSON file. Some additional configuration can be done through environment variables starting with `ONNX_WEB`. +For the web GUI or client, most configuration is provided by the server. Some additional options can be set using +the query string. + +### Client Parameters + +- `api` + - the root URL of the server you intend to use + - `?api=http://localhost:5000` + - `?api=https://generate.your-server.ai` +- `interval` + - the polling interval for status updates + - `?interval=500` for LCM and Turbo +- `lng` + - the language to use for localization + - `de`, `en`, `es`, and `fr` are supported + - `dev` will show localization labels + ### Debug Mode Setting the `DEBUG` variable to any value except `false` will enable debug mode, which will print garbage @@ -308,6 +326,10 @@ These extra images can be helpful when debugging inpainting, especially poorly b - `panorama-highres` - when using the panorama pipeline with highres, prefer panorama views over stage tiling +- `horde-safety` and `horde-safety-nsfw` + - enable [the horde-safety plugin](https://github.com/Haidra-Org/horde-safety) + - plugin always blocks CSAM, enable `horde-safety-nsfw` to block NSFW as well + - recommended for shared servers, may be required by some model licenses ### Pipeline Optimizations @@ -330,12 +352,14 @@ These extra images can be helpful when debugging inpainting, especially poorly b - `onnx-*` - `onnx-cpu-*` - CPU offloading for individual models - - `onnx-cpu-text-encoder` - - recommended for SDXL highres - - `onnx-cpu-unet` - - not recommended - - `onnx-cpu-vae` - - may be necessary for SDXL highres + - upscaling models: `bsrgan`, `esrgan`, `swinir` + - diffusion models: `controlnet`, `scheduler`, `text-encoder`, `unet`, `vae` + - diffusion models can use `-sdxl` suffix for CPU offloading only in SDXL pipelines + - common flags: + - `onnx-cpu-text-encoder` + - recommended for SDXL highres + - `onnx-cpu-vae-sdxl` + - may be necessary for SDXL highres with limited VRAM - `onnx-deterministic-compute` - enable ONNX deterministic compute - `onnx-fp16` diff --git a/gui/Makefile b/gui/Makefile index b4d21f7a0..f35c134cf 100644 --- a/gui/Makefile +++ b/gui/Makefile @@ -22,6 +22,7 @@ docs-local: build: deps yarn tsc + cp -v src/components/main.css out/src/components/ build-shebang: build sed -i '1s;^;#! /usr/bin/env node\n\n;g' $(shell pwd)/out/src/main.js @@ -35,6 +36,7 @@ bundle: build # copy everything into the server's default path cp -v src/index.html ../api/gui/ cp -v src/config.json ../api/gui/ + cp -v out/bundle/main.css ../api/gui/bundle/ cp -v out/bundle/main.js ../api/gui/bundle/ COVER_OPTS := --all \ diff --git a/gui/package.json b/gui/package.json index afc832b55..f0cbc1834 100644 --- a/gui/package.json +++ b/gui/package.json @@ -15,7 +15,9 @@ "@mui/material": "^5.12.0", "@tanstack/react-query": "^4.0.5", "@types/lodash": "^4.14.192", - "@types/node": "^18.15.11", + "@types/node": "^20.11.0", + "@yornaath/batshit": "^0.9.0", + "allotment": "^1.19.5", "browser-bunyan": "^1.8.0", "exifreader": "^4.13.0", "i18next": "^22.4.14", @@ -25,9 +27,9 @@ "react": "^18.2.0", "react-dom": "^18.2.0", "react-i18next": "^12.2.0", - "react-use": "^17.4.0", + "react-use": "^17.4.3", "semver": "^7.4.0", - "tslib": "^2.5.0", + "tslib": "^2.6.2", "zustand": "^4.3.7" }, "devDependencies": { diff --git a/gui/src/Motd.tsx b/gui/src/Motd.tsx new file mode 100644 index 000000000..f0feb7286 --- /dev/null +++ b/gui/src/Motd.tsx @@ -0,0 +1,30 @@ +import React, { useState } from 'react'; +import { Alert, Collapse, IconButton } from '@mui/material'; +import { Close } from '@mui/icons-material'; +import { useTranslation } from 'react-i18next'; + +export function Motd() { + const [open, setOpen] = useState(true); + const { t } = useTranslation(); + + return + { + setOpen(false); + }} + > + + + } + severity='info' + sx={{ mb: 2 }} + > + {t('motd')} + + ; +} diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 8ea871eaa..6c12e1741 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -1,13 +1,15 @@ +/* eslint-disable max-params */ +/* eslint-disable camelcase */ /* eslint-disable max-lines */ import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils'; +import { create as batcher, keyResolver, windowedFiniteBatchScheduler } from '@yornaath/batshit'; import { ServerParams } from '../config.js'; +import { FIXED_FLOAT, FIXED_INTEGER, STATUS_SUCCESS } from '../constants.js'; +import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js'; import { FilterResponse, - ImageResponse, - ImageResponseWithRetry, ModelResponse, - ReadyResponse, RetryParams, WriteExtrasResponse, } from '../types/api.js'; @@ -16,11 +18,16 @@ import { ExtrasFile } from '../types/model.js'; import { BaseImgParams, BlendParams, + ExperimentalParams, HighresParams, + ImageSize, + Img2ImgJSONParams, Img2ImgParams, + InpaintJSONParams, InpaintParams, ModelParams, OutpaintParams, + OutpaintPixels, Txt2ImgParams, UpscaleParams, UpscaleReqParams, @@ -28,23 +35,8 @@ import { import { range } from '../utils.js'; import { ApiClient } from './base.js'; -/** - * Fixed precision for integer parameters. - */ -export const FIXED_INTEGER = 0; - -/** - * Fixed precision for float parameters. - * - * The GUI limits the input steps based on the server parameters, but this does limit - * the maximum precision that can be sent back to the server, and may have to be - * increased in the future. - */ -export const FIXED_FLOAT = 2; -export const STATUS_SUCCESS = 200; - -export function equalResponse(a: ImageResponse, b: ImageResponse): boolean { - return a.outputs === b.outputs; +export function equalResponse(a: JobResponse, b: JobResponse): boolean { + return a.name === b.name; } /** @@ -57,15 +49,150 @@ export function joinPath(...parts: Array): string { /** * Build the URL to an API endpoint, given the API root and a list of segments. */ -export function makeApiUrl(root: string, ...path: Array) { +export function makeApiURL(root: string, ...path: Array): URL { return new URL(joinPath('api', ...path), root); } +export interface ImageJSON { + model: ModelParams; + base: BaseImgParams; + size?: ImageSize; + border?: OutpaintPixels; + upscale?: UpscaleParams; + highres?: HighresParams; + img2img?: Img2ImgJSONParams; + inpaint?: InpaintJSONParams; + experimental?: ExperimentalParams; +} + +export interface JSONInner { + [key: string]: string | number | boolean | undefined | JSONInner; +} + +export interface JSONBody extends JSONInner { + params: JSONInner; +} + +export function makeImageJSON(params: ImageJSON): string { + const { model, base, img2img, inpaint, size, border, upscale, highres, experimental } = params; + + const body: JSONBody = { + device: { + platform: model.platform, + }, + params: { + // model params + model: model.model, + pipeline: model.pipeline, + upscaling: model.upscaling, + correction: model.correction, + control: model.control, + // image params + batch: base.batch, + cfg: base.cfg, + eta: base.eta, + steps: base.steps, + tiled_vae: base.tiled_vae, + unet_overlap: base.unet_overlap, + unet_tile: base.unet_tile, + vae_overlap: base.vae_overlap, + vae_tile: base.vae_tile, + scheduler: base.scheduler, + seed: base.seed, + prompt: base.prompt, + negativePrompt: base.negativePrompt, + }, + }; + + if (doesExist(img2img)) { + body.params = { + ...body.params, + loopback: img2img.loopback, + sourceFilter: img2img.sourceFilter, + strength: img2img.strength, + }; + } + + if (doesExist(inpaint)) { + body.params = { + ...body.params, + filter: inpaint.filter, + noise: inpaint.noise, + strength: inpaint.strength, + fillColor: inpaint.fillColor, + tileOrder: inpaint.tileOrder, + }; + } + + if (doesExist(size)) { + body.params = { + ...body.params, + width: size.width, + height: size.height, + }; + } + + if (doesExist(border) && border.enabled) { + body.border = { + left: border.left, + right: border.right, + top: border.top, + bottom: border.bottom, + }; + } + + if (doesExist(upscale)) { + body.upscale = { + enabled: upscale.enabled, + upscaleOrder: upscale.upscaleOrder, + denoise: upscale.denoise, + scale: upscale.scale, + outscale: upscale.outscale, + faces: upscale.faces, + faceOutscale: upscale.faceOutscale, + faceStrength: upscale.faceStrength, + }; + } + + if (doesExist(highres)) { + body.highres = { + highres: highres.enabled, + highresIterations: highres.highresIterations, + highresMethod: highres.highresMethod, + highresScale: highres.highresScale, + highresSteps: highres.highresSteps, + highresStrength: highres.highresStrength, + }; + } + + if (doesExist(experimental)) { + body.experimental = { + latentSymmetry: { + enabled: experimental.latentSymmetry.enabled, + gradientStart: experimental.latentSymmetry.gradientStart, + gradientEnd: experimental.latentSymmetry.gradientEnd, + lineOfSymmetry: experimental.latentSymmetry.lineOfSymmetry, + }, + promptEditing: { + enabled: experimental.promptEditing.enabled, + addSuffix: experimental.promptEditing.addSuffix, + minLength: experimental.promptEditing.minLength, + promptFilter: experimental.promptEditing.filter, + removeTokens: experimental.promptEditing.removeTokens, + }, + }; + } + + return JSON.stringify(body); +} + /** * Build the URL for an image request, including all of the base image parameters. + * + * @deprecated use `makeImageJSON` and `makeApiURL` instead */ export function makeImageURL(root: string, type: string, params: BaseImgParams): URL { - const url = makeApiUrl(root, type); + const url = makeApiURL(root, type); url.searchParams.append('batch', params.batch.toFixed(FIXED_INTEGER)); url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT)); url.searchParams.append('eta', params.eta.toFixed(FIXED_FLOAT)); @@ -96,6 +223,8 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams): /** * Append the model parameters to an existing URL. + * + * @deprecated use `makeImageJSON` instead */ export function appendModelToURL(url: URL, params: ModelParams) { url.searchParams.append('model', params.model); @@ -108,6 +237,8 @@ export function appendModelToURL(url: URL, params: ModelParams) { /** * Append the upscale parameters to an existing URL. + * + * @deprecated use `makeImageJSON` instead */ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) { url.searchParams.append('upscale', String(upscale.enabled)); @@ -126,6 +257,11 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) { } } +/** + * Append the highres parameters to an existing URL. + * + * @deprecated use `makeImageJSON` instead + */ export function appendHighresToURL(url: URL, highres: HighresParams) { if (highres.enabled) { url.searchParams.append('highres', String(highres.enabled)); @@ -140,14 +276,14 @@ export function appendHighresToURL(url: URL, highres: HighresParams) { /** * Make an API client using the given API root and fetch client. */ -export function makeClient(root: string, token: Maybe = undefined, f = fetch): ApiClient { - function parseRequest(url: URL, options: RequestInit): Promise { - return f(url, options).then((res) => parseApiResponse(root, res)); +export function makeClient(root: string, batchInterval: number, token: Maybe = undefined, f = fetch): ApiClient { + function parseRequest(url: URL, options: RequestInit): Promise { + return f(url, options).then((res) => parseJobResponse(root, res)); } - return { + const client = { async extras(): Promise { - const path = makeApiUrl(root, 'extras'); + const path = makeApiURL(root, 'extras'); if (doesExist(token)) { path.searchParams.append('token', token); @@ -157,7 +293,7 @@ export function makeClient(root: string, token: Maybe = undefined, f = f return await res.json() as ExtrasFile; }, async writeExtras(extras: ExtrasFile): Promise { - const path = makeApiUrl(root, 'extras'); + const path = makeApiURL(root, 'extras'); if (doesExist(token)) { path.searchParams.append('token', token); @@ -170,82 +306,75 @@ export function makeClient(root: string, token: Maybe = undefined, f = f return await res.json() as WriteExtrasResponse; }, async filters(): Promise { - const path = makeApiUrl(root, 'settings', 'filters'); + const path = makeApiURL(root, 'settings', 'filters'); const res = await f(path); return await res.json() as FilterResponse; }, async models(): Promise { - const path = makeApiUrl(root, 'settings', 'models'); + const path = makeApiURL(root, 'settings', 'models'); const res = await f(path); return await res.json() as ModelResponse; }, async noises(): Promise> { - const path = makeApiUrl(root, 'settings', 'noises'); + const path = makeApiURL(root, 'settings', 'noises'); const res = await f(path); return await res.json() as Array; }, async params(): Promise { - const path = makeApiUrl(root, 'settings', 'params'); + const path = makeApiURL(root, 'settings', 'params'); const res = await f(path); return await res.json() as ServerParams; }, async schedulers(): Promise> { - const path = makeApiUrl(root, 'settings', 'schedulers'); + const path = makeApiURL(root, 'settings', 'schedulers'); const res = await f(path); return await res.json() as Array; }, async pipelines(): Promise> { - const path = makeApiUrl(root, 'settings', 'pipelines'); + const path = makeApiURL(root, 'settings', 'pipelines'); const res = await f(path); return await res.json() as Array; }, async platforms(): Promise> { - const path = makeApiUrl(root, 'settings', 'platforms'); + const path = makeApiURL(root, 'settings', 'platforms'); const res = await f(path); return await res.json() as Array; }, async strings(): Promise; }>> { - const path = makeApiUrl(root, 'settings', 'strings'); + const path = makeApiURL(root, 'settings', 'strings'); const res = await f(path); return await res.json() as Record; }>; }, async wildcards(): Promise> { - const path = makeApiUrl(root, 'settings', 'wildcards'); + const path = makeApiURL(root, 'settings', 'wildcards'); const res = await f(path); return await res.json() as Array; }, - async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { - const url = makeImageURL(root, 'img2img', params); - appendModelToURL(url, model); - - url.searchParams.append('loopback', params.loopback.toFixed(FIXED_INTEGER)); - url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT)); - - if (doesExist(params.sourceFilter)) { - url.searchParams.append('sourceFilter', params.sourceFilter); - } - - if (doesExist(upscale)) { - appendUpscaleToURL(url, upscale); - } - - if (doesExist(highres)) { - appendHighresToURL(url, highres); - } + async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise { + const url = makeApiURL(root, 'img2img'); + const json = makeImageJSON({ + model, + base: params, + upscale, + highres, + img2img: params, + experimental, + }); - const body = new FormData(); - body.append('source', params.source, 'source'); + const form = new FormData(); + form.append('json', json); + form.append('source', params.source, 'source'); - const image = await parseRequest(url, { - body, + const job = await parseRequest(url, { + body: form, method: 'POST', }); return { - image, + job, retry: { type: 'img2img', model, @@ -254,31 +383,26 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }, }; }, - async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { - const url = makeImageURL(root, 'txt2img', params); - appendModelToURL(url, model); - - if (doesExist(params.width)) { - url.searchParams.append('width', params.width.toFixed(FIXED_INTEGER)); - } - - if (doesExist(params.height)) { - url.searchParams.append('height', params.height.toFixed(FIXED_INTEGER)); - } - - if (doesExist(upscale)) { - appendUpscaleToURL(url, upscale); - } + async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise { + const url = makeApiURL(root, 'txt2img'); + const json = makeImageJSON({ + model, + base: params, + size: params, + upscale, + highres, + experimental, + }); - if (doesExist(highres)) { - appendHighresToURL(url, highres); - } + const form = new FormData(); + form.append('json', json); - const image = await parseRequest(url, { + const job = await parseRequest(url, { + body: form, method: 'POST', }); return { - image, + job, retry: { type: 'txt2img', model, @@ -288,33 +412,28 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }, }; }, - async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { - const url = makeImageURL(root, 'inpaint', params); - appendModelToURL(url, model); - - url.searchParams.append('filter', params.filter); - url.searchParams.append('noise', params.noise); - url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT)); - url.searchParams.append('fillColor', params.fillColor); - - if (doesExist(upscale)) { - appendUpscaleToURL(url, upscale); - } - - if (doesExist(highres)) { - appendHighresToURL(url, highres); - } + async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise { + const url = makeApiURL(root, 'inpaint'); + const json = makeImageJSON({ + model, + base: params, + upscale, + highres, + inpaint: params, + experimental, + }); - const body = new FormData(); - body.append('mask', params.mask, 'mask'); - body.append('source', params.source, 'source'); + const form = new FormData(); + form.append('json', json); + form.append('mask', params.mask, 'mask'); + form.append('source', params.source, 'source'); - const image = await parseRequest(url, { - body, + const job = await parseRequest(url, { + body: form, method: 'POST', }); return { - image, + job, retry: { type: 'inpaint', model, @@ -323,50 +442,29 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }, }; }, - async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { - const url = makeImageURL(root, 'inpaint', params); - appendModelToURL(url, model); - - url.searchParams.append('filter', params.filter); - url.searchParams.append('noise', params.noise); - url.searchParams.append('strength', params.strength.toFixed(FIXED_FLOAT)); - url.searchParams.append('fillColor', params.fillColor); - url.searchParams.append('tileOrder', params.tileOrder); - - if (doesExist(upscale)) { - appendUpscaleToURL(url, upscale); - } - - if (doesExist(highres)) { - appendHighresToURL(url, highres); - } - - if (doesExist(params.left)) { - url.searchParams.append('left', params.left.toFixed(FIXED_INTEGER)); - } - - if (doesExist(params.right)) { - url.searchParams.append('right', params.right.toFixed(FIXED_INTEGER)); - } - - if (doesExist(params.top)) { - url.searchParams.append('top', params.top.toFixed(FIXED_INTEGER)); - } - - if (doesExist(params.bottom)) { - url.searchParams.append('bottom', params.bottom.toFixed(FIXED_INTEGER)); - } + async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise { + const url = makeApiURL(root, 'inpaint'); + const json = makeImageJSON({ + model, + base: params, + border: params, + upscale, + highres, + inpaint: params, + experimental, + }); - const body = new FormData(); - body.append('mask', params.mask, 'mask'); - body.append('source', params.source, 'source'); + const form = new FormData(); + form.append('json', json); + form.append('mask', params.mask, 'mask'); + form.append('source', params.source, 'source'); - const image = await parseRequest(url, { - body, + const job = await parseRequest(url, { + body: form, method: 'POST', }); return { - image, + job, retry: { type: 'outpaint', model, @@ -375,33 +473,25 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }, }; }, - async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { - const url = makeApiUrl(root, 'upscale'); - appendModelToURL(url, model); - - if (doesExist(upscale)) { - appendUpscaleToURL(url, upscale); - } - - if (doesExist(highres)) { - appendHighresToURL(url, highres); - } - - url.searchParams.append('prompt', params.prompt); - - if (doesExist(params.negativePrompt)) { - url.searchParams.append('negativePrompt', params.negativePrompt); - } + async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { + const url = makeApiURL(root, 'upscale'); + const json = makeImageJSON({ + model, + base: params, + upscale, + highres, + }); - const body = new FormData(); - body.append('source', params.source, 'source'); + const form = new FormData(); + form.append('json', json); + form.append('source', params.source, 'source'); - const image = await parseRequest(url, { - body, + const job = await parseRequest(url, { + body: form, method: 'POST', }); return { - image, + job, retry: { type: 'upscale', model, @@ -410,28 +500,29 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }, }; }, - async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise { - const url = makeApiUrl(root, 'blend'); - appendModelToURL(url, model); - - if (doesExist(upscale)) { - appendUpscaleToURL(url, upscale); - } + async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise { + const url = makeApiURL(root, 'blend'); + const json = makeImageJSON({ + model, + base: params as unknown as BaseImgParams, // TODO: fix this + upscale, + }); - const body = new FormData(); - body.append('mask', params.mask, 'mask'); + const form = new FormData(); + form.append('json', json); + form.append('mask', params.mask, 'mask'); for (const i of range(params.sources.length)) { const name = `source:${i.toFixed(0)}`; - body.append(name, params.sources[i], name); + form.append(name, params.sources[i], name); } - const image = await parseRequest(url, { - body, + const job = await parseRequest(url, { + body: form, method: 'POST', }); return { - image, + job, retry: { type: 'blend', model, @@ -440,8 +531,8 @@ export function makeClient(root: string, token: Maybe = undefined, f = f } }; }, - async chain(model: ModelParams, chain: ChainPipeline): Promise { - const url = makeApiUrl(root, 'chain'); + async chain(model: ModelParams, chain: ChainPipeline): Promise { + const url = makeApiURL(root, 'job'); const body = JSON.stringify({ ...chain, platform: model.platform, @@ -456,23 +547,23 @@ export function makeClient(root: string, token: Maybe = undefined, f = f method: 'POST', }); }, - async ready(key: string): Promise { - const path = makeApiUrl(root, 'ready'); - path.searchParams.append('output', key); + async status(keys: Array): Promise> { + const path = makeApiURL(root, 'job', 'status'); + path.searchParams.append('jobs', keys.join(',')); const res = await f(path); - return await res.json() as ReadyResponse; + return await res.json() as Array; }, - async cancel(key: string): Promise { - const path = makeApiUrl(root, 'cancel'); - path.searchParams.append('output', key); + async cancel(keys: Array): Promise> { + const path = makeApiURL(root, 'job', 'cancel'); + path.searchParams.append('jobs', keys.join(',')); const res = await f(path, { method: 'PUT', }); - return res.status === STATUS_SUCCESS; + return await res.json() as Array; }, - async retry(retry: RetryParams): Promise { + async retry(retry: RetryParams): Promise { switch (retry.type) { case 'blend': return this.blend(retry.model, retry.params, retry.upscale); @@ -491,7 +582,7 @@ export function makeClient(root: string, token: Maybe = undefined, f = f } }, async restart(): Promise { - const path = makeApiUrl(root, 'restart'); + const path = makeApiURL(root, 'worker', 'restart'); if (doesExist(token)) { path.searchParams.append('token', token); @@ -502,8 +593,8 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }); return res.status === STATUS_SUCCESS; }, - async status(): Promise> { - const path = makeApiUrl(root, 'status'); + async workers(): Promise> { + const path = makeApiURL(root, 'worker', 'status'); if (doesExist(token)) { path.searchParams.append('token', token); @@ -512,6 +603,32 @@ export function makeClient(root: string, token: Maybe = undefined, f = f const res = await f(path); return res.json(); }, + outputURL(image: SuccessJobResponse, index: number): string { + return new URL(joinPath('output', image.outputs[index]), root).toString(); + }, + thumbnailURL(image: SuccessJobResponse, index: number): Maybe { + if (doesExist(image.thumbnails) && doesExist(image.thumbnails[index])) { + return new URL(joinPath('output', image.thumbnails[index]), root).toString(); + } + + return undefined; + }, + }; + + const batchStatus = batcher({ + fetcher: async (jobs: Array) => client.status(jobs), + resolver: keyResolver('name'), + scheduler: windowedFiniteBatchScheduler({ + windowMs: batchInterval, + maxBatchSize: 10, + }), + }); + + return { + ...client, + async status(keys): Promise> { + return Promise.all(keys.map((key) => batchStatus.fetch(key))); + }, }; } @@ -521,24 +638,9 @@ export function makeClient(root: string, token: Maybe = undefined, f = f * The server sends over the output key, and the client is in the best position to turn * that into a full URL, since it already knows the root URL of the server. */ -export async function parseApiResponse(root: string, res: Response): Promise { - type LimitedResponse = Omit & { outputs: Array }; - +export async function parseJobResponse(root: string, res: Response): Promise { if (res.status === STATUS_SUCCESS) { - const data = await res.json() as LimitedResponse; - - const outputs = data.outputs.map((output) => { - const url = new URL(joinPath('output', output), root).toString(); - return { - key: output, - url, - }; - }); - - return { - ...data, - outputs, - }; + return await res.json() as JobResponse; } else { throw new Error('request error'); } diff --git a/gui/src/client/base.ts b/gui/src/client/base.ts index 62ef440a6..e9d114871 100644 --- a/gui/src/client/base.ts +++ b/gui/src/client/base.ts @@ -1,12 +1,31 @@ +import { Maybe } from '@apextoaster/js-utils'; import { ServerParams } from '../config.js'; -import { ExtrasFile } from '../types/model.js'; -import { WriteExtrasResponse, FilterResponse, ModelResponse, ImageResponseWithRetry, ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; +import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js'; +import { FilterResponse, ModelResponse, RetryParams, WriteExtrasResponse } from '../types/api.js'; import { ChainPipeline } from '../types/chain.js'; -import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js'; +import { ExtrasFile } from '../types/model.js'; +import { + BlendParams, + ExperimentalParams, + HighresParams, + Img2ImgParams, + InpaintParams, + ModelParams, + OutpaintParams, + Txt2ImgParams, + UpscaleParams, + UpscaleReqParams, +} from '../types/params.js'; export interface ApiClient { + /** + * Get the first extras file. + */ extras(): Promise; + /** + * Update the first extras file. + */ writeExtras(extras: ExtrasFile): Promise; /** @@ -51,54 +70,60 @@ export interface ApiClient { translation: Record; }>>; + /** + * Get the available wildcards. + */ wildcards(): Promise>; /** * Start a txt2img pipeline. */ - txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise; /** * Start an im2img pipeline. */ - img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise; /** * Start an inpaint pipeline. */ - inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise; /** * Start an outpaint pipeline. */ - outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams, experimental?: ExperimentalParams): Promise; /** * Start an upscale pipeline. */ - upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; /** * Start a blending pipeline. */ - blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise; + blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise; - chain(model: ModelParams, chain: ChainPipeline): Promise; + /** + * Start a custom chain pipeline. + */ + chain(model: ModelParams, chain: ChainPipeline): Promise; /** * Check whether job has finished and its output is ready. */ - ready(key: string): Promise; + status(keys: Array): Promise>; /** * Cancel an existing job. */ - cancel(key: string): Promise; + cancel(keys: Array): Promise>; /** * Retry a previous job using the same parameters. */ - retry(params: RetryParams): Promise; + retry(params: RetryParams): Promise; /** * Restart the image job workers. @@ -108,5 +133,9 @@ export interface ApiClient { /** * Check the status of the image job workers. */ - status(): Promise>; + workers(): Promise>; + + outputURL(image: SuccessJobResponse, index: number): string; + + thumbnailURL(image: SuccessJobResponse, index: number): Maybe; } diff --git a/gui/src/client/local.ts b/gui/src/client/local.ts index 06dcd6b0d..eefbaa12c 100644 --- a/gui/src/client/local.ts +++ b/gui/src/client/local.ts @@ -48,7 +48,7 @@ export const LOCAL_CLIENT = { async params() { throw new NoServerError(); }, - async ready(key) { + async status(key) { throw new NoServerError(); }, async cancel(key) { @@ -78,7 +78,13 @@ export const LOCAL_CLIENT = { async restart() { throw new NoServerError(); }, - async status() { + async workers() { throw new NoServerError(); - } + }, + outputURL(image, index) { + throw new NoServerError(); + }, + thumbnailURL(image, index) { + throw new NoServerError(); + }, } as ApiClient; diff --git a/gui/src/client/utils.ts b/gui/src/client/utils.ts index 967c818da..903bbfebe 100644 --- a/gui/src/client/utils.ts +++ b/gui/src/client/utils.ts @@ -48,14 +48,16 @@ export function newSeed(): number { return Math.floor(Math.random() * MAX_SEED); } +// eslint-disable-next-line @typescript-eslint/no-magic-numbers +export const RANDOM_SEED = [-1, '-1']; + export function replaceRandomSeeds(key: string, values: Array): Array { if (key !== 'seed') { return values; } return values.map((it) => { - // eslint-disable-next-line @typescript-eslint/no-magic-numbers - if (it === '-1' || it === -1) { + if (RANDOM_SEED.includes(it)) { return newSeed(); } @@ -97,11 +99,19 @@ export function expandRanges(range: string): Array { export const GRID_TILE_SIZE = 8192; // eslint-disable-next-line max-params -export function makeTxt2ImgGridPipeline(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline { +export function makeTxt2ImgGridPipeline( + grid: PipelineGrid, + model: ModelParams, + params: Txt2ImgParams, + upscale?: UpscaleParams, + highres?: HighresParams, +): ChainPipeline { const pipeline: ChainPipeline = { defaults: { ...model, ...params, + ...(upscale ?? {}), + ...(highres ?? {}), }, stages: [], }; diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx index 20a520ed9..1793923a3 100644 --- a/gui/src/components/ImageHistory.tsx +++ b/gui/src/components/ImageHistory.tsx @@ -1,17 +1,25 @@ -import { doesExist, mustExist } from '@apextoaster/js-utils'; +import { mustExist } from '@apextoaster/js-utils'; import { Grid, Typography } from '@mui/material'; -import { ReactNode, useContext } from 'react'; import * as React from 'react'; +import { ReactNode, useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; +import { STANDARD_SPACING } from '../constants.js'; import { OnnxState, StateContext } from '../state/full.js'; +import { JobStatus } from '../types/api-v2.js'; import { ErrorCard } from './card/ErrorCard.js'; import { ImageCard } from './card/ImageCard.js'; import { LoadingCard } from './card/LoadingCard.js'; -export function ImageHistory() { +export interface ImageHistoryProps { + width: number; +} + +export function ImageHistory(props: ImageHistoryProps) { + const { width } = props; + const store = mustExist(useContext(StateContext)); const { history, limit } = useStore(store, selectParams, shallow); const { removeHistory } = useStore(store, selectActions, shallow); @@ -25,22 +33,29 @@ export function ImageHistory() { const limited = history.slice(0, limit); for (const item of limited) { - const key = item.image.outputs[0].key; - - if (doesExist(item.ready) && item.ready.ready) { - if (item.ready.cancelled || item.ready.failed) { - children.push([key, ]); - continue; - } + const key = item.image.name; - children.push([key, ]); - continue; + switch (item.image.status) { + case JobStatus.SUCCESS: + children.push([key, ]); + break; + case JobStatus.UNKNOWN: + case JobStatus.FAILED: + children.push([key, ]); + break; + default: + children.push([key, ]); + break; } - - children.push([key, ]); } - return {children.map(([key, child]) => {child})}; + return { + // eslint-disable-next-line @typescript-eslint/no-magic-numbers + children.map(([key, child]) => {child}) + }; } export function selectActions(state: OnnxState) { diff --git a/gui/src/components/LoadingScreen.tsx b/gui/src/components/LoadingScreen.tsx index 8f41df300..b26a41dd7 100644 --- a/gui/src/components/LoadingScreen.tsx +++ b/gui/src/components/LoadingScreen.tsx @@ -2,6 +2,8 @@ import { Box, CircularProgress, Stack, Typography } from '@mui/material'; import * as React from 'react'; import { useTranslation } from 'react-i18next'; +import { STANDARD_SPACING } from '../constants'; + export function LoadingScreen() { const { t } = useTranslation(); @@ -13,7 +15,7 @@ export function LoadingScreen() { }}> diff --git a/gui/src/components/OnnxError.tsx b/gui/src/components/OnnxError.tsx index 7e0cbe009..916c7f6bf 100644 --- a/gui/src/components/OnnxError.tsx +++ b/gui/src/components/OnnxError.tsx @@ -2,6 +2,7 @@ import { Box, Button, Container, Stack, Typography } from '@mui/material'; import * as React from 'react'; import { ReactNode } from 'react'; +import { STANDARD_MARGIN, STANDARD_SPACING } from '../constants.js'; import { STATE_KEY } from '../state/full.js'; import { Logo } from './Logo.js'; @@ -20,11 +21,11 @@ export function OnnxError(props: OnnxErrorProps) { return ( - + - - + + {props.children} This is a web UI for running ONNX models with GPU acceleration or in software, running locally or on a diff --git a/gui/src/components/OnnxWeb.tsx b/gui/src/components/OnnxWeb.tsx index 69c2db202..abd772762 100644 --- a/gui/src/components/OnnxWeb.tsx +++ b/gui/src/components/OnnxWeb.tsx @@ -1,13 +1,18 @@ import { mustExist } from '@apextoaster/js-utils'; import { TabContext, TabList, TabPanel } from '@mui/lab'; -import { Box, Container, CssBaseline, Divider, Tab, useMediaQuery } from '@mui/material'; +import { Box, Container, CssBaseline, Divider, Stack, Tab, useMediaQuery } from '@mui/material'; import { createTheme, ThemeProvider } from '@mui/material/styles'; +import { Allotment } from 'allotment'; import * as React from 'react'; import { useContext, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; +import { LAYOUT_MIN, LAYOUT_PROPORTIONS, LAYOUT_STYLES, STANDARD_MARGIN, STANDARD_SPACING } from '../constants.js'; +import { Motd } from '../Motd.js'; import { OnnxState, StateContext } from '../state/full.js'; +import { Layout } from '../state/settings.js'; import { ImageHistory } from './ImageHistory.js'; import { Logo } from './Logo.js'; import { Blend } from './tab/Blend.js'; @@ -19,11 +24,22 @@ import { Txt2Img } from './tab/Txt2Img.js'; import { Upscale } from './tab/Upscale.js'; import { getTab, getTheme, TAB_LABELS } from './utils.js'; -export function OnnxWeb() { +import 'allotment/dist/style.css'; +import './main.css'; + +export interface OnnxWebProps { + motd: boolean; +} + +export function OnnxWeb(props: OnnxWebProps) { /* checks for system light/dark mode preference */ const prefersDarkMode = useMediaQuery('(prefers-color-scheme: dark)'); const store = mustExist(useContext(StateContext)); const stateTheme = useStore(store, selectTheme); + const historyWidth = useStore(store, selectHistoryWidth); + const direction = useStore(store, selectDirection); + + const layout = LAYOUT_STYLES[direction]; const theme = useMemo( () => createTheme({ @@ -34,49 +50,15 @@ export function OnnxWeb() { [prefersDarkMode, stateTheme], ); - const [hash, setHash] = useHash(); - return ( - - + + - - - { - setHash(idx); - }}> - {TAB_LABELS.map((name) => )} - - - - - - - - - - - - - - - - - - - - - - - - - - - - + {props.motd && } + {renderBody(direction, historyWidth)} ); @@ -85,3 +67,99 @@ export function OnnxWeb() { export function selectTheme(state: OnnxState) { return state.theme; } + +export function selectDirection(state: OnnxState) { + return state.layout; +} + +export function selectHistoryWidth(state: OnnxState) { + return state.historyWidth; +} + +function renderBody(direction: Layout, historyWidth: number) { + if (direction === 'vertical') { + return ; + } else { + return ; + } +} + +// used for both horizontal and vertical +export interface BodyProps { + direction: Layout; + width: number; +} + +export function HorizontalBody(props: BodyProps) { + const layout = LAYOUT_STYLES[props.direction]; + + return + + + + + ; +} + +export function VerticalBody(props: BodyProps) { + const layout = LAYOUT_STYLES[props.direction]; + + return + + + + + + ; +} + +export interface TabGroupProps { + direction: Layout; + panelClass?: string; +} + +export function TabGroup(props: TabGroupProps) { + const layout = LAYOUT_STYLES[props.direction]; + + const [hash, setHash] = useHash(); + const { t } = useTranslation(); + + return + + + { + setHash(idx); + }}> + {TAB_LABELS.map((name) => )} + + + + + + + + + + + + + + + + + + + + + + + + + ; +} diff --git a/gui/src/components/Profiles.tsx b/gui/src/components/Profiles.tsx index f24104ac2..cb5822d20 100644 --- a/gui/src/components/Profiles.tsx +++ b/gui/src/components/Profiles.tsx @@ -18,13 +18,22 @@ import { defaultTo, isString } from 'lodash'; import * as React from 'react'; import { useContext, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import * as useDropModule from 'react-use/lib/useDrop'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; +import { STANDARD_SPACING } from '../constants.js'; import { OnnxState, StateContext } from '../state/full.js'; -import { ImageMetadata } from '../types/api.js'; +import { AnyImageMetadata } from '../types/api-v2.js'; import { DeepPartial } from '../types/model.js'; import { BaseImgParams, HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; +import { downloadAsJson } from '../utils.js'; + +// useDrop has a really weird export +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const useDrop = (useDropModule.default as any).default as typeof useDropModule['default']; + +export type PartialImageMetadata = DeepPartial; export const ALLOWED_EXTENSIONS = ['.json','.jpg','.jpeg','.png','.txt','.webp']; export const EXTENSION_FILTER = ALLOWED_EXTENSIONS.join(','); @@ -50,7 +59,36 @@ export function Profiles(props: ProfilesProps) { const [profileName, setProfileName] = useState(''); const { t } = useTranslation(); - return + async function loadFromMetadata(metadata: PartialImageMetadata) { + // TODO: load model parameters + + if (doesExist(metadata.params)) { + props.setParams(metadata.params); + } + + if (doesExist(metadata.highres)) { + props.setHighres(metadata.highres); + } + + if (doesExist(metadata.upscale)) { + props.setUpscale(metadata.upscale); + } + } + + async function loadFromFile(file: File) { + await loadParamsFromFile(file).then(loadFromMetadata); + } + + useDrop({ + onFiles(files, event) { + event.preventDefault(); + const file = files[0]; + // eslint-disable-next-line @typescript-eslint/no-floating-promises + loadFromFile(file); + }, + }); + + return { const state = store.getState(); saveProfile({ - model: props.selectModel(state), - params: props.selectParams(state), name: profileName, highres: props.selectHighres(state), + model: props.selectModel(state), + params: props.selectParams(state), upscale: props.selectUpscale(state), }); setDialogOpen(false); @@ -142,21 +180,7 @@ export function Profiles(props: ProfilesProps) { if (doesExist(files) && files.length > 0) { const file = mustExist(files[0]); // eslint-disable-next-line @typescript-eslint/no-floating-promises - loadParamsFromFile(file).then((newParams) => { - // TODO: load model parameters - - if (doesExist(newParams.params)) { - props.setParams(newParams.params); - } - - if (doesExist(newParams.highres)) { - props.setHighres(newParams.highres); - } - - if (doesExist(newParams.upscale)) { - props.setUpscale(newParams.upscale); - } - }); + loadFromFile(file); } }} onClick={(event) => { @@ -166,9 +190,9 @@ export function Profiles(props: ProfilesProps) { @@ -125,3 +97,68 @@ export function selectActions(state: OnnxState) { setReady: state.setReady, }; } + +export function selectStatus(data: Maybe>, defaultData: JobResponse) { + if (doesExist(data) && data.length > 0) { + return { + steps: data[0].steps, + stages: data[0].stages, + tiles: data[0].tiles, + }; + } + + return { + steps: defaultData.steps, + stages: defaultData.stages, + tiles: defaultData.tiles, + }; +} + +export function getPercent(current: number, total: number): number { + if (current > total) { + // steps was not complete, show 99% until done + return LOADING_OVERAGE; + } + + const pct = current / total; + return Math.ceil(pct * LOADING_PERCENT); +} + +export function getProgress(data: Maybe>) { + if (doesExist(data)) { + return data[0].steps.current; + } + + return 0; +} + +export function getTotal(data: Maybe>) { + if (doesExist(data)) { + return data[0].steps.total; + } + + return 0; +} + +function getStatus(data: Maybe>) { + if (doesExist(data)) { + return data[0].status; + } + + return JobStatus.PENDING; +} + +function getQueue(data: Maybe>) { + if (doesExist(data) && data[0].status === JobStatus.PENDING) { + const { current, total } = data[0].queue; + return { + current: visibleIndex(current), + total: total.toFixed(0), + }; + } + + return { + current: '0', + total: '0', + }; +} diff --git a/gui/src/components/control/ExperimentalControl.tsx b/gui/src/components/control/ExperimentalControl.tsx new file mode 100644 index 000000000..393dc8fc0 --- /dev/null +++ b/gui/src/components/control/ExperimentalControl.tsx @@ -0,0 +1,193 @@ +/* eslint-disable camelcase */ +import { mustDefault, mustExist } from '@apextoaster/js-utils'; +import { Accordion, AccordionDetails, AccordionSummary, Checkbox, FormControlLabel, Stack, TextField } from '@mui/material'; +import { useTranslation } from 'react-i18next'; +import * as React from 'react'; +import { useContext } from 'react'; +import { useQuery } from '@tanstack/react-query'; +import { useStore } from 'zustand'; + +import { STALE_TIME, STANDARD_SPACING } from '../../constants.js'; +import { NumericField } from '../input/NumericField.js'; +import { QueryList } from '../input/QueryList.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { ExperimentalParams } from '../../types/params.js'; + +export interface ExperimentalControlProps { + selectExperimental(state: OnnxState): ExperimentalParams; + setExperimental(params: Record): void; +} + +export function ExperimentalControl(props: ExperimentalControlProps) { + // eslint-disable-next-line @typescript-eslint/unbound-method + const { selectExperimental, setExperimental } = props; + + const store = mustExist(React.useContext(StateContext)); + const experimental = useStore(store, selectExperimental); + + const { params } = mustExist(useContext(ConfigContext)); + const { t } = useTranslation(); + + const client = mustExist(React.useContext(ClientContext)); + const filters = useQuery(['filters'], async () => client.filters(), { + staleTime: STALE_TIME, + }); + + return + {t('experimental.label')} + + + + { + setExperimental({ + promptEditing: { + ...experimental.promptEditing, + enabled: experimental.promptEditing.enabled === false, + }, + }); + }} + />} + /> + f.prompt, + }} + value={mustDefault(experimental.promptEditing.filter, '')} + onChange={(prompt_filter) => { + setExperimental({ + promptEditing: { + ...experimental.promptEditing, + filter: prompt_filter, + }, + }); + }} + /> + { + setExperimental({ + promptEditing: { + ...experimental.promptEditing, + removeTokens: event.target.value, + }, + }); + }} + /> + { + setExperimental({ + promptEditing: { + ...experimental.promptEditing, + addSuffix: event.target.value, + }, + }); + }} + /> + { + setExperimental({ + promptEditing: { + ...experimental.promptEditing, + minLength: prompt_editing_min_length, + }, + }); + }} + /> + + + { + setExperimental({ + latentSymmetry: { + ...experimental.latentSymmetry, + enabled: experimental.latentSymmetry.enabled === false, + }, + }); + }} + />} + /> + { + setExperimental({ + latentSymmetry: { + ...experimental.latentSymmetry, + gradientStart: latent_symmetry_gradient_start, + }, + }); + }} + /> + { + setExperimental({ + latentSymmetry: { + ...experimental.latentSymmetry, + gradientEnd: latent_symmetry_gradient_end, + }, + }); + }} + /> + { + setExperimental({ + latentSymmetry: { + ...experimental.latentSymmetry, + lineOfSymmetry: latent_symmetry_line_of_symmetry, + }, + }); + }} + /> + + + + ; +} diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index c0faa7b56..4d55cf4c8 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -10,7 +10,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { STALE_TIME } from '../../config.js'; +import { STALE_TIME, STANDARD_SPACING } from '../../constants.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { BaseImgParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; @@ -46,7 +46,7 @@ export function ImageControl(props: ImageControlProps) { staleTime: STALE_TIME, }); - return + return ): void; + tab: JobType; } export function ModelControl(props: ModelControlProps) { // eslint-disable-next-line @typescript-eslint/unbound-method - const { model, setModel } = props; + const { model, setModel, tab } = props; const client = mustExist(useContext(ClientContext)); const { t } = useTranslation(); @@ -33,7 +35,7 @@ export function ModelControl(props: ModelControlProps) { staleTime: STALE_TIME, }); - return + return filterValidPipelines(result, tab), }} value={model.pipeline} onChange={(pipeline) => { @@ -113,3 +116,46 @@ export function ModelControl(props: ModelControlProps) { >{t('admin.restart')} ; } + +// plugin pipelines will show up on all tabs for now +export const PIPELINE_TABS: Record> = { + 'txt2img': [JobType.TXT2IMG], + 'txt2img-sdxl': [JobType.TXT2IMG], + 'panorama': [JobType.TXT2IMG, JobType.IMG2IMG], + 'panorama-sdxl': [JobType.TXT2IMG, JobType.IMG2IMG], + 'lpw': [JobType.TXT2IMG, JobType.IMG2IMG, JobType.INPAINT], + 'img2img': [JobType.IMG2IMG], + 'img2img-sdxl': [JobType.IMG2IMG], + 'controlnet': [JobType.IMG2IMG], + 'pix2pix': [JobType.IMG2IMG], + 'inpaint': [JobType.INPAINT], + 'upscale': [JobType.UPSCALE], +}; + +export const DEFAULT_PIPELINE: Record = { + [JobType.TXT2IMG]: 'txt2img', + [JobType.IMG2IMG]: 'img2img', + [JobType.INPAINT]: 'inpaint', + [JobType.UPSCALE]: 'upscale', + [JobType.BLEND]: '', + [JobType.CHAIN]: '', +}; + +export const FIRST_A = -1; +export const FIRST_B = +1; + +export function filterValidPipelines(pipelines: Array, tab: JobType): Array { + const defaultPipeline = DEFAULT_PIPELINE[tab]; + return pipelines.filter((pipeline) => PIPELINE_TABS[pipeline].includes(tab)).sort((a, b) => { + // put validPipelines.default first + if (a === defaultPipeline) { + return FIRST_A; + } + + if (b === defaultPipeline) { + return FIRST_B; + } + + return a.localeCompare(b); + }); +} diff --git a/gui/src/components/control/VariableControl.tsx b/gui/src/components/control/VariableControl.tsx index bf5139e18..f10c4d44e 100644 --- a/gui/src/components/control/VariableControl.tsx +++ b/gui/src/components/control/VariableControl.tsx @@ -7,6 +7,7 @@ import { useStore } from 'zustand'; import { PipelineGrid } from '../../client/utils.js'; import { OnnxState, StateContext } from '../../state/full.js'; import { VARIABLE_PARAMETERS } from '../../types/chain.js'; +import { STANDARD_SPACING } from '../../constants.js'; export interface VariableControlProps { selectGrid: (state: OnnxState) => PipelineGrid; @@ -20,7 +21,7 @@ export function VariableControl(props: VariableControlProps) { const grid = useStore(store, props.selectGrid); const stack = [ - + + Columns props.setGrid({ @@ -78,7 +79,7 @@ export function VariableControl(props: VariableControlProps) { ); } - return {...stack}; + return {...stack}; } export function parameterList(exclude?: Array) { diff --git a/gui/src/components/error/ParamsVersion.tsx b/gui/src/components/error/ParamsVersion.tsx index bb141af08..bc5df7eb1 100644 --- a/gui/src/components/error/ParamsVersion.tsx +++ b/gui/src/components/error/ParamsVersion.tsx @@ -1,7 +1,7 @@ import { Alert, AlertTitle, Typography } from '@mui/material'; import * as React from 'react'; -import { PARAM_VERSION } from '../../config.js'; +import { PARAM_VERSION } from '../../constants'; export interface ParamsVersionErrorProps { root: string; diff --git a/gui/src/components/input/EditableList.tsx b/gui/src/components/input/EditableList.tsx index 8c0d50263..761dbfee1 100644 --- a/gui/src/components/input/EditableList.tsx +++ b/gui/src/components/input/EditableList.tsx @@ -5,6 +5,7 @@ import { memo, useContext, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; +import { STANDARD_SPACING } from '../../constants.js'; import { OnnxState, StateContext } from '../../state/full.js'; export interface EditableListProps { @@ -30,7 +31,7 @@ export function EditableList(props: EditableListProps) { const [nextSource, setNextSource] = useState(''); const RenderMemo = useMemo(() => memo(renderItem), [renderItem]); - return + return {items.map((model, idx) => (props: EditableListProps) { onRemove={removeItem} /> )} - + ; @@ -35,7 +37,7 @@ export function ImageInput(props: ImageInputProps) { } } - return + return