diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cda61f7d..1c9d66b9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -45,8 +45,6 @@ jobs: uses: actions/setup-python@v4 with: python-version: "3.10" - cache: pip - cache-dependency-path: requirements-dev.txt - name: Install dependencies run: | python -m pip install uv @@ -108,6 +106,8 @@ jobs: build-wheel: name: Build Wheel runs-on: ubuntu-latest + outputs: + wheel_name: ${{ steps.set_wheel_name.outputs.wheel_name }} steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4.5.0 @@ -117,8 +117,9 @@ jobs: run: python -m pip install wheel - name: Build package run: python setup.py bdist_wheel - - name: Rename wheel - run: mv dist/*.whl dist/imaginAIry-0.0.1b0-py3-none-any.whl + - name: Set wheel filename + id: set_wheel_name + run: echo "wheel_name=$(ls dist/*.whl)" >> "$GITHUB_OUTPUT" - uses: actions/upload-artifact@v3 with: name: wheels @@ -142,11 +143,14 @@ jobs: name: wheels path: dist - name: Install built wheel + env: + WHEEL_FILENAME: ${{ needs.build-wheel.outputs.wheel_name }} run: | - python -m pip install dist/imaginAIry-0.0.1b0-py3-none-any.whl + python -m pip install uv + uv pip install --system ${{ needs.build-wheel.outputs.wheel_name }} - name: Generate an image run: | - imagine fruit --steps 10 --size 512 --seed 1 + imagine fruit --steps 3 --size 512 --seed 1 - uses: actions/upload-artifact@v3 with: name: images @@ -174,14 +178,17 @@ jobs: path: dist - name: Install built wheel shell: bash -l {0} + env: + WHEEL_FILENAME: ${{ steps.set-wheel-name.outputs.WHEEL_FILENAME }} run: | conda activate test-env - python -m pip install dist/imaginAIry-0.0.1b0-py3-none-any.whl + python -m pip install uv + uv pip install ${{ needs.build-wheel.outputs.wheel_name }} - name: Generate an image shell: bash -l {0} run: | conda activate test-env - imagine fruit --steps 10 --size 512 --seed 1 + imagine fruit --steps 3 --size 512 --seed 1 - uses: actions/upload-artifact@v3 with: name: images diff --git a/README.md b/README.md index 7c200866..cc5906d4 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,8 @@ Options: ### Whats New [See full Changelog here](./docs/changelog.md) +**14.2.1** +- feature: integrates spandrel for upscaling. **14.2.0** - ๐ŸŽ‰ feature: add image prompt support via `--image-prompt` and `--image-prompt-strength` @@ -352,14 +354,28 @@ When writing strength modifiers keep in mind that pixel values are between 0 and -### Upscaling [by RealESRGAN](https://github.com/xinntao/Real-ESRGAN) -```bash ->> imagine "colorful smoke" --steps 40 --upscale -# upscale an existing image ->> aimg upscale my-image.jpg -``` -
-Python Example +## Image Upscaling +Upscale images easily. + +=== "CLI" + ```bash + aimg upscale assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg --upscale-model real-hat + ``` + +=== "Python" + ```py + from imaginairy.api.upscale import upscale + + img = upscale(img="assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg") + img.save("colorful_smoke.upscaled.jpg") + + ``` + โžก๏ธ + + +Upscaling uses [Spandrel](https://github.com/chaiNNer-org/spandrel) to make it easy to use different upscaling models. +You can view different integrated models by running `aimg upscale --list-models`, and then use it with `--upscale-model `. +Also accepts url's if you want to upscale an image with a different model. Control the new file format/location with --format. ```python from imaginairy.enhancers.upscale_realesrgan import upscale_image @@ -368,9 +384,7 @@ img = Image.open("my-image.jpg") big_img = upscale_image(i) ``` -
- โžก๏ธ - + ### Tiled Images ```bash diff --git a/docs/changelog.md b/docs/changelog.md index aa7bfe46..6003b50d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,8 @@ ## ChangeLog +**14.2.1** + +- feature: integrates spandrel for upscaling **14.2.0** - ๐ŸŽ‰ feature: add image prompt support via `--image-prompt` and `--image-prompt-strength` @@ -9,6 +12,7 @@ - fix: dependency issues **14.1.0** + - ๐ŸŽ‰ feature: make video generation smooth by adding frame interpolation - feature: SDXL weights in the compvis format can now be used - feature: allow video generation at any size specified by user @@ -21,15 +25,18 @@ **14.0.4** + - docs: add a documentation website at https://brycedrennan.github.io/imaginAIry/ - build: remove fairscale dependency - fix: video generation was broken **14.0.3** + - fix: several critical bugs with package - tests: add a wheel smoketest to detect these issues in the future **14.0.0** + - ๐ŸŽ‰ video generation using [Stable Video Diffusion](https://github.com/Stability-AI/generative-models) - add `--videogen` to any image generation to create a short video from the generated image - or use `aimg videogen` to generate a video from an image @@ -63,15 +70,18 @@ For example `--size 720p --seed 1` and `--size 1080p --seed 1` will produce the - broken: samplers other than ddim **13.2.1** + - fix: pydantic models for http server working now. Fixes #380 - fix: install triton so annoying message is gone **13.2.0** + - fix: allow tile_mode to be set to True or False for backward compatibility - fix: various pydantic issues have been resolved - feature: switch to pydantic 2.3 (faster but was a pain to migrate) **13.1.0** + - feature: *api server now has feature parity with the python API*. View the docs at http://127.0.0.1:8000/docs after running `aimg server` - `ImaginePrompt` is now a pydantic model and can thus be sent over the rest API - images are expected in base64 string format @@ -81,11 +91,13 @@ For example `--size 720p --seed 1` and `--size 1080p --seed 1` will produce the - docs: add a discord link **13.0.1** + - feature: show full stack trace when there is an api error - fix: make lack of support for python 3.11 explicit - fix: add some routes to match StableStudio routes **13.0.0** + - ๐ŸŽ‰ feature: multi-controlnet support. pass in multiple `--control-mode`, `--control-image`, and `--control-image-raw` arguments. - ๐ŸŽ‰ feature: add colorization controlnet. improve `aimg colorize` command - ๐ŸŽ‰๐Ÿงช feature: Graphical Web Interface [StableStudio](https://github.com/Stability-AI/StableStudio). run `aimg server` and visit http://127.0.0.1:8000/ @@ -103,12 +115,15 @@ MacOS M1, [torch will not be able to use the M1 when generating images.](https:/ - build: check for torch version at runtime (fixes #329) **12.0.3** + - fix: exclude broken versions of timm as dependencies **12.0.2** + - fix: move normal map preprocessor for conda compatibility **12.0.1** + - fix: use correct device for depth images on mps. Fixes #300 **12.0.0** @@ -123,6 +138,7 @@ MacOS M1, [torch will not be able to use the M1 when generating images.](https:/ - fix: filenames start numbers after latest image, even if some previous images were deleted **11.1.1** + - fix: fix globbing bug with input image path handling - fix: changed sample to True to generate caption using blip model @@ -135,6 +151,7 @@ MacOS M1, [torch will not be able to use the M1 when generating images.](https:/ - fix: fix model downloads that were broken by [library change in transformers 4.27.0](https://github.com/huggingface/transformers/commit/8f3b4a1d5bd97045541c43179efe8cd9c58adb76) **11.0.0** + - all these changes together mean same seed/sampler will not be guaranteed to produce same image (thus the version bump) - fix: image composition didn't work very well. Works well now but probably very slow on non-cuda platforms - fix: remove upscaler tiling message @@ -142,6 +159,7 @@ MacOS M1, [torch will not be able to use the M1 when generating images.](https:/ - fix: img2img was broken for all samplers except plms and ddim when init image strength was >~0.25 **10.2.0** + - feature: input raw control images (a pose, canny map, depth map, etc) directly using `--control-image-raw` This is opposed to current behavior of extracting the control signal from an input image via `--control-image` - feature: `aimg model-list` command lists included models @@ -153,11 +171,13 @@ MacOS M1, [torch will not be able to use the M1 when generating images.](https:/ - docs: pypi docs now link properly to github automatically **10.1.0** + - feature: ๐ŸŽ‰ ControlNet integration! Control the structure of generated images. - feature: `aimg colorize` attempts to use controlnet to colorize images - feature: `--caption-text` command adds text at the bottom left of an image **10.0.1** + - fix: `edit` was broken **10.0.0** @@ -171,9 +191,11 @@ MacOS M1, [torch will not be able to use the M1 when generating images.](https:/ - perf: sliced latent decoding - now possible to make much bigger images. 3310x3310 on 11 GB GPU. **9.0.2** + - fix: edit interface was broken **9.0.1** + - fix: use entry_points for windows since setup.py scripts doesn't work on windows [#239](https://github.com/brycedrennan/imaginAIry/issues/239) **9.0.0** @@ -187,9 +209,11 @@ batch editing of images as requested in [#229](https://github.com/brycedrennan/i - docs: add directions on how to change model cache path **8.3.1** + - fix: init-image-strength type **8.3.0** + - feature: create `gifs` or `mp4s` from any images made in a single run with `--compilation-anim gif` - feature: create a series of images or edits by iterating over a parameter with the `--arg-schedule` argument - feature: `openjourney-v1` and `openjourney-v2` models added. available via `--model openjourney-v2` @@ -199,9 +223,11 @@ batch editing of images as requested in [#229](https://github.com/brycedrennan/i - fix: tile mode was broken since latest perf improvements **8.2.0** + - feature: added `aimg system-info` command to help debug issues **8.1.0** + - feature: some memory optimizations and documentation - feature: surprise-me improvements - feature: image sizes can now be multiples of 8 instead of 64. Inputs will be silently rounded down. @@ -211,44 +237,55 @@ batch editing of images as requested in [#229](https://github.com/brycedrennan/i - fix: make captioning work with alpha pngs **8.0.5** + - fix: bypass huggingface cache retrieval bug **8.0.4** + - fix: limit attention slice size on MacOS machines with 64gb (#175) **8.0.3** + - fix: use python 3.7 compatible lru_cache - fix: use windows compatible filenames **8.0.2** + - fix: hf_hub_download() got an unexpected keyword argument 'token' **8.0.1** + - fix: spelling mistake of "surprise" **8.0.0** + - feature: ๐ŸŽ‰ edit images with instructions alone! - feature: when editing an image add `--gif` to create a comparision gif - feature: `aimg edit --surprise-me --gif my-image.jpg` for some fun pre-programmed edits - feature: prune-ckpt command also removes the non-ema weights **7.6.0** + - fix: default model config was broken - feature: print version with `--version` - feature: ability to load safetensors - feature: ๐ŸŽ‰ outpainting. Examples: `--outpaint up10,down300,left50,right50` or `--outpaint all100` or `--outpaint u100,d200,l300,r400` **7.4.3** + - fix: handle old pytorch lightning imports with a graceful failure (fixes #161) - fix: handle failed image generations better (fixes #83) **7.4.2** + - fix: run face enhancement on GPU for 10x speedup **7.4.1** + - fix: incorrect config files being used for non-1.0 models **7.4.0** + - feature: ๐ŸŽ‰ finetune your own image model. kind of like dreambooth. Read instructions on ["Concept Training"](docs/concept-training.md) page - feature: image prep command. crops to face or other interesting parts of photo - fix: back-compat for hf_hub_download @@ -256,33 +293,41 @@ batch editing of images as requested in [#229](https://github.com/brycedrennan/i - feature: allow specification of model config file **7.3.0** + - feature: ๐ŸŽ‰ depth-based image-to-image generations (and inpainting) - fix: k_euler_a produces more consistent images per seed (randomization respects the seed again) **7.2.0** + - feature: ๐ŸŽ‰ tile in a single dimension ("x" or "y"). This enables, with a bit of luck, generation of 360 VR images. Try this for example: `imagine --tile-x -w 1024 -h 512 "360 degree equirectangular panorama photograph of the mountains" --upscale` **7.1.1** + - fix: memory/speed regression introduced in 6.1.0 - fix: model switching now clears memory better, thus avoiding out of memory errors **7.1.0** + - feature: ๐ŸŽ‰ Stable Diffusion 2.1. Generated people are no longer (completely) distorted. Use with `--model SD-2.1` or `--model SD-2.0-v` **7.0.0** + - feature: negative prompting. `--negative-prompt` or `ImaginePrompt(..., negative_prompt="ugly, deformed, extra arms, etc")` - feature: a default negative prompt is added to all generations. Images in SD-2.0 don't look bad anymore. Images in 1.5 look improved as well. **6.1.2** + - fix: add back in memory-efficient algorithms **6.1.1** + - feature: xformers will be used if available (for faster generation) - fix: version metadata was broken **6.1.0** + - feature: use different default steps and image sizes depending on sampler and model selected - fix: #110 use proper version in image metadata - refactor: solvers all have their own class that inherits from ImageSolver @@ -294,13 +339,16 @@ Use with `--model SD-2.1` or `--model SD-2.0-v` - 768x768 model working for all samplers except PLMS (`--model SD-2.0-v `) **5.1.0** + - feature: add progress image callback **5.0.1** + - fix: support larger images on M1. Fixes #8 - fix: support CPU generation by disabling autocast on CPU. Fixes #81 **5.0.0** + - feature: ๐ŸŽ‰ inpainting support using new inpainting model from RunwayML. It works really well! By default, the inpainting model will automatically be used for any image-masking task - feature: ๐ŸŽ‰ new default sampler makes image generation more than twice as fast @@ -312,10 +360,12 @@ inpainting model will automatically be used for any image-masking task - fix: larger image sizes now work on macOS. fixes #8 **4.1.0** + - feature: allow dynamic switching between models/weights `--model SD-1.5` or `--model SD-1.4` or `--model path/my-custom-weights.ckpt`) - feature: log total progress when generating images (image X out of Y) **4.0.0** + - feature: stable diffusion 1.5 (slightly improved image quality) - feature: dilation and erosion of masks Previously the `+` and `-` characters in a mask (example: `face{+0.1}`) added to the grayscale value of any masked areas. This wasn't very useful. The new behavior is that the mask will expand or contract by the number of pixel specified. The technical terms for this are dilation and erosion. This allows much greater control over the masked area. @@ -325,39 +375,49 @@ inpainting model will automatically be used for any image-masking task - ci: minor logging improvements **3.0.1** + - fix: k-samplers were broken **3.0.0** + - feature: improved safety filter **2.4.0** + - ๐ŸŽ‰ feature: prompt expansion - feature: make (blip) photo captions more descriptive **2.3.1** + - fix: face fidelity default was broken **2.3.0** + - feature: model weights file can be specified via `--model-weights-path` argument at the command line - fix: set face fidelity default back to old value - fix: handle small images without throwing exception. credit to @NiclasEriksen - docs: add setuptools-rust as dependency for macos **2.2.1** + - fix: init image is fully ignored if init-image-strength = 0 **2.2.0** + - feature: face enhancement fidelity is now configurable **2.1.0** + - [improved masking accuracy from clipseg](https://github.com/timojl/clipseg/issues/8#issuecomment-1259150865) **2.0.3** + - fix memory leak in face enhancer - fix blurry inpainting - fix for pillow compatibility **2.0.0** + - ๐ŸŽ‰ fix: inpainted areas correlate with surrounding image, even at 100% generation strength. Previously if the generation strength was high enough the generated image would be uncorrelated to the rest of the surrounding image. It created terrible looking images. - ๐ŸŽ‰ feature: interactive prompt added. access by running `aimg` @@ -370,35 +430,44 @@ would be uncorrelated to the rest of the surrounding image. It created terrible - fix: img2img algorithm was wrong and wouldn't at values close to 0 or 1 **1.6.2** + - fix: another bfloat16 fix **1.6.1** + - fix: make sure image tensors come to the CPU as float32 so there aren't compatibility issues with non-bfloat16 cpus **1.6.0** + - fix: *maybe* address #13 with `expected scalar type BFloat16 but found Float` - at minimum one can specify `--precision full` now and that will probably fix the issue - feature: tile mode can now be specified per-prompt **1.5.3** + - fix: missing config file for describe feature **1.5.1** + - img2img now supported with PLMS (instead of just DDIM) - added image captioning feature `aimg describe dog.jpg` => `a brown dog sitting on grass` - added new commandline tool `aimg` for additional image manipulation functionality **1.4.0** + - support multiple additive targets for masking with `|` symbol. Example: "fruit|stem|fruit stem" **1.3.0** + - added prompt based image editing. Example: "fruit => gold coins" - test coverage improved **1.2.0** + - allow urls as init-images **previous** + - img2img actually does # of steps you specify - performance optimizations - numerous other changes diff --git a/docs/docs/Python/upscale.md b/docs/docs/Python/upscale.md new file mode 100644 index 00000000..7e0082a5 --- /dev/null +++ b/docs/docs/Python/upscale.md @@ -0,0 +1 @@ +::: imaginairy.api.upscale.upscale diff --git a/docs/index.md b/docs/index.md index 9c96f441..edc2fe76 100644 --- a/docs/index.md +++ b/docs/index.md @@ -314,6 +314,29 @@ allow the tool to generate one for you.

+## Image Upscaling +Upscale images easily. + +=== "CLI" + ```bash + aimg upscale assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg --upscale-model real-hat + ``` + +=== "Python" + ```py + from imaginairy.api.upscale import upscale + + img = upscale(img="assets/000206_856637805_PLMS40_PS7.5_colorful_smoke.jpg") + img.save("colorful_smoke.upscaled.jpg") + + ``` + โžก๏ธ + + +Upscaling uses [Spandrel](https://github.com/chaiNNer-org/spandrel) to make it easy to use different upscaling models. +You can view different integrated models by running `aimg upscale --list-models`, and then use it with `--upscale-model `. +Also accepts url's if you want to upscale an image with a different model. Control the new file format/location with --format. + ## Video Generation @@ -329,4 +352,6 @@ allow the tool to generate one for you. generate_video(input_path="assets/rocket-wide.png") ``` - \ No newline at end of file + + + diff --git a/imaginairy/api/generate_compvis.py b/imaginairy/api/generate_compvis.py index 7e1106a1..50429860 100644 --- a/imaginairy/api/generate_compvis.py +++ b/imaginairy/api/generate_compvis.py @@ -30,7 +30,7 @@ def _generate_single_image( from imaginairy.enhancers.clip_masking import get_img_mask from imaginairy.enhancers.describe_image_blip import generate_caption from imaginairy.enhancers.face_restoration_codeformer import enhance_faces - from imaginairy.enhancers.upscale_realesrgan import upscale_image + from imaginairy.enhancers.upscale import upscale_image from imaginairy.modules.midas.api import torch_image_to_depth_map from imaginairy.samplers import SOLVER_LOOKUP from imaginairy.samplers.editing import CFGEditingDenoiser @@ -534,7 +534,7 @@ def _generate_composition_image( result = _generate_single_image(composition_prompt, dtype=dtype) img = result.images["generated"] while img.width < target_width: - from imaginairy.enhancers.upscale_realesrgan import upscale_image + from imaginairy.enhancers.upscale import upscale_image img = upscale_image(img) diff --git a/imaginairy/api/generate_refiners.py b/imaginairy/api/generate_refiners.py index 5f0ed66f..6203101c 100644 --- a/imaginairy/api/generate_refiners.py +++ b/imaginairy/api/generate_refiners.py @@ -35,7 +35,7 @@ def generate_single_image( from imaginairy.enhancers.clip_masking import get_img_mask from imaginairy.enhancers.describe_image_blip import generate_caption from imaginairy.enhancers.face_restoration_codeformer import enhance_faces - from imaginairy.enhancers.upscale_realesrgan import upscale_image + from imaginairy.enhancers.upscale import upscale_image from imaginairy.samplers import SolverName from imaginairy.schema import ImagineResult from imaginairy.utils import get_device, randn_seeded @@ -603,7 +603,7 @@ def _generate_composition_image( ) img = result.images["generated"] while img.width < target_width: - from imaginairy.enhancers.upscale_realesrgan import upscale_image + from imaginairy.enhancers.upscale import upscale_image if prompt.fix_faces: from imaginairy.enhancers.face_restoration_codeformer import enhance_faces @@ -612,7 +612,7 @@ def _generate_composition_image( logger.info("Fixing ๐Ÿ˜Š 's in ๐Ÿ–ผ using CodeFormer...") img = enhance_faces(img, fidelity=prompt.fix_faces_fidelity) with logging_context.timing("upscaling"): - img = upscale_image(img, ultrasharp=True) + img = upscale_image(img, upscaler_model="ultrasharp") img = img.resize( (target_width, target_height), diff --git a/imaginairy/api/upscale.py b/imaginairy/api/upscale.py new file mode 100644 index 00000000..9f7d5fe4 --- /dev/null +++ b/imaginairy/api/upscale.py @@ -0,0 +1,57 @@ +from typing import TYPE_CHECKING, Union + +from imaginairy.config import DEFAULT_UPSCALE_MODEL + +if TYPE_CHECKING: + from PIL import Image + + from imaginairy.schema import LazyLoadingImage + + +def upscale_image( + img: "Union[LazyLoadingImage, Image.Image, str]", + upscale_model: str = DEFAULT_UPSCALE_MODEL, + tile_size: int = 512, + tile_pad: int = 50, + repetition: int = 1, + device=None, +) -> "Image.Image": + """ + Upscales an image using a specified super-resolution model. + + It accepts an image in various forms: a LazyLoadingImage instance, a PIL Image, + or a string representing a URL or file path. Supports different upscaling models, customizable tile size, padding, + and the number of repetitions for upscaling. It can use tiles to manage memory usage on large images and supports multiple passes for upscaling. + + Args: + img (LazyLoadingImage | Image.Image | str): The input image. + upscale_model (str, optional): Upscaling model to use. Defaults to realesrgan-x2-plus + tile_size (int, optional): Size of the tiles used for processing the image. Defaults to 512. + tile_pad (int, optional): Padding size for each tile. Defaults to 50. + repetition (int, optional): Number of times the upscaling is repeated. Defaults to 1. + device: The device (CPU/GPU) to be used for computation. If None, the best available device is used. + + Returns: + Image.Image: The upscaled image as a PIL Image object. + """ + from PIL import Image + + from imaginairy.enhancers.upscale import upscale_image + from imaginairy.schema import LazyLoadingImage + + if isinstance(img, str): + if img.startswith("https://"): + img = LazyLoadingImage(url=img) + else: + img = LazyLoadingImage(filepath=img) + elif isinstance(img, Image.Image): + img = LazyLoadingImage(img=img) + + return upscale_image( + img, + upscale_model, + tile_size=tile_size, + tile_pad=tile_pad, + repetition=repetition, + device=device, + ) diff --git a/imaginairy/cli/upscale.py b/imaginairy/cli/upscale.py index 5894e363..95ce2e76 100644 --- a/imaginairy/cli/upscale.py +++ b/imaginairy/cli/upscale.py @@ -1,19 +1,25 @@ """Command for upscaling images with AI""" import logging +import os.path +from datetime import datetime, timezone import click +from imaginairy.config import DEFAULT_UPSCALE_MODEL + logger = logging.getLogger(__name__) +DEFAULT_FORMAT_TEMPLATE = "{original_filename}.upscaled{file_extension}" + -@click.argument("image_filepaths", nargs=-1) +@click.argument("image_filepaths", nargs=-1, required=False) @click.option( "--outdir", default="./outputs/upscaled", show_default=True, type=click.Path(), - help="Where to write results to.", + help="Where to write results to. Default will be where the directory of the original file.", ) @click.option("--fix-faces", is_flag=True) @click.option( @@ -22,34 +28,109 @@ type=float, help="How faithful to the original should face enhancement be. 1 = best fidelity, 0 = best looking face.", ) +@click.option( + "--upscale-model", + multiple=True, + type=str, + default=[DEFAULT_UPSCALE_MODEL], + show_default=True, + help="Specify one or more upscale models to use.", +) +@click.option("--list-models", is_flag=True, help="View available upscale models.") +@click.option( + "--format", + "format_template", + default="{original_filename}.upscaled{file_extension}", + type=str, + help="Formats the file name. Default value will save '{original_filename}.upscaled{file_extension}' to the original directory." + " {original_filename}: original name without the extension;" + "{file_sequence_number:pad}: sequence number in directory, can make zero-padded (e.g., 06 for six digits).;" + " {algorithm}: upscaling algorithm; " + "{now:%Y-%m-%d:%H-%M-%S}: current date and time, customizable using standard strftime format codes. " + "Use 'DEV' to config to save in standard imaginAIry format '{file_sequence_number:06}_{algorithm}_{original_filename}.upscaled{file_extension}'. ", +) @click.command("upscale") -def upscale_cmd(image_filepaths, outdir, fix_faces, fix_faces_fidelity): +def upscale_cmd( + image_filepaths, + outdir, + fix_faces, + fix_faces_fidelity, + upscale_model, + list_models, + format_template, +): """ Upscale an image 4x using AI. """ - import os.path - - from tqdm import tqdm from imaginairy.enhancers.face_restoration_codeformer import enhance_faces - from imaginairy.enhancers.upscale_realesrgan import upscale_image + from imaginairy.enhancers.upscale import upscale_image, upscale_model_lookup from imaginairy.schema import LazyLoadingImage from imaginairy.utils import glob_expand_paths + from imaginairy.utils.format_file_name import format_filename, get_url_file_name + from imaginairy.utils.log_utils import configure_logging + + configure_logging() + + if list_models: + for model_name in upscale_model_lookup: + click.echo(f"{model_name}") + return os.makedirs(outdir, exist_ok=True) image_filepaths = glob_expand_paths(image_filepaths) - for p in tqdm(image_filepaths): - savepath = os.path.join(outdir, os.path.basename(p)) + + if not image_filepaths: + click.echo( + "Error: No valid image file paths found. Please check the provided file paths." + ) + return + + if format_template == "DEV": + format_template = "{file_sequence_number:06}_{algorithm}_{original_filename}.upscaled{file_extension}" + elif format_template == "DEFAULT": + format_template = DEFAULT_FORMAT_TEMPLATE + + for n, p in enumerate(image_filepaths): if p.startswith("http"): img = LazyLoadingImage(url=p) else: img = LazyLoadingImage(filepath=p) - logger.info( - f"Upscaling {p} from {img.width}x{img.height} to {img.width * 4}x{img.height * 4} and saving it to {savepath}" - ) + orig_height = img.height + for model in upscale_model: + logger.info( + f"Upscaling ({n + 1}/{len(image_filepaths)}) {p} ({img.width}x{img.height})..." + ) + + img = upscale_image(img, model) + if fix_faces: + img = enhance_faces(img, fidelity=fix_faces_fidelity) + + if format_template == DEFAULT_FORMAT_TEMPLATE: + outdir = os.path.dirname(p) + "/" + + file_base_name, extension = os.path.splitext(os.path.basename(p)) + base_count = len(os.listdir(outdir)) + + now = datetime.now(timezone.utc) - img = upscale_image(img) - if fix_faces: - img = enhance_faces(img, fidelity=fix_faces_fidelity) + if model.startswith(("https://", "http://")): + model_name = get_url_file_name(model) + else: + model_name = model - img.save(os.path.join(outdir, os.path.basename(p))) + new_file_name_data = { + "original_filename": file_base_name, + "output_path": outdir, + "file_sequence_number": base_count, + "algorithm": model_name, + "now": now, + "file_extension": extension, + } + new_file_name = format_filename(format_template, new_file_name_data) + new_file_path = os.path.join(outdir, new_file_name) + img.save(new_file_path) + scale = int(img.height / orig_height) + logger.info( + f"Upscaled {scale}x to {img.width}x{img.height} and saved to {new_file_path}" + ) diff --git a/imaginairy/config.py b/imaginairy/config.py index 75af3027..9cce76ac 100644 --- a/imaginairy/config.py +++ b/imaginairy/config.py @@ -5,6 +5,7 @@ DEFAULT_MODEL_WEIGHTS = "sd15" DEFAULT_SOLVER = "ddim" +DEFAULT_UPSCALE_MODEL = "realesrgan-x2-plus" DEFAULT_NEGATIVE_PROMPT = ( "Ugly, duplication, duplicates, mutilation, deformed, mutilated, mutation, twisted body, disfigured, bad anatomy, " diff --git a/imaginairy/enhancers/upscale.py b/imaginairy/enhancers/upscale.py new file mode 100644 index 00000000..ca7ba279 --- /dev/null +++ b/imaginairy/enhancers/upscale.py @@ -0,0 +1,113 @@ +import logging +from typing import TYPE_CHECKING, Union + +from imaginairy.config import DEFAULT_UPSCALE_MODEL +from imaginairy.utils import get_device + +if TYPE_CHECKING: + from PIL import Image + + from imaginairy.schema import LazyLoadingImage + + +upscale_model_lookup = { + # RealESRGAN + "ultrasharp": "https://huggingface.co/lokCX/4x-Ultrasharp/resolve/1856559b50de25116a7c07261177dd128f1f5664/4x-UltraSharp.pth", + "realesrgan-x4-plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + "realesrgan-x2-plus": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + # ESRGAN + "esrgan-x4": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + # HAT + "real-hat": "https://huggingface.co/imaginairy/model-weights/resolve/main/weights/super-resolution/hat/Real_HAT_GAN_SRx4.safetensors", + "real-hat-sharper": "https://huggingface.co/imaginairy/model-weights/resolve/main/weights/super-resolution/hat/Real_HAT_GAN_sharper.safetensors", + "4xNomos8kHAT-L": "https://huggingface.co/imaginairy/model-weights/resolve/main/weights/super-resolution/hat/4xNomos8kHAT-L_otf.safetensors", +} +logger = logging.getLogger(__name__) + + +def upscale_image( + img: "Union[LazyLoadingImage, Image.Image]", + upscaler_model: str = DEFAULT_UPSCALE_MODEL, + tile_size: int = 512, + tile_pad: int = 50, + repetition: int = 1, + device=None, +) -> "Image.Image": + """ + Upscales an image using a specified super-resolution model. + + Supports various upscaling models defined in the `upscale_model_lookup` dictionary, as well as direct URLs to models. + It can process the image in tiles (to manage memory usage on large images) and supports multiple passes for upscaling. + + Args: + img (LazyLoadingImage | Image.Image): The input image to be upscaled. + upscaler_model (str, optional): Key for the upscaling model to use. Defaults to DEFAULT_UPSCALE_MODEL. + tile_size (int, optional): Size of the tiles used for processing the image. Defaults to 512. + tile_pad (int, optional): Padding size for each tile. Defaults to 50. + repetition (int, optional): Number of times the upscaling is repeated. Defaults to 1. + device: The device (CPU/GPU) to be used for computation. If None, the best available device is used. + + Returns: + Image.Image: The upscaled image as a PIL Image object. + """ + import torch + import torchvision.transforms.functional as F + from spandrel import ImageModelDescriptor, ModelLoader + + from imaginairy.utils.downloads import get_cached_url_path + from imaginairy.utils.tile_up import tile_process + + device = device or get_device() + + if upscaler_model in upscale_model_lookup: + model_url = upscale_model_lookup[upscaler_model] + model_path = get_cached_url_path(model_url) + elif upscaler_model.startswith(("https://", "http://")): + model_url = upscaler_model + model_path = get_cached_url_path(model_url) + else: + model_path = upscaler_model + + model = ModelLoader().load_from_file(model_path) + logger.debug(f"Upscaling image with model {model.architecture}@{upscaler_model}") + + assert isinstance(model, ImageModelDescriptor) + + model.to(torch.device(device)).eval() + + image_tensor = load_image(img).to(device) + with torch.no_grad(): + for _ in range(repetition): + if tile_size > 0: + image_tensor = tile_process( + image_tensor, + scale=model.scale, + model=model, + tile_size=tile_size, + tile_pad=tile_pad, + ) + else: + image_tensor = model(image_tensor) + + image_tensor = image_tensor.squeeze(0) + image = F.to_pil_image(image_tensor) + image = image.resize((img.width * model.scale, img.height * model.scale)) + + return image + + +def load_image(img: "Union[LazyLoadingImage, Image.Image]"): + """ + Converts a LazyLoadingImage or PIL Image into a PyTorch tensor. + """ + from torchvision import transforms + + from imaginairy.schema import LazyLoadingImage + + if isinstance(img, LazyLoadingImage): + img = img.as_pillow() + transform = transforms.ToTensor() + image_tensor = transform(img) + + image_tensor = image_tensor.unsqueeze(0) + return image_tensor.to(get_device()) diff --git a/imaginairy/utils/__init__.py b/imaginairy/utils/__init__.py index 65ed5474..d13b76bc 100644 --- a/imaginairy/utils/__init__.py +++ b/imaginairy/utils/__init__.py @@ -242,6 +242,7 @@ def glob_expand_paths(paths): expanded_paths.append(p) else: expanded_paths.extend(glob.glob(os.path.expanduser(p))) + return expanded_paths diff --git a/imaginairy/utils/format_file_name.py b/imaginairy/utils/format_file_name.py new file mode 100644 index 00000000..33d90d6b --- /dev/null +++ b/imaginairy/utils/format_file_name.py @@ -0,0 +1,19 @@ +import os +from urllib.parse import urlparse + + +def format_filename(format_template: str, data: dict) -> str: + """ + Formats the filename based on the provided template and variables. + """ + if not isinstance(format_template, str): + raise TypeError("format argument must be a string") + + filename = format_template.format(**data) + return filename + + +def get_url_file_name(url): + parsed = urlparse(url) + model_name, _ = os.path.splitext(os.path.basename(parsed.path)) + return model_name diff --git a/imaginairy/utils/log_utils.py b/imaginairy/utils/log_utils.py index 90d29122..de1d2218 100644 --- a/imaginairy/utils/log_utils.py +++ b/imaginairy/utils/log_utils.py @@ -488,6 +488,16 @@ def disable_common_warnings(): "ignore", category=UserWarning, message=r"Arguments other than a weight.*" ) warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=r".*?torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument..*?", + ) + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=r".*?is not currently supported on the MPS backend and will fall back.*?", + ) def suppress_annoying_logs_and_warnings(): diff --git a/imaginairy/utils/tile_up.py b/imaginairy/utils/tile_up.py new file mode 100644 index 00000000..44123371 --- /dev/null +++ b/imaginairy/utils/tile_up.py @@ -0,0 +1,91 @@ +import logging +import math +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch + from torch import Tensor + +logger = logging.getLogger(__name__) + + +def tile_process( + img: "Tensor", + scale: int, + model: "torch.nn.Module", + tile_size: int = 512, + tile_pad: int = 50, +) -> "Tensor": + """ + Process an image by tiling it, processing each tile, and then merging them back into one image. + + Args: + img (Tensor): The input image tensor. + scale (int): The scale factor for the image. + tile_size (int): The size of each tile. + tile_pad (int): The padding for each tile. + model (torch.nn.Module): The model used for processing the tile. + + Returns: + Tensor: The processed output image. + """ + import torch + + batch, channel, height, width = img.shape + output_height = height * scale + output_width = width * scale + output_shape = (batch, channel, output_height, output_width) + + # Initialize the output tensor + output = img.new_zeros(output_shape) + tiles_x = math.ceil(width / tile_size) + tiles_y = math.ceil(height / tile_size) + logger.debug(f"Tiling with {tiles_x}x{tiles_y} ({tiles_x*tiles_y}) tiles") + + for y in range(tiles_y): + for x in range(tiles_x): + # Calculate the input tile coordinates with and without padding + ofs_x, ofs_y = x * tile_size, y * tile_size + input_start_x, input_end_x = ofs_x, min(ofs_x + tile_size, width) + input_start_y, input_end_y = ofs_y, min(ofs_y + tile_size, height) + padded_start_x, padded_end_x = ( + max(input_start_x - tile_pad, 0), + min(input_end_x + tile_pad, width), + ) + padded_start_y, padded_end_y = ( + max(input_start_y - tile_pad, 0), + min(input_end_y + tile_pad, height), + ) + + # Extract the input tile + input_tile = img[ + :, :, padded_start_y:padded_end_y, padded_start_x:padded_end_x + ] + + # Process the tile + with torch.no_grad(): + output_tile = model(input_tile) + + # Calculate the output tile coordinates + output_start_x, output_end_x = input_start_x * scale, input_end_x * scale + output_start_y, output_end_y = input_start_y * scale, input_end_y * scale + tile_output_start_x = (input_start_x - padded_start_x) * scale + tile_output_end_x = ( + tile_output_start_x + (input_end_x - input_start_x) * scale + ) + tile_output_start_y = (input_start_y - padded_start_y) * scale + tile_output_end_y = ( + tile_output_start_y + (input_end_y - input_start_y) * scale + ) + + # Place the processed tile in the output image + output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = ( + output_tile[ + :, + :, + tile_output_start_y:tile_output_end_y, + tile_output_start_x:tile_output_end_x, + ] + ) + + return output diff --git a/imaginairy/vendored/clip/clip.py b/imaginairy/vendored/clip/clip.py index 197464cf..69f41c6e 100644 --- a/imaginairy/vendored/clip/clip.py +++ b/imaginairy/vendored/clip/clip.py @@ -8,7 +8,6 @@ import torch from PIL import Image -from pkg_resources import packaging from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor from tqdm import tqdm @@ -23,9 +22,6 @@ BICUBIC = Image.BICUBIC -if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): - warnings.warn("PyTorch version 1.7.1 or higher is recommended") - __all__ = ["available_models", "load", "tokenize"] _tokenizer = _Tokenizer() @@ -272,10 +268,7 @@ def tokenize( sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - else: - result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) + result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: diff --git a/mkdocs.yml b/mkdocs.yml index 383651fd..13e5c9a3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,8 +49,10 @@ nav: - imagine_image_files(): docs/Python/imagine-image-files.md - generate_video(): docs/Python/generate-video.md - colorize_img(): docs/Python/colorize-img.md + - upscale(): docs/Python/upscale.md - ImaginePrompt: docs/Python/ImaginePrompt.md - ControlInput: docs/Python/ControlInput.md - LazyLoadingImage: docs/Python/LazyLoadingImage.md - WeightedPrompt: docs/Python/WeightedPrompt.md + - Changelog: changelog.md diff --git a/requirements-dev.txt b/requirements-dev.txt index d8275291..11f68013 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,7 +14,7 @@ anyio==4.3.0 # starlette babel==2.14.0 # via mkdocs-material -build==1.1.1 +build==1.2.1 # via pip-tools certifi==2024.2.2 # via @@ -43,55 +43,57 @@ colorama==0.4.6 # mkdocs-material coverage==7.4.4 # via -r requirements-dev.in -diffusers==0.27.0 +diffusers==0.27.2 # via imaginAIry (setup.py) einops==0.7.0 - # via imaginAIry (setup.py) + # via + # imaginAIry (setup.py) + # spandrel exceptiongroup==1.2.0 # via # anyio # pytest -fastapi==0.110.0 +fastapi==0.110.1 # via imaginAIry (setup.py) -filelock==3.13.1 +filelock==3.13.4 # via # diffusers # huggingface-hub # torch # transformers -fsspec==2024.2.0 +fsspec==2024.3.1 # via # huggingface-hub # torch -ftfy==6.1.3 +ftfy==6.2.0 # via # imaginAIry (setup.py) # open-clip-torch ghp-import==2.1.0 # via mkdocs -griffe==0.42.0 +griffe==0.42.2 # via mkdocstrings-python h11==0.14.0 # via # httpcore # uvicorn -httpcore==1.0.4 +httpcore==1.0.5 # via httpx httpx==0.27.0 # via -r requirements-dev.in -huggingface-hub==0.21.4 +huggingface-hub==0.22.2 # via # diffusers # open-clip-torch # timm # tokenizers # transformers -idna==3.6 +idna==3.7 # via # anyio # httpx # requests -importlib-metadata==7.0.2 +importlib-metadata==7.1.0 # via diffusers iniconfig==2.0.0 # via pytest @@ -105,16 +107,15 @@ jinja2==3.1.3 # torch kornia==0.7.2 # via imaginAIry (setup.py) -kornia-rs==0.1.1 +kornia-rs==0.1.3 # via kornia -markdown==3.5.2 +markdown==3.6 # via # mkdocs # mkdocs-autorefs # mkdocs-click # mkdocs-material # mkdocstrings - # mkdocstrings-python # pymdown-extensions markupsafe==2.1.5 # via @@ -133,15 +134,15 @@ mkdocs-autorefs==1.0.1 # via mkdocstrings mkdocs-click==0.8.1 # via -r requirements-dev.in -mkdocs-material==9.5.13 +mkdocs-material==9.5.18 # via -r requirements-dev.in mkdocs-material-extensions==1.3.1 # via mkdocs-material -mkdocstrings[python]==0.24.1 +mkdocstrings[python]==0.24.3 # via # -r requirements-dev.in # mkdocstrings-python -mkdocstrings-python==1.9.0 +mkdocstrings-python==1.9.2 # via mkdocstrings mpmath==1.3.0 # via sympy @@ -149,7 +150,7 @@ mypy==1.9.0 # via -r requirements-dev.in mypy-extensions==1.0.0 # via mypy -networkx==3.2.1 +networkx==3.3 # via torch numpy==1.24.4 # via @@ -159,6 +160,7 @@ numpy==1.24.4 # jaxtyping # opencv-python # scipy + # spandrel # torchvision # transformers omegaconf==2.3.0 @@ -180,7 +182,7 @@ paginate==0.5.6 # via mkdocs-material pathspec==0.12.1 # via mkdocs -pillow==10.2.0 +pillow==10.3.0 # via # diffusers # imaginAIry (setup.py) @@ -193,17 +195,17 @@ platformdirs==4.2.0 # mkdocstrings pluggy==1.4.0 # via pytest -protobuf==5.26.0 +protobuf==5.26.1 # via # imaginAIry (setup.py) # open-clip-torch psutil==5.9.8 # via imaginAIry (setup.py) -pydantic==2.6.4 +pydantic==2.7.0 # via # fastapi # imaginAIry (setup.py) -pydantic-core==2.16.3 +pydantic-core==2.18.1 # via pydantic pygments==2.17.2 # via mkdocs-material @@ -223,7 +225,7 @@ pytest==8.1.1 # pytest-asyncio # pytest-randomly # pytest-sugar -pytest-asyncio==0.23.5.post1 +pytest-asyncio==0.23.6 # via -r requirements-dev.in pytest-randomly==3.15.0 # via -r requirements-dev.in @@ -243,7 +245,7 @@ pyyaml==6.0.1 # transformers pyyaml-env-tag==0.1 # via mkdocs -regex==2023.12.25 +regex==2024.4.16 # via # diffusers # mkdocs-material @@ -259,15 +261,16 @@ requests==2.31.0 # transformers responses==0.25.0 # via -r requirements-dev.in -ruff==0.3.3 +ruff==0.3.7 # via -r requirements-dev.in -safetensors==0.4.2 +safetensors==0.4.3 # via # diffusers # imaginAIry (setup.py) + # spandrel # timm # transformers -scipy==1.12.0 +scipy==1.13.0 # via # imaginAIry (setup.py) # torchdiffeq @@ -279,7 +282,9 @@ sniffio==1.3.1 # via # anyio # httpx -starlette==0.36.3 +spandrel==0.3.1 + # via imaginAIry (setup.py) +starlette==0.37.2 # via fastapi sympy==1.12 # via torch @@ -300,20 +305,22 @@ tomli==2.0.1 # pip-tools # pyproject-hooks # pytest -torch==2.2.1 +torch==2.2.2 # via # imaginAIry (setup.py) # kornia # open-clip-torch + # spandrel # timm # torchdiffeq # torchvision torchdiffeq==0.2.3 # via imaginAIry (setup.py) -torchvision==0.17.1 +torchvision==0.17.2 # via # imaginAIry (setup.py) # open-clip-torch + # spandrel # timm tqdm==4.66.2 # via @@ -321,19 +328,19 @@ tqdm==4.66.2 # imaginAIry (setup.py) # open-clip-torch # transformers -transformers==4.38.2 +transformers==4.39.3 # via imaginAIry (setup.py) typeguard==2.13.3 # via jaxtyping -types-pillow==10.2.0.20240311 +types-pillow==10.2.0.20240415 # via -r requirements-dev.in -types-psutil==5.9.5.20240311 +types-psutil==5.9.5.20240316 # via -r requirements-dev.in -types-requests==2.31.0.20240311 +types-requests==2.31.0.20240406 # via -r requirements-dev.in -types-tqdm==4.66.0.20240106 +types-tqdm==4.66.0.20240417 # via -r requirements-dev.in -typing-extensions==4.10.0 +typing-extensions==4.11.0 # via # anyio # fastapi @@ -341,6 +348,7 @@ typing-extensions==4.10.0 # mypy # pydantic # pydantic-core + # spandrel # torch # uvicorn urllib3==2.2.1 @@ -348,7 +356,7 @@ urllib3==2.2.1 # requests # responses # types-requests -uvicorn==0.28.0 +uvicorn==0.29.0 # via imaginAIry (setup.py) watchdog==4.0.0 # via mkdocs diff --git a/setup.py b/setup.py index 0c79e4a9..90b0f2c2 100644 --- a/setup.py +++ b/setup.py @@ -108,6 +108,7 @@ def get_git_revision_hash() -> str: "triton>=2.0.0; sys_platform!='darwin' and platform_machine!='aarch64' and sys_platform == 'linux'", "kornia>=0.6", "uvicorn>=0.16.0", + "spandrel>=0.1.8", # "xformers>=0.0.22; sys_platform!='darwin' and platform_machine!='aarch64'", ], # don't specify maximum python versions as it can cause very long dependency resolution issues as the resolver diff --git a/tests/test_enhancers/test_upscale.py b/tests/test_enhancers/test_upscale.py new file mode 100644 index 00000000..350ba609 --- /dev/null +++ b/tests/test_enhancers/test_upscale.py @@ -0,0 +1,43 @@ +from unittest.mock import Mock, patch + +import pytest +from click.testing import CliRunner +from PIL import Image + +from imaginairy.cli.upscale import ( + upscale_cmd, +) +from tests import TESTS_FOLDER + + +@pytest.fixture() +def mock_pil_save(): + with patch.object(Image, "save", autospec=True) as mock_save: + yield mock_save + + +def test_upscale_cmd_format_option(): + runner = CliRunner() + + mock_img = Mock() + mock_img.save = Mock() + mock_img.height = 1000 + + with patch.multiple( + "imaginairy.enhancers.upscale", upscale_image=Mock(return_value=mock_img) + ), patch( + "imaginairy.utils.glob_expand_paths", + new=Mock(return_value=[f"{TESTS_FOLDER}/data/sand_upscale_difficult.jpg"]), + ): + result = runner.invoke( + upscale_cmd, + [ + "tests/data/sand_upscale_difficult.jpg", + "--format", + "{original_filename}_upscaled_{file_sequence_number}_{algorithm}_{now}", + ], + ) + + assert result.exit_code == 0 + assert "saved to " in result.output + mock_img.save.assert_called() # Check if save method was called diff --git a/tests/test_utils/test_format_file_name.py b/tests/test_utils/test_format_file_name.py new file mode 100644 index 00000000..f21029f6 --- /dev/null +++ b/tests/test_utils/test_format_file_name.py @@ -0,0 +1,64 @@ +from datetime import datetime, timezone + +import pytest + + +def format_filename(format_template: str, data: dict) -> str: + """ + Formats the filename based on the provided template and variables. + """ + if not isinstance(format_template, str): + raise TypeError("format argument must be a string") + + filename = format_template.format(**data) + filename += data["ext"] + return filename + + +base_data = { + "original": "file", + "number": 1, + "algorithm": "alg", + "now": datetime(2023, 1, 23, 12, 30, 45, tzinfo=timezone.utc), + "ext": ".jpg", +} + + +@pytest.mark.parametrize( + ("format_str", "data", "expected"), + [ + ("{original}_{algorithm}", base_data, "file_alg.jpg"), + ( + "{original}_{number}_{now}", + base_data, + "file_1_2023-01-23 12:30:45+00:00.jpg", + ), + ("", base_data, ".jpg"), + ("{original}", {}, KeyError), + ("{nonexistent_key}", base_data, KeyError), + (123, base_data, TypeError), + ("{original}_@#$_{algorithm}", base_data, "file_@#$_alg.jpg"), + ("{original}" * 100, base_data, "file" * 100 + ".jpg"), + ( + "{original}_{number}", + {"original": "file", "number": 123, "ext": ".jpg"}, + "file_123.jpg", + ), + ( + "{now}", + {"now": "2023/01/23", "ext": ".log"}, + "2023/01/23.log", + ), + ("{original}", {"original": "file", "ext": ""}, "file"), + ], +) +def test_format_filename(format_str, data, expected): + if isinstance(expected, type) and issubclass(expected, Exception): + try: + format_filename(format_str, data) + except expected: + assert True, f"Expected {expected} to be raised" + except Exception: + raise + else: + assert format_filename(format_str, data) == expected